Shortcuts

Source code for flash.core.trainer

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import warnings
from argparse import ArgumentParser, Namespace
from functools import wraps
from typing import Callable, List, Optional, Tuple, Union

import torch
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning import Trainer as PlTrainer
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.utilities.argparse import add_argparse_args, get_init_arguments_and_types, parse_env_variables
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader

import flash
from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.io.transform_predictions import TransformPredictions
from flash.core.model import Task
from flash.core.registry import FlashRegistry


[docs]def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): """Modified version of :func:`pytorch_lightning.utilities.argparse.from_argparse_args` which populates ``valid_kwargs`` from :class:`pytorch_lightning.Trainer`.""" if isinstance(args, ArgumentParser): args = cls.parse_argparser(args) params = vars(args) # we only want to pass in valid PLTrainer args, the rest may be user specific valid_kwargs = inspect.signature(PlTrainer.__init__).parameters trainer_kwargs = {name: params[name] for name in valid_kwargs if name in params} trainer_kwargs.update(**kwargs) return cls(**trainer_kwargs)
def _defaults_from_env_vars(fn: Callable) -> Callable: """Copy of ``pytorch_lightning.trainer.connectors.env_vars_connector._defaults_from_env_vars``. Required to fix build error in readthedocs. """ @wraps(fn) def insert_env_defaults(self, *args, **kwargs): cls = self.__class__ # get the class if args: # inace any args passed move them to kwargs # parse only the argument names cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] # convert args to kwargs kwargs.update(dict(zip(cls_arg_names, args))) env_variables = vars(parse_env_variables(cls)) # update the kwargs by env variables kwargs = dict(list(env_variables.items()) + list(kwargs.items())) # all args were already moved to kwargs return fn(self, **kwargs) return insert_env_defaults
[docs]class Trainer(PlTrainer): @_defaults_from_env_vars def __init__(self, *args, **kwargs): if flash._IS_TESTING: if torch.cuda.is_available(): kwargs["gpus"] = -1 kwargs["limit_train_batches"] = 1.0 kwargs["limit_val_batches"] = 1.0 kwargs["limit_test_batches"] = 1.0 kwargs["fast_dev_run"] = False else: kwargs["fast_dev_run"] = True kwargs["gpus"] = None kwargs["accelerator"] = None kwargs["precision"] = 32 super().__init__(*args, **kwargs)
[docs] def fit( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, ): r""" Runs the full optimization routine. Same as :meth:`pytorch_lightning.Trainer.fit` Args: datamodule: A instance of :class:`LightningDataModule`. model: Model to fit. train_dataloader: A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped """ if any(isinstance(c, BaseFinetuning) for c in self.callbacks): # TODO: if we find a finetuning callback in the trainer should we remove it? or just warn the user? warnings.warn("Warning: You are calling fit(), but your trainer is using a fine-tuning callback") return super().fit(model, train_dataloader, val_dataloaders, datamodule)
[docs] def finetune( self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, datamodule: Optional[LightningDataModule] = None, strategy: Union[str, BaseFinetuning, Tuple[str, int], Tuple[str, Tuple[int, int]]] = "no_freeze", train_bn: bool = True, ): r""" Runs the full optimization routine. Same as :meth:`pytorch_lightning.Trainer.fit`, but unfreezes layers of the backbone throughout training layers of the backbone throughout training. Args: model: Model to fit. train_dataloader: A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped. val_dataloaders: Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped datamodule: A instance of :class:`LightningDataModule`. strategy: Should either be a string, or a Tuple, or a finetuning callback subclassing :class:`pytorch_lightning.callbacks.BaseFinetuning`. Default strategies can be enabled with these inputs: - ``"no_freeze"`` - ``"freeze"`` - ``("freeze_unfreeze", integer: unfreeze_epoch)`` - ``("unfreeze_milestones", ((integer: unfreeze_epoch_num_layers, integer: unfreeze_epoch_all_layers), integer: num_layers))`` where ``integer`` can be any integer. By default, ``no_freeze`` strategy will be used. train_bn: Whether to train Batch Norm layer """ self._resolve_callbacks(model, strategy, train_bn=train_bn) return super().fit(model, train_dataloader, val_dataloaders, datamodule)
[docs] def predict( self, model: Optional[LightningModule] = None, dataloaders: Optional[Union[DataLoader, LightningDataModule]] = None, output: Union[Output, str] = None, **kwargs, ): r""" Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the prediction hooks. Args: model: The model to predict with. dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying prediction samples. output: The :class:`~flash.core.data.io.output.Output` to use to transform predict outputs. kwargs: Additional keyword arguments to pass to ``pytorch_lightning.Trainer.predict``. Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ model = model or self.lightning_module output_transform = getattr(model, "_output_transform", None) or OutputTransform() if output is None: output = Output() if isinstance(output, str) and isinstance(model, Task): output = getattr(model, "outputs", FlashRegistry("outputs")).get(output).from_task(model) old_callbacks = self.callbacks self.callbacks = self._merge_callbacks(self.callbacks, [TransformPredictions(output_transform, output)]) result = super().predict(model, dataloaders, **kwargs) self.callbacks = old_callbacks return result
def _resolve_callbacks( self, model: Task, strategy: Union[str, BaseFinetuning, Tuple[str, int], Tuple[str, Tuple[int, int]]] = "no_freeze", train_bn: bool = True, ): """This function is used to select the `BaseFinetuning` to be used for finetuning.""" if isinstance(strategy, str) and strategy == "no_freeze": warnings.warn("The model contains a default finetune callback.", UserWarning) finetuning_callback = model.configure_finetune_callback(strategy=strategy, train_bn=train_bn) if len(finetuning_callback) > 1: raise MisconfigurationException("Create a list with only 1 finetuning callback.") self.callbacks = self._merge_callbacks(self.callbacks, finetuning_callback) @staticmethod def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List: """This function keeps only 1 instance of each callback type, extending new_callbacks with old_callbacks.""" if len(new_callbacks) == 0: return old_callbacks new_callbacks_types = {type(c) for c in new_callbacks} old_callbacks_types = {type(c) for c in old_callbacks} override_types = new_callbacks_types.intersection(old_callbacks_types) new_callbacks.extend(c for c in old_callbacks if type(c) not in override_types) return new_callbacks
[docs] @classmethod def add_argparse_args(cls, *args, **kwargs) -> ArgumentParser: """See :func:`pytorch_lightning.utilities.argparse.add_argparse_args`.""" # the lightning trainer implementation does not support subclasses. # context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447 return add_argparse_args(PlTrainer, *args, **kwargs)
[docs] @classmethod def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> "Trainer": """Modified version of :func:`pytorch_lightning.utilities.argparse.from_argparse_args` which populates ``valid_kwargs`` from :class:`pytorch_lightning.Trainer`.""" # the lightning trainer implementation does not support subclasses. # context: https://github.com/PyTorchLightning/lightning-flash/issues/342#issuecomment-848892447 return from_argparse_args(Trainer, args, **kwargs)

© Copyright 2020-2021, PyTorch Lightning. Revision 8db29e8e.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.3
Versions
latest
stable
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.