Shortcuts

Training from scratch

Some Flash tasks have been pretrained on large data sets. To accelerate your training, calling the finetune() method using a pretrained backbone will fine-tune the backbone to generate a model customized to your data set and desired task.

From the Quick Start guide.

To train a task from scratch:

  1. Load your data and organize it using a DataModule customized for the task (example: ImageClassificationData).

  2. Choose and initialize your Task (setting pretrained=False) which has state-of-the-art backbones built in (example: ImageClassifier).

  3. Init a flash.core.trainer.Trainer or a pytorch_lightning.trainer.Trainer.

  4. Call flash.core.trainer.Trainer.fit() with your data set.

  5. Save your trained model.


Here’s an example:

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Train the model
trainer.fit(model, datamodule=datamodule)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Training options

Flash tasks supports many advanced training functionalities out-of-the-box, such as:

  • limit number of epochs

# train for 10 epochs
flash.Trainer(max_epochs=10)
  • Training on GPUs

# train on 1 GPU
flash.Trainer(gpus=1)
  • Training on multiple GPUs

# train on multiple GPUs
flash.Trainer(gpus=4)
# train on gpu 1, 3, 5 (3 gpus total)
flash.Trainer(gpus=[1, 3, 5])
  • Using mixed precision training

# Multi GPU with mixed precision
flash.Trainer(gpus=2, precision=16)
  • Training on TPUs

# Train on TPUs
flash.Trainer(tpu_cores=8)

You can add to the flash Trainer any argument from the Lightning trainer! Learn more about the Lightning Trainer here.

Read the Docs v: latest
Versions
latest
stable
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.