Shortcuts

Source code for flash.core.finetuning

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import BaseFinetuning
from torch.nn import Module
from torch.optim import Optimizer

from flash.core.registry import FlashRegistry

if not os.environ.get("READTHEDOCS", False):
    from pytorch_lightning.utilities.enums import LightningEnum
else:
    # ReadTheDocs mocks the `LightningEnum` import to be a regular type, so we replace it with a plain Enum here.
    from enum import Enum

    LightningEnum = Enum


class FinetuningStrategies(LightningEnum):
    """The ``FinetuningStrategies`` enum contains the keys that are used internally by the ``FlashBaseFinetuning`` when
    choosing the strategy to perform."""

    NO_FREEZE = "no_freeze"
    FREEZE = "freeze"
    FREEZE_UNFREEZE = "freeze_unfreeze"
    UNFREEZE_MILESTONES = "unfreeze_milestones"

    # TODO: Create a FlashEnum class???
    def __hash__(self) -> int:
        return hash(self.value)


[docs]class FlashBaseFinetuning(BaseFinetuning): """FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback.""" def __init__( self, strategy_key: Union[str, FinetuningStrategies], strategy_metadata: Optional[Union[int, Tuple[Tuple[int, int], int]]] = None, train_bn: bool = True, ): """ Args: strategy_key: The finetuning strategy to be used. See :meth:`~flash.core.trainer.Trainer.finetune` for the available strategies. strategy_metadata: Data that accompanies certain finetuning strategies like epoch number or number of layers. train_bn: Whether to train Batch Norm layer """ super().__init__() self.strategy: FinetuningStrategies = strategy_key self.strategy_metadata: Optional[Union[int, Tuple[Tuple[int, int], int]]] = strategy_metadata self.train_bn: bool = train_bn if self.strategy == FinetuningStrategies.FREEZE_UNFREEZE and not isinstance(self.strategy_metadata, int): raise TypeError( "The `freeze_unfreeze` strategy requires an integer denoting the epoch number to unfreeze at. Example: " "`strategy=('freeze_unfreeze', 7)`" ) if self.strategy == FinetuningStrategies.UNFREEZE_MILESTONES and not ( isinstance(self.strategy_metadata, Tuple) and isinstance(self.strategy_metadata[0], Tuple) and isinstance(self.strategy_metadata[1], int) and isinstance(self.strategy_metadata[0][0], int) and isinstance(self.strategy_metadata[0][1], int) ): raise TypeError( "The `unfreeze_milestones` strategy requires the format Tuple[Tuple[int, int], int]. Example: " "`strategy=('unfreeze_milestones', ((5, 10), 15))`" ) def _get_modules_to_freeze(self, pl_module: LightningModule) -> Union[Module, Iterable[Union[Module, Iterable]]]: modules_to_freeze = getattr(pl_module, "modules_to_freeze", None) if modules_to_freeze is None: raise AttributeError( "LightningModule missing instance method 'modules_to_freeze'." "Please, implement the method which returns NoneType or a Module or an Iterable of Modules." ) return modules_to_freeze() def freeze_before_training(self, pl_module: Union[Module, Iterable[Union[Module, Iterable]]]) -> None: if self.strategy != FinetuningStrategies.NO_FREEZE: modules = self._get_modules_to_freeze(pl_module=pl_module) if modules is not None: if isinstance(modules, Module): modules = [modules] self.freeze(modules=modules, train_bn=self.train_bn) def unfreeze_and_extend_param_group( self, modules: Union[Module, Iterable[Union[Module, Iterable]]], optimizer: Optimizer, train_bn: bool = True, ) -> None: self.make_trainable(modules) params = self.filter_params(modules, train_bn=train_bn, requires_grad=True) params = self.filter_on_optimizer(optimizer, params) if params: optimizer.param_groups[0]["params"].extend(params) def _freeze_unfreeze_function( self, pl_module: Union[Module, Iterable[Union[Module, Iterable]]], epoch: int, optimizer: Optimizer, opt_idx: int, strategy_metadata: int, ): unfreeze_epoch: int = strategy_metadata if epoch != unfreeze_epoch: return modules = self._get_modules_to_freeze(pl_module=pl_module) if modules is not None: self.unfreeze_and_extend_param_group( modules=modules, optimizer=optimizer, train_bn=self.train_bn, ) def _unfreeze_milestones_function( self, pl_module: Union[Module, Iterable[Union[Module, Iterable]]], epoch: int, optimizer: Optimizer, opt_idx: int, strategy_metadata: Tuple[Tuple[int, int], int], ): unfreeze_milestones: Tuple[int, int] = strategy_metadata[0] num_layers: int = strategy_metadata[1] modules = self._get_modules_to_freeze(pl_module=pl_module) if modules is not None: if epoch == unfreeze_milestones[0]: # unfreeze num_layers last layers backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[-num_layers:] self.unfreeze_and_extend_param_group( modules=backbone_modules, optimizer=optimizer, train_bn=self.train_bn, ) elif epoch == unfreeze_milestones[1]: # unfreeze remaining layers backbone_modules = BaseFinetuning.flatten_modules(modules=modules)[:-num_layers] self.unfreeze_and_extend_param_group( modules=backbone_modules, optimizer=optimizer, train_bn=self.train_bn, ) def finetune_function( self, pl_module: Union[Module, Iterable[Union[Module, Iterable]]], epoch: int, optimizer: Optimizer, opt_idx: int, ): if self.strategy == FinetuningStrategies.FREEZE_UNFREEZE: self._freeze_unfreeze_function(pl_module, epoch, optimizer, opt_idx, self.strategy_metadata) elif self.strategy == FinetuningStrategies.UNFREEZE_MILESTONES: self._unfreeze_milestones_function(pl_module, epoch, optimizer, opt_idx, self.strategy_metadata)
_FINETUNING_STRATEGIES_REGISTRY = FlashRegistry("finetuning_strategies") for strategy in FinetuningStrategies: _FINETUNING_STRATEGIES_REGISTRY( name=strategy.value, fn=partial(FlashBaseFinetuning, strategy_key=strategy), )
[docs]class NoFreeze(FlashBaseFinetuning): def __init__(self, train_bn: bool = True): super().__init__(FinetuningStrategies.NO_FREEZE, train_bn)
[docs]class Freeze(FlashBaseFinetuning): def __init__(self, train_bn: bool = True): super().__init__(FinetuningStrategies.FREEZE, train_bn)
[docs]class FreezeUnfreeze(FlashBaseFinetuning): def __init__( self, strategy_metadata: int, train_bn: bool = True, ): super().__init__(FinetuningStrategies.FREEZE_UNFREEZE, strategy_metadata, train_bn)
[docs]class UnfreezeMilestones(FlashBaseFinetuning): def __init__( self, strategy_metadata: Tuple[Tuple[int, int], int], train_bn: bool = True, ): super().__init__(FinetuningStrategies.UNFREEZE_MILESTONES, strategy_metadata, train_bn)
class FlashDeepSpeedFinetuning(FlashBaseFinetuning): """FlashDeepSpeedFinetuning can be used to create a custom Flash Finetuning Callback which works with DeepSpeed. DeepSpeed cannot store and load its parameters when working with Lightning. So FlashDeepSpeedFinetuning overrides `_store` to not store its parameters. """ def _store( self, pl_module: LightningModule, opt_idx: int, num_param_groups: int, current_param_groups: List[Dict[str, Any]], ) -> None: pass class NoFreezeDeepSpeed(FlashDeepSpeedFinetuning): def __init__(self, train_bn: bool = True): super().__init__(FinetuningStrategies.NO_FREEZE, train_bn) class FreezeDeepSpeed(FlashDeepSpeedFinetuning): def __init__(self, train_bn: bool = True): super().__init__(FinetuningStrategies.FREEZE, train_bn) class FreezeUnfreezeDeepSpeed(FlashDeepSpeedFinetuning): def __init__( self, strategy_metadata: int, train_bn: bool = True, ): super().__init__(FinetuningStrategies.FREEZE_UNFREEZE, strategy_metadata, train_bn) class UnfreezeMilestonesDeepSpeed(FlashDeepSpeedFinetuning): def __init__( self, strategy_metadata: Tuple[Tuple[int, int], int], train_bn: bool = True, ): super().__init__(FinetuningStrategies.UNFREEZE_MILESTONES, strategy_metadata, train_bn) for strategy in FinetuningStrategies: _FINETUNING_STRATEGIES_REGISTRY( name=f"{strategy.value}_deepspeed", fn=partial(FlashDeepSpeedFinetuning, strategy_key=strategy), )

© Copyright 2020-2021, PyTorch Lightning. Revision a374dd4f.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.