Finetuning¶
Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset.
Terminology¶
Here are common terms you need to be familiar with:
Term |
Definition |
---|---|
Finetuning |
The process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset |
Transfer learning |
The common name for finetuning |
Backbone |
The neural network that was pretrained on a different dataset |
Head |
Another neural network (usually smaller) that maps the backbone to your particular dataset |
Freeze |
Disabling gradient updates to a model (ie: not learning) |
Unfreeze |
Enabling gradient updates to a model |
Finetuning in Flash¶
From the Quick Start guide.
To use a Task for finetuning:
Load your data and organize it using a DataModule customized for the task (example:
ImageClassificationData
).Choose and initialize your Task which has state-of-the-art backbones built in (example:
ImageClassifier
).Init a
flash.core.trainer.Trainer
.Choose a finetune strategy (example: “freeze”) and call
flash.core.trainer.Trainer.finetune()
with your data.Save your finetuned model.
Here’s an example of finetuning.
from pytorch_lightning import seed_everything
import flash
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# set the random seeds.
seed_everything(42)
# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
batch_size=1,
)
# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)
# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
# 4. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
Using a finetuned model¶
Once you’ve finetuned, use the model to predict:
predict_datamodule = ImageClassificationData.from_files(
predict_files=[
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
],
batch_size=1,
)
predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
print(predictions)
We get the following output:
[['bees', 'ants']]
Or you can use the saved model for prediction anywhere you want!
from flash import Trainer
from flash.image import ImageClassifier, ImageClassificationData
# load finetuned checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")
trainer = Trainer()
datamodule = ImageClassificationData.from_files(predict_files=["path/to/your/own/image.png"])
predictions = trainer.predict(model, datamodule=datamodule)
Finetune strategies¶
Finetuning is very task specific. Each task encodes the best finetuning practices for that task. However, Flash gives you a few default strategies for finetuning.
Finetuning operates on two things, the model backbone and the head. The backbone is the neural network that was pre-trained. The head is another neural network that bridges between the backbone and your particular dataset.
no_freeze¶
In this strategy, the backbone and the head are unfrozen from the beginning.
trainer.finetune(model, datamodule, strategy="no_freeze")
In pseudocode, this looks like:
backbone = Resnet50()
head = nn.Linear(...)
backbone.unfreeze()
head.unfreeze()
train(backbone, head)
freeze¶
The freeze strategy keeps the backbone frozen throughout.
trainer.finetune(model, datamodule, strategy="freeze")
The pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head)
Advanced strategies¶
Every finetune strategy can also be customized.
freeze_unfreeze¶
The freeze_unfreeze
strategy keeps the backbone frozen until a certain epoch (provided in a tuple to the strategy
argument) after which the backbone will be unfrozen.
For example, to unfreeze after epoch 7:
trainer.finetune(model, datamodule, strategy=("freeze_unfreeze", 7))
Under the hood, the pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head, epochs=10)
# unfreeze after 7 epochs
backbone.unfreeze()
train(backbone, head)
unfreeze_milestones¶
This strategy allows you to unfreeze part of the backbone at predetermined intervals.
Here’s an example where:
backbone starts frozen
at epoch 3 the last 2 layers unfreeze
at epoch 8 the full backbone unfreezes
trainer.finetune(model, datamodule, strategy=("unfreeze_milestones", ((3, 8), 2)))
Under the hood, the pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head, epochs=3)
# unfreeze last 2 layers at epoch 3
backbone.unfreeze_last_layers(2)
train(backbone, head, epochs=8)
# unfreeze the full backbone
backbone.unfreeze()
Custom Strategy¶
For even more customization, create your own finetuning callback. Learn more about callbacks here.
from flash.core.finetuning import FlashBaseFinetuning
# Create a finetuning callback
class FeatureExtractorFreezeUnfreeze(FlashBaseFinetuning):
def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True):
# this will set self.attr_names as ["backbone"]
super().__init__("backbone", train_bn)
self._unfreeze_epoch = unfreeze_epoch
def finetune_function(self, pl_module, current_epoch, optimizer, opt_idx):
# unfreeze any module you want by overriding this function
# When ``current_epoch`` is 5, backbone will start to be trained.
if current_epoch == self._unfreeze_epoch:
self.unfreeze_and_extend_param_group(
pl_module.backbone,
optimizer,
)
# Pass the callback to trainer.finetune
trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5))
Working with DeepSpeed¶
If you are using DeepSpeed, you can use the following strategies. The usage of the following strategies is the same as listed above, but finetuning with DeepSpeed doesn’t yet support the loading and storing of its parameters.
freeze_deepspeed
no_freeze_deepspeed
freeze_unfreeze_deepspeed
unfreeze_milestones_deepspeed