Shortcuts

Lightning Flash




Quick Start

Flash is a high-level deep learning framework for fast prototyping, baselining, finetuning and solving deep learning problems. It features a set of tasks for you to use for inference and finetuning out of the box, and an easy to implement API to customize every step of the process for full flexibility.

Flash is built for beginners with a simple API that requires very little deep learning background, and for data scientists, Kagglers, applied ML practitioners and deep learning researchers that want a quick way to get a deep learning baseline with advanced features PyTorch Lightning offers.

Why Flash?

For getting started with Deep Learning

Easy to learn

If you are just getting started with deep learning, Flash offers common deep learning tasks you can use out-of-the-box in a few lines of code, no math, fancy nn.Modules or research experience required!

Easy to scale

Flash is built on top of PyTorch Lightning, a powerful deep learning research framework for training models at scale. With the power of Lightning, you can train your flash tasks on any hardware: CPUs, GPUs or TPUs without any code changes.

Easy to upskill

If you want to create more complex and customized models, you can refactor any part of flash with PyTorch or PyTorch Lightning components to get all the flexibility you need. Lightning is just organized PyTorch with the unnecessary engineering details abstracted away.

  • Flash (high-level)

  • Lightning (mid-level)

  • PyTorch (low-level)

When you need more flexibility you can build your own tasks or simply use Lightning directly.

For Deep learning research

Quickest way to a baseline

PyTorch Lightning is designed to abstract away unnecessary boilerplate, while enabling maximal flexibility. In order to provide full flexibility, solving very common deep learning problems such as classification in Lightning still requires some boilerplate. It can still take quite some time to get a baseline model running on a new dataset or out of domain task. We created Flash to answer our users need for a super quick way to baseline for Lightning using proven backbones for common data patterns. Flash aims to be the easiest starting point for your research- start with a Flash Task to benchmark against, and override any part of flash with Lightning or PyTorch components on your way to SOTA research.

Flexibility where you want it

Flash tasks are essentially LightningModules, and the Flash Trainer is a thin wrapper for the Lightning Trainer. You can use your own LightningModule instead of the Flash task, the Lightning Trainer instead of the flash trainer, etc. Flash helps you focus even more only on your research, and less on anything else.

Standard best practices

Flash tasks implement the standard best practices for a variety of different models and domains, to save you time digging through different implementations. Flash abstracts even more details than Lightning, allowing deep learning experts to share their tips and tricks for solving scoped deep learning problems.


Tasks

Flash is comprised of a collection of Tasks. The Flash tasks are laser-focused objects designed to solve a well-defined type of problem, using state-of-the-art methods.

The Flash tasks contain all the relevant information to solve the task at hand- the number of class labels you want to predict, number of columns in your dataset, as well as details on the model architecture used such as loss function, optimizers, etc.

Here are examples of tasks:

from flash.text import TextClassifier
from flash.image import ImageClassifier
from flash.tabular import TabularClassifier

Note

Tasks are inflexible by definition! To get more flexibility, you can simply use LightningModule directly or modify an existing task in just a few lines.


Inference

Inference is the process of generating predictions from trained models. To use a task for inference:

  1. Init your task with pretrained weights using a checkpoint (a checkpoint is simply a file that capture the exact value of all parameters used by a model). Local file or URL works.

  2. Pass in the data to flash.core.model.Task.predict().


Here’s an example of inference:

# import our libraries
from flash.text import TextClassifier

# 1. Init the finetuned task from URL
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict(
    [
        "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
        "The worst movie in the history of cinema.",
        "This guy has done a great job with this movie!",
    ]
)
print(predictions)

We get the following output:

["negative", "negative", "positive"]

Finetuning

Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset. All Flash tasks have pre-trained backbones that are already trained on large datasets such as ImageNet. Finetuning on pretrained models decreases training time significantly.

To use a Task for finetuning:

  1. Load your data and organize it using a DataModule customized for the task (example: ImageClassificationData).

  2. Choose and initialize your Task which has state-of-the-art backbones built in (example: ImageClassifier).

  3. Init a flash.core.trainer.Trainer.

  4. Choose a finetune strategy (example: “freeze”) and call flash.core.trainer.Trainer.finetune() with your data.

  5. Save your finetuned model.


Here’s an example of finetuning.

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Using a finetuned model

Once you’ve finetuned, use the model to predict:

# Serialize predictions as labels, automatically inferred from the training data in part 2.
model.serializer = Labels()

predictions = model.predict(
    [
        "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
        "data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
    ]
)
print(predictions)

We get the following output:

['bees', 'ants']

Or you can use the saved model for prediction anywhere you want!

from flash.image import ImageClassifier

# load finetuned checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")

predictions = model.predict("path/to/your/own/image.png")

Training

When you have enough data, you’re likely better off training from scratch instead of finetuning.

To train a task from scratch:

  1. Load your data and organize it using a DataModule customized for the task (example: ImageClassificationData).

  2. Choose and initialize your Task (setting pretrained=False) which has state-of-the-art backbones built in (example: ImageClassifier).

  3. Init a flash.core.trainer.Trainer or a pytorch_lightning.trainer.Trainer.

  4. Call flash.core.trainer.Trainer.fit() with your data set.

  5. Save your trained model.


Here’s an example:

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Train the model
trainer.fit(model, datamodule=datamodule)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

A few Built-in Tasks

More tasks coming soon!

Contribute a task

The lightning + Flash team is hard at work building more tasks for common deep-learning use cases. But we’re looking for incredible contributors like you to submit new tasks!

Join our Slack to get help becoming a contributor!

Installation

Install with pip

pip install lightning-flash

Optionally, you can install Flash with extra packages for each domain.

For a single domain, use: pip install 'lightning-flash[{DOMAIN}]'.

pip install 'lightning-flash[image]'
pip install 'lightning-flash[tabular]'
pip install 'lightning-flash[text]'
...

For muliple domains, use: pip install 'lightning-flash[{DOMAIN_1, DOMAIN_2, ...}]'.

pip install 'lightning-flash[audio,image]'
...

For contributors, please install Flash with packages for testing Flash and building docs.

# Clone Flash repository locally
git clone https://github.com/[your username]/lightning-flash.git
cd lightning-flash

# Install Flash in editable mode with extra packages for development
pip install -e '.[dev]'

Install with conda

Flash is available via conda forge. Install it with:

conda install -c conda-forge lightning-flash

Install from source

You can install Flash from source without any domain specific dependencies with:

pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git'

To install Flash with domain dependencies, use:

pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[image]'

You can again install dependencies for multiple domains by separating them with commas as above.

Flash Zero

Flash Zero is a zero-code machine learning platform. Here’s an image classification example to illustrate with one of the dozens tasks available.

Flash Zero in 3 steps

1. Select your task

flash {TASK_NAME}

Here is the list of currently supported tasks.

audio_classification     Classify audio spectrograms.
graph_classification     Classify graphs.
image_classification     Classify images.
instance_segmentation    Segment object instances in images.
keypoint_detection       Detect keypoints in images.
object_detection         Detect objects in images.
pointcloud_detection     Detect objects in point clouds.
pointcloud_segmentation  Segment objects in point clouds.
question_answering       Extractive Question Answering.
semantic_segmentation    Segment objects in images.
speech_recognition       Speech recognition.
style_transfer           Image style transfer.
summarization            Summarize text.
tabular_classification   Classify tabular data.
text_classification      Classify text.
translation              Translate text.
video_classification     Classify videos.

2. Pass in your own data

flash image_classification from_folders --train_folder data/hymenoptera_data/train

3. Modify the model and training parameters

flash image_classification --trainer.max_epochs 10 --model.backbone resnet50 from_folders --train_folder data/hymenoptera_data/train

Note

The trainer and model arguments should be placed before the source subcommand. Here it is from_folders.

Other Examples

Image Object Detection

To train an Object Detector on COCO 2017 dataset, you could use the following command:

flash object_detection from_coco --train_folder data/coco128/images/train2017/ --train_ann_file data/coco128/annotations/instances_train2017.json --val_split .3 --batch_size 8 --num_workers 4

Image Object Segmentation

To train an Image Segmenter on CARLA driving simulator dataset

flash semantic_segmentation from_folders --train_folder data/CameraRGB --train_target_folder data/CameraSeg --num_classes 21

Below is an example where the head, the backbone and its pretrained weights are customized.

flash semantic_segmentation --model.head fpn --model.backbone efficientnet-b0 --model.pretrained advprop from_folders --train_folder data/CameraRGB --train_target_folder data/CameraSeg --num_classes 21

Video Classification

To train an Video Classifier on the Kinetics dataset, you could use the following command:

flash video_classification from_folders --train_folder data/kinetics/train/ --clip_duration 1 --num_workers 0

CLI options

Flash Zero is built on top of the lightning CLI, so the trainer and model arguments can be configured either from the command line or from a config file. For example, to run the image classifier for 10 epochs with a resnet50 backbone you can use:

flash image_classification --trainer.max_epochs 10 --model.backbone resnet50

To view all of the available options for a task, run:

flash image_classification --help

Using Your Own Data

Flash Zero works with your own data through subcommands. The available subcommands for each task are given at the bottom of their help pages (e.g. when running flash image-classification --help). You can then use the required subcommand to train on your own data. Let’s look at an example using the Hymenoptera data from the Image Classification guide. First, download and unzip your data:

curl https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip -o hymenoptera_data
unzip hymenoptera_data.zip

Now train with Flash Zero:

flash image_classification from_folders --train_folder ./hymenoptera_data/train

Getting Help

To find all available tasks, you can run:

flash --help

This will output the following:

Commands:
audio_classification     Classify audio spectrograms.
graph_classification     Classify graphs.
image_classification     Classify images.
instance_segmentation    Segment object instances in images.
keypoint_detection       Detect keypoints in images.
object_detection         Detect objects in images.
pointcloud_detection     Detect objects in point clouds.
pointcloud_segmentation  Segment objects in point clouds.
question_answering       Extractive Question Answering.
semantic_segmentation    Segment objects in images.
speech_recognition       Speech recognition.
style_transfer           Image style transfer.
summarization            Summarize text.
tabular_classification   Classify tabular data.
text_classification      Classify text.
translation              Translate text.
video_classification     Classify videos.

To get more information about a specific task, you can do the following:

flash image_classification --help

You can view the help page for each subcommand. For example, to view the options for training an image classifier from folders, you can run:

flash image_classification from_folders --help

Finally, you can generate a config.yaml file from the client to ease parameters modification by running:

flash image_classification --print_config > config.yaml

Flash in Production

Flash Serve

Flash Serve makes model deployment simple.

Server Side

from flash.image import SemanticSegmentation
from flash.image.segmentation.serialization import SegmentationLabels

model = SemanticSegmentation.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
)
model.serializer = SegmentationLabels(visualize=False)
model.serve()

Client Side

import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "road.png").open("rb") as f:
    imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())

Credits to @rlizzo, @hhsecond, @lantiga, @luiscape for building the Flash Serve Engine. Read all about it here.

Training from scratch

Some Flash tasks have been pretrained on large data sets. To accelerate your training, calling the finetune() method using a pretrained backbone will fine-tune the backbone to generate a model customized to your data set and desired task.

From the Quick Start guide.

To train a task from scratch:

  1. Load your data and organize it using a DataModule customized for the task (example: ImageClassificationData).

  2. Choose and initialize your Task (setting pretrained=False) which has state-of-the-art backbones built in (example: ImageClassifier).

  3. Init a flash.core.trainer.Trainer or a pytorch_lightning.trainer.Trainer.

  4. Call flash.core.trainer.Trainer.fit() with your data set.

  5. Save your trained model.


Here’s an example:

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Train the model
trainer.fit(model, datamodule=datamodule)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Training options

Flash tasks supports many advanced training functionalities out-of-the-box, such as:

  • limit number of epochs

# train for 10 epochs
flash.Trainer(max_epochs=10)
  • Training on GPUs

# train on 1 GPU
flash.Trainer(gpus=1)
  • Training on multiple GPUs

# train on multiple GPUs
flash.Trainer(gpus=4)
# train on gpu 1, 3, 5 (3 gpus total)
flash.Trainer(gpus=[1, 3, 5])
  • Using mixed precision training

# Multi GPU with mixed precision
flash.Trainer(gpus=2, precision=16)
  • Training on TPUs

# Train on TPUs
flash.Trainer(tpu_cores=8)

You can add to the flash Trainer any argument from the Lightning trainer! Learn more about the Lightning Trainer here.

Finetuning

Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset.


Terminology

Here are common terms you need to be familiar with:

Terminology

Term

Definition

Finetuning

The process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset

Transfer learning

The common name for finetuning

Backbone

The neural network that was pretrained on a different dataset

Head

Another neural network (usually smaller) that maps the backbone to your particular dataset

Freeze

Disabling gradient updates to a model (ie: not learning)

Unfreeze

Enabling gradient updates to a model


Finetuning in Flash

From the Quick Start guide.

To use a Task for finetuning:

  1. Load your data and organize it using a DataModule customized for the task (example: ImageClassificationData).

  2. Choose and initialize your Task which has state-of-the-art backbones built in (example: ImageClassifier).

  3. Init a flash.core.trainer.Trainer.

  4. Choose a finetune strategy (example: “freeze”) and call flash.core.trainer.Trainer.finetune() with your data.

  5. Save your finetuned model.


Here’s an example of finetuning.

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Using a finetuned model

Once you’ve finetuned, use the model to predict:

# Serialize predictions as labels, automatically inferred from the training data in part 2.
model.serializer = Labels()

predictions = model.predict(
    [
        "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
        "data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
    ]
)
print(predictions)

We get the following output:

['bees', 'ants']

Or you can use the saved model for prediction anywhere you want!

from flash.image import ImageClassifier

# load finetuned checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")

predictions = model.predict("path/to/your/own/image.png")

Finetune strategies

Finetuning is very task specific. Each task encodes the best finetuning practices for that task. However, Flash gives you a few default strategies for finetuning.

Finetuning operates on two things, the model backbone and the head. The backbone is the neural network that was pre-trained. The head is another neural network that bridges between the backbone and your particular dataset.

no_freeze

In this strategy, the backbone and the head are unfrozen from the beginning.

trainer.finetune(model, datamodule, strategy="no_freeze")

In pseudocode, this looks like:

backbone = Resnet50()
head = nn.Linear(...)

backbone.unfreeze()
head.unfreeze()

train(backbone, head)

freeze

The freeze strategy keeps the backbone frozen throughout.

trainer.finetune(model, datamodule, strategy="freeze")

The pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head)

Advanced strategies

Every finetune strategy can also be customized.

freeze_unfreeze

By default, in this strategy the backbone is frozen for 5 epochs then unfrozen:

trainer.finetune(model, datamodule, strategy="freeze_unfreeze")

Or we can customize it unfreeze the backbone after a different epoch. For example, to unfreeze after epoch 7:

from flash.core.finetuning import FreezeUnfreeze

trainer.finetune(model, datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=7))

Under the hood, the pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head, epochs=10)

# unfreeze after 10 epochs
backbone.unfreeze()

train(backbone, head)

unfreeze_milestones

This strategy allows you to unfreeze part of the backbone at predetermined intervals

Here’s an example where: - backbone starts frozen - at epoch 3 the last 2 layers unfreeze - at epoch 8 the full backbone unfreezes


from flash.core.finetuning import UnfreezeMilestones

trainer.finetune(model, datamodule, strategy=UnfreezeMilestones(unfreeze_milestones=(3, 8), num_layers=2))

Under the hood, the pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head, epochs=3)

# unfreeze last 2 layers at epoch 3
backbone.unfreeze_last_layers(2)

train(backbone, head, epochs=8)

# unfreeze the full backbone
backbone.unfreeze()

Custom Strategy

For even more customization, create your own finetuning callback. Learn more about callbacks here.

from flash.core.finetuning import FlashBaseFinetuning

# Create a finetuning callback
class FeatureExtractorFreezeUnfreeze(FlashBaseFinetuning):
    def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True):
        # this will set self.attr_names as ["backbone"]
        super().__init__("backbone", train_bn)
        self._unfreeze_epoch = unfreeze_epoch

    def finetune_function(self, pl_module, current_epoch, optimizer, opt_idx):
        # unfreeze any module you want by overriding this function

        # When ``current_epoch`` is 5, backbone will start to be trained.
        if current_epoch == self._unfreeze_epoch:
            self.unfreeze_and_add_param_group(
                pl_module.backbone,
                optimizer,
            )


# Pass the callback to trainer.finetune
trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5))

Predictions (inference)

You can use Flash to get predictions on pretrained or finetuned models.

Predict on a single sample of data

You can pass in a sample of data (image file path, a string of text, etc) to the predict() method.

from flash.core.data.utils import download_data
from flash.image import ImageClassifier


# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 3. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)

Predict on a csv file

from flash.core.data.utils import download_data
from flash.tabular import TabularClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")

# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt")

# 3. Generate predictions from a csv file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)

Serializing predictions

To change how predictions are serialized you can attach a Serializer to your Task. For example, you can choose to serialize outputs as probabilities (for more options see the API reference below).

from flash.core.classification import Probabilities
from flash.core.data.utils import download_data
from flash.image import ImageClassifier


# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 3. Attach the Serializer
model.serializer = Probabilities()

# 4. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)
# out: [[0.5926494598388672, 0.40735048055648804]]

TorchScript JIT Support

We test all of our tasks for compatibility with torch.jit. This table gives a breakdown of the supported features.

Task

torch.jit.script()

torch.jit.trace()

torch.jit.save()

ImageClassifier

Yes

Yes

Yes

ObjectDetector

Yes

No

Yes

ImageEmbedder

Yes

Yes

Yes

SemanticSegmentation

No

Yes

Yes

StyleTransfer

No

Yes

Yes

TabularClassifier

No

Yes

No

TextClassifier

No

Yes *

Yes

SummarizationTask

No

Yes

Yes

TranslationTask

No

Yes

Yes

VideoClassifier

No

Yes

Yes

* with strict=False

Data

DataFlow Gif

Terminology

Here are common terms you need to be familiar with:

Terminology

Term

Definition

Deserializer

The Deserializer provides a single deserialize() method.

DataModule

The DataModule contains the datasets, transforms and dataloaders.

DataPipeline

The DataPipeline is Flash internal object to manage Deserializer, DataSource, Preprocess, Postprocess, and Serializer objects.

DataSource

The DataSource provides load_data() and load_sample() hooks for creating data sets from metadata (such as folder names).

Preprocess

The Preprocess provides a simple hook-based API to encapsulate your pre-processing logic.

These hooks (such as pre_tensor_transform()) enable transformations to be applied to your data at every point along the pipeline (including on the device). The DataPipeline contains a system to call the right hooks when needed. The Preprocess hooks can be either overridden directly or provided as a dictionary of transforms (mapping hook name to callable transform).

Postprocess

The Postprocess provides a simple hook-based API to encapsulate your post-processing logic.

The Postprocess hooks cover from model outputs to predictions export.

Serializer

The Serializer provides a single serialize() method that is used to convert model outputs (after the Postprocess) to the desired output format during prediction.

How to use out-of-the-box Flash DataModules

Flash provides several DataModules with helpers functions. Check out the Image Classification section (or the sections for any of our other tasks) to learn more.

Data Processing

Currently, it is common practice to implement a torch.utils.data.Dataset and provide it to a torch.utils.data.DataLoader. However, after model training, it requires a lot of engineering overhead to make inference on raw data and deploy the model in production environment. Usually, extra processing logic should be added to bridge the gap between training data and raw data.

The DataSource class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way. The Preprocess and Postprocess classes can be used to manage the preprocessing and postprocessing transforms. The Serializer class provides the logic for converting Postprocess outputs to the desired predict format (e.g. classes, labels, probabilities, etc.).

By providing a series of hooks that can be overridden with custom data processing logic (or just targeted with transforms), Flash gives the user much more granular control over their data processing flow.

Here are the primary advantages:

  • Making inference on raw data simple

  • Make the code more readable, modular and self-contained

  • Data Augmentation experimentation is simpler

