Shortcuts

Source code for flash.core.trainer

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from argparse import ArgumentParser, Namespace
from functools import wraps
from typing import Callable, List, Optional, Tuple, Union

import torch
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning import Trainer as PlTrainer
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables
from torch.utils.data import DataLoader

import flash
from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.io.transform_predictions import TransformPredictions
from flash.core.model import Task
from flash.core.registry import FlashRegistry


[docs]def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): """Modified version of :func:`pytorch_lightning.utilities.argparse.from_argparse_args` which populates ``valid_kwargs`` from :class:`pytorch_lightning.Trainer`.""" if isinstance(args, ArgumentParser): args = cls.parse_argparser(args) params = vars(args) # we only want to pass in valid PLTrainer args, the rest may be user specific valid_kwargs = inspect.signature(PlTrainer.__init__).parameters trainer_kwargs = {name: params[name] for name in valid_kwargs if name in params} trainer_kwargs.update(**kwargs) return cls(**trainer_kwargs)
def _defaults_from_env_vars(fn: Callable) -> Callable: """Copy of ``pytorch_lightning.trainer.connectors.env_vars_connector._defaults_from_env_vars``. Required to fix build error in readthedocs. """ @wraps(fn) def insert_env_defaults(self, *args, **kwargs): cls = self.__class__ # get the class if args: # inace any args passed move them to kwargs # parse only the argument names cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] # convert args to kwargs kwargs.update(dict(zip(cls_arg_names, args))) env_variables = vars(parse_env_variables(cls)) # update the kwargs by env variables kwargs = dict(list(env_variables.items()) + list(kwargs.items())) # all args were already moved to kwargs return fn(self, **kwargs) return insert_env_defaults
[docs]class Trainer(PlTrainer): """Exteded Trainer for FLash tasks. >>> Trainer() # doctest: +ELLIPSIS <...trainer.Trainer object at ...> """ @_defaults_from_env_vars def __init__(self, *args, **kwargs): if flash._IS_TESTING: if torch.cuda.is_available(): kwargs["gpus"] = -1 kwargs["limit_train_batches"] = 1.0 kwargs["limit_val_batches"] = 1.0 kwargs["limit_test_batches"] = 1.0 kwargs["fast_dev_run"] = False else: kwargs["fast_dev_run"] = True kwargs["gpus"] = None kwargs["accelerator"] = None kwargs["precision"] = 32 super().__init__(*args, **kwargs)
[docs] def fit( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): r"""Runs the full optimization routine. Same as :meth:`pytorch_lightning.Trainer.fit` Args: datamodule: A instance of :class:`LightningDataModule`. model: Model to fit. train_dataloader: A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped """ if any(isinstance(c, BaseFinetuning) for c in self.callbacks): # TODO: if we find a finetuning callback in the trainer should we remove it? or just warn the user? warnings.warn("Warning: You are calling fit(), but your trainer is using a fine-tuning callback") return super().fit(model, train_dataloader, val_dataloaders, datamodule)
[docs] def finetune( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, strategy: Union[str, BaseFinetuning, Tuple[str, int], Tuple[str, Tuple[int, int]]] = "no_freeze", train_bn: bool = True, ): r"""Runs the full optimization routine. Same as :meth:`pytorch_lightning.Trainer.fit`, but unfreezes layers of the backbone throughout training layers of the backbone throughout training. Args: model: Model to fit. train_dataloader: A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. val_dataloaders: Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped datamodule: A instance of :class:`LightningDataModule`. strategy: Should either be a string, or a Tuple, or a finetuning callback subclassing :class:`pytorch_lightning.callbacks.BaseFinetuning`. Default strategies can be enabled with these inputs: - ``"no_freeze"`` - ``"freeze"`` - ``("freeze_unfreeze", integer: unfreeze_epoch)`` - ``("unfreeze_milestones", ((integer: unfreeze_epoch_num_layers, integer: unfreeze_epoch_all_layers), integer: num_layers))`` where ``integer`` can be any integer. By default, ``no_freeze`` strategy will be used. train_bn: Whether to train Batch Norm layer """ self._resolve_callbacks(model, strategy, train_bn=train_bn) return super().fit(model, train_dataloader, val_dataloaders, datamodule)
[docs] def predict( self, model: Optional[LightningModule] = None, dataloaders: Optional[Union[DataLoader, LightningDataModule]] = None, output: Union[Output, str] = None, **kwargs, ): r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the prediction hooks. Args: model: The model to predict with. dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying prediction samples. output: The :class:`~flash.core.data.io.output.Output` to use to transform predict outputs. kwargs: Additional keyword arguments to pass to :meth:`~pytorch_lightning.Trainer.predict`. Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ # Note: Prediction on TPU device with multi cores is not supported yet if isinstance(self.accelerator, TPUAccelerator) and self.num_devices > 1: raise NotImplementedError( f"Prediction on TPU device with multi-cores (requested cores: {self.num_devices}) is not supported yet." ) model = model or self.lightning_module output_transform = getattr(model, "_output_transform", None) or OutputTransform() if output is None: output = Output() if isinstance(output, str) and isinstance(model, Task): output = getattr(model, "outputs", FlashRegistry("outputs")).get(output).from_task(model) old_callbacks = self.callbacks self.callbacks = self._merge_callbacks(self.callbacks, [TransformPredictions(output_transform, output)]) result = super().predict(model, dataloaders, **kwargs) self.callbacks = old_callbacks return result
def _resolve_callbacks( self, model: Task, strategy: Union[str, BaseFinetuning, Tuple[str, int], Tuple[str, Tuple[int, int]]] = "no_freeze", train_bn: bool = True, ): """This function is used to select the `BaseFinetuning` to be used for finetuning.""" if isinstance(strategy, str) and strategy == "no_freeze": warnings.warn("The model contains a default finetune callback.", UserWarning) finetuning_callback = model.configure_finetune_callback(strategy=strategy, train_bn=train_bn) if len(finetuning_callback) > 1: raise ValueError("Create a list with only 1 finetuning callback.") self.callbacks = self._merge_callbacks(self.callbacks, finetuning_callback) @staticmethod def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List: """This function keeps only 1 instance of each callback type, extending new_callbacks with old_callbacks.""" if len(new_callbacks) == 0: return old_callbacks new_callbacks_types = {type(c) for c in new_callbacks} old_callbacks_types = {type(c) for c in old_callbacks} override_types = new_callbacks_types.intersection(old_callbacks_types) new_callbacks.extend(c for c in old_callbacks if type(c) not in override_types) return new_callbacks
[docs] @classmethod def add_argparse_args(cls, *args, **kwargs) -> ArgumentParser: """See :func:`pytorch_lightning.utilities.argparse.add_argparse_args`.""" # the lightning trainer implementation does not support subclasses. # context: https://github.com/Lightning-AI/lightning-flash/issues/342#issuecomment-848892447 return add_argparse_args(PlTrainer, *args, **kwargs)
[docs] @classmethod def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> "Trainer": """Modified version of :func:`pytorch_lightning.utilities.argparse.from_argparse_args` which populates ``valid_kwargs`` from :class:`pytorch_lightning.Trainer`.""" # the lightning trainer implementation does not support subclasses. # context: https://github.com/Lightning-AI/lightning-flash/issues/342#issuecomment-848892447 return from_argparse_args(Trainer, args, **kwargs)
@property def estimated_stepping_batches(self) -> Union[int, float]: """Estimated stepping batches for the complete training inferred from DataLoaders, gradient accumulation factor and distributed setup. Examples ________ .. code-block:: python def configure_optimizers(self): optimizer = ... scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches ) return [optimizer], [scheduler] """ return super().estimated_stepping_batches

© Copyright 2020-2021, PyTorch Lightning. Revision a374dd4f.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.