Shortcuts

Training from scratch

Some Flash tasks have been pretrained on large data sets. To accelerate your training, calling the finetune() method using a pretrained backbone will fine-tune the backbone to generate a model customized to your data set and desired task.

From the Quick Start guide.

To train a task from scratch:

  1. Load your data and organize it using a DataModule customized for the task (example: ImageClassificationData).

  2. Choose and initialize your Task (setting pretrained=False) which has state-of-the-art backbones built in (example: ImageClassifier).

  3. Init a flash.core.trainer.Trainer or a pytorch_lightning.trainer.Trainer.

  4. Call flash.core.trainer.Trainer.fit() with your data set.

  5. Save your trained model.


Here’s an example:

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    batch_size=1,
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Train the model
trainer.fit(model, datamodule=datamodule)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Training options

Flash tasks supports many advanced training functionalities out-of-the-box, such as:

  • limit number of epochs

# train for 10 epochs
flash.Trainer(max_epochs=10)
  • Training on GPUs

# train on 1 GPU
flash.Trainer(gpus=1)
  • Training on multiple GPUs

# train on multiple GPUs
flash.Trainer(gpus=4)
# train on gpu 1, 3, 5 (3 gpus total)
flash.Trainer(gpus=[1, 3, 5])
  • Using mixed precision training

# Multi GPU with mixed precision
flash.Trainer(gpus=2, precision=16)
  • Training on TPUs

# Train on TPUs
flash.Trainer(accelerator="tpu", num_devices=8)

You can add to the flash Trainer any argument from the Lightning trainer! Learn more about the Lightning Trainer here.