To change the processing behavior only on specific stages for a given hook, you can prefix each of the Preprocess and Postprocess hooks by adding train, val, test or predict.

Check out Preprocess for some examples.

How to customize existing DataModules

Any Flash DataModule can be created directly from datasets using the from_datasets() like this:

from flash import DataModule, Trainer

data_module = DataModule.from_datasets(train_dataset=MyDataset())
trainer = Trainer()
trainer.fit(model, data_module=data_module)

The DataModule provides additional classmethod helpers (from_*) for loading data from various sources. In each from_* method, the DataModule internally retrieves the correct DataSource to use from the Preprocess. Flash AutoDataset instances are created from the DataSource for train, val, test, and predict. The DataModule populates the DataLoader for each stage with the corresponding AutoDataset.

Customize preprocessing of DataModules

The Preprocess contains the processing logic related to a given task. Each Preprocess provides some default transforms through the default_transforms() method. Users can easily override these by providing their own transforms to the DataModule. Here’s an example:

from flash.core.data.transforms import ApplyToKeys
from flash.image import ImageClassificationData, ImageClassifier

transform = {"to_tensor_transform": ApplyToKeys("input", my_to_tensor_transform)}

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    train_transform=transform,
    val_transform=transform,
    test_transform=transform,
)

Alternatively, the user may directly override the hooks for their needs like this:

from typing import Any, Dict
from flash.image import ImageClassificationData, ImageClassifier, ImageClassificationPreprocess


class CustomImageClassificationPreprocess(ImageClassificationPreprocess):
    def to_tensor_transform(sample: Dict[str, Any]) -> Dict[str, Any]:
        sample["input"] = my_to_tensor_transform(sample["input"])
        return sample


datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    preprocess=CustomImageClassificationPreprocess(),
)

Create your own Preprocess and DataModule

The example below shows a very simple ImageClassificationPreprocess with a single ImageClassificationFoldersDataSource and an ImageClassificationDataModule.

1. User-Facing API design

Designing an easy-to-use API is key. This is the first and most important step. We want the ImageClassificationDataModule to generate a dataset from folders of images arranged in this way.

Example:

train/dog/xxx.png
train/dog/xxy.png
train/dog/xxz.png
train/cat/123.png
train/cat/nsdf3.png
train/cat/asd932.png

Example:

dm = ImageClassificationDataModule.from_folders(
    train_folder="./data/train",
    val_folder="./data/val",
    test_folder="./data/test",
    predict_folder="./data/predict",
)

model = ImageClassifier(...)
trainer = Trainer(...)

trainer.fit(model, dm)

2. The DataSource

We start by implementing the ImageClassificationFoldersDataSource. The load_data method will produce a list of files and targets from the given directory. The load_sample method will load the given file as a PIL.Image. Here’s the full ImageClassificationFoldersDataSource:

from PIL import Image
from torchvision.datasets.folder import make_dataset
from typing import Any, Dict
from flash.core.data.data_source import DataSource, DefaultDataKeys


class ImageClassificationFoldersDataSource(DataSource):
    def load_data(self, folder: str, dataset: Any) -> Iterable:
        # The dataset is optional but can be useful to save some metadata.

        # `metadata` contains the image path and its corresponding label
        # with the following structure:
        # [(image_path_1, label_1), ... (image_path_n, label_n)].
        metadata = make_dataset(folder)

        # for the train `AutoDataset`, we want to store the `num_classes`.
        if self.training:
            dataset.num_classes = len(np.unique([m[1] for m in metadata]))

        return [
            {
                DefaultDataKeys.INPUT: file,
                DefaultDataKeys.TARGET: target,
            }
            for file, target in metadata
        ]

    def predict_load_data(self, predict_folder: str) -> Iterable:
        # This returns [image_path_1, ... image_path_m].
        return [{DefaultDataKeys.INPUT: file} for file in os.listdir(folder)]

    def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
        sample[DefaultDataKeys.INPUT] = Image.open(sample[DefaultDataKeys.INPUT])
        return sample

Note

We return samples as dictionaries using the DefaultDataKeys by convention. This is the recommended (although not required) way to represent data in Flash.

3. The Preprocess

Next, implement your custom ImageClassificationPreprocess with some default transforms and a reference to the data source:

from typing import Any, Callable, Dict, Optional
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Preprocess
import torchvision.transforms.functional as T

# Subclass `Preprocess`
class ImageClassificationPreprocess(Preprocess):
    def __init__(
        self,
        train_transform: Optional[Dict[str, Callable]] = None,
        val_transform: Optional[Dict[str, Callable]] = None,
        test_transform: Optional[Dict[str, Callable]] = None,
        predict_transform: Optional[Dict[str, Callable]] = None,
    ):
        super().__init__(
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            predict_transform=predict_transform,
            data_sources={
                DefaultDataSources.FOLDERS: ImageClassificationFoldersDataSource(),
            },
            default_data_source=DefaultDataSources.FOLDERS,
        )

    def get_state_dict(self) -> Dict[str, Any]:
        return {**self.transforms}

    @classmethod
    def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
        return cls(**state_dict)

    def default_transforms(self) -> Dict[str, Callable]:
        return {"to_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.to_tensor)}

4. The DataModule

Finally, let’s implement the ImageClassificationDataModule. We get the from_folders classmethod for free as we’ve registered a DefaultDataSources.FOLDERS data source in our ImageClassificationPreprocess. All we need to do is attach our Preprocess class like this:

from flash import DataModule


class ImageClassificationDataModule(DataModule):

    # Set `preprocess_cls` with your custom `Preprocess`.
    preprocess_cls = ImageClassificationPreprocess

How it works behind the scenes

DataSource

Note

The load_data() and load_sample() will be used to generate an AutoDataset object.

Here is the AutoDataset pseudo-code.

class AutoDataset:
    def __init__(
        self,
        data: List[Any],  # output of `DataSource.load_data`
        data_source: DataSource,
        running_stage: RunningStage,
    ):

        self.data = data
        self.data_source = data_source

    def __getitem__(self, index: int):
        return self.data_source.load_sample(self.data[index])

    def __len__(self):
        return len(self.data)

Preprocess

Note

The pre_tensor_transform(), to_tensor_transform(), post_tensor_transform(), collate(), per_batch_transform() are injected as the torch.utils.data.DataLoader.collate_fn function of the DataLoader.

Here is the pseudo code using the preprocess hooks name. Flash takes care of calling the right hooks for each stage.

Example:

# This will be wrapped into a :class:`~flash.core.data.batch._Preprocessor`.
def collate_fn(samples: Sequence[Any]) -> Any:

    # This will be wrapped into a :class:`~flash.core.data.batch._Sequential`
    for sample in samples:
        sample = pre_tensor_transform(sample)
        sample = to_tensor_transform(sample)
        sample = post_tensor_transform(sample)

    samples = type(samples)(samples)

    # if :func:`flash.core.data.process.Preprocess.per_sample_transform_on_device` hook is overridden,
    # those functions below will be no-ops

    samples = collate(samples)
    samples = per_batch_transform(samples)
    return samples

dataloader = DataLoader(dataset, collate_fn=collate_fn)

Note

The per_sample_transform_on_device, collate, per_batch_transform_on_device are injected after the LightningModule transfer_batch_to_device hook.

Here is the pseudo code using the preprocess hooks name. Flash takes care of calling the right hooks for each stage.

Example:

# This will be wrapped into a :class:`~flash.core.data.batch._Preprocessor`
def collate_fn(samples: Sequence[Any]) -> Any:

    # if ``per_batch_transform`` hook is overridden, those functions below will be no-ops
    samples = [per_sample_transform_on_device(sample) for sample in samples]
    samples = type(samples)(samples)
    samples = collate(samples)

    samples = per_batch_transform_on_device(samples)
    return samples

# move the data to device
data = lightning_module.transfer_data_to_device(data)
data = collate_fn(data)
predictions = lightning_module(data)

Postprocess and Serializer

Once the predictions have been generated by the Flash Task, the Flash DataPipeline will execute the Postprocess hooks and the Serializer behind the scenes.

First, the per_batch_transform() hooks will be applied on the batch predictions. Then, the uncollate() will split the batch into individual predictions. Next, the per_sample_transform() will be applied on each prediction. Finally, the serialize() method will be called to serialize the predictions.

Note

The transform can be applied either on device or CPU.

Here is the pseudo-code:

Example:

# This will be wrapped into a :class:`~flash.core.data.batch._Postprocessor`
def uncollate_fn(batch: Any) -> Any:

    batch = per_batch_transform(batch)

    samples = uncollate(batch)

    samples = [per_sample_transform(sample) for sample in samples]
    # only if serializers are enabled.
    return [serialize(sample) for sample in samples]

predictions = lightning_module(data)
return uncollate_fn(predictions)

Registry

Available Registries

Registries are Flash internal key-value database to store a mapping between a name and a function.

In simple words, they are just advanced dictionary storing a function from a key string.

Registries help organize code and make the functions accessible all across the Flash codebase. Each Flash Task can have several registries as static attributes.

Currently, Flash uses internally registries only for backbones, but more components will be added.

1. Imports

from functools import partial

from flash import Task
from flash.core.registry import FlashRegistry

2. Init a Registry

It is good practice to associate one or multiple registry to a Task as follow:

# creating a custom `Task` with its own registry
class MyImageClassifier(Task):

    backbones = FlashRegistry("backbones")

    def __init__(
        self,
        backbone: str = "resnet18",
        pretrained: bool = True,
    ):
        ...

        self.backbone, self.num_features = self.backbones.get(backbone)(pretrained=pretrained)

3. Adding new functions

Your custom functions can be registered within a FlashRegistry as a decorator or directly.

# Option 1: Used with partial.
def fn(backbone: str, pretrained: bool = True):
    # Create backbone and backbone output dimension (`num_features`)
    backbone, num_features = None, None
    return backbone, num_features


# HINT 1: Use `from functools import partial` if you want to store some arguments.
MyImageClassifier.backbones(fn=partial(fn, backbone="my_backbone"), name="username/partial_backbone")

# Option 2: Using decorator.
@MyImageClassifier.backbones(name="username/decorated_backbone")
def fn(pretrained: bool = True):
    # Create backbone and backbone output dimension (`num_features`)
    backbone, num_features = None, None
    return backbone, num_features

4. Accessing registered functions

You can now access your function from your task!

# 3.b Optional: List available backbones
print(MyImageClassifier.available_backbones())

# 4. Build the model
model = MyImageClassifier(backbone="username/decorated_backbone")

Here’s the output:

['username/decorated_backbone', 'username/partial_backbone']

5. Pre-registered backbones

Flash provides populated registries containing lots of available backbones.

Example:

from flash.image.backbones import OBJ_DETECTION_BACKBONES
from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES

print(IMAGE_CLASSIFIER_BACKBONES.available_keys())
""" out:
['adv_inception_v3', 'cspdarknet53', 'cspdarknet53_iabn', 430+.., 'xception71']
"""

Flash Serve

Flash Serve is a library to easily serve models in production.

Terminology

Here are common terms you need to be familiar with:

Terminology

Term

Definition

de-serialization

Transform data encoded as text into tensors

inference function

A function taking the decoded tensors and forward them through the model to produce predictions.

serialization

Transform the predictions tensors back to a text encoding.

ModelComponent

The ModelComponent contains the de-serialization, inference and serialization functions.

Servable

The Servable is an helper track the asset file related to a model

Composition

The Composition defines the computations / endpoints to create & run

expose()

The expose() function is a python decorator used to augment the ModelComponent inference function with de-serialization, serialization.

Example

In this tutorial, we will serve a Resnet18 from the PyTorchVision library in 3 steps.

The entire tutorial can be found under flash_examples/serve/generic.

Introduction

Traditionally, an inference pipeline is made out of 3 steps:

  • de-serialization: Transform data encoded as text into tensors.

  • inference function: A function taking the decoded tensors and forward them through the model to produce predictions.

  • serialization: Transform the predictions tensors back as text.

In this example, we will implement only the inference function as Flash Serve already provides some built-in de-serialization and serialization functions with Image

Step 1 - Create a ModelComponent

Inside inference_serve.py, we will implement a ClassificationInference class, which overrides ModelComponent.

First, we need make the following imports:

import torch
import torchvision

from flash.core.serve import Composition, Servable, ModelComponent, expose
from flash.core.serve.types import Image, Label
Data Serving Flow

To implement ClassificationInference, we need to implement a method responsible for inference function and decorated with the expose() function.

The name of the inference method isn’t constrained, but we will use classify as appropriate in this example.

Our classify function will take a tensor image, apply some normalization on it, and forward it through the model.

def classify(img):
    img = img.float() / 255
    mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
    std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
    img = (img - mean) / std
    img = img.permute(0, 3, 2, 1)
    out = self.model(img)
    return out.argmax()

The expose() is a python decorator extending the decorated function with the de-serialization, serialization steps.

Note

Flash Serve was designed this way to enable several models to be chained together by removing the decorator.

The expose() function takes 2 arguments:

  • inputs: Dictionary mapping the decorated function inputs to BaseType objects.

  • outputs: Dictionary mapping the decorated function outputs to BaseType objects.

A BaseType is a python dataclass which implements a serialize and deserialize function.

Note

Flash Serve has already several BaseType built-in such as Image or Text.

class ClassificationInference(ModelComponent):
    def __init__(self, model: Servable):
        self.model = model

    @expose(
        inputs={"img": Image()},
        outputs={"prediction": Label(path="imagenet_labels.txt")},
    )
    def classify(self, img):
        img = img.float() / 255
        mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
        std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
        img = (img - mean) / std
        img = img.permute(0, 3, 2, 1)
        out = self.model(img)
        return out.argmax()

Step 2 - Create a scripted Model

Using the PyTorchVision library, we create a resnet18 and use torch.jit.script to script the model.

Note

TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.

model = torchvision.models.resnet18(pretrained=True).eval()
torch.jit.script(model).save("resnet.pt")

Step 3 - Serve the model

The Servable takes as argument the path to the TorchScripted model and then will be passed to our ClassificationInference class.

The ClassificationInference instance will be passed as argument to a Composition class.

Once the Composition class is instantiated, just call its serve() method.

resnet = Servable("resnet.pt")
comp = ClassificationInference(resnet)
composition = Composition(classification=comp)
composition.serve()

Launching the server.

In Terminal 1

Just run:

python inference_server.py

And you should see this in your terminal

Data Serving Flow

You should also see an Swagger UI already built for you at http://127.0.0.1:8000/docs

Data Serving Flow
In Terminal 2

Run this script from another terminal:

import base64
from pathlib import Path

import requests

with Path("fish.jpg").open("rb") as f:
    imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"img": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
# {'session': 'UUID', 'result': {'prediction': 'goldfish, Carassius auratus'}}

Credits to @rlizzo, @hhsecond, @lantiga, @luiscape for building Flash Serve Engine.

Backbones and Heads

Backbones are the pre trained models that can be used with a task. The backbones or heads that are available can be found by using the available_backbones and available_heads methods.

To get the available backbones for a task like ImageClassifier, run:

from flash.image import ImageClassifier

# get the backbones available for ImageClassifier
backbones = ImageClassifier.available_backbones()

# print the backbones
print(backbones)

To get the available heads for a task like SemanticSegmentation, run:

from flash.image import SemanticSegmentation

# get the heads available for SemanticSegmentation
heads = SemanticSegmentation.available_heads()

# print the heads
print(heads)

Optimization (Optimizers and Schedulers)

Using optimizers and learning rate schedulers with Flash has become easier and cleaner than ever.

With the use of Registry, instantiation of an optimzer or a learning rate scheduler can done with just a string.

Setting an optimizer to a task

Each task has a built-in method available_optimizers() which will list all the optimizers registered with Flash.

>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_optimizers()
['adadelta', ..., 'sgd']

To train / finetune a Task of your choice, just pass on a string.

from flash.image import ImageClassifier

model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4)

In order to customize specific parameters of the Optimizer, pass along a dictionary of kwargs with the string as a tuple.

from flash.image import ImageClassifier

model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=("Adam", {"amsgrad": True}), learning_rate=1e-4)

An alternative to customizing an optimizer using a tuple is to pass it as a callable.

from functools import partial
from torch.optim import Adam
from flash.image import ImageClassifier

model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=partial(Adam, amsgrad=True), learning_rate=1e-4)

Setting a Learning Rate Scheduler

Each task has a built-in method available_lr_schedulers() which will list all the learning rate schedulers registered with Flash.

>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_lr_schedulers()
['lambdalr', ..., 'cosineannealingwarmrestarts']

To train / finetune a Task of your choice, just pass on a string.

from flash.image import ImageClassifier

model = ImageClassifier(
    num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4, lr_scheduler="constant_schedule"
)

Note

"constant_schedule" and a few other lr schedulers will be available only if you have installed the transformers library from Hugging Face.

In order to customize specific parameters of the LR Scheduler, pass along a dictionary of kwargs with the string as a tuple.

from flash.image import ImageClassifier

model = ImageClassifier(
    num_classes=10,
    backbone="resnet18",
    optimizer="Adam",
    learning_rate=1e-4,
    lr_scheduler=("StepLR", {"step_size": 10}),
)

An alternative to customizing the LR Scheduler using a tuple is to pass it as a callable.

from functools import partial
from torch.optim.lr_scheduler import CyclicLR
from flash.image import ImageClassifier

model = ImageClassifier(
    num_classes=10,
    backbone="resnet18",
    optimizer="Adam",
    learning_rate=1e-4,
    lr_scheduler=partial(CyclicLR, step_size_up=1500, mode="exp_range", gamma=0.5),
)

Additionally, the lr_scheduler parameter also accepts the Lightning Scheduler configuration which can be passed on using a tuple.

The Lightning Scheduler configuration is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Flash requires that the Lightning Scheduler configuration contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on. Below is an example for this:

from flash.image import ImageClassifier

model = ImageClassifier(
    num_classes=10,
    backbone="resnet18",
    optimizer="Adam",
    learning_rate=1e-4,
    lr_scheduler=("reducelronplateau", {"mode": "max"}, {"monitor": "val_accuracy"}),
)

Note

Do not set the "scheduler" key in the Lightning Scheduler configuration, it will overriden with an instance of the provided scheduler key.

Pre-Registering optimizers and scheduler recipes

Flash registry also provides the flexiblty of registering functions. This feature is also provided in the Optimizer and Scheduler registry.

Using the optimizers and lr_schedulers decorator pertaining to each Task, custom optimizer and LR scheduler recipes can be pre-registered.

import torch
from flash.image import ImageClassifier


@ImageClassifier.lr_schedulers
def my_flash_steplr_recipe(optimizer):
    return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)


model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_flash_steplr_recipe")

Provider specific requirements

Schedulers

Certain LR Schedulers provided by Hugging Face require both num_training_steps and num_warmup_steps.

In order to use them in Flash, just provide num_warmup_steps as float between 0 and 1 which indicates the fraction of the training steps that will be used as warmup steps. Flash’s Trainer will take care of computing the number of training steps and number of warmup steps based on the flags that are set in the Trainer.

from flash.image import ImageClassifier

model = ImageClassifier(
    backbone="resnet18",
    num_classes=2,
    optimizer="Adam",
    lr_scheduler=("cosine_schedule_with_warmup", {"num_warmup_steps": 0.1}),
)

Image Classification

The Task

The task of identifying what is in an image is called image classification. Typically, Image Classification is used to identify images containing a single object. The task predicts which ‘class’ the image most likely belongs to with a degree of certainty. A class is a label that describes what is in an image, such as ‘car’, ‘house’, ‘cat’ etc.


Example

Let’s look at the task of predicting whether images contain Ants or Bees using the hymenoptera dataset. The dataset contains train and validation folders, and then each folder contains a bees folder, with pictures of bees, and an ants folder with images of, you guessed it, ants.

hymenoptera_data
├── train
│   ├── ants
│   │   ├── 0013035.jpg
│   │   ├── 1030023514_aad5c608f9.jpg
│   │   ...
│   └── bees
│       ├── 1092977343_cb42b38d62.jpg
│       ├── 1093831624_fb5fbe2308.jpg
│       ...
└── val
    ├── ants
    │   ├── 10308379_1b6c72e180.jpg
    │   ├── 1053149811_f62a3410d3.jpg
    │   ...
    └── bees
        ├── 1032546534_06907fe3b3.jpg
        ├── 10870992_eebeeb3a12.jpg
        ...

