Shortcuts

Lightning Flash




Quick Start

Flash is a high-level deep learning framework for fast prototyping, baselining, finetuning and solving deep learning problems. It features a set of tasks for you to use for inference and finetuning out of the box, and an easy to implement API to customize every step of the process for full flexibility.

Flash is built for beginners with a simple API that requires very little deep learning background, and for data scientists, Kagglers, applied ML practitioners and deep learning researchers that want a quick way to get a deep learning baseline with advanced features PyTorch Lightning offers.

Why Flash?

For getting started with Deep Learning

Easy to learn

If you are just getting started with deep learning, Flash offers common deep learning tasks you can use out-of-the-box in a few lines of code, no math, fancy nn.Modules or research experience required!

Easy to scale

Flash is built on top of PyTorch Lightning, a powerful deep learning research framework for training models at scale. With the power of Lightning, you can train your flash tasks on any hardware: CPUs, GPUs or TPUs without any code changes.

Easy to upskill

If you want to create more complex and customized models, you can refactor any part of flash with PyTorch or PyTorch Lightning components to get all the flexibility you need. Lightning is just organized PyTorch with the unnecessary engineering details abstracted away.

  • Flash (high-level)

  • Lightning (mid-level)

  • PyTorch (low-level)

When you need more flexibility you can build your own tasks or simply use Lightning directly.

For Deep learning research

Quickest way to a baseline

PyTorch Lightning is designed to abstract away unnecessary boilerplate, while enabling maximal flexibility. In order to provide full flexibility, solving very common deep learning problems such as classification in Lightning still requires some boilerplate. It can still take quite some time to get a baseline model running on a new dataset or out of domain task. We created Flash to answer our users need for a super quick way to baseline for Lightning using proven backbones for common data patterns. Flash aims to be the easiest starting point for your research- start with a Flash Task to benchmark against, and override any part of flash with Lightning or PyTorch components on your way to SOTA research.

Flexibility where you want it

Flash tasks are essentially LightningModules, and the Flash Trainer is a thin wrapper for the Lightning Trainer. You can use your own LightningModule instead of the Flash task, the Lightning Trainer instead of the flash trainer, etc. Flash helps you focus even more only on your research, and less on anything else.

Standard best practices

Flash tasks implement the standard best practices for a variety of different models and domains, to save you time digging through different implementations. Flash abstracts even more details than Lightning, allowing deep learning experts to share their tips and tricks for solving scoped deep learning problems.


Tasks

Flash is comprised of a collection of Tasks. The Flash tasks are laser-focused objects designed to solve a well-defined type of problem, using state-of-the-art methods.

The Flash tasks contain all the relevant information to solve the task at hand- the number of class labels you want to predict, number of columns in your dataset, as well as details on the model architecture used such as loss function, optimizers, etc.

Here are examples of tasks:

from flash.text import TextClassifier
from flash.image import ImageClassifier
from flash.tabular import TabularClassifier

Note

Tasks are inflexible by definition! To get more flexibility, you can simply use LightningModule directly or modify an existing task in just a few lines.


Inference

Inference is the process of generating predictions from trained models. To use a task for inference:

  1. Init your task with pretrained weights using a checkpoint (a checkpoint is simply a file that capture the exact value of all parameters used by a model). Local file or URL works.

  2. Load your data into a DataModule and pass it to Trainer.predict.


Here’s an example of inference:

# import our libraries
from flash import Trainer
from flash.text import TextClassifier, TextClassificationData

# 1. Init the finetuned task from URL
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.7.0/text_classification_model.pt")

# 2. Perform inference from list of sequences
trainer = Trainer()
datamodule = TextClassificationData.from_lists(
    predict_data=[
        "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
        "The worst movie in the history of cinema.",
        "This guy has done a great job with this movie!",
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

We get the following output:

[["negative", "negative", "positive"]]

Finetuning

Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset. All Flash tasks have pre-trained backbones that are already trained on large datasets such as ImageNet. Finetuning on pretrained models decreases training time significantly.

To use a Task for finetuning:

  1. Load your data and organize it using a DataModule customized for the task (example: ImageClassificationData).

  2. Choose and initialize your Task which has state-of-the-art backbones built in (example: ImageClassifier).

  3. Init a flash.core.trainer.Trainer.

  4. Choose a finetune strategy (example: “freeze”) and call flash.core.trainer.Trainer.finetune() with your data.

  5. Save your finetuned model.


Here’s an example of finetuning.

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    batch_size=1,
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Using a finetuned model

Once you’ve finetuned, use the model to predict:

predict_datamodule = ImageClassificationData.from_files(
    predict_files=[
        "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
        "data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
    ],
    batch_size=1,
)
predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
print(predictions)

We get the following output:

[['bees', 'ants']]

Or you can use the saved model for prediction anywhere you want!

from flash import Trainer
from flash.image import ImageClassifier, ImageClassificationData

# load finetuned checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")

trainer = Trainer()
datamodule = ImageClassificationData.from_files(predict_files=["path/to/your/own/image.png"])
predictions = trainer.predict(model, datamodule=datamodule)

Training

When you have enough data, you’re likely better off training from scratch instead of finetuning.

To train a task from scratch:

  1. Load your data and organize it using a DataModule customized for the task (example: ImageClassificationData).

  2. Choose and initialize your Task (setting pretrained=False) which has state-of-the-art backbones built in (example: ImageClassifier).

  3. Init a flash.core.trainer.Trainer or a pytorch_lightning.trainer.Trainer.

  4. Call flash.core.trainer.Trainer.fit() with your data set.

  5. Save your trained model.


Here’s an example:

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    batch_size=1,
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Train the model
trainer.fit(model, datamodule=datamodule)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

A few Built-in Tasks

More tasks coming soon!

Contribute a task

The lightning + Flash team is hard at work building more tasks for common deep-learning use cases. But we’re looking for incredible contributors like you to submit new tasks!

Join our Slack to get help becoming a contributor!

Installation & Troubleshooting

Installation options

Install with pip

pip install lightning-flash

Optionally, you can install Flash with extra packages for each domain.

For a single domain, use: pip install 'lightning-flash[{DOMAIN}]'.

pip install 'lightning-flash[image]'
pip install 'lightning-flash[tabular]'
pip install 'lightning-flash[text]'
...

For multiple domains, use: pip install 'lightning-flash[{DOMAIN_1, DOMAIN_2, ...}]'.

pip install 'lightning-flash[audio,image]'
...

For contributors, please install Flash with packages for testing Flash and building docs.

# Clone Flash repository locally
git clone https://github.com/[your username]/lightning-flash.git
cd lightning-flash

# Install Flash in editable mode with extra packages for development
pip install -e '.[dev]'

Install with conda

Flash is available via conda forge. Install it with:

conda install -c conda-forge lightning-flash

Install from source

You can install Flash from source without any domain specific dependencies with:

pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git'

To install Flash with domain dependencies, use:

pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[image]'

You can again install dependencies for multiple domains by separating them with commas as above.

Troubleshooting

Torchtext incompatibility

If you install Flash in an environment that already has a version of torchtext installed, you may see an error like this when you try to import it:

ImportError: /usr/local/lib/python3.7/dist-packages/torchtext/_torchtext.so: undefined symbol: _ZN2at6detail10noopDeleteEPv

The workaround is to uninstall torchtext before installing Flash, like this:

pip uninstall -y torchtext
pip install lightning-flash[...]

FiftyOne incompatibility on Google Colab

When installing Flash (or PyTorch Lightning) alongside FiftyOne in a Google Colab environment, you may get the following error when importing FiftyOne:

ServiceListenTimeout: fiftyone.core.service.DatabaseService failed to bind to port

There is no known workaround for this issue at the time of writing, but you can view the latest updates on the associated github issue.

Flash Zero

Flash Zero is a zero-code machine learning platform. Here’s an image classification example to illustrate with one of the dozens tasks available.

Flash Zero in 3 steps

1. Select your task

flash {TASK_NAME}

Here is the list of currently supported tasks.

audio_classification     Classify audio spectrograms.
graph_classification     Classify graphs.
image_classification     Classify images.
instance_segmentation    Segment object instances in images.
keypoint_detection       Detect keypoints in images.
object_detection         Detect objects in images.
pointcloud_detection     Detect objects in point clouds.
pointcloud_segmentation  Segment objects in point clouds.
question_answering       Extractive Question Answering.
semantic_segmentation    Segment objects in images.
speech_recognition       Speech recognition.
style_transfer           Image style transfer.
summarization            Summarize text.
tabular_classification   Classify tabular data.
text_classification      Classify text.
translation              Translate text.
video_classification     Classify videos.

2. Pass in your own data

flash image_classification from_folders --train_folder data/hymenoptera_data/train

3. Modify the model and training parameters

flash image_classification --trainer.max_epochs 10 --model.backbone resnet50 from_folders --train_folder data/hymenoptera_data/train

Note

The trainer and model arguments should be placed before the source subcommand. Here it is from_folders.

Other Examples

Image Object Detection

To train an Object Detector on COCO 2017 dataset, you could use the following command:

flash object_detection from_coco --train_folder data/coco128/images/train2017/ --train_ann_file data/coco128/annotations/instances_train2017.json --val_split .3 --batch_size 8 --num_workers 4

Image Object Segmentation

To train an Image Segmenter on CARLA driving simulator dataset

flash semantic_segmentation from_folders --train_folder data/CameraRGB --train_target_folder data/CameraSeg --num_classes 21

Below is an example where the head, the backbone and its pretrained weights are customized.

flash semantic_segmentation --model.head fpn --model.backbone efficientnet-b0 --model.pretrained advprop from_folders --train_folder data/CameraRGB --train_target_folder data/CameraSeg --num_classes 21

Video Classification

To train an Video Classifier on the Kinetics dataset, you could use the following command:

flash video_classification from_folders --train_folder data/kinetics/train/ --clip_duration 1 --num_workers 0

CLI options

Flash Zero is built on top of the lightning CLI, so the trainer and model arguments can be configured either from the command line or from a config file. For example, to run the image classifier for 10 epochs with a resnet50 backbone you can use:

flash image_classification --trainer.max_epochs 10 --model.backbone resnet50

To view all of the available options for a task, run:

flash image_classification --help

Using Your Own Data

Flash Zero works with your own data through subcommands. The available subcommands for each task are given at the bottom of their help pages (e.g. when running flash image-classification --help). You can then use the required subcommand to train on your own data. Let’s look at an example using the Hymenoptera data from the Image Classification guide. First, download and unzip your data:

curl https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip -o hymenoptera_data
unzip hymenoptera_data.zip

Now train with Flash Zero:

flash image_classification from_folders --train_folder ./hymenoptera_data/train

Getting Help

To find all available tasks, you can run:

flash --help

This will output the following:

Commands:
audio_classification     Classify audio spectrograms.
graph_classification     Classify graphs.
image_classification     Classify images.
instance_segmentation    Segment object instances in images.
keypoint_detection       Detect keypoints in images.
object_detection         Detect objects in images.
pointcloud_detection     Detect objects in point clouds.
pointcloud_segmentation  Segment objects in point clouds.
question_answering       Extractive Question Answering.
semantic_segmentation    Segment objects in images.
speech_recognition       Speech recognition.
style_transfer           Image style transfer.
summarization            Summarize text.
tabular_classification   Classify tabular data.
text_classification      Classify text.
translation              Translate text.
video_classification     Classify videos.

To get more information about a specific task, you can do the following:

flash image_classification --help

You can view the help page for each subcommand. For example, to view the options for training an image classifier from folders, you can run:

flash image_classification from_folders --help

Finally, you can generate a config.yaml file from the client to ease parameters modification by running:

flash image_classification --print_config > config.yaml

Flash in Production

Flash Serve

Flash Serve makes model deployment simple.

Server Side

from flash.image import SemanticSegmentation
from flash.image.segmentation.output import SegmentationLabelsOutput

model = SemanticSegmentation.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/semantic_segmentation_model.pt"
)
model.output = SegmentationLabelsOutput(visualize=False)
model.serve()

Client Side

import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "road.png").open("rb") as f:
    imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())

Credits to @rlizzo, @hhsecond, @lantiga, @luiscape for building the Flash Serve Engine. Read all about it here.

Electricity Price Forecasting with N-BEATS

This tutorial covers using Lightning Flash and it’s integration with PyTorch Forecasting to train an autoregressive model (N-BEATS) on hourly electricity pricing data. We show how the built-in interpretability tools from PyTorch Forecasting can be used with Flash to plot the trend and daily seasonality in our data discovered by the model. We also cover how features from PyTorch Lightning such as the learning rate finder can be used easily with Flash. As a bonus, we show hat we can resample daily observations from the data to discover weekly trends instead.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Slack

Setup

This notebook requires some packages besides pytorch-lightning.

[1]:
! pip install --quiet "pytorch-lightning>=1.3" "ipython[notebook]" "torch>=1.6, <1.9" "pandas==1.1.5" "torchmetrics>=0.3" "lightning-flash[tabular]>=0.6.0"

In this tutorial we’ll look at using Lightning Flash and it’s integration with PyTorch Forecasting for autoregressive modelling of electricity prices using the N-BEATS model. We’ll start by using N-BEATS to uncover daily patterns (seasonality) from hourly observations and then show how we can resample daily averages to uncover weekly patterns too.

Along the way, we’ll see how the built-in tools from PyTorch Lightning, like the learning rate finder, can be used seamlessly with Flash to help make the process of putting a model together as smooth as possible.

[2]:

import os from typing import Any, Dict import flash import matplotlib.pyplot as plt import pandas as pd import torch from flash.core.data.utils import download_data from flash.core.integrations.pytorch_forecasting import convert_predictions from flash.tabular.forecasting import TabularForecaster, TabularForecastingData DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")

Loading the data

We’ll use the Spanish hourly energy demand generation and weather data set from Kaggle: https://www.kaggle.com/nicholasjhana/energy-consumption-generation-prices-and-weather

First, download the data:

[3]:
download_data("https://pl-flash-data.s3.amazonaws.com/kaggle_electricity.zip", DATASET_PATH)
/usr/local/lib/python3.9/dist-packages/urllib3/connectionpool.py:1013: InsecureRequestWarning: Unverified HTTPS request is being made to host 'pl-flash-data.s3.amazonaws.com'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html#ssl-warnings
  warnings.warn(

Data loading

To load the data, we start by loading the CSV file into a pandas DataFrame:

[4]:
df_energy_hourly = pd.read_csv(f"{DATASET_PATH}/energy_dataset.csv", parse_dates=["time"])

Before we can load the data into Flash, there are a few preprocessing steps we need to take. The first preprocessing step is to set the time field as the index (formatted as a datetime). The second step is to resample the data to the desired frequency in case it is different from the desired observation frequency. Since we are performing autoregressive modelling, we can remove all columns except for "price actual".

For the third preprocessing step, we need to create a “time_idx” column. The “time_idx” column should contain integers corresponding to the observation index (e.g. in our case the difference between two “time_idx” values is the number of hours between the observations). To do this we convert the datetime to an index by taking the nanoseconds value and dividing by the number of nanoseconds in a single unit of our chosen frequency. We then subtract the minimum value so it starts at zero (although it would still work without this step).

The Flash TabularForecastingData (which uses the TimeSeriesDataSet from PyTorch Forecasting internally) also supports loading data from multiple time series (e.g. you may have electricity data from multiple countries). To indicate that our data is all from the same series, we add a constant column with a constant value of zero.

Here’s the full preprocessing function:

[5]:


def preprocess(df: pd.DataFrame, frequency: str = "1H") -> pd.DataFrame: df["time"] = pd.to_datetime(df["time"], utc=True, infer_datetime_format=True) df.set_index("time", inplace=True) df = df.resample(frequency).mean() df = df.filter(["price actual"]) df["time_idx"] = (df.index.view(int) / pd.Timedelta(frequency).value).astype(int) df["time_idx"] -= df["time_idx"].min() df["constant"] = 0 return df df_energy_hourly = preprocess(df_energy_hourly)

Creating the Flash DataModule

Now, we can create a TabularForecastingData. The role of the TabularForecastingData is to split up our time series into windows which include a region to encode (of size max_encoder_length) and a region to predict (of size max_prediction_length) which will be used to compute the loss. The size of the prediction window should be chosen depending on the kinds of trends we would like our model to uncover. In our case, we are interested in how electricity prices change throughout the day, so a one day prediction window (max_prediction_length = 24) makes sense here. The size of the encoding window can vary, however, in the N-BEATS paper the authors suggest using an encoder length of between two and ten times the prediction length. We therefore choose two days (max_encoder_length = 48) as the encoder length.

[6]:
max_prediction_length = 24
max_encoder_length = 24 * 2

training_cutoff = df_energy_hourly["time_idx"].max() - max_prediction_length

datamodule = TabularForecastingData.from_data_frame(
    time_idx="time_idx",
    target="price actual",
    group_ids=["constant"],
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    time_varying_unknown_reals=["price actual"],
    train_data_frame=df_energy_hourly[df_energy_hourly["time_idx"] <= training_cutoff],
    val_data_frame=df_energy_hourly,
    batch_size=256,
)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/core/datamodule.py:73: LightningDeprecationWarning: DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
  rank_zero_deprecation(

Creating the Flash Task

Now, we’re ready to create a TabularForecaster. The N-BEATS model has two primary hyper-parameters:"widths", and "backcast_loss_ratio". In the PyTorch Forecasting Documentation, the authors recommend using "widths" of [32, 512]. In order to prevent overfitting with smaller datasets, a good rule of thumb is to limit the number of parameters of your model. For this reason, we use "widths" of [16, 256].

To understand the "backcast_loss_ratio", let’s take a look at this diagram of the model taken from the arXiv paper:

N-BEATS diagram

Each ‘block’ within the N-BEATS architecture includes a forecast output and a backcast which can each yield their own loss. The "backcast_loss_ratio" is the ratio of the backcast loss to the forecast loss. A value of 1.0 means that the loss function is simply the sum of the forecast and backcast losses.

[7]:
model = TabularForecaster(
    datamodule.parameters, backbone="n_beats", backbone_kwargs={"widths": [16, 256], "backcast_loss_ratio": 1.0}
)
/usr/local/lib/python3.9/dist-packages/numpy/lib/nanfunctions.py:1119: RuntimeWarning: All-NaN slice encountered
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
Using 'n_beats' provided by jdb78/PyTorch-Forecasting (https://github.com/jdb78/pytorch-forecasting).

Finding the learning rate

Tabular models can be particularly sensitive to the choice of learning rate. Helpfully, PyTorch Lightning provides a built-in learning rate finder that suggests a suitable learning rate automatically. To use it, we first create our Trainer. We apply gradient clipping (a common technique for tabular tasks) with gradient_clip_val=0.01 in order to help prevent our model from over-fitting. Here’s how to find the learning rate:

[8]:
trainer = flash.Trainer(
    max_epochs=3,
    gpus=int(torch.cuda.is_available()),
    gradient_clip_val=0.01,
)

res = trainer.tuner.lr_find(model, datamodule=datamodule, min_lr=1e-5)
print(f"Suggested learning rate: {res.suggestion()}")
res.plot(show=True, suggest=True).show()
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:116: UserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:116: UserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:59: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 256. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
  warning_cache.warn(
LR finder stopped early after 81 steps due to diverging loss.
Restoring states from the checkpoint path at /__w/1/s/lr_find_temp_model_18513590-95dc-42e0-874c-d3607a3fe945.ckpt
Suggested learning rate: 0.0007079457843841378
_images/notebooks_flash_tutorials_electricity_forecasting_16_4.png

Once the suggest learning rate has been found, we can update our model with it:

[9]:
model.learning_rate = res.suggestion()

Training the model

Now all we have to do is train the model!

[10]:
trainer.fit(model, datamodule=datamodule)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/core/datamodule.py:469: LightningDeprecationWarning: DataModule.setup has already been called, so it will not be called again. In v1.6 this behavior will change to always call DataModule.setup.
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type                      | Params
------------------------------------------------------------
0 | train_metrics | ModuleDict                | 0
1 | val_metrics   | ModuleDict                | 0
2 | test_metrics  | ModuleDict                | 0
3 | adapter       | PyTorchForecastingAdapter | 454 K
------------------------------------------------------------
454 K     Trainable params
0         Non-trainable params
454 K     Total params
1.820     Total estimated model params size (MB)
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/core/datamodule.py:469: LightningDeprecationWarning: DataModule.teardown has already been called, so it will not be called again. In v1.6 this behavior will change to always call DataModule.teardown.
  rank_zero_deprecation(

Plot the interpretation

An important feature of the N-BEATS model is that it can be configured to produce an interpretable prediction that is split into both a low frequency (trend) component and a high frequency (seasonality) component. For hourly observations, we might expect the trend component to show us how electricity prices are changing from one day to the next (for example, whether prices were generally higher or lower than yesterday). In contrast, the seasonality component would be expected to show us the general pattern in prices through the day (for example, if there is typically a peak in price around lunch time or a drop at night).

It is often useful to visualize this decomposition and the TabularForecaster makes it simple. First, we load the best model from our training run and generate some predictions. Next, we convert the predictions to the format expected by PyTorch Forecasting using the convert_predictions utility function. Finally, we plot the interpretation using the pytorch_forecasting_model attribute. Here’s the full function:

[11]:


def plot_interpretation(model_path: str, predict_df: pd.DataFrame, parameters: Dict[str, Any]): model = TabularForecaster.load_from_checkpoint(model_path) datamodule = TabularForecastingData.from_data_frame( parameters=parameters, predict_data_frame=predict_df, batch_size=256, ) trainer = flash.Trainer(gpus=int(torch.cuda.is_available())) predictions = trainer.predict(model, datamodule=datamodule) predictions, inputs = convert_predictions(predictions) model.pytorch_forecasting_model.plot_interpretation(inputs, predictions, idx=0) plt.show()

And now we run the function to plot the trend and seasonality curves:

[12]:
plot_interpretation(trainer.checkpoint_callback.best_model_path, df_energy_hourly, datamodule.parameters)
/usr/local/lib/python3.9/dist-packages/numpy/lib/nanfunctions.py:1119: RuntimeWarning: All-NaN slice encountered
  r, k = function_base._ureduce(a, func=_nanmedian, axis=axis, out=out,
Using 'n_beats' provided by jdb78/PyTorch-Forecasting (https://github.com/jdb78/pytorch-forecasting).
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/core/datamodule.py:73: LightningDeprecationWarning: DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7.
  rank_zero_deprecation(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:116: UserWarning: The dataloader, predict_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/home/AzDevOps_azpcontainer/.local/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py:172: UserWarning: Lightning couldn't infer the indices fetched for your dataloader.
  warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
_images/notebooks_flash_tutorials_electricity_forecasting_24_3.png

It worked! The plot shows that the TabularForecaster does a reasonable job of modelling the time series and also breaks it down into a trend component and a seasonality component (in this case showing daily fluctuations in electricity prices).

Closing thoughts and next steps!

This tutorial has shown how Flash and PyTorch Forecasting can be used to train state-of-the-art auto-regressive forecasting models (such as N-BEATS). We’ve seen how we can influence the kinds of trends and patterns uncovered by the model by resampling the data and changing the hyper-parameters.

There are plenty of ways you could take this tutorial further. For example, you could try a more complex model, such as the temporal fusion transformer, which can handle additional inputs (the kaggle data set we used also includes weather data).

Alternatively, if you want to be a bit more adventurous, you could look at some of the other problems that can solved with Lightning Flash.

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning

Training from scratch

Some Flash tasks have been pretrained on large data sets. To accelerate your training, calling the finetune() method using a pretrained backbone will fine-tune the backbone to generate a model customized to your data set and desired task.

From the Quick Start guide.

To train a task from scratch:

  1. Load your data and organize it using a DataModule customized for the task (example: ImageClassificationData).

  2. Choose and initialize your Task (setting pretrained=False) which has state-of-the-art backbones built in (example: ImageClassifier).

  3. Init a flash.core.trainer.Trainer or a pytorch_lightning.trainer.Trainer.

  4. Call flash.core.trainer.Trainer.fit() with your data set.

  5. Save your trained model.


Here’s an example:

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    batch_size=1,
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Train the model
trainer.fit(model, datamodule=datamodule)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Training options

Flash tasks supports many advanced training functionalities out-of-the-box, such as:

  • limit number of epochs

# train for 10 epochs
flash.Trainer(max_epochs=10)
  • Training on GPUs

# train on 1 GPU
flash.Trainer(gpus=1)
  • Training on multiple GPUs

# train on multiple GPUs
flash.Trainer(gpus=4)
# train on gpu 1, 3, 5 (3 gpus total)
flash.Trainer(gpus=[1, 3, 5])
  • Using mixed precision training

# Multi GPU with mixed precision
flash.Trainer(gpus=2, precision=16)
  • Training on TPUs

# Train on TPUs
flash.Trainer(tpu_cores=8)

You can add to the flash Trainer any argument from the Lightning trainer! Learn more about the Lightning Trainer here.

Finetuning

Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset.


Terminology

Here are common terms you need to be familiar with:

Terminology

Term

Definition

Finetuning

The process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset

Transfer learning

The common name for finetuning

Backbone

The neural network that was pretrained on a different dataset

Head

Another neural network (usually smaller) that maps the backbone to your particular dataset

Freeze

Disabling gradient updates to a model (ie: not learning)

Unfreeze

Enabling gradient updates to a model


Finetuning in Flash

From the Quick Start guide.

To use a Task for finetuning:

  1. Load your data and organize it using a DataModule customized for the task (example: ImageClassificationData).

  2. Choose and initialize your Task which has state-of-the-art backbones built in (example: ImageClassifier).

  3. Init a flash.core.trainer.Trainer.

  4. Choose a finetune strategy (example: “freeze”) and call flash.core.trainer.Trainer.finetune() with your data.

  5. Save your finetuned model.


Here’s an example of finetuning.

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    batch_size=1,
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Using a finetuned model

Once you’ve finetuned, use the model to predict:

predict_datamodule = ImageClassificationData.from_files(
    predict_files=[
        "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
        "data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
    ],
    batch_size=1,
)
predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
print(predictions)

We get the following output:

[['bees', 'ants']]

Or you can use the saved model for prediction anywhere you want!

from flash import Trainer
from flash.image import ImageClassifier, ImageClassificationData

# load finetuned checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")

trainer = Trainer()
datamodule = ImageClassificationData.from_files(predict_files=["path/to/your/own/image.png"])
predictions = trainer.predict(model, datamodule=datamodule)

Finetune strategies

Finetuning is very task specific. Each task encodes the best finetuning practices for that task. However, Flash gives you a few default strategies for finetuning.

Finetuning operates on two things, the model backbone and the head. The backbone is the neural network that was pre-trained. The head is another neural network that bridges between the backbone and your particular dataset.

no_freeze

In this strategy, the backbone and the head are unfrozen from the beginning.

trainer.finetune(model, datamodule, strategy="no_freeze")

In pseudocode, this looks like:

backbone = Resnet50()
head = nn.Linear(...)

backbone.unfreeze()
head.unfreeze()

train(backbone, head)

freeze

The freeze strategy keeps the backbone frozen throughout.

trainer.finetune(model, datamodule, strategy="freeze")

The pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head)

Advanced strategies

Every finetune strategy can also be customized.

freeze_unfreeze

The freeze_unfreeze strategy keeps the backbone frozen until a certain epoch (provided in a tuple to the strategy argument) after which the backbone will be unfrozen.

For example, to unfreeze after epoch 7:

trainer.finetune(model, datamodule, strategy=("freeze_unfreeze", 7))

Under the hood, the pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head, epochs=10)

# unfreeze after 7 epochs
backbone.unfreeze()

train(backbone, head)

unfreeze_milestones

This strategy allows you to unfreeze part of the backbone at predetermined intervals.

Here’s an example where:

  • backbone starts frozen

  • at epoch 3 the last 2 layers unfreeze

  • at epoch 8 the full backbone unfreezes


trainer.finetune(model, datamodule, strategy=("unfreeze_milestones", ((3, 8), 2)))

Under the hood, the pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head, epochs=3)

# unfreeze last 2 layers at epoch 3
backbone.unfreeze_last_layers(2)

train(backbone, head, epochs=8)

# unfreeze the full backbone
backbone.unfreeze()

Custom Strategy

For even more customization, create your own finetuning callback. Learn more about callbacks here.

from flash.core.finetuning import FlashBaseFinetuning

# Create a finetuning callback
class FeatureExtractorFreezeUnfreeze(FlashBaseFinetuning):
    def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True):
        # this will set self.attr_names as ["backbone"]
        super().__init__("backbone", train_bn)
        self._unfreeze_epoch = unfreeze_epoch

    def finetune_function(self, pl_module, current_epoch, optimizer, opt_idx):
        # unfreeze any module you want by overriding this function

        # When ``current_epoch`` is 5, backbone will start to be trained.
        if current_epoch == self._unfreeze_epoch:
            self.unfreeze_and_add_param_group(
                pl_module.backbone,
                optimizer,
            )


# Pass the callback to trainer.finetune
trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5))

Predictions (inference)

You can use Flash to get predictions on pretrained or finetuned models.

First create a DataModule with some predict data, then pass it to the Trainer.predict method.

from flash import Trainer
from flash.core.data.utils import download_data
from flash.image import ImageClassifier, ImageClassificationData

# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt"
)

# 3. Predict whether the image contains an ant or a bee
trainer = Trainer()
datamodule = ImageClassificationData.from_files(
    predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"]
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# out: [["bees"]]

Serializing predictions

To change the output format of predictions you can attach an Output to your Task. For example, you can choose to output probabilities (for more options see the API reference below).

from flash.core.classification import ProbabilitiesOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassifier


# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt"
)

# 3. Attach the Output
model.output = ProbabilitiesOutput()

# 4. Predict whether the image contains an ant or a bee
trainer = Trainer()
datamodule = ImageClassificationData.from_files(
    predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"]
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# out: [[[0.5926494598388672, 0.40735048055648804]]]

Note

PyTorch Lightning does not return predictions directly from predict when using a multi-GPU configuration (DDP). Instead you should use a pytorch_lightning.callbacks.BasePredictionWriter.

TorchScript JIT Support

We test all of our tasks for compatibility with torch.jit. This table gives a breakdown of the supported features.

Task

torch.jit.script()

torch.jit.trace()

torch.jit.save()

ImageClassifier

Yes

Yes

Yes

ObjectDetector

Yes

No

Yes

ImageEmbedder

Yes

Yes

Yes

SemanticSegmentation

No

Yes

Yes

StyleTransfer

No

Yes

Yes

TabularClassifier

No

Yes

No

TextClassifier

No

Yes *

Yes

SummarizationTask

No

Yes

Yes

TranslationTask

No

Yes

Yes

VideoClassifier

No

Yes

Yes

* with strict=False

Data

Note

The contents of this page are currently being updated. Stay tuned!

Registry

Available Registries

Registries are Flash internal key-value database to store a mapping between a name and a function.

In simple words, they are just advanced dictionary storing a function from a key string.

Registries help organize code and make the functions accessible all across the Flash codebase. Each Flash Task can have several registries as static attributes.

Currently, Flash uses internally registries only for backbones, but more components will be added.

1. Imports

from functools import partial

from flash import Task
from flash.core.registry import FlashRegistry

2. Init a Registry

It is good practice to associate one or multiple registry to a Task as follow:

# creating a custom `Task` with its own registry
class MyImageClassifier(Task):

    backbones = FlashRegistry("backbones")

    def __init__(
        self,
        backbone: str = "resnet18",
        pretrained: bool = True,
    ):
        ...

        self.backbone, self.num_features = self.backbones.get(backbone)(pretrained=pretrained)

3. Adding new functions

Your custom functions can be registered within a FlashRegistry as a decorator or directly.

# Option 1: Used with partial.
def fn(backbone: str, pretrained: bool = True):
    # Create backbone and backbone output dimension (`num_features`)
    backbone, num_features = None, None
    return backbone, num_features


# HINT 1: Use `from functools import partial` if you want to store some arguments.
MyImageClassifier.backbones(fn=partial(fn, backbone="my_backbone"), name="username/partial_backbone")

# Option 2: Using decorator.
@MyImageClassifier.backbones(name="username/decorated_backbone")
def fn(pretrained: bool = True):
    # Create backbone and backbone output dimension (`num_features`)
    backbone, num_features = None, None
    return backbone, num_features

4. Accessing registered functions

You can now access your function from your task!

# 3.b Optional: List available backbones
print(MyImageClassifier.available_backbones())

# 4. Build the model
model = MyImageClassifier(backbone="username/decorated_backbone")

Here’s the output:

['username/decorated_backbone', 'username/partial_backbone']

5. Pre-registered backbones

Flash provides populated registries containing lots of available backbones.

Example:

from flash.image.backbones import OBJ_DETECTION_BACKBONES
from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES

print(IMAGE_CLASSIFIER_BACKBONES.available_keys())
""" out:
['adv_inception_v3', 'cspdarknet53', 'cspdarknet53_iabn', 430+.., 'xception71']
"""

Flash Serve

Flash Serve is a library to easily serve models in production.

Terminology

Here are common terms you need to be familiar with:

Terminology

Term

Definition

de-serialization

Transform data encoded as text into tensors

inference function

A function taking the decoded tensors and forward them through the model to produce predictions.

serialization

Transform the predictions tensors back to a text encoding.

ModelComponent

The ModelComponent contains the de-serialization, inference and serialization functions.

Servable

The Servable is an helper track the asset file related to a model

Composition

The Composition defines the computations / endpoints to create & run

expose()

The expose() function is a python decorator used to augment the ModelComponent inference function with de-serialization, serialization.

Example

In this tutorial, we will serve a Resnet18 from the PyTorchVision library in 3 steps.

The entire tutorial can be found under flash_examples/serve/generic.

Introduction

Traditionally, an inference pipeline is made out of 3 steps:

  • de-serialization: Transform data encoded as text into tensors.

  • inference function: A function taking the decoded tensors and forward them through the model to produce predictions.

  • serialization: Transform the predictions tensors back as text.

In this example, we will implement only the inference function as Flash Serve already provides some built-in de-serialization and serialization functions with Image

Step 1 - Create a ModelComponent

Inside inference_serve.py, we will implement a ClassificationInference class, which overrides ModelComponent.

First, we need make the following imports:

import torch
import torchvision

from flash.core.serve import Composition, Servable, ModelComponent, expose
from flash.core.serve.types import Image, Label
Data Serving Flow

To implement ClassificationInference, we need to implement a method responsible for inference function and decorated with the expose() function.

The name of the inference method isn’t constrained, but we will use classify as appropriate in this example.

Our classify function will take a tensor image, apply some normalization on it, and forward it through the model.

def classify(img):
    img = img.float() / 255
    mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
    std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
    img = (img - mean) / std
    img = img.permute(0, 3, 2, 1)
    out = self.model(img)
    return out.argmax()

The expose() is a python decorator extending the decorated function with the de-serialization, serialization steps.

Note

Flash Serve was designed this way to enable several models to be chained together by removing the decorator.

The expose() function takes 2 arguments:

  • inputs: Dictionary mapping the decorated function inputs to BaseType objects.

  • outputs: Dictionary mapping the decorated function outputs to BaseType objects.

A BaseType is a python dataclass which implements a serialize and deserialize function.

Note

Flash Serve has already several BaseType built-in such as Image or Text.

class ClassificationInference(ModelComponent):
    def __init__(self, model: Servable):
        self.model = model

    @expose(
        inputs={"img": Image()},
        outputs={"prediction": Label(path="imagenet_labels.txt")},
    )
    def classify(self, img):
        img = img.float() / 255
        mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
        std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
        img = (img - mean) / std
        img = img.permute(0, 3, 2, 1)
        out = self.model(img)
        return out.argmax()

Step 2 - Create a scripted Model

Using the PyTorchVision library, we create a resnet18 and use torch.jit.script to script the model.

Note

TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.

model = torchvision.models.resnet18(pretrained=True).eval()
torch.jit.script(model).save("resnet.pt")

Step 3 - Serve the model

The Servable takes as argument the path to the TorchScripted model and then will be passed to our ClassificationInference class.

The ClassificationInference instance will be passed as argument to a Composition class.

Once the Composition class is instantiated, just call its serve() method.

resnet = Servable("resnet.pt")
comp = ClassificationInference(resnet)
composition = Composition(classification=comp)
composition.serve()

Launching the server.

In Terminal 1

Just run:

python inference_server.py

And you should see this in your terminal

Data Serving Flow

You should also see an Swagger UI already built for you at http://127.0.0.1:8000/docs

Data Serving Flow
In Terminal 2

Run this script from another terminal:

import base64
from pathlib import Path

import requests

with Path("fish.jpg").open("rb") as f:
    imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"img": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
# {'session': 'UUID', 'result': {'prediction': 'goldfish, Carassius auratus'}}

Credits to @rlizzo, @hhsecond, @lantiga, @luiscape for building Flash Serve Engine.

Backbones and Heads

Backbones are the pre trained models that can be used with a task. The backbones or heads that are available can be found by using the available_backbones and available_heads methods.

To get the available backbones for a task like ImageClassifier, run:

from flash.image import ImageClassifier

# get the backbones available for ImageClassifier
backbones = ImageClassifier.available_backbones()

# print the backbones
print(backbones)

To get the available heads for a task like SemanticSegmentation, run:

from flash.image import SemanticSegmentation

# get the heads available for SemanticSegmentation
heads = SemanticSegmentation.available_heads()

# print the heads
print(heads)

Optimization (Optimizers and Schedulers)

Using optimizers and learning rate schedulers with Flash has become easier and cleaner than ever.

With the use of Registry, instantiation of an optimzer or a learning rate scheduler can done with just a string.

Setting an optimizer to a task

Each task has a built-in method available_optimizers() which will list all the optimizers registered with Flash.

>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_optimizers()  
[...'adadelta', ..., 'sgd'...]

To train / finetune a Task of your choice, just pass on a string.

>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4)
>>> model.configure_optimizers()  
Adam ...

In order to customize specific parameters of the Optimizer, pass along a dictionary of kwargs with the string as a tuple.

>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=("Adam", {"amsgrad": True}), learning_rate=1e-4)
>>> model.configure_optimizers()  
Adam ( ... amsgrad: True ...)

An alternative to customizing an optimizer using a tuple is to pass it as a callable.

>>> from functools import partial
>>> from torch.optim import Adam
>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(num_classes=10, backbone="resnet18", optimizer=partial(Adam, amsgrad=True), learning_rate=1e-4)
>>> model.configure_optimizers()  
Adam ( ... amsgrad: True ...)

Setting a Learning Rate Scheduler

Each task has a built-in method available_lr_schedulers() which will list all the learning rate schedulers registered with Flash.

>>> from flash.core.classification import ClassificationTask
>>> ClassificationTask.available_lr_schedulers()  
[...'cosineannealingwarmrestarts', ..., 'lambdalr'...]

To train / finetune a Task with a scheduler of your choice, just pass in the name:

>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(
...     num_classes=10, backbone="resnet18", optimizer="Adam", learning_rate=1e-4, lr_scheduler="constant_schedule"
... )
>>> model.configure_optimizers()  
([Adam ...], [{'scheduler': ...}])

Note

"constant_schedule" and a few other lr schedulers will be available only if you have installed the transformers library from Hugging Face.

In order to customize specific parameters of the LR Scheduler, pass along a dictionary of kwargs with the string as a tuple.

>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(
...     num_classes=10,
...     backbone="resnet18",
...     optimizer="Adam",
...     learning_rate=1e-4,
...     lr_scheduler=("StepLR", {"step_size": 10}),
... )
>>> scheduler = model.configure_optimizers()[1][0]["scheduler"]
>>> scheduler.step_size  
10

An alternative to customizing the LR Scheduler using a tuple is to pass it as a callable.

>>> from functools import partial
>>> from torch.optim.lr_scheduler import CyclicLR
>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(
...     num_classes=10,
...     backbone="resnet18",
...     optimizer="SGD",
...     learning_rate=1e-4,
...     lr_scheduler=partial(CyclicLR, base_lr=0.001, max_lr=0.1, mode="exp_range", gamma=0.5),
... )
>>> scheduler = model.configure_optimizers()[1][0]["scheduler"]
>>> (scheduler.mode, scheduler.gamma)  
('exp_range', 0.5)

Additionally, the lr_scheduler parameter also accepts the Lightning Scheduler configuration which can be passed on using a tuple.

The Lightning Scheduler configuration is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Flash requires that the Lightning Scheduler configuration contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on. Below is an example for this:

>>> from flash.image import ImageClassifier
>>> model = ImageClassifier(
...     num_classes=10,
...     backbone="resnet18",
...     optimizer="Adam",
...     learning_rate=1e-4,
...     lr_scheduler=("reducelronplateau", {"mode": "max"}, {"monitor": "val_accuracy"}),
... )
>>> model.configure_optimizers()  
([Adam ...], [{'scheduler': ..., 'monitor': 'val_accuracy', ...}])

Note

Do not set the "scheduler" key in the Lightning Scheduler configuration, it will overridden with an instance of the provided scheduler key.

Pre-Registering optimizers and scheduler recipes

Flash registry also provides the flexiblty of registering functions. This feature is also provided in the Optimizer and Scheduler registry.

Using the optimizers and lr_schedulers decorator pertaining to each Task, custom optimizer and LR scheduler recipes can be pre-registered.

>>> import torch
>>> from flash.image import ImageClassifier
>>> @ImageClassifier.lr_schedulers
... def my_flash_steplr_recipe(optimizer):
...     return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
...
>>> model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_flash_steplr_recipe")
>>> scheduler = model.configure_optimizers()[1][0]["scheduler"]
>>> scheduler.step_size  
10

Provider specific requirements

Schedulers

Certain LR Schedulers provided by Hugging Face require both num_training_steps and num_warmup_steps.

In order to use them in Flash, just provide num_warmup_steps as float between 0 and 1 which indicates the fraction of the training steps that will be used as warmup steps. Flash’s Trainer will take care of computing the number of training steps and number of warmup steps based on the flags that are set in the Trainer.

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=["cat", "dog", "cat"],
...     predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> model = ImageClassifier(
...     backbone="resnet18",
...     num_classes=datamodule.num_classes,
...     optimizer="Adam",
...     lr_scheduler=("cosine_schedule_with_warmup", {"num_warmup_steps": 0.1}),
... )
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

Formatting Classification Targets

This guide details the different target formats supported by classification tasks in Flash. By default, the target format and any additional metadata (labels, num_classes, multi_label) will be inferred from your training data. You can override this behaviour by passing your own TargetFormatter using the target_formatter argument.

Single Label

Classification targets are described as single label (DataModule.multi_label = False) if each data sample corresponds to a single class.

Class Indexes

Targets formatted as class indexes are represented by a single number, e.g. train_targets = [0, 1, 0]. No labels will be inferred. The inferred num_classes is the maximum index plus one (we assume that class indexes are zero-based). Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=[0, 1, 0],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels is None
True
>>> datamodule.multi_label
False

Alternatively, you can provide a SingleNumericTargetFormatter to override the behaviour. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SingleNumericTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=[0, 1, 0],
...     target_formatter=SingleNumericTargetFormatter(labels=["dog", "cat", "rabbit"]),
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
False

Labels

Targets formatted as labels are represented by a single string, e.g. train_targets = ["cat", "dog", "cat"]. The inferred labels will be the unique labels in the train targets sorted alphanumerically. The inferred num_classes is the number of labels. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=["cat", "dog", "cat"],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> datamodule.multi_label
False

Alternatively, you can provide a SingleLabelTargetFormatter to override the behaviour. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=["cat", "dog", "cat"],
...     target_formatter=SingleLabelTargetFormatter(labels=["dog", "cat", "rabbit"]),
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
False

One-hot Binaries

Targets formatted as one-hot binaries are represented by a binary list with a single index (the target class index) set to 1, e.g. train_targets = [[1, 0], [0, 1], [1, 0]]. No labels will be inferred. The inferred num_classes is the length of the binary list. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=[[1, 0], [0, 1], [1, 0]],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels is None
True
>>> datamodule.multi_label
False

Alternatively, you can provide a SingleBinaryTargetFormatter to override the behaviour. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SingleBinaryTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=[[1, 0], [0, 1], [1, 0]],
...     target_formatter=SingleLabelTargetFormatter(labels=["dog", "cat"]),
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['dog', 'cat']
>>> datamodule.multi_label
False

Multi Label

Classification targets are described as multi label (DataModule.multi_label = True) if each data sample corresponds to zero or more (and perhaps many) classes.

Class Indexes

Targets formatted as multi label class indexes are represented by a list of class indexes, e.g. train_targets = [[0], [0, 1], [1, 2]]. No labels will be inferred. The inferred num_classes is the maximum target value plus one (we assume that targets are zero-based). Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=[[0], [0, 1], [1, 2]],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels is None
True
>>> datamodule.multi_label
True

Alternatively, you can provide a MultiNumericTargetFormatter to override the behaviour. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import MultiNumericTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=[[0], [0, 1], [1, 2]],
...     target_formatter=MultiNumericTargetFormatter(labels=["dog", "cat", "rabbit"]),
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True

Labels

Targets formatted as multi label are represented by a list of strings, e.g. train_targets = [["cat"], ["cat", "dog"], ["dog", "rabbit"]]. The inferred labels will be the unique labels in the train targets sorted alphanumerically. The inferred num_classes is the number of labels. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=[["cat"], ["cat", "dog"], ["dog", "rabbit"]],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['cat', 'dog', 'rabbit']
>>> datamodule.multi_label
True

