Shortcuts

Finetuning

Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset.


Terminology

Here are common terms you need to be familiar with:

Terminology

Term

Definition

Finetuning

The process of tweaking a model trained on a large dataset, to your particular (likely much smaller) dataset

Transfer learning

The common name for finetuning

Backbone

The neural network that was pretrained on a different dataset

Head

Another neural network (usually smaller) that maps the backbone to your particular dataset

Freeze

Disabling gradient updates to a model (ie: not learning)

Unfreeze

Enabling gradient updates to a model


Finetuning in Flash

From the Quick Start guide.

To use a Task for finetuning:

  1. Load your data and organize it using a DataModule customized for the task (example: ImageClassificationData).

  2. Choose and initialize your Task which has state-of-the-art backbones built in (example: ImageClassifier).

  3. Init a flash.core.trainer.Trainer.

  4. Choose a finetune strategy (example: “freeze”) and call flash.core.trainer.Trainer.finetune() with your data.

  5. Save your finetuned model.


Here’s an example of finetuning.

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import LabelsOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    batch_size=1,
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())

# 4. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Using a finetuned model

Once you’ve finetuned, use the model to predict:

predict_datamodule = ImageClassificationData.from_files(
    predict_files=[
        "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
        "data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
    ],
    batch_size=1,
)
predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
print(predictions)

We get the following output:

[['bees', 'ants']]

Or you can use the saved model for prediction anywhere you want!

from flash import Trainer
from flash.image import ImageClassifier, ImageClassificationData

# load finetuned checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")

trainer = Trainer()
datamodule = ImageClassificationData.from_files(predict_files=["path/to/your/own/image.png"])
predictions = trainer.predict(model, datamodule=datamodule)

Finetune strategies

Finetuning is very task specific. Each task encodes the best finetuning practices for that task. However, Flash gives you a few default strategies for finetuning.

Finetuning operates on two things, the model backbone and the head. The backbone is the neural network that was pre-trained. The head is another neural network that bridges between the backbone and your particular dataset.

no_freeze

In this strategy, the backbone and the head are unfrozen from the beginning.

trainer.finetune(model, datamodule, strategy="no_freeze")

In pseudocode, this looks like:

backbone = Resnet50()
head = nn.Linear(...)

backbone.unfreeze()
head.unfreeze()

train(backbone, head)

freeze

The freeze strategy keeps the backbone frozen throughout.

trainer.finetune(model, datamodule, strategy="freeze")

The pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head)

Advanced strategies

Every finetune strategy can also be customized.

freeze_unfreeze

The freeze_unfreeze strategy keeps the backbone frozen until a certain epoch (provided in a tuple to the strategy argument) after which the backbone will be unfrozen.

For example, to unfreeze after epoch 7:

trainer.finetune(model, datamodule, strategy=("freeze_unfreeze", 7))

Under the hood, the pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head, epochs=10)

# unfreeze after 7 epochs
backbone.unfreeze()

train(backbone, head)

unfreeze_milestones

This strategy allows you to unfreeze part of the backbone at predetermined intervals.

Here’s an example where:

  • backbone starts frozen

  • at epoch 3 the last 2 layers unfreeze

  • at epoch 8 the full backbone unfreezes


trainer.finetune(model, datamodule, strategy=("unfreeze_milestones", ((3, 8), 2)))

Under the hood, the pseudocode looks like:

backbone = Resnet50()
head = nn.Linear(...)

# freeze backbone
backbone.freeze()
head.unfreeze()

train(backbone, head, epochs=3)

# unfreeze last 2 layers at epoch 3
backbone.unfreeze_last_layers(2)

train(backbone, head, epochs=8)

# unfreeze the full backbone
backbone.unfreeze()

Custom Strategy

For even more customization, create your own finetuning callback. Learn more about callbacks here.

from flash.core.finetuning import FlashBaseFinetuning

# Create a finetuning callback
class FeatureExtractorFreezeUnfreeze(FlashBaseFinetuning):
    def __init__(self, unfreeze_epoch: int = 5, train_bn: bool = True):
        # this will set self.attr_names as ["backbone"]
        super().__init__("backbone", train_bn)
        self._unfreeze_epoch = unfreeze_epoch

    def finetune_function(self, pl_module, current_epoch, optimizer, opt_idx):
        # unfreeze any module you want by overriding this function

        # When ``current_epoch`` is 5, backbone will start to be trained.
        if current_epoch == self._unfreeze_epoch:
            self.unfreeze_and_extend_param_group(
                pl_module.backbone,
                optimizer,
            )


# Pass the callback to trainer.finetune
trainer.finetune(model, datamodule, strategy=FeatureExtractorFreezeUnfreeze(unfreeze_epoch=5))
Read the Docs v: 0.7.5
Versions
latest
stable
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.