Once we’ve downloaded the data using download_data(), we create the ImageClassificationData. We select a pre-trained backbone to use for our ImageClassifier and fine-tune on the hymenoptera data. We then use the trained ImageClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
)

# 2. Build the task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict what's on a few images! ants or bees?
predictions = model.predict(
    [
        "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
        "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
        "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Flash Zero

The image classifier can be used directly from the command line with zero code using Flash Zero. You can run the hymenoptera example with:

flash image_classification

To view configuration options and options for running the image classifier with your own data, use:

flash image_classification --help

Loading Data

This section details the available ways to load your own data into the ImageClassificationData.

from_folders

Construct the ImageClassificationData from folders.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, .npy.

For train, test, and val data, the folders are expected to contain a sub-folder for each class. Here’s the required structure:

train_folder
├── class_1
│   ├── file1.jpg
│   ├── file2.jpg
│   ...
└── class_2
    ├── file1.jpg
    ├── file2.jpg
    ...

For prediction, the folder is expected to contain the files for inference, like this:

predict_folder
├── file1.jpg
├── file2.jpg
...

Example:

data_module = ImageClassificationData.from_folders(
    train_folder = "./train_folder",
    predict_folder = "./predict_folder",
    ...
)

from_files

Construct the ImageClassificationData from lists of files and corresponding lists of targets.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, .npy.

Example:

train_files = ["file1.jpg", "file2.jpg", "file3.jpg", ...]
train_targets = [0, 1, 0, ...]

datamodule = ImageClassificationData.from_files(
    train_files = train_files,
    train_targets = train_targets,
    ...
)

from_datasets

Construct the ImageClassificationData from the given datasets for each stage.

Example:

from torch.utils.data.dataset import Dataset

train_dataset: Dataset = ...

datamodule = ImageClassificationData.from_datasets(
    train_dataset = train_dataset,
    ...
)

Note

The __getitem__ of your datasets should return a dictionary with "input" and "target" keys which map to the input image (as a PIL.Image) and the target (as an int or list of ints) respectively.


Custom Transformations

Flash automatically applies some default image transformations and augmentations, but you may wish to customize these for your own use case. The base Preprocess defines 7 hooks for different stages in the data loading pipeline. To apply image augmentations you can directly import the default_transforms from flash.image.classification.transforms and then merge your custom image transformations with them using the merge_transforms() helper function. Here’s an example where we load the default transforms and merge with custom torchvision transformations. We use the post_tensor_transform hook to apply the transformations after the image has been converted to a torch.Tensor.

from torchvision import transforms as T

import flash
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.transforms import ApplyToKeys, merge_transforms
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.transforms import default_transforms

post_tensor_transform = ApplyToKeys(
    DefaultDataKeys.INPUT,
    T.Compose([T.RandomHorizontalFlip(), T.ColorJitter(), T.RandomAutocontrast(), T.RandomPerspective()]),
)

new_transforms = merge_transforms(default_transforms((64, 64)), {"post_tensor_transform": post_tensor_transform})

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", train_transform=new_transforms
)

model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

Serving

The ImageClassifier is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.image import ImageClassifier

model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model.serve()

You can now perform inference from your client like this:

import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
    imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())

Multi-label Image Classification

The Task

Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality (images in this case). Multi-label image classification is supported by the ImageClassifier via the multi-label argument.


Example

Let’s look at the task of trying to predict the movie genres from an image of the movie poster. The data we will use is a subset of the awesome movie poster genre prediction data set from the paper “Movie Genre Classification based on Poster Images with Deep Neural Networks” by Wei-Ta Chu and Hung-Jui Guo, resized to 128 by 128. Take a look at their paper (and please consider citing their paper if you use the data) here: www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/. The data set contains train and validation folders, and then each folder contains images and a metadata.csv which stores the labels. Here’s an overview:

movie_posters
├── train
│   ├── metadata.csv
│   ├── tt0084058.jpg
│   ├── tt0084867.jpg
│   ...
└── val
    ├── metadata.csv
    ├── tt0200465.jpg
    ├── tt0326965.jpg
    ...

Once we’ve downloaded the data using download_data(), we need to create the ImageClassificationData. We first create a function (load_data) to extract the list of images and associated labels which can then be passed to from_files(). We select a pre-trained backbone to use for our ImageClassifier and fine-tune on the posters data. We then use the trained ImageClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# 1. Create the DataModule
# Data set from the paper “Movie Genre Classification based on Poster Images with Deep Neural Networks”.
# More info here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/
download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip")

datamodule = ImageClassificationData.from_csv(
    "Id",
    ["Action", "Romance", "Crime", "Thriller", "Adventure"],
    train_file="data/movie_posters/train/metadata.csv",
    val_file="data/movie_posters/val/metadata.csv",
    image_size=(128, 128),
)

# 2. Build the task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, multi_label=True)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict the genre of a few movies!
predictions = model.predict(
    [
        "data/movie_posters/predict/tt0085318.jpg",
        "data/movie_posters/predict/tt0089461.jpg",
        "data/movie_posters/predict/tt0097179.jpg",
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("image_classification_multi_label_model.pt")

Flash Zero

The multi-label image classifier can be used directly from the command line with zero code using Flash Zero. You can run the movie posters example with:

flash image_classification from_movie_posters

To view configuration options and options for running the image classifier with your own data, use:

flash image_classification --help

Serving

The ImageClassifier is servable. For more information, see Image Classification.

Image Embedder

The Task

Image embedding encodes an image into a vector of features which can be used for a downstream task. This could include: clustering, similarity search, or classification.

The ImageEmbedder internally relies on VISSL.


Example

Let’s see how to configure a training strategy for the ImageEmbedder task. A vanilla DataModule object be created using standard Datasets as shown below. Then the user can configure the ImageEmbedder task with training_strategy, backbone, head and pretraining_transform. There are options provided to send additional arguments to config selections. This task can now be sent to the fit() method of Trainer.

Note

A lot of VISSL loss functions use hard-coded torch.distributed methods. The user is suggested to use accelerator=ddp even with a single GPU. Only barlow_twins training strategy works on the CPU. All other loss functions are configured to work on GPUs.

import torch
from torchvision.datasets import CIFAR10

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder

# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
    train_dataset=CIFAR10(".", download=True),
    batch_size=16,
)

# 2. Build the task
embedder = ImageEmbedder(
    backbone="resnet",
    training_strategy="barlow_twins",
    head="simclr_head",
    pretraining_transform="barlow_twins_transform",
    training_strategy_kwargs={"latent_embedding_dim": 128},
    pretraining_transform_kwargs={"size_crops": [196]},
)

# 3. Create the trainer and pre-train the encoder
# use accelerator='ddp' when using GPU(s),
# i.e. flash.Trainer(max_epochs=3, gpus=1, accelerator='ddp')
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)

# 4. Save the model!
trainer.save_checkpoint("image_embedder_model.pt")

# 5. Download the downstream prediction dataset and generate embeddings
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

embeddings = embedder.predict(
    [
        "data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
        "data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
    ]
)
# list of embeddings for images sent to the predict function
print(embeddings)

Object Detection

The Task

Object detection is the task of identifying objects in images and their associated classes and bounding boxes.

The ObjectDetector and ObjectDetectionData classes internally rely on IceVision.


Example

Let’s look at object detection with the COCO 128 data set, which contains 91 object classes. This is a subset of COCO train2017 with only 128 images. The data set is organized following the COCO format. Here’s an outline:

coco128
├── annotations
│   └── instances_train2017.json
├── images
│   └── train2017
│       ├── 000000000009.jpg
│       ├── 000000000025.jpg
│       ...
└── labels
    └── train2017
        ├── 000000000009.txt
        ├── 000000000025.txt
        ...

Once we’ve downloaded the data using download_data(), we can create the ObjectDetectionData. We select a pre-trained RetinaNet to use for our ObjectDetector and fine-tune on the COCO 128 data. We then use the trained ObjectDetector for inference. Finally, we save the model. Here’s the full example:

import flash
from flash.core.data.utils import download_data
from flash.image import ObjectDetectionData, ObjectDetector

# 1. Create the DataModule
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")

datamodule = ObjectDetectionData.from_coco(
    train_folder="data/coco128/images/train2017/",
    train_ann_file="data/coco128/annotations/instances_train2017.json",
    val_split=0.1,
    image_size=128,
)

