Training from scratch¶
Some Flash tasks have been pretrained on large data sets.
To accelerate your training, calling the finetune()
method using a pretrained backbone will fine-tune the backbone to generate a model customized to your data set and desired task.
From the Quick Start guide.
To train a task from scratch:
Load your data and organize it using a DataModule customized for the task (example:
ImageClassificationData
).Choose and initialize your Task (setting
pretrained=False
) which has state-of-the-art backbones built in (example:ImageClassifier
).Init a
flash.core.trainer.Trainer
or apytorch_lightning.trainer.Trainer
.Call
flash.core.trainer.Trainer.fit()
with your data set.Save your trained model.
Here’s an example:
from pytorch_lightning import seed_everything
import flash
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# set the random seeds.
seed_everything(42)
# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
batch_size=1,
)
# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False)
# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
# 4. Train the model
trainer.fit(model, datamodule=datamodule)
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
Training options¶
Flash tasks supports many advanced training functionalities out-of-the-box, such as:
limit number of epochs
# train for 10 epochs
flash.Trainer(max_epochs=10)
Training on GPUs
# train on 1 GPU
flash.Trainer(gpus=1)
Training on multiple GPUs
# train on multiple GPUs
flash.Trainer(gpus=4)
# train on gpu 1, 3, 5 (3 gpus total)
flash.Trainer(gpus=[1, 3, 5])
Using mixed precision training
# Multi GPU with mixed precision
flash.Trainer(gpus=2, precision=16)
Training on TPUs
# Train on TPUs
flash.Trainer(accelerator="tpu", num_devices=8)
You can add to the flash Trainer any argument from the Lightning trainer! Learn more about the Lightning Trainer here.