Alternatively, you can provide a MultiLabelTargetFormatter to override the behaviour. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import MultiLabelTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=[["cat"], ["cat", "dog"], ["dog", "rabbit"]],
...     target_formatter=MultiLabelTargetFormatter(labels=["dog", "cat", "rabbit"]),
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True

Comma Delimited

Targets formatted as comma delimited mutli label are given as comma delimited strings, e.g. train_targets = ["cat", "cat,dog", "dog,rabbit"]. The inferred labels will be the unique labels in the train targets sorted alphanumerically. The inferred num_classes is the number of labels. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=["cat", "cat,dog", "dog,rabbit"],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['cat', 'dog', 'rabbit']
>>> datamodule.multi_label
True

Alternatively, you can provide a CommaDelimitedMultiLabelTargetFormatter to override the behaviour. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import CommaDelimitedMultiLabelTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=["cat", "cat,dog", "dog,rabbit"],
...     target_formatter=CommaDelimitedMultiLabelTargetFormatter(labels=["dog", "cat", "rabbit"]),
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True

Space Delimited

Targets formatted as space delimited mutli label are given as space delimited strings, e.g. train_targets = ["cat", "cat dog", "dog rabbit"]. The inferred labels will be the unique labels in the train targets sorted alphanumerically. The inferred num_classes is the number of labels. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=["cat", "cat dog", "dog rabbit"],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['cat', 'dog', 'rabbit']
>>> datamodule.multi_label
True

Alternatively, you can provide a SpaceDelimitedTargetFormatter to override the behaviour. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import SpaceDelimitedTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=["cat", "cat dog", "dog rabbit"],
...     target_formatter=SpaceDelimitedTargetFormatter(labels=["dog", "cat", "rabbit"]),
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True

Multi-hot Binaries

Targets formatted as one-hot binaries are represented by a binary list with a zero or more indices (the target class indices) set to 1, e.g. train_targets = [[1, 0, 0], [1, 1, 0], [0, 1, 1]]. No labels will be inferred. The inferred num_classes is the length of the binary list. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=[[1, 0, 0], [1, 1, 0], [0, 1, 1]],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels is None
True
>>> datamodule.multi_label
True

Alternatively, you can provide a MultiBinaryTargetFormatter to override the behaviour. Here’s an example:

>>> from flash import Trainer
>>> from flash.image import ImageClassifier, ImageClassificationData
>>> from flash.core.data.utilities.classification import MultiBinaryTargetFormatter
>>> datamodule = ImageClassificationData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     train_targets=[[1, 0, 0], [1, 1, 0], [0, 1, 1]],
...     target_formatter=MultiBinaryTargetFormatter(labels=["dog", "cat", "rabbit"]),
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['dog', 'cat', 'rabbit']
>>> datamodule.multi_label
True

Customizing Transforms

Note

The contents of this page are currently being updated. Stay tuned!

Image Classification

The Task

The task of identifying what is in an image is called image classification. Typically, Image Classification is used to identify images containing a single object. The task predicts which ‘class’ the image most likely belongs to with a degree of certainty. A class is a label that describes what is in an image, such as ‘car’, ‘house’, ‘cat’ etc.


Example

Let’s look at the task of predicting whether images contain Ants or Bees using the hymenoptera dataset. The dataset contains train and validation folders, and then each folder contains a bees folder, with pictures of bees, and an ants folder with images of, you guessed it, ants.

hymenoptera_data
├── train
│   ├── ants
│   │   ├── 0013035.jpg
│   │   ├── 1030023514_aad5c608f9.jpg
│   │   ...
│   └── bees
│       ├── 1092977343_cb42b38d62.jpg
│       ├── 1093831624_fb5fbe2308.jpg
│       ...
└── val
    ├── ants
    │   ├── 10308379_1b6c72e180.jpg
    │   ├── 1053149811_f62a3410d3.jpg
    │   ...
    └── bees
        ├── 1032546534_06907fe3b3.jpg
        ├── 10870992_eebeeb3a12.jpg
        ...

Once we’ve downloaded the data using download_data(), we create the ImageClassificationData. We select a pre-trained backbone to use for our ImageClassifier and fine-tune on the hymenoptera data. We then use the trained ImageClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    batch_size=4,
    transform_kwargs={"image_size": (196, 196), "mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)},
)

# 2. Build the task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict what's on a few images! ants or bees?
datamodule = ImageClassificationData.from_files(
    predict_files=[
        "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
        "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
        "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
    ],
    batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads. Benchmarks for backbones provided by PyTorch Image Models (TIMM) can be found here: https://github.com/rwightman/pytorch-image-models/blob/master/results/results-imagenet-real.csv


Flash Zero

The image classifier can be used directly from the command line with zero code using Flash Zero. You can run the hymenoptera example with:

flash image_classification

To view configuration options and options for running the image classifier with your own data, use:

flash image_classification --help

Custom Transformations

Flash automatically applies some default image transformations and augmentations, but you may wish to customize these for your own use case. The base InputTransform defines 7 hooks for different stages in the data loading pipeline. To apply custom image augmentations you can create your own InputTransform. Here’s an example:

from torchvision import transforms as T

from typing import Callable, Tuple, Union
import flash
from flash.image import ImageClassificationData, ImageClassifier
from flash.core.data.io.input_transform import InputTransform
from dataclasses import dataclass


@dataclass
class ImageClassificationInputTransform(InputTransform):

    image_size: Tuple[int, int] = (196, 196)
    mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
    std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)

    def input_per_sample_transform(self):
        return T.Compose([T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std)])

    def train_input_per_sample_transform(self):
        return T.Compose(
            [
                T.ToTensor(),
                T.Resize(self.image_size),
                T.Normalize(self.mean, self.std),
                T.RandomHorizontalFlip(),
                T.ColorJitter(),
                T.RandomAutocontrast(),
                T.RandomPerspective(),
            ]
        )

    def target_per_sample_transform(self) -> Callable:
        return torch.as_tensor


datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    train_transform=ImageClassificationInputTransform,
    transform_kwargs=dict(image_size=(128, 128)),
    batch_size=1,
)

model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

Serving

The ImageClassifier is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.image import ImageClassifier

model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt"
)
model.serve(output="labels")

You can now perform inference from your client like this:

import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
    imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())

Multi-label Image Classification

The Task

Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality (images in this case). Multi-label image classification is supported by the ImageClassifier via the multi-label argument.


Example

Let’s look at the task of trying to predict the movie genres from an image of the movie poster. The data we will use is a subset of the awesome movie poster genre prediction data set from the paper “Movie Genre Classification based on Poster Images with Deep Neural Networks” by Wei-Ta Chu and Hung-Jui Guo, resized to 128 by 128. Take a look at their paper (and please consider citing their paper if you use the data) here: www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/. The data set contains train and validation folders, and then each folder contains images and a metadata.csv which stores the labels. Here’s an overview:

movie_posters
├── train
│   ├── metadata.csv
│   ├── tt0084058.jpg
│   ├── tt0084867.jpg
│   ...
└── val
    ├── metadata.csv
    ├── tt0200465.jpg
    ├── tt0326965.jpg
    ...

Once we’ve downloaded the data using download_data(), we need to create the ImageClassificationData. We first create a function (load_data) to extract the list of images and associated labels which can then be passed to from_files(). We select a pre-trained backbone to use for our ImageClassifier and fine-tune on the posters data. We then use the trained ImageClassifier for inference. Finally, we save the model. Here’s the full example:

import os

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# 1. Create the DataModule
# Data set from the paper "Movie Genre Classification based on Poster Images with Deep Neural Networks".
# More info here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/
download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip")


def resolver(root, file_id):
    return os.path.join(root, f"{file_id}.jpg")


datamodule = ImageClassificationData.from_csv(
    "Id",
    ["Action", "Romance", "Crime", "Thriller", "Adventure"],
    train_file="data/movie_posters/train/metadata.csv",
    train_resolver=resolver,
    val_file="data/movie_posters/val/metadata.csv",
    val_resolver=resolver,
    transform_kwargs={"image_size": (128, 128)},
    batch_size=1,
)

# 2. Build the task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels, multi_label=datamodule.multi_label)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict the genre of a few movies!
datamodule = ImageClassificationData.from_files(
    predict_files=[
        "data/movie_posters/predict/tt0085318.jpg",
        "data/movie_posters/predict/tt0089461.jpg",
        "data/movie_posters/predict/tt0097179.jpg",
    ],
    batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("image_classification_multi_label_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The multi-label image classifier can be used directly from the command line with zero code using Flash Zero. You can run the movie posters example with:

flash image_classification from_movie_posters

To view configuration options and options for running the image classifier with your own data, use:

flash image_classification --help

Serving

The ImageClassifier is servable. For more information, see Image Classification.

Warning

Multi-gpu training is not currently supported by the ImageEmbedder task.

Image Embedder

The Task

Image embedding encodes an image into a vector of features which can be used for a downstream task. This could include: clustering, similarity search, or classification.

The Flash ImageEmbedder can be trained with Self Supervised Learning (SSL) to improve the quality of the embeddings it produces for your data. The ImageEmbedder internally relies on VISSL. You can read more about our integration with VISSL here: VISSL.


Example

Let’s see how to configure a training strategy for the ImageEmbedder task. First we create an ImageClassificationData object using a Dataset from torchvision. Next, we configure the ImageEmbedder task with training_strategy, backbone, head and pretraining_transform. Finally, we construct a Trainer and call fit(). Here’s the full example:

import torch
from torchvision.datasets import CIFAR10

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder

# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
    train_dataset=CIFAR10(".", download=True),
    batch_size=4,
)

# 2. Build the task
embedder = ImageEmbedder(
    backbone="vision_transformer",
    training_strategy="barlow_twins",
    head="barlow_twins_head",
    pretraining_transform="barlow_twins_transform",
    training_strategy_kwargs={"latent_embedding_dim": 128},
    pretraining_transform_kwargs={"size_crops": [32]},
)

# 3. Create the trainer and pre-train the encoder
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)

# 4. Save the model!
trainer.save_checkpoint("image_embedder_model.pt")

# 5. Download the downstream prediction dataset and generate embeddings
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_files(
    predict_files=[
        "data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
        "data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
    ],
    batch_size=3,
)
embeddings = trainer.predict(embedder, datamodule=datamodule)

# list of embeddings for images sent to the predict function
print(embeddings)

To learn how to view the available backbones / heads for this task, see Backbones and Heads. You can view the available training strategies with the available_training_strategies() method.

Note

The "dino" training strategy only supports single GPU training with strategy="ddp".

The head and pretraining_transform arguments should match the choice of training_strategy following this table:

training_strategy

head

pretraining_transform

simclr

simclr_head

simclr_transform

barlow_twins

barlow_twins_head

barlow_twins_transform

swav

swav_head

swav_transform

dino

dino_head

dino_transform

Object Detection

The Task

Object detection is the task of identifying objects in images and their associated classes and bounding boxes.

The ObjectDetector and ObjectDetectionData classes internally rely on IceVision.


Example

Let’s look at object detection with the COCO 128 data set, which contains 80 object classes. This is a subset of COCO train2017 with only 128 images. The data set is organized following the COCO format. Here’s an outline:

coco128
├── annotations
│   └── instances_train2017.json
├── images
│   └── train2017
│       ├── 000000000009.jpg
│       ├── 000000000025.jpg
│       ...
└── labels
    └── train2017
        ├── 000000000009.txt
        ├── 000000000025.txt
        ...

Once we’ve downloaded the data using download_data(), we can create the ObjectDetectionData. We select a pre-trained EfficientDet to use for our ObjectDetector and fine-tune on the COCO 128 data. We then use the trained ObjectDetector for inference. Finally, we save the model. Here’s the full example:

import flash
from flash.core.data.utils import download_data
from flash.image import ObjectDetectionData, ObjectDetector

# 1. Create the DataModule
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")

datamodule = ObjectDetectionData.from_coco(
    train_folder="data/coco128/images/train2017/",
    train_ann_file="data/coco128/annotations/instances_train2017.json",
    val_split=0.1,
    transform_kwargs={"image_size": 512},
    batch_size=4,
)