# 2. Build the task
model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=128)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Detect objects in a few images!
predictions = model.predict(
    [
        "data/coco128/images/train2017/000000000625.jpg",
        "data/coco128/images/train2017/000000000626.jpg",
        "data/coco128/images/train2017/000000000629.jpg",
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("object_detection_model.pt")

Flash Zero

The object detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash object_detection

To view configuration options and options for running the object detector with your own data, use:

flash object_detection --help

Custom Transformations

Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case. The base Preprocess defines 7 hooks for different stages in the data loading pipeline. For object-detection tasks, you can leverage the transformations from Albumentations with the IceVisionTransformAdapter.

import albumentations as alb
from icevision.tfms import A

from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData

train_transform = {
    "pre_tensor_transform": transforms.IceVisionTransformAdapter(
        [*A.resize_and_pad(128), A.Normalize(), A.Flip(0.4), alb.RandomBrightnessContrast()]
    )
}

datamodule = ObjectDetectionData.from_coco(
    train_folder="data/coco128/images/train2017/",
    train_ann_file="data/coco128/annotations/instances_train2017.json",
    val_split=0.1,
    image_size=128,
    train_transform=train_transform,
)

Keypoint Detection

The Task

Keypoint detection is the task of identifying keypoints in images and their associated classes.

The KeypointDetector and KeypointDetectionData classes internally rely on IceVision.


Example

Let’s look at keypoint detection with BIWI Sample Keypoints (center of face) from IceData. Once we’ve downloaded the data, we can create the KeypointDetectionData. We select a keypoint_rcnn with a resnet18_fpn backbone to use for our KeypointDetector and fine-tune on the BIWI data. We then use the trained KeypointDetector for inference. Finally, we save the model. Here’s the full example:

import flash
from flash.core.utilities.imports import example_requires
from flash.image import KeypointDetectionData, KeypointDetector

example_requires("image")

import icedata  # noqa: E402

# 1. Create the DataModule
data_dir = icedata.biwi.load_data()

datamodule = KeypointDetectionData.from_folders(
    train_folder=data_dir,
    val_split=0.1,
    parser=icedata.biwi.parser,
)

# 2. Build the task
model = KeypointDetector(
    head="keypoint_rcnn",
    backbone="resnet18_fpn",
    num_keypoints=1,
    num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Detect objects in a few images!
predictions = model.predict(
    [
        str(data_dir / "biwi_sample/images/0.jpg"),
        str(data_dir / "biwi_sample/images/1.jpg"),
        str(data_dir / "biwi_sample/images/10.jpg"),
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("keypoint_detection_model.pt")

Flash Zero

The keypoint detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash keypoint_detection

To view configuration options and options for running the keypoint detector with your own data, use:

flash keypoint_detection --help

Instance Segmentation

The Task

Instance segmentation is the task of segmenting objects images and determining their associated classes.

The InstanceSegmentation and InstanceSegmentationData classes internally rely on IceVision.


Example

Let’s look at instance segmentation with The Oxford-IIIT Pet Dataset from IceData. Once we’ve downloaded the data, we can create the InstanceSegmentationData. We select a mask_rcnn with a resnet18_fpn backbone to use for our InstanceSegmentation and fine-tune on the pets data. We then use the trained InstanceSegmentation for inference. Finally, we save the model. Here’s the full example:

from functools import partial

import flash
from flash.core.utilities.imports import example_requires
from flash.image import InstanceSegmentation, InstanceSegmentationData

example_requires("image")

import icedata  # noqa: E402

# 1. Create the DataModule
data_dir = icedata.pets.load_data()

datamodule = InstanceSegmentationData.from_folders(
    train_folder=data_dir,
    val_split=0.1,
    parser=partial(icedata.pets.parser, mask=True),
)

# 2. Build the task
model = InstanceSegmentation(
    head="mask_rcnn",
    backbone="resnet18_fpn",
    num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Detect objects in a few images!
predictions = model.predict(
    [
        str(data_dir / "images/yorkshire_terrier_9.jpg"),
        str(data_dir / "images/yorkshire_terrier_12.jpg"),
        str(data_dir / "images/yorkshire_terrier_13.jpg"),
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("instance_segmentation_model.pt")

Flash Zero

The instance segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash instance_segmentation

To view configuration options and options for running the instance segmentation task with your own data, use:

flash instance_segmentation --help

Semantic Segmentation

The Task

Semantic Segmentation, or image segmentation, is the task of performing classification at a pixel-level, meaning each pixel will associated to a given class. See more: https://paperswithcode.com/task/semantic-segmentation


Example

Let’s look at an example using a data set generated with the CARLA driving simulator. The data was generated as part of the Kaggle Lyft Udacity Challenge. The data contains one folder of images and another folder with the corresponding segmentation masks. Here’s the structure:

data
├── CameraRGB
│   ├── F61-1.png
│   ├── F61-2.png
│       ...
└── CameraSeg
    ├── F61-1.png
    ├── F61-2.png
        ...

Once we’ve downloaded the data using download_data(), we create the SemanticSegmentationData. We select a pre-trained mobilenet_v3_large backbone with an fpn head to use for our SemanticSegmentation task and fine-tune on the CARLA data. We then use the trained SemanticSegmentation for inference. You can check the available pretrained weights for the backbones like this SemanticSegmentation.available_pretrained_weights(“resnet18”). Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData

# 1. Create the DataModule
# The data was generated with the  CARLA self-driving simulator as part of the Kaggle Lyft Udacity Challenge.
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
download_data(
    "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
    "./data",
)

datamodule = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    val_split=0.1,
    image_size=(256, 256),
    num_classes=21,
)

# 2. Build the task
model = SemanticSegmentation(
    backbone="mobilenetv3_large_100",
    head="fpn",
    num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Segment a few images!
predictions = model.predict(
    [
        "data/CameraRGB/F61-1.png",
        "data/CameraRGB/F62-1.png",
        "data/CameraRGB/F63-1.png",
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("semantic_segmentation_model.pt")

Flash Zero

The semantic segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash semantic_segmentation

To view configuration options and options for running the semantic segmentation task with your own data, use:

flash semantic_segmentation --help

Loading Data

This section details the available ways to load your own data into the SemanticSegmentationData.

from_folders

Construct the SemanticSegmentationData from folders.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp.

For train, test, and val data, we expect a folder containing inputs and another folder containing the masks. Here’s the required structure:

train_folder
├── inputs
│   ├── file1.jpg
│   ├── file2.jpg
│   ...
└── masks
    ├── file1.jpg
    ├── file2.jpg
    ...

For prediction, the folder is expected to contain the files for inference, like this:

predict_folder
├── file1.jpg
├── file2.jpg
...

Example:

data_module = SemanticSegmentationData.from_folders(
    train_folder = "./train_folder/inputs",
    train_target_folder = "./train_folder/masks",
    predict_folder = "./predict_folder",
    ...
)

from_files

Construct the SemanticSegmentationData from lists of input images and corresponding list of target images.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp.

Example:

train_files = ["file1.jpg", "file2.jpg", "file3.jpg", ...]
train_targets = ["mask1.jpg", "mask2.jpg", "mask3.jpg", ...]

datamodule = SemanticSegmentationData.from_files(
    train_files = train_files,
    train_targets = train_targets,
    ...
)

from_datasets

Construct the SemanticSegmentationData from the given datasets for each stage.

Example:

from torch.utils.data.dataset import Dataset

train_dataset: Dataset = ...

datamodule = SemanticSegmentationData.from_datasets(
    train_dataset = train_dataset,
    ...
)

Note

The __getitem__ of your datasets should return a dictionary with "input" and "target" keys which map to the input and target images as tensors.


Serving

The SemanticSegmentation task is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.image import SemanticSegmentation
from flash.image.segmentation.serialization import SegmentationLabels

model = SemanticSegmentation.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
)
model.serializer = SegmentationLabels(visualize=False)
model.serve()

You can now perform inference from your client like this:

import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "road.png").open("rb") as f:
    imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())

Style Transfer

The Task

The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content. The goal is that the output image looks like the content image, but “painted” in the style of the style reference image.

style_transfer_example

The StyleTransfer and StyleTransferData classes internally rely on pystiche.


Example

Let’s look at transferring the style from The Starry Night onto the images from the COCO 128 data set from the Object Detection Guide. Once we’ve downloaded the data using download_data(), we create the StyleTransferData. Next, we create our StyleTransfer task with the desired style image and fit on the COCO 128 images. We then use the trained StyleTransfer for inference. Finally, we save the model. Here’s the full example:

import os

import torch

import flash
from flash.core.data.utils import download_data
from flash.image.style_transfer import StyleTransfer, StyleTransferData

# 1. Create the DataModule
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "./data")

datamodule = StyleTransferData.from_folders(train_folder="data/coco128/images/train2017")

# 2. Build the task
model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"))

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Apply style transfer to a few images!
predictions = model.predict(
    [
        "data/coco128/images/train2017/000000000625.jpg",
        "data/coco128/images/train2017/000000000626.jpg",
        "data/coco128/images/train2017/000000000629.jpg",
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("style_transfer_model.pt")

Flash Zero

The style transfer task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash style_transfer

To view configuration options and options for running the style transfer task with your own data, use:

flash style_transfer --help

Video Classification

The Task

Typically, Video Classification refers to the task of producing a label for actions identified in a given video. The task is to predict which class the video clip belongs to.

Lightning Flash VideoClassifier and VideoClassificationData classes internally rely on PyTorchVideo.


Example

Let’s develop a model to classifying video clips of Humans performing actions (such as: archery , bowling, etc.). We’ll use data from the Kinetics dataset. Here’s an outline of the folder structure:

video_dataset
├── train
│   ├── archery
│   │   ├── -1q7jA3DXQM_000005_000015.mp4
│   │   ├── -5NN5hdIwTc_000036_000046.mp4
│   │   ...
│   ├── bowling
│   │   ├── -5ExwuF5IUI_000030_000040.mp4
│   │   ├── -7sTNNI1Bcg_000075_000085.mp4
│   ... ...
└── val
    ├── archery
    │   ├── 0S-P4lr_c7s_000022_000032.mp4
    │   ├── 2x1lIrgKxYo_000589_000599.mp4
    │   ...
    ├── bowling
    │   ├── 1W7HNDBA4pA_000002_000012.mp4
    │   ├── 4JxH3S5JwMs_000003_000013.mp4
    ... ...

Once we’ve downloaded the data using download_data(), we create the VideoClassificationData. We select a pre-trained backbone to use for our VideoClassifier and fine-tune on the Kinetics data. The backbone can be any model from the PyTorchVideo Model Zoo. We then use the trained VideoClassifier for inference. Finally, we save the model. Here’s the full example:

import os

import torch

import flash
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier

# 1. Create the DataModule
# Find more datasets at https://pytorchvideo.readthedocs.io/en/latest/data.html
download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip", "./data")

datamodule = VideoClassificationData.from_folders(
    train_folder=os.path.join(os.getcwd(), "data/kinetics/train"),
    val_folder=os.path.join(os.getcwd(), "data/kinetics/val"),
    clip_sampler="uniform",
    clip_duration=1,
    decode_audio=False,
)

# 2. Build the task
model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes, pretrained=False)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Make a prediction
predictions = model.predict(os.path.join(os.getcwd(), "data/kinetics/predict"))
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("video_classification.pt")

Flash Zero

The video classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash video_classification

To view configuration options and options for running the video classifier with your own data, use:

flash video_classification --help

Audio Classification

The Task

The task of identifying what is in an audio file is called audio classification. Typically, Audio Classification is used to identify audio files containing sounds or words. The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty. A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc.


Example

Let’s look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset. The dataset contains train, val and test folders, and then each folder contains a airconditioner folder, with spectrograms generated from air-conditioner sounds, siren folder with spectrograms generated from siren sounds and the same goes for the other classes.

urban8k_images
├── train
│   ├── air_conditioner
│   ├── car_horn
│   ├── children_playing
│   ├── dog_bark
│   ├── drilling
│   ├── engine_idling
│   ├── gun_shot
│   ├── jackhammer
│   ├── siren
│   └── street_music
├── test
│   ├── air_conditioner
│   ├── car_horn
│   ├── children_playing
│   ├── dog_bark
│   ├── drilling
│   ├── engine_idling
│   ├── gun_shot
│   ├── jackhammer
│   ├── siren
│   └── street_music
└── val
    ├── air_conditioner
    ├── car_horn
    ├── children_playing
    ├── dog_bark
    ├── drilling
    ├── engine_idling
    ├── gun_shot
    ├── jackhammer
    ├── siren
    └── street_music

        ...

Once we’ve downloaded the data using download_data(), we create the AudioClassificationData. We select a pre-trained backbone to use for our ImageClassifier and fine-tune on the UrbanSound8k spectrogram images data. We then use the trained ImageClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.audio import AudioClassificationData
from flash.core.data.utils import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.image import ImageClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data")

datamodule = AudioClassificationData.from_folders(
    train_folder="data/urban8k_images/train",
    val_folder="data/urban8k_images/val",
    spectrogram_size=(64, 64),
)

# 2. Build the model.
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))

# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c
predictions = model.predict(
    [
        "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg",
        "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg",
        "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg",
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("audio_classification_model.pt")

Flash Zero

The audio classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash audio_classification

To view configuration options and options for running the audio classifier with your own data, use:

flash audio_classification --help

Loading Data

This section details the available ways to load your own data into the AudioClassificationData.

from_folders

Construct the AudioClassificationData from folders.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, .npy.

For train, test, and val data, the folders are expected to contain a sub-folder for each class. Here’s the required structure:

train_folder
├── class_1
│   ├── file1.jpg
│   ├── file2.jpg
│   ...
└── class_2
    ├── file1.jpg
    ├── file2.jpg
    ...

For prediction, the folder is expected to contain the files for inference, like this:

predict_folder
├── file1.jpg
├── file2.jpg
...

Example:

data_module = AudioClassificationData.from_folders(
    train_folder = "./train_folder",
    predict_folder = "./predict_folder",
    ...
)

from_files

Construct the AudioClassificationData from lists of files and corresponding lists of targets.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, .npy.

Example:

train_files = ["file1.jpg", "file2.jpg", "file3.jpg", ...]
train_targets = [0, 1, 0, ...]

datamodule = AudioClassificationData.from_files(
    train_files = train_files,
    train_targets = train_targets,
    ...
)

from_datasets

Construct the AudioClassificationData from the given datasets for each stage.

Example:

from torch.utils.data.dataset import Dataset

train_dataset: Dataset = ...

datamodule = AudioClassificationData.from_datasets(
    train_dataset = train_dataset,
    ...
)

Note

The __getitem__ of your datasets should return a dictionary with "input" and "target" keys which map to the input spectrogram image (as a NumPy array) and the target (as an int or list of ints) respectively.

Speech Recognition

The Task

Speech recognition is the task of classifying audio into a text transcription. We rely on Wav2Vec as our backbone, fine-tuned on labeled transcriptions for speech to text. Wav2Vec is pre-trained on thousand of hours of unlabeled audio, providing a strong baseline when fine-tuning to downstream tasks such as Speech Recognition.


Example

Let’s fine-tune the model onto our own labeled audio transcription data:

Here’s the structure our CSV file:

file,text
"/path/to/file_1.wav","what was said in file 1."
"/path/to/file_2.wav","what was said in file 2."
"/path/to/file_3.wav","what was said in file 3."
...

Alternatively, here is the structure of our JSON file:

{"file": "/path/to/file_1.wav", "text": "what was said in file 1."}
{"file": "/path/to/file_2.wav", "text": "what was said in file 2."}
{"file": "/path/to/file_3.wav", "text": "what was said in file 3."}

Once we’ve downloaded the data using download_data(), we create the SpeechRecognitionData. We select a pre-trained Wav2Vec backbone to use for our SpeechRecognition and finetune on a subset of the TIMIT corpus. The backbone can be any Wav2Vec model from HuggingFace transformers. Next, we use the trained SpeechRecognition for inference and save the model. Here’s the full example:

import torch

import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")

datamodule = SpeechRecognitionData.from_json(
    input_fields="file",
    target_fields="text",
    train_file="data/timit/train.json",
    test_file="data/timit/test.json",
)

# 2. Build the task
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

# 4. Predict on audio files!
predictions = model.predict(["data/timit/example.wav"])
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("speech_recognition_model.pt")

Flash Zero

The speech recognition task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash speech_recognition

To view configuration options and options for running the speech recognition task with your own data, use:

flash speech_recognition --help

Serving

The SpeechRecognition is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.audio import SpeechRecognition

model = SpeechRecognition.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/speech_recognition_model.pt")
model.serve()

You can now perform inference from your client like this:

import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f:
    audio_str = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": audio_str}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)

print(resp.json())

Tabular Classification

The Task

Tabular classification is the task of assigning a class to samples of structured or relational data. The TabularClassifier task can be used for classification of samples in more than two classes (multi-class classification).


Example

Let’s look at training a model to predict if passenger survival on the Titanic using the classic Kaggle data set. The data is provided in CSV files that look like this:

PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
1,0,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S
3,1,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S
5,0,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S
6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q
...

Once we’ve downloaded the data using download_data(), we can create the TabularData from our CSV files using the from_csv() method. From the API reference, we need to provide:

  • cat_cols- A list of the names of columns that contain categorical data (strings or integers).

  • num_cols- A list of the names of columns that contain numerical continuous data (floats).

  • target- The name of the column we want to predict.

  • train_csv- A CSV file containing the training data converted to a Pandas DataFrame

Next, we create the TabularClassifier and finetune on the Titanic data. We then use the trained TabularClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularClassificationData, TabularClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")

datamodule = TabularClassificationData.from_csv(
    ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
    "Fare",
    target_fields="Survived",
    train_file="data/titanic/titanic.csv",
    val_split=0.1,
)

# 2. Build the task
model = TabularClassifier.from_data(datamodule)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Generate predictions from a CSV
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("tabular_classification_model.pt")

Flash Zero

The tabular classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash tabular_classifier

To view configuration options and options for running the tabular classifier with your own data, use:

flash tabular_classifier --help

Serving

The TabularClassifier is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.core.classification import Labels
from flash.tabular import TabularClassifier

model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt")
model.serializer = Labels(["Did not survive", "Survived"])
model.serve()

You can now perform inference from your client like this:

import pandas as pd
import requests

from flash.core.data.utils import download_data

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")

df = pd.read_csv("./data/titanic/predict.csv")
text = str(df.to_csv())
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())

Text Classification

The Task

Text classification is the task of assigning a piece of text (word, sentence or document) an appropriate class, or category. The categories depend on the chosen data set and can range from topics.


Example

Let’s train a model to classify text as expressing either positive or negative sentiment. We will be using the IMDB data set, that contains a train.csv and valid.csv. Here’s the structure:

review,sentiment
"Japanese indie film with humor ... ",positive
"Isaac Florentine has made some ...",negative
"After seeing the low-budget ...",negative
"I've seen the original English version ...",positive
"Hunters chase what they think is a man through ...",negative
...

Once we’ve downloaded the data using download_data(), we create the TextClassificationData. We select a pre-trained backbone to use for our TextClassifier and finetune on the IMDB data. The backbone can be any BERT classification model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the TextClassifier and the TextClassificationData!

Next, we use the trained TextClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")

datamodule = TextClassificationData.from_csv(
    "review",
    "sentiment",
    train_file="data/imdb/train.csv",
    val_file="data/imdb/valid.csv",
    backbone="prajjwal1/bert-medium",
)

# 2. Build the task
model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Classify a few sentences! How was the movie?
predictions = model.predict(
    [
        "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
        "The worst movie in the history of cinema.",
        "I come from Bulgaria where it 's almost impossible to have a tornado.",
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("text_classification_model.pt")

Flash Zero

The text classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash text_classification

To view configuration options and options for running the text classifier with your own data, use:

flash text_classification --help

Serving

The TextClassifier is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.text import TextClassifier

model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
model.serve()

You can now perform inference from your client like this:

import requests

text = "Best movie ever"
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)

print(resp.json())

Accelerate Training & Inference with Torch ORT

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the TextClassifier once installed. See installation instructions here.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

...

model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)

Multi-label Text Classification

The Task

Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality (text in this case). Multi-label text classification is supported by the TextClassifier via the multi-label argument.


Example

Let’s look at the task of classifying comment toxicity. The data we will use in this example is from the kaggle toxic comment classification challenge by jigsaw: www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge. The data is stored in CSV files with this structure:

"id","comment_text","toxic","severe_toxic","obscene","threat","insult","identity_hate"
"0000997932d777bf","...",0,0,0,0,0,0
"0002bcb3da6cb337","...",1,1,1,0,1,0
"0005c987bdfc9d4b","...",1,0,0,0,0,0
...

Once we’ve downloaded the data using download_data(), we create the TextClassificationData. We select a pre-trained backbone to use for our TextClassifier and finetune on the toxic comments data. The backbone can be any BERT classification model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the TextClassifier and the TextClassificationData!

Next, we use the trained TextClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Create the DataModule
# Data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data")

datamodule = TextClassificationData.from_csv(
    "comment_text",
    ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
    train_file="data/jigsaw_toxic_comments/train.csv",
    val_split=0.1,
    backbone="unitary/toxic-bert",
)

# 2. Build the task
model = TextClassifier(
    backbone="unitary/toxic-bert",
    num_classes=datamodule.num_classes,
    multi_label=True,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Generate predictions for a few comments!
predictions = model.predict(
    [
        "No, he is an arrogant, self serving, immature idiot. Get it right.",
        "U SUCK HANNAH MONTANA",
        "Would you care to vote? Thx.",
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("text_classification_multi_label_model.pt")

Flash Zero

The multi-label text classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash text_classification from_toxic

To view configuration options and options for running the text classifier with your own data, use:

flash text_classification --help

Serving

The TextClassifier is servable. For more information, see Text Classification.

Question Answering

The Task

Question Answering is the task of being able to answer questions pertaining to some known context. For example, given a context about some historical figure, any question pertaininig to the context should be answerable. In our case the article would be our input context and question, and the answer would be the output sequence from the model.

Note

We currently only support Extractive Question Answering, like the task performed using the SQUAD like datasets.


Example

Let’s look at an example. We’ll use the SQUAD 2.0 dataset, which contains train-v2.0.json and dev-v2.0.json. Each JSON file looks like this:

{
            "answers": {
                    "answer_start": [94, 87, 94, 94],
                    "text": ["10th and 11th centuries", "in the 10th and 11th centuries", "10th and 11th centuries", "10th and 11th centuries"]
            },
            "context": "\"The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave thei...",
            "id": "56ddde6b9a695914005b9629",
            "question": "When were the Normans in Normandy?",
            "title": "Normans"
    }
...

In the above, the context key represents the context used for the question and answer, the question key represents the question being asked with respect to the context, the answer key stores the answer(s) for the question. id and title are used for unique identification and grouping concepts together respectively. Once we’ve downloaded the data using download_data(), we create the QuestionAnsweringData. We select a pre-trained backbone to use for our QuestionAnsweringTask and finetune on the SQUAD 2.0 data. The backbone can be any Question Answering model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the QuestionAnsweringData and the QuestionAnsweringTask!

Next, we use the trained QuestionAnsweringTask for inference. Finally, we save the model. Here’s the full example:

from flash import Trainer
from flash.core.data.utils import download_data
from flash.text import QuestionAnsweringData, QuestionAnsweringTask

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/squad_tiny.zip", "./data/")

datamodule = QuestionAnsweringData.from_squad_v2(
    train_file="./data/squad_tiny/train.json",
    val_file="./data/squad_tiny/val.json",
)

# 2. Build the task
model = QuestionAnsweringTask()

# 3. Create the trainer and finetune the model
trainer = Trainer(max_epochs=3, limit_train_batches=1, limit_val_batches=1)
trainer.finetune(model, datamodule=datamodule)

# 4. Answer some Questions!
predictions = model.predict(
    {
        "id": ["56ddde6b9a695914005b9629", "56ddde6b9a695914005b9628"],
        "context": [
            """
        The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th
        and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse
        ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under
        their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations
        of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their
        descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct
        cultural and ethnic identity of the Normans emerged initially in the first half of the 10th
        century, and it continued to evolve over the succeeding centuries.
        """,
            """
        The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th
        and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse
        ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under
        their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations
        of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their
        descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct
        cultural and ethnic identity of the Normans emerged initially in the first half of the 10th
        century, and it continued to evolve over the succeeding centuries.
        """,
        ],
        "question": ["When were the Normans in Normandy?", "In what country is Normandy located?"],
    }
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("question_answering_on_sqaud_v2.pt")

Accelerate Training & Inference with Torch ORT

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the QuestionAnsweringTask once installed. See installation instructions here.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

...

model = QuestionAnsweringTask(backbone="distilbert-base-uncased", max_answer_length=30, enable_ort=True)

Summarization

The Task

Summarization is the task of summarizing text from a larger document/article into a short sentence/description. For example, taking a web article and describing the topic in a short sentence. This task is a subset of Sequence to Sequence tasks, which require the model to generate a variable length sequence given an input sequence. In our case the article would be our input sequence, and the short description/sentence would be the output sequence from the model.


Example

Let’s look at an example. We’ll use the XSUM dataset, which contains a train.csv and valid.csv. Each CSV file looks like this:

input,target
"The researchers have sequenced the genome of a strain of bacterium that causes the virulent infection...","A team of UK scientists hopes to shed light on the mysteries of bleeding canker, a disease that is threatening the nation's horse chestnut trees."
"Knight was shot in the leg by an unknown gunman at Miami's Shore Club where West was holding a pre-MTV Awards...",Hip hop star Kanye West is being sued by Death Row Records founder Suge Knight over a shooting at a beach party in August 2005.
...

In the above, the input column represents the long articles/documents, and the target is the short description used as the target. Once we’ve downloaded the data using download_data(), we create the SummarizationData. We select a pre-trained backbone to use for our SummarizationTask and finetune on the XSUM data. The backbone can be any Seq2Seq summarization model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the SummarizationData and the SummarizationTask!

Next, we use the trained SummarizationTask for inference. Finally, we save the model. Here’s the full example:

from flash import Trainer
from flash.core.data.utils import download_data
from flash.text import SummarizationData, SummarizationTask

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "./data/")

datamodule = SummarizationData.from_csv(
    "input",
    "target",
    train_file="data/xsum/train.csv",
    val_file="data/xsum/valid.csv",
)

# 2. Build the task
model = SummarizationTask()

# 3. Create the trainer and finetune the model
trainer = Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule)

# 4. Summarize some text!
predictions = model.predict(
    """
    Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local
    people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue.
    They came to Brixton to see work which has started to revitalise the borough.
    It was Charles' first visit to the area since 1996, when he was accompanied by the former
    South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue
    for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit.
    ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes.
    She asked me were they ripe and I said yes - they're from the Dominican Republic.""
    Mr Chong is one of 170 local retailers who accept the Brixton Pound.
    Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market
    or in participating shops.
    During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children
    nearby on an estate off Coldharbour Lane. Mr West said:
    ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man.""
    He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'""
    Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust.
    The trust hopes to restore and refurbish the building,
    where once Jimi Hendrix and The Clash played, as a new community and business centre."
    """
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("summarization_model_xsum.pt")

Flash Zero

The summarization task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash summarization

To view configuration options and options for running the summarization task with your own data, use:

flash summarization --help

Serving

The SummarizationTask is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.text import SummarizationTask

model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt")
model.serve()

You can now perform inference from your client like this:

import requests

text = """
Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local
people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue.
They came to Brixton to see work which has started to revitalise the borough.
It was Charles' first visit to the area since 1996, when he was accompanied by the former
South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue
for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit.
""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes.
She asked me were they ripe and I said yes - they're from the Dominican Republic.""
Mr Chong is one of 170 local retailers who accept the Brixton Pound.
Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market
or in participating shops.
During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children
nearby on an estate off Coldharbour Lane. Mr West said:
""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man.""
He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'""
Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust.
The trust hopes to restore and refurbish the building,
where once Jimi Hendrix and The Clash played, as a new community and business centre."
"""
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)

print(resp.json())

Accelerate Training & Inference with Torch ORT

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the SummarizationTask once installed. See installation instructions here.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

...

model = SummarizationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)

Translation

The Task

Translation is the task of translating text from a source language to another, such as English to Romanian. This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case, the task will take an English sequence as input, and output the same sequence in Romanian.


Example

Let’s look at an example. We’ll use WMT16 English/Romanian, a dataset of English to Romanian samples, based on the Europarl corpora. The data set contains a train.csv and valid.csv. Each CSV file looks like this:

input,target
"Written statements and oral questions (tabling): see Minutes","Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal"
"Closure of sitting","Ridicarea şedinţei"
...

In the above the input/target columns represent the English and Romanian translation respectively. Once we’ve downloaded the data using download_data(), we create the TranslationData. We select a pre-trained backbone to use for our TranslationTask and finetune on the WMT16 data. The backbone can be any Seq2Seq translation model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the TranslationData and the TranslationTask!

Next, we use the trained TranslationTask for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.text import TranslationData, TranslationTask

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "./data")

datamodule = TranslationData.from_csv(
    "input",
    "target",
    train_file="data/wmt_en_ro/train.csv",
    val_file="data/wmt_en_ro/valid.csv",
    backbone="Helsinki-NLP/opus-mt-en-ro",
)

# 2. Build the task
model = TranslationTask(backbone="Helsinki-NLP/opus-mt-en-ro")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule)

# 4. Translate something!
predictions = model.predict(
    [
        "BBC News went to meet one of the project's first graduates.",
        "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",
        "Of course, it's still early in the election cycle.",
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("translation_model_en_ro.pt")

Flash Zero

The translation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash translation

To view configuration options and options for running the translation task with your own data, use:

flash translation --help

Serving

The TranslationTask is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.text import TranslationTask

model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")
model.serve()

You can now perform inference from your client like this:

import requests

text = "Some English text"
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)

print(resp.json())

Accelerate Training & Inference with Torch ORT

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the TranslationTask once installed. See installation instructions here.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

...

model = TranslationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)

Point Cloud Segmentation

The Task

A Point Cloud is a set of data points in space, usually describes by x, y and z coordinates.

PointCloud Segmentation is the task of performing classification at a point-level, meaning each point will associated to a given class. The current integration builds on top Open3D-ML.


Example

Let’s look at an example using a data set generated from the KITTI Vision Benchmark. The data are a tiny subset of the original dataset and contains sequences of point clouds. The data contains multiple folder, one for each sequence and a meta.yaml file describing the classes and their official associated color map. A sequence should contain one folder for scans and one folder for labels, plus a pose.txt to re-align the sequence if required. Here’s the structure:

data
├── meta.yaml
├── 00
│   ├── scans
|   |    ├── 00000.bin
|   |    ├── 00001.bin
|   |    ...
│   ├── labels
|   |    ├── 00000.label
|   |    ├── 00001.label
|   |   ...
|   ├── pose.txt
│   ...
|
└── XX
   ├── scans
   |    ├── 00000.bin
   |    ├── 00001.bin
   |    ...
   ├── labels
   |    ├── 00000.label
   |    ├── 00001.label
   |   ...
   ├── pose.txt

Learn more: http://www.semantic-kitti.org/dataset.html

Once we’ve downloaded the data using download_data(), we create the PointCloudSegmentationData. We select a pre-trained randlanet_semantic_kitti backbone for our PointCloudSegmentation task. We then use the trained PointCloudSegmentation for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData

# 1. Create the DataModule
# Dataset Credit: http://www.semantic-kitti.org/
download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/")

datamodule = PointCloudSegmentationData.from_folders(
    train_folder="data/SemanticKittiTiny/train",
    val_folder="data/SemanticKittiTiny/val",
)

# 2. Build the task
model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
    max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule)

# 4. Predict what's within a few PointClouds?
predictions = model.predict(
    [
        "data/SemanticKittiTiny/predict/000000.bin",
        "data/SemanticKittiTiny/predict/000001.bin",
    ]
)

# 5. Save the model!
trainer.save_checkpoint("pointcloud_segmentation_model.pt")
https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/getting_started_ml_visualizer.gif

Flash Zero

The point cloud segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash pointcloud_segmentation

To view configuration options and options for running the point cloud segmentation task with your own data, use:

flash pointcloud_segmentation --help

Point Cloud Object Detection

The Task

A Point Cloud is a set of data points in space, usually describes by x, y and z coordinates.

PointCloud Object Detection is the task of identifying 3D objects in point clouds and their associated classes and 3D bounding boxes.

The current integration builds on top Open3D-ML.


Example

Let’s look at an example using a data set generated from the KITTI Vision Benchmark. The data are a tiny subset of the original dataset and contains sequences of point clouds.

The data contains:
  • one folder for scans

  • one folder for scan calibrations

  • one folder for labels

  • a meta.yaml file describing the classes and their official associated color map.

Here’s the structure:

data
├── meta.yaml
├── train
│   ├── scans
|   |    ├── 00000.bin
|   |    ├── 00001.bin
|   |    ...
│   ├── calibs
|   |    ├── 00000.txt
|   |    ├── 00001.txt
|   |   ...
│   ├── labels
|   |    ├── 00000.txt
|   |    ├── 00001.txt
│   ...
├── val
│   ...
├── predict
    ├── scans
    |   ├── 00000.bin
    |   ├── 00001.bin
    |
    ├── calibs
    |   ├── 00000.txt
    |   ├── 00001.txt
    ├── meta.yaml

Learn more: http://www.semantic-kitti.org/dataset.html

Once we’ve downloaded the data using download_data(), we create the PointCloudObjectDetectorData. We select a pre-trained randlanet_semantic_kitti backbone for our PointCloudObjectDetector task. We then use the trained PointCloudObjectDetector for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData

# 1. Create the DataModule
# Dataset Credit: http://www.semantic-kitti.org/
download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/")

datamodule = PointCloudObjectDetectorData.from_folders(
    train_folder="data/KITTI_Tiny/Kitti/train",
    val_folder="data/KITTI_Tiny/Kitti/val",
)

# 2. Build the task
model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
    max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule)

# 4. Predict what's within a few PointClouds?
predictions = model.predict(
    [
        "data/KITTI_Tiny/Kitti/predict/scans/000000.bin",
        "data/KITTI_Tiny/Kitti/predict/scans/000001.bin",
    ]
)

# 5. Save the model!
trainer.save_checkpoint("pointcloud_detection_model.pt")
https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/visualizer_BoundingBoxes.png

Flash Zero

The point cloud object detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash pointcloud_detection

To view configuration options and options for running the point cloud object detector with your own data, use:

flash pointcloud_detection --help

Graph Classification

The Task

This task consist on classifying graphs. The task predicts which ‘class’ the graph belongs to. A class is a label that indicates the kind of graph. For example, a label may indicate whether one molecule interacts with another.

The GraphClassifier and GraphClassificationData classes internally rely on pytorch-geometric.


Example

Let’s look at the task of classifying graphs from the KKI data set from TU Dortmund University.

Once we’ve created the TUDataset, we create the GraphClassificationData. We then create our GraphClassifier and train on the KKI data. Next, we use the trained GraphClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.utilities.imports import example_requires
from flash.graph import GraphClassificationData, GraphClassifier

example_requires("graph")

from torch_geometric.datasets import TUDataset  # noqa: E402

# 1. Create the DataModule
dataset = TUDataset(root="data", name="KKI")

datamodule = GraphClassificationData.from_datasets(
    train_dataset=dataset,
    val_split=0.1,
)

# 2. Build the task
model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)

# 3. Create the trainer and fit the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Classify some graphs!
predictions = model.predict(dataset[:3])
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("graph_classification.pt")

Flash Zero

The graph classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash graph_classification

To view configuration options and options for running the graph classifier with your own data, use:

flash graph_classification --help

Providers

Flash is a framework integrator. We rely on many open source frameworks for our tasks, visualizations and backbones. Here’s a list of some of the providers we use for backbones and heads within Flash (check them out and star their repos to spread the open source love!):

You can also read our guides for some of our larger integrations:

BaaL

The framework Bayesian Active Learning (BaaL) is an active learning library developed at ElementAI.

Active Learning is a sub-field in AI, focusing on adding a human in the learning loop. The most uncertain samples will be labelled by the human to accelerate the model training cycle.

Credit to ElementAI / Baal Team for creating this diagram flow


With its integration within Flash, the Active Learning process is simpler than ever before.

import torch

import flash
from flash.core.classification import Probabilities
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")

# Implement the research use-case where we mask labels from labelled dataset.
datamodule = ActiveLearningDataModule(
    ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2),
    initial_num_labels=5,
    val_split=0.1,
)

# 2. Build the task
head = torch.nn.Sequential(
    torch.nn.Dropout(p=0.1),
    torch.nn.Linear(512, datamodule.num_classes),
)
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, serializer=Probabilities())


