Source code for flash.core.trainer
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from argparse import ArgumentParser, Namespace
from functools import wraps
from typing import Callable, List, Optional, Tuple, Union
import torch
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning import Trainer as PlTrainer
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables
from torch.utils.data import DataLoader
import flash
from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.io.transform_predictions import TransformPredictions
from flash.core.model import Task
from flash.core.registry import FlashRegistry
[docs]def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
"""Modified version of :func:`pytorch_lightning.utilities.argparse.from_argparse_args` which populates
``valid_kwargs`` from :class:`pytorch_lightning.Trainer`."""
if isinstance(args, ArgumentParser):
args = cls.parse_argparser(args)
params = vars(args)
# we only want to pass in valid PLTrainer args, the rest may be user specific
valid_kwargs = inspect.signature(PlTrainer.__init__).parameters
trainer_kwargs = {name: params[name] for name in valid_kwargs if name in params}
trainer_kwargs.update(**kwargs)
return cls(**trainer_kwargs)
def _defaults_from_env_vars(fn: Callable) -> Callable:
"""Copy of ``pytorch_lightning.trainer.connectors.env_vars_connector._defaults_from_env_vars``.
Required to fix build error in readthedocs.
"""
@wraps(fn)
def insert_env_defaults(self, *args, **kwargs):
cls = self.__class__ # get the class
if args: # inace any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
# convert args to kwargs
kwargs.update(dict(zip(cls_arg_names, args)))
env_variables = vars(parse_env_variables(cls))
# update the kwargs by env variables
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
# all args were already moved to kwargs
return fn(self, **kwargs)
return insert_env_defaults
[docs]class Trainer(PlTrainer):
"""Exteded Trainer for FLash tasks.
>>> Trainer() # doctest: +ELLIPSIS
<...trainer.Trainer object at ...>
"""
@_defaults_from_env_vars
def __init__(self, *args, **kwargs):
if flash._IS_TESTING:
if torch.cuda.is_available():
kwargs["gpus"] = -1
kwargs["limit_train_batches"] = 1.0
kwargs["limit_val_batches"] = 1.0
kwargs["limit_test_batches"] = 1.0
kwargs["fast_dev_run"] = False
else:
kwargs["fast_dev_run"] = True
kwargs["gpus"] = None
kwargs["accelerator"] = None
kwargs["precision"] = 32
super().__init__(*args, **kwargs)
[docs] def fit(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
r"""Runs the full optimization routine. Same as :meth:`pytorch_lightning.Trainer.fit`
Args:
datamodule: A instance of :class:`LightningDataModule`.
model: Model to fit.
train_dataloader: A Pytorch DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped
"""
if any(isinstance(c, BaseFinetuning) for c in self.callbacks):
# TODO: if we find a finetuning callback in the trainer should we remove it? or just warn the user?
warnings.warn("Warning: You are calling fit(), but your trainer is using a fine-tuning callback")
return super().fit(model, train_dataloader, val_dataloaders, datamodule)
[docs] def finetune(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
strategy: Union[str, BaseFinetuning, Tuple[str, int], Tuple[str, Tuple[int, int]]] = "no_freeze",
train_bn: bool = True,
):
r"""Runs the full optimization routine. Same as :meth:`pytorch_lightning.Trainer.fit`, but unfreezes layers of
the backbone throughout training layers of the backbone throughout training.
Args:
model: Model to fit.
train_dataloader: A PyTorch DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.
val_dataloaders: Either a single PyTorch Dataloader or a list of them, specifying validation samples.
If the model has a predefined val_dataloaders method this will be skipped
datamodule: A instance of :class:`LightningDataModule`.
strategy: Should either be a string, or a Tuple, or a finetuning callback subclassing
:class:`pytorch_lightning.callbacks.BaseFinetuning`.
Default strategies can be enabled with these inputs:
- ``"no_freeze"``
- ``"freeze"``
- ``("freeze_unfreeze", integer: unfreeze_epoch)``
- ``("unfreeze_milestones", ((integer: unfreeze_epoch_num_layers, integer: unfreeze_epoch_all_layers),
integer: num_layers))``
where ``integer`` can be any integer.
By default, ``no_freeze`` strategy will be used.
train_bn: Whether to train Batch Norm layer
"""
self._resolve_callbacks(model, strategy, train_bn=train_bn)
return super().fit(model, train_dataloader, val_dataloaders, datamodule)
[docs] def predict(
self,
model: Optional[LightningModule] = None,
dataloaders: Optional[Union[DataLoader, LightningDataModule]] = None,
output: Union[Output, str] = None,
**kwargs,
):
r"""Run inference on your data.
This will call the model forward function to compute predictions.
Useful to perform distributed and batched predictions. Logging is disabled in the prediction hooks.
Args:
model: The model to predict with.
dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them,
or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying prediction samples.
output: The :class:`~flash.core.data.io.output.Output` to use to transform predict outputs.
kwargs: Additional keyword arguments to pass to :meth:`~pytorch_lightning.Trainer.predict`.
Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
"""
# Note: Prediction on TPU device with multi cores is not supported yet
if isinstance(self.accelerator, TPUAccelerator) and self.num_devices > 1:
raise NotImplementedError(
f"Prediction on TPU device with multi-cores (requested cores: {self.num_devices}) is not supported yet."
)
model = model or self.lightning_module
output_transform = getattr(model, "_output_transform", None) or OutputTransform()
if output is None:
output = Output()
if isinstance(output, str) and isinstance(model, Task):
output = getattr(model, "outputs", FlashRegistry("outputs")).get(output).from_task(model)
old_callbacks = self.callbacks
self.callbacks = self._merge_callbacks(self.callbacks, [TransformPredictions(output_transform, output)])
result = super().predict(model, dataloaders, **kwargs)
self.callbacks = old_callbacks
return result
def _resolve_callbacks(
self,
model: Task,
strategy: Union[str, BaseFinetuning, Tuple[str, int], Tuple[str, Tuple[int, int]]] = "no_freeze",
train_bn: bool = True,
):
"""This function is used to select the `BaseFinetuning` to be used for finetuning."""
if isinstance(strategy, str) and strategy == "no_freeze":
warnings.warn("The model contains a default finetune callback.", UserWarning)
finetuning_callback = model.configure_finetune_callback(strategy=strategy, train_bn=train_bn)
if len(finetuning_callback) > 1:
raise ValueError("Create a list with only 1 finetuning callback.")
self.callbacks = self._merge_callbacks(self.callbacks, finetuning_callback)
@staticmethod
def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List:
"""This function keeps only 1 instance of each callback type, extending new_callbacks with old_callbacks."""
if len(new_callbacks) == 0:
return old_callbacks
new_callbacks_types = {type(c) for c in new_callbacks}
old_callbacks_types = {type(c) for c in old_callbacks}
override_types = new_callbacks_types.intersection(old_callbacks_types)
new_callbacks.extend(c for c in old_callbacks if type(c) not in override_types)
return new_callbacks
[docs] @classmethod
def add_argparse_args(cls, *args, **kwargs) -> ArgumentParser:
"""See :func:`pytorch_lightning.utilities.argparse.add_argparse_args`."""
# the lightning trainer implementation does not support subclasses.
# context: https://github.com/Lightning-AI/lightning-flash/issues/342#issuecomment-848892447
return add_argparse_args(PlTrainer, *args, **kwargs)
[docs] @classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> "Trainer":
"""Modified version of :func:`pytorch_lightning.utilities.argparse.from_argparse_args` which populates
``valid_kwargs`` from :class:`pytorch_lightning.Trainer`."""
# the lightning trainer implementation does not support subclasses.
# context: https://github.com/Lightning-AI/lightning-flash/issues/342#issuecomment-848892447
return from_argparse_args(Trainer, args, **kwargs)
@property
def estimated_stepping_batches(self) -> Union[int, float]:
"""Estimated stepping batches for the complete training inferred from DataLoaders, gradient accumulation factor
and distributed setup.
Examples
________
.. code-block:: python
def configure_optimizers(self):
optimizer = ...
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
)
return [optimizer], [scheduler]
"""
return super().estimated_stepping_batches