Lightning Flash¶
Quick Start¶
Flash is a high-level deep learning framework for fast prototyping, baselining, finetuning and solving deep learning problems. It features a set of tasks for you to use for inference and finetuning out of the box, and an easy to implement API to customize every step of the process for full flexibility.
Flash is built for beginners with a simple API that requires very little deep learning background, and for data scientists, Kagglers, applied ML practitioners and deep learning researchers that want a quick way to get a deep learning baseline with advanced features PyTorch Lightning offers.
Why Flash?¶
For getting started with Deep Learning¶
Easy to learn¶
If you are just getting started with deep learning, Flash offers common deep learning tasks you can use out-of-the-box in a few lines of code, no math, fancy nn.Modules or research experience required!
Easy to scale¶
Flash is built on top of PyTorch Lightning, a powerful deep learning research framework for training models at scale. With the power of Lightning, you can train your flash tasks on any hardware: CPUs, GPUs or TPUs without any code changes.
Easy to upskill¶
If you want to create more complex and customized models, you can refactor any part of flash with PyTorch or PyTorch Lightning components to get all the flexibility you need. Lightning is just organized PyTorch with the unnecessary engineering details abstracted away.
Flash (high-level)
Lightning (mid-level)
PyTorch (low-level)
When you need more flexibility you can build your own tasks or simply use Lightning directly.
For Deep learning research¶
Quickest way to a baseline¶
PyTorch Lightning is designed to abstract away unnecessary boilerplate, while enabling maximal flexibility. In order to provide full flexibility, solving very common deep learning problems such as classification in Lightning still requires some boilerplate. It can still take quite some time to get a baseline model running on a new dataset or out of domain task. We created Flash to answer our users need for a super quick way to baseline for Lightning using proven backbones for common data patterns. Flash aims to be the easiest starting point for your research- start with a Flash Task to benchmark against, and override any part of flash with Lightning or PyTorch components on your way to SOTA research.
Flexibility where you want it¶
Flash tasks are essentially LightningModules, and the Flash Trainer is a thin wrapper for the Lightning Trainer. You can use your own LightningModule instead of the Flash task, the Lightning Trainer instead of the flash trainer, etc. Flash helps you focus even more only on your research, and less on anything else.
Standard best practices¶
Flash tasks implement the standard best practices for a variety of different models and domains, to save you time digging through different implementations. Flash abstracts even more details than Lightning, allowing deep learning experts to share their tips and tricks for solving scoped deep learning problems.
Tasks¶
Flash is comprised of a collection of Tasks. The Flash tasks are laser-focused objects designed to solve a well-defined type of problem, using state-of-the-art methods.
The Flash tasks contain all the relevant information to solve the task at hand- the number of class labels you want to predict, number of columns in your dataset, as well as details on the model architecture used such as loss function, optimizers, etc.
Here are examples of tasks:
from flash.text import TextClassifier
from flash.image import ImageClassifier
from flash.tabular import TabularClassifier
Note
Tasks are inflexible by definition! To get more flexibility, you can simply use LightningModule
directly or modify an existing task in just a few lines.
Inference¶
Inference is the process of generating predictions from trained models. To use a task for inference:
Init your task with pretrained weights using a checkpoint (a checkpoint is simply a file that capture the exact value of all parameters used by a model). Local file or URL works.
Pass in the data to
flash.core.model.Task.predict()
.
Here’s an example of inference:
# import our libraries
from flash.text import TextClassifier
# 1. Init the finetuned task from URL
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
# 2. Perform inference from list of sequences
predictions = model.predict(
[
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"This guy has done a great job with this movie!",
]
)
print(predictions)
We get the following output:
["negative", "negative", "positive"]
Finetuning¶
Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset. All Flash tasks have pre-trained backbones that are already trained on large datasets such as ImageNet. Finetuning on pretrained models decreases training time significantly.
To use a Task for finetuning:
Load your data and organize it using a DataModule customized for the task (example:
ImageClassificationData
).Choose and initialize your Task which has state-of-the-art backbones built in (example:
ImageClassifier
).Init a
flash.core.trainer.Trainer
.Choose a finetune strategy (example: “freeze”) and call
flash.core.trainer.Trainer.finetune()
with your data.Save your finetuned model.
Here’s an example of finetuning.
from pytorch_lightning import seed_everything
import flash
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# set the random seeds.
seed_everything(42)
# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)
# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
# 4. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
Using a finetuned model¶
Once you’ve finetuned, use the model to predict:
# Serialize predictions as labels, automatically inferred from the training data in part 2.
model.serializer = Labels()
predictions = model.predict(
[
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
]
)
print(predictions)
We get the following output:
['bees', 'ants']
Or you can use the saved model for prediction anywhere you want!
from flash.image import ImageClassifier
# load finetuned checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")
predictions = model.predict("path/to/your/own/image.png")
Training¶
When you have enough data, you’re likely better off training from scratch instead of finetuning.
To train a task from scratch:
Load your data and organize it using a DataModule customized for the task (example:
ImageClassificationData
).Choose and initialize your Task (setting
pretrained=False
) which has state-of-the-art backbones built in (example:ImageClassifier
).Init a
flash.core.trainer.Trainer
or apytorch_lightning.trainer.Trainer
.Call
flash.core.trainer.Trainer.fit()
with your data set.Save your trained model.
Here’s an example:
from pytorch_lightning import seed_everything
import flash
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# set the random seeds.
seed_everything(42)
# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)
# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False)
# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
# 4. Train the model
trainer.fit(model, datamodule=datamodule)
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
A few Built-in Tasks¶
More tasks coming soon!
Contribute a task¶
The lightning + Flash team is hard at work building more tasks for common deep-learning use cases. But we’re looking for incredible contributors like you to submit new tasks!
Join our Slack to get help becoming a contributor!
Installation¶
Install with pip¶
pip install lightning-flash
Optionally, you can install Flash with extra packages for each domain.
For a single domain, use: pip install 'lightning-flash[{DOMAIN}]'
.
pip install 'lightning-flash[image]'
pip install 'lightning-flash[tabular]'
pip install 'lightning-flash[text]'
...
For muliple domains, use: pip install 'lightning-flash[{DOMAIN_1, DOMAIN_2, ...}]'
.
pip install 'lightning-flash[audio,image]'
...
For contributors, please install Flash with packages for testing Flash and building docs.
# Clone Flash repository locally
git clone https://github.com/[your username]/lightning-flash.git
cd lightning-flash
# Install Flash in editable mode with extra packages for development
pip install -e '.[dev]'
Install with conda¶
Flash is available via conda forge. Install it with:
conda install -c conda-forge lightning-flash
Install from source¶
You can install Flash from source without any domain specific dependencies with:
pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git'
To install Flash with domain dependencies, use:
pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[image]'
You can again install dependencies for multiple domains by separating them with commas as above.
Flash Zero¶
Flash Zero is a zero-code machine learning platform. Here’s an image classification example to illustrate with one of the dozens tasks available.
Flash Zero in 3 steps¶
1. Select your task¶
flash {TASK_NAME}
Here is the list of currently supported tasks.
audio_classification Classify audio spectrograms.
graph_classification Classify graphs.
image_classification Classify images.
instance_segmentation Segment object instances in images.
keypoint_detection Detect keypoints in images.
object_detection Detect objects in images.
pointcloud_detection Detect objects in point clouds.
pointcloud_segmentation Segment objects in point clouds.
question_answering Extractive Question Answering.
semantic_segmentation Segment objects in images.
speech_recognition Speech recognition.
style_transfer Image style transfer.
summarization Summarize text.
tabular_classification Classify tabular data.
text_classification Classify text.
translation Translate text.
video_classification Classify videos.
2. Pass in your own data¶
flash image_classification from_folders --train_folder data/hymenoptera_data/train
3. Modify the model and training parameters¶
flash image_classification --trainer.max_epochs 10 --model.backbone resnet50 from_folders --train_folder data/hymenoptera_data/train
Note
The trainer and model arguments should be placed before the source
subcommand. Here it is from_folders
.
Other Examples¶
Image Object Detection¶
To train an Object Detector on COCO 2017 dataset, you could use the following command:
flash object_detection from_coco --train_folder data/coco128/images/train2017/ --train_ann_file data/coco128/annotations/instances_train2017.json --val_split .3 --batch_size 8 --num_workers 4
Image Object Segmentation¶
To train an Image Segmenter on CARLA driving simulator dataset
flash semantic_segmentation from_folders --train_folder data/CameraRGB --train_target_folder data/CameraSeg --num_classes 21
Below is an example where the head, the backbone and its pretrained weights are customized.
flash semantic_segmentation --model.head fpn --model.backbone efficientnet-b0 --model.pretrained advprop from_folders --train_folder data/CameraRGB --train_target_folder data/CameraSeg --num_classes 21
Video Classification¶
To train an Video Classifier on the Kinetics dataset, you could use the following command:
flash video_classification from_folders --train_folder data/kinetics/train/ --clip_duration 1 --num_workers 0
CLI options¶
Flash Zero is built on top of the lightning CLI, so the trainer and model arguments can be configured either from the command line or from a config file. For example, to run the image classifier for 10 epochs with a resnet50 backbone you can use:
flash image_classification --trainer.max_epochs 10 --model.backbone resnet50
To view all of the available options for a task, run:
flash image_classification --help
Using Your Own Data¶
Flash Zero works with your own data through subcommands. The available subcommands for each task are given at the bottom
of their help pages (e.g. when running flash image-classification --help
). You can then use the required
subcommand to train on your own data. Let’s look at an example using the Hymenoptera data from the
Image Classification guide. First, download and unzip your data:
curl https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip -o hymenoptera_data
unzip hymenoptera_data.zip
Now train with Flash Zero:
flash image_classification from_folders --train_folder ./hymenoptera_data/train
Getting Help¶
To find all available tasks, you can run:
flash --help
This will output the following:
Commands:
audio_classification Classify audio spectrograms.
graph_classification Classify graphs.
image_classification Classify images.
instance_segmentation Segment object instances in images.
keypoint_detection Detect keypoints in images.
object_detection Detect objects in images.
pointcloud_detection Detect objects in point clouds.
pointcloud_segmentation Segment objects in point clouds.
question_answering Extractive Question Answering.
semantic_segmentation Segment objects in images.
speech_recognition Speech recognition.
style_transfer Image style transfer.
summarization Summarize text.
tabular_classification Classify tabular data.
text_classification Classify text.
translation Translate text.
video_classification Classify videos.
To get more information about a specific task, you can do the following:
flash image_classification --help
You can view the help page for each subcommand. For example, to view the options for training an image classifier from folders, you can run:
flash image_classification from_folders --help
Finally, you can generate a config.yaml file from the client to ease parameters modification by running:
flash image_classification --print_config > config.yaml
Flash in Production¶
Flash Serve¶
Flash Serve makes model deployment simple.
Server Side¶
from flash.image import SemanticSegmentation
from flash.image.segmentation.serialization import SegmentationLabels
model = SemanticSegmentation.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
)
model.serializer = SegmentationLabels(visualize=False)
model.serve()
Client Side¶
import base64
from pathlib import Path
import requests
import flash
with (Path(flash.ASSETS_ROOT) / "road.png").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Credits to @rlizzo, @hhsecond, @lantiga, @luiscape for building the Flash Serve Engine. Read all about it here.
Training from scratch¶
Some Flash tasks have been pretrained on large data sets.
To accelerate your training, calling the finetune()
method using a pretrained backbone will fine-tune the backbone to generate a model customized to your data set and desired task.
From the Quick Start guide.
To train a task from scratch:
Load your data and organize it using a DataModule customized for the task (example:
ImageClassificationData
).Choose and initialize your Task (setting
pretrained=False
) which has state-of-the-art backbones built in (example:ImageClassifier
).Init a
flash.core.trainer.Trainer
or apytorch_lightning.trainer.Trainer
.Call
flash.core.trainer.Trainer.fit()
with your data set.Save your trained model.
Here’s an example:
from pytorch_lightning import seed_everything
import flash
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# set the random seeds.
seed_everything(42)
# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)
# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False)
# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
# 4. Train the model
trainer.fit(model, datamodule=datamodule)
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
Training options¶
Flash tasks supports many advanced training functionalities out-of-the-box, such as:
limit number of epochs
# train for 10 epochs
flash.Trainer(max_epochs=10)
Training on GPUs
# train on 1 GPU
flash.Trainer(gpus=1)
Training on multiple GPUs
# train on multiple GPUs
flash.Trainer(gpus=4)
# train on gpu 1, 3, 5 (3 gpus total)
flash.Trainer(gpus=[1, 3, 5])
Using mixed precision training
# Multi GPU with mixed precision
flash.Trainer(gpus=2, precision=16)
Training on TPUs
# Train on TPUs
flash.Trainer(tpu_cores=8)
You can add to the flash Trainer any argument from the Lightning trainer! Learn more about the Lightning Trainer here.
Finetuning¶
Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset.
Terminology¶
Here are common terms you need to be familiar with:
Term |
Definition |
---|---|
Finetuning |
The process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset |
Transfer learning |
The common name for finetuning |
Backbone |
The neural network that was pretrained on a different dataset |
Head |
Another neural network (usually smaller) that maps the backbone to your particular dataset |
Freeze |
Disabling gradient updates to a model (ie: not learning) |
Unfreeze |
Enabling gradient updates to a model |
Finetuning in Flash¶
From the Quick Start guide.
To use a Task for finetuning:
Load your data and organize it using a DataModule customized for the task (example:
ImageClassificationData
).Choose and initialize your Task which has state-of-the-art backbones built in (example:
ImageClassifier
).Init a
flash.core.trainer.Trainer
.Choose a finetune strategy (example: “freeze”) and call
flash.core.trainer.Trainer.finetune()
with your data.Save your finetuned model.
Here’s an example of finetuning.
from pytorch_lightning import seed_everything
import flash
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# set the random seeds.
seed_everything(42)
# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)
# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
# 4. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
Using a finetuned model¶
Once you’ve finetuned, use the model to predict:
# Serialize predictions as labels, automatically inferred from the training data in part 2.
model.serializer = Labels()
predictions = model.predict(
[
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
]
)
print(predictions)
We get the following output:
['bees', 'ants']
Or you can use the saved model for prediction anywhere you want!
from flash.image import ImageClassifier
# load finetuned checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")
predictions = model.predict("path/to/your/own/image.png")
Finetune strategies¶
Finetuning is very task specific. Each task encodes the best finetuning practices for that task. However, Flash gives you a few default strategies for finetuning.
Finetuning operates on two things, the model backbone and the head. The backbone is the neural network that was pre-trained. The head is another neural network that bridges between the backbone and your particular dataset.
no_freeze¶
In this strategy, the backbone and the head are unfrozen from the beginning.
trainer.finetune(model, datamodule, strategy="no_freeze")
In pseudocode, this looks like:
backbone = Resnet50()
head = nn.Linear(...)
backbone.unfreeze()
head.unfreeze()
train(backbone, head)
freeze¶
The freeze strategy keeps the backbone frozen throughout.
trainer.finetune(model, datamodule, strategy="freeze")
The pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head)
Advanced strategies¶
Every finetune strategy can also be customized.
freeze_unfreeze¶
By default, in this strategy the backbone is frozen for 5 epochs then unfrozen:
trainer.finetune(model, datamodule, strategy="freeze_unfreeze")
Or we can customize it unfreeze the backbone after a different epoch. For example, to unfreeze after epoch 7:
from flash.core.finetuning import FreezeUnfreeze
trainer.finetune(model, datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=7))
Under the hood, the pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head, epochs=10)
# unfreeze after 10 epochs
backbone.unfreeze()
train(backbone, head)
unfreeze_milestones¶
This strategy allows you to unfreeze part of the backbone at predetermined intervals
Here’s an example where: - backbone starts frozen - at epoch 3 the last 2 layers unfreeze - at epoch 8 the full backbone unfreezes
from flash.core.finetuning import UnfreezeMilestones
trainer.finetune(model, datamodule, strategy=UnfreezeMilestones(unfreeze_milestones=(3, 8), num_layers=2))
Under the hood, the pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head, epochs=3)
# unfreeze last 2 layers at epoch 3
backbone.unfreeze_last_layers(2)
train(backbone, head, epochs=8)
# unfreeze the full backbone
backbone.unfreeze()
Custom Strategy¶
For even more customization, create your own finetuning callback. Learn more about callbacks here.
from flash.core.finetuning import FlashBaseFinetuning
# Create a finetuning callback
class FeatureExtractorFreezeUnfreeze(FlashBaseFinetuning):
def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True):
# this will set self.attr_names as ["backbone"]
super().__init__("backbone", train_bn)
self._unfreeze_epoch = unfreeze_epoch
def finetune_function(self, pl_module, current_epoch, optimizer, opt_idx):
# unfreeze any module you want by overriding this function
# When ``current_epoch`` is 5, backbone will start to be trained.
if current_epoch == self._unfreeze_epoch:
self.unfreeze_and_add_param_group(
pl_module.backbone,
optimizer,
)
# Pass the callback to trainer.finetune
trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5))
Predictions (inference)¶
You can use Flash to get predictions on pretrained or finetuned models.
Predict on a single sample of data¶
You can pass in a sample of data (image file path, a string of text, etc) to the predict()
method.
from flash.core.data.utils import download_data
from flash.image import ImageClassifier
# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
# 3. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)
Predict on a csv file¶
from flash.core.data.utils import download_data
from flash.tabular import TabularClassifier
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")
# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabnet_classification_model.pt")
# 3. Generate predictions from a csv file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)
Serializing predictions¶
To change how predictions are serialized you can attach a Serializer
to your
Task
. For example, you can choose to serialize outputs as probabilities (for more options see the API
reference below).
from flash.core.classification import Probabilities
from flash.core.data.utils import download_data
from flash.image import ImageClassifier
# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
# 3. Attach the Serializer
model.serializer = Probabilities()
# 4. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)
# out: [[0.5926494598388672, 0.40735048055648804]]
TorchScript JIT Support¶
We test all of our tasks for compatibility with torch.jit
.
This table gives a breakdown of the supported features.
Task |
|||
---|---|---|---|
Yes |
Yes |
Yes |
|
Yes |
No |
Yes |
|
Yes |
Yes |
Yes |
|
No |
Yes |
Yes |
|
No |
Yes |
Yes |
|
No |
Yes |
No |
|
No |
Yes * |
Yes |
|
No |
Yes |
Yes |
|
No |
Yes |
Yes |
|
No |
Yes |
Yes |
* with strict=False
Data¶
Terminology¶
Here are common terms you need to be familiar with:
Term |
Definition |
---|---|
The |
|
The |
|
The |
|
The |
|
|
|
|
|
The |
How to use out-of-the-box Flash DataModules¶
Flash provides several DataModules with helpers functions. Check out the Image Classification section (or the sections for any of our other tasks) to learn more.
Data Processing¶
Currently, it is common practice to implement a torch.utils.data.Dataset
and provide it to a torch.utils.data.DataLoader
.
However, after model training, it requires a lot of engineering overhead to make inference on raw data and deploy the model in production environment.
Usually, extra processing logic should be added to bridge the gap between training data and raw data.
The DataSource
class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way.
The Preprocess
and Postprocess
classes can be used to manage the preprocessing and postprocessing transforms.
The Serializer
class provides the logic for converting Postprocess
outputs to the desired predict format (e.g. classes, labels, probabilities, etc.).
By providing a series of hooks that can be overridden with custom data processing logic (or just targeted with transforms), Flash gives the user much more granular control over their data processing flow.
Here are the primary advantages:
Making inference on raw data simple
Make the code more readable, modular and self-contained
Data Augmentation experimentation is simpler
To change the processing behavior only on specific stages for a given hook,
you can prefix each of the Preprocess
and Postprocess
hooks by adding train
, val
, test
or predict
.
Check out Preprocess
for some examples.
How to customize existing DataModules¶
Any Flash DataModule
can be created directly from datasets using the from_datasets()
like this:
from flash import DataModule, Trainer
data_module = DataModule.from_datasets(train_dataset=MyDataset())
trainer = Trainer()
trainer.fit(model, data_module=data_module)
The DataModule
provides additional classmethod
helpers (from_*
) for loading data from various sources.
In each from_*
method, the DataModule
internally retrieves the correct DataSource
to use from the Preprocess
.
Flash AutoDataset
instances are created from the DataSource
for train, val, test, and predict.
The DataModule
populates the DataLoader
for each stage with the corresponding AutoDataset
.
Customize preprocessing of DataModules¶
The Preprocess
contains the processing logic related to a given task.
Each Preprocess
provides some default transforms through the default_transforms()
method.
Users can easily override these by providing their own transforms to the DataModule
.
Here’s an example:
from flash.core.data.transforms import ApplyToKeys
from flash.image import ImageClassificationData, ImageClassifier
transform = {"to_tensor_transform": ApplyToKeys("input", my_to_tensor_transform)}
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
train_transform=transform,
val_transform=transform,
test_transform=transform,
)
Alternatively, the user may directly override the hooks for their needs like this:
from typing import Any, Dict
from flash.image import ImageClassificationData, ImageClassifier, ImageClassificationPreprocess
class CustomImageClassificationPreprocess(ImageClassificationPreprocess):
def to_tensor_transform(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["input"] = my_to_tensor_transform(sample["input"])
return sample
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
preprocess=CustomImageClassificationPreprocess(),
)
Create your own Preprocess and DataModule¶
The example below shows a very simple ImageClassificationPreprocess
with a single ImageClassificationFoldersDataSource
and an ImageClassificationDataModule
.
1. User-Facing API design¶
Designing an easy-to-use API is key. This is the first and most important step.
We want the ImageClassificationDataModule
to generate a dataset from folders of images arranged in this way.
Example:
train/dog/xxx.png
train/dog/xxy.png
train/dog/xxz.png
train/cat/123.png
train/cat/nsdf3.png
train/cat/asd932.png
Example:
dm = ImageClassificationDataModule.from_folders(
train_folder="./data/train",
val_folder="./data/val",
test_folder="./data/test",
predict_folder="./data/predict",
)
model = ImageClassifier(...)
trainer = Trainer(...)
trainer.fit(model, dm)
2. The DataSource¶
We start by implementing the ImageClassificationFoldersDataSource
.
The load_data
method will produce a list of files and targets from the given directory.
The load_sample
method will load the given file as a PIL.Image
.
Here’s the full ImageClassificationFoldersDataSource
:
from PIL import Image
from torchvision.datasets.folder import make_dataset
from typing import Any, Dict
from flash.core.data.data_source import DataSource, DefaultDataKeys
class ImageClassificationFoldersDataSource(DataSource):
def load_data(self, folder: str, dataset: Any) -> Iterable:
# The dataset is optional but can be useful to save some metadata.
# `metadata` contains the image path and its corresponding label
# with the following structure:
# [(image_path_1, label_1), ... (image_path_n, label_n)].
metadata = make_dataset(folder)
# for the train `AutoDataset`, we want to store the `num_classes`.
if self.training:
dataset.num_classes = len(np.unique([m[1] for m in metadata]))
return [
{
DefaultDataKeys.INPUT: file,
DefaultDataKeys.TARGET: target,
}
for file, target in metadata
]
def predict_load_data(self, predict_folder: str) -> Iterable:
# This returns [image_path_1, ... image_path_m].
return [{DefaultDataKeys.INPUT: file} for file in os.listdir(folder)]
def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
sample[DefaultDataKeys.INPUT] = Image.open(sample[DefaultDataKeys.INPUT])
return sample
Note
We return samples as dictionaries using the DefaultDataKeys
by convention. This is the recommended (although not required) way to represent data in Flash.
3. The Preprocess¶
Next, implement your custom ImageClassificationPreprocess
with some default transforms and a reference to the data source:
from typing import Any, Callable, Dict, Optional
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Preprocess
import torchvision.transforms.functional as T
# Subclass `Preprocess`
class ImageClassificationPreprocess(Preprocess):
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
):
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FOLDERS: ImageClassificationFoldersDataSource(),
},
default_data_source=DefaultDataSources.FOLDERS,
)
def get_state_dict(self) -> Dict[str, Any]:
return {**self.transforms}
@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)
def default_transforms(self) -> Dict[str, Callable]:
return {"to_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.to_tensor)}
4. The DataModule¶
Finally, let’s implement the ImageClassificationDataModule
.
We get the from_folders
classmethod for free as we’ve registered a DefaultDataSources.FOLDERS
data source in our ImageClassificationPreprocess
.
All we need to do is attach our Preprocess
class like this:
from flash import DataModule
class ImageClassificationDataModule(DataModule):
# Set `preprocess_cls` with your custom `Preprocess`.
preprocess_cls = ImageClassificationPreprocess
How it works behind the scenes¶
DataSource¶
Note
The load_data()
and
load_sample()
will be used to generate an
AutoDataset
object.
Here is the AutoDataset
pseudo-code.
class AutoDataset:
def __init__(
self,
data: List[Any], # output of `DataSource.load_data`
data_source: DataSource,
running_stage: RunningStage,
):
self.data = data
self.data_source = data_source
def __getitem__(self, index: int):
return self.data_source.load_sample(self.data[index])
def __len__(self):
return len(self.data)
Preprocess¶
Note
The pre_tensor_transform()
,
to_tensor_transform()
,
post_tensor_transform()
,
collate()
,
per_batch_transform()
are injected as the
torch.utils.data.DataLoader.collate_fn
function of the DataLoader.
Here is the pseudo code using the preprocess hooks name. Flash takes care of calling the right hooks for each stage.
Example:
# This will be wrapped into a :class:`~flash.core.data.batch._Preprocessor`.
def collate_fn(samples: Sequence[Any]) -> Any:
# This will be wrapped into a :class:`~flash.core.data.batch._Sequential`
for sample in samples:
sample = pre_tensor_transform(sample)
sample = to_tensor_transform(sample)
sample = post_tensor_transform(sample)
samples = type(samples)(samples)
# if :func:`flash.core.data.process.Preprocess.per_sample_transform_on_device` hook is overridden,
# those functions below will be no-ops
samples = collate(samples)
samples = per_batch_transform(samples)
return samples
dataloader = DataLoader(dataset, collate_fn=collate_fn)
Note
The per_sample_transform_on_device
, collate
, per_batch_transform_on_device
are injected
after the LightningModule
transfer_batch_to_device
hook.
Here is the pseudo code using the preprocess hooks name. Flash takes care of calling the right hooks for each stage.
Example:
# This will be wrapped into a :class:`~flash.core.data.batch._Preprocessor`
def collate_fn(samples: Sequence[Any]) -> Any:
# if ``per_batch_transform`` hook is overridden, those functions below will be no-ops
samples = [per_sample_transform_on_device(sample) for sample in samples]
samples = type(samples)(samples)
samples = collate(samples)
samples = per_batch_transform_on_device(samples)
return samples
# move the data to device
data = lightning_module.transfer_data_to_device(data)
data = collate_fn(data)
predictions = lightning_module(data)
Postprocess and Serializer¶
Once the predictions have been generated by the Flash Task
, the Flash
DataPipeline
will execute the Postprocess
hooks and the
Serializer
behind the scenes.
First, the per_batch_transform()
hooks will be applied on the batch predictions.
Then, the uncollate()
will split the batch into individual predictions.
Next, the per_sample_transform()
will be applied on each prediction.
Finally, the serialize()
method will be called to serialize the predictions.
Note
The transform can be applied either on device or CPU
.
Here is the pseudo-code:
Example:
# This will be wrapped into a :class:`~flash.core.data.batch._Postprocessor`
def uncollate_fn(batch: Any) -> Any:
batch = per_batch_transform(batch)
samples = uncollate(batch)
samples = [per_sample_transform(sample) for sample in samples]
# only if serializers are enabled.
return [serialize(sample) for sample in samples]
predictions = lightning_module(data)
return uncollate_fn(predictions)
Registry¶
Available Registries¶
Registries are Flash internal key-value database to store a mapping between a name and a function.
In simple words, they are just advanced dictionary storing a function from a key string.
Registries help organize code and make the functions accessible all across the Flash
codebase.
Each Flash Task
can have several registries as static attributes.
Currently, Flash uses internally registries only for backbones, but more components will be added.
1. Imports¶
from functools import partial
from flash import Task
from flash.core.registry import FlashRegistry
2. Init a Registry¶
It is good practice to associate one or multiple registry to a Task as follow:
# creating a custom `Task` with its own registry
class MyImageClassifier(Task):
backbones = FlashRegistry("backbones")
def __init__(
self,
backbone: str = "resnet18",
pretrained: bool = True,
):
...
self.backbone, self.num_features = self.backbones.get(backbone)(pretrained=pretrained)
3. Adding new functions¶
Your custom functions can be registered within a FlashRegistry
as a decorator or directly.
# Option 1: Used with partial.
def fn(backbone: str, pretrained: bool = True):
# Create backbone and backbone output dimension (`num_features`)
backbone, num_features = None, None
return backbone, num_features
# HINT 1: Use `from functools import partial` if you want to store some arguments.
MyImageClassifier.backbones(fn=partial(fn, backbone="my_backbone"), name="username/partial_backbone")
# Option 2: Using decorator.
@MyImageClassifier.backbones(name="username/decorated_backbone")
def fn(pretrained: bool = True):
# Create backbone and backbone output dimension (`num_features`)
backbone, num_features = None, None
return backbone, num_features
4. Accessing registered functions¶
You can now access your function from your task!
# 3.b Optional: List available backbones
print(MyImageClassifier.available_backbones())
# 4. Build the model
model = MyImageClassifier(backbone="username/decorated_backbone")
Here’s the output:
['username/decorated_backbone', 'username/partial_backbone']
5. Pre-registered backbones¶
Flash provides populated registries containing lots of available backbones.
Example:
from flash.image.backbones import OBJ_DETECTION_BACKBONES
from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES
print(IMAGE_CLASSIFIER_BACKBONES.available_keys())
""" out:
['adv_inception_v3', 'cspdarknet53', 'cspdarknet53_iabn', 430+.., 'xception71']
"""
Flash Serve¶
Flash Serve is a library to easily serve models in production.
Terminology¶
Here are common terms you need to be familiar with:
Term |
Definition |
---|---|
de-serialization |
Transform data encoded as text into tensors |
inference function |
A function taking the decoded tensors and forward them through the model to produce predictions. |
serialization |
Transform the predictions tensors back to a text encoding. |
|
The |
|
The |
|
The |
The |
Example¶
In this tutorial, we will serve a Resnet18 from the PyTorchVision library in 3 steps.
The entire tutorial can be found under flash_examples/serve/generic
.
Introduction¶
Traditionally, an inference pipeline is made out of 3 steps:
de-serialization
: Transform data encoded as text into tensors.inference function
: A function taking the decoded tensors and forward them through the model to produce predictions.serialization
: Transform the predictions tensors back as text.
In this example, we will implement only the inference function as Flash Serve already provides some built-in de-serialization
and serialization
functions with Image
Step 1 - Create a ModelComponent¶
Inside inference_serve.py
,
we will implement a ClassificationInference
class, which overrides ModelComponent
.
First, we need make the following imports:
import torch
import torchvision
from flash.core.serve import Composition, Servable, ModelComponent, expose
from flash.core.serve.types import Image, Label
To implement ClassificationInference
, we need to implement a method responsible for inference function
and decorated with the expose()
function.
The name of the inference method isn’t constrained, but we will use classify
as appropriate in this example.
Our classify function will take a tensor image, apply some normalization on it, and forward it through the model.
def classify(img):
img = img.float() / 255
mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
img = (img - mean) / std
img = img.permute(0, 3, 2, 1)
out = self.model(img)
return out.argmax()
The expose()
is a python decorator extending the decorated function with the de-serialization
, serialization
steps.
Note
Flash Serve was designed this way to enable several models to be chained together by removing the decorator.
The expose()
function takes 2 arguments:
inputs
: Dictionary mapping the decorated function inputs toBaseType
objects.outputs
: Dictionary mapping the decorated function outputs toBaseType
objects.
A BaseType
is a python dataclass
which implements a serialize
and deserialize
function.
Note
Flash Serve has already several BaseType
built-in such as Image
or Text
.
class ClassificationInference(ModelComponent):
def __init__(self, model: Servable):
self.model = model
@expose(
inputs={"img": Image()},
outputs={"prediction": Label(path="imagenet_labels.txt")},
)
def classify(self, img):
img = img.float() / 255
mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
img = (img - mean) / std
img = img.permute(0, 3, 2, 1)
out = self.model(img)
return out.argmax()
Step 2 - Create a scripted Model¶
Using the PyTorchVision library, we create a resnet18
and use torch.jit.script to script the model.
Note
TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.
model = torchvision.models.resnet18(pretrained=True).eval()
torch.jit.script(model).save("resnet.pt")
Step 3 - Serve the model¶
The Servable
takes as argument the path to the TorchScripted model and then will be passed to our ClassificationInference
class.
The ClassificationInference
instance will be passed as argument to a Composition
class.
Once the Composition
class is instantiated, just call its serve()
method.
resnet = Servable("resnet.pt")
comp = ClassificationInference(resnet)
composition = Composition(classification=comp)
composition.serve()
Launching the server.¶
In Terminal 1¶
Just run:
python inference_server.py
And you should see this in your terminal
You should also see an Swagger UI already built for you at http://127.0.0.1:8000/docs
In Terminal 2¶
Run this script from another terminal:
import base64
from pathlib import Path
import requests
with Path("fish.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
body = {"session": "UUID", "payload": {"img": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
# {'session': 'UUID', 'result': {'prediction': 'goldfish, Carassius auratus'}}
Credits to @rlizzo, @hhsecond, @lantiga, @luiscape for building Flash Serve Engine.
Backbones and Heads¶
Backbones are the pre trained models that can be used with a task.
The backbones or heads that are available can be found by using the available_backbones
and available_heads
methods.
To get the available backbones for a task like ImageClassifier
, run:
from flash.image import ImageClassifier
# get the backbones available for ImageClassifier
backbones = ImageClassifier.available_backbones()
# print the backbones
print(backbones)
To get the available heads for a task like SemanticSegmentation
, run:
from flash.image import SemanticSegmentation
# get the heads available for SemanticSegmentation
heads = SemanticSegmentation.available_heads()
# print the heads
print(heads)
Optimization (Optimizers and Schedulers)¶
Using optimizers and learning rate schedulers with Flash has become easier and cleaner than ever.
With the use of Registry, instantiation of an optimzer or a learning rate scheduler can done with just a string.
Setting an optimizer to a task¶
Each task has a built-in method available_optimizers()
which will list all the optimizers
registered with Flash.
>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_optimizers()
['adadelta', ..., 'sgd']
To train / finetune a Task
of your choice, just pass on a string.
from flash.image import ImageClassifier
model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4)
In order to customize specific parameters of the Optimizer, pass along a dictionary of kwargs with the string as a tuple.
from flash.image import ImageClassifier
model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=("Adam", {"amsgrad": True}), learning_rate=1e-4)
An alternative to customizing an optimizer using a tuple is to pass it as a callable.
from functools import partial
from torch.optim import Adam
from flash.image import ImageClassifier
model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=partial(Adam, amsgrad=True), learning_rate=1e-4)
Setting a Learning Rate Scheduler¶
Each task has a built-in method available_lr_schedulers()
which will list all the learning
rate schedulers registered with Flash.
>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_lr_schedulers()
['lambdalr', ..., 'cosineannealingwarmrestarts']
To train / finetune a Task
of your choice, just pass on a string.
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4, lr_scheduler="constant_schedule"
)
Note
"constant_schedule"
and a few other lr schedulers will be available only if you have installed the transformers
library from Hugging Face.
In order to customize specific parameters of the LR Scheduler, pass along a dictionary of kwargs with the string as a tuple.
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=("StepLR", {"step_size": 10}),
)
An alternative to customizing the LR Scheduler using a tuple is to pass it as a callable.
from functools import partial
from torch.optim.lr_scheduler import CyclicLR
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=partial(CyclicLR, step_size_up=1500, mode="exp_range", gamma=0.5),
)
Additionally, the lr_scheduler
parameter also accepts the Lightning Scheduler configuration which can be passed on using a tuple.
The Lightning Scheduler configuration is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.
lr_scheduler_config = {
# REQUIRED: The scheduler instance
"scheduler": lr_scheduler,
# The unit of the scheduler's step size, could also be 'step'.
# 'epoch' updates the scheduler on epoch end whereas 'step'
# updates it after a optimizer update.
"interval": "epoch",
# How many epochs/steps should pass between calls to
# `scheduler.step()`. 1 corresponds to updating the learning
# rate after every epoch/step.
"frequency": 1,
# Metric to to monitor for schedulers like `ReduceLROnPlateau`
"monitor": "val_loss",
# If set to `True`, will enforce that the value specified 'monitor'
# is available when the scheduler is updated, thus stopping
# training if not found. If set to `False`, it will only produce a warning
"strict": True,
# If using the `LearningRateMonitor` callback to monitor the
# learning rate progress, this keyword can be used to specify
# a custom logged name
"name": None,
}
When there are schedulers in which the .step()
method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau
scheduler,
Flash requires that the Lightning Scheduler configuration contains the keyword "monitor"
set to the metric name that the scheduler should be conditioned on.
Below is an example for this:
from flash.image import ImageClassifier
model = ImageClassifier(
num_classes=10,
backbone="resnet18",
optimizer="Adam",
learning_rate=1e-4,
lr_scheduler=("reducelronplateau", {"mode": "max"}, {"monitor": "val_accuracy"}),
)
Note
Do not set the "scheduler"
key in the Lightning Scheduler configuration, it will overriden with an instance of the provided scheduler key.
Pre-Registering optimizers and scheduler recipes¶
Flash registry also provides the flexiblty of registering functions. This feature is also provided in the Optimizer and Scheduler registry.
Using the optimizers
and lr_schedulers
decorator pertaining to each Task
, custom optimizer and LR scheduler recipes can be pre-registered.
import torch
from flash.image import ImageClassifier
@ImageClassifier.lr_schedulers
def my_flash_steplr_recipe(optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_flash_steplr_recipe")
Provider specific requirements¶
Schedulers¶
Certain LR Schedulers provided by Hugging Face require both num_training_steps
and num_warmup_steps
.
In order to use them in Flash, just provide num_warmup_steps
as float between 0 and 1 which indicates the fraction of the training steps
that will be used as warmup steps. Flash’s Trainer
will take care of computing the number of training steps and
number of warmup steps based on the flags that are set in the Trainer.
from flash.image import ImageClassifier
model = ImageClassifier(
backbone="resnet18",
num_classes=2,
optimizer="Adam",
lr_scheduler=("cosine_schedule_with_warmup", {"num_warmup_steps": 0.1}),
)
Image Classification¶
The Task¶
The task of identifying what is in an image is called image classification. Typically, Image Classification is used to identify images containing a single object. The task predicts which ‘class’ the image most likely belongs to with a degree of certainty. A class is a label that describes what is in an image, such as ‘car’, ‘house’, ‘cat’ etc.
Example¶
Let’s look at the task of predicting whether images contain Ants or Bees using the hymenoptera dataset.
The dataset contains train
and validation
folders, and then each folder contains a bees folder, with pictures of bees, and an ants folder with images of, you guessed it, ants.
hymenoptera_data
├── train
│ ├── ants
│ │ ├── 0013035.jpg
│ │ ├── 1030023514_aad5c608f9.jpg
│ │ ...
│ └── bees
│ ├── 1092977343_cb42b38d62.jpg
│ ├── 1093831624_fb5fbe2308.jpg
│ ...
└── val
├── ants
│ ├── 10308379_1b6c72e180.jpg
│ ├── 1053149811_f62a3410d3.jpg
│ ...
└── bees
├── 1032546534_06907fe3b3.jpg
├── 10870992_eebeeb3a12.jpg
...
Once we’ve downloaded the data using download_data()
, we create the ImageClassificationData
.
We select a pre-trained backbone to use for our ImageClassifier
and fine-tune on the hymenoptera data.
We then use the trained ImageClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
)
# 2. Build the task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict what's on a few images! ants or bees?
predictions = model.predict(
[
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
Flash Zero¶
The image classifier can be used directly from the command line with zero code using Flash Zero. You can run the hymenoptera example with:
flash image_classification
To view configuration options and options for running the image classifier with your own data, use:
flash image_classification --help
Loading Data¶
This section details the available ways to load your own data into the ImageClassificationData
.
from_folders¶
Construct the ImageClassificationData
from folders.
The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, .npy.
For train, test, and val data, the folders are expected to contain a sub-folder for each class. Here’s the required structure:
train_folder
├── class_1
│ ├── file1.jpg
│ ├── file2.jpg
│ ...
└── class_2
├── file1.jpg
├── file2.jpg
...
For prediction, the folder is expected to contain the files for inference, like this:
predict_folder
├── file1.jpg
├── file2.jpg
...
Example:
data_module = ImageClassificationData.from_folders(
train_folder = "./train_folder",
predict_folder = "./predict_folder",
...
)
from_files¶
Construct the ImageClassificationData
from lists of files and corresponding lists of targets.
The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, .npy.
Example:
train_files = ["file1.jpg", "file2.jpg", "file3.jpg", ...]
train_targets = [0, 1, 0, ...]
datamodule = ImageClassificationData.from_files(
train_files = train_files,
train_targets = train_targets,
...
)
from_datasets¶
Construct the ImageClassificationData
from the given datasets for each stage.
Example:
from torch.utils.data.dataset import Dataset
train_dataset: Dataset = ...
datamodule = ImageClassificationData.from_datasets(
train_dataset = train_dataset,
...
)
Note
The __getitem__
of your datasets should return a dictionary with "input"
and "target"
keys which map to the input image (as a PIL.Image) and the target (as an int or list of ints) respectively.
Custom Transformations¶
Flash automatically applies some default image transformations and augmentations, but you may wish to customize these for your own use case.
The base Preprocess
defines 7 hooks for different stages in the data loading pipeline.
To apply image augmentations you can directly import the default_transforms
from flash.image.classification.transforms
and then merge your custom image transformations with them using the merge_transforms()
helper function.
Here’s an example where we load the default transforms and merge with custom torchvision transformations.
We use the post_tensor_transform hook to apply the transformations after the image has been converted to a torch.Tensor.
from torchvision import transforms as T
import flash
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.transforms import ApplyToKeys, merge_transforms
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.transforms import default_transforms
post_tensor_transform = ApplyToKeys(
DefaultDataKeys.INPUT,
T.Compose([T.RandomHorizontalFlip(), T.ColorJitter(), T.RandomAutocontrast(), T.RandomPerspective()]),
)
new_transforms = merge_transforms(default_transforms((64, 64)), {"post_tensor_transform": post_tensor_transform})
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", train_transform=new_transforms
)
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
Serving¶
The ImageClassifier
is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.image import ImageClassifier
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model.serve()
You can now perform inference from your client like this:
import base64
from pathlib import Path
import requests
import flash
with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Multi-label Image Classification¶
The Task¶
Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality (images in this case).
Multi-label image classification is supported by the ImageClassifier
via the multi-label
argument.
Example¶
Let’s look at the task of trying to predict the movie genres from an image of the movie poster.
The data we will use is a subset of the awesome movie poster genre prediction data set from the paper “Movie Genre Classification based on Poster Images with Deep Neural Networks” by Wei-Ta Chu and Hung-Jui Guo, resized to 128 by 128.
Take a look at their paper (and please consider citing their paper if you use the data) here: www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/.
The data set contains train
and validation
folders, and then each folder contains images and a metadata.csv
which stores the labels.
Here’s an overview:
movie_posters
├── train
│ ├── metadata.csv
│ ├── tt0084058.jpg
│ ├── tt0084867.jpg
│ ...
└── val
├── metadata.csv
├── tt0200465.jpg
├── tt0326965.jpg
...
Once we’ve downloaded the data using download_data()
, we need to create the ImageClassificationData
.
We first create a function (load_data
) to extract the list of images and associated labels which can then be passed to from_files()
.
We select a pre-trained backbone to use for our ImageClassifier
and fine-tune on the posters data.
We then use the trained ImageClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# 1. Create the DataModule
# Data set from the paper “Movie Genre Classification based on Poster Images with Deep Neural Networks”.
# More info here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/
download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip")
datamodule = ImageClassificationData.from_csv(
"Id",
["Action", "Romance", "Crime", "Thriller", "Adventure"],
train_file="data/movie_posters/train/metadata.csv",
val_file="data/movie_posters/val/metadata.csv",
image_size=(128, 128),
)
# 2. Build the task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, multi_label=True)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict the genre of a few movies!
predictions = model.predict(
[
"data/movie_posters/predict/tt0085318.jpg",
"data/movie_posters/predict/tt0089461.jpg",
"data/movie_posters/predict/tt0097179.jpg",
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("image_classification_multi_label_model.pt")
Flash Zero¶
The multi-label image classifier can be used directly from the command line with zero code using Flash Zero. You can run the movie posters example with:
flash image_classification from_movie_posters
To view configuration options and options for running the image classifier with your own data, use:
flash image_classification --help
Serving¶
The ImageClassifier
is servable.
For more information, see Image Classification.
Image Embedder¶
The Task¶
Image embedding encodes an image into a vector of features which can be used for a downstream task. This could include: clustering, similarity search, or classification.
The ImageEmbedder
internally relies on VISSL.
Example¶
Let’s see how to configure a training strategy for the ImageEmbedder
task.
A vanilla DataModule
object be created using standard Datasets as shown below.
Then the user can configure the ImageEmbedder
task with training_strategy
, backbone
, head
and pretraining_transform
.
There are options provided to send additional arguments to config selections.
This task can now be sent to the fit()
method of Trainer
.
Note
A lot of VISSL loss functions use hard-coded torch.distributed
methods. The user is suggested to use accelerator=ddp
even with a single GPU.
Only barlow_twins
training strategy works on the CPU. All other loss functions are configured to work on GPUs.
import torch
from torchvision.datasets import CIFAR10
import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder
# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
train_dataset=CIFAR10(".", download=True),
batch_size=16,
)
# 2. Build the task
embedder = ImageEmbedder(
backbone="resnet",
training_strategy="barlow_twins",
head="simclr_head",
pretraining_transform="barlow_twins_transform",
training_strategy_kwargs={"latent_embedding_dim": 128},
pretraining_transform_kwargs={"size_crops": [196]},
)
# 3. Create the trainer and pre-train the encoder
# use accelerator='ddp' when using GPU(s),
# i.e. flash.Trainer(max_epochs=3, gpus=1, accelerator='ddp')
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)
# 4. Save the model!
trainer.save_checkpoint("image_embedder_model.pt")
# 5. Download the downstream prediction dataset and generate embeddings
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
embeddings = embedder.predict(
[
"data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
"data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
]
)
# list of embeddings for images sent to the predict function
print(embeddings)
Object Detection¶
The Task¶
Object detection is the task of identifying objects in images and their associated classes and bounding boxes.
The ObjectDetector
and ObjectDetectionData
classes internally rely on IceVision.
Example¶
Let’s look at object detection with the COCO 128 data set, which contains 91 object classes. This is a subset of COCO train2017 with only 128 images. The data set is organized following the COCO format. Here’s an outline:
coco128
├── annotations
│ └── instances_train2017.json
├── images
│ └── train2017
│ ├── 000000000009.jpg
│ ├── 000000000025.jpg
│ ...
└── labels
└── train2017
├── 000000000009.txt
├── 000000000025.txt
...
Once we’ve downloaded the data using download_data()
, we can create the ObjectDetectionData
.
We select a pre-trained RetinaNet to use for our ObjectDetector
and fine-tune on the COCO 128 data.
We then use the trained ObjectDetector
for inference.
Finally, we save the model.
Here’s the full example:
import flash
from flash.core.data.utils import download_data
from flash.image import ObjectDetectionData, ObjectDetector
# 1. Create the DataModule
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")
datamodule = ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=0.1,
image_size=128,
)
# 2. Build the task
model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=128)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Detect objects in a few images!
predictions = model.predict(
[
"data/coco128/images/train2017/000000000625.jpg",
"data/coco128/images/train2017/000000000626.jpg",
"data/coco128/images/train2017/000000000629.jpg",
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("object_detection_model.pt")
Flash Zero¶
The object detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash object_detection
To view configuration options and options for running the object detector with your own data, use:
flash object_detection --help
Custom Transformations¶
Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case.
The base Preprocess
defines 7 hooks for different stages in the data loading pipeline.
For object-detection tasks, you can leverage the transformations from Albumentations with the IceVisionTransformAdapter
.
import albumentations as alb
from icevision.tfms import A
from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData
train_transform = {
"pre_tensor_transform": transforms.IceVisionTransformAdapter(
[*A.resize_and_pad(128), A.Normalize(), A.Flip(0.4), alb.RandomBrightnessContrast()]
)
}
datamodule = ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=0.1,
image_size=128,
train_transform=train_transform,
)
Keypoint Detection¶
The Task¶
Keypoint detection is the task of identifying keypoints in images and their associated classes.
The KeypointDetector
and KeypointDetectionData
classes internally rely on IceVision.
Example¶
Let’s look at keypoint detection with BIWI Sample Keypoints (center of face) from IceData.
Once we’ve downloaded the data, we can create the KeypointDetectionData
.
We select a keypoint_rcnn
with a resnet18_fpn
backbone to use for our KeypointDetector
and fine-tune on the BIWI data.
We then use the trained KeypointDetector
for inference.
Finally, we save the model.
Here’s the full example:
import flash
from flash.core.utilities.imports import example_requires
from flash.image import KeypointDetectionData, KeypointDetector
example_requires("image")
import icedata # noqa: E402
# 1. Create the DataModule
data_dir = icedata.biwi.load_data()
datamodule = KeypointDetectionData.from_folders(
train_folder=data_dir,
val_split=0.1,
parser=icedata.biwi.parser,
)
# 2. Build the task
model = KeypointDetector(
head="keypoint_rcnn",
backbone="resnet18_fpn",
num_keypoints=1,
num_classes=datamodule.num_classes,
)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Detect objects in a few images!
predictions = model.predict(
[
str(data_dir / "biwi_sample/images/0.jpg"),
str(data_dir / "biwi_sample/images/1.jpg"),
str(data_dir / "biwi_sample/images/10.jpg"),
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("keypoint_detection_model.pt")
Flash Zero¶
The keypoint detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash keypoint_detection
To view configuration options and options for running the keypoint detector with your own data, use:
flash keypoint_detection --help
Instance Segmentation¶
The Task¶
Instance segmentation is the task of segmenting objects images and determining their associated classes.
The InstanceSegmentation
and InstanceSegmentationData
classes internally rely on IceVision.
Example¶
Let’s look at instance segmentation with The Oxford-IIIT Pet Dataset from IceData.
Once we’ve downloaded the data, we can create the InstanceSegmentationData
.
We select a mask_rcnn
with a resnet18_fpn
backbone to use for our InstanceSegmentation
and fine-tune on the pets data.
We then use the trained InstanceSegmentation
for inference.
Finally, we save the model.
Here’s the full example:
from functools import partial
import flash
from flash.core.utilities.imports import example_requires
from flash.image import InstanceSegmentation, InstanceSegmentationData
example_requires("image")
import icedata # noqa: E402
# 1. Create the DataModule
data_dir = icedata.pets.load_data()
datamodule = InstanceSegmentationData.from_folders(
train_folder=data_dir,
val_split=0.1,
parser=partial(icedata.pets.parser, mask=True),
)
# 2. Build the task
model = InstanceSegmentation(
head="mask_rcnn",
backbone="resnet18_fpn",
num_classes=datamodule.num_classes,
)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Detect objects in a few images!
predictions = model.predict(
[
str(data_dir / "images/yorkshire_terrier_9.jpg"),
str(data_dir / "images/yorkshire_terrier_12.jpg"),
str(data_dir / "images/yorkshire_terrier_13.jpg"),
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("instance_segmentation_model.pt")
Flash Zero¶
The instance segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash instance_segmentation
To view configuration options and options for running the instance segmentation task with your own data, use:
flash instance_segmentation --help
Semantic Segmentation¶
The Task¶
Semantic Segmentation, or image segmentation, is the task of performing classification at a pixel-level, meaning each pixel will associated to a given class. See more: https://paperswithcode.com/task/semantic-segmentation
Example¶
Let’s look at an example using a data set generated with the CARLA driving simulator. The data was generated as part of the Kaggle Lyft Udacity Challenge. The data contains one folder of images and another folder with the corresponding segmentation masks. Here’s the structure:
data
├── CameraRGB
│ ├── F61-1.png
│ ├── F61-2.png
│ ...
└── CameraSeg
├── F61-1.png
├── F61-2.png
...
Once we’ve downloaded the data using download_data()
, we create the SemanticSegmentationData
.
We select a pre-trained mobilenet_v3_large
backbone with an fpn
head to use for our SemanticSegmentation
task and fine-tune on the CARLA data.
We then use the trained SemanticSegmentation
for inference. You can check the available pretrained weights for the backbones like this SemanticSegmentation.available_pretrained_weights(“resnet18”).
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData
# 1. Create the DataModule
# The data was generated with the CARLA self-driving simulator as part of the Kaggle Lyft Udacity Challenge.
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
download_data(
"https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
"./data",
)
datamodule = SemanticSegmentationData.from_folders(
train_folder="data/CameraRGB",
train_target_folder="data/CameraSeg",
val_split=0.1,
image_size=(256, 256),
num_classes=21,
)
# 2. Build the task
model = SemanticSegmentation(
backbone="mobilenetv3_large_100",
head="fpn",
num_classes=datamodule.num_classes,
)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Segment a few images!
predictions = model.predict(
[
"data/CameraRGB/F61-1.png",
"data/CameraRGB/F62-1.png",
"data/CameraRGB/F63-1.png",
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("semantic_segmentation_model.pt")
Flash Zero¶
The semantic segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash semantic_segmentation
To view configuration options and options for running the semantic segmentation task with your own data, use:
flash semantic_segmentation --help
Loading Data¶
This section details the available ways to load your own data into the SemanticSegmentationData
.
from_folders¶
Construct the SemanticSegmentationData
from folders.
The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp.
For train, test, and val data, we expect a folder containing inputs and another folder containing the masks. Here’s the required structure:
train_folder
├── inputs
│ ├── file1.jpg
│ ├── file2.jpg
│ ...
└── masks
├── file1.jpg
├── file2.jpg
...
For prediction, the folder is expected to contain the files for inference, like this:
predict_folder
├── file1.jpg
├── file2.jpg
...
Example:
data_module = SemanticSegmentationData.from_folders(
train_folder = "./train_folder/inputs",
train_target_folder = "./train_folder/masks",
predict_folder = "./predict_folder",
...
)
from_files¶
Construct the SemanticSegmentationData
from lists of input images and corresponding list of target images.
The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp.
Example:
train_files = ["file1.jpg", "file2.jpg", "file3.jpg", ...]
train_targets = ["mask1.jpg", "mask2.jpg", "mask3.jpg", ...]
datamodule = SemanticSegmentationData.from_files(
train_files = train_files,
train_targets = train_targets,
...
)
from_datasets¶
Construct the SemanticSegmentationData
from the given datasets for each stage.
Example:
from torch.utils.data.dataset import Dataset
train_dataset: Dataset = ...
datamodule = SemanticSegmentationData.from_datasets(
train_dataset = train_dataset,
...
)
Note
The __getitem__
of your datasets should return a dictionary with "input"
and "target"
keys which map to the input and target images as tensors.
Serving¶
The SemanticSegmentation
task is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.image import SemanticSegmentation
from flash.image.segmentation.serialization import SegmentationLabels
model = SemanticSegmentation.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
)
model.serializer = SegmentationLabels(visualize=False)
model.serve()
You can now perform inference from your client like this:
import base64
from pathlib import Path
import requests
import flash
with (Path(flash.ASSETS_ROOT) / "road.png").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Style Transfer¶
The Task¶
The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content. The goal is that the output image looks like the content image, but “painted” in the style of the style reference image.
The StyleTransfer
and StyleTransferData
classes internally rely on pystiche.
Example¶
Let’s look at transferring the style from The Starry Night onto the images from the COCO 128 data set from the Object Detection Guide.
Once we’ve downloaded the data using download_data()
, we create the StyleTransferData
.
Next, we create our StyleTransfer
task with the desired style image and fit on the COCO 128 images.
We then use the trained StyleTransfer
for inference.
Finally, we save the model.
Here’s the full example:
import os
import torch
import flash
from flash.core.data.utils import download_data
from flash.image.style_transfer import StyleTransfer, StyleTransferData
# 1. Create the DataModule
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "./data")
datamodule = StyleTransferData.from_folders(train_folder="data/coco128/images/train2017")
# 2. Build the task
model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"))
# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Apply style transfer to a few images!
predictions = model.predict(
[
"data/coco128/images/train2017/000000000625.jpg",
"data/coco128/images/train2017/000000000626.jpg",
"data/coco128/images/train2017/000000000629.jpg",
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("style_transfer_model.pt")
Flash Zero¶
The style transfer task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash style_transfer
To view configuration options and options for running the style transfer task with your own data, use:
flash style_transfer --help
Video Classification¶
The Task¶
Typically, Video Classification refers to the task of producing a label for actions identified in a given video. The task is to predict which class the video clip belongs to.
Lightning Flash VideoClassifier
and VideoClassificationData
classes internally rely on PyTorchVideo.
Example¶
Let’s develop a model to classifying video clips of Humans performing actions (such as: archery , bowling, etc.). We’ll use data from the Kinetics dataset. Here’s an outline of the folder structure:
video_dataset
├── train
│ ├── archery
│ │ ├── -1q7jA3DXQM_000005_000015.mp4
│ │ ├── -5NN5hdIwTc_000036_000046.mp4
│ │ ...
│ ├── bowling
│ │ ├── -5ExwuF5IUI_000030_000040.mp4
│ │ ├── -7sTNNI1Bcg_000075_000085.mp4
│ ... ...
└── val
├── archery
│ ├── 0S-P4lr_c7s_000022_000032.mp4
│ ├── 2x1lIrgKxYo_000589_000599.mp4
│ ...
├── bowling
│ ├── 1W7HNDBA4pA_000002_000012.mp4
│ ├── 4JxH3S5JwMs_000003_000013.mp4
... ...
Once we’ve downloaded the data using download_data()
, we create the VideoClassificationData
.
We select a pre-trained backbone to use for our VideoClassifier
and fine-tune on the Kinetics data.
The backbone can be any model from the PyTorchVideo Model Zoo.
We then use the trained VideoClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import os
import torch
import flash
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier
# 1. Create the DataModule
# Find more datasets at https://pytorchvideo.readthedocs.io/en/latest/data.html
download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip", "./data")
datamodule = VideoClassificationData.from_folders(
train_folder=os.path.join(os.getcwd(), "data/kinetics/train"),
val_folder=os.path.join(os.getcwd(), "data/kinetics/val"),
clip_sampler="uniform",
clip_duration=1,
decode_audio=False,
)
# 2. Build the task
model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes, pretrained=False)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Make a prediction
predictions = model.predict(os.path.join(os.getcwd(), "data/kinetics/predict"))
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("video_classification.pt")
Flash Zero¶
The video classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash video_classification
To view configuration options and options for running the video classifier with your own data, use:
flash video_classification --help
Audio Classification¶
The Task¶
The task of identifying what is in an audio file is called audio classification. Typically, Audio Classification is used to identify audio files containing sounds or words. The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty. A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc.
Example¶
Let’s look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset.
The dataset contains train
, val
and test
folders, and then each folder contains a airconditioner folder, with spectrograms generated from air-conditioner sounds, siren folder with spectrograms generated from siren sounds and the same goes for the other classes.
urban8k_images
├── train
│ ├── air_conditioner
│ ├── car_horn
│ ├── children_playing
│ ├── dog_bark
│ ├── drilling
│ ├── engine_idling
│ ├── gun_shot
│ ├── jackhammer
│ ├── siren
│ └── street_music
├── test
│ ├── air_conditioner
│ ├── car_horn
│ ├── children_playing
│ ├── dog_bark
│ ├── drilling
│ ├── engine_idling
│ ├── gun_shot
│ ├── jackhammer
│ ├── siren
│ └── street_music
└── val
├── air_conditioner
├── car_horn
├── children_playing
├── dog_bark
├── drilling
├── engine_idling
├── gun_shot
├── jackhammer
├── siren
└── street_music
...
Once we’ve downloaded the data using download_data()
, we create the AudioClassificationData
.
We select a pre-trained backbone to use for our ImageClassifier
and fine-tune on the UrbanSound8k spectrogram images data.
We then use the trained ImageClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.audio import AudioClassificationData
from flash.core.data.utils import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.image import ImageClassifier
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data")
datamodule = AudioClassificationData.from_folders(
train_folder="data/urban8k_images/train",
val_folder="data/urban8k_images/val",
spectrogram_size=(64, 64),
)
# 2. Build the model.
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))
# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c
predictions = model.predict(
[
"data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg",
"data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg",
"data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg",
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("audio_classification_model.pt")
Flash Zero¶
The audio classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash audio_classification
To view configuration options and options for running the audio classifier with your own data, use:
flash audio_classification --help
Loading Data¶
This section details the available ways to load your own data into the AudioClassificationData
.
from_folders¶
Construct the AudioClassificationData
from folders.
The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, .npy.
For train, test, and val data, the folders are expected to contain a sub-folder for each class. Here’s the required structure:
train_folder
├── class_1
│ ├── file1.jpg
│ ├── file2.jpg
│ ...
└── class_2
├── file1.jpg
├── file2.jpg
...
For prediction, the folder is expected to contain the files for inference, like this:
predict_folder
├── file1.jpg
├── file2.jpg
...
Example:
data_module = AudioClassificationData.from_folders(
train_folder = "./train_folder",
predict_folder = "./predict_folder",
...
)
from_files¶
Construct the AudioClassificationData
from lists of files and corresponding lists of targets.
The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, .npy.
Example:
train_files = ["file1.jpg", "file2.jpg", "file3.jpg", ...]
train_targets = [0, 1, 0, ...]
datamodule = AudioClassificationData.from_files(
train_files = train_files,
train_targets = train_targets,
...
)
from_datasets¶
Construct the AudioClassificationData
from the given datasets for each stage.
Example:
from torch.utils.data.dataset import Dataset
train_dataset: Dataset = ...
datamodule = AudioClassificationData.from_datasets(
train_dataset = train_dataset,
...
)
Note
The __getitem__
of your datasets should return a dictionary with "input"
and "target"
keys which map to the input spectrogram image (as a NumPy array) and the target (as an int or list of ints) respectively.
Speech Recognition¶
The Task¶
Speech recognition is the task of classifying audio into a text transcription. We rely on Wav2Vec as our backbone, fine-tuned on labeled transcriptions for speech to text. Wav2Vec is pre-trained on thousand of hours of unlabeled audio, providing a strong baseline when fine-tuning to downstream tasks such as Speech Recognition.
Example¶
Let’s fine-tune the model onto our own labeled audio transcription data:
Here’s the structure our CSV file:
file,text
"/path/to/file_1.wav","what was said in file 1."
"/path/to/file_2.wav","what was said in file 2."
"/path/to/file_3.wav","what was said in file 3."
...
Alternatively, here is the structure of our JSON file:
{"file": "/path/to/file_1.wav", "text": "what was said in file 1."}
{"file": "/path/to/file_2.wav", "text": "what was said in file 2."}
{"file": "/path/to/file_3.wav", "text": "what was said in file 3."}
Once we’ve downloaded the data using download_data()
, we create the SpeechRecognitionData
.
We select a pre-trained Wav2Vec backbone to use for our SpeechRecognition
and finetune on a subset of the TIMIT corpus.
The backbone can be any Wav2Vec model from HuggingFace transformers.
Next, we use the trained SpeechRecognition
for inference and save the model.
Here’s the full example:
import torch
import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")
datamodule = SpeechRecognitionData.from_json(
input_fields="file",
target_fields="text",
train_file="data/timit/train.json",
test_file="data/timit/test.json",
)
# 2. Build the task
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")
# 4. Predict on audio files!
predictions = model.predict(["data/timit/example.wav"])
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("speech_recognition_model.pt")
Flash Zero¶
The speech recognition task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash speech_recognition
To view configuration options and options for running the speech recognition task with your own data, use:
flash speech_recognition --help
Serving¶
The SpeechRecognition
is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.audio import SpeechRecognition
model = SpeechRecognition.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/speech_recognition_model.pt")
model.serve()
You can now perform inference from your client like this:
import base64
from pathlib import Path
import requests
import flash
with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f:
audio_str = base64.b64encode(f.read()).decode("UTF-8")
body = {"session": "UUID", "payload": {"inputs": {"data": audio_str}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Tabular Classification¶
The Task¶
Tabular classification is the task of assigning a class to samples of structured or relational data.
The TabularClassifier
task can be used for classification of samples in more than two classes (multi-class classification).
Example¶
Let’s look at training a model to predict if passenger survival on the Titanic using the classic Kaggle data set. The data is provided in CSV files that look like this:
PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
1,0,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S
3,1,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S
5,0,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S
6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q
...
Once we’ve downloaded the data using download_data()
, we can create the TabularData
from our CSV files using the from_csv()
method.
From the API reference
, we need to provide:
cat_cols- A list of the names of columns that contain categorical data (strings or integers).
num_cols- A list of the names of columns that contain numerical continuous data (floats).
target- The name of the column we want to predict.
train_csv- A CSV file containing the training data converted to a Pandas DataFrame
Next, we create the TabularClassifier
and finetune on the Titanic data.
We then use the trained TabularClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularClassificationData, TabularClassifier
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")
datamodule = TabularClassificationData.from_csv(
["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
"Fare",
target_fields="Survived",
train_file="data/titanic/titanic.csv",
val_split=0.1,
)
# 2. Build the task
model = TabularClassifier.from_data(datamodule)
# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Generate predictions from a CSV
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("tabular_classification_model.pt")
Flash Zero¶
The tabular classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash tabular_classifier
To view configuration options and options for running the tabular classifier with your own data, use:
flash tabular_classifier --help
Serving¶
The TabularClassifier
is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.core.classification import Labels
from flash.tabular import TabularClassifier
model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt")
model.serializer = Labels(["Did not survive", "Survived"])
model.serve()
You can now perform inference from your client like this:
import pandas as pd
import requests
from flash.core.data.utils import download_data
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")
df = pd.read_csv("./data/titanic/predict.csv")
text = str(df.to_csv())
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Text Classification¶
The Task¶
Text classification is the task of assigning a piece of text (word, sentence or document) an appropriate class, or category. The categories depend on the chosen data set and can range from topics.
Example¶
Let’s train a model to classify text as expressing either positive or negative sentiment.
We will be using the IMDB data set, that contains a train.csv
and valid.csv
.
Here’s the structure:
review,sentiment
"Japanese indie film with humor ... ",positive
"Isaac Florentine has made some ...",negative
"After seeing the low-budget ...",negative
"I've seen the original English version ...",positive
"Hunters chase what they think is a man through ...",negative
...
Once we’ve downloaded the data using download_data()
, we create the TextClassificationData
.
We select a pre-trained backbone to use for our TextClassifier
and finetune on the IMDB data.
The backbone can be any BERT classification model from HuggingFace/transformers.
Note
When changing the backbone, make sure you pass in the same backbone to the TextClassifier
and the TextClassificationData
!
Next, we use the trained TextClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
datamodule = TextClassificationData.from_csv(
"review",
"sentiment",
train_file="data/imdb/train.csv",
val_file="data/imdb/valid.csv",
backbone="prajjwal1/bert-medium",
)
# 2. Build the task
model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Classify a few sentences! How was the movie?
predictions = model.predict(
[
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado.",
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("text_classification_model.pt")
Flash Zero¶
The text classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash text_classification
To view configuration options and options for running the text classifier with your own data, use:
flash text_classification --help
Serving¶
The TextClassifier
is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.text import TextClassifier
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/text_classification_model.pt")
model.serve()
You can now perform inference from your client like this:
import requests
text = "Best movie ever"
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Accelerate Training & Inference with Torch ORT¶
Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the TextClassifier
once installed. See installation instructions here.
Note
Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.
...
model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)
Multi-label Text Classification¶
The Task¶
Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality (text in this case).
Multi-label text classification is supported by the TextClassifier
via the multi-label
argument.
Example¶
Let’s look at the task of classifying comment toxicity. The data we will use in this example is from the kaggle toxic comment classification challenge by jigsaw: www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge. The data is stored in CSV files with this structure:
"id","comment_text","toxic","severe_toxic","obscene","threat","insult","identity_hate"
"0000997932d777bf","...",0,0,0,0,0,0
"0002bcb3da6cb337","...",1,1,1,0,1,0
"0005c987bdfc9d4b","...",1,0,0,0,0,0
...
Once we’ve downloaded the data using download_data()
, we create the TextClassificationData
.
We select a pre-trained backbone to use for our TextClassifier
and finetune on the toxic comments data.
The backbone can be any BERT classification model from HuggingFace/transformers.
Note
When changing the backbone, make sure you pass in the same backbone to the TextClassifier
and the TextClassificationData
!
Next, we use the trained TextClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
# 1. Create the DataModule
# Data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data")
datamodule = TextClassificationData.from_csv(
"comment_text",
["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
train_file="data/jigsaw_toxic_comments/train.csv",
val_split=0.1,
backbone="unitary/toxic-bert",
)
# 2. Build the task
model = TextClassifier(
backbone="unitary/toxic-bert",
num_classes=datamodule.num_classes,
multi_label=True,
)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Generate predictions for a few comments!
predictions = model.predict(
[
"No, he is an arrogant, self serving, immature idiot. Get it right.",
"U SUCK HANNAH MONTANA",
"Would you care to vote? Thx.",
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("text_classification_multi_label_model.pt")
Flash Zero¶
The multi-label text classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash text_classification from_toxic
To view configuration options and options for running the text classifier with your own data, use:
flash text_classification --help
Serving¶
The TextClassifier
is servable.
For more information, see Text Classification.
Question Answering¶
The Task¶
Question Answering is the task of being able to answer questions pertaining to some known context. For example, given a context about some historical figure, any question pertaininig to the context should be answerable. In our case the article would be our input context and question, and the answer would be the output sequence from the model.
Note
We currently only support Extractive Question Answering, like the task performed using the SQUAD like datasets.
Example¶
Let’s look at an example.
We’ll use the SQUAD 2.0 dataset, which contains train-v2.0.json
and dev-v2.0.json
.
Each JSON file looks like this:
{
"answers": {
"answer_start": [94, 87, 94, 94],
"text": ["10th and 11th centuries", "in the 10th and 11th centuries", "10th and 11th centuries", "10th and 11th centuries"]
},
"context": "\"The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave thei...",
"id": "56ddde6b9a695914005b9629",
"question": "When were the Normans in Normandy?",
"title": "Normans"
}
...
In the above, the context
key represents the context used for the question and answer, the question
key represents the question being asked with respect to the context, the answer
key stores the answer(s) for the question.
id
and title
are used for unique identification and grouping concepts together respectively.
Once we’ve downloaded the data using download_data()
, we create the QuestionAnsweringData
.
We select a pre-trained backbone to use for our QuestionAnsweringTask
and finetune on the SQUAD 2.0 data.
The backbone can be any Question Answering model from HuggingFace/transformers.
Note
When changing the backbone, make sure you pass in the same backbone to the QuestionAnsweringData
and the QuestionAnsweringTask
!
Next, we use the trained QuestionAnsweringTask
for inference.
Finally, we save the model.
Here’s the full example:
from flash import Trainer
from flash.core.data.utils import download_data
from flash.text import QuestionAnsweringData, QuestionAnsweringTask
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/squad_tiny.zip", "./data/")
datamodule = QuestionAnsweringData.from_squad_v2(
train_file="./data/squad_tiny/train.json",
val_file="./data/squad_tiny/val.json",
)
# 2. Build the task
model = QuestionAnsweringTask()
# 3. Create the trainer and finetune the model
trainer = Trainer(max_epochs=3, limit_train_batches=1, limit_val_batches=1)
trainer.finetune(model, datamodule=datamodule)
# 4. Answer some Questions!
predictions = model.predict(
{
"id": ["56ddde6b9a695914005b9629", "56ddde6b9a695914005b9628"],
"context": [
"""
The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th
and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse
("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under
their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations
of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their
descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct
cultural and ethnic identity of the Normans emerged initially in the first half of the 10th
century, and it continued to evolve over the succeeding centuries.
""",
"""
The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th
and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse
("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under
their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations
of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their
descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct
cultural and ethnic identity of the Normans emerged initially in the first half of the 10th
century, and it continued to evolve over the succeeding centuries.
""",
],
"question": ["When were the Normans in Normandy?", "In what country is Normandy located?"],
}
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("question_answering_on_sqaud_v2.pt")
Accelerate Training & Inference with Torch ORT¶
Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the QuestionAnsweringTask
once installed. See installation instructions here.
Note
Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.
...
model = QuestionAnsweringTask(backbone="distilbert-base-uncased", max_answer_length=30, enable_ort=True)
Summarization¶
The Task¶
Summarization is the task of summarizing text from a larger document/article into a short sentence/description. For example, taking a web article and describing the topic in a short sentence. This task is a subset of Sequence to Sequence tasks, which require the model to generate a variable length sequence given an input sequence. In our case the article would be our input sequence, and the short description/sentence would be the output sequence from the model.
Example¶
Let’s look at an example.
We’ll use the XSUM dataset, which contains a train.csv
and valid.csv
.
Each CSV file looks like this:
input,target
"The researchers have sequenced the genome of a strain of bacterium that causes the virulent infection...","A team of UK scientists hopes to shed light on the mysteries of bleeding canker, a disease that is threatening the nation's horse chestnut trees."
"Knight was shot in the leg by an unknown gunman at Miami's Shore Club where West was holding a pre-MTV Awards...",Hip hop star Kanye West is being sued by Death Row Records founder Suge Knight over a shooting at a beach party in August 2005.
...
In the above, the input column represents the long articles/documents, and the target is the short description used as the target.
Once we’ve downloaded the data using download_data()
, we create the SummarizationData
.
We select a pre-trained backbone to use for our SummarizationTask
and finetune on the XSUM data.
The backbone can be any Seq2Seq summarization model from HuggingFace/transformers.
Note
When changing the backbone, make sure you pass in the same backbone to the SummarizationData
and the SummarizationTask
!
Next, we use the trained SummarizationTask
for inference.
Finally, we save the model.
Here’s the full example:
from flash import Trainer
from flash.core.data.utils import download_data
from flash.text import SummarizationData, SummarizationTask
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "./data/")
datamodule = SummarizationData.from_csv(
"input",
"target",
train_file="data/xsum/train.csv",
val_file="data/xsum/valid.csv",
)
# 2. Build the task
model = SummarizationTask()
# 3. Create the trainer and finetune the model
trainer = Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule)
# 4. Summarize some text!
predictions = model.predict(
"""
Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local
people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue.
They came to Brixton to see work which has started to revitalise the borough.
It was Charles' first visit to the area since 1996, when he was accompanied by the former
South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue
for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit.
""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes.
She asked me were they ripe and I said yes - they're from the Dominican Republic.""
Mr Chong is one of 170 local retailers who accept the Brixton Pound.
Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market
or in participating shops.
During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children
nearby on an estate off Coldharbour Lane. Mr West said:
""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man.""
He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'""
Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust.
The trust hopes to restore and refurbish the building,
where once Jimi Hendrix and The Clash played, as a new community and business centre."
"""
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("summarization_model_xsum.pt")
Flash Zero¶
The summarization task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash summarization
To view configuration options and options for running the summarization task with your own data, use:
flash summarization --help
Serving¶
The SummarizationTask
is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.text import SummarizationTask
model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt")
model.serve()
You can now perform inference from your client like this:
import requests
text = """
Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local
people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue.
They came to Brixton to see work which has started to revitalise the borough.
It was Charles' first visit to the area since 1996, when he was accompanied by the former
South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue
for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit.
""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes.
She asked me were they ripe and I said yes - they're from the Dominican Republic.""
Mr Chong is one of 170 local retailers who accept the Brixton Pound.
Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market
or in participating shops.
During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children
nearby on an estate off Coldharbour Lane. Mr West said:
""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man.""
He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'""
Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust.
The trust hopes to restore and refurbish the building,
where once Jimi Hendrix and The Clash played, as a new community and business centre."
"""
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Accelerate Training & Inference with Torch ORT¶
Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the SummarizationTask
once installed. See installation instructions here.
Note
Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.
...
model = SummarizationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)
Translation¶
The Task¶
Translation is the task of translating text from a source language to another, such as English to Romanian. This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case, the task will take an English sequence as input, and output the same sequence in Romanian.
Example¶
Let’s look at an example.
We’ll use WMT16 English/Romanian, a dataset of English to Romanian samples, based on the Europarl corpora.
The data set contains a train.csv
and valid.csv
.
Each CSV file looks like this:
input,target
"Written statements and oral questions (tabling): see Minutes","Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal"
"Closure of sitting","Ridicarea şedinţei"
...
In the above the input/target columns represent the English and Romanian translation respectively.
Once we’ve downloaded the data using download_data()
, we create the TranslationData
.
We select a pre-trained backbone to use for our TranslationTask
and finetune on the WMT16 data.
The backbone can be any Seq2Seq translation model from HuggingFace/transformers.
Note
When changing the backbone, make sure you pass in the same backbone to the TranslationData
and the TranslationTask
!
Next, we use the trained TranslationTask
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.text import TranslationData, TranslationTask
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "./data")
datamodule = TranslationData.from_csv(
"input",
"target",
train_file="data/wmt_en_ro/train.csv",
val_file="data/wmt_en_ro/valid.csv",
backbone="Helsinki-NLP/opus-mt-en-ro",
)
# 2. Build the task
model = TranslationTask(backbone="Helsinki-NLP/opus-mt-en-ro")
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule)
# 4. Translate something!
predictions = model.predict(
[
"BBC News went to meet one of the project's first graduates.",
"A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",
"Of course, it's still early in the election cycle.",
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("translation_model_en_ro.pt")
Flash Zero¶
The translation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash translation
To view configuration options and options for running the translation task with your own data, use:
flash translation --help
Serving¶
The TranslationTask
is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.text import TranslationTask
model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt")
model.serve()
You can now perform inference from your client like this:
import requests
text = "Some English text"
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Accelerate Training & Inference with Torch ORT¶
Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the TranslationTask
once installed. See installation instructions here.
Note
Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.
...
model = TranslationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)
Point Cloud Segmentation¶
The Task¶
A Point Cloud is a set of data points in space, usually describes by x
, y
and z
coordinates.
PointCloud Segmentation is the task of performing classification at a point-level, meaning each point will associated to a given class. The current integration builds on top Open3D-ML.
Example¶
Let’s look at an example using a data set generated from the KITTI Vision Benchmark.
The data are a tiny subset of the original dataset and contains sequences of point clouds.
The data contains multiple folder, one for each sequence and a meta.yaml file describing the classes and their official associated color map.
A sequence should contain one folder for scans and one folder for labels, plus a pose.txt
to re-align the sequence if required.
Here’s the structure:
data
├── meta.yaml
├── 00
│ ├── scans
| | ├── 00000.bin
| | ├── 00001.bin
| | ...
│ ├── labels
| | ├── 00000.label
| | ├── 00001.label
| | ...
| ├── pose.txt
│ ...
|
└── XX
├── scans
| ├── 00000.bin
| ├── 00001.bin
| ...
├── labels
| ├── 00000.label
| ├── 00001.label
| ...
├── pose.txt
Learn more: http://www.semantic-kitti.org/dataset.html
Once we’ve downloaded the data using download_data()
, we create the PointCloudSegmentationData
.
We select a pre-trained randlanet_semantic_kitti
backbone for our PointCloudSegmentation
task.
We then use the trained PointCloudSegmentation
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData
# 1. Create the DataModule
# Dataset Credit: http://www.semantic-kitti.org/
download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/")
datamodule = PointCloudSegmentationData.from_folders(
train_folder="data/SemanticKittiTiny/train",
val_folder="data/SemanticKittiTiny/val",
)
# 2. Build the task
model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule)
# 4. Predict what's within a few PointClouds?
predictions = model.predict(
[
"data/SemanticKittiTiny/predict/000000.bin",
"data/SemanticKittiTiny/predict/000001.bin",
]
)
# 5. Save the model!
trainer.save_checkpoint("pointcloud_segmentation_model.pt")
Flash Zero¶
The point cloud segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash pointcloud_segmentation
To view configuration options and options for running the point cloud segmentation task with your own data, use:
flash pointcloud_segmentation --help
Point Cloud Object Detection¶
The Task¶
A Point Cloud is a set of data points in space, usually describes by x
, y
and z
coordinates.
PointCloud Object Detection is the task of identifying 3D objects in point clouds and their associated classes and 3D bounding boxes.
The current integration builds on top Open3D-ML.
Example¶
Let’s look at an example using a data set generated from the KITTI Vision Benchmark. The data are a tiny subset of the original dataset and contains sequences of point clouds.
- The data contains:
one folder for scans
one folder for scan calibrations
one folder for labels
a meta.yaml file describing the classes and their official associated color map.
Here’s the structure:
data
├── meta.yaml
├── train
│ ├── scans
| | ├── 00000.bin
| | ├── 00001.bin
| | ...
│ ├── calibs
| | ├── 00000.txt
| | ├── 00001.txt
| | ...
│ ├── labels
| | ├── 00000.txt
| | ├── 00001.txt
│ ...
├── val
│ ...
├── predict
├── scans
| ├── 00000.bin
| ├── 00001.bin
|
├── calibs
| ├── 00000.txt
| ├── 00001.txt
├── meta.yaml
Learn more: http://www.semantic-kitti.org/dataset.html
Once we’ve downloaded the data using download_data()
, we create the PointCloudObjectDetectorData
.
We select a pre-trained randlanet_semantic_kitti
backbone for our PointCloudObjectDetector
task.
We then use the trained PointCloudObjectDetector
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData
# 1. Create the DataModule
# Dataset Credit: http://www.semantic-kitti.org/
download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/")
datamodule = PointCloudObjectDetectorData.from_folders(
train_folder="data/KITTI_Tiny/Kitti/train",
val_folder="data/KITTI_Tiny/Kitti/val",
)
# 2. Build the task
model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule)
# 4. Predict what's within a few PointClouds?
predictions = model.predict(
[
"data/KITTI_Tiny/Kitti/predict/scans/000000.bin",
"data/KITTI_Tiny/Kitti/predict/scans/000001.bin",
]
)
# 5. Save the model!
trainer.save_checkpoint("pointcloud_detection_model.pt")
Flash Zero¶
The point cloud object detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash pointcloud_detection
To view configuration options and options for running the point cloud object detector with your own data, use:
flash pointcloud_detection --help
Graph Classification¶
The Task¶
This task consist on classifying graphs. The task predicts which ‘class’ the graph belongs to. A class is a label that indicates the kind of graph. For example, a label may indicate whether one molecule interacts with another.
The GraphClassifier
and GraphClassificationData
classes internally rely on pytorch-geometric.
Example¶
Let’s look at the task of classifying graphs from the KKI data set from TU Dortmund University.
Once we’ve created the TUDataset, we create the GraphClassificationData
.
We then create our GraphClassifier
and train on the KKI data.
Next, we use the trained GraphClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.utilities.imports import example_requires
from flash.graph import GraphClassificationData, GraphClassifier
example_requires("graph")
from torch_geometric.datasets import TUDataset # noqa: E402
# 1. Create the DataModule
dataset = TUDataset(root="data", name="KKI")
datamodule = GraphClassificationData.from_datasets(
train_dataset=dataset,
val_split=0.1,
)
# 2. Build the task
model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
# 3. Create the trainer and fit the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Classify some graphs!
predictions = model.predict(dataset[:3])
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("graph_classification.pt")
Flash Zero¶
The graph classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash graph_classification
To view configuration options and options for running the graph classifier with your own data, use:
flash graph_classification --help
Providers¶
Flash is a framework integrator. We rely on many open source frameworks for our tasks, visualizations and backbones. Here’s a list of some of the providers we use for backbones and heads within Flash (check them out and star their repos to spread the open source love!):
airctic/IceVision (https://github.com/airctic/icevision)
Facebook Research/dino (https://github.com/facebookresearch/dino)
Facebook Research/PyTorchVideo (https://github.com/facebookresearch/pytorchvideo)
Facebook Research/vissl (https://github.com/facebookresearch/vissl)
Hugging Face/transformers (https://github.com/huggingface/transformers)
Intelligent Systems Lab Org/Open3D-ML (https://github.com/isl-org/Open3D-ML)
learnables/learn2learn (https://github.com/learnables/learn2learn)
OpenMMLab/MMDetection (https://github.com/open-mmlab/mmdetection)
pystiche/pystiche (https://github.com/pystiche/pystiche)
PyTorch/fairseq (https://github.com/pytorch/fairseq)
PyTorch/torchvision (https://github.com/pytorch/vision)
qubvel/segmentation_models.pytorch (https://github.com/qubvel/segmentation_models.pytorch)
rwightman/efficientdet-pytorch (https://github.com/rwightman/efficientdet-pytorch)
rwightman/pytorch-image-models (https://github.com/rwightman/pytorch-image-models)
Ultralytics/YOLOV5 (https://github.com/ultralytics/yolov5)
You can also read our guides for some of our larger integrations:
BaaL¶
The framework Bayesian Active Learning (BaaL) is an active learning library developed at ElementAI.
Active Learning is a sub-field in AI, focusing on adding a human in the learning loop. The most uncertain samples will be labelled by the human to accelerate the model training cycle.
Credit to ElementAI / Baal Team for creating this diagram flow
With its integration within Flash, the Active Learning process is simpler than ever before.
import torch
import flash
from flash.core.classification import Probabilities
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
# Implement the research use-case where we mask labels from labelled dataset.
datamodule = ActiveLearningDataModule(
ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2),
initial_num_labels=5,
val_split=0.1,
)
# 2. Build the task
head = torch.nn.Sequential(
torch.nn.Dropout(p=0.1),
torch.nn.Linear(512, datamodule.num_classes),
)
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes, serializer=Probabilities())
# 3.1 Create the trainer
trainer = flash.Trainer(max_epochs=3)
# 3.2 Create the active learning loop and connect it to the trainer
active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1)
active_learning_loop.connect(trainer.fit_loop)
trainer.fit_loop = active_learning_loop
# 3.3 Finetune
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict what's on a few images! ants or bees?
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
FiftyOne¶
We have collaborated with the team at Voxel51 to integrate their tool, FiftyOne, into Lightning Flash.
FiftyOne is an open-source tool for building high-quality datasets and computer vision models. The FiftyOne API and App enable you to visualize datasets and interpret models faster and more effectively.
This integration allows you to view predictions generated by your tasks in the FiftyOne App, as well as easily incorporate FiftyOne Datasets into your tasks. All image and video tasks are supported!
Installation¶
In order to utilize this integration, you will need to install FiftyOne:
pip install fiftyone
Visualizing Flash predictions¶
This section shows you how to augment your existing Lightning Flash workflows with a couple of lines of code that let you visualize predictions in FiftyOne. You can visualize predictions for classification, object detection, and semantic segmentation tasks. Doing so is as easy as updating your model to use one of the following serializers:
The visualize()
function then lets you visualize
your predictions in the
FiftyOne App. This function accepts a list of
dictionaries containing FiftyOne Label objects
and filepaths, which is exactly the output of the FiftyOne serializers when the
return_filepath=True
option is specified.
from itertools import chain
import torch
import flash
from flash.core.classification import FiftyOneLabels, Labels
from flash.core.data.utils import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.core.integrations.fiftyone import visualize
from flash.image import ImageClassificationData, ImageClassifier
# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")
# 2 Load data
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
predict_folder="data/hymenoptera_data/predict/",
)
# 3 Fine tune a model
model = ImageClassifier(
backbone="resnet18",
num_classes=datamodule.num_classes,
serializer=Labels(),
)
trainer = flash.Trainer(
max_epochs=1,
gpus=torch.cuda.device_count(),
limit_train_batches=1,
limit_val_batches=1,
)
trainer.finetune(
model,
datamodule=datamodule,
strategy=FreezeUnfreeze(unfreeze_epoch=1),
)
trainer.save_checkpoint("image_classification_model.pt")
# 4 Predict from checkpoint
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model.serializer = FiftyOneLabels(return_filepath=True) # output FiftyOne format
predictions = trainer.predict(model, datamodule=datamodule)
predictions = list(chain.from_iterable(predictions)) # flatten batches
# 5 Visualize predictions in FiftyOne App
# Optional: pass `wait=True` to block execution until App is closed
session = visualize(predictions)
The visualize()
function can be used in
all of the following environments:
Local Python shell: The App will launch in a new tab in your default web browser
Remote Python shell: Pass the
remote=True
option and then follow the instructions printed to your remote shell to open the App in your browser on your local machineJupyter notebook: The App will launch in the output of your current cell
Google Colab: The App will launch in the output of your current cell
Python script: Pass the
wait=True
option to block execution of your script until the App is closed
See this page for more information about using the FiftyOne App in different environments.
Using FiftyOne datasets¶
The above workflow is great for visualizing model predictions. However, if you store your data in a FiftyOne Dataset initially, then you can also visualize ground truth annotations. This allows you to perform more complex analysis with views into your data and evaluation of your model results.
The
from_fiftyone()
method allows you to load your FiftyOne datasets directly into a
DataModule
to be used for training,
testing, or inference.
from itertools import chain
import fiftyone as fo
import torch
import flash
from flash.core.classification import FiftyOneLabels, Labels
from flash.core.data.utils import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.image import ImageClassificationData, ImageClassifier
# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")
# 2 Load data into FiftyOne
train_dataset = fo.Dataset.from_dir(
dataset_dir="data/hymenoptera_data/train/",
dataset_type=fo.types.ImageClassificationDirectoryTree,
)
val_dataset = fo.Dataset.from_dir(
dataset_dir="data/hymenoptera_data/val/",
dataset_type=fo.types.ImageClassificationDirectoryTree,
)
test_dataset = fo.Dataset.from_dir(
dataset_dir="data/hymenoptera_data/test/",
dataset_type=fo.types.ImageClassificationDirectoryTree,
)
# 3 Load data into Flash
datamodule = ImageClassificationData.from_fiftyone(
train_dataset=train_dataset,
val_dataset=val_dataset,
test_dataset=test_dataset,
)
# 4 Fine tune model
model = ImageClassifier(
backbone="resnet18",
num_classes=datamodule.num_classes,
serializer=Labels(),
)
trainer = flash.Trainer(
max_epochs=1,
gpus=torch.cuda.device_count(),
limit_train_batches=1,
limit_val_batches=1,
)
trainer.finetune(
model,
datamodule=datamodule,
strategy=FreezeUnfreeze(unfreeze_epoch=1),
)
trainer.save_checkpoint("image_classification_model.pt")
# 5 Predict from checkpoint on data with ground truth
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model.serializer = FiftyOneLabels(return_filepath=False) # output FiftyOne format
datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset)
predictions = trainer.predict(model, datamodule=datamodule)
predictions = list(chain.from_iterable(predictions)) # flatten batches
# 6 Add predictions to dataset
test_dataset.set_values("predictions", predictions)
# 7 Evaluate your model
results = test_dataset.evaluate_classifications("predictions", gt_field="ground_truth", eval_key="eval")
results.print_report()
plot = results.plot_confusion_matrix()
plot.show()
# 8 Visualize results in the App
session = fo.launch_app(test_dataset)
# Optional: block execution until App is closed
session.wait()
Visualizing embeddings¶
FiftyOne provides the methods for dimensionality reduction and interactive plotting. When combined with embedding tasks in Flash, you can accomplish powerful workflows like clustering, similarity search, pre-annotation, and more in only a few lines of code.
import fiftyone as fo
import fiftyone.brain as fob
import numpy as np
from flash.core.data.utils import download_data
from flash.image import ImageEmbedder
# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")
# 2 Load data into FiftyOne
dataset = fo.Dataset.from_dir(
"data/hymenoptera_data/test/",
fo.types.ImageClassificationDirectoryTree,
)
# 3 Load model
embedder = ImageEmbedder(backbone="resnet101")
# 4 Generate embeddings
filepaths = dataset.values("filepath")
embeddings = np.stack(embedder.predict(filepaths))
# 5 Visualize in FiftyOne App
results = fob.compute_visualization(dataset, embeddings=embeddings)
session = fo.launch_app(dataset)
plot = results.visualize(labels="ground_truth.label")
plot.show()
# Optional: block execution until App is closed
session.wait()
IceVision¶
IceVision from airctic is an awesome computer vision framework which offers a curated collection of hundreds of high-quality pre-trained models for: object detection, keypoint detection, and instance segmentation. In Flash, we’ve integrated the IceVision framework to provide: data loading, augmentation, backbones, and heads. We use IceVision components in our: object detection, instance segmentation, and keypoint detection tasks. Take a look at their documentation and star IceVision on GitHub to spread the open source love!
IceData¶
The IceData library is a community driven dataset hub for IceVision.
All of the datasets in IceData can be used out of the box with flash using our .from_folders
methods and the parser
argument.
Take a look at our Keypoint Detection page for an example.
Albumentations with IceVision and Flash¶
IceVision provides two utilities for using the albumentations library with their models:
- the Adapter
helper class for adapting an any albumentations transform to work with IceVision records,
- the aug_tfms
utility function that returns a standard augmentation recipe to get the most out of your model.
In Flash, we use the aug_tfms
as default transforms for the: object detection, instance segmentation, and keypoint detection tasks.
You can also provide custom transforms from albumentations using the IceVisionTransformAdapter
(which relies on the IceVision Adapter
underneath).
Here’s an example:
import albumentations as A
from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData
train_transform = {
"pre_tensor_transform": IceVisionTransformAdapter([A.HorizontalFlip(), A.Normalize()]),
}
datamodule = ObjectDetectionData.from_coco(
...,
train_transform=train_transform,
)
Learn2Learn¶
Learn2Learn is a software library for meta-learning research by Sébastien M. R. Arnold and al. (Aug 2020)
What is Meta-Learning and why you should care?¶
Humans can distinguish between new objects with little or no training data, However, machine learning models often require thousands, millions, billions of annotated data samples to achieve good performance while extrapolating their learned knowledge on unseen objects.
A machine learning model which could learn or learn to learn from only few new samples (K-shot learning) would have tremendous applications once deployed in production. In an extreme case, a model performing 1-shot or 0-shot learning could be the source of new kind of AI applications.
Meta-Learning is a sub-field of AI dedicated to the study of few-shot learning algorithms. This is often characterized as teaching deep learning models to learn with only a few labeled data. The goal is to repeatedly learn from K-shot examples during training that match the structure of the final K-shot used in production. It is important to note that the K-shot example seen in production are very likely to be completely out-of-distribution with new objects.
How does Meta-Learning work?¶
In meta-learning, the model is trained over multiple meta tasks. A meta task is the smallest unit of data and it represents the data available to the model once in its deployment environment. By doing so, we can optimise the model and get higher results.
For image classification, a meta task is comprised of shot + query elements for each class. The shots samples are used to adapt the parameters and the queries ones to update the original model weights. The classes used in the validation and testing shouldn’t be present within the training dataset, as the goal is to optimise the model performance on out-of-distribution (OOD) data with little label data.
When training the model with the meta-learning algorithm, the model will average its gradients over meta_batch_size meta tasks before performing an optimizer step. Traditionally, an meta epoch is composed of multiple meta batch.
Use Meta-Learning with Flash¶
With its integration within Flash, Meta Learning has never been simpler. Flash takes care of all the hard work: the tasks sampling, meta optimizer update, distributed training, etc…
Note
The users requires to provide a training dataset and testing dataset with no overlapping classes. Flash doesn’t support this feature out-of-the box.
Once done, the users are left to play the hyper-parameters associated with the meta-learning algorithm.
Here is an example using miniImageNet dataset containing 100 classes divided into 64 training, 16 validation, and 20 test classes.
# adapted from https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py#L154
import warnings
import kornia.augmentation as Ka
import kornia.geometry as Kg
import learn2learn as l2l
import torch
import torchvision
from torch import nn
import flash
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.transforms import ApplyToKeys, kornia_collate
from flash.image import ImageClassificationData, ImageClassifier
warnings.simplefilter("ignore")
# download MiniImagenet
train_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="train", download=True)
val_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="validation", download=True)
test_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="test", download=True)
train_transform = {
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
),
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
Kg.Resize((196, 196)),
# SPATIAL
Ka.RandomHorizontalFlip(p=0.25),
Ka.RandomRotation(degrees=90.0, p=0.25),
Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25),
Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
# PIXEL-LEVEL
Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness
Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation
Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast
Ka.ColorJitter(hue=1 / 30, p=0.25), # hue
Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25),
Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25),
),
"collate": kornia_collate,
"per_batch_transform_on_device": ApplyToKeys(
DefaultDataKeys.INPUT,
Ka.RandomHorizontalFlip(p=0.25),
),
}
# construct datamodule
datamodule = ImageClassificationData.from_tensors(
train_data=train_dataset.x,
train_targets=torch.from_numpy(train_dataset.y.astype(int)),
val_data=val_dataset.x,
val_targets=torch.from_numpy(val_dataset.y.astype(int)),
test_data=test_dataset.x,
test_targets=torch.from_numpy(test_dataset.y.astype(int)),
num_workers=4,
train_transform=train_transform,
)
model = ImageClassifier(
backbone="resnet18",
training_strategy="prototypicalnetworks",
training_strategy_kwargs={
"epoch_length": 10 * 16,
"meta_batch_size": 4,
"num_tasks": 200,
"test_num_tasks": 2000,
"ways": datamodule.num_classes,
"shots": 1,
"test_ways": 5,
"test_shots": 1,
"test_queries": 15,
},
optimizer=torch.optim.Adam,
optimizer_kwargs={"lr": 0.001},
)
trainer = flash.Trainer(
max_epochs=200,
gpus=2,
accelerator="ddp_shared",
precision=16,
)
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")
You can read their paper Learn2Learn: A Library for Meta-Learning Research.
And don’t forget to cite Learn2Learn repository in your academic publications. Find their Biblex on their repository.
VISSL¶
VISSL is a library from Facebook AI Research for state-of-the-art self-supervised learning. We integrate VISSL models and algorithms into Flash with the image embedder task.
Using VISSL with Flash¶
The ImageEmbedder task in Flash can be configured with different backbones, projection heads, image transforms and loss functions so that you can train your feature extractor using a SOTA SSL method.
from flash.image import ImageEmbedder
embedder = ImageEmbedder(
backbone="resnet",
training_strategy="barlow_twins",
head="simclr_head",
pretraining_transform="barlow_twins_transform",
training_strategy_kwargs={"latent_embedding_dim": 256, "dims": [2048, 2048, 256]},
pretraining_transform_kwargs={"size_crops": [196]},
)
The user can pass arguments to the training strategy, image transforms and backbones using the optional dictionary arguments the ImageEmbedder task accepts.
The training strategies club together the projection head, the loss function as well as VISSL hooks for a particular algorithm and the arguments to customize these can passed via training_strategy_kwargs
.
As an example, in the above code block, the latent_embedding_dim
is an argument to the BarlowTwins loss function from VISSL, while the dims
argument configures the projection head to output 256 dim vectors for the loss function.
If you find VISSL integration in Flash useful for your research, please don’t forget to cite us and the VISSL library. You can find our bibtex on Flash and VISSL’s bibxtex on their github page.
flash¶
The |
|
A basic DataModule class for all Flash tasks. |
|
|
|
The |
|
The |
|
A |
|
A general Task. |
|
flash.core¶
flash.core.adapter¶
The |
|
The |
flash.core.classification¶
A |
|
A base class for classification serializers. |
|
A |
|
A |
|
A |
|
A |
|
A |
flash.core.finetuning¶
FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback. |
|
flash.core.integrations.fiftyone¶
Visualizes predictions from a model with a FiftyOne Serializer in the FiftyOne App. |
flash.core.integrations.icevision¶
The default transforms from IceVision. |
|
The default augmentations from IceVision. |
flash.core.model¶
Specialized callback only used during testing Keeps track metrics during training. |
|
The |
|
The |
|
A general Task. |
flash.core.registry¶
This class is used to register function or |
|
The |
|
The |
flash.core.optimizers¶
Extends SGD in PyTorch with LARS scaling from the paper Large batch training of Convolutional Networks. |
|
Extends ADAM in pytorch to incorporate LAMB algorithm from the paper: Large batch optimization for deep learning: Training BERT in 76 minutes. |
|
Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr and base_lr followed by a cosine annealing schedule between base_lr and eta_min. |
Utilities¶
Modified version of |
|
|
|
This decorator is used as context manager to put model in eval mode before running predict and reset to train after. |
flash.core.data¶
flash.core.data.auto_dataset¶
The |
|
The |
|
The |
flash.core.data.base_viz¶
This Base Class is used to create visualization tool on top of |
flash.core.data.batch¶
This function is used to uncollate a batch into samples. |
flash.core.data.callback¶
This class is used to profile |
|
|
flash.core.data.data_module¶
A basic DataModule class for all Flash tasks. |
flash.core.data.data_pipeline¶
DataPipeline holds the engineering logic to connect |
|
A class to store and share all process states once a |
flash.core.data.data_source¶
The |
|
The |
|
The |
|
The |
|
The |
|
A |
|
The |
|
The |
|
The |
|
The |
|
The |
Checks if a file is an allowed extension. |
|
|
|
Generates a list of samples of a form (path_to_sample, class). |
flash.core.data.process¶
Deserializer Mapping. |
|
Deserializer. |
|
The |
|
The |
|
If the model output is a dictionary, then the |
|
A |
flash.core.data.properties¶
Base class for all process states. |
|
flash.core.data.splits¶
SplitDataset is used to create Dataset Subset using indices. |
flash.core.data.transforms¶
The |
|
The |
Utility function to merge two transform dictionaries. |
|
Kornia transforms add batch dimension which need to be removed. |
flash.core.data.utils¶
This class is used to wrap a callable within a nn.Module and apply the wrapped function in __call__ |
Download file with progressbar. |
flash.core.serve¶
alias of |
|
Create a composition which define computations / endpoints to create & run. |
|
An endpoint maps a route and request/response payload to components. |
|
ModuleWrapperBase around a model object to enable serving at scale. |
|
Expose a function/method via a web API for serving model inference. |
flash.image¶
Classification¶
The |
|
Data module for image classification tasks. |
|
Preprocssing of data of image classification. |
|
Process and show the image batch and its associated label using matplotlib. |
The default transforms for image classification: resize the image, convert the image and target to a tensor, collate the batch, and apply normalization. |
|
During training, we apply the default transforms with additional |
Object Detection¶
The |
|
A |
Keypoint Detection¶
The |
|
Instance Segmentation¶
The |
|
Embedding¶
The |
Segmentation¶
|
|
Data module for semantic segmentation tasks. |
|
Process and show the image batch and its associated label using matplotlib. |
|
A |
|
A |
The default transforms for semantic segmentation: resize the image and mask, collate the batch, and apply normalization. |
|
Convert the target mask to long and remove the channel dimension. |
|
During training, we apply the default transforms with additional |
Style Transfer¶
|
|
|
flash.image.data¶
flash.audio¶
Classification¶
Data module for audio classification. |
|
Speech Recognition¶
Data Module for text classification tasks. |
|
The |
|
The |
|
flash.pointcloud¶
Segmentation¶
The |
|
Object Detection¶
The |
|
flash.tabular¶
Classification¶
The |
|
Regression¶
flash.tabular.data¶
Data module for tabular tasks. |
|
flash.text¶
Classification¶
The |
|
Data Module for text classification tasks. |
|
Question Answering¶
The |
|
Data module for QuestionAnswering task. |
|
The |
|
|
|
Summarization¶
The |
|
Translation¶
The |
|
Data module for Translation tasks. |
|
General Seq2Seq¶
General Task for Sequence2Sequence. |
|
Data module for Seq2Seq tasks. |
|
Freezes the embedding layers during Seq2Seq training. |
|
The |
|
Calculate BLEU score of machine translated text with one or more references. |
|
Aggregates rouge scores and provides confidence intervals. |
|
Metric used for automatic summarization. |
flash.video¶
Classification¶
Task that classifies videos. |
|
Data module for Video classification tasks. |
|
flash.graph¶
Classification¶
The |
|
Data module for graph classification tasks. |
|
flash.graph.data¶
Introduction / Set-up¶
Welcome¶
Before you begin, we’d like to express our gratitude to you for wanting to add a task to Flash. With Flash our aim is to create a great user experience, enabling awesome advanced applications with just a few lines of code. We’re really pleased with what we’ve achieved with Flash and we hope you will be too. Now let’s dive in!
Set-up¶
The Task template is designed to guide you through contributing a task to Flash. It contains the code, tests, and examples for a task that performs classification with a multi-layer perceptron, intended for use with the classic data sets from scikit-learn. The Flash tasks are organized in folders by data-type (image, text, video, etc.), with sub-folders for different task types (classification, regression, etc.).
Copy the files in flash/template/classification to a new sub-directory under the relevant data-type.
If a data-type folder already exists for your task, then a task type sub-folder should be added containing the template files.
If a data-type folder doesn’t exist, then you will need to add that too.
You should also copy the files from tests/template/classification to the corresponding data-type, task type folder in tests
.
For example, if you were adding an image classification task, you would do:
mkdir flash/image/classification
cp flash/template/classification/* flash/image/classification/
mkdir tests/image/classification
cp tests/template/classification/* tests/image/classification/
Tutorials¶
The tutorials in this section will walk you through all of the components you need to implement (or adapt from the template) for your custom task.
The Data: our first tutorial goes over the best practices for implementing everything you need to connect data to your task
The Backbones: the second tutorial shows you how to create an extensible backbone registry for your task
The Task: now that we have the data and the models, in this tutorial we create our custom task
Optional Extras: this tutorial covers some optional extras you can add if needed for your particular task
The Example: this tutorial guides you through creating some simple examples showing your task in action
The Tests: in this tutorial, we cover best practices for writing some tests for your new task
The Docs: in our final tutorial, we provide a template for you to create the docs page for your task
The Data¶
The first step to contributing a task is to implement the classes we need to load some data. Inside data.py you should implement:
some
DataSource
classes (optional)a
BaseVisualization
(optional)a
Postprocess
(optional)
DataSource¶
The DataSource
class contains the logic for data loading from different sources such as folders, files, tensors, etc.
Every Flash DataModule
can be instantiated with from_datasets()
.
For each additional way you want the user to be able to instantiate your DataModule
, you’ll need to create a DataSource
.
Each DataSource
has 2 methods:
load_data()
takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata.load_sample()
then takes as input a single element from the output ofload_data
and returns a sample.
By default these methods just return their input, so you don’t need both a load_data()
and a load_sample()
to create a DataSource
.
Where possible, you should override one of our existing DataSource
classes.
Let’s start by implementing a TemplateNumpyDataSource
, which overrides NumpyDataSource
.
The main DataSource
method that we have to implement is load_data()
.
As we’re extending the NumpyDataSource
, we expect the same data
argument (in this case, a tuple containing data and corresponding target arrays).
We can also take the dataset argument.
Any attributes we set on dataset
will be available on the Dataset
generated by our DataSource
.
In this data source, we’ll set the num_features
attribute.
Here’s the code for our TemplateNumpyDataSource.load_data
method:
def load_data(self, data: Tuple[np.ndarray, Sequence[Any]], dataset: Any) -> Sequence[Mapping[str, Any]]:
"""Sets the ``num_features`` attribute and calls ``super().load_data``.
Args:
data: The tuple of ``np.ndarray`` (num_examples x num_features) and associated targets.
dataset: The object that we can set attributes (such as ``num_features``) on.
Returns:
A sequence of samples / sample metadata.
"""
dataset.num_features = data[0].shape[1]
return super().load_data(data, dataset)
Note
Later, when we add our DataModule implementation, we’ll make num_features
available to the user.
Sometimes you need to something a bit more custom.
When creating a custom DataSource
, the type of the data
argument is up to you.
For our template Task
, it would be cool if the user could provide a scikit-learn Bunch
as the data source.
To achieve this, we’ll add a TemplateSKLearnDataSource
whose load_data
expects a Bunch
as input.
We override our TemplateNumpyDataSource
so that we can call super
with the data and targets extracted from the Bunch
.
We perform two additional steps here to improve the user experience:
We set the
num_classes
attribute on thedataset
. Ifnum_classes
is set, it is automatically made available as a property of theDataModule
.We create and set a
LabelsState
. The labels provided here will be shared with theLabels
serializer, so the user doesn’t need to provide them.
Here’s the code for the TemplateSKLearnDataSource.load_data
method:
def load_data(self, data: Bunch, dataset: Any) -> Sequence[Mapping[str, Any]]:
"""Gets the ``data`` and ``target`` attributes from the ``Bunch`` and passes them to ``super().load_data``.
Args:
data: The scikit-learn data ``Bunch``.
dataset: The object that we can set attributes (such as ``num_classes``) on.
Returns:
A sequence of samples / sample metadata.
"""
dataset.num_classes = len(data.target_names)
self.set_state(LabelsState(data.target_names))
return super().load_data((data.data, data.target), dataset=dataset)
We can customize the behaviour of our load_data()
for different stages, by prepending train, val, test, or predict.
For our TemplateSKLearnDataSource
, we don’t want to provide any targets to the model when predicting.
We can implement predict_load_data
like this:
def predict_load_data(self, data: Bunch) -> Sequence[Mapping[str, Any]]:
"""Avoid including targets when predicting.
Args:
data: The scikit-learn data ``Bunch``.
Returns:
A sequence of samples / sample metadata.
"""
return super().predict_load_data(data.data)
DataSource vs Dataset¶
A DataSource
is not the same as a torch.utils.data.Dataset
.
When a from_*
method is called on your DataModule
, it gets the DataSource
to use from the Preprocess
.
A Dataset
is then created from the DataSource
for each stage (train, val, test, predict) using the provided metadata (e.g. folder name, numpy array etc.).
The output of the load_data()
can just be a torch.utils.data.Dataset
instance.
If the library that your Task
is based on provides a custom dataset, you don’t need to re-write it as a DataSource
.
For example, the load_data()
of the VideoClassificationPathsDataSource
just creates an EncodedVideoDataset
from the given folder.
Here’s how it looks (from video/classification.data.py):
def load_data(self, data: str, dataset: Optional[Any] = None) -> "LabeledVideoDataset":
ds = self._make_encoded_video_dataset(data)
if self.training:
label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in ds._labeled_videos._paths_and_labels}
self.set_state(LabelsState(label_to_class_mapping))
dataset.num_classes = len(np.unique([s[1]["label"] for s in ds._labeled_videos]))
return ds
Preprocess¶
The Preprocess
object contains all the data transforms.
Internally we inject the Preprocess
transforms at several points along the pipeline.
Defining the standard transforms (typically at least a to_tensor_transform
should be defined) for your Preprocess
is as simple as implementing the default_transforms
method.
The Preprocess
must take train_transform
, val_transform
, test_transform
, and predict_transform
arguments in the __init__
.
These arguments can be provided by the user (when creating the DataModule
) to override the default transforms.
Any additional arguments are up to you.
Inside the __init__
, we make a call to super.
This is where we register our data sources.
Data sources should be given as a dictionary which maps data source name to data source object.
The name can be anything, but if you want to take advantage of our built-in from_*
classmethods, you should use DefaultDataSources
as the names.
In our case, we have both a NUMPY
and a custom scikit-learn data source (which we’ll call “sklearn”).
You should also provide a default_data_source
.
This is the name of the data source to use by default when predicting.
It’d be cool if we could get predictions just from a numpy array, so we’ll use NUMPY
as the default.
Here’s our TemplatePreprocess.__init__
:
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
):
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.NUMPY: TemplateNumpyDataSource(),
"sklearn": TemplateSKLearnDataSource(),
},
default_data_source=DefaultDataSources.NUMPY,
)
For our TemplatePreprocess
, we’ll just configure a default to_tensor_transform
.
Let’s first define the transform as a staticmethod
:
@staticmethod
def input_to_tensor(input: np.ndarray):
"""Transform which creates a tensor from the given numpy ``ndarray`` and converts it to ``float``"""
return torch.from_numpy(input).float()
Our inputs samples will be dictionaries whose keys are in the DefaultDataKeys
.
You can map each key to different transforms using ApplyToKeys
.
Here’s our default_transforms
method:
def default_transforms(self) -> Optional[Dict[str, Callable]]:
"""Configures the default ``to_tensor_transform``.
Returns:
Our dictionary of transforms.
"""
return {
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, self.input_to_tensor),
ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
),
}
DataModule¶
The DataModule
is responsible for creating the DataLoader
and injecting the transforms for each stage.
When the user calls a from_*
method (such as from_numpy()
), the following steps take place:
The
from_data_source()
method is called with the name of theDataSource
to use and the inputs to provide toload_data()
for each stage.The
Preprocess
is created fromcls.preprocess_cls
(if it wasn’t provided by the user) with any provided transforms.The
DataSource
of the provided name is retrieved from thePreprocess
.A
BaseAutoDataset
is created from theDataSource
for each stage.The
DataModule
is instantiated with the data sets.
To create our TemplateData
DataModule
, we first need to attach out preprocess class like this:
preprocess_cls = TemplatePreprocess
Since we provided a NUMPY
DataSource
in the TemplatePreprocess
, from_numpy()
will now work with our TemplateData
.
If you’ve defined a fully custom DataSource
(like our TemplateSKLearnDataSource
), then you will need to write a from_*
method for each.
Here’s the from_sklearn
method for our TemplateData
:
@classmethod
def from_sklearn(
cls,
train_bunch: Optional[Bunch] = None,
val_bunch: Optional[Bunch] = None,
test_bunch: Optional[Bunch] = None,
predict_bunch: Optional[Bunch] = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: int = 0,
**preprocess_kwargs: Any,
):
"""This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and passes them
through to the :meth:`~flash.core.data.data_module.DataModule.from_data_source` method underneath.
Args:
train_bunch: The scikit-learn ``Bunch`` containing the train data.
val_bunch: The scikit-learn ``Bunch`` containing the validation data.
test_bunch: The scikit-learn ``Bunch`` containing the test data.
predict_bunch: The scikit-learn ``Bunch`` containing the predict data.
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
predict_transform: The dictionary of transforms to use during predicting which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` will be
constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Returns:
The constructed data module.
"""
return super().from_data_source(
"sklearn",
train_bunch,
val_bunch,
test_bunch,
predict_bunch,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
**preprocess_kwargs,
)
The final step is to implement the num_features
property for our TemplateData
.
This is just a convenience for the user that finds the num_features
attribute on any of the data sets and returns it.
Here’s the code:
@property
def num_features(self) -> Optional[int]:
"""Tries to get the ``num_features`` from each dataset in turn and returns the output."""
n_fts_train = getattr(self.train_dataset, "num_features", None)
n_fts_val = getattr(self.val_dataset, "num_features", None)
n_fts_test = getattr(self.test_dataset, "num_features", None)
return n_fts_train or n_fts_val or n_fts_test
BaseVisualization¶
An optional step is to implement a BaseVisualization
.
The BaseVisualization
lets you control how data at various points in the pipeline can be visualized.
This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms.
Note
Don’t worry about implementing it right away, you can always come back and add it later!
Here’s the code for our TemplateVisualization
which just prints the data:
class TemplateVisualization(BaseVisualization):
"""The ``TemplateVisualization`` class is a :class:`~flash.core.data.callbacks.BaseVisualization` that just
prints the data.
If you want to provide a visualization with your task, you can override these hooks.
"""
def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
print(samples)
def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
print(samples)
def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
print(samples)
def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
print(samples)
def show_per_batch_transform(self, batch: List[Any], running_stage):
print(batch)
We can configure our custom visualization in the TemplateData
using configure_data_fetcher()
like this:
@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
"""We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher``
method."""
return TemplateVisualization(*args, **kwargs)
Postprocess¶
Postprocess
contains any transforms that need to be applied after the model.
You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc.
As an example, here’s the TextClassificationPostprocess
which gets the logits from a SequenceClassifierOutput
:
class TextClassificationPostprocess(Postprocess):
def per_batch_transform(self, batch: Any) -> Any:
if isinstance(batch, SequenceClassifierOutput):
batch = batch.logits
return super().per_batch_transform(batch)
In your DataSource
or Preprocess
, you can add metadata to the batch using the METADATA
key.
Your Postprocess
can then use this metadata in its transforms.
You should use this approach if your postprocessing depends on the state of the input before the Preprocess
transforms.
For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the METADATA
.
Here’s an example from the SemanticSegmentationNumpyDataSource
:
def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float()
sample[DefaultDataKeys.INPUT] = img
sample[DefaultDataKeys.METADATA] = {"size": img.shape}
return sample
The METADATA
can now be referenced in your Postprocess
.
For example, here’s the code for the per_sample_transform
method of the SemanticSegmentationPostprocess
:
def per_sample_transform(self, sample: Any) -> Any:
resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA]["size"][-2:], interpolation="bilinear")
sample[DefaultDataKeys.PREDS] = resize(sample[DefaultDataKeys.PREDS])
sample[DefaultDataKeys.INPUT] = resize(sample[DefaultDataKeys.INPUT])
return super().per_sample_transform(sample)
Now that you’ve got some data, it’s time to add some backbones for your task!
The Backbones¶
Now that you’ve got a way of loading data, you should implement some backbones to use with your Task
.
Create a FlashRegistry
to use with your Task
in backbones.py.
The registry allows you to register backbones for your task that can be selected by the user.
The backbones can come from anywhere as long as you can register a function that loads the backbone.
Furthermore, the user can add their own models to the existing backbones, without having to write their own Task
!
You can create a registry like this:
TEMPLATE_BACKBONES = FlashRegistry("backbones")
Let’s add a simple MLP backbone to our registry.
We need a function that creates the backbone and returns it along with the output size (so that we can create the model head in our Task
).
You can use any name for the function, although we use load_{model name}
by convention.
You also need to provide name
and namespace
of the backbone.
The standard for namespace is data_type/task_type
, so for an image classification task the namespace will be image/classification
.
Here’s the code:
@TEMPLATE_BACKBONES(name="mlp-128", namespace="template/classification")
def load_mlp_128(num_features, **_):
"""A simple MLP backbone with 128 hidden units."""
return (
nn.Sequential(
nn.Linear(num_features, 128),
nn.ReLU(True),
nn.BatchNorm1d(128),
),
128,
)
Here’s another example with a slightly more complex model:
@TEMPLATE_BACKBONES(name="mlp-128-256", namespace="template/classification")
def load_mlp_128_256(num_features, **_):
"""An two layer MLP backbone with 128 and 256 hidden units respectively."""
return (
nn.Sequential(
nn.Linear(num_features, 128),
nn.ReLU(True),
nn.BatchNorm1d(128),
nn.Linear(128, 256),
nn.ReLU(True),
nn.BatchNorm1d(256),
),
256,
)
Here’s a another example, which adds DINO
pretrained model from PyTorch Hub to the IMAGE_CLASSIFIER_BACKBONES
, from flash/image/classification/backbones/transformers.py:
def dino_vitb16(*_, **__):
backbone = torch.hub.load("facebookresearch/dino:main", "dino_vitb16")
return backbone, 768
Once you’ve got some data and some backbones, implement your task!
The Task¶
Once you’ve implemented a Flash DataModule
and some backbones, you should implement your Task
in model.py.
The Task
is responsible for: setting up the backbone, performing the forward pass of the model, and calculating the loss and any metrics.
Remember that, under the hood, the Flash Task
is simply a LightningModule
with some helpful defaults.
To build your task, you can start by overriding the base Task
or any of the existing Task
implementations.
For example, in our scikit-learn example, we can just override ClassificationTask
which provides good defaults for classification.
You should attach your backbones registry as a class attribute like this:
class TemplateSKLearnClassifier(ClassificationTask):
backbones: FlashRegistry = TEMPLATE_BACKBONES
Model architecture and hyper-parameters¶
In the __init__()
, you will need to configure defaults for the:
loss function
optimizer
metrics
backbone / model
You will also need to create the backbone from the registry and create the model head. Here’s the code:
def __init__(
self,
num_features: int,
num_classes: int,
backbone: Union[str, Tuple[nn.Module, int]] = "mlp-128",
backbone_kwargs: Optional[Dict] = None,
loss_fn: LOSS_FN_TYPE = None,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
metrics: METRICS_TYPE = None,
learning_rate: float = 1e-2,
multi_label: bool = False,
serializer: SERIALIZER_TYPE = None,
):
super().__init__(
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
serializer=serializer or Labels(),
)
self.save_hyperparameters()
if not backbone_kwargs:
backbone_kwargs = {}
if isinstance(backbone, tuple):
self.backbone, out_features = backbone
else:
self.backbone, out_features = self.backbones.get(backbone)(num_features=num_features, **backbone_kwargs)
self.head = nn.Linear(out_features, num_classes)
Note
We call save_hyperparameters()
to log the arguments to the __init__
as hyperparameters. Read more here.
Adding the model routines¶
You should override the {train,val,test,predict}_step
methods.
The default {train,val,test,predict}_step
implementations in Task
expect a tuple containing the input (to be passed to the model) and target (to be used when computing the loss), and should be suitable for most applications.
In our template example, we just extract the input and target from the input mapping and forward them to the super
methods.
Here’s the code for the training_step
:
def training_step(self, batch: Any, batch_idx: int) -> Any:
"""For the training step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` and
:attr:`~flash.core.data.data_source.DefaultDataKeys.TARGET` keys from the input and forward them to the
:meth:`~flash.core.model.Task.training_step`."""
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return super().training_step(batch, batch_idx)
We use the same code for the validation_step
and test_step
.
For predict_step
we don’t need the targets, so our code looks like this:
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
"""For the predict step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` key
from the input and forward it to the :meth:`~flash.core.model.Task.predict_step`."""
batch = batch[DefaultDataKeys.INPUT]
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
Note
You can completely replace the {train,val,test,predict}_step
methods (that is, without a call to super
) if you need more custom behaviour for your Task
at a particular stage.
Finally, we use our backbone and head in a custom forward pass:
def forward(self, x) -> torch.Tensor:
"""First call the backbone, then the model head."""
x = self.backbone(x)
return self.head(x)
Now that you’ve got your task, take a look at some optional advanced features you can add or go ahead and create some examples showing your task in action!
Optional Extras¶
Organize your transforms in transforms.py¶
If you have a lot of default transforms, it can be useful to put them all in a transforms.py
file, to be referenced in your Preprocess
.
Here’s an example from image/classification/transforms.py which creates some default transforms given the desired image size:
def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
"""The default transforms for image classification: resize the image, convert the image and target to a tensor,
collate the batch, and apply normalization."""
if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1":
# Better approach as all transforms are applied on tensor directly
return {
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
),
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
K.geometry.Resize(image_size),
),
"collate": kornia_collate,
"per_batch_transform_on_device": ApplyToKeys(
DefaultDataKeys.INPUT,
K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])),
),
}
return {
"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(image_size)),
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
),
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
),
"collate": kornia_collate,
}
Here’s how we create our transforms in the ImageClassificationPreprocess
:
def default_transforms(self) -> Optional[Dict[str, Callable]]:
return default_transforms(self.image_size)
Add output serializers to your Task¶
We recommend that you do most of the heavy lifting in the Postprocess
.
Specifically, it should include any formatting and transforms that should always be applied to the predictions.
If you want to support different use cases that require different prediction formats, you should add some Serializer
implementations in a serialization.py
file.
Some good examples are in flash/core/classification.py.
Here’s the Classes
Serializer
:
class Classes(PredsClassificationSerializer):
"""A :class:`.Serializer` which applies an argmax to the model outputs (either logits or probabilities) and
converts to a list.
Args:
multi_label: If true, treats outputs as multi label logits.
threshold: The threshold to use for multi_label classification.
"""
def __init__(self, multi_label: bool = False, threshold: float = 0.5):
super().__init__(multi_label)
self.threshold = threshold
def serialize(self, sample: Any) -> Union[int, List[int]]:
sample = super().serialize(sample)
if self.multi_label:
one_hot = (sample.sigmoid() > self.threshold).int().tolist()
result = []
for index, value in enumerate(one_hot):
if value == 1:
result.append(index)
return result
return torch.argmax(sample, -1).tolist()
Alternatively, here’s the Logits
Serializer
:
class Logits(PredsClassificationSerializer):
"""A :class:`.Serializer` which simply converts the model outputs (assumed to be logits) to a list."""
def serialize(self, sample: Any) -> Any:
return super().serialize(sample).tolist()
Take a look at Predictions (inference) to learn more.
Once you’ve added any optional extras, it’s time to create some examples showing your task in action!
The Example¶
Now you’ve implemented your task, it’s time to add an example showing how cool it is!
We usually provide one example in flash_examples/.
You can base these off of our template.py
examples.
The example should:
download the data (we’ll add the example to our CI later on, so choose a dataset small enough that it runs in reasonable time)
load the data into a
DataModule
create an instance of the
Task
create a
Trainer
call
finetune()
orfit()
to train your modelgenerate predictions for a few examples
save the checkpoint
For our template example we don’t have a pretrained backbone, so we can just call fit()
rather than finetune()
.
Here’s the full example (flash_examples/template.py):
import numpy as np
import torch
from sklearn import datasets
import flash
from flash.template import TemplateData, TemplateSKLearnClassifier
# 1. Create the DataModule
datamodule = TemplateData.from_sklearn(
train_bunch=datasets.load_iris(),
val_split=0.1,
)
# 2. Build the task
model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Classify a few examples
predictions = model.predict(
[
np.array([4.9, 3.0, 1.4, 0.2]),
np.array([6.9, 3.2, 5.7, 2.3]),
np.array([7.2, 3.0, 5.8, 1.6]),
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("template_model.pt")
We get this output:
['setosa', 'virginica', 'versicolor']
Now that you’ve got an example showing your awesome task in action, it’s time to write some tests!
The Tests¶
Our next step is to create some tests for our Task
.
For the TemplateSKLearnClassifier
, we will just create some basic tests.
You should expand on these to include tests for any specific functionality you have in your Task
.
Smoke tests¶
We use smoke tests, usually called test_smoke
, throughout.
These just instantiate the class we are testing, to see that they can be created without raising any errors.
tests/examples/test_scripts.py¶
Before we write our custom tests, we should add out examples to the CI.
To do this, add a line for each example (finetuning
and predict
) to the annotation of test_example
in tests/examples/test_scripts.py.
Here’s how those lines look for our template.py
examples:
pytest.param(
"finetuning", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),
...
pytest.param(
"predict", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),
test_data.py¶
The most important tests in test_data.py check that the from_*
methods work correctly.
In the class TestTemplateData
, we have two of these: test_from_numpy
and test_from_sklearn
.
In general, there should be one test_from_*
method for each data_source
you have configured.
Here’s the code for test_from_numpy
:
def test_from_numpy(self):
"""Tests that ``TemplateData`` is properly created when using the ``from_numpy`` method."""
data = np.random.rand(10, self.num_features)
targets = np.random.randint(0, self.num_classes, (10,))
# instantiate the data module
dm = TemplateData.from_numpy(
train_data=data,
train_targets=targets,
val_data=data,
val_targets=targets,
test_data=data,
test_targets=targets,
batch_size=2,
num_workers=0,
)
assert dm is not None
assert dm.train_dataloader() is not None
assert dm.val_dataloader() is not None
assert dm.test_dataloader() is not None
# check training data
data = next(iter(dm.train_dataloader()))
rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert rows.shape == (2, self.num_features)
assert targets.shape == (2,)
# check val data
data = next(iter(dm.val_dataloader()))
rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert rows.shape == (2, self.num_features)
assert targets.shape == (2,)
# check test data
data = next(iter(dm.test_dataloader()))
rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert rows.shape == (2, self.num_features)
assert targets.shape == (2,)
test_model.py¶
In test_model.py, we first have test_forward
and test_train
.
These test that tensors can be passed to the forward and that the Task
can be trained.
Here’s the code for test_forward
and test_train
:
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
@pytest.mark.parametrize("num_classes", [4, 256])
@pytest.mark.parametrize("shape", [(1, 3), (2, 128)])
def test_forward(num_classes, shape):
"""Tests that a tensor can be given to the model forward and gives the correct output size."""
model = TemplateSKLearnClassifier(
num_features=shape[1],
num_classes=num_classes,
)
model.eval()
row = torch.rand(*shape)
out = model(row)
assert out.shape == (shape[0], num_classes)
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_train(tmpdir):
"""Tests that the model can be trained on our ``DummyDataset``."""
model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, train_dl)
We also include tests for validating and testing: test_val
, and test_test
.
These tests are very similar to test_train
, but here they are for completeness:
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_val(tmpdir):
"""Tests that the model can be validated on our ``DummyDataset``."""
model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
val_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.validate(model, val_dl)
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_test(tmpdir):
"""Tests that the model can be tested on our ``DummyDataset``."""
model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
test_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.test(model, test_dl)
We also include tests for prediction named test_predict_*
for each of our data sources.
In our case, we have test_predict_numpy
and test_predict_sklearn
.
These tests should use the data_source
argument to predict()
to select the required DataSource
.
Here’s test_predict_sklearn
as an example:
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_predict_sklearn():
"""Tests that we can generate predictions from a scikit-learn ``Bunch``."""
bunch = datasets.load_iris()
model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
data_pipe = DataPipeline(preprocess=TemplatePreprocess())
out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe)
assert isinstance(out[0], int)
Now that you’ve written the tests, it’s time to add some docs!
The Docs¶
The final step is to add some docs.
For each Task
in Flash, we have a docs page in docs/source/reference.
You should create a .rst
file there with the following:
a brief description of the task
the predict example
the finetuning example
any relevant API reference
Here are the contents of docs/source/reference/template.rst which breaks down each of these steps:
.. _template:
########
Template
########
********
The Task
********
Here you should add a description of your task. For example:
Classification is the task of assigning one of a number of classes to each data point.
------
*******
Example
*******
.. note::
Here you should add a short intro to your example, and then use ``literalinclude`` to add it.
To make it simple, you can fill in this template.
Let's look at the task of <describe the task> using the <data set used in the example>.
The dataset contains <describe the data>.
Here's an outline:
.. code-block::
<present the folder structure of the data or some data samples here>
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the <link to the DataModule with ``:class:``>.
We select a pre-trained backbone to use for our <link to the Task with ``:class:``> and finetune on the <name of the data set> data.
We then use the trained <link to the Task with ``:class:``> for inference.
Finally, we save the model.
Here's the full example:
<include the example with ``literalinclude``>
.. literalinclude:: ../../../flash_examples/template.py
:language: python
:lines: 14-
Once the docs are done, it’s finally time to open a PR and wait for some reviews!
Congratulations on adding your first Task
to Flash, we hope to see you again soon!
Flash Governance | Persons of interest¶
Leads¶
William Falcon (williamFalcon)
Thomas Chaton (tchaton)
Ethan Harris (ethanwharris)
Core Maintainers¶
Jirka Borovec (Borda)
Kaushik Bokka (kaushikb11)
Justus Schock (justusschock)
Carlos Mocholí (carmocca)
Sean Narenthiran (SeanNaren)
Akihiro Nitta (akihironitta)
Aniket Maurya (aniketmaurya)
Ananya Harsh Jha (ananyahjha93)
Sivaraman Karthik Rangasai (karthikrangasai)
Contributing¶
Welcome to the PyTorch Lightning community! We’re building the most advanced research platform on the planet to implement the latest, best practices that the amazing PyTorch team rolls out!
Flash Design Principles¶
We encourage all sorts of contributions you’re interested in adding! When coding for Flash, please follow these principles.
Simple Internal Code¶
It’s useful for users to look at the code and understand very quickly what’s happening. Many users won’t be engineers. Thus we need to value clear, simple code over condensed ninja moves. While that’s super cool, this isn’t the project for that :)
Force User Decisions To Best Practices¶
There are 1,000 ways to do something. However, eventually one popular solution becomes standard practice, and everyone follows. We try to find the best way to solve a particular problem, and then force our users to use it for readability and simplicity.
When something becomes a best practice, we add it to the framework. This is usually something like bits of code in utils or in the model file that everyone keeps adding over and over again across projects. When this happens, bring that code inside the trainer and add a flag for it.
Backward-compatible API¶
We all hate updating our deep learning packages because we don’t want to refactor a bunch of stuff. In Flash, we make sure every change we make which could break an API is backward compatible with good deprecation warnings.
Gain User Trust¶
As a researcher, you can’t have any part of your code going wrong. So, make thorough tests to ensure that every implementation of a new trick or subtle change is correct.
Interoperability¶
PyTorch Lightning Flash is highly interoperable with PyTorch Lightning and PyTorch.
Contribution Types¶
We are always looking for help implementing new features or fixing bugs.
A lot of good work has already been done in project mechanics (requirements.txt, setup.py, pep8, badges, ci, etc…) so we’re in a good state there thanks to all the early contributors (even pre-beta release)!
Bug Fixes:¶
If you find a bug please submit a GitHub issue.
Make sure the title explains the issue.
Describe your setup, what you are trying to do, expected vs. actual behaviour. Please add configs and code samples.
Add details on how to reproduce the issue - a minimal test case is always best, colab is also great. Note, that the sample code shall be minimal and if needed with publicly available data.
Try to fix it or recommend a solution. We highly recommend to use test-driven approach:
Convert your minimal code example to a unit/integration test with assert on expected results.
Start by debugging the issue… You can run just this particular test in your IDE and draft a fix.
Verify that your test case fails on the master branch and only passes with the fix applied.
Submit a PR!
Note, even if you do not find the solution, sending a PR with a test covering the issue is a valid contribution and we can help you or finish it with you :]
New Features:¶
Submit a GitHub issue - describe what is the motivation of such feature (adding the use case or an example is helpful).
Let’s discuss to determine the feature scope.
Submit a PR! We recommend test driven approach to adding new features as well:
Write a test for the functionality you want to add.
Write the functional code until the test passes.
Add/update the relevant tests!
This PR is a good example for adding a new metric, and this one for a new logger.
New Tasks:¶
Flash is a framework of tasks for fast prototyping, baselining, finetuning and solving business and scientific problems with deep learning. Following are general guidelines for adding new tasks.
Models which are standard baselines
Whose results are reproduced properly either by us or by authors.
Top models which are not SOTA but highly cited for production usage / for other uses. (E.g. Mobile BERT, MobileNets, FBNets).
Do not reinvent the wheel, natively support torchvision, torchtext, torchaudio models.
Use open source licensed models.
Please raise an issue before adding a new task. Please let us know why the particular task is important for Flash.
Test cases:¶
Want to keep Lightning Flash healthy? Love seeing those green tests? So do we! How to we keep it that way? We write tests! We value tests contribution even more than new features.
Tests are written using pytest.
Have a look at sample tests here.
After you have added the respective tests, you can run the tests locally with make script:
make test
Want to add a new test case and not sure how? Talk to us!
Guidelines¶
For this section, we refer to read the parent PL guidelines
Reminder
All added or edited code shall be the own original work of the particular contributor.
If you use some third-party implementation, all such blocks/functions/modules shall be properly referred and if possible also agreed by code’s author. For example - This code is inspired from http://...
.
In case you adding new dependencies, make sure that they are compatible with the actual PyTorch Lightning license (ie. dependencies should be at least as permissive as the PyTorch Lightning license).
How to rebase my PR?¶
We recommend creating a PR in a separate branch other than master
, especially if you plan to submit several changes and do not want to wait until the first one is resolved (we can work on them in parallel).
First, make sure you have set upstream by running:
git remote add upstream https://github.com/PyTorchLightning/lightning-flash.git
You’ll know its set up right if you run git remote -v
and see something similar to this:
origin https://github.com/{YOUR_USERNAME}/lightning-flash.git (fetch)
origin https://github.com/{YOUR_USERNAME}/lightning-flash.git (push)
upstream https://github.com/PyTorchLightning/lightning-flash.git (fetch)
upstream https://github.com/PyTorchLightning/lightning-flash.git (push)
Checkout your feature branch and rebase it with upstream’s master before pushing up your feature branch:
git fetch --all --prune
git rebase upstream/master
# follow git instructions to resolve conflicts
git push -f
Question & Answer¶
How can I help/contribute?
All help is extremely welcome - reporting bugs, fixing documentation, adding test cases, solving issues and preparing bug fixes. To solve some issues you can start with label good first issue or chose something close to your domain with label help wanted. Before you start to implement anything check that the issue description that it is clear and self-assign the task to you (if it is not possible, just comment that you take it and we assign it to you…).
Is there a recommendation for branch names?
We do not rely on the name convention so far you are working with your own fork. Anyway it would be nice to follow this convention
<type>/<issue-id>_<short-name>
where the types are:bugfix
,feature
,docs
,tests
, …I have a model in other framework than PyTorch, how do I add it here?
Since PyTorch Lightning is written on top of PyTorch. We need models in PyTorch only. Also, we would need same or equivalent results with PyTorch Lightning after converting the models from other frameworks.
Changelog¶
All notable changes to this project will be documented in this file.
The format is based on Keep a Changelog.
[0.5.1] - 2021-10-26¶
[0.5.1] - Added¶
Added
LabelStudio
integration (#554)Added support
learn2learn
training_strategy forImageClassifier
(#737)Added
vissl
training_strategies forImageEmbedder
(#682)Added support for
from_data_frame
toTextClassificationData
(#785)Added
FastFace
integration (#606)Added support for
from_lists
toTextClassificationData
(#805)
[0.5.1] - Changed¶
[0.5.1] - Fixed¶
[0.5.0] - 2021-09-07¶
[0.5.0] - Added¶
Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method (#552)
Added support for
from_csv
andfrom_data_frame
toImageClassificationData
(#556)Added SimCLR, SwAV, Barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task (#560)
Added support for Semantic Segmentation backbones and heads from
segmentation-models.pytorch
(#562)Added support for nesting of
Task
objects (#575)Added
PointCloudSegmentation
Task (#566)Added
PointCloudObjectDetection
Task (#600)Added a
GraphClassifier
task (#73)Added the option to pass
pretrained
as a string toSemanticSegmentation
to change pretrained weights to load fromsegmentation-models.pytorch
(#587)Added support for
field
parameter for loadng JSON based datasets in text tasks. (#585)Added
AudioClassificationData
and an example for classifying audio spectrograms (#594)Added a
SpeechRecognition
task for speech to text using Wav2Vec (#586)Added Flash Zero, a zero code command line ML platform built with flash (#611)
Added support for
.npy
and.npz
files toImageClassificationData
andAudioClassificationData
(#651)Added support for
from_csv
to theAudioClassificationData
(#651)Added option to pass a
resolver
to thefrom_csv
andfrom_pandas
methods ofImageClassificationData
, which is used to resolve filenames given IDs (#651)Added integration with IceVision for the
ObjectDetector
(#608)Added keypoint detection task (#608)
Added instance segmentation task (#608)
Added Torch ORT support to Transformer based tasks (#667)
Added support for flash zero with the
InstanceSegmentation
andKeypointDetector
tasks (#672)Added support for
in_chans
argument to the flash ResNet to control the expected number of input channels (#673)Added a
QuestionAnswering
task for extractive question answering (#607)Added automatic unwrapping of IceVision prediction objects (#727)
Added support for the
ObjectDetector
with FiftyOne (#727)Added support for MP3 files to the
SpeechRecognition
task with librosa (#726)Added support for
from_numpy
andfrom_tensors
toAudioClassificationData
(#745)
[0.5.0] - Changed¶
Changed how pretrained flag works for loading weights for ImageClassifier task (#560)
Removed bolts pretrained weights for SSL from ImageClassifier task (#560)
Changed the behaviour of the
sampler
argument of theDataModule
to take aSampler
type rather than instantiated object (#651)Changed arguments to
ObjectDetector
, usehead
instead ofmodel
and append_fpn
to the backbone name instead of thefpn
argument (#608)
[0.5.0] - Fixed¶
Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version (#493)
Fixed a bug where train and validation metrics weren’t being correctly computed (#559)
Fixed a bug where an uncaught ValueError could be raised when checking if a module is available (#615)
Fixed a bug where some tasks were not compatible with PyTorch 1.7 due to use of
torch.jit.isinstance
(#611)Fixed a bug where custom samplers would not be properly forwarded to the data loader (#651)
Fixed a bug where it was not possible to pass no metrics to the
ImageClassifier
orTestClassifier
(#660)Fixed a bug where
drop_last
would be set to True during prediction and testing (#671)Fixed a bug where flash was not compatible with pytorch-lightning >= 1.4.3 (#690)
[0.4.0] - 2021-06-22¶
[0.4.0] - Added¶
Added integration with FiftyOne (#360)
Added
flash.serve
(#399)Added support for
torch.jit
to tasks where possible and documented task JIT compatibility (#389)Added option to provide a
Sampler
to theDataModule
to use when creating aDataLoader
(#390)Added support for multi-label text classification and toxic comments example (#401)
Added a sanity checking feature to flash.serve (#423)
[0.4.0] - Changed¶
Split
backbone
argument toSemanticSegmentation
intobackbone
andhead
arguments (#412)
[0.4.0] - Fixed¶
[0.3.2] - 2021-06-08¶
[0.3.2] - Fixed¶
Fixed a bug where
flash.Trainer.from_argparse_args
+finetune
would not work (#382)
[0.3.1] - 2021-06-08¶
[0.3.1] - Added¶
Added
deeplabv3
,lraspp
, andunet
backbones for theSemanticSegmentation
task (#370)
[0.3.1] - Changed¶
[0.3.1] - Deprecated¶
Deprecated
SemanticSegmentation
backbone namestorchvision/fcn_resnet50
andtorchvision/fcn_resnet101
, usefc_resnet50
andfcn_resnet101
instead (#370)
[0.3.1] - Fixed¶
Fixed
flash.Trainer.add_argparse_args
not adding any arguments (#343)Fixed a bug where the translation task wasn’t decoding tokens properly (#332)
Fixed a bug where huggingface tokenizers were sometimes being pickled (#332)
Fixed issue with
KorniaParallelTransforms
to assure to share the random state between transforms (#351)Fixed a bug where using
val_split
withoverfit_batches
would give an infinite recursion (#375)Fixed a bug where some timm models were mistakenly given a
global_pool
argument (#377)Fixed
flash.Trainer.from_argparse_args
not passing arguments correctly (#380)
[0.3.0] - 2021-05-20¶
[0.3.0] - Added¶
Added timm integration (#196)
Added BaseViz Callback (#201)
Added backbone API (#204)
Added support for Iterable auto dataset (#227)
Added multi label support (#230)
Added support for schedulers (#232)
Added visualisation callback for image classification (#228)
Added Video Classification task (#216)
Added Dino backbone for image classification (#259)
Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState (#229)
Added Object detection prediction example (#283)
Added Style Transfer task and accompanying finetuning and prediction examples (#262)
Added a Template task and tutorials showing how to contribute a task to flash (#306)
[0.3.0] - Changed¶
[0.3.0] - Fixed¶
[0.2.3] - 2021-04-17¶
[0.2.3] - Added¶
Added TIMM integration as backbones (#196)
[0.2.3] - Fixed¶
Fixed nltk.download (#210)
[0.2.2] - 2021-04-05¶
[0.2.2] - Changed¶
[0.2.2] - Fixed¶
[0.2.1] - 2021-3-06¶
[0.2.1] - Added¶
[0.2.1] - Changed¶
Set inputs as optional (#109)
[0.2.1] - Fixed¶
[0.2.0] - 2021-02-12¶
[0.2.0] - Added¶
[0.2.0] - Changed¶
[0.2.0] - Fixed¶
[0.2.0] - Removed¶
[0.1.0] - 2021-02-02¶
[0.1.0] - Added¶
Template¶
The Task¶
Here you should add a description of your task. For example: Classification is the task of assigning one of a number of classes to each data point.
Example¶
Note
Here you should add a short intro to your example, and then use literalinclude
to add it.
To make it simple, you can fill in this template.
Let’s look at the task of <describe the task> using the <data set used in the example>. The dataset contains <describe the data>. Here’s an outline:
<present the folder structure of the data or some data samples here>
Once we’ve downloaded the data using download_data()
, we create the <link to the DataModule with :class:
>.
We select a pre-trained backbone to use for our <link to the Task with :class:
> and finetune on the <name of the data set> data.
We then use the trained <link to the Task with :class:
> for inference.
Finally, we save the model.
Here’s the full example:
<include the example with literalinclude
>
import numpy as np
import torch
from sklearn import datasets
import flash
from flash.template import TemplateData, TemplateSKLearnClassifier
# 1. Create the DataModule
datamodule = TemplateData.from_sklearn(
train_bunch=datasets.load_iris(),
val_split=0.1,
)
# 2. Build the task
model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Classify a few examples
predictions = model.predict(
[
np.array([4.9, 3.0, 1.4, 0.2]),
np.array([6.9, 3.2, 5.7, 2.3]),
np.array([7.2, 3.0, 5.8, 1.6]),
]
)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("template_model.pt")