# 3.1 Create the trainer
trainer = flash.Trainer(max_epochs=3)

# 3.2 Create the active learning loop and connect it to the trainer
active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1)
active_learning_loop.connect(trainer.fit_loop)
trainer.fit_loop = active_learning_loop

# 3.3 Finetune
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict what's on a few images! ants or bees?
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

FiftyOne

We have collaborated with the team at Voxel51 to integrate their tool, FiftyOne, into Lightning Flash.

FiftyOne is an open-source tool for building high-quality datasets and computer vision models. The FiftyOne API and App enable you to visualize datasets and interpret models faster and more effectively.

This integration allows you to view predictions generated by your tasks in the FiftyOne App, as well as easily incorporate FiftyOne Datasets into your tasks. All image and video tasks are supported!

Installation

In order to utilize this integration, you will need to install FiftyOne:

pip install fiftyone

Visualizing Flash predictions

This section shows you how to augment your existing Lightning Flash workflows with a couple of lines of code that let you visualize predictions in FiftyOne. You can visualize predictions for classification, object detection, and semantic segmentation tasks. Doing so is as easy as updating your model to use one of the following serializers:

The visualize() function then lets you visualize your predictions in the FiftyOne App. This function accepts a list of dictionaries containing FiftyOne Label objects and filepaths, which is exactly the output of the FiftyOne serializers when the return_filepath=True option is specified.

from itertools import chain

import torch

import flash
from flash.core.classification import FiftyOneLabels, Labels
from flash.core.data.utils import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.core.integrations.fiftyone import visualize
from flash.image import ImageClassificationData, ImageClassifier

# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")

# 2 Load data
datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    predict_folder="data/hymenoptera_data/predict/",
)

# 3 Fine tune a model
model = ImageClassifier(
    backbone="resnet18",
    num_classes=datamodule.num_classes,
    serializer=Labels(),
)
trainer = flash.Trainer(
    max_epochs=1,
    gpus=torch.cuda.device_count(),
    limit_train_batches=1,
    limit_val_batches=1,
)
trainer.finetune(
    model,
    datamodule=datamodule,
    strategy=FreezeUnfreeze(unfreeze_epoch=1),
)
trainer.save_checkpoint("image_classification_model.pt")

# 4 Predict from checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model.serializer = FiftyOneLabels(return_filepath=True)  # output FiftyOne format
predictions = trainer.predict(model, datamodule=datamodule)
predictions = list(chain.from_iterable(predictions))  # flatten batches

# 5 Visualize predictions in FiftyOne App
# Optional: pass `wait=True` to block execution until App is closed
session = visualize(predictions)

The visualize() function can be used in all of the following environments:

  • Local Python shell: The App will launch in a new tab in your default web browser

  • Remote Python shell: Pass the remote=True option and then follow the instructions printed to your remote shell to open the App in your browser on your local machine

  • Jupyter notebook: The App will launch in the output of your current cell

  • Google Colab: The App will launch in the output of your current cell

  • Python script: Pass the wait=True option to block execution of your script until the App is closed

See this page for more information about using the FiftyOne App in different environments.

Using FiftyOne datasets

The above workflow is great for visualizing model predictions. However, if you store your data in a FiftyOne Dataset initially, then you can also visualize ground truth annotations. This allows you to perform more complex analysis with views into your data and evaluation of your model results.

The from_fiftyone() method allows you to load your FiftyOne datasets directly into a DataModule to be used for training, testing, or inference.

from itertools import chain

import fiftyone as fo
import torch

import flash
from flash.core.classification import FiftyOneLabels, Labels
from flash.core.data.utils import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.image import ImageClassificationData, ImageClassifier

# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")

# 2 Load data into FiftyOne
train_dataset = fo.Dataset.from_dir(
    dataset_dir="data/hymenoptera_data/train/",
    dataset_type=fo.types.ImageClassificationDirectoryTree,
)
val_dataset = fo.Dataset.from_dir(
    dataset_dir="data/hymenoptera_data/val/",
    dataset_type=fo.types.ImageClassificationDirectoryTree,
)
test_dataset = fo.Dataset.from_dir(
    dataset_dir="data/hymenoptera_data/test/",
    dataset_type=fo.types.ImageClassificationDirectoryTree,
)

# 3 Load data into Flash
datamodule = ImageClassificationData.from_fiftyone(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    test_dataset=test_dataset,
)

# 4 Fine tune model
model = ImageClassifier(
    backbone="resnet18",
    num_classes=datamodule.num_classes,
    serializer=Labels(),
)
trainer = flash.Trainer(
    max_epochs=1,
    gpus=torch.cuda.device_count(),
    limit_train_batches=1,
    limit_val_batches=1,
)
trainer.finetune(
    model,
    datamodule=datamodule,
    strategy=FreezeUnfreeze(unfreeze_epoch=1),
)
trainer.save_checkpoint("image_classification_model.pt")

# 5 Predict from checkpoint on data with ground truth
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model.serializer = FiftyOneLabels(return_filepath=False)  # output FiftyOne format
datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset)
predictions = trainer.predict(model, datamodule=datamodule)
predictions = list(chain.from_iterable(predictions))  # flatten batches

# 6 Add predictions to dataset
test_dataset.set_values("predictions", predictions)

# 7 Evaluate your model
results = test_dataset.evaluate_classifications("predictions", gt_field="ground_truth", eval_key="eval")
results.print_report()
plot = results.plot_confusion_matrix()
plot.show()

# 8 Visualize results in the App
session = fo.launch_app(test_dataset)

# Optional: block execution until App is closed
session.wait()

Visualizing embeddings

FiftyOne provides the methods for dimensionality reduction and interactive plotting. When combined with embedding tasks in Flash, you can accomplish powerful workflows like clustering, similarity search, pre-annotation, and more in only a few lines of code.

import fiftyone as fo
import fiftyone.brain as fob
import numpy as np

from flash.core.data.utils import download_data
from flash.image import ImageEmbedder

# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")

# 2 Load data into FiftyOne
dataset = fo.Dataset.from_dir(
    "data/hymenoptera_data/test/",
    fo.types.ImageClassificationDirectoryTree,
)

# 3 Load model
embedder = ImageEmbedder(backbone="resnet101")

# 4 Generate embeddings
filepaths = dataset.values("filepath")
embeddings = np.stack(embedder.predict(filepaths))

# 5 Visualize in FiftyOne App
results = fob.compute_visualization(dataset, embeddings=embeddings)
session = fo.launch_app(dataset)
plot = results.visualize(labels="ground_truth.label")
plot.show()

# Optional: block execution until App is closed
session.wait()
embeddings_example

IceVision

IceVision from airctic is an awesome computer vision framework which offers a curated collection of hundreds of high-quality pre-trained models for: object detection, keypoint detection, and instance segmentation. In Flash, we’ve integrated the IceVision framework to provide: data loading, augmentation, backbones, and heads. We use IceVision components in our: object detection, instance segmentation, and keypoint detection tasks. Take a look at their documentation and star IceVision on GitHub to spread the open source love!

IceData

The IceData library is a community driven dataset hub for IceVision. All of the datasets in IceData can be used out of the box with flash using our .from_folders methods and the parser argument. Take a look at our Keypoint Detection page for an example.

Albumentations with IceVision and Flash

IceVision provides two utilities for using the albumentations library with their models: - the Adapter helper class for adapting an any albumentations transform to work with IceVision records, - the aug_tfms utility function that returns a standard augmentation recipe to get the most out of your model.

In Flash, we use the aug_tfms as default transforms for the: object detection, instance segmentation, and keypoint detection tasks. You can also provide custom transforms from albumentations using the IceVisionTransformAdapter (which relies on the IceVision Adapter underneath). Here’s an example:

import albumentations as A

from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData

train_transform = {
    "pre_tensor_transform": IceVisionTransformAdapter([A.HorizontalFlip(), A.Normalize()]),
}

datamodule = ObjectDetectionData.from_coco(
    ...,
    train_transform=train_transform,
)

Learn2Learn

Learn2Learn is a software library for meta-learning research by Sébastien M. R. Arnold and al. (Aug 2020)


What is Meta-Learning and why you should care?

Humans can distinguish between new objects with little or no training data, However, machine learning models often require thousands, millions, billions of annotated data samples to achieve good performance while extrapolating their learned knowledge on unseen objects.

A machine learning model which could learn or learn to learn from only few new samples (K-shot learning) would have tremendous applications once deployed in production. In an extreme case, a model performing 1-shot or 0-shot learning could be the source of new kind of AI applications.

Meta-Learning is a sub-field of AI dedicated to the study of few-shot learning algorithms. This is often characterized as teaching deep learning models to learn with only a few labeled data. The goal is to repeatedly learn from K-shot examples during training that match the structure of the final K-shot used in production. It is important to note that the K-shot example seen in production are very likely to be completely out-of-distribution with new objects.

How does Meta-Learning work?

In meta-learning, the model is trained over multiple meta tasks. A meta task is the smallest unit of data and it represents the data available to the model once in its deployment environment. By doing so, we can optimise the model and get higher results.


For image classification, a meta task is comprised of shot + query elements for each class. The shots samples are used to adapt the parameters and the queries ones to update the original model weights. The classes used in the validation and testing shouldn’t be present within the training dataset, as the goal is to optimise the model performance on out-of-distribution (OOD) data with little label data.

When training the model with the meta-learning algorithm, the model will average its gradients over meta_batch_size meta tasks before performing an optimizer step. Traditionally, an meta epoch is composed of multiple meta batch.

Use Meta-Learning with Flash

With its integration within Flash, Meta Learning has never been simpler. Flash takes care of all the hard work: the tasks sampling, meta optimizer update, distributed training, etc…

Note

The users requires to provide a training dataset and testing dataset with no overlapping classes. Flash doesn’t support this feature out-of-the box.

Once done, the users are left to play the hyper-parameters associated with the meta-learning algorithm.

Here is an example using miniImageNet dataset containing 100 classes divided into 64 training, 16 validation, and 20 test classes.

# adapted from https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py#L154

import warnings

import kornia.augmentation as Ka
import kornia.geometry as Kg
import learn2learn as l2l
import torch
import torchvision
from torch import nn

import flash
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.transforms import ApplyToKeys, kornia_collate
from flash.image import ImageClassificationData, ImageClassifier

warnings.simplefilter("ignore")

# download MiniImagenet
train_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="train", download=True)
val_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="validation", download=True)
test_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="test", download=True)

train_transform = {
    "to_tensor_transform": nn.Sequential(
        ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
        ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
    ),
    "post_tensor_transform": ApplyToKeys(
        DefaultDataKeys.INPUT,
        Kg.Resize((196, 196)),
        # SPATIAL
        Ka.RandomHorizontalFlip(p=0.25),
        Ka.RandomRotation(degrees=90.0, p=0.25),
        Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25),
        Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
        # PIXEL-LEVEL
        Ka.ColorJitter(brightness=1 / 30, p=0.25),  # brightness
        Ka.ColorJitter(saturation=1 / 30, p=0.25),  # saturation
        Ka.ColorJitter(contrast=1 / 30, p=0.25),  # contrast
        Ka.ColorJitter(hue=1 / 30, p=0.25),  # hue
        Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25),
        Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25),
    ),
    "collate": kornia_collate,
    "per_batch_transform_on_device": ApplyToKeys(
        DefaultDataKeys.INPUT,
        Ka.RandomHorizontalFlip(p=0.25),
    ),
}

# construct datamodule
datamodule = ImageClassificationData.from_tensors(
    train_data=train_dataset.x,
    train_targets=torch.from_numpy(train_dataset.y.astype(int)),
    val_data=val_dataset.x,
    val_targets=torch.from_numpy(val_dataset.y.astype(int)),
    test_data=test_dataset.x,
    test_targets=torch.from_numpy(test_dataset.y.astype(int)),
    num_workers=4,
    train_transform=train_transform,
)

model = ImageClassifier(
    backbone="resnet18",
    training_strategy="prototypicalnetworks",
    training_strategy_kwargs={
        "epoch_length": 10 * 16,
        "meta_batch_size": 4,
        "num_tasks": 200,
        "test_num_tasks": 2000,
        "ways": datamodule.num_classes,
        "shots": 1,
        "test_ways": 5,
        "test_shots": 1,
        "test_queries": 15,
    },
    optimizer=torch.optim.Adam,
    optimizer_kwargs={"lr": 0.001},
)

trainer = flash.Trainer(
    max_epochs=200,
    gpus=2,
    accelerator="ddp_shared",
    precision=16,
)
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

You can read their paper Learn2Learn: A Library for Meta-Learning Research.

And don’t forget to cite Learn2Learn repository in your academic publications. Find their Biblex on their repository.

VISSL

VISSL is a library from Facebook AI Research for state-of-the-art self-supervised learning. We integrate VISSL models and algorithms into Flash with the image embedder task.

Using VISSL with Flash

The ImageEmbedder task in Flash can be configured with different backbones, projection heads, image transforms and loss functions so that you can train your feature extractor using a SOTA SSL method.

from flash.image import ImageEmbedder

embedder = ImageEmbedder(
    backbone="resnet",
    training_strategy="barlow_twins",
    head="simclr_head",
    pretraining_transform="barlow_twins_transform",
    training_strategy_kwargs={"latent_embedding_dim": 256, "dims": [2048, 2048, 256]},
    pretraining_transform_kwargs={"size_crops": [196]},
)

The user can pass arguments to the training strategy, image transforms and backbones using the optional dictionary arguments the ImageEmbedder task accepts. The training strategies club together the projection head, the loss function as well as VISSL hooks for a particular algorithm and the arguments to customize these can passed via training_strategy_kwargs. As an example, in the above code block, the latent_embedding_dim is an argument to the BarlowTwins loss function from VISSL, while the dims argument configures the projection head to output 256 dim vectors for the loss function.

If you find VISSL integration in Flash useful for your research, please don’t forget to cite us and the VISSL library. You can find our bibtex on Flash and VISSL’s bibxtex on their github page.

flash

DataSource

The DataSource class encapsulates two hooks: load_data and load_sample.

DataModule

A basic DataModule class for all Flash tasks.

FlashCallback

FlashCallback is an extension of pytorch_lightning.callbacks.Callback.

Preprocess

The Preprocess encapsulates all the data processing logic that should run before the data is passed to the model.

Postprocess

The Postprocess encapsulates all the data processing logic that should run after the model.

Serializer

A Serializer encapsulates a single serialize method which is used to convert the model output into the desired output format when predicting.

Task

A general Task.

Trainer

flash.core

flash.core.adapter

Adapter

The Adapter is a lightweight interface that can be used to encapsulate the logic from a particular provider within a Task.

AdapterTask

The AdapterTask is a Task which wraps an Adapter and forwards all of the hooks.

flash.core.classification

Classes