# 2. Build the task
model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=512)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Detect objects in a few images!
datamodule = ObjectDetectionData.from_files(
    predict_files=[
        "data/coco128/images/train2017/000000000625.jpg",
        "data/coco128/images/train2017/000000000626.jpg",
        "data/coco128/images/train2017/000000000629.jpg",
    ],
    transform_kwargs={"image_size": 512},
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("object_detection_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The object detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash object_detection

To view configuration options and options for running the object detector with your own data, use:

flash object_detection --help

Custom Transformations

Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case. The base InputTransform defines 7 hooks for different stages in the data loading pipeline. For object-detection tasks, you can leverage the transformations from Albumentations with the IceVisionTransformAdapter, creating a subclass of InputTransform

from dataclasses import dataclass
import albumentations as alb
from icevision.tfms import A

from flash import InputTransform
from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData


@dataclass
class BrightnessContrastTransform(InputTransform):
    image_size: int = 128

    def per_sample_transform(self):
        return IceVisionTransformAdapter(
            [*A.aug_tfms(size=self.image_size), A.Normalize(), alb.RandomBrightnessContrast()]
        )


datamodule = ObjectDetectionData.from_coco(
    train_folder="data/coco128/images/train2017/",
    train_ann_file="data/coco128/annotations/instances_train2017.json",
    val_split=0.1,
    train_transform=BrightnessContrastTransform,
    batch_size=4,
)

Keypoint Detection

The Task

Keypoint detection is the task of identifying keypoints in images and their associated classes.

The KeypointDetector and KeypointDetectionData classes internally rely on IceVision.


Example

Let’s look at keypoint detection with BIWI Sample Keypoints (center of face) from IceData. Once we’ve downloaded the data, we can create the KeypointDetectionData. We select a keypoint_rcnn with a resnet18_fpn backbone to use for our KeypointDetector and fine-tune on the BIWI data. We then use the trained KeypointDetector for inference. Finally, we save the model. Here’s the full example:

import flash
from flash.core.utilities.imports import example_requires
from flash.image import KeypointDetectionData, KeypointDetector

example_requires("image")

import icedata  # noqa: E402

# 1. Create the DataModule
data_dir = icedata.biwi.load_data()

datamodule = KeypointDetectionData.from_icedata(
    train_folder=data_dir,
    val_split=0.1,
    parser=icedata.biwi.parser,
    batch_size=4,
)

# 2. Build the task
model = KeypointDetector(
    head="keypoint_rcnn",
    backbone="resnet18_fpn",
    num_keypoints=1,
    num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Detect objects in a few images!
datamodule = KeypointDetectionData.from_files(
    predict_files=[
        str(data_dir / "biwi_sample/images/0.jpg"),
        str(data_dir / "biwi_sample/images/1.jpg"),
        str(data_dir / "biwi_sample/images/10.jpg"),
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("keypoint_detection_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The keypoint detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash keypoint_detection

To view configuration options and options for running the keypoint detector with your own data, use:

flash keypoint_detection --help

Instance Segmentation

The Task

Instance segmentation is the task of segmenting objects images and determining their associated classes.

The InstanceSegmentation and InstanceSegmentationData classes internally rely on IceVision.


Example

Let’s look at instance segmentation with The Oxford-IIIT Pet Dataset from IceData. Once we’ve downloaded the data, we can create the InstanceSegmentationData. We select a mask_rcnn with a resnet18_fpn backbone to use for our InstanceSegmentation and fine-tune on the pets data. We then use the trained InstanceSegmentation for inference. Finally, we save the model. Here’s the full example:

from functools import partial

import flash
from flash.core.utilities.imports import example_requires
from flash.image import InstanceSegmentation, InstanceSegmentationData

example_requires("image")

import icedata  # noqa: E402

# 1. Create the DataModule
data_dir = icedata.pets.load_data()

datamodule = InstanceSegmentationData.from_icedata(
    train_folder=data_dir,
    val_split=0.1,
    parser=partial(icedata.pets.parser, mask=True),
    batch_size=4,
)

# 2. Build the task
model = InstanceSegmentation(
    head="mask_rcnn",
    backbone="resnet18_fpn",
    num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Detect objects in a few images!
datamodule = InstanceSegmentationData.from_files(
    predict_files=[
        str(data_dir / "images/yorkshire_terrier_9.jpg"),
        str(data_dir / "images/yorkshire_terrier_12.jpg"),
        str(data_dir / "images/yorkshire_terrier_13.jpg"),
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("instance_segmentation_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The instance segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash instance_segmentation

To view configuration options and options for running the instance segmentation task with your own data, use:

flash instance_segmentation --help

Semantic Segmentation

The Task

Semantic Segmentation, or image segmentation, is the task of performing classification at a pixel-level, meaning each pixel will associated to a given class. See more: https://paperswithcode.com/task/semantic-segmentation


Example

Let’s look at an example using a data set generated with the CARLA driving simulator. The data was generated as part of the Kaggle Lyft Udacity Challenge. The data contains one folder of images and another folder with the corresponding segmentation masks. Here’s the structure:

data
├── CameraRGB
│   ├── F61-1.png
│   ├── F61-2.png
│       ...
└── CameraSeg
    ├── F61-1.png
    ├── F61-2.png
        ...

Once we’ve downloaded the data using download_data(), we create the SemanticSegmentationData. We select a pre-trained mobilenet_v3_large backbone with an fpn head to use for our SemanticSegmentation task and fine-tune on the CARLA data. We then use the trained SemanticSegmentation for inference. You can check the available pretrained weights for the backbones like this SemanticSegmentation.available_pretrained_weights(“resnet18”). Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData

# 1. Create the DataModule
# The data was generated with the  CARLA self-driving simulator as part of the Kaggle Lyft Udacity Challenge.
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
download_data(
    "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
    "./data",
)

datamodule = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    val_split=0.1,
    transform_kwargs=dict(image_size=(256, 256)),
    num_classes=21,
    batch_size=4,
)

# 2. Build the task
model = SemanticSegmentation(
    backbone="mobilenetv3_large_100",
    head="fpn",
    num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Segment a few images!
datamodule = SemanticSegmentationData.from_files(
    predict_files=[
        "data/CameraRGB/F61-1.png",
        "data/CameraRGB/F62-1.png",
        "data/CameraRGB/F63-1.png",
    ],
    batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("semantic_segmentation_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The semantic segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash semantic_segmentation

To view configuration options and options for running the semantic segmentation task with your own data, use:

flash semantic_segmentation --help

Serving

The SemanticSegmentation task is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.image import SemanticSegmentation
from flash.image.segmentation.output import SegmentationLabelsOutput

model = SemanticSegmentation.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/semantic_segmentation_model.pt"
)
model.output = SegmentationLabelsOutput(visualize=False)
model.serve()

You can now perform inference from your client like this:

import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "road.png").open("rb") as f:
    imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())

Style Transfer

The Task

The Neural Style Transfer Task is an optimization method which extract the style from an image and apply it another image while preserving its content. The goal is that the output image looks like the content image, but “painted” in the style of the style reference image.

style_transfer_example

The StyleTransfer and StyleTransferData classes internally rely on pystiche.


Example

Let’s look at transferring the style from The Starry Night onto the images from the COCO 128 data set from the Object Detection Guide. Once we’ve downloaded the data using download_data(), we create the StyleTransferData. Next, we create our StyleTransfer task with the desired style image and fit on the COCO 128 images. We then use the trained StyleTransfer for inference. Finally, we save the model. Here’s the full example:

import os

import torch

import flash
from flash.core.data.utils import download_data
from flash.image.style_transfer import StyleTransfer, StyleTransferData

# 1. Create the DataModule
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "./data")

datamodule = StyleTransferData.from_folders(train_folder="data/coco128/images/train2017", batch_size=1)

# 2. Build the task
model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"))

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Apply style transfer to a few images!
datamodule = StyleTransferData.from_files(
    predict_files=[
        "data/coco128/images/train2017/000000000625.jpg",
        "data/coco128/images/train2017/000000000626.jpg",
        "data/coco128/images/train2017/000000000629.jpg",
    ],
    batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("style_transfer_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The style transfer task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash style_transfer

To view configuration options and options for running the style transfer task with your own data, use:

flash style_transfer --help

Video Classification

The Task

Typically, Video Classification refers to the task of producing a label for actions identified in a given video. The task is to predict which class the video clip belongs to.

Lightning Flash VideoClassifier and VideoClassificationData classes internally rely on PyTorchVideo.


Example

Let’s develop a model to classifying video clips of Humans performing actions (such as: archery , bowling, etc.). We’ll use data from the Kinetics dataset. Here’s an outline of the folder structure:

video_dataset
├── train
│   ├── archery
│   │   ├── -1q7jA3DXQM_000005_000015.mp4
│   │   ├── -5NN5hdIwTc_000036_000046.mp4
│   │   ...
│   ├── bowling
│   │   ├── -5ExwuF5IUI_000030_000040.mp4
│   │   ├── -7sTNNI1Bcg_000075_000085.mp4
│   ... ...
└── val
    ├── archery
    │   ├── 0S-P4lr_c7s_000022_000032.mp4
    │   ├── 2x1lIrgKxYo_000589_000599.mp4
    │   ...
    ├── bowling
    │   ├── 1W7HNDBA4pA_000002_000012.mp4
    │   ├── 4JxH3S5JwMs_000003_000013.mp4
    ... ...

Once we’ve downloaded the data using download_data(), we create the VideoClassificationData. We select a pre-trained backbone to use for our VideoClassifier and fine-tune on the Kinetics data. The backbone can be any model from the PyTorchVideo Model Zoo. We then use the trained VideoClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier

# 1. Create the DataModule
# Find more datasets at https://pytorchvideo.readthedocs.io/en/latest/data.html
download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip", "./data")

datamodule = VideoClassificationData.from_folders(
    train_folder="data/kinetics/train",
    val_folder="data/kinetics/val",
    clip_sampler="uniform",
    clip_duration=1,
    decode_audio=False,
    batch_size=1,
)

# 2. Build the task
model = VideoClassifier(backbone="x3d_xs", labels=datamodule.labels, pretrained=False)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
    max_epochs=1, gpus=torch.cuda.device_count(), strategy="ddp" if torch.cuda.device_count() > 1 else None
)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Make a prediction
datamodule = VideoClassificationData.from_folders(predict_folder="data/kinetics/predict", batch_size=1)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("video_classification.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The video classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash video_classification

To view configuration options and options for running the video classifier with your own data, use:

flash video_classification --help

Audio Classification

The Task

The task of identifying what is in an audio file is called audio classification. Typically, Audio Classification is used to identify audio files containing sounds or words. The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty. A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc.


Example

Let’s look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset. The dataset contains train, val and test folders, and then each folder contains a airconditioner folder, with spectrograms generated from air-conditioner sounds, siren folder with spectrograms generated from siren sounds and the same goes for the other classes.

urban8k_images
├── train
│   ├── air_conditioner
│   ├── car_horn
│   ├── children_playing
│   ├── dog_bark
│   ├── drilling
│   ├── engine_idling
│   ├── gun_shot
│   ├── jackhammer
│   ├── siren
│   └── street_music
├── test
│   ├── air_conditioner
│   ├── car_horn
│   ├── children_playing
│   ├── dog_bark
│   ├── drilling
│   ├── engine_idling
│   ├── gun_shot
│   ├── jackhammer
│   ├── siren
│   └── street_music
└── val
    ├── air_conditioner
    ├── car_horn
    ├── children_playing
    ├── dog_bark
    ├── drilling
    ├── engine_idling
    ├── gun_shot
    ├── jackhammer
    ├── siren
    └── street_music

        ...

Once we’ve downloaded the data using download_data(), we create the AudioClassificationData. We select a pre-trained backbone to use for our ImageClassifier and fine-tune on the UrbanSound8k spectrogram images data. We then use the trained ImageClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.audio import AudioClassificationData
from flash.core.data.utils import download_data
from flash.image import ImageClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data")

datamodule = AudioClassificationData.from_folders(
    train_folder="data/urban8k_images/train",
    val_folder="data/urban8k_images/val",
    transform_kwargs=dict(spectrogram_size=(64, 64)),
    batch_size=4,
)

# 2. Build the model.
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy=("freeze_unfreeze", 1))

# 4. Predict what's on few images! air_conditioner, children_playing, siren etc.
datamodule = AudioClassificationData.from_files(
    predict_files=[
        "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg",
        "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg",
        "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg",
    ],
    batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("audio_classification_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The audio classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash audio_classification

To view configuration options and options for running the audio classifier with your own data, use:

flash audio_classification --help

Speech Recognition

The Task

Speech recognition is the task of classifying audio into a text transcription. We rely on Wav2Vec as our backbone, fine-tuned on labeled transcriptions for speech to text. Wav2Vec is pre-trained on thousand of hours of unlabeled audio, providing a strong baseline when fine-tuning to downstream tasks such as Speech Recognition.


Example

Let’s fine-tune the model onto our own labeled audio transcription data:

Here’s the structure our CSV file:

file,text
"/path/to/file_1.wav","what was said in file 1."
"/path/to/file_2.wav","what was said in file 2."
"/path/to/file_3.wav","what was said in file 3."
...

Alternatively, here is the structure of our JSON file:

{"file": "/path/to/file_1.wav", "text": "what was said in file 1."}
{"file": "/path/to/file_2.wav", "text": "what was said in file 2."}
{"file": "/path/to/file_3.wav", "text": "what was said in file 3."}

Once we’ve downloaded the data using download_data(), we create the SpeechRecognitionData. We select a pre-trained Wav2Vec backbone to use for our SpeechRecognition and finetune on a subset of the TIMIT corpus. The backbone can be any Wav2Vec model from HuggingFace transformers. Next, we use the trained SpeechRecognition for inference and save the model. Here’s the full example:

import torch

import flash
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")

datamodule = SpeechRecognitionData.from_json(
    "file",
    "text",
    train_file="data/timit/train.json",
    test_file="data/timit/test.json",
    batch_size=4,
)

# 2. Build the task
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict on audio files!
datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("speech_recognition_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The speech recognition task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash speech_recognition

To view configuration options and options for running the speech recognition task with your own data, use:

flash speech_recognition --help

Serving

The SpeechRecognition is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.audio import SpeechRecognition

model = SpeechRecognition.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/speech_recognition_model.pt"
)
model.serve()

You can now perform inference from your client like this:

import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f:
    audio_str = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": audio_str}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)

print(resp.json())

Tabular Classification

The Task

Tabular classification is the task of assigning a class to samples of structured or relational data. The TabularClassifier task can be used for classification of samples in more than two classes (multi-class classification).


Example

Let’s look at training a model to predict if passenger survival on the Titanic using the classic Kaggle data set. The data is provided in CSV files that look like this:

PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
1,0,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S
3,1,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S
5,0,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S
6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q
...

Once we’ve downloaded the data using download_data(), we can create the TabularData from our CSV files using the from_csv() method. From the API reference, we need to provide:

  • cat_cols- A list of the names of columns that contain categorical data (strings or integers).

  • num_cols- A list of the names of columns that contain numerical continuous data (floats).

  • target- The name of the column we want to predict.

  • train_csv- A CSV file containing the training data converted to a Pandas DataFrame

Next, we create the TabularClassifier and finetune on the Titanic data. We then use the trained TabularClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularClassificationData, TabularClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")

datamodule = TabularClassificationData.from_csv(
    categorical_fields=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
    numerical_fields="Fare",
    target_fields="Survived",
    train_file="data/titanic/titanic.csv",
    val_split=0.1,
    batch_size=8,
)

# 2. Build the task
model = TabularClassifier.from_data(datamodule, backbone="fttransformer")

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Generate predictions from a CSV
datamodule = TabularClassificationData.from_csv(
    predict_file="data/titanic/titanic.csv",
    parameters=datamodule.parameters,
    batch_size=8,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("tabular_classification_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The tabular classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash tabular_classifier

To view configuration options and options for running the tabular classifier with your own data, use:

flash tabular_classifier --help

Serving

The TabularClassifier is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.core.classification import LabelsOutput
from flash.tabular import TabularClassifier

model = TabularClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt"
)
model.output = LabelsOutput(["Did not survive", "Survived"])
model.serve()

You can now perform inference from your client like this:

import pandas as pd
import requests

from flash.core.data.utils import download_data

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")

df = pd.read_csv("./data/titanic/predict.csv")
text = str(df.to_csv())
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())

Tabular Forecasting

The Task

Tabular (or timeseries) forecasting is the task of using historical data to predict future trends in a time varying quantity such as: stock prices, temperature, etc. The TabularForecaster and TabularForecastingData enable timeseries forecasting in Flash using PyTorch Forecasting.


Example

Let’s look at training the NBeats model on some synthetic data with seasonal changes. The data could represent many naturally occurring timeseries such as energy demand which fluctuates throughout the day but is also expected to change with the season. This example is a reimplementation of the NBeats tutorial from the PyTorch Forecasting docs in Flash. The NBeats model takes no additional inputs unlike other more complex models such as the Temporal Fusion Transformer.

Once we’ve created, we can create the TabularData from our DataFrame using the from_data_frame() method. To this method, we provide any configuration arguments that should be used when internally constructing the TimeSeriesDataSet.

Next, we create the TabularForecaster and train on the data. We then use the trained TabularForecaster for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.utilities.imports import example_requires
from flash.tabular.forecasting import TabularForecaster, TabularForecastingData

example_requires("tabular")

import pandas as pd  # noqa: E402
from pytorch_forecasting.data import NaNLabelEncoder  # noqa: E402
from pytorch_forecasting.data.examples import generate_ar_data  # noqa: E402

# Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html
# 1. Create the DataModule
data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42)
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")

max_encoder_length = 60
max_prediction_length = 20

training_cutoff = data["time_idx"].max() - max_prediction_length

datamodule = TabularForecastingData.from_data_frame(
    time_idx="time_idx",
    target="value",
    categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
    group_ids=["series"],
    # only unknown variable is "value" - and N-Beats can also not take any additional variables
    time_varying_unknown_reals=["value"],
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    train_data_frame=data[lambda x: x.time_idx <= training_cutoff],
    # validate on the last sequence
    val_data_frame=data[lambda x: x.time_idx > training_cutoff - max_encoder_length],
    batch_size=32,
)

# 2. Build the task
model = TabularForecaster(
    datamodule.parameters,
    backbone="n_beats",
    backbone_kwargs={"widths": [32, 512], "backcast_loss_ratio": 0.1},
)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(), gradient_clip_val=0.01)
trainer.fit(model, datamodule=datamodule)

# 4. Generate predictions
datamodule = TabularForecastingData.from_data_frame(
    predict_data_frame=data[lambda x: x.time_idx > training_cutoff - max_encoder_length],
    parameters=datamodule.parameters,
    batch_size=32,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("tabular_forecasting_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.

Note

Read more about our integration with PyTorch Forecasting to see how to use your Flash model with their built-in plotting capabilities.

Text Classification

The Task

Text classification is the task of assigning a piece of text (word, sentence or document) an appropriate class, or category. The categories depend on the chosen data set and can range from topics.


Example

Let’s train a model to classify text as expressing either positive or negative sentiment. We will be using the IMDB data set, that contains a train.csv and valid.csv. Here’s the structure:

review,sentiment
"Japanese indie film with humor ... ",positive
"Isaac Florentine has made some ...",negative
"After seeing the low-budget ...",negative
"I've seen the original English version ...",positive
"Hunters chase what they think is a man through ...",negative
...

Once we’ve downloaded the data using download_data(), we create the TextClassificationData. We select a pre-trained backbone to use for our TextClassifier and finetune on the IMDB data. The backbone can be any BERT classification model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the TextClassifier and the TextClassificationData!

Next, we use the trained TextClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")

datamodule = TextClassificationData.from_csv(
    "review",
    "sentiment",
    train_file="data/imdb/train.csv",
    val_file="data/imdb/valid.csv",
    batch_size=4,
)

# 2. Build the task
model = TextClassifier(backbone="prajjwal1/bert-medium", labels=datamodule.labels)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Classify a few sentences! How was the movie?
datamodule = TextClassificationData.from_lists(
    predict_data=[
        "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
        "The worst movie in the history of cinema.",
        "I come from Bulgaria where it 's almost impossible to have a tornado.",
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("text_classification_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The text classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash text_classification

To view configuration options and options for running the text classifier with your own data, use:

flash text_classification --help

Serving

The TextClassifier is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.text import TextClassifier

model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.7.0/text_classification_model.pt")
model.serve()

You can now perform inference from your client like this:

import requests

text = "Best movie ever"
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)

print(resp.json())

Accelerate Training & Inference with Torch ORT

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the TextClassifier once installed. See installation instructions here.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

...

model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)

Multi-label Text Classification

The Task

Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality (text in this case). Multi-label text classification is supported by the TextClassifier via the multi-label argument.


Example

Let’s look at the task of classifying comment toxicity. The data we will use in this example is from the kaggle toxic comment classification challenge by jigsaw: www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge. The data is stored in CSV files with this structure:

"id","comment_text","toxic","severe_toxic","obscene","threat","insult","identity_hate"
"0000997932d777bf","...",0,0,0,0,0,0
"0002bcb3da6cb337","...",1,1,1,0,1,0
"0005c987bdfc9d4b","...",1,0,0,0,0,0
...

Once we’ve downloaded the data using download_data(), we create the TextClassificationData. We select a pre-trained backbone to use for our TextClassifier and finetune on the toxic comments data. The backbone can be any BERT classification model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the TextClassifier and the TextClassificationData!

Next, we use the trained TextClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Create the DataModule
# Data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data")

datamodule = TextClassificationData.from_csv(
    "comment_text",
    ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
    train_file="data/jigsaw_toxic_comments/train.csv",
    val_split=0.1,
    batch_size=4,
)

# 2. Build the task
model = TextClassifier(
    backbone="unitary/toxic-bert",
    labels=datamodule.labels,
    multi_label=datamodule.multi_label,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Generate predictions for a few comments!
datamodule = TextClassificationData.from_lists(
    predict_data=[
        "No, he is an arrogant, self serving, immature idiot. Get it right.",
        "U SUCK HANNAH MONTANA",
        "Would you care to vote? Thx.",
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("text_classification_multi_label_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The multi-label text classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash text_classification from_toxic

To view configuration options and options for running the text classifier with your own data, use:

flash text_classification --help

Serving

The TextClassifier is servable. For more information, see Text Classification.

Question Answering

The Task

Question Answering is the task of being able to answer questions pertaining to some known context. For example, given a context about some historical figure, any question pertaininig to the context should be answerable. In our case the article would be our input context and question, and the answer would be the output sequence from the model.

Note

We currently only support Extractive Question Answering, like the task performed using the SQUAD like datasets.


Example

Let’s look at an example. We’ll use the SQUAD 2.0 dataset, which contains train-v2.0.json and dev-v2.0.json. Each JSON file looks like this:

{
            "answers": {
                    "answer_start": [94, 87, 94, 94],
                    "text": ["10th and 11th centuries", "in the 10th and 11th centuries", "10th and 11th centuries", "10th and 11th centuries"]
            },
            "context": "\"The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th and 11th centuries gave thei...",
            "id": "56ddde6b9a695914005b9629",
            "question": "When were the Normans in Normandy?",
            "title": "Normans"
    }
...

In the above, the context key represents the context used for the question and answer, the question key represents the question being asked with respect to the context, the answer key stores the answer(s) for the question. id and title are used for unique identification and grouping concepts together respectively. Once we’ve downloaded the data using download_data(), we create the QuestionAnsweringData. We select a pre-trained backbone to use for our QuestionAnsweringTask and finetune on the SQUAD 2.0 data. The backbone can be any Question Answering model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the QuestionAnsweringData and the QuestionAnsweringTask!

Next, we use the trained QuestionAnsweringTask for inference. Finally, we save the model. Here’s the full example:

from flash import Trainer
from flash.core.data.utils import download_data
from flash.core.utilities.imports import example_requires
from flash.text import QuestionAnsweringData, QuestionAnsweringTask

example_requires("text")

import nltk  # noqa: E402

nltk.download("punkt")

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/squad_tiny.zip", "./data/")

datamodule = QuestionAnsweringData.from_squad_v2(
    train_file="./data/squad_tiny/train.json",
    val_file="./data/squad_tiny/val.json",
    batch_size=4,
)

# 2. Build the task
model = QuestionAnsweringTask(backbone="distilbert-base-uncased")

# 3. Create the trainer and finetune the model
trainer = Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule)

# 4. Answer some Questions!
datamodule = QuestionAnsweringData.from_dicts(
    predict_data={
        "id": ["56ddde6b9a695914005b9629", "56ddde6b9a695914005b9628"],
        "context": [
            """
        The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th
        and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse
        ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under
        their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations
        of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their
        descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct
        cultural and ethnic identity of the Normans emerged initially in the first half of the 10th
        century, and it continued to evolve over the succeeding centuries.
        """,
            """
        The Normans (Norman: Nourmands; French: Normands; Latin: Normanni) were the people who in the 10th
        and 11th centuries gave their name to Normandy, a region in France. They were descended from Norse
        ("Norman" comes from "Norseman") raiders and pirates from Denmark, Iceland and Norway who, under
        their leader Rollo, agreed to swear fealty to King Charles III of West Francia. Through generations
        of assimilation and mixing with the native Frankish and Roman-Gaulish populations, their
        descendants would gradually merge with the Carolingian-based cultures of West Francia. The distinct
        cultural and ethnic identity of the Normans emerged initially in the first half of the 10th
        century, and it continued to evolve over the succeeding centuries.
        """,
        ],
        "question": ["When were the Normans in Normandy?", "In what country is Normandy located?"],
    },
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("question_answering_on_sqaud_v2.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Accelerate Training & Inference with Torch ORT

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the QuestionAnsweringTask once installed. See installation instructions here.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

...

model = QuestionAnsweringTask(backbone="distilbert-base-uncased", max_answer_length=30, enable_ort=True)

Summarization

The Task

Summarization is the task of summarizing text from a larger document/article into a short sentence/description. For example, taking a web article and describing the topic in a short sentence. This task is a subset of Sequence to Sequence tasks, which require the model to generate a variable length sequence given an input sequence. In our case the article would be our input sequence, and the short description/sentence would be the output sequence from the model.


Example

Let’s look at an example. We’ll use the XSUM dataset, which contains a train.csv and valid.csv. Each CSV file looks like this:

input,target
"The researchers have sequenced the genome of a strain of bacterium that causes the virulent infection...","A team of UK scientists hopes to shed light on the mysteries of bleeding canker, a disease that is threatening the nation's horse chestnut trees."
"Knight was shot in the leg by an unknown gunman at Miami's Shore Club where West was holding a pre-MTV Awards...",Hip hop star Kanye West is being sued by Death Row Records founder Suge Knight over a shooting at a beach party in August 2005.
...

In the above, the input column represents the long articles/documents, and the target is the short description used as the target. Once we’ve downloaded the data using download_data(), we create the SummarizationData. We select a pre-trained backbone to use for our SummarizationTask and finetune on the XSUM data. The backbone can be any Seq2Seq summarization model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the SummarizationData and the SummarizationTask!

Next, we use the trained SummarizationTask for inference. Finally, we save the model. Here’s the full example:

from flash import Trainer
from flash.core.data.utils import download_data
from flash.text import SummarizationData, SummarizationTask

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "./data/")

datamodule = SummarizationData.from_csv(
    "input",
    "target",
    train_file="data/xsum/train.csv",
    val_file="data/xsum/valid.csv",
    batch_size=4,
)

# 2. Build the task
model = SummarizationTask()

# 3. Create the trainer and finetune the model
trainer = Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Summarize some text!
datamodule = SummarizationData.from_lists(
    predict_data=[
        """
        Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local
        people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue.
        They came to Brixton to see work which has started to revitalise the borough.
        It was Charles' first visit to the area since 1996, when he was accompanied by the former
        South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue
        for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit.
        ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes.
        She asked me were they ripe and I said yes - they're from the Dominican Republic.""
        Mr Chong is one of 170 local retailers who accept the Brixton Pound.
        Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market
        or in participating shops.
        During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children
        nearby on an estate off Coldharbour Lane. Mr West said:
        ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man.""
        He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'""
        Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust.
        The trust hopes to restore and refurbish the building,
        where once Jimi Hendrix and The Clash played, as a new community and business centre."
        """,
        """
        "The problem is affecting people using the older versions of the PlayStation 3, called the ""Fat"" model. The
        problem isn't affecting the newer PS3 Slim systems that have been on sale since September last year. Sony have
        also said they are aiming to have the problem fixed shortly but is advising some users to avoid using their
        console for the time being.""We hope to resolve this problem within the next 24 hours,"" a statement reads.
        ""In the meantime, if you have a model other than the new slim PS3, we advise that you do not use your PS3
        system, as doing so may result in errors in some functionality, such as recording obtained trophies, and not
        being able to restore certain data.""We believe we have identified that this problem is being caused by a bug
        in the clock functionality incorporated in the system.""The PlayStation Network is used by millions of people
        around the world.It allows users to play their friends at games like Fifa over the internet and also do things
        like download software or visit online stores.",Sony has told owners of older models of its PlayStation 3
        console to stop using the machine because of a problem with the PlayStation Network.
        """,
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("summarization_model_xsum.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The summarization task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash summarization

To view configuration options and options for running the summarization task with your own data, use:

flash summarization --help

Serving

The SummarizationTask is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.text import SummarizationTask

model = SummarizationTask.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/summarization_model_xsum.pt"
)
model.serve()

You can now perform inference from your client like this:

import requests

text = """
Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local
people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue.
They came to Brixton to see work which has started to revitalise the borough.
It was Charles' first visit to the area since 1996, when he was accompanied by the former
South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue
for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit.
""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes.
She asked me were they ripe and I said yes - they're from the Dominican Republic.""
Mr Chong is one of 170 local retailers who accept the Brixton Pound.
Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market
or in participating shops.
During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children
nearby on an estate off Coldharbour Lane. Mr West said:
""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man.""
He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'""
Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust.
The trust hopes to restore and refurbish the building,
where once Jimi Hendrix and The Clash played, as a new community and business centre."
"""
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)

print(resp.json())

Accelerate Training & Inference with Torch ORT

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the SummarizationTask once installed. See installation instructions here.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

...

model = SummarizationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)

Translation

The Task

Translation is the task of translating text from a source language to another, such as English to Romanian. This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case, the task will take an English sequence as input, and output the same sequence in Romanian.


Example

Let’s look at an example. We’ll use WMT16 English/Romanian, a dataset of English to Romanian samples, based on the Europarl corpora. The data set contains a train.csv and valid.csv. Each CSV file looks like this:

input,target
"Written statements and oral questions (tabling): see Minutes","Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal"
"Closure of sitting","Ridicarea şedinţei"
...

In the above the input/target columns represent the English and Romanian translation respectively. Once we’ve downloaded the data using download_data(), we create the TranslationData. We select a pre-trained backbone to use for our TranslationTask and finetune on the WMT16 data. The backbone can be any Seq2Seq translation model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the TranslationData and the TranslationTask!

Next, we use the trained TranslationTask for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.text import TranslationData, TranslationTask

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "./data")

datamodule = TranslationData.from_csv(
    "input",
    "target",
    train_file="data/wmt_en_ro/train.csv",
    val_file="data/wmt_en_ro/valid.csv",
    batch_size=4,
)

# 2. Build the task
model = TranslationTask(backbone="Helsinki-NLP/opus-mt-en-ro")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Translate something!
datamodule = TranslationData.from_lists(
    predict_data=[
        "BBC News went to meet one of the project's first graduates.",
        "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",
        "Of course, it's still early in the election cycle.",
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("translation_model_en_ro.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The translation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash translation

To view configuration options and options for running the translation task with your own data, use:

flash translation --help

Serving

The TranslationTask is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.text import TranslationTask

model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.7.0/translation_model_en_ro.pt")
model.serve()

You can now perform inference from your client like this:

import requests

text = "Some English text"
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)

print(resp.json())

Accelerate Training & Inference with Torch ORT

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the TranslationTask once installed. See installation instructions here.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

...

model = TranslationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)

Text Embedder

The Task

This task consists of creating a Sentence Embedding. That is, a vector of sentence representations which can be used for a downstream task. The TextEmbedder implementation relies on components from sentence-transformers.


Example

Let’s look at an example of generating sentence embeddings.

We start by loading some sentences for prediction with the TextClassificationData class. Next, we create our TextEmbedder with a pretrained backbone from the HuggingFace hub. Finally, we create a Trainer and generate sentence embeddings. Here’s the full example:

import torch

import flash
from flash.text import TextClassificationData, TextEmbedder

# 1. Create the DataModule
datamodule = TextClassificationData.from_lists(
    predict_data=[
        "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
        "The worst movie in the history of cinema.",
        "I come from Bulgaria where it 's almost impossible to have a tornado.",
    ],
    batch_size=4,
)

# 2. Load a previously trained TextEmbedder
model = TextEmbedder(backbone="sentence-transformers/all-MiniLM-L6-v2")

# 3. Generate embeddings for the first 3 graphs
trainer = flash.Trainer(gpus=torch.cuda.device_count())
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

To learn how to view the available backbones / heads for this task, see Backbones and Heads.

Point Cloud Segmentation

The Task

A Point Cloud is a set of data points in space, usually describes by x, y and z coordinates.

PointCloud Segmentation is the task of performing classification at a point-level, meaning each point will associated to a given class. The current integration builds on top Open3D-ML.


Example

Let’s look at an example using a data set generated from the KITTI Vision Benchmark. The data are a tiny subset of the original dataset and contains sequences of point clouds. The data contains multiple folder, one for each sequence and a meta.yaml file describing the classes and their official associated color map. A sequence should contain one folder for scans and one folder for labels, plus a pose.txt to re-align the sequence if required. Here’s the structure:

data
├── meta.yaml
├── 00
│   ├── scans
|   |    ├── 00000.bin
|   |    ├── 00001.bin
|   |    ...
│   ├── labels
|   |    ├── 00000.label
|   |    ├── 00001.label
|   |   ...
|   ├── pose.txt
│   ...
|
└── XX
   ├── scans
   |    ├── 00000.bin
   |    ├── 00001.bin
   |    ...
   ├── labels
   |    ├── 00000.label
   |    ├── 00001.label
   |   ...
   ├── pose.txt

Learn more: http://www.semantic-kitti.org/dataset.html

Once we’ve downloaded the data using download_data(), we create the PointCloudSegmentationData. We select a pre-trained randlanet_semantic_kitti backbone for our PointCloudSegmentation task. We then use the trained PointCloudSegmentation for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData

# 1. Create the DataModule
# Dataset Credit: http://www.semantic-kitti.org/
download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/")

datamodule = PointCloudSegmentationData.from_folders(
    train_folder="data/SemanticKittiTiny/train",
    val_folder="data/SemanticKittiTiny/val",
    batch_size=4,
)

# 2. Build the task
model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
    max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule)

# 4. Predict what's within a few PointClouds?
datamodule = PointCloudSegmentationData.from_files(
    predict_files=[
        "data/SemanticKittiTiny/predict/000000.bin",
        "data/SemanticKittiTiny/predict/000001.bin",
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("pointcloud_segmentation_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.

https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/getting_started_ml_visualizer.gif

Flash Zero

The point cloud segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash pointcloud_segmentation

To view configuration options and options for running the point cloud segmentation task with your own data, use:

flash pointcloud_segmentation --help

Point Cloud Object Detection

The Task

A Point Cloud is a set of data points in space, usually describes by x, y and z coordinates.

PointCloud Object Detection is the task of identifying 3D objects in point clouds and their associated classes and 3D bounding boxes.

The current integration builds on top Open3D-ML.


Example

Let’s look at an example using a data set generated from the KITTI Vision Benchmark. The data are a tiny subset of the original dataset and contains sequences of point clouds.

The data contains:
  • one folder for scans

  • one folder for scan calibrations

  • one folder for labels

  • a meta.yaml file describing the classes and their official associated color map.

Here’s the structure:

data
├── meta.yaml
├── train
│   ├── scans
|   |    ├── 00000.bin
|   |    ├── 00001.bin
|   |    ...
│   ├── calibs
|   |    ├── 00000.txt
|   |    ├── 00001.txt
|   |   ...
│   ├── labels
|   |    ├── 00000.txt
|   |    ├── 00001.txt
│   ...
├── val
│   ...
├── predict
    ├── scans
    |   ├── 00000.bin
    |   ├── 00001.bin
    |
    ├── calibs
    |   ├── 00000.txt
    |   ├── 00001.txt
    ├── meta.yaml

Learn more: http://www.semantic-kitti.org/dataset.html

Once we’ve downloaded the data using download_data(), we create the PointCloudObjectDetectorData. We select a pre-trained randlanet_semantic_kitti backbone for our PointCloudObjectDetector task. We then use the trained PointCloudObjectDetector for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData

# 1. Create the DataModule
# Dataset Credit: http://www.semantic-kitti.org/
download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/")

datamodule = PointCloudObjectDetectorData.from_folders(
    train_folder="data/KITTI_Tiny/Kitti/train",
    val_folder="data/KITTI_Tiny/Kitti/val",
    batch_size=4,
)

# 2. Build the task
model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
    max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule)

# 4. Predict what's within a few PointClouds?
datamodule = PointCloudObjectDetectorData.from_files(
    predict_files=[
        "data/KITTI_Tiny/Kitti/predict/scans/000000.bin",
        "data/KITTI_Tiny/Kitti/predict/scans/000001.bin",
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("pointcloud_detection_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.

https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/visualizer_BoundingBoxes.png

Flash Zero

The point cloud object detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash pointcloud_detection

To view configuration options and options for running the point cloud object detector with your own data, use:

flash pointcloud_detection --help

Graph Classification

The Task

This task consist on classifying graphs. The task predicts which ‘class’ the graph belongs to. A class is a label that indicates the kind of graph. For example, a label may indicate whether one molecule interacts with another.

The GraphClassifier and GraphClassificationData classes internally rely on pytorch-geometric.


Example

Let’s look at the task of classifying graphs from the KKI data set from TU Dortmund University.

Once we’ve created the TUDataset, we create the GraphClassificationData. We then create our GraphClassifier and train on the KKI data. Next, we use the trained GraphClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.utilities.imports import example_requires
from flash.graph import GraphClassificationData, GraphClassifier

example_requires("graph")

from torch_geometric.datasets import TUDataset  # noqa: E402

# 1. Create the DataModule
dataset = TUDataset(root="data", name="KKI")

datamodule = GraphClassificationData.from_datasets(
    train_dataset=dataset,
    val_split=0.1,
    batch_size=4,
)
# 2. Build the task
backbone_kwargs = {"hidden_channels": 512, "num_layers": 4}
model = GraphClassifier(
    num_features=datamodule.num_features, num_classes=datamodule.num_classes, backbone_kwargs=backbone_kwargs
)

# 3. Create the trainer and fit the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Classify some graphs!
datamodule = GraphClassificationData.from_datasets(
    predict_dataset=dataset[:3],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("graph_classification_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The graph classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash graph_classification

To view configuration options and options for running the graph classifier with your own data, use:

flash graph_classification --help

Graph Embedder

The Task

This task consists of creating an embedding of a graph. That is, a vector of features which can be used for a downstream task. The GraphEmbedder and GraphClassificationData classes internally rely on pytorch-geometric.


Example

Let’s look at generating embeddings of graphs from the KKI data set from TU Dortmund University.

We start by creating the TUDataset <https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/tu_dataset.html#TUDataset>. Next, we load a trained GraphEmbedder (from a previously trained GraphClassifier). Finally, we save the model. Here’s the full example:

import torch

To learn how to view the available backbones / heads for this task, see Backbones and Heads.

Providers

Flash is a framework integrator. We rely on many open source frameworks for our tasks, visualizations and backbones. Here’s a list of some of the providers we use for backbones and heads within Flash (check them out and star their repos to spread the open source love!):

You can also read our guides for some of our larger integrations:

BaaL

The framework Bayesian Active Learning (BaaL) is an active learning library developed at ElementAI.

Active Learning is a sub-field in AI, focusing on adding a human in the learning loop. The most uncertain samples will be labelled by the human to accelerate the model training cycle.

Credit to ElementAI / Baal Team for creating this diagram flow


With its integration within Flash, the Active Learning process is simpler than ever before.

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")

# Implement the research use-case where we mask labels from labelled dataset.
datamodule = ActiveLearningDataModule(
    ImageClassificationData.from_folders(train_folder="data/hymenoptera_data/train/", batch_size=2),
    initial_num_labels=10,
)

# 2. Build the task
head = torch.nn.Sequential(
    torch.nn.Dropout(p=0.1),
    torch.nn.Linear(512, datamodule.num_classes),
)
model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes)

# 3.1 Create the trainer
trainer = flash.Trainer(max_epochs=3)

# 3.2 Create the active learning loop and connect it to the trainer
active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1)
active_learning_loop.connect(trainer.fit_loop)
trainer.fit_loop = active_learning_loop

# 3.3 Finetune
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict what's on a few images! ants or bees?
datamodule = ImageClassificationData.from_files(
    predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"],
    batch_size=1,
)
predictions = trainer.predict(model, datamodule=datamodule, output="probabilities")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

FiftyOne

We have collaborated with the team at Voxel51 to integrate their tool, FiftyOne, into Lightning Flash.

FiftyOne is an open-source tool for building high-quality datasets and computer vision models. The FiftyOne API and App enable you to visualize datasets and interpret models faster and more effectively.

This integration allows you to view predictions generated by your tasks in the FiftyOne App, as well as easily incorporate FiftyOne Datasets into your tasks. All image and video tasks are supported!

Installation

In order to utilize this integration, you will need to install FiftyOne:

pip install fiftyone

Visualizing Flash predictions

This section shows you how to augment your existing Lightning Flash workflows with a couple of lines of code that let you visualize predictions in FiftyOne. You can visualize predictions for classification, object detection, and semantic segmentation tasks. Doing so is as easy as updating your model to use one of the following outputs:

The visualize() function then lets you visualize your predictions in the FiftyOne App. This function accepts a list of dictionaries containing FiftyOne Label objects and filepaths, which is exactly the output of the FiftyOne outputs when the return_filepath=True option is specified.

import torch

import flash
from flash.core.data.utils import download_data
from flash.core.integrations.fiftyone import visualize
from flash.image import ImageClassificationData, ImageClassifier

# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")

# 2 Load data
datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    predict_folder="data/hymenoptera_data/predict/",
    batch_size=16,
)

# 3 Fine tune a model
model = ImageClassifier(
    backbone="resnet18",
    labels=datamodule.labels,
)
trainer = flash.Trainer(
    max_epochs=1,
    gpus=torch.cuda.device_count(),
    fast_dev_run=True,
)
trainer.finetune(
    model,
    datamodule=datamodule,
    strategy=("freeze_unfreeze", 1),
)
trainer.save_checkpoint("image_classification_model.pt")

# 4 Predict from checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")
predictions = trainer.predict(model, datamodule=datamodule, output="fiftyone")  # output FiftyOne format

# 5 Visualize predictions in FiftyOne App
# Optional: pass `wait=True` to block execution until App is closed
session = visualize(predictions, wait=True)

The visualize() function can be used in all of the following environments:

  • Local Python shell: The App will launch in a new tab in your default web browser

  • Remote Python shell: Pass the remote=True option and then follow the instructions printed to your remote shell to open the App in your browser on your local machine

  • Jupyter notebook: The App will launch in the output of your current cell

  • Google Colab: The App will launch in the output of your current cell

  • Python script: Pass the wait=True option to block execution of your script until the App is closed

See this page for more information about using the FiftyOne App in different environments.

Using FiftyOne datasets

The above workflow is great for visualizing model predictions. However, if you store your data in a FiftyOne Dataset initially, then you can also visualize ground truth annotations. This allows you to perform more complex analysis with views into your data and evaluation of your model results.

The from_fiftyone() method allows you to load your FiftyOne datasets directly into a DataModule to be used for training, testing, or inference.

from itertools import chain

import fiftyone as fo
import torch

import flash
from flash.core.classification import FiftyOneLabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")

# 2 Load data into FiftyOne
train_dataset = fo.Dataset.from_dir(
    dataset_dir="data/hymenoptera_data/train/",
    dataset_type=fo.types.ImageClassificationDirectoryTree,
)
val_dataset = fo.Dataset.from_dir(
    dataset_dir="data/hymenoptera_data/val/",
    dataset_type=fo.types.ImageClassificationDirectoryTree,
)
test_dataset = fo.Dataset.from_dir(
    dataset_dir="data/hymenoptera_data/test/",
    dataset_type=fo.types.ImageClassificationDirectoryTree,
)

# 3 Load data into Flash
datamodule = ImageClassificationData.from_fiftyone(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    test_dataset=test_dataset,
    batch_size=4,
)

# 4 Fine tune model
model = ImageClassifier(
    backbone="resnet18",
    labels=datamodule.labels,
)
trainer = flash.Trainer(
    max_epochs=1,
    gpus=torch.cuda.device_count(),
    limit_train_batches=1,
    limit_val_batches=1,
)
trainer.finetune(
    model,
    datamodule=datamodule,
    strategy=("freeze_unfreeze", 1),
)
trainer.save_checkpoint("image_classification_model.pt")

# 5 Predict from checkpoint on data with ground truth
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")
datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset, batch_size=4)
predictions = trainer.predict(
    model, datamodule=datamodule, output=FiftyOneLabelsOutput(model.labels, return_filepath=False)
)  # output FiftyOne format
predictions = list(chain.from_iterable(predictions))

# 6 Add predictions to dataset
test_dataset.set_values("predictions", predictions)

# 7 Evaluate your model
results = test_dataset.evaluate_classifications("predictions", gt_field="ground_truth", eval_key="eval")
results.print_report()
plot = results.plot_confusion_matrix()
plot.show()

# 8 Visualize results in the App
session = fo.launch_app(test_dataset)

# Optional: block execution until App is closed
session.wait()

Visualizing embeddings

FiftyOne provides the methods for dimensionality reduction and interactive plotting. When combined with embedding tasks in Flash, you can accomplish powerful workflows like clustering, similarity search, pre-annotation, and more in only a few lines of code.

import fiftyone as fo
import fiftyone.brain as fob
import numpy as np

from flash.core.data.utils import download_data
from flash.image import ImageEmbedder

# 1 Download data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip")

# 2 Load data into FiftyOne
dataset = fo.Dataset.from_dir(
    "data/hymenoptera_data/test/",
    fo.types.ImageClassificationDirectoryTree,
)

# 3 Load model
embedder = ImageEmbedder(backbone="resnet101")

# 4 Generate embeddings
filepaths = dataset.values("filepath")
embeddings = np.stack(embedder.predict(filepaths))

# 5 Visualize in FiftyOne App
results = fob.compute_visualization(dataset, embeddings=embeddings)
session = fo.launch_app(dataset)
plot = results.visualize(labels="ground_truth.label")
plot.show()

# Optional: block execution until App is closed
session.wait()
embeddings_example

IceVision

IceVision from airctic is an awesome computer vision framework which offers a curated collection of hundreds of high-quality pre-trained models for: object detection, keypoint detection, and instance segmentation. In Flash, we’ve integrated the IceVision framework to provide: data loading, augmentation, backbones, and heads. We use IceVision components in our: object detection, instance segmentation, and keypoint detection tasks. Take a look at their documentation and star IceVision on GitHub to spread the open source love!

IceData

The IceData library is a community driven dataset hub for IceVision. All of the datasets in IceData can be used out of the box with flash using our .from_folders methods and the parser argument. Take a look at our Keypoint Detection page for an example.

Albumentations with IceVision and Flash

IceVision provides two utilities for using the albumentations library with their models: - the Adapter helper class for adapting an any albumentations transform to work with IceVision records, - the aug_tfms utility function that returns a standard augmentation recipe to get the most out of your model.

In Flash, we use the aug_tfms as default transforms for the: object detection, instance segmentation, and keypoint detection tasks. You can also provide custom transforms from albumentations using the IceVisionTransformAdapter (which relies on the IceVision Adapter underneath). Here’s an example:

import albumentations as A

from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData

train_transform = {
    "per_sample_transform": IceVisionTransformAdapter([A.HorizontalFlip(), A.Normalize()]),
}

datamodule = ObjectDetectionData.from_coco(
    ...,
    train_transform=train_transform,
)

Learn2Learn

Learn2Learn is a software library for meta-learning research by Sébastien M. R. Arnold and al. (Aug 2020)


What is Meta-Learning and why you should care?

Humans can distinguish between new objects with little or no training data, However, machine learning models often require thousands, millions, billions of annotated data samples to achieve good performance while extrapolating their learned knowledge on unseen objects.

A machine learning model which could learn or learn to learn from only few new samples (K-shot learning) would have tremendous applications once deployed in production. In an extreme case, a model performing 1-shot or 0-shot learning could be the source of new kind of AI applications.

Meta-Learning is a sub-field of AI dedicated to the study of few-shot learning algorithms. This is often characterized as teaching deep learning models to learn with only a few labeled data. The goal is to repeatedly learn from K-shot examples during training that match the structure of the final K-shot used in production. It is important to note that the K-shot example seen in production are very likely to be completely out-of-distribution with new objects.

How does Meta-Learning work?

In meta-learning, the model is trained over multiple meta tasks. A meta task is the smallest unit of data and it represents the data available to the model once in its deployment environment. By doing so, we can optimise the model and get higher results.


For image classification, a meta task is comprised of shot + query elements for each class. The shots samples are used to adapt the parameters and the queries ones to update the original model weights. The classes used in the validation and testing shouldn’t be present within the training dataset, as the goal is to optimise the model performance on out-of-distribution (OOD) data with little label data.

When training the model with the meta-learning algorithm, the model will average its gradients over meta_batch_size meta tasks before performing an optimizer step. Traditionally, an meta epoch is composed of multiple meta batch.

Use Meta-Learning with Flash

With its integration within Flash, Meta Learning has never been simpler. Flash takes care of all the hard work: the tasks sampling, meta optimizer update, distributed training, etc…

Note

The users requires to provide a training dataset and testing dataset with no overlapping classes. Flash doesn’t support this feature out-of-the box.

Once done, the users are left to play the hyper-parameters associated with the meta-learning algorithm.

Here is an example using miniImageNet dataset containing 100 classes divided into 64 training, 16 validation, and 20 test classes.

# adapted from https://github.com/learnables/learn2learn/blob/master/examples/vision/protonet_miniimagenet.py#L154

import warnings

import kornia.augmentation as Ka
import kornia.geometry as Kg
import learn2learn as l2l
import torch
import torchvision
from torch import nn

import flash
from flash.core.data.io.input import DataKeys
from flash.core.data.transforms import ApplyToKeys, kornia_collate
from flash.image import ImageClassificationData, ImageClassifier

warnings.simplefilter("ignore")

# download MiniImagenet
train_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="train", download=True)
val_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="validation", download=True)

train_transform = {
    "per_sample_transform": nn.Sequential(
        ApplyToKeys(
            DataKeys.INPUT,
            nn.Sequential(
                torchvision.transforms.ToTensor(),
                Kg.Resize((196, 196)),
                # SPATIAL
                Ka.RandomHorizontalFlip(p=0.25),
                Ka.RandomRotation(degrees=90.0, p=0.25),
                Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25),
                Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25),
                # PIXEL-LEVEL
                Ka.ColorJitter(brightness=1 / 30, p=0.25),  # brightness
                Ka.ColorJitter(saturation=1 / 30, p=0.25),  # saturation
                Ka.ColorJitter(contrast=1 / 30, p=0.25),  # contrast
                Ka.ColorJitter(hue=1 / 30, p=0.25),  # hue
                Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25),
                Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25),
            ),
        ),
        ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
    ),
    "collate": kornia_collate,
    "per_batch_transform_on_device": ApplyToKeys(
        DataKeys.INPUT,
        Ka.RandomHorizontalFlip(p=0.25),
    ),
}

# construct datamodule
datamodule = ImageClassificationData.from_tensors(
    train_data=train_dataset.x,
    train_targets=torch.from_numpy(train_dataset.y.astype(int)),
    val_data=val_dataset.x,
    val_targets=torch.from_numpy(val_dataset.y.astype(int)),
    train_transform=train_transform,
)

model = ImageClassifier(
    backbone="resnet18",
    training_strategy="prototypicalnetworks",
    training_strategy_kwargs={
        "epoch_length": 10 * 16,
        "meta_batch_size": 4,
        "num_tasks": 200,
        "test_num_tasks": 2000,
        "ways": datamodule.num_classes,
        "shots": 1,
        "test_ways": 5,
        "test_shots": 1,
        "test_queries": 15,
    },
    optimizer=torch.optim.Adam,
    learning_rate=0.001,
)

trainer = flash.Trainer(
    max_epochs=200,
    gpus=2,
    accelerator="ddp_shared",
    precision=16,
)
trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

You can read their paper Learn2Learn: A Library for Meta-Learning Research.

And don’t forget to cite Learn2Learn repository in your academic publications. Find their Biblex on their repository.

PyTorch Forecasting

PyTorch Forecasting provides the models and data loading for the Tabular Forecasting task in Flash. As with all of our tasks, you won’t typically interact with the components from PyTorch Forecasting directly. However, PyTorch Forecasting provides some built-in plotting and analysis methods that are different for each model which cannot be used directly with the TabularForecaster. Instead, you can access the PyTorch Forecasting model object using the pytorch_forecasting_model attribute. In addition, we provide the convert_predictions() utility to convert predictions from the Flash format into the expected format. With these, you can train your model and perform inference using Flash but still make use of the plotting and analysis tools built in to PyTorch Forecasting.

Here’s an example, plotting the predictions and interpretation analysis from the NBeats model trained in the Tabular Forecasting documentation:

import torch

import flash
from flash.core.integrations.pytorch_forecasting import convert_predictions
from flash.core.utilities.imports import example_requires
from flash.tabular.forecasting import TabularForecaster, TabularForecastingData

example_requires(["tabular", "matplotlib"])

import matplotlib.pyplot as plt  # noqa: E402
import pandas as pd  # noqa: E402
from pytorch_forecasting.data import NaNLabelEncoder  # noqa: E402
from pytorch_forecasting.data.examples import generate_ar_data  # noqa: E402

# Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html
# 1. Create the DataModule
data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42)
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")

max_prediction_length = 20

training_cutoff = data["time_idx"].max() - max_prediction_length

datamodule = TabularForecastingData.from_data_frame(
    time_idx="time_idx",
    target="value",
    categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
    group_ids=["series"],
    # only unknown variable is "value" - and N-Beats can also not take any additional variables
    time_varying_unknown_reals=["value"],
    max_encoder_length=60,
    max_prediction_length=max_prediction_length,
    train_data_frame=data[lambda x: x.time_idx <= training_cutoff],
    val_data_frame=data,
    batch_size=32,
)

# 2. Build the task
model = TabularForecaster(
    datamodule.parameters,
    backbone="n_beats",
    backbone_kwargs={"widths": [32, 512], "backcast_loss_ratio": 0.1},
)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(), gradient_clip_val=0.01)
trainer.fit(model, datamodule=datamodule)

# 4. Generate predictions
datamodule = TabularForecastingData.from_data_frame(predict_data_frame=data, parameters=datamodule.parameters)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# Plot with PyTorch Forecasting!
predictions, inputs = convert_predictions(predictions)

fig, axs = plt.subplots(2, 3, sharex="col")

for idx in range(3):
    model.pytorch_forecasting_model.plot_interpretation(inputs, predictions, idx=idx, ax=[axs[0][idx], axs[1][idx]])

plt.show()

Here’s the visualization:

https://pl-flash-data.s3.amazonaws.com/assets/pytorch_forecasting_plot.png

VISSL

VISSL is a library from Facebook AI Research for state-of-the-art self-supervised learning. We integrate VISSL models and algorithms into Flash with the image embedder task.

Using VISSL with Flash

The ImageEmbedder task in Flash can be configured with different backbones, projection heads, image transforms and loss functions so that you can train your feature extractor using a SOTA SSL method.

from flash.image import ImageEmbedder

embedder = ImageEmbedder(
    backbone="resnet",
    training_strategy="barlow_twins",
    head="simclr_head",
    pretraining_transform="barlow_twins_transform",
    training_strategy_kwargs={"latent_embedding_dim": 256, "dims": [2048, 2048, 256]},
    pretraining_transform_kwargs={"size_crops": [196]},
)

The user can pass arguments to the training strategy, image transforms and backbones using the optional dictionary arguments the ImageEmbedder task accepts. The training strategies club together the projection head, the loss function as well as VISSL hooks for a particular algorithm and the arguments to customize these can passed via training_strategy_kwargs. As an example, in the above code block, the latent_embedding_dim is an argument to the BarlowTwins loss function from VISSL, while the dims argument configures the projection head to output 256 dim vectors for the loss function.

If you find VISSL integration in Flash useful for your research, please don’t forget to cite us and the VISSL library. You can find our bibtex on Flash and VISSL’s bibxtex on their github page.

flash

Input

DataModule

A basic DataModule class for all Flash tasks.

FlashCallback

FlashCallback is an extension of pytorch_lightning.callbacks.Callback.

OutputTransform

The OutputTransform encapsulates all the data processing logic that should run after the model.

Output

An Output encapsulates a single transform() method which is used to convert the model output into the desired output format when predicting.

InputTransform

Task

A general Task.

Trainer

flash.core

flash.core.adapter

Adapter

The Adapter is a lightweight interface that can be used to encapsulate the logic from a particular provider within a Task.

AdapterTask

The AdapterTask is a Task which wraps an Adapter and forwards all of the hooks.

flash.core.classification

ClassesOutput

A Output which applies an argmax to the model outputs (either logits or probabilities) and converts to a list.

ClassificationOutput

A base class for classification outputs.

ClassificationTask

FiftyOneLabelsOutput

A Output which converts the model outputs to FiftyOne classification format.

LabelsOutput

A Output which converts the model outputs (either logits or probabilities) to the label of the argmax classification.

LogitsOutput

A Output which simply converts the model outputs (assumed to be logits) to a list.

PredsClassificationOutput

A ClassificationOutput which gets the PREDS from the sample.

ProbabilitiesOutput

A Output which applies a softmax to the model outputs (assumed to be logits) and converts to a list.

flash.core.finetuning

FlashBaseFinetuning

FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback.

FineTuningHooks

Hooks to be used in Task and FlashBaseTuning.

Freeze

FreezeUnfreeze

NoFreeze

UnfreezeMilestones

flash.core.integrations.fiftyone

visualize

Visualizes predictions from a model with a FiftyOne Output in the FiftyOne App.

flash.core.integrations.icevision

IceVisionTransformAdapter

type _sphinx_paramlinks_flash.core.integrations.icevision.transforms.IceVisionTransformAdapter.transform

List[Callable]

flash.core.integrations.pytorch_forecasting

convert_predictions

rtype

Tuple[Dict[str, Any], List]

flash.core.model

BenchmarkConvergenceCI

Specialized callback only used during testing Keeps track metrics during training.

CheckDependenciesMeta

ModuleWrapperBase

The ModuleWrapperBase is a base for classes which wrap a LightningModule or an instance of ModuleWrapperBase.

DatasetProcessor

The DatasetProcessor mixin provides hooks for classes which need custom logic for producing the data loaders for each running stage given the corresponding dataset.

Task

A general Task.

flash.core.registry

FlashRegistry

This class is used to register function or functools.partial class to a registry.

ExternalRegistry

The ExternalRegistry is a FlashRegistry that can point to an external provider via a getter function.

ConcatRegistry

The ConcatRegistry can be used to concatenate multiple registries of different types together.

flash.core.optimizers

LARS

Extends SGD in PyTorch with LARS scaling from the paper Large batch training of Convolutional Networks.

LAMB

Extends ADAM in pytorch to incorporate LAMB algorithm from the paper: Large batch optimization for deep learning: Training BERT in 76 minutes.

LinearWarmupCosineAnnealingLR

Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr and base_lr followed by a cosine annealing schedule between base_lr and eta_min.

Utilities

from_argparse_args

Modified version of pytorch_lightning.utilities.argparse.from_argparse_args() which populates valid_kwargs from pytorch_lightning.Trainer.

get_callable_name

rtype

str

get_callable_dict

rtype

Union[Dict, Mapping]

flash.core.data

flash.core.data.base_viz

BaseVisualization

This Base Class is used to create visualization tool on top of InputTransform hooks.

flash.core.data.batch

default_uncollate

This function is used to uncollate a batch into samples.

flash.core.data.callback

BaseDataFetcher

This class is used to profile InputTransform hook outputs.

ControlFlow

FlashCallback

FlashCallback is an extension of pytorch_lightning.callbacks.Callback.

flash.core.data.data_module

DataModule

A basic DataModule class for all Flash tasks.

flash.core.data.utilities.classification

TargetFormatter

A TargetFormatter is used to convert targets of a given type to a standard format required by the loss function.

SingleNumericTargetFormatter

A TargetFormatter for targets that contain a single numeric value (the class index).

SingleLabelTargetFormatter

A TargetFormatter for targets that contain a single string label.

SingleBinaryTargetFormatter

A TargetFormatter for targets that are one-hot encoded binaries.

MultiNumericTargetFormatter

A TargetFormatter for targets that contain multiple numeric values (the class indices).

MultiLabelTargetFormatter

A TargetFormatter for targets that contain multiple string labels in a list.

CommaDelimitedMultiLabelTargetFormatter

A TargetFormatter for targets that contain a string with multiple comma-delimited labels.

SpaceDelimitedTargetFormatter

A TargetFormatter for targets that contain a string with multiple space-delimited labels.

MultiBinaryTargetFormatter

A TargetFormatter for targets that are multi-hot binary.

get_target_formatter

Get the TargetFormatter object to use for the given targets.

flash.core.data.properties

Properties

flash.core.data.splits

SplitDataset

SplitDataset is used to create Dataset Subset using indices.

flash.core.data.transforms

ApplyToKeys

The ApplyToKeys class is an nn.Sequential which applies the given transforms to the given keys from the input.

KorniaParallelTransforms

The KorniaParallelTransforms class is an nn.Sequential which will apply the given transforms to each input (to .forward) in parallel, whilst sharing the random state (._params).

kornia_collate

Kornia transforms add batch dimension which need to be removed.

flash.core.data.utils

FuncModule

This class is used to wrap a callable within a nn.Module and apply the wrapped function in __call__

convert_to_modules

download_data

Download file with progressbar.

flash.core.data.io.input

Input

DataKeys

The DataKeys enum contains the keys that are used by built-in data sources to refer to inputs and targets.

InputFormat

The InputFormat enum contains the data source names used by all of the default from_* methods in DataModule.

flash.core.data.io.classification_input

ClassificationInputMixin

The ClassificationInputMixin class provides utility methods for handling classification targets.

flash.core.data.io.input_transform

InputTransform

flash.core.data.io.output

Output

An Output encapsulates a single transform() method which is used to convert the model output into the desired output format when predicting.

flash.core.data.io.output_transform

OutputTransform

The OutputTransform encapsulates all the data processing logic that should run after the model.

flash.core.serve

ModelComponent

alias of object

Composition

Create a composition which define computations / endpoints to create & run.

Endpoint

An endpoint maps a route and request/response payload to components.

Servable

ModuleWrapperBase around a model object to enable serving at scale.

expose

Expose a function/method via a web API for serving model inference.

flash.image

Classification

ImageClassifier

The ImageClassifier is a Task for classifying images.

ImageClassificationFiftyOneInput

ImageClassificationData

The ImageClassificationData class is a DataModule with a set of classmethods for loading data for image classification.

ImageClassificationInputTransform

classification.data.MatplotlibVisualization

Process and show the image batch and its associated label using matplotlib.

Object Detection

ObjectDetector

The ObjectDetector is a Task for detecting objects in images.

ObjectDetectionData

The ObjectDetectionData class is a DataModule with a set of classmethods for loading data for object detection.

detection.input.FiftyOneParser

detection.input.ObjectDetectionFiftyOneInput

detection.output.FiftyOneDetectionLabelsOutput

A Output which converts model outputs to FiftyOne detection format.

Keypoint Detection

KeypointDetector

The KeypointDetector is a Task for detecting keypoints in images.

KeypointDetectionData

The KeypointDetectionData class is a DataModule with a set of classmethods for loading data for keypoint detection.

Instance Segmentation

InstanceSegmentation

The InstanceSegmentation is a Task for detecting objects in images.

InstanceSegmentationData

Embedding

ImageEmbedder

The ImageEmbedder is a Task for obtaining feature vectors (embeddings) from images.

Segmentation

SemanticSegmentation

SemanticSegmentation is a Task for semantic segmentation of images.

SemanticSegmentationData

The SemanticSegmentationData class is a DataModule with a set of classmethods for loading data for semantic segmentation.

SemanticSegmentationInputTransform

segmentation.data.SegmentationMatplotlibVisualization

Process and show the image batch and its associated label using matplotlib.

segmentation.input.SemanticSegmentationInput

segmentation.input.SemanticSegmentationFilesInput

segmentation.input.SemanticSegmentationFolderInput

segmentation.input.SemanticSegmentationNumpyInput

segmentation.input.SemanticSegmentationTensorInput

segmentation.input.SemanticSegmentationFiftyOneInput

segmentation.input.SemanticSegmentationDeserializer

segmentation.model.SemanticSegmentationOutputTransform

segmentation.output.FiftyOneSegmentationLabelsOutput

A Output which converts the model outputs to FiftyOne segmentation format.

segmentation.output.SegmentationLabelsOutput

A Output which converts the model outputs to the label of the argmax classification per pixel in the image for semantic segmentation tasks.

segmentation.input_transform.prepare_target

Convert the target mask to long and remove the channel dimension.

Style Transfer

StyleTransfer

StyleTransfer is a Task for transferring the style from one image onto another.

StyleTransferData

The StyleTransferData class is a DataModule with a set of classmethods for loading data for image style transfer.

StyleTransferInputTransform

raise_not_supported

rtype

NoReturn

flash.image.data

ImageDeserializer

ImageNumpyInput

ImageTensorInput

flash.audio

Classification

AudioClassificationData

The AudioClassificationData class is a DataModule with a set of classmethods for loading data for audio classification.

AudioClassificationInputTransform

Speech Recognition

SpeechRecognitionData

The SpeechRecognitionData class is a DataModule with a set of classmethods for loading data for speech recognition.

SpeechRecognition

The SpeechRecognition task is a Task for converting speech to text.

speech_recognition.input.SpeechRecognitionCSVInput

speech_recognition.input.SpeechRecognitionJSONInput

speech_recognition.input.BaseSpeechRecognition

speech_recognition.input.SpeechRecognitionFileInput

speech_recognition.input.SpeechRecognitionPathsInput

speech_recognition.input.SpeechRecognitionDatasetInput

speech_recognition.input.SpeechRecognitionDeserializer

speech_recognition.output_transform.SpeechRecognitionOutputTransform

flash.pointcloud

Segmentation

PointCloudSegmentation

The PointCloudClassifier is a ClassificationTask that classifies pointcloud data.

PointCloudSegmentationData

segmentation.input.PointCloudSegmentationFoldersInput

segmentation.input.PointCloudSegmentationDatasetInput

Object Detection

PointCloudObjectDetector

The PointCloudObjectDetector is a ClassificationTask that classifies pointcloud data.

PointCloudObjectDetectorData

detection.input.PointCloudObjectDetectorDatasetInput

detection.open3d_ml.input.PointCloudObjectDetectorFoldersInput

flash.tabular

Classification

TabularClassifier

The TabularClassifier is a Task for classifying tabular data.

TabularClassificationData

The TabularClassificationData class is a DataModule with a set of classmethods for loading data for tabular classification.

classification.input.TabularClassificationDataFrameInput

classification.input.TabularClassificationCSVInput

Regression

TabularRegressor

The TabularRegressor is a Task for classifying tabular data.

TabularRegressionData

The TabularRegressionData class is a DataModule with a set of classmethods for loading data for tabular regression.

regression.input.TabularRegressionDataFrameInput

regression.input.TabularRegressionCSVInput

Forecasting

TabularForecaster

TabularForecastingData

The TabularForecastingData class is a DataModule with a set of classmethods for loading data for tabular forecasting.

forecasting.input.TabularForecastingDataFrameInput

flash.tabular.data

TabularData

input.TabularDataFrameInput

input.TabularDeserializer

flash.text

Classification

TextClassifier

The TextClassifier is a Task for classifying text.

TextClassificationData

The TextClassificationData class is a DataModule with a set of classmethods for loading data for text classification.

classification.input.TextClassificationInput

classification.input.TextClassificationCSVInput

classification.input.TextClassificationJSONInput

classification.input.TextClassificationDataFrameInput

classification.input.TextClassificationParquetInput

classification.input.TextClassificationListInput

Embedding

TextEmbedder

The TextEmbedder is a Task for generating sentence embeddings, training and validation.

Question Answering

QuestionAnsweringTask

The QuestionAnsweringTask is a Task for extractive question answering.

QuestionAnsweringData

The QuestionAnsweringData class is a DataModule with a set of classmethods for loading data for extractive question answering.

question_answering.input.QuestionAnsweringInputBase

question_answering.input.QuestionAnsweringCSVInput

question_answering.input.QuestionAnsweringJSONInput

question_answering.input.QuestionAnsweringSQuADInput

question_answering.input.QuestionAnsweringDictionaryInput

question_answering.output_transform.QuestionAnsweringOutputTransform

Summarization

SummarizationTask

The SummarizationTask is a Task for Seq2Seq text summarization.

SummarizationData

The SummarizationData class is a DataModule with a set of classmethods for loading data for text summarization.

Translation

TranslationTask

The TranslationTask is a Task for Seq2Seq text translation.

TranslationData

The TranslationData class is a DataModule with a set of classmethods for loading data for text translation.

General Seq2Seq

Seq2SeqTask

General Task for Sequence2Sequence.

seq2seq.core.input.Seq2SeqInputBase

seq2seq.core.input.Seq2SeqCSVInput

seq2seq.core.input.Seq2SeqJSONInput

seq2seq.core.input.Seq2SeqListInput

flash.text.input

input.TextDeserializer

flash.video

Classification

VideoClassifier

Task that classifies videos.

VideoClassificationData

The VideoClassificationData class is a DataModule with a set of classmethods for loading data for video classification.

classification.input.VideoClassificationInput

classification.input.VideoClassificationFiftyOneInput

classification.input.VideoClassificationFoldersInput

classification.input.VideoClassificationFilesInput

classification.input.VideoClassificationDataFrameInput

classification.input.VideoClassificationCSVInput

classification.input.VideoClassificationPathsPredictInput

classification.input.VideoClassificationDataFramePredictInput

classification.input.VideoClassificationCSVPredictInput

classification.input_transform.VideoClassificationInputTransform

int = 244, temporal_sub_sample: int = 8, mean: torch.Tensor = torch.tensor, std: torch.Tensor = torch.tensor, data_format: str = 'BCTHW', same_on_frame: bool = False)

flash.graph

Classification

GraphClassifier

The GraphClassifier is a Task for classifying graphs.

GraphClassificationData

The GraphClassificationData class is a DataModule with a set of classmethods for loading data for graph classification.

classification.input.GraphClassificationDatasetInput

classification.input_transform.GraphClassificationInputTransform

Embedding

GraphEmbedder

The GraphEmbedder is a Task for obtaining feature vectors (embeddings) from graphs.

Introduction / Set-up

Welcome

Before you begin, we’d like to express our gratitude to you for wanting to add a task to Flash. With Flash our aim is to create a great user experience, enabling awesome advanced applications with just a few lines of code. We’re really pleased with what we’ve achieved with Flash and we hope you will be too. Now let’s dive in!

Set-up

The Task template is designed to guide you through contributing a task to Flash. It contains the code, tests, and examples for a task that performs classification with a multi-layer perceptron, intended for use with the classic data sets from scikit-learn. The Flash tasks are organized in folders by data-type (image, text, video, etc.), with sub-folders for different task types (classification, regression, etc.).

Copy the files in flash/template/classification to a new sub-directory under the relevant data-type. If a data-type folder already exists for your task, then a task type sub-folder should be added containing the template files. If a data-type folder doesn’t exist, then you will need to add that too. You should also copy the files from tests/template/classification to the corresponding data-type, task type folder in tests. For example, if you were adding an image classification task, you would do:

mkdir flash/image/classification
cp flash/template/classification/* flash/image/classification/
mkdir tests/image/classification
cp tests/template/classification/* tests/image/classification/

Tutorials

The tutorials in this section will walk you through all of the components you need to implement (or adapt from the template) for your custom task.

  • The Data: our first tutorial goes over the best practices for implementing everything you need to connect data to your task

  • The Backbones: the second tutorial shows you how to create an extensible backbone registry for your task

  • The Task: now that we have the data and the models, in this tutorial we create our custom task

  • Optional Extras: this tutorial covers some optional extras you can add if needed for your particular task

  • The Example: this tutorial guides you through creating some simple examples showing your task in action

  • The Tests: in this tutorial, we cover best practices for writing some tests for your new task

  • The Docs: in our final tutorial, we provide a template for you to create the docs page for your task

The Data

The first step to contributing a task is to implement the classes we need to load some data. Inside data.py you should implement:

  1. some Input classes (optional)

  2. a InputTransform

  3. a DataModule

  4. a BaseVisualization (optional)

  5. a OutputTransform (optional)

Input

The Input class contains the logic for data loading from different sources such as folders, files, tensors, etc. Every Flash DataModule can be instantiated with from_datasets(). For each additional way you want the user to be able to instantiate your DataModule, you’ll need to create a Input. Each Input has 2 methods:

  • load_data() takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata.

  • load_sample() then takes as input a single element from the output of load_data and returns a sample.

By default these methods just return their input, so you don’t need both a load_data() and a load_sample() to create a Input. Where possible, you should override one of our existing Input classes.

Let’s start by implementing a TemplateNumpyClassificationInput, which overrides ClassificationInputMixin. The main Input method that we have to implement is load_data(). ClassificationInputMixin provides utilities for handling targets within flash which need to be called from the load_data() and load_sample(). In this Input, we’ll also set the num_features attribute so that we can access it later.

Here’s the code for our TemplateNumpyClassificationInput.load_data method:

def load_data(
    self,
    examples: Collection[np.ndarray],
    targets: Optional[Sequence[Any]] = None,
    target_formatter: Optional[TargetFormatter] = None,
) -> Sequence[Dict[str, Any]]:
    """Sets the ``num_features`` attribute and calls ``super().load_data``.

    Args:
        examples: The ``np.ndarray`` (num_examples x num_features).
        targets: Associated targets.
        target_formatter: Optionally provide a ``TargetFormatter`` to control how targets are formatted.

    Returns:
        A sequence of samples / sample metadata.
    """
    if not self.predicting and isinstance(examples, np.ndarray):
        self.num_features = examples.shape[1]
    if targets is not None:
        self.load_target_metadata(targets, target_formatter=target_formatter)
    return to_samples(examples, targets)

and here’s the code for the TemplateNumpyClassificationInput.load_sample method:

def load_sample(self, sample: Dict[str, Any]) -> Any:
    if DataKeys.TARGET in sample:
        sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET])
    return sample

Note

Later, when we add our DataModule implementation, we’ll make num_features available to the user.

For our template Task, it would be cool if the user could provide a scikit-learn Bunch as the data source. To achieve this, we’ll add a TemplateSKLearnClassificationInput whose load_data expects a Bunch as input. We override our TemplateNumpyClassificationInput so that we can call super with the data and targets extracted from the Bunch. We perform two additional steps here to improve the user experience:

  1. We set the num_classes attribute on the dataset. If num_classes is set, it is automatically made available as a property of the DataModule.

  2. We create and set a ClassificationState. The labels provided here will be shared with the Labels output, so the user doesn’t need to provide them.

Here’s the code for the TemplateSKLearnClassificationInput.load_data method:

def load_data(self, data: Bunch, target_formatter: Optional[TargetFormatter] = None) -> Sequence[Dict[str, Any]]:
    """Gets the ``data`` and ``target`` attributes from the ``Bunch`` and passes them to ``super().load_data``.

    Args:
        data: The scikit-learn data ``Bunch``.
        target_formatter: Optionally provide a ``TargetFormatter`` to control how targets are formatted.

    Returns:
        A sequence of samples / sample metadata.
    """
    return super().load_data(data.data, data.target, target_formatter=target_formatter)

We can customize the behaviour of our load_data() for different stages, by prepending train, val, test, or predict. For our TemplateSKLearnClassificationInput, we don’t want to provide any targets to the model when predicting. We can implement predict_load_data like this:

def predict_load_data(self, data: Bunch) -> Sequence[Dict[str, Any]]:
    """Avoid including targets when predicting.

    Args:
        data: The scikit-learn data ``Bunch``.

    Returns:
        A sequence of samples / sample metadata.
    """
    return super().load_data(data.data)

InputTransform

The InputTransform object contains all the data transforms. Internally we inject the InputTransform transforms at several points along the pipeline.

Defining the standard transforms (typically at least a per_sample_transform should be defined) for your InputTransform involves simply overriding the required hook to return a callable transform.

For our TemplateInputTransform, we’ll just configure an input_per_sample_transform and a target_per_sample_transform. Let’s first define a to_tensor transform as a staticmethod:

@staticmethod
def input_to_tensor(input: np.ndarray):
    """Transform which creates a tensor from the given numpy ``ndarray`` and converts it to ``float``"""
    return torch.from_numpy(input).float()

Now in our input_per_sample_transform hook, we return the transform:

def input_per_sample_transform(self) -> Callable:
    return self.input_to_tensor

To convert the targets to a tensor we can simply use torch.as_tensor. Here’s our target_per_sample_transform:

def target_per_sample_transform(self) -> Callable:
    return self.target_to_tensor

DataModule

The DataModule is responsible for creating the DataLoader and injecting the transforms for each stage. When the user calls a from_* method (such as from_numpy()), the following steps take place:

  1. The from_() method is called with the name of the Input to use and the inputs to provide to load_data() for each stage.

  2. The InputTransform is created from cls.input_transform_cls (if it wasn’t provided by the user) with any provided transforms.

  3. The Input of the provided name is retrieved from the InputTransform.

  4. A BaseAutoDataset is created from the Input for each stage.

  5. The DataModule is instantiated with the data sets.


To create our TemplateData DataModule, we first need to attach our input transform class like this:

input_transform_cls = TemplateInputTransform

Since we provided a NUMPY Input in the TemplateInputTransform, from_numpy() will now work with our TemplateData.

If you’ve defined a fully custom Input (like our TemplateSKLearnClassificationInput), then you will need to write a from_* method for each. Here’s the from_sklearn method for our TemplateData:

@classmethod
def from_sklearn(
    cls,
    train_bunch: Optional[Bunch] = None,
    val_bunch: Optional[Bunch] = None,
    test_bunch: Optional[Bunch] = None,
    predict_bunch: Optional[Bunch] = None,
    train_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform,
    val_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform,
    test_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform,
    predict_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform,
    input_cls: Type[Input] = TemplateSKLearnClassificationInput,
    transform_kwargs: Optional[Dict] = None,
    **data_module_kwargs: Any,
) -> "TemplateData":
    """This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and creates the
    ``TemplateData`` with them.

    Args:
        train_bunch: The scikit-learn ``Bunch`` containing the train data.
        val_bunch: The scikit-learn ``Bunch`` containing the validation data.
        test_bunch: The scikit-learn ``Bunch`` containing the test data.
        predict_bunch: The scikit-learn ``Bunch`` containing the predict data.
        train_transform: The dictionary of transforms to use during training which maps
            :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms.
        val_transform: The dictionary of transforms to use during validation which maps
            :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms.
        test_transform: The dictionary of transforms to use during testing which maps
            :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms.
        predict_transform: The dictionary of transforms to use during predicting which maps
            :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms.

    Returns:
        The constructed data module.
    """
    ds_kw = dict(
        transform_kwargs=transform_kwargs,
        input_transforms_registry=cls.input_transforms_registry,
    )

    train_input = input_cls(RunningStage.TRAINING, train_bunch, transform=train_transform, **ds_kw)
    target_formatter = getattr(train_input, "target_formatter", None)

    return cls(
        train_input,
        input_cls(
            RunningStage.VALIDATING,
            val_bunch,
            transform=val_transform,
            target_formatter=target_formatter,
            **ds_kw,
        ),
        input_cls(
            RunningStage.TESTING,
            test_bunch,
            transform=test_transform,
            target_formatter=target_formatter,
            **ds_kw,
        ),
        input_cls(RunningStage.PREDICTING, predict_bunch, transform=predict_transform, **ds_kw),
        **data_module_kwargs,
    )

The final step is to implement the num_features property for our TemplateData. This is just a convenience for the user that finds the num_features attribute on any of the data sets and returns it. Here’s the code:

@property
def num_features(self) -> Optional[int]:
    """Tries to get the ``num_features`` from each dataset in turn and returns the output."""
    n_fts_train = getattr(self.train_dataset, "num_features", None)
    n_fts_val = getattr(self.val_dataset, "num_features", None)
    n_fts_test = getattr(self.test_dataset, "num_features", None)
    return n_fts_train or n_fts_val or n_fts_test

BaseVisualization

An optional step is to implement a BaseVisualization. The BaseVisualization lets you control how data at various points in the pipeline can be visualized. This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms.

Note

Don’t worry about implementing it right away, you can always come back and add it later!

Here’s the code for our TemplateVisualization which just prints the data:

class TemplateVisualization(BaseVisualization):
    """The ``TemplateVisualization`` class is a :class:`~flash.core.data.callbacks.BaseVisualization` that just
    prints the data.

    If you want to provide a visualization with your task, you can override these hooks.
    """

    def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
        print(samples)

    def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage):
        print(samples)

We can configure our custom visualization in the TemplateData using configure_data_fetcher() like this:

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
    """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher``
    method."""
    return TemplateVisualization(*args, **kwargs)

OutputTransform

OutputTransform contains any transforms that need to be applied after the model. You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. As an example, here’s the SemanticSegmentationOutputTransform which decodes tokenized model outputs:

class SemanticSegmentationOutputTransform(OutputTransform):
    def per_sample_transform(self, sample: Any) -> Any:
        resize = K.geometry.Resize(sample[DataKeys.METADATA]["size"], interpolation="bilinear")
        sample[DataKeys.PREDS] = resize(sample[DataKeys.PREDS])
        sample[DataKeys.INPUT] = resize(sample[DataKeys.INPUT])
        return super().per_sample_transform(sample)

In your Input or InputTransform, you can add metadata to the batch using the METADATA key. Your OutputTransform can then use this metadata in its transforms. You should use this approach if your postprocessing depends on the state of the input before the InputTransform transforms. For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the METADATA. Here’s an example from the SemanticSegmentationNumpyInput:

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
    sample[DataKeys.INPUT] = torch.from_numpy(sample[DataKeys.INPUT])
    if DataKeys.TARGET in sample:
        sample[DataKeys.TARGET] = torch.from_numpy(sample[DataKeys.TARGET])
    return super().load_sample(sample)

The METADATA can now be referenced in your OutputTransform. For example, here’s the code for the per_sample_transform method of the SemanticSegmentationOutputTransform:

def per_sample_transform(self, sample: Any) -> Any:
    resize = K.geometry.Resize(sample[DataKeys.METADATA]["size"], interpolation="bilinear")
    sample[DataKeys.PREDS] = resize(sample[DataKeys.PREDS])
    sample[DataKeys.INPUT] = resize(sample[DataKeys.INPUT])
    return super().per_sample_transform(sample)

Now that you’ve got some data, it’s time to add some backbones for your task!

The Backbones

Now that you’ve got a way of loading data, you should implement some backbones to use with your Task. Create a FlashRegistry to use with your Task in backbones.py.

The registry allows you to register backbones for your task that can be selected by the user. The backbones can come from anywhere as long as you can register a function that loads the backbone. Furthermore, the user can add their own models to the existing backbones, without having to write their own Task!

You can create a registry like this:

TEMPLATE_BACKBONES = FlashRegistry("backbones")

Let’s add a simple MLP backbone to our registry. We need a function that creates the backbone and returns it along with the output size (so that we can create the model head in our Task). You can use any name for the function, although we use load_{model name} by convention. You also need to provide name and namespace of the backbone. The standard for namespace is data_type/task_type, so for an image classification task the namespace will be image/classification. Here’s the code:

@TEMPLATE_BACKBONES(name="mlp-128", namespace="template/classification")
def load_mlp_128(num_features, **_):
    """A simple MLP backbone with 128 hidden units."""
    return (
        nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(True),
            nn.BatchNorm1d(128),
        ),
        128,
    )

Here’s another example with a slightly more complex model:

@TEMPLATE_BACKBONES(name="mlp-128-256", namespace="template/classification")
def load_mlp_128_256(num_features, **_):
    """Two layer MLP backbone with 128 and 256 hidden units respectively."""
    return (
        nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(True),
            nn.BatchNorm1d(128),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.BatchNorm1d(256),
        ),
        256,
    )

Here’s a another example, which adds DINO pretrained model from PyTorch Hub to the IMAGE_CLASSIFIER_BACKBONES, from flash/image/classification/backbones/transformers.py:

def dino_vitb16(*_, **__):
    backbone = torch.hub.load("facebookresearch/dino:main", "dino_vitb16")
    return backbone, 768

Once you’ve got some data and some backbones, implement your task!

The Task

Once you’ve implemented a Flash DataModule and some backbones, you should implement your Task in model.py. The Task is responsible for: setting up the backbone, performing the forward pass of the model, and calculating the loss and any metrics. Remember that, under the hood, the Flash Task is simply a LightningModule with some helpful defaults.

To build your task, you can start by overriding the base Task or any of the existing Task implementations. For example, in our scikit-learn example, we can just override ClassificationTask which provides good defaults for classification.

You should attach your backbones registry as a class attribute like this:

class TemplateSKLearnClassifier(ClassificationTask):

    backbones: FlashRegistry = TEMPLATE_BACKBONES

Model architecture and hyper-parameters

In the __init__(), you will need to configure defaults for the:

  • loss function

  • optimizer

  • metrics

  • backbone / model

You will also need to create the backbone from the registry and create the model head. Here’s the code:

def __init__(
    self,
    num_features: int,
    num_classes: Optional[int] = None,
    labels: Optional[List[str]] = None,
    backbone: Union[str, Tuple[nn.Module, int]] = "mlp-128",
    backbone_kwargs: Optional[Dict] = None,
    loss_fn: LOSS_FN_TYPE = None,
    optimizer: OPTIMIZER_TYPE = "Adam",
    lr_scheduler: LR_SCHEDULER_TYPE = None,
    metrics: METRICS_TYPE = None,
    learning_rate: Optional[float] = None,
    multi_label: bool = False,
):
    self.save_hyperparameters()

    if labels is not None and num_classes is None:
        num_classes = len(labels)

    super().__init__(
        model=None,
        loss_fn=loss_fn,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        metrics=metrics,
        learning_rate=learning_rate,
        multi_label=multi_label,
        num_classes=num_classes,
        labels=labels,
    )

    if not backbone_kwargs:
        backbone_kwargs = {}

    if isinstance(backbone, tuple):
        self.backbone, out_features = backbone
    else:
        self.backbone, out_features = self.backbones.get(backbone)(num_features=num_features, **backbone_kwargs)

    self.head = nn.Linear(out_features, num_classes)

Note

We call save_hyperparameters() to log the arguments to the __init__ as hyperparameters. Read more here.

Adding the model routines

You should override the {train,val,test,predict}_step methods. The default {train,val,test,predict}_step implementations in Task expect a tuple containing the input (to be passed to the model) and target (to be used when computing the loss), and should be suitable for most applications. In our template example, we just extract the input and target from the input mapping and forward them to the super methods. Here’s the code for the training_step:

def training_step(self, batch: Any, batch_idx: int) -> Any:
    """For the training step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` and
    :attr:`~flash.core.data.io.input.DataKeys.TARGET` keys from the input and forward them to the
    :meth:`~flash.core.model.Task.training_step`."""
    batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
    return super().training_step(batch, batch_idx)

We use the same code for the validation_step and test_step. For predict_step we don’t need the targets, so our code looks like this:

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
    """For the predict step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` key from the
    input and forward it to the :meth:`~flash.core.model.Task.predict_step`."""
    batch = batch[DataKeys.INPUT]
    return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)

Note

You can completely replace the {train,val,test,predict}_step methods (that is, without a call to super) if you need more custom behaviour for your Task at a particular stage.

Finally, we use our backbone and head in a custom forward pass:

def forward(self, x) -> torch.Tensor:
    """First call the backbone, then the model head."""
    x = self.backbone(x)
    return self.head(x)

Now that you’ve got your task, take a look at some optional advanced features you can add or go ahead and create some examples showing your task in action!

Optional Extras

Organize your transforms in transforms.py

It can be useful to define your InputTransform in an input_transform.py file. Here’s an example from image/classification/input_transform.py:

@dataclass
class ImageClassificationInputTransform(InputTransform):

    image_size: Tuple[int, int] = (196, 196)
    mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
    std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)

    def input_per_sample_transform(self):
        return T.Compose([T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std)])

    def train_input_per_sample_transform(self):
        return T.Compose(
            [T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std), T.RandomHorizontalFlip()]
        )

    def target_per_sample_transform(self) -> Callable:
        return torch.as_tensor

    def collate(self) -> Callable:
        # TODO: Remove kornia collate for default_collate
        return kornia_collate

Add outputs to your Task

We recommend that you do most of the heavy lifting in the OutputTransform. Specifically, it should include any formatting and transforms that should always be applied to the predictions. If you want to support different use cases that require different prediction formats, you should add some Output implementations in an output.py file.

Some good examples are in flash/core/classification.py. Here’s the ClassesOutput Output:

@CLASSIFICATION_OUTPUTS(name="classes")
class ClassesOutput(PredsClassificationOutput):
    """A :class:`.Output` which applies an argmax to the model outputs (either logits or probabilities) and
    converts to a list.

    Args:
        multi_label: If true, treats outputs as multi label logits.
        threshold: The threshold to use for multi_label classification.
    """

    def __init__(self, multi_label: bool = False, threshold: float = 0.5):
        super().__init__(multi_label)

        self.threshold = threshold

    def transform(self, sample: Any) -> Union[int, List[int]]:
        sample = super().transform(sample)
        if self.multi_label:
            one_hot = (sample.sigmoid() > self.threshold).int().tolist()
            result = []
            for index, value in enumerate(one_hot):
                if value == 1:
                    result.append(index)
            return result
        return torch.argmax(sample, -1).tolist()

Alternatively, here’s the LogitsOutput Output:

@CLASSIFICATION_OUTPUTS(name="logits")
class LogitsOutput(PredsClassificationOutput):
    """A :class:`.Output` which simply converts the model outputs (assumed to be logits) to a list."""

    def transform(self, sample: Any) -> Any:
        return super().transform(sample).tolist()

Take a look at Predictions (inference) to learn more.


Once you’ve added any optional extras, it’s time to create some examples showing your task in action!

The Example

Now you’ve implemented your task, it’s time to add an example showing how cool it is! We usually provide one example in flash_examples/. You can base these off of our template.py examples.

The example should:

  1. download the data (we’ll add the example to our CI later on, so choose a dataset small enough that it runs in reasonable time)

  2. load the data into a DataModule

  3. create an instance of the Task

  4. create a Trainer

  5. call finetune() or fit() to train your model

  6. generate predictions for a few examples

  7. save the checkpoint

For our template example we don’t have a pretrained backbone, so we can just call fit() rather than finetune(). Here’s the full example (flash_examples/template.py):

import numpy as np
import torch
from sklearn import datasets

import flash
from flash.template import TemplateData, TemplateSKLearnClassifier

# 1. Create the DataModule
datamodule = TemplateData.from_sklearn(
    train_bunch=datasets.load_iris(),
    val_split=0.1,
    batch_size=4,
)

# 2. Build the task
model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Classify a few examples
datamodule = TemplateData.from_numpy(
    predict_data=[
        np.array([4.9, 3.0, 1.4, 0.2]),
        np.array([6.9, 3.2, 5.7, 2.3]),
        np.array([7.2, 3.0, 5.8, 1.6]),
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("template_model.pt")

We get this output:

['setosa', 'virginica', 'versicolor']

Now that you’ve got an example showing your awesome task in action, it’s time to write some tests!

The Tests

Our next step is to create some tests for our Task. For the TemplateSKLearnClassifier, we will just create some basic tests. You should expand on these to include tests for any specific functionality you have in your Task.

Smoke tests

We use smoke tests, usually called test_smoke, throughout. These just instantiate the class we are testing, to see that they can be created without raising any errors.

tests/examples/test_scripts.py

Before we write our custom tests, we should add out examples to the CI. To do this, add a line for each example (finetuning and predict) to the annotation of test_example in tests/examples/test_scripts.py. Here’s how those lines look for our template.py examples:

pytest.param(
    "finetuning", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),
...
pytest.param(
    "predict", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),

test_data.py

The most important tests in test_data.py check that the from_* methods work correctly. In the class TestTemplateData, we have two of these: test_from_numpy and test_from_sklearn. In general, there should be one test_from_* method for each input you have configured.

Here’s the code for test_from_numpy:

    def test_from_numpy(self):
        """Tests that ``TemplateData`` is properly created when using the ``from_numpy`` method."""
        data = np.random.rand(10, self.num_features)
        targets = np.random.randint(0, self.num_classes, (10,))

        # instantiate the data module
        dm = TemplateData.from_numpy(
            train_data=data,
            train_targets=targets,
            val_data=data,
            val_targets=targets,
            test_data=data,
            test_targets=targets,
            batch_size=2,
            num_workers=0,
        )
        assert dm is not None
        assert dm.train_dataloader() is not None
        assert dm.val_dataloader() is not None
        assert dm.test_dataloader() is not None

        # check training data
        data = next(iter(dm.train_dataloader()))
        rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET]
        assert rows.shape == (2, self.num_features)
        assert targets.shape == (2,)

        # check val data
        data = next(iter(dm.val_dataloader()))
        rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET]
        assert rows.shape == (2, self.num_features)
        assert targets.shape == (2,)

        # check test data
        data = next(iter(dm.test_dataloader()))
        rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET]
        assert rows.shape == (2, self.num_features)
        assert targets.shape == (2,)

test_model.py

In test_model.py, we first have test_forward and test_train. These test that tensors can be passed to the forward and that the Task can be trained. Here’s the code for test_forward and test_train:

@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
@pytest.mark.parametrize("num_classes", [4, 256])
@pytest.mark.parametrize("shape", [(1, 3), (2, 128)])
def test_forward(num_classes, shape):
    """Tests that a tensor can be given to the model forward and gives the correct output size."""
    model = TemplateSKLearnClassifier(
        num_features=shape[1],
        num_classes=num_classes,
    )
    model.eval()

    row = torch.rand(*shape)

    out = model(row)
    assert out.shape == (shape[0], num_classes)
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_train(tmpdir):
    """Tests that the model can be trained on our ``DummyDataset``."""
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model, train_dl)

We also include tests for validating and testing: test_val, and test_test. These tests are very similar to test_train, but here they are for completeness:

@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_val(tmpdir):
    """Tests that the model can be validated on our ``DummyDataset``."""
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    val_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.validate(model, val_dl)
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_test(tmpdir):
    """Tests that the model can be tested on our ``DummyDataset``."""
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    test_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.test(model, test_dl)

We also include tests for prediction named test_predict_* for each of our data sources. In our case, we have test_predict_numpy and test_predict_sklearn. These tests should load the data with a DataModule and generate predictions with Trainer.predict. Here’s test_predict_sklearn as an example:

@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_predict_sklearn():
    """Tests that we can generate predictions from a scikit-learn ``Bunch``."""
    bunch = datasets.load_iris()
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    datamodule = TemplateData.from_sklearn(predict_bunch=bunch, batch_size=1)
    trainer = Trainer()
    out = trainer.predict(model, datamodule=datamodule, output="classes")
    assert isinstance(out[0][0], int)

Now that you’ve written the tests, it’s time to add some docs!

The Docs

The final step is to add some docs. For each Task in Flash, we have a docs page in docs/source/reference. You should create a .rst file there with the following:

  • a brief description of the task

  • the predict example

  • the finetuning example

  • any relevant API reference

Here are the contents of docs/source/reference/template.rst which breaks down each of these steps:


.. _template:

########
Template
########

********
The Task
********

Here you should add a description of your task. For example:
Classification is the task of assigning one of a number of classes to each data point.

------

*******
Example
*******

.. note::

    Here you should add a short intro to your example, and then use ``literalinclude`` to add it.
    To make it simple, you can fill in this template.

Let's look at the task of <describe the task> using the <data set used in the example>.
The dataset contains <describe the data>.
Here's an outline:

.. code-block::

    <present the folder structure of the data or some data samples here>

Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the <link to the DataModule with ``:class:``>.
We select a pre-trained backbone to use for our <link to the Task with ``:class:``> and finetune on the <name of the data set> data.
We then use the trained <link to the Task with ``:class:``> for inference.
Finally, we save the model.
Here's the full example:

<include the example with ``literalinclude``>

.. literalinclude:: ../../../flash_examples/template.py
    :language: python
    :lines: 14-

Here’s the rendered doc page!


Once the docs are done, it’s finally time to open a PR and wait for some reviews!


Congratulations on adding your first Task to Flash, we hope to see you again soon!

Flash Governance | Persons of interest

Leads

Core Maintainers

Contributing

Welcome to the PyTorch Lightning community! We’re building the most advanced research platform on the planet to implement the latest, best practices that the amazing PyTorch team rolls out!

Flash Design Principles

We encourage all sorts of contributions you’re interested in adding! When coding for Flash, please follow these principles.

Simple Internal Code

It’s useful for users to look at the code and understand very quickly what’s happening. Many users won’t be engineers. Thus we need to value clear, simple code over condensed ninja moves. While that’s super cool, this isn’t the project for that :)

Force User Decisions To Best Practices

There are 1,000 ways to do something. However, eventually one popular solution becomes standard practice, and everyone follows. We try to find the best way to solve a particular problem, and then force our users to use it for readability and simplicity.

When something becomes a best practice, we add it to the framework. This is usually something like bits of code in utils or in the model file that everyone keeps adding over and over again across projects. When this happens, bring that code inside the trainer and add a flag for it.

Backward-compatible API

We all hate updating our deep learning packages because we don’t want to refactor a bunch of stuff. In Flash, we make sure every change we make which could break an API is backward compatible with good deprecation warnings.

Gain User Trust

As a researcher, you can’t have any part of your code going wrong. So, make thorough tests to ensure that every implementation of a new trick or subtle change is correct.

Interoperability

PyTorch Lightning Flash is highly interoperable with PyTorch Lightning and PyTorch.


Contribution Types

We are always looking for help implementing new features or fixing bugs.

A lot of good work has already been done in project mechanics (requirements.txt, setup.py, pep8, badges, ci, etc…) so we’re in a good state there thanks to all the early contributors (even pre-beta release)!

Bug Fixes:

  1. If you find a bug please submit a GitHub issue.

    • Make sure the title explains the issue.

    • Describe your setup, what you are trying to do, expected vs. actual behaviour. Please add configs and code samples.

    • Add details on how to reproduce the issue - a minimal test case is always best, colab is also great. Note, that the sample code shall be minimal and if needed with publicly available data.

  2. Try to fix it or recommend a solution. We highly recommend to use test-driven approach:

    • Convert your minimal code example to a unit/integration test with assert on expected results.

    • Start by debugging the issue… You can run just this particular test in your IDE and draft a fix.

    • Verify that your test case fails on the master branch and only passes with the fix applied.

  3. Submit a PR!

Note, even if you do not find the solution, sending a PR with a test covering the issue is a valid contribution and we can help you or finish it with you :]

New Features:

  1. Submit a GitHub issue - describe what is the motivation of such feature (adding the use case or an example is helpful).

  2. Let’s discuss to determine the feature scope.

  3. Submit a PR! We recommend test driven approach to adding new features as well:

    • Write a test for the functionality you want to add.

    • Write the functional code until the test passes.

  4. Add/update the relevant tests!

New Tasks:

Flash is a framework of tasks for fast prototyping, baselining, finetuning and solving business and scientific problems with deep learning. Following are general guidelines for adding new tasks.

  1. Models which are standard baselines

  2. Whose results are reproduced properly either by us or by authors.

  3. Top models which are not SOTA but highly cited for production usage / for other uses. (E.g. Mobile BERT, MobileNets, FBNets).

  4. Do not reinvent the wheel, natively support torchvision, torchtext, torchaudio models.

  5. Use open source licensed models.

Please raise an issue before adding a new task. Please let us know why the particular task is important for Flash.

Test cases:

Want to keep Lightning Flash healthy? Love seeing those green tests? So do we! How to we keep it that way? We write tests! We value tests contribution even more than new features.

Tests are written using pytest.

Have a look at sample tests here.

After you have added the respective tests, you can run the tests locally with make script:

make test

Want to add a new test case and not sure how? Talk to us!


Guidelines

For this section, we refer to read the parent PL guidelines

Reminder

All added or edited code shall be the own original work of the particular contributor. If you use some third-party implementation, all such blocks/functions/modules shall be properly referred and if possible also agreed by code’s author. For example - This code is inspired from http://.... In case you adding new dependencies, make sure that they are compatible with the actual PyTorch Lightning license (ie. dependencies should be at least as permissive as the PyTorch Lightning license).

How to rebase my PR?

We recommend creating a PR in a separate branch other than master, especially if you plan to submit several changes and do not want to wait until the first one is resolved (we can work on them in parallel).

First, make sure you have set upstream by running:

git remote add upstream https://github.com/PyTorchLightning/lightning-flash.git

You’ll know its set up right if you run git remote -v and see something similar to this:

origin  https://github.com/{YOUR_USERNAME}/lightning-flash.git (fetch)
origin  https://github.com/{YOUR_USERNAME}/lightning-flash.git (push)
upstream        https://github.com/PyTorchLightning/lightning-flash.git (fetch)
upstream        https://github.com/PyTorchLightning/lightning-flash.git (push)

Checkout your feature branch and rebase it with upstream’s master before pushing up your feature branch:

git fetch --all --prune
git rebase upstream/master
# follow git instructions to resolve conflicts
git push -f

Question & Answer

  1. How can I help/contribute?

    All help is extremely welcome - reporting bugs, fixing documentation, adding test cases, solving issues and preparing bug fixes. To solve some issues you can start with label good first issue or chose something close to your domain with label help wanted. Before you start to implement anything check that the issue description that it is clear and self-assign the task to you (if it is not possible, just comment that you take it and we assign it to you…).

  2. Is there a recommendation for branch names?

    We do not rely on the name convention so far you are working with your own fork. Anyway it would be nice to follow this convention <type>/<issue-id>_<short-name> where the types are: bugfix, feature, docs, tests, …

  3. I have a model in other framework than PyTorch, how do I add it here?

    Since PyTorch Lightning is written on top of PyTorch. We need models in PyTorch only. Also, we would need same or equivalent results with PyTorch Lightning after converting the models from other frameworks.

Changelog

All notable changes to this project will be documented in this file.

The format is based on Keep a Changelog.

[0.7.4] - 2022-04-27

[0.7.4] - Fixed

  • Fixed a bug where LR schedulers from HuggingFace could not be used with newer versions of PyTorch Lightning (#1307)

  • Fixed a bug where the default Flash zero configurations for ObjectDetector, InstanceSegmentation, and KeypointDetector would error with the latest version of some requirements (#1306)

  • Fixed plain LightningModule support for Flash data modules. (#1281)

[0.7.3] - 2022-04-13

[0.7.3] - Fixed

  • Fixed a bug where some backbones were incorrectly listed as available for the ObjectDetector, InstanceSegmentation, and KeypointDetector (#1267)

  • Fixed a bug where the backbone would not be frozen when finetuning the SpeechRecognition task (#1275)

  • Fixed a bug where the backbone would not be frozen when finetuning the QuestionAnswering task with certain model types (#1275)

[0.7.2] - 2022-03-30

[0.7.2] - Fixed

  • Fixed examples (question answering), where NLTK’s punkt module needs to be downloaded first. (#1215)

  • Fixed normalizing inputs to video classification (#1213)

  • Fixed a bug where pretraining_transforms in the ImageEmbedder was never called. (1196)

  • Fixed a bug where BASE_MODEL_NAME was not in the dict for dino and moco strategies. (1196)

  • Fixed support for torch==1.11.0 (#1234)

  • Fixed DDP spawn support for ObjectDetector, InstanceSegmentation, and KeypointDetector (#1222)

  • Fixed a bug where InstanceSegmentation would fail if samples had an inconsistent number of bboxes, labels, and masks (these will now be treated as negative samples) (#1222)

  • Fixed a bug where collate functions were never called in the ImageEmbedder class. (#1217)

  • Fixed a bug where ObjectDetector, InstanceSegmentation, and KeypointDetector would log train and validation metrics with the same name (#1252)

  • Fixed a bug where using ReduceLROnPlateau would raise an error (#1251)

  • Fixed GPU support for self-supervised training with the ImageEmbedder (#1256)

[0.7.1] - 2022-03-01

[0.7.1] - Added

  • Added the normalization parameters of torchvision.transforms.Normalize as transform_kwargs in the ImageClassificationInputTransform (#1178)

  • Added available_outputs method to the Task (#1206)

[0.7.1] - Fixed

  • Fixed a bug where DDP would not work with Flash tasks (#1182)

  • Fixed DDP support for VideoClassifier (#1189)

  • Fixed a bug where buffers in loss functions were not correctly registered in the Task (#1203)

  • Fixed support for passing a sampler instance to from_* methods / the DataModule (#1204)

[0.7.0] - 2022-02-15

[0.7.0] - Added

  • Added support for multi-label, space delimited, targets (#1076)

  • Added support for tabular classification / regression backbones from PyTorch Tabular (#1098)

  • Added Flash zero support for tabular regression (#1098)

  • Added support for COCO annotations with non-default keypoint labels to KeypointDetectionData.from_coco (#1102)

  • Added support for from_csv and from_data_frame to VideoClassificationData (#1117)

  • Added support for SemanticSegmentationData.from_folders where mask files have different extensions to the image files (#1130)

  • Added FlashRegistry of Available Heads for flash.image.ImageClassifier (#1152)

  • Added support for ObjectDetectionData.from_files (#1154)

  • Added support for passing the Output object (or a string e.g. "labels") to the flash.Trainer.predict method (#1157)

  • Added support for passing the TargetFormatter object to from_* methods for classification to override target handling (#1171)

[0.7.0] - Changed

  • Changed Wav2Vec2Processor to AutoProcessor and seperate it from backbone [optional] (#1075)

  • Renamed ClassificationInput to ClassificationInputMixin (#1116)

  • Changed the default learning_rate for all tasks to be None, corresponding to the default for your chosen optimizer (#1172)

[0.7.0] - Fixed

  • Fixed a bug when not explicitly passing embedding_sizes to the TabularClassifier and TabularRegressor tasks (#1067)

  • Fixed a bug where under some circumstances transforms would not get called (#1072)

  • Fixed a bug where prediction would sometimes give the wrong number of outputs (#1077)

  • Fixed a bug where passing the val_split to the DataModule would not have the desired effect (#1079)

  • Fixed a bug where passing predict_data_frame to ImageClassificationData.from_data_frame raised an error (#1088)

  • Fixed a bug where segmentation files / masks were loaded with an inconsistent ordering (#1094)

  • Fixed a bug with AudioClassificationData.from_numpy (#1096)

  • Fixed a bug when using SpeechRecognitionData.from_files for training / validating / testing (#1097)

  • Fixed a bug when using SpeechRecognitionData.from_csv or from_json when predicting without targets (#1097)

  • Fixed a bug where SpeechRecognitionData.from_datasets did not work as expected (#1097)

  • Fixed a bug where loading data for prediction with SemanticSegmentationData.from_folders raised an error (#1101)

  • Fixed a bug when passing a predict_folder argument to from_coco / from_voc / from_via in IceVision tasks (#1102)

  • Fixed ObjectDetectionData.from_voc and ObjectDetectionData.from_via (#1102)

  • Fixed a bug where InstanceSegmentationData.from_coco would raise an error if not using file-based masks (#1102)

  • Fixed InstanceSegmentationData.from_voc (#1102)

  • Fixed a bug when loading tabular data for prediction without a target field / column (#1114)

  • Fixed a bug when loading prediction data for graph classification without targets (#1121)

  • Fixed a bug where loading Seq2Seq data for prediction would not work if the target field was not present (#1128)

  • Fixed a bug where from_fiftyone classmethods did not work correctly with a predict_dataset (#1136)

  • Fixed a bug where the labels property would return None when using ObjectDetectionData.from_fiftyone (#1136)

  • Fixed a bug where TabularData would not work correctly with no categorical variables (#1144)

  • Fixed a bug where loading TabularForecastingData for prediction would only yield a single sample per series (#1149)

  • Fixed a bug where backbones for the ObjectDetector, KeypointDetector, and InstanceSegmentation tasks were not always frozen correctly when finetuning (#1163)

  • Fixed a bug where DataModule.multi_label would sometimes be None when it had been inferred to be False (#1165)

[0.7.0] - Removed

  • Removed the Seq2SeqData base class (use TranslationData or SummarizationData directly) (#1128)

  • Removed the ability to attach the Output object directly to the model (#1157)

[0.6.0] - 2021-13-12

[0.6.0] - Added

  • Added TextEmbedder task (#996)

  • Added predict_kwargs in ObjectDetector, InstanceSegmentation, KeypointDetector (#990)

  • Added backbones for GraphClassifier (#592)

  • Added GraphEmbedder task (#592)

  • Added support for comma delimited multi-label targets to the ImageClassifier (#997)

  • Added datapipeline_state on dataset creation within the from_* methods from the DataModule (#1018)

[0.6.0] - Changed

  • Changed DataSource to Input (#929)

  • Changed Preprocess to InputTransform (#951)

  • Changed classes named *Serializer and properties / variables named serializer to be *Output and output respectively (#927)

  • Changed Postprocess to OutputTransform (#942)

  • Changed loading of RGBA images to drop alpha channel by default (#946)

  • Updated FlashFinetuning callback to use separate hooks that lets users use the freezing logic provided out-of-the-box from flash, route FlashFinetuning through a registry. (#830)

  • Changed the SpeechRecognition task to use AutoModelForCTC rather than just Wav2Vec2ForCTC (#874)

  • Changed the Deserializer to subclass ServeInput (#1013)

  • Added Output suffix to Preds, FiftyOneDetectionLabels, SegmentationLabels, FiftyOneDetectionLabels, DetectionLabels, Classes, FiftyOneLabels, Labels, Logits, Probabilities (#1011)

  • Changed from_files and from_folders from ObjectDetectionData, InstanceSegmentationData, KeypointDetectionData to support only the predicting stage (#1018)

  • Changed Image Classification Task to use the new DataModule API (#1025)

[0.6.0] - Deprecated

  • Deprecated flash.core.data.process.Serializer in favour of flash.core.data.io.output.Output (#927)

  • Deprecated Task.serializer in favour of Task.output (#927)

  • Deprecated flash.text.seq2seq.core.metrics in favour of torchmetrics[text] (#648)

  • Deprecated flash.core.data.data_source.DefaultDataKeys in favour of flash.DataKeys (#929)

  • Deprecated data_source argument to flash.Task.predict in favour of input (#929)

[0.6.0] - Fixed

  • Fixed a bug where using image classification with DDP spawn would trigger an infinite recursion (#969)

  • Fixed a bug where Flash could not be used with IceVision 0.11.0 (#989)

  • Fixed a bug where backbone weights were sometimes not frozen correctly (#992)

  • Fixed a bug where translation metrics were not computed correctly (#992)

  • Fixed a bug where additional DataModule keyword arguments could not be configured with Flash Zero for some tasks (#994)

  • Fixed a bug where the TabularForecaster would not work with some versions of pandas (#995)

[0.6.0] - Removed

  • Removed OutputMapping (#939)

  • Removed Output.enable and Output.disable (#939)

  • Removed OutputTransform.save_sample and save_data hooks (#948)

  • Removed InputTransform pre_tensor_transform, to_tensor_transform, post_tensor_transform hooks in favour of per_sample_transform (#1010)

  • Removed Task.predict, use Trainer.predict instead (#1030)

  • Removed the backbone argument from TextClassificationData, it is now sufficient to only provide a backbone argument to the TextClassifier (#1022)

  • Removed support for the serve_sanity_check argument in flash.Trainer (#1062)

[0.5.2] - 2021-11-05

[0.5.2] - Added

  • Added a TabularForecaster task based on PyTorch Forecasting (#647)

  • Added a TabularRegressor task (#892)

[0.5.2] - Fixed

  • Fixed a bug where test metrics were not logged correctly with active learning (#879)

  • Fixed a bug where validation metrics could be aggregated together with test metrics in some cases (#900)

  • Fixed a bug where the latest versions of torchmetrics and Lightning Flash could not be installed together (#902)

  • Fixed compatibility with PyTorch-Lightning 1.5 (#933)

[0.5.1] - 2021-10-26

[0.5.1] - Added

  • Added LabelStudio integration (#554)

  • Added support learn2learn training_strategy for ImageClassifier (#737)

  • Added vissl training_strategies for ImageEmbedder (#682)

  • Added support for from_data_frame to TextClassificationData (#785)

  • Added FastFace integration (#606)

  • Added support for from_lists to TextClassificationData (#805)

[0.5.1] - Changed

  • Changed the default num_workers on linux to 0 (matching the default for other OS) (#759)

  • Optimizer and LR Scheduler registry are used to get the respective inputs to the Task using a string (or a callable). (#777)

[0.5.1] - Fixed

  • Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored (#792)

  • Fixed a bug where loading text data with additional non-numeric columns (not input or target) would give an error (#888)

[0.5.0] - 2021-09-07

[0.5.0] - Added

  • Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method (#552)

  • Added support for from_csv and from_data_frame to ImageClassificationData (#556)

  • Added SimCLR, SwAV, Barlow-twins pretrained weights for resnet50 backbone in ImageClassifier task (#560)

  • Added support for Semantic Segmentation backbones and heads from segmentation-models.pytorch (#562)

  • Added support for nesting of Task objects (#575)

  • Added PointCloudSegmentation Task (#566)

  • Added PointCloudObjectDetection Task (#600)

  • Added a GraphClassifier task (#73)

  • Added the option to pass pretrained as a string to SemanticSegmentation to change pretrained weights to load from segmentation-models.pytorch (#587)

  • Added support for field parameter for loadng JSON based datasets in text tasks. (#585)

  • Added AudioClassificationData and an example for classifying audio spectrograms (#594)

  • Added a SpeechRecognition task for speech to text using Wav2Vec (#586)

  • Added Flash Zero, a zero code command line ML platform built with flash (#611)

  • Added support for .npy and .npz files to ImageClassificationData and AudioClassificationData (#651)

  • Added support for from_csv to the AudioClassificationData (#651)

  • Added option to pass a resolver to the from_csv and from_pandas methods of ImageClassificationData, which is used to resolve filenames given IDs (#651)

  • Added integration with IceVision for the ObjectDetector (#608)

  • Added keypoint detection task (#608)

  • Added instance segmentation task (#608)

  • Added Torch ORT support to Transformer based tasks (#667)

  • Added support for flash zero with the InstanceSegmentation and KeypointDetector tasks (#672)

  • Added support for in_chans argument to the flash ResNet to control the expected number of input channels (#673)

  • Added a QuestionAnswering task for extractive question answering (#607)

  • Added automatic unwrapping of IceVision prediction objects (#727)

  • Added support for the ObjectDetector with FiftyOne (#727)

  • Added support for MP3 files to the SpeechRecognition task with librosa (#726)

  • Added support for from_numpy and from_tensors to AudioClassificationData (#745)

[0.5.0] - Changed

  • Changed how pretrained flag works for loading weights for ImageClassifier task (#560)

  • Removed bolts pretrained weights for SSL from ImageClassifier task (#560)

  • Changed the behaviour of the sampler argument of the DataModule to take a Sampler type rather than instantiated object (#651)

  • Changed arguments to ObjectDetector, use head instead of model and append _fpn to the backbone name instead of the fpn argument (#608)

[0.5.0] - Fixed

  • Fixed a bug where serve sanity checking would not be triggered using the latest PyTorchLightning version (#493)

  • Fixed a bug where train and validation metrics weren’t being correctly computed (#559)

  • Fixed a bug where an uncaught ValueError could be raised when checking if a module is available (#615)

  • Fixed a bug where some tasks were not compatible with PyTorch 1.7 due to use of torch.jit.isinstance (#611)

  • Fixed a bug where custom samplers would not be properly forwarded to the data loader (#651)

  • Fixed a bug where it was not possible to pass no metrics to the ImageClassifier or TestClassifier (#660)

  • Fixed a bug where drop_last would be set to True during prediction and testing (#671)

  • Fixed a bug where flash was not compatible with pytorch-lightning >= 1.4.3 (#690)

[0.4.0] - 2021-06-22

[0.4.0] - Added

  • Added integration with FiftyOne (#360)

  • Added flash.serve (#399)

  • Added support for torch.jit to tasks where possible and documented task JIT compatibility (#389)

  • Added option to provide a Sampler to the DataModule to use when creating a DataLoader (#390)

  • Added support for multi-label text classification and toxic comments example (#401)

  • Added a sanity checking feature to flash.serve (#423)

[0.4.0] - Changed

  • Split backbone argument to SemanticSegmentation into backbone and head arguments (#412)

[0.4.0] - Fixed

  • Fixed a bug where the DefaultDataKeys.METADATA couldn’t be a dict (#393)

  • Fixed a bug where the SemanticSegmentation task would not work as expected with finetuning callbacks (#412)

  • Fixed a bug where predict batches could not be visualized with ImageClassificationData (#438)

[0.3.2] - 2021-06-08

[0.3.2] - Fixed

  • Fixed a bug where flash.Trainer.from_argparse_args + finetune would not work (#382)

[0.3.1] - 2021-06-08

[0.3.1] - Added

  • Added deeplabv3, lraspp, and unet backbones for the SemanticSegmentation task (#370)

[0.3.1] - Changed

  • Changed the installation command for extra features (#346)

  • Change resize interpolation default mode to nearest (#352)

[0.3.1] - Deprecated

  • Deprecated SemanticSegmentation backbone names torchvision/fcn_resnet50 and torchvision/fcn_resnet101, use fc_resnet50 and fcn_resnet101 instead (#370)

[0.3.1] - Fixed

  • Fixed flash.Trainer.add_argparse_args not adding any arguments (#343)

  • Fixed a bug where the translation task wasn’t decoding tokens properly (#332)

  • Fixed a bug where huggingface tokenizers were sometimes being pickled (#332)

  • Fixed issue with KorniaParallelTransforms to assure to share the random state between transforms (#351)

  • Fixed a bug where using val_split with overfit_batches would give an infinite recursion (#375)

  • Fixed a bug where some timm models were mistakenly given a global_pool argument (#377)

  • Fixed flash.Trainer.from_argparse_args not passing arguments correctly (#380)

[0.3.0] - 2021-05-20

[0.3.0] - Added

  • Added DataPipeline API (#188 #141 #207)

  • Added timm integration (#196)

  • Added BaseViz Callback (#201)

  • Added backbone API (#204)

  • Added support for Iterable auto dataset (#227)

  • Added multi label support (#230)

  • Added support for schedulers (#232)

  • Added visualisation callback for image classification (#228)

  • Added Video Classification task (#216)

  • Added Dino backbone for image classification (#259)

  • Added Data Sources API (#256 #264 #272)

  • Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState (#229)

  • Added Semantic Segmentation task (#239 #287 #290)

  • Added Object detection prediction example (#283)

  • Added Style Transfer task and accompanying finetuning and prediction examples (#262)

  • Added a Template task and tutorials showing how to contribute a task to flash (#306)

[0.3.0] - Changed

  • Rename valid_ to val_ (#197)

  • Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState (#229)

[0.3.0] - Fixed

  • Fix DataPipeline resolution in Task (#212)

  • Fixed a bug where the backbone used in summarization was not correctly passed to the postprocess (#296)

[0.2.3] - 2021-04-17

[0.2.3] - Added

  • Added TIMM integration as backbones (#196)

[0.2.3] - Fixed

  • Fixed nltk.download (#210)

[0.2.2] - 2021-04-05

[0.2.2] - Changed

  • Switch to use torchmetrics (#169)

  • Better support for optimizer and schedulers (#232)

  • Update lightning version to v1.2 (#133)

[0.2.2] - Fixed

  • Fixed classification softmax (#169)

  • Fixed a bug where loading from a local checkpoint that had pretrained=True without an internet connection would sometimes raise an error (#237)

  • Don’t download data if exists (#157)

[0.2.1] - 2021-3-06

[0.2.1] - Added

  • Added RetinaNet & backbones to ObjectDetector Task (#121)

  • Added .csv image loading utils (#116, #117, #118)

[0.2.1] - Changed

  • Set inputs as optional (#109)

[0.2.1] - Fixed

  • Set minimal requirements (#62)

  • Fixed VGG backbone num_features (#154)

[0.2.0] - 2021-02-12

[0.2.0] - Added

  • Added ObjectDetector Task (#56)

  • Added TabNet for tabular classification (#101)

  • Added support for more backbones(mobilnet, vgg, densenet, resnext) (#45)

  • Added backbones for image embedding model (#63)

  • Added SWAV and SimCLR models to imageclassifier + backbone reorg (#68)

[0.2.0] - Changed

  • Applied transform in FilePathDataset (#97)

  • Moved classification integration from vision root to folder (#86)

[0.2.0] - Fixed

  • Unfreeze default number of workers in datamodule (#57)

  • Fixed wrong label in FilePathDataset (#94)

[0.2.0] - Removed

  • Removed densenet161 duplicate in DENSENET_MODELS (#76)

  • Removed redundant num_features arg from Classification model (#88)

[0.1.0] - 2021-02-02

[0.1.0] - Added

  • Added flash_notebook examples (#9)

  • Added strategy to trainer.finetune with NoFreeze, Freeze, FreezeUnfreeze, UnfreezeMilestones Callbacks(#39)

  • Added SummarizationData, SummarizationTask and TranslationData, TranslationTask (#37)

  • Added ImageEmbedder (#36)

Template

The Task

Here you should add a description of your task. For example: Classification is the task of assigning one of a number of classes to each data point.


Example

Note

Here you should add a short intro to your example, and then use literalinclude to add it. To make it simple, you can fill in this template.

Let’s look at the task of <describe the task> using the <data set used in the example>. The dataset contains <describe the data>. Here’s an outline:

<present the folder structure of the data or some data samples here>

Once we’ve downloaded the data using download_data(), we create the <link to the DataModule with :class:>. We select a pre-trained backbone to use for our <link to the Task with :class:> and finetune on the <name of the data set> data. We then use the trained <link to the Task with :class:> for inference. Finally, we save the model. Here’s the full example:

<include the example with literalinclude>

import numpy as np
import torch
from sklearn import datasets

import flash
from flash.template import TemplateData, TemplateSKLearnClassifier

# 1. Create the DataModule
datamodule = TemplateData.from_sklearn(
    train_bunch=datasets.load_iris(),
    val_split=0.1,
    batch_size=4,
)

# 2. Build the task
model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Classify a few examples
datamodule = TemplateData.from_numpy(
    predict_data=[
        np.array([4.9, 3.0, 1.4, 0.2]),
        np.array([6.9, 3.2, 5.7, 2.3]),
        np.array([7.2, 3.0, 5.8, 1.6]),
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("template_model.pt")

Indices and tables


© Copyright 2020-2021, PyTorch Lightning. Revision 1c7d8e08.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.4
Versions
latest
stable
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.