Finetuning¶
Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset.
Terminology¶
Here are common terms you need to be familiar with:
Term |
Definition |
---|---|
Finetuning |
The process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset |
Transfer learning |
The common name for finetuning |
Backbone |
The neural network that was pretrained on a different dataset |
Head |
Another neural network (usually smaller) that maps the backbone to your particular dataset |
Freeze |
Disabling gradient updates to a model (ie: not learning) |
Unfreeze |
Enabling gradient updates to a model |
Finetuning in Flash¶
From the Quick Start guide.
To use a Task for finetuning:
Load your data and organize it using a DataModule customized for the task (example:
ImageClassificationData
).Choose and initialize your Task which has state-of-the-art backbones built in (example:
ImageClassifier
).Init a
flash.core.trainer.Trainer
.Choose a finetune strategy (example: “freeze”) and call
flash.core.trainer.Trainer.finetune()
with your data.Save your finetuned model.
Here’s an example of finetuning.
from pytorch_lightning import seed_everything
import flash
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# set the random seeds.
seed_everything(42)
# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
batch_size=1,
)
# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)
# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
# 4. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
Using a finetuned model¶
Once you’ve finetuned, use the model to predict:
predict_datamodule = ImageClassificationData.from_files(
predict_files=[
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
],
batch_size=1,
)
predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
print(predictions)
We get the following output:
[['bees', 'ants']]
Or you can use the saved model for prediction anywhere you want!
from flash import Trainer
from flash.image import ImageClassifier, ImageClassificationData
# load finetuned checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")
trainer = Trainer()
datamodule = ImageClassificationData.from_files(predict_files=["path/to/your/own/image.png"])
predictions = trainer.predict(model, datamodule=datamodule)
Finetune strategies¶
Finetuning is very task specific. Each task encodes the best finetuning practices for that task. However, Flash gives you a few default strategies for finetuning.
Finetuning operates on two things, the model backbone and the head. The backbone is the neural network that was pre-trained. The head is another neural network that bridges between the backbone and your particular dataset.
no_freeze¶
In this strategy, the backbone and the head are unfrozen from the beginning.
trainer.finetune(model, datamodule, strategy="no_freeze")
In pseudocode, this looks like:
backbone = Resnet50()
head = nn.Linear(...)
backbone.unfreeze()
head.unfreeze()
train(backbone, head)
freeze¶
The freeze strategy keeps the backbone frozen throughout.
trainer.finetune(model, datamodule, strategy="freeze")
The pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head)
Advanced strategies¶
Every finetune strategy can also be customized.
freeze_unfreeze¶
The freeze_unfreeze
strategy keeps the backbone frozen until a certain epoch (provided in a tuple to the strategy
argument) after which the backbone will be unfrozen.
For example, to unfreeze after epoch 7:
trainer.finetune(model, datamodule, strategy=("freeze_unfreeze", 7))
Under the hood, the pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head, epochs=10)
# unfreeze after 7 epochs
backbone.unfreeze()
train(backbone, head)
unfreeze_milestones¶
This strategy allows you to unfreeze part of the backbone at predetermined intervals
Here’s an example where:
backbone starts frozen
at epoch 3 the last 2 layers unfreeze
at epoch 8 the full backbone unfreezes
trainer.finetune(model, datamodule, strategy=("unfreeze_milestones", ((3, 8), 2)))
Under the hood, the pseudocode looks like:
backbone = Resnet50()
head = nn.Linear(...)
# freeze backbone
backbone.freeze()
head.unfreeze()
train(backbone, head, epochs=3)
# unfreeze last 2 layers at epoch 3
backbone.unfreeze_last_layers(2)
train(backbone, head, epochs=8)
# unfreeze the full backbone
backbone.unfreeze()
Custom Strategy¶
For even more customization, create your own finetuning callback. Learn more about callbacks here.
from flash.core.finetuning import FlashBaseFinetuning
# Create a finetuning callback
class FeatureExtractorFreezeUnfreeze(FlashBaseFinetuning):
def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True):
# this will set self.attr_names as ["backbone"]
super().__init__("backbone", train_bn)
self._unfreeze_epoch = unfreeze_epoch
def finetune_function(self, pl_module, current_epoch, optimizer, opt_idx):
# unfreeze any module you want by overriding this function
# When ``current_epoch`` is 5, backbone will start to be trained.
if current_epoch == self._unfreeze_epoch:
self.unfreeze_and_add_param_group(
pl_module.backbone,
optimizer,
)
# Pass the callback to trainer.finetune
trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5))