A Serializer which applies an argmax to the model outputs (either logits or probabilities) and converts to a list.

ClassificationSerializer

A base class for classification serializers.

ClassificationTask

FiftyOneLabels

A Serializer which converts the model outputs to FiftyOne classification format.

Labels

A Serializer which converts the model outputs (either logits or probabilities) to the label of the argmax classification.

Logits

A Serializer which simply converts the model outputs (assumed to be logits) to a list.

PredsClassificationSerializer

A ClassificationSerializer which gets the PREDS from the sample.

Probabilities

A Serializer which applies a softmax to the model outputs (assumed to be logits) and converts to a list.

flash.core.finetuning

FlashBaseFinetuning

FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback.

FreezeUnfreeze

NoFreeze

UnfreezeMilestones

flash.core.integrations.fiftyone

visualize

Visualizes predictions from a model with a FiftyOne Serializer in the FiftyOne App.

flash.core.integrations.icevision

IceVisionTransformAdapter

type _sphinx_paramlinks_flash.core.integrations.icevision.transforms.IceVisionTransformAdapter.transform

List[Callable]

default_transforms

The default transforms from IceVision.

train_default_transforms

The default augmentations from IceVision.

flash.core.model

BenchmarkConvergenceCI

Specialized callback only used during testing Keeps track metrics during training.

CheckDependenciesMeta

ModuleWrapperBase

The ModuleWrapperBase is a base for classes which wrap a LightningModule or an instance of ModuleWrapperBase.

DatasetProcessor

The DatasetProcessor mixin provides hooks for classes which need custom logic for producing the data loaders for each running stage given the corresponding dataset.

Task

A general Task.

flash.core.registry

FlashRegistry

This class is used to register function or functools.partial class to a registry.

ExternalRegistry

The ExternalRegistry is a FlashRegistry that can point to an external provider via a getter function.

ConcatRegistry

The ConcatRegistry can be used to concatenate multiple registries of different types together.

flash.core.optimizers

LARS

Extends SGD in PyTorch with LARS scaling from the paper Large batch training of Convolutional Networks.

LAMB

Extends ADAM in pytorch to incorporate LAMB algorithm from the paper: Large batch optimization for deep learning: Training BERT in 76 minutes.

LinearWarmupCosineAnnealingLR

Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr and base_lr followed by a cosine annealing schedule between base_lr and eta_min.

Utilities

from_argparse_args

Modified version of pytorch_lightning.utilities.argparse.from_argparse_args() which populates valid_kwargs from pytorch_lightning.Trainer.

get_callable_name

rtype

str

get_callable_dict

rtype

Union[Dict, Mapping]

predict_context

This decorator is used as context manager to put model in eval mode before running predict and reset to train after.

flash.core.data

flash.core.data.auto_dataset

AutoDataset

The AutoDataset is a BaseAutoDataset and a Dataset.

BaseAutoDataset

The BaseAutoDataset class wraps the output of a call to load_data() and a DataSource and provides the _call_load_sample method to call load_sample() with the correct CurrentRunningStageFuncContext for the current running_stage.

IterableAutoDataset

The IterableAutoDataset is a BaseAutoDataset and a IterableDataset.

flash.core.data.base_viz

BaseVisualization

This Base Class is used to create visualization tool on top of Preprocess hooks.

flash.core.data.batch

default_uncollate

This function is used to uncollate a batch into samples.

flash.core.data.callback

BaseDataFetcher

This class is used to profile Preprocess hook outputs.

ControlFlow

FlashCallback

FlashCallback is an extension of pytorch_lightning.callbacks.Callback.

flash.core.data.data_module

DataModule

A basic DataModule class for all Flash tasks.

flash.core.data.data_pipeline

DataPipeline

DataPipeline holds the engineering logic to connect Preprocess and/or Postprocess objects to the DataModule, Flash Task and Trainer.

DataPipelineState

A class to store and share all process states once a DataPipeline has been initialized.

flash.core.data.data_source

DatasetDataSource

The DatasetDataSource implements default behaviours for data sources which expect the input to load_data() to be a torch.utils.data.dataset.Dataset

DataSource

The DataSource class encapsulates two hooks: load_data and load_sample.

DefaultDataKeys

The DefaultDataKeys enum contains the keys that are used by built-in data sources to refer to inputs and targets.

DefaultDataSources

The DefaultDataSources enum contains the data source names used by all of the default from_* methods in DataModule.

FiftyOneDataSource

The FiftyOneDataSource expects the input to load_data() to be a fiftyone.core.collections.SampleCollection.

ImageLabelsMap

LabelsState

A ProcessState containing labels, a mapping from class index to label.

MockDataset

The MockDataset catches any metadata that is attached through __setattr__.

NumpyDataSource

The NumpyDataSource is a SequenceDataSource which expects the input to load_data() to be a sequence of np.ndarray objects.

PathsDataSource

The PathsDataSource implements default behaviours for data sources which expect the input to load_data() to be either a directory with a subdirectory for each class or a tuple containing list of files and corresponding list of targets.

SequenceDataSource

The SequenceDataSource implements default behaviours for data sources which expect the input to load_data() to be a sequence of tuples ((input, target) where target can be None).

TensorDataSource

The TensorDataSource is a SequenceDataSource which expects the input to load_data() to be a sequence of torch.Tensor objects.

has_file_allowed_extension

Checks if a file is an allowed extension.

has_len

rtype

bool

make_dataset

Generates a list of samples of a form (path_to_sample, class).

flash.core.data.process

BasePreprocess

DefaultPreprocess

DeserializerMapping

Deserializer Mapping.

Deserializer

Deserializer.

Postprocess

The Postprocess encapsulates all the data processing logic that should run after the model.

Preprocess

The Preprocess encapsulates all the data processing logic that should run before the data is passed to the model.

SerializerMapping

If the model output is a dictionary, then the SerializerMapping enables each entry in the dictionary to be passed to it’s own Serializer.

Serializer

A Serializer encapsulates a single serialize method which is used to convert the model output into the desired output format when predicting.

flash.core.data.properties

ProcessState

Base class for all process states.

Properties

flash.core.data.splits

SplitDataset

SplitDataset is used to create Dataset Subset using indices.

flash.core.data.transforms

ApplyToKeys

The ApplyToKeys class is an nn.Sequential which applies the given transforms to the given keys from the input.

KorniaParallelTransforms

The KorniaParallelTransforms class is an nn.Sequential which will apply the given transforms to each input (to .forward) in parallel, whilst sharing the random state (._params).

merge_transforms

Utility function to merge two transform dictionaries.

kornia_collate

Kornia transforms add batch dimension which need to be removed.

flash.core.data.utils

CurrentFuncContext

CurrentRunningStageContext

CurrentRunningStageFuncContext

FuncModule

This class is used to wrap a callable within a nn.Module and apply the wrapped function in __call__

convert_to_modules

download_data

Download file with progressbar.

flash.core.serve

ModelComponent

alias of object

Composition

Create a composition which define computations / endpoints to create & run.

Endpoint

An endpoint maps a route and request/response payload to components.

Servable

ModuleWrapperBase around a model object to enable serving at scale.

expose

Expose a function/method via a web API for serving model inference.

flash.image

Classification

ImageClassifier

The ImageClassifier is a Task for classifying images.

ImageClassificationData

Data module for image classification tasks.

ImageClassificationPreprocess

Preprocssing of data of image classification.

classification.data.MatplotlibVisualization

Process and show the image batch and its associated label using matplotlib.

classification.transforms.default_transforms

The default transforms for image classification: resize the image, convert the image and target to a tensor, collate the batch, and apply normalization.

classification.transforms.train_default_transforms

During training, we apply the default transforms with additional RandomHorizontalFlip.

Object Detection

ObjectDetector

The ObjectDetector is a Task for detecting objects in images.

ObjectDetectionData

detection.data.FiftyOneParser

detection.data.ObjectDetectionFiftyOneDataSource

detection.data.ObjectDetectionPreprocess

detection.serialization.FiftyOneDetectionLabels

A Serializer which converts model outputs to FiftyOne detection format.

Keypoint Detection

KeypointDetector

The ObjectDetector is a Task for detecting objects in images.

KeypointDetectionData

keypoint_detection.data.KeypointDetectionPreprocess

Instance Segmentation

InstanceSegmentation

The InstanceSegmentation is a Task for detecting objects in images.

InstanceSegmentationData

instance_segmentation.data.InstanceSegmentationPreprocess

Embedding

ImageEmbedder

The ImageEmbedder is a Task for obtaining feature vectors (embeddings) from images.

Segmentation

SemanticSegmentation

SemanticSegmentation is a Task for semantic segmentation of images.

SemanticSegmentationData

Data module for semantic segmentation tasks.

SemanticSegmentationPreprocess

segmentation.data.SegmentationMatplotlibVisualization

Process and show the image batch and its associated label using matplotlib.

segmentation.data.SemanticSegmentationNumpyDataSource

segmentation.data.SemanticSegmentationTensorDataSource

segmentation.data.SemanticSegmentationPathsDataSource

segmentation.data.SemanticSegmentationFiftyOneDataSource

segmentation.data.SemanticSegmentationDeserializer

segmentation.model.SemanticSegmentationPostprocess

segmentation.serialization.FiftyOneSegmentationLabels

A Serializer which converts the model outputs to FiftyOne segmentation format.

segmentation.serialization.SegmentationLabels

A Serializer which converts the model outputs to the label of the argmax classification per pixel in the image for semantic segmentation tasks.

segmentation.transforms.default_transforms

The default transforms for semantic segmentation: resize the image and mask, collate the batch, and apply normalization.

segmentation.transforms.prepare_target

Convert the target mask to long and remove the channel dimension.

segmentation.transforms.train_default_transforms

During training, we apply the default transforms with additional RandomHorizontalFlip and ColorJitter.

Style Transfer

StyleTransfer

StyleTransfer is a Task for transferring the style from one image onto another.

StyleTransferData

StyleTransferPreprocess

raise_not_supported

rtype

NoReturn

flash.image.data

ImageDeserializer

ImageFiftyOneDataSource

ImageNumpyDataSource

ImagePathsDataSource

ImageTensorDataSource

flash.audio

Classification

AudioClassificationData

Data module for audio classification.

AudioClassificationPreprocess

Speech Recognition

SpeechRecognitionData

Data Module for text classification tasks.

SpeechRecognition

The SpeechRecognition task is a Task for converting speech to text.

speech_recognition.data.SpeechRecognitionPreprocess

speech_recognition.data.SpeechRecognitionBackboneState

The SpeechRecognitionBackboneState stores the backbone in use by the SpeechRecognitionPostprocess

speech_recognition.data.SpeechRecognitionPostprocess

speech_recognition.data.SpeechRecognitionCSVDataSource

speech_recognition.data.SpeechRecognitionJSONDataSource

speech_recognition.data.BaseSpeechRecognition

speech_recognition.data.SpeechRecognitionFileDataSource

speech_recognition.data.SpeechRecognitionPathsDataSource

speech_recognition.data.SpeechRecognitionDatasetDataSource

speech_recognition.data.SpeechRecognitionDeserializer

flash.pointcloud

Segmentation

PointCloudSegmentation

The PointCloudClassifier is a ClassificationTask that classifies pointcloud data.

PointCloudSegmentationData

segmentation.data.PointCloudSegmentationPreprocess

segmentation.data.PointCloudSegmentationFoldersDataSource

segmentation.data.PointCloudSegmentationDatasetDataSource

Object Detection

PointCloudObjectDetector

The PointCloudObjectDetector is a ClassificationTask that classifies pointcloud data.

PointCloudObjectDetectorData

detection.data.PointCloudObjectDetectorPreprocess

detection.data.PointCloudObjectDetectorFoldersDataSource

detection.data.PointCloudObjectDetectorDatasetDataSource

flash.tabular

Classification

TabularClassifier

The TabularClassifier is a Task for classifying tabular data.

TabularClassificationData

Regression

TabularRegressionData

flash.tabular.data

TabularData

Data module for tabular tasks.

TabularDataFrameDataSource

TabularCSVDataSource

TabularDeserializer

TabularPreprocess

TabularPostprocess

flash.text

Classification

TextClassifier

The TextClassifier is a Task for classifying text.

TextClassificationData

Data Module for text classification tasks.

classification.data.TextClassificationPostprocess

classification.data.TextClassificationPreprocess

classification.data.TextDeserializer

classification.data.TextDataSource

classification.data.TextCSVDataSource

classification.data.TextJSONDataSource

classification.data.TextDataFrameDataSource

classification.data.TextParquetDataSource

classification.data.TextHuggingFaceDatasetDataSource

classification.data.TextListDataSource

Question Answering

QuestionAnsweringTask

The QuestionAnsweringTask is a Task for extractive question answering.

QuestionAnsweringData

Data module for QuestionAnswering task.

question_answering.data.QuestionAnsweringBackboneState

The QuestionAnsweringBackboneState stores the backbone in use by the QuestionAnsweringPreprocess

question_answering.data.QuestionAnsweringCSVDataSource

question_answering.data.QuestionAnsweringDataSource

question_answering.data.QuestionAnsweringDictionaryDataSource

question_answering.data.QuestionAnsweringFileDataSource

question_answering.data.QuestionAnsweringJSONDataSource

question_answering.data.QuestionAnsweringPostprocess

question_answering.data.QuestionAnsweringPreprocess

question_answering.data.SQuADDataSource

Summarization

SummarizationTask

The SummarizationTask is a Task for Seq2Seq text summarization.

SummarizationData

seq2seq.summarization.data.SummarizationPreprocess

Translation

TranslationTask

The TranslationTask is a Task for Seq2Seq text translation.

TranslationData

Data module for Translation tasks.

seq2seq.translation.data.TranslationPreprocess

General Seq2Seq

Seq2SeqTask

General Task for Sequence2Sequence.

Seq2SeqData

Data module for Seq2Seq tasks.

Seq2SeqFreezeEmbeddings

Freezes the embedding layers during Seq2Seq training.

seq2seq.core.data.Seq2SeqBackboneState

The Seq2SeqBackboneState stores the backbone in use by the Seq2SeqPreprocess

seq2seq.core.data.Seq2SeqCSVDataSource

seq2seq.core.data.Seq2SeqDataSource

seq2seq.core.data.Seq2SeqFileDataSource

seq2seq.core.data.Seq2SeqJSONDataSource

seq2seq.core.data.Seq2SeqPostprocess

seq2seq.core.data.Seq2SeqPreprocess

seq2seq.core.data.Seq2SeqSentencesDataSource

seq2seq.core.metrics.BLEUScore

Calculate BLEU score of machine translated text with one or more references.

seq2seq.core.metrics.RougeBatchAggregator

Aggregates rouge scores and provides confidence intervals.

seq2seq.core.metrics.RougeMetric

Metric used for automatic summarization.

flash.video

Classification

VideoClassifier

Task that classifies videos.

VideoClassificationData

Data module for Video classification tasks.

classification.data.BaseVideoClassification

classification.data.VideoClassificationFiftyOneDataSource

classification.data.VideoClassificationPathsDataSource

classification.data.VideoClassificationPreprocess

classification.model.VideoClassifierFinetuning

flash.graph

Classification

GraphClassifier

The GraphClassifier is a Task for classifying graphs.

GraphClassificationData

Data module for graph classification tasks.

classification.data.GraphClassificationPreprocess

flash.graph.data

GraphDatasetDataSource

Introduction / Set-up

Welcome

Before you begin, we’d like to express our gratitude to you for wanting to add a task to Flash. With Flash our aim is to create a great user experience, enabling awesome advanced applications with just a few lines of code. We’re really pleased with what we’ve achieved with Flash and we hope you will be too. Now let’s dive in!

Set-up

The Task template is designed to guide you through contributing a task to Flash. It contains the code, tests, and examples for a task that performs classification with a multi-layer perceptron, intended for use with the classic data sets from scikit-learn. The Flash tasks are organized in folders by data-type (image, text, video, etc.), with sub-folders for different task types (classification, regression, etc.).

Copy the files in flash/template/classification to a new sub-directory under the relevant data-type. If a data-type folder already exists for your task, then a task type sub-folder should be added containing the template files. If a data-type folder doesn’t exist, then you will need to add that too. You should also copy the files from tests/template/classification to the corresponding data-type, task type folder in tests. For example, if you were adding an image classification task, you would do:

mkdir flash/image/classification
cp flash/template/classification/* flash/image/classification/
mkdir tests/image/classification
cp tests/template/classification/* tests/image/classification/

Tutorials

The tutorials in this section will walk you through all of the components you need to implement (or adapt from the template) for your custom task.

  • The Data: our first tutorial goes over the best practices for implementing everything you need to connect data to your task

  • The Backbones: the second tutorial shows you how to create an extensible backbone registry for your task

  • The Task: now that we have the data and the models, in this tutorial we create our custom task

  • Optional Extras: this tutorial covers some optional extras you can add if needed for your particular task

  • The Example: this tutorial guides you through creating some simple examples showing your task in action

  • The Tests: in this tutorial, we cover best practices for writing some tests for your new task

  • The Docs: in our final tutorial, we provide a template for you to create the docs page for your task

The Data

The first step to contributing a task is to implement the classes we need to load some data. Inside data.py you should implement:

  1. some DataSource classes (optional)

  2. a Preprocess

  3. a DataModule

  4. a BaseVisualization (optional)

  5. a Postprocess (optional)

DataSource

The DataSource class contains the logic for data loading from different sources such as folders, files, tensors, etc. Every Flash DataModule can be instantiated with from_datasets(). For each additional way you want the user to be able to instantiate your DataModule, you’ll need to create a DataSource. Each DataSource has 2 methods:

  • load_data() takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata.

  • load_sample() then takes as input a single element from the output of load_data and returns a sample.

By default these methods just return their input, so you don’t need both a load_data() and a load_sample() to create a DataSource. Where possible, you should override one of our existing DataSource classes.

Let’s start by implementing a TemplateNumpyDataSource, which overrides NumpyDataSource. The main DataSource method that we have to implement is load_data(). As we’re extending the NumpyDataSource, we expect the same data argument (in this case, a tuple containing data and corresponding target arrays).

We can also take the dataset argument. Any attributes we set on dataset will be available on the Dataset generated by our DataSource. In this data source, we’ll set the num_features attribute.

Here’s the code for our TemplateNumpyDataSource.load_data method:

def load_data(self, data: Tuple[np.ndarray, Sequence[Any]], dataset: Any) -> Sequence[Mapping[str, Any]]:
    """Sets the ``num_features`` attribute and calls ``super().load_data``.

    Args:
        data: The tuple of ``np.ndarray`` (num_examples x num_features) and associated targets.
        dataset: The object that we can set attributes (such as ``num_features``) on.

    Returns:
        A sequence of samples / sample metadata.
    """
    dataset.num_features = data[0].shape[1]
    return super().load_data(data, dataset)

Note

Later, when we add our DataModule implementation, we’ll make num_features available to the user.

Sometimes you need to something a bit more custom. When creating a custom DataSource, the type of the data argument is up to you. For our template Task, it would be cool if the user could provide a scikit-learn Bunch as the data source. To achieve this, we’ll add a TemplateSKLearnDataSource whose load_data expects a Bunch as input. We override our TemplateNumpyDataSource so that we can call super with the data and targets extracted from the Bunch. We perform two additional steps here to improve the user experience:

  1. We set the num_classes attribute on the dataset. If num_classes is set, it is automatically made available as a property of the DataModule.

  2. We create and set a LabelsState. The labels provided here will be shared with the Labels serializer, so the user doesn’t need to provide them.

Here’s the code for the TemplateSKLearnDataSource.load_data method:

def load_data(self, data: Bunch, dataset: Any) -> Sequence[Mapping[str, Any]]:
    """Gets the ``data`` and ``target`` attributes from the ``Bunch`` and passes them to ``super().load_data``.

    Args:
        data: The scikit-learn data ``Bunch``.
        dataset: The object that we can set attributes (such as ``num_classes``) on.

    Returns:
        A sequence of samples / sample metadata.
    """
    dataset.num_classes = len(data.target_names)
    self.set_state(LabelsState(data.target_names))
    return super().load_data((data.data, data.target), dataset=dataset)

We can customize the behaviour of our load_data() for different stages, by prepending train, val, test, or predict. For our TemplateSKLearnDataSource, we don’t want to provide any targets to the model when predicting. We can implement predict_load_data like this:

def predict_load_data(self, data: Bunch) -> Sequence[Mapping[str, Any]]:
    """Avoid including targets when predicting.

    Args:
        data: The scikit-learn data ``Bunch``.

    Returns:
        A sequence of samples / sample metadata.
    """
    return super().predict_load_data(data.data)

DataSource vs Dataset

A DataSource is not the same as a torch.utils.data.Dataset. When a from_* method is called on your DataModule, it gets the DataSource to use from the Preprocess. A Dataset is then created from the DataSource for each stage (train, val, test, predict) using the provided metadata (e.g. folder name, numpy array etc.).

The output of the load_data() can just be a torch.utils.data.Dataset instance. If the library that your Task is based on provides a custom dataset, you don’t need to re-write it as a DataSource. For example, the load_data() of the VideoClassificationPathsDataSource just creates an EncodedVideoDataset from the given folder. Here’s how it looks (from video/classification.data.py):

def load_data(self, data: str, dataset: Optional[Any] = None) -> "LabeledVideoDataset":
    ds = self._make_encoded_video_dataset(data)
    if self.training:
        label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in ds._labeled_videos._paths_and_labels}
        self.set_state(LabelsState(label_to_class_mapping))
        dataset.num_classes = len(np.unique([s[1]["label"] for s in ds._labeled_videos]))
    return ds

Preprocess

The Preprocess object contains all the data transforms. Internally we inject the Preprocess transforms at several points along the pipeline.

Defining the standard transforms (typically at least a to_tensor_transform should be defined) for your Preprocess is as simple as implementing the default_transforms method. The Preprocess must take train_transform, val_transform, test_transform, and predict_transform arguments in the __init__. These arguments can be provided by the user (when creating the DataModule) to override the default transforms. Any additional arguments are up to you.

Inside the __init__, we make a call to super. This is where we register our data sources. Data sources should be given as a dictionary which maps data source name to data source object. The name can be anything, but if you want to take advantage of our built-in from_* classmethods, you should use DefaultDataSources as the names. In our case, we have both a NUMPY and a custom scikit-learn data source (which we’ll call “sklearn”).

You should also provide a default_data_source. This is the name of the data source to use by default when predicting. It’d be cool if we could get predictions just from a numpy array, so we’ll use NUMPY as the default.

Here’s our TemplatePreprocess.__init__:

def __init__(
    self,
    train_transform: Optional[Dict[str, Callable]] = None,
    val_transform: Optional[Dict[str, Callable]] = None,
    test_transform: Optional[Dict[str, Callable]] = None,
    predict_transform: Optional[Dict[str, Callable]] = None,
):
    super().__init__(
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        predict_transform=predict_transform,
        data_sources={
            DefaultDataSources.NUMPY: TemplateNumpyDataSource(),
            "sklearn": TemplateSKLearnDataSource(),
        },
        default_data_source=DefaultDataSources.NUMPY,
    )

For our TemplatePreprocess, we’ll just configure a default to_tensor_transform. Let’s first define the transform as a staticmethod:

@staticmethod
def input_to_tensor(input: np.ndarray):
    """Transform which creates a tensor from the given numpy ``ndarray`` and converts it to ``float``"""
    return torch.from_numpy(input).float()

Our inputs samples will be dictionaries whose keys are in the DefaultDataKeys. You can map each key to different transforms using ApplyToKeys. Here’s our default_transforms method:

def default_transforms(self) -> Optional[Dict[str, Callable]]:
    """Configures the default ``to_tensor_transform``.

    Returns:
        Our dictionary of transforms.
    """
    return {
        "to_tensor_transform": nn.Sequential(
            ApplyToKeys(DefaultDataKeys.INPUT, self.input_to_tensor),
            ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
        ),
    }

DataModule

The DataModule is responsible for creating the DataLoader and injecting the transforms for each stage. When the user calls a from_* method (such as from_numpy()), the following steps take place:

  1. The from_data_source() method is called with the name of the DataSource to use and the inputs to provide to load_data() for each stage.

  2. The Preprocess is created from cls.preprocess_cls (if it wasn’t provided by the user) with any provided transforms.

  3. The DataSource of the provided name is retrieved from the Preprocess.

  4. A BaseAutoDataset is created from the DataSource for each stage.

  5. The DataModule is instantiated with the data sets.


To create our TemplateData DataModule, we first need to attach out preprocess class like this:

preprocess_cls = TemplatePreprocess

Since we provided a NUMPY DataSource in the TemplatePreprocess, from_numpy() will now work with our TemplateData.

If you’ve defined a fully custom DataSource (like our TemplateSKLearnDataSource), then you will need to write a from_* method for each. Here’s the from_sklearn method for our TemplateData:

@classmethod
def from_sklearn(
    cls,
    train_bunch: Optional[Bunch] = None,
    val_bunch: Optional[Bunch] = None,
    test_bunch: Optional[Bunch] = None,
    predict_bunch: Optional[Bunch] = None,
    train_transform: Optional[Dict[str, Callable]] = None,
    val_transform: Optional[Dict[str, Callable]] = None,
    test_transform: Optional[Dict[str, Callable]] = None,
    predict_transform: Optional[Dict[str, Callable]] = None,
    data_fetcher: Optional[BaseDataFetcher] = None,
    preprocess: Optional[Preprocess] = None,
    val_split: Optional[float] = None,
    batch_size: int = 4,
    num_workers: int = 0,
    **preprocess_kwargs: Any,
):
    """This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and passes them
    through to the :meth:`~flash.core.data.data_module.DataModule.from_data_source` method underneath.

    Args:
        train_bunch: The scikit-learn ``Bunch`` containing the train data.
        val_bunch: The scikit-learn ``Bunch`` containing the validation data.
        test_bunch: The scikit-learn ``Bunch`` containing the test data.
        predict_bunch: The scikit-learn ``Bunch`` containing the predict data.
        train_transform: The dictionary of transforms to use during training which maps
            :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
        val_transform: The dictionary of transforms to use during validation which maps
            :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
        test_transform: The dictionary of transforms to use during testing which maps
            :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
        predict_transform: The dictionary of transforms to use during predicting which maps
            :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
        data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
            :class:`~flash.core.data.data_module.DataModule`.
        preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
            :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be
            constructed and used.
        val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
        batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
        num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
        preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
            if ``preprocess = None``.

    Returns:
        The constructed data module.
    """
    return super().from_data_source(
        "sklearn",
        train_bunch,
        val_bunch,
        test_bunch,
        predict_bunch,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        predict_transform=predict_transform,
        data_fetcher=data_fetcher,
        preprocess=preprocess,
        val_split=val_split,
        batch_size=batch_size,
        num_workers=num_workers,
        **preprocess_kwargs,
    )

The final step is to implement the num_features property for our TemplateData. This is just a convenience for the user that finds the num_features attribute on any of the data sets and returns it. Here’s the code:

@property
def num_features(self) -> Optional[int]:
    """Tries to get the ``num_features`` from each dataset in turn and returns the output."""
    n_fts_train = getattr(self.train_dataset, "num_features", None)
    n_fts_val = getattr(self.val_dataset, "num_features", None)
    n_fts_test = getattr(self.test_dataset, "num_features", None)
    return n_fts_train or n_fts_val or n_fts_test

BaseVisualization

An optional step is to implement a BaseVisualization. The BaseVisualization lets you control how data at various points in the pipeline can be visualized. This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms.

Note

Don’t worry about implementing it right away, you can always come back and add it later!

Here’s the code for our TemplateVisualization which just prints the data:

class TemplateVisualization(BaseVisualization):
    """The ``TemplateVisualization`` class is a :class:`~flash.core.data.callbacks.BaseVisualization` that just
    prints the data.

    If you want to provide a visualization with your task, you can override these hooks.
    """

    def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
        print(samples)

    def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
        print(samples)

    def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
        print(samples)

    def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
        print(samples)

    def show_per_batch_transform(self, batch: List[Any], running_stage):
        print(batch)

We can configure our custom visualization in the TemplateData using configure_data_fetcher() like this:

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
    """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher``
    method."""
    return TemplateVisualization(*args, **kwargs)

Postprocess

Postprocess contains any transforms that need to be applied after the model. You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. As an example, here’s the TextClassificationPostprocess which gets the logits from a SequenceClassifierOutput:

class TextClassificationPostprocess(Postprocess):
    def per_batch_transform(self, batch: Any) -> Any:
        if isinstance(batch, SequenceClassifierOutput):
            batch = batch.logits
        return super().per_batch_transform(batch)

In your DataSource or Preprocess, you can add metadata to the batch using the METADATA key. Your Postprocess can then use this metadata in its transforms. You should use this approach if your postprocessing depends on the state of the input before the Preprocess transforms. For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the METADATA. Here’s an example from the SemanticSegmentationNumpyDataSource:

def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
    img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float()
    sample[DefaultDataKeys.INPUT] = img
    sample[DefaultDataKeys.METADATA] = {"size": img.shape}
    return sample

The METADATA can now be referenced in your Postprocess. For example, here’s the code for the per_sample_transform method of the SemanticSegmentationPostprocess:

def per_sample_transform(self, sample: Any) -> Any:
    resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA]["size"][-2:], interpolation="bilinear")
    sample[DefaultDataKeys.PREDS] = resize(sample[DefaultDataKeys.PREDS])
    sample[DefaultDataKeys.INPUT] = resize(sample[DefaultDataKeys.INPUT])
    return super().per_sample_transform(sample)

Now that you’ve got some data, it’s time to add some backbones for your task!

The Backbones

Now that you’ve got a way of loading data, you should implement some backbones to use with your Task. Create a FlashRegistry to use with your Task in backbones.py.

The registry allows you to register backbones for your task that can be selected by the user. The backbones can come from anywhere as long as you can register a function that loads the backbone. Furthermore, the user can add their own models to the existing backbones, without having to write their own Task!

You can create a registry like this:

TEMPLATE_BACKBONES = FlashRegistry("backbones")

Let’s add a simple MLP backbone to our registry. We need a function that creates the backbone and returns it along with the output size (so that we can create the model head in our Task). You can use any name for the function, although we use load_{model name} by convention. You also need to provide name and namespace of the backbone. The standard for namespace is data_type/task_type, so for an image classification task the namespace will be image/classification. Here’s the code:

@TEMPLATE_BACKBONES(name="mlp-128", namespace="template/classification")
def load_mlp_128(num_features, **_):
    """A simple MLP backbone with 128 hidden units."""
    return (
        nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(True),
            nn.BatchNorm1d(128),
        ),
        128,
    )

Here’s another example with a slightly more complex model:

@TEMPLATE_BACKBONES(name="mlp-128-256", namespace="template/classification")
def load_mlp_128_256(num_features, **_):
    """An two layer MLP backbone with 128 and 256 hidden units respectively."""
    return (
        nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(True),
            nn.BatchNorm1d(128),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.BatchNorm1d(256),
        ),
        256,
    )

Here’s a another example, which adds DINO pretrained model from PyTorch Hub to the IMAGE_CLASSIFIER_BACKBONES, from flash/image/classification/backbones/transformers.py:

def dino_vitb16(*_, **__):
    backbone = torch.hub.load("facebookresearch/dino:main", "dino_vitb16")
    return backbone, 768

Once you’ve got some data and some backbones, implement your task!

The Task

Once you’ve implemented a Flash DataModule and some backbones, you should implement your Task in model.py. The Task is responsible for: setting up the backbone, performing the forward pass of the model, and calculating the loss and any metrics. Remember that, under the hood, the Flash Task is simply a LightningModule with some helpful defaults.

To build your task, you can start by overriding the base Task or any of the existing Task implementations. For example, in our scikit-learn example, we can just override ClassificationTask which provides good defaults for classification.

You should attach your backbones registry as a class attribute like this:

class TemplateSKLearnClassifier(ClassificationTask):

    backbones: FlashRegistry = TEMPLATE_BACKBONES

Model architecture and hyper-parameters

In the __init__(), you will need to configure defaults for the:

  • loss function

  • optimizer

  • metrics

  • backbone / model

You will also need to create the backbone from the registry and create the model head. Here’s the code:

def __init__(
    self,
    num_features: int,
    num_classes: int,
    backbone: Union[str, Tuple[nn.Module, int]] = "mlp-128",
    backbone_kwargs: Optional[Dict] = None,
    loss_fn: LOSS_FN_TYPE = None,
    optimizer: OPTIMIZER_TYPE = "Adam",
    lr_scheduler: LR_SCHEDULER_TYPE = None,
    metrics: METRICS_TYPE = None,
    learning_rate: float = 1e-2,
    multi_label: bool = False,
    serializer: SERIALIZER_TYPE = None,
):
    super().__init__(
        model=None,
        loss_fn=loss_fn,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        metrics=metrics,
        learning_rate=learning_rate,
        multi_label=multi_label,
        serializer=serializer or Labels(),
    )

    self.save_hyperparameters()

    if not backbone_kwargs:
        backbone_kwargs = {}

    if isinstance(backbone, tuple):
        self.backbone, out_features = backbone
    else:
        self.backbone, out_features = self.backbones.get(backbone)(num_features=num_features, **backbone_kwargs)

    self.head = nn.Linear(out_features, num_classes)

Note

We call save_hyperparameters() to log the arguments to the __init__ as hyperparameters. Read more here.

Adding the model routines

You should override the {train,val,test,predict}_step methods. The default {train,val,test,predict}_step implementations in Task expect a tuple containing the input (to be passed to the model) and target (to be used when computing the loss), and should be suitable for most applications. In our template example, we just extract the input and target from the input mapping and forward them to the super methods. Here’s the code for the training_step:

def training_step(self, batch: Any, batch_idx: int) -> Any:
    """For the training step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` and
    :attr:`~flash.core.data.data_source.DefaultDataKeys.TARGET` keys from the input and forward them to the
    :meth:`~flash.core.model.Task.training_step`."""
    batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
    return super().training_step(batch, batch_idx)

We use the same code for the validation_step and test_step. For predict_step we don’t need the targets, so our code looks like this:

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
    """For the predict step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` key
    from the input and forward it to the :meth:`~flash.core.model.Task.predict_step`."""
    batch = batch[DefaultDataKeys.INPUT]
    return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)

Note

You can completely replace the {train,val,test,predict}_step methods (that is, without a call to super) if you need more custom behaviour for your Task at a particular stage.

Finally, we use our backbone and head in a custom forward pass:

def forward(self, x) -> torch.Tensor:
    """First call the backbone, then the model head."""
    x = self.backbone(x)
    return self.head(x)

Now that you’ve got your task, take a look at some optional advanced features you can add or go ahead and create some examples showing your task in action!

Optional Extras

Organize your transforms in transforms.py

If you have a lot of default transforms, it can be useful to put them all in a transforms.py file, to be referenced in your Preprocess. Here’s an example from image/classification/transforms.py which creates some default transforms given the desired image size:

def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
    """The default transforms for image classification: resize the image, convert the image and target to a tensor,
    collate the batch, and apply normalization."""
    if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1":
        #  Better approach as all transforms are applied on tensor directly
        return {
            "to_tensor_transform": nn.Sequential(
                ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
                ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
            ),
            "post_tensor_transform": ApplyToKeys(
                DefaultDataKeys.INPUT,
                K.geometry.Resize(image_size),
            ),
            "collate": kornia_collate,
            "per_batch_transform_on_device": ApplyToKeys(
                DefaultDataKeys.INPUT,
                K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
            ),
        }
    return {
        "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size)),
        "to_tensor_transform": nn.Sequential(
            ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
            ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
        ),
        "post_tensor_transform": ApplyToKeys(
            DefaultDataKeys.INPUT,
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ),
        "collate": kornia_collate,
    }

Here’s how we create our transforms in the ImageClassificationPreprocess:

    def default_transforms(self) -> Optional[Dict[str, Callable]]:
        return default_transforms(self.image_size)

Add output serializers to your Task

We recommend that you do most of the heavy lifting in the Postprocess. Specifically, it should include any formatting and transforms that should always be applied to the predictions. If you want to support different use cases that require different prediction formats, you should add some Serializer implementations in a serialization.py file.

Some good examples are in flash/core/classification.py. Here’s the Classes Serializer:

class Classes(PredsClassificationSerializer):
    """A :class:`.Serializer` which applies an argmax to the model outputs (either logits or probabilities) and
    converts to a list.

    Args:
        multi_label: If true, treats outputs as multi label logits.
        threshold: The threshold to use for multi_label classification.
    """

    def __init__(self, multi_label: bool = False, threshold: float = 0.5):
        super().__init__(multi_label)

        self.threshold = threshold

    def serialize(self, sample: Any) -> Union[int, List[int]]:
        sample = super().serialize(sample)
        if self.multi_label:
            one_hot = (sample.sigmoid() > self.threshold).int().tolist()
            result = []
            for index, value in enumerate(one_hot):
                if value == 1:
                    result.append(index)
            return result
        return torch.argmax(sample, -1).tolist()

Alternatively, here’s the Logits Serializer:

class Logits(PredsClassificationSerializer):
    """A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list."""

    def serialize(self, sample: Any) -> Any:
        return super().serialize(sample).tolist()

Take a look at Predictions (inference) to learn more.


Once you’ve added any optional extras, it’s time to create some examples showing your task in action!

The Example

Now you’ve implemented your task, it’s time to add an example showing how cool it is! We usually provide one example in flash_examples/. You can base these off of our template.py examples.

The example should:

  1. download the data (we’ll add the example to our CI later on, so choose a dataset small enough that it runs in reasonable time)

  2. load the data into a DataModule

  3. create an instance of the Task

  4. create a Trainer

  5. call finetune() or fit() to train your model

  6. generate predictions for a few examples

  7. save the checkpoint

For our template example we don’t have a pretrained backbone, so we can just call fit() rather than finetune(). Here’s the full example (flash_examples/template.py):

import numpy as np
import torch
from sklearn import datasets

import flash
from flash.template import TemplateData, TemplateSKLearnClassifier

# 1. Create the DataModule
datamodule = TemplateData.from_sklearn(
    train_bunch=datasets.load_iris(),
    val_split=0.1,
)

# 2. Build the task
model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Classify a few examples
predictions = model.predict(
    [
        np.array([4.9, 3.0, 1.4, 0.2]),
        np.array([6.9, 3.2, 5.7, 2.3]),
        np.array([7.2, 3.0, 5.8, 1.6]),
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("template_model.pt")

We get this output:

['setosa', 'virginica', 'versicolor']

Now that you’ve got an example showing your awesome task in action, it’s time to write some tests!

The Tests

Our next step is to create some tests for our Task. For the TemplateSKLearnClassifier, we will just create some basic tests. You should expand on these to include tests for any specific functionality you have in your Task.

Smoke tests

We use smoke tests, usually called test_smoke, throughout. These just instantiate the class we are testing, to see that they can be created without raising any errors.

tests/examples/test_scripts.py

Before we write our custom tests, we should add out examples to the CI. To do this, add a line for each example (finetuning and predict) to the annotation of test_example in tests/examples/test_scripts.py. Here’s how those lines look for our template.py examples:

pytest.param(
    "finetuning", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),
...
pytest.param(
    "predict", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),

test_data.py

The most important tests in test_data.py check that the from_* methods work correctly. In the class TestTemplateData, we have two of these: test_from_numpy and test_from_sklearn. In general, there should be one test_from_* method for each data_source you have configured.

Here’s the code for test_from_numpy:

    def test_from_numpy(self):
        """Tests that ``TemplateData`` is properly created when using the ``from_numpy`` method."""
        data = np.random.rand(10, self.num_features)
        targets = np.random.randint(0, self.num_classes, (10,))

        # instantiate the data module
        dm = TemplateData.from_numpy(
            train_data=data,
            train_targets=targets,
            val_data=data,
            val_targets=targets,
            test_data=data,
            test_targets=targets,
            batch_size=2,
            num_workers=0,
        )
        assert dm is not None
        assert dm.train_dataloader() is not None
        assert dm.val_dataloader() is not None
        assert dm.test_dataloader() is not None

        # check training data
        data = next(iter(dm.train_dataloader()))
        rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
        assert rows.shape == (2, self.num_features)
        assert targets.shape == (2,)

        # check val data
        data = next(iter(dm.val_dataloader()))
        rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
        assert rows.shape == (2, self.num_features)
        assert targets.shape == (2,)

        # check test data
        data = next(iter(dm.test_dataloader()))
        rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
        assert rows.shape == (2, self.num_features)
        assert targets.shape == (2,)

test_model.py

In test_model.py, we first have test_forward and test_train. These test that tensors can be passed to the forward and that the Task can be trained. Here’s the code for test_forward and test_train:

@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
@pytest.mark.parametrize("num_classes", [4, 256])
@pytest.mark.parametrize("shape", [(1, 3), (2, 128)])
def test_forward(num_classes, shape):
    """Tests that a tensor can be given to the model forward and gives the correct output size."""
    model = TemplateSKLearnClassifier(
        num_features=shape[1],
        num_classes=num_classes,
    )
    model.eval()

    row = torch.rand(*shape)

    out = model(row)
    assert out.shape == (shape[0], num_classes)
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_train(tmpdir):
    """Tests that the model can be trained on our ``DummyDataset``."""
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model, train_dl)

We also include tests for validating and testing: test_val, and test_test. These tests are very similar to test_train, but here they are for completeness:

@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_val(tmpdir):
    """Tests that the model can be validated on our ``DummyDataset``."""
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    val_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.validate(model, val_dl)
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_test(tmpdir):
    """Tests that the model can be tested on our ``DummyDataset``."""
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    test_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.test(model, test_dl)

We also include tests for prediction named test_predict_* for each of our data sources. In our case, we have test_predict_numpy and test_predict_sklearn. These tests should use the data_source argument to predict() to select the required DataSource. Here’s test_predict_sklearn as an example:

@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_predict_sklearn():
    """Tests that we can generate predictions from a scikit-learn ``Bunch``."""
    bunch = datasets.load_iris()
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    data_pipe = DataPipeline(preprocess=TemplatePreprocess())
    out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe)
    assert isinstance(out[0], int)

Now that you’ve written the tests, it’s time to add some docs!

The Docs

The final step is to add some docs. For each Task in Flash, we have a docs page in docs/source/reference. You should create a .rst file there with the following:

  • a brief description of the task

  • the predict example

  • the finetuning example

  • any relevant API reference

Here are the contents of docs/source/reference/template.rst which breaks down each of these steps:


.. _template:

########
Template
########

********
The Task
********

Here you should add a description of your task. For example:
Classification is the task of assigning one of a number of classes to each data point.

------

*******
Example
*******

.. note::

    Here you should add a short intro to your example, and then use ``literalinclude`` to add it.
    To make it simple, you can fill in this template.

Let's look at the task of <describe the task> using the <data set used in the example>.
The dataset contains <describe the data>.
Here's an outline:

.. code-block::

    <present the folder structure of the data or some data samples here>

Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the <link to the DataModule with ``:class:``>.
We select a pre-trained backbone to use for our <link to the Task with ``:class:``> and finetune on the <name of the data set> data.
We then use the trained <link to the Task with ``:class:``> for inference.
Finally, we save the model.
Here's the full example:

<include the example with ``literalinclude``>

.. literalinclude:: ../../../flash_examples/template.py
    :language: python
    :lines: 14-

Here’s the rendered doc page!


Once the docs are done, it’s finally time to open a PR and wait for some reviews!


Congratulations on adding your first Task to Flash, we hope to see you again soon!

Flash Governance | Persons of interest

Leads

Core Maintainers

Contributing

Welcome to the PyTorch Lightning community! We’re building the most advanced research platform on the planet to implement the latest, best practices that the amazing PyTorch team rolls out!

Flash Design Principles

We encourage all sorts of contributions you’re interested in adding! When coding for Flash, please follow these principles.

Simple Internal Code

It’s useful for users to look at the code and understand very quickly what’s happening. Many users won’t be engineers. Thus we need to value clear, simple code over condensed ninja moves. While that’s super cool, this isn’t the project for that :)

Force User Decisions To Best Practices

There are 1,000 ways to do something. However, eventually one popular solution becomes standard practice, and everyone follows. We try to find the best way to solve a particular problem, and then force our users to use it for readability and simplicity.

When something becomes a best practice, we add it to the framework. This is usually something like bits of code in utils or in the model file that everyone keeps adding over and over again across projects. When this happens, bring that code inside the trainer and add a flag for it.

Backward-compatible API

We all hate updating our deep learning packages because we don’t want to refactor a bunch of stuff. In Flash, we make sure every change we make which could break an API is backward compatible with good deprecation warnings.

Gain User Trust

As a researcher, you can’t have any part of your code going wrong. So, make thorough tests to ensure that every implementation of a new trick or subtle change is correct.

Interoperability

PyTorch Lightning Flash is highly interoperable with PyTorch Lightning and PyTorch.


Contribution Types

We are always looking for help implementing new features or fixing bugs.

A lot of good work has already been done in project mechanics (requirements.txt, setup.py, pep8, badges, ci, etc…) so we’re in a good state there thanks to all the early contributors (even pre-beta release)!

Bug Fixes:

  1. If you find a bug please submit a GitHub issue.

    • Make sure the title explains the issue.

    • Describe your setup, what you are trying to do, expected vs. actual behaviour. Please add configs and code samples.

    • Add details on how to reproduce the issue - a minimal test case is always best, colab is also great. Note, that the sample code shall be minimal and if needed with publicly available data.

  2. Try to fix it or recommend a solution. We highly recommend to use test-driven approach:

    • Convert your minimal code example to a unit/integration test with assert on expected results.

    • Start by debugging the issue… You can run just this particular test in your IDE and draft a fix.

    • Verify that your test case fails on the master branch and only passes with the fix applied.

  3. Submit a PR!

Note, even if you do not find the solution, sending a PR with a test covering the issue is a valid contribution and we can help you or finish it with you :]

New Features:

  1. Submit a GitHub issue - describe what is the motivation of such feature (adding the use case or an example is helpful).

  2. Let’s discuss to determine the feature scope.

  3. Submit a PR! We recommend test driven approach to adding new features as well:

    • Write a test for the functionality you want to add.

    • Write the functional code until the test passes.

  4. Add/update the relevant tests!

New Tasks:

Flash is a framework of tasks for fast prototyping, baselining, finetuning and solving business and scientific problems with deep learning. Following are general guidelines for adding new tasks.

  1. Models which are standard baselines

  2. Whose results are reproduced properly either by us or by authors.

  3. Top models which are not SOTA but highly cited for production usage / for other uses. (E.g. Mobile BERT, MobileNets, FBNets).

  4. Do not reinvent the wheel, natively support torchvision, torchtext, torchaudio models.

  5. Use open source licensed models.

Please raise an issue before adding a new task. Please let us know why the particular task is important for Flash.

Test cases:

Want to keep Lightning Flash healthy? Love seeing those green tests? So do we! How to we keep it that way? We write tests! We value tests contribution even more than new features.

Tests are written using pytest.

Have a look at sample tests here.

After you have added the respective tests, you can run the tests locally with make script:

make test

Want to add a new test case and not sure how? Talk to us!


Guidelines

For this section, we refer to read the parent PL guidelines

Reminder

All added or edited code shall be the own original work of the particular contributor. If you use some third-party implementation, all such blocks/functions/modules shall be properly referred and if possible also agreed by code’s author. For example - This code is inspired from http://.... In case you adding new dependencies, make sure that they are compatible with the actual PyTorch Lightning license (ie. dependencies should be at least as permissive as the PyTorch Lightning license).

How to rebase my PR?

We recommend creating a PR in a separate branch other than master, especially if you plan to submit several changes and do not want to wait until the first one is resolved (we can work on them in parallel).

First, make sure you have set upstream by running:

git remote add upstream https://github.com/PyTorchLightning/lightning-flash.git

You’ll know its set up right if you run git remote -v and see something similar to this:

origin  https://github.com/{YOUR_USERNAME}/lightning-flash.git (fetch)
origin  https://github.com/{YOUR_USERNAME}/lightning-flash.git (push)
upstream        https://github.com/PyTorchLightning/lightning-flash.git (fetch)
upstream        https://github.com/PyTorchLightning/lightning-flash.git (push)

Checkout your feature branch and rebase it with upstream’s master before pushing up your feature branch:

git fetch --all --prune
git rebase upstream/master
# follow git instructions to resolve conflicts
git push -f

Question & Answer

  1. How can I help/contribute?

    All help is extremely welcome - reporting bugs, fixing documentation, adding test cases, solving issues and preparing bug fixes. To solve some issues you can start with label good first issue or chose something close to your domain with label help wanted. Before you start to implement anything check that the issue description that it is clear and self-assign the task to you (if it is not possible, just comment that you take it and we assign it to you…).

  2. Is there a recommendation for branch names?

    We do not rely on the name convention so far you are working with your own fork. Anyway it would be nice to follow this convention <type>/<issue-id>_<short-name> where the types are: bugfix, feature, docs, tests, …

  3. I have a model in other framework than PyTorch, how do I add it here?

    Since PyTorch Lightning is written on top of PyTorch. We need models in PyTorch only. Also, we would need same or equivalent results with PyTorch Lightning after converting the models from other frameworks.

Changelog

All notable changes to this project will be documented in this file.

The format is based on Keep a Changelog.

[0.5.1] - 2021-10-26

[0.5.1] - Added

  • Added LabelStudio integration (#554)

  • Added support learn2learn training_strategy for ImageClassifier (#737)

  • Added vissl training_strategies for ImageEmbedder (#682)

  • Added support for from_data_frame to TextClassificationData (#785)

  • Added FastFace integration (#606)

  • Added support for from_lists to TextClassificationData (#805)

[0.5.1] - Changed

  • Changed the default num_workers on linux to 0 (matching the default for other OS) (#759)

  • Optimizer and LR Scheduler registry are used to get the respective inputs to the Task using a string (or a callable). (#777)

[0.5.1] - Fixed

  • Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored (#792)

  • Fixed a bug where loading text data with additional non-numeric columns (not input or target) would give an error (#888)

[0.5.0] - 2021-09-07

[0.5.0] - Added

  • Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method (#552)

  • Added support for from_csv and from_data_frame to ImageClassificationData (#556)

  • Added SimCLR, SwAV, Barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task (#560)

  • Added support for Semantic Segmentation backbones and heads from segmentation-models.pytorch (#562)

  • Added support for nesting of Task objects (#575)

  • Added PointCloudSegmentation Task (#566)

  • Added PointCloudObjectDetection Task (#600)

  • Added a GraphClassifier task (#73)

  • Added the option to pass pretrained as a string to SemanticSegmentation to change pretrained weights to load from segmentation-models.pytorch (#587)

  • Added support for field parameter for loadng JSON based datasets in text tasks. (#585)

  • Added AudioClassificationData and an example for classifying audio spectrograms (#594)

  • Added a SpeechRecognition task for speech to text using Wav2Vec (#586)

  • Added Flash Zero, a zero code command line ML platform built with flash (#611)

  • Added support for .npy and .npz files to ImageClassificationData and AudioClassificationData (#651)

  • Added support for from_csv to the AudioClassificationData (#651)

  • Added option to pass a resolver to the from_csv and from_pandas methods of ImageClassificationData, which is used to resolve filenames given IDs (#651)

  • Added integration with IceVision for the ObjectDetector (#608)

  • Added keypoint detection task (#608)

  • Added instance segmentation task (#608)

  • Added Torch ORT support to Transformer based tasks (#667)

  • Added support for flash zero with the InstanceSegmentation and KeypointDetector tasks (#672)

  • Added support for in_chans argument to the flash ResNet to control the expected number of input channels (#673)

  • Added a QuestionAnswering task for extractive question answering (#607)

  • Added automatic unwrapping of IceVision prediction objects (#727)

  • Added support for the ObjectDetector with FiftyOne (#727)

  • Added support for MP3 files to the SpeechRecognition task with librosa (#726)

  • Added support for from_numpy and from_tensors to AudioClassificationData (#745)

[0.5.0] - Changed

  • Changed how pretrained flag works for loading weights for ImageClassifier task (#560)

  • Removed bolts pretrained weights for SSL from ImageClassifier task (#560)

  • Changed the behaviour of the sampler argument of the DataModule to take a Sampler type rather than instantiated object (#651)

  • Changed arguments to ObjectDetector, use head instead of model and append _fpn to the backbone name instead of the fpn argument (#608)

[0.5.0] - Fixed

  • Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version (#493)

  • Fixed a bug where train and validation metrics weren’t being correctly computed (#559)

  • Fixed a bug where an uncaught ValueError could be raised when checking if a module is available (#615)

  • Fixed a bug where some tasks were not compatible with PyTorch 1.7 due to use of torch.jit.isinstance (#611)

  • Fixed a bug where custom samplers would not be properly forwarded to the data loader (#651)

  • Fixed a bug where it was not possible to pass no metrics to the ImageClassifier or TestClassifier (#660)

  • Fixed a bug where drop_last would be set to True during prediction and testing (#671)

  • Fixed a bug where flash was not compatible with pytorch-lightning >= 1.4.3 (#690)

[0.4.0] - 2021-06-22

[0.4.0] - Added

  • Added integration with FiftyOne (#360)

  • Added flash.serve (#399)

  • Added support for torch.jit to tasks where possible and documented task JIT compatibility (#389)

  • Added option to provide a Sampler to the DataModule to use when creating a DataLoader (#390)

  • Added support for multi-label text classification and toxic comments example (#401)

  • Added a sanity checking feature to flash.serve (#423)

[0.4.0] - Changed

  • Split backbone argument to SemanticSegmentation into backbone and head arguments (#412)

[0.4.0] - Fixed

  • Fixed a bug where the DefaultDataKeys.METADATA couldn’t be a dict (#393)

  • Fixed a bug where the SemanticSegmentation task would not work as expected with finetuning callbacks (#412)

  • Fixed a bug where predict batches could not be visualized with ImageClassificationData (#438)

[0.3.2] - 2021-06-08

[0.3.2] - Fixed

  • Fixed a bug where flash.Trainer.from_argparse_args + finetune would not work (#382)

[0.3.1] - 2021-06-08

[0.3.1] - Added

  • Added deeplabv3, lraspp, and unet backbones for the SemanticSegmentation task (#370)

[0.3.1] - Changed

  • Changed the installation command for extra features (#346)

  • Change resize interpolation default mode to nearest (#352)

[0.3.1] - Deprecated

  • Deprecated SemanticSegmentation backbone names torchvision/fcn_resnet50 and torchvision/fcn_resnet101, use fc_resnet50 and fcn_resnet101 instead (#370)

[0.3.1] - Fixed

  • Fixed flash.Trainer.add_argparse_args not adding any arguments (#343)

  • Fixed a bug where the translation task wasn’t decoding tokens properly (#332)

  • Fixed a bug where huggingface tokenizers were sometimes being pickled (#332)

  • Fixed issue with KorniaParallelTransforms to assure to share the random state between transforms (#351)

  • Fixed a bug where using val_split with overfit_batches would give an infinite recursion (#375)

  • Fixed a bug where some timm models were mistakenly given a global_pool argument (#377)

  • Fixed flash.Trainer.from_argparse_args not passing arguments correctly (#380)

[0.3.0] - 2021-05-20

[0.3.0] - Added

  • Added DataPipeline API (#188 #141 #207)

  • Added timm integration (#196)

  • Added BaseViz Callback (#201)

  • Added backbone API (#204)

  • Added support for Iterable auto dataset (#227)

  • Added multi label support (#230)

  • Added support for schedulers (#232)

  • Added visualisation callback for image classification (#228)

  • Added Video Classification task (#216)

  • Added Dino backbone for image classification (#259)

  • Added Data Sources API (#256 #264 #272)

  • Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState (#229)

  • Added Semantic Segmentation task (#239 #287 #290)

  • Added Object detection prediction example (#283)

  • Added Style Transfer task and accompanying finetuning and prediction examples (#262)

  • Added a Template task and tutorials showing how to contribute a task to flash (#306)

[0.3.0] - Changed

  • Rename valid_ to val_ (#197)

  • Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState (#229)

[0.3.0] - Fixed

  • Fix DataPipeline resolution in Task (#212)

  • Fixed a bug where the backbone used in summarization was not correctly passed to the postprocess (#296)

[0.2.3] - 2021-04-17

[0.2.3] - Added

  • Added TIMM integration as backbones (#196)

[0.2.3] - Fixed

  • Fixed nltk.download (#210)

[0.2.2] - 2021-04-05

[0.2.2] - Changed

  • Switch to use torchmetrics (#169)

  • Better support for optimizer and schedulers (#232)

  • Update lightning version to v1.2 (#133)

[0.2.2] - Fixed

  • Fixed classification softmax (#169)

  • Fixed a bug where loading from a local checkpoint that had pretrained=True without an internet connection would sometimes raise an error (#237)

  • Don’t download data if exists (#157)

[0.2.1] - 2021-3-06

[0.2.1] - Added

  • Added RetinaNet & backbones to ObjectDetector Task (#121)

  • Added .csv image loading utils (#116, #117, #118)

[0.2.1] - Changed

  • Set inputs as optional (#109)

[0.2.1] - Fixed

  • Set minimal requirements (#62)

  • Fixed VGG backbone num_features (#154)

[0.2.0] - 2021-02-12

[0.2.0] - Added

  • Added ObjectDetector Task (#56)

  • Added TabNet for tabular classification (#101)

  • Added support for more backbones(mobilnet, vgg, densenet, resnext) (#45)

  • Added backbones for image embedding model (#63)

  • Added SWAV and SimCLR models to imageclassifier + backbone reorg (#68)

[0.2.0] - Changed

  • Applied transform in FilePathDataset (#97)

  • Moved classification integration from vision root to folder (#86)

[0.2.0] - Fixed

  • Unfreeze default number of workers in datamodule (#57)

  • Fixed wrong label in FilePathDataset (#94)

[0.2.0] - Removed

  • Removed densenet161 duplicate in DENSENET_MODELS (#76)

  • Removed redundant num_features arg from Classification model (#88)

[0.1.0] - 2021-02-02

[0.1.0] - Added

  • Added flash_notebook examples (#9)

  • Added strategy to trainer.finetune with NoFreeze, Freeze, FreezeUnfreeze, UnfreezeMilestones Callbacks(#39)

  • Added SummarizationData, SummarizationTask and TranslationData, TranslationTask (#37)

  • Added ImageEmbedder (#36)

Template

The Task

Here you should add a description of your task. For example: Classification is the task of assigning one of a number of classes to each data point.


Example

Note

Here you should add a short intro to your example, and then use literalinclude to add it. To make it simple, you can fill in this template.

Let’s look at the task of <describe the task> using the <data set used in the example>. The dataset contains <describe the data>. Here’s an outline:

<present the folder structure of the data or some data samples here>

Once we’ve downloaded the data using download_data(), we create the <link to the DataModule with :class:>. We select a pre-trained backbone to use for our <link to the Task with :class:> and finetune on the <name of the data set> data. We then use the trained <link to the Task with :class:> for inference. Finally, we save the model. Here’s the full example:

<include the example with literalinclude>

import numpy as np
import torch
from sklearn import datasets

import flash
from flash.template import TemplateData, TemplateSKLearnClassifier

# 1. Create the DataModule
datamodule = TemplateData.from_sklearn(
    train_bunch=datasets.load_iris(),
    val_split=0.1,
)

# 2. Build the task
model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Classify a few examples
predictions = model.predict(
    [
        np.array([4.9, 3.0, 1.4, 0.2]),
        np.array([6.9, 3.2, 5.7, 2.3]),
        np.array([7.2, 3.0, 5.8, 1.6]),
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("template_model.pt")

Indices and tables


© Copyright 2020-2021, PyTorch Lightning. Revision 78faff0e.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.5.1
Versions
latest
stable
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_build
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.