Shortcuts

Source code for flash.core.data.io.input_transform

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.enums import LightningEnum

from flash.core.data.callback import ControlFlow
from flash.core.data.utilities.collate import default_collate
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE


class InputTransformPlacement(LightningEnum):
    PER_SAMPLE_TRANSFORM = "per_sample_transform"
    PER_BATCH_TRANSFORM = "per_batch_transform"
    COLLATE = "collate"
    PER_SAMPLE_TRANSFORM_ON_DEVICE = "per_sample_transform_on_device"
    PER_BATCH_TRANSFORM_ON_DEVICE = "per_batch_transform_on_device"


INVALID_STAGES_FOR_INPUT_TRANSFORMS = [RunningStage.SANITY_CHECKING, RunningStage.TUNING]


@dataclass
class _InputTransformPerStage:
    collate_in_worker: bool
    transforms: Optional[Dict[str, Callable]] = None


[docs]@dataclass class InputTransform: def __post_init__(self): self.callbacks: Optional[List] = None # used to keep track of provided transforms self._transform: Dict[RunningStage, _InputTransformPerStage] = {} # For all the stages possible, set/load the transforms. for stage in RunningStage: if stage not in INVALID_STAGES_FOR_INPUT_TRANSFORMS: self._populate_transforms_for_stage(stage) def current_transform(self, stage: RunningStage, current_fn: str) -> Callable: return self._transform[stage].transforms.get(current_fn, self._identity) ######################## # PER SAMPLE TRANSFORM # ########################
[docs] def per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ pass
[docs] def train_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_sample_transform()
[docs] def val_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform()
[docs] def test_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_sample_transform()
[docs] def predict_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform()
[docs] def serve_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the serving stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform()
################################## # PER SAMPLE TRANSFORM ON DEVICE # ##################################
[docs] def per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ pass
[docs] def train_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_sample_transform_on_device()
[docs] def val_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform_on_device()
[docs] def test_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_sample_transform_on_device()
[docs] def predict_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform_on_device()
[docs] def serve_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the serving stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def serve_per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform_on_device()
####################### # PER BATCH TRANSFORM # #######################
[docs] def per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ pass
[docs] def train_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_batch_transform()
[docs] def val_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform()
[docs] def test_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_batch_transform()
[docs] def predict_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform()
[docs] def serve_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the serving stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform()
################################# # PER BATCH TRANSFORM ON DEVICE # #################################
[docs] def per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ pass
[docs] def train_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_batch_transform_on_device()
[docs] def val_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform_on_device()
[docs] def test_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_batch_transform_on_device()
[docs] def predict_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform_on_device()
[docs] def serve_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the serving stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def serve_per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform_on_device()
########### # COLLATE # ###########
[docs] def collate(self) -> Callable: """Defines the transform to be applied on a list of sample to create a batch for all stages.""" return default_collate
[docs] def train_collate(self) -> Callable: """Defines the transform to be applied on a list of training sample to create a training batch.""" return self.collate()
[docs] def val_collate(self) -> Callable: """Defines the transform to be applied on a list of validating sample to create a validating batch.""" return self.collate()
[docs] def test_collate(self) -> Callable: """Defines the transform to be applied on a list of testing sample to create a testing batch.""" return self.collate()
[docs] def predict_collate(self) -> Callable: """Defines the transform to be applied on a list of predicting sample to create a predicting batch.""" return self.collate()
[docs] def serve_collate(self) -> Callable: """Defines the transform to be applied on a list of serving sample to create a serving batch.""" return self.collate()
######################################## # HOOKS CALLED INTERNALLY WITHIN FLASH # ######################################## def _per_sample_transform(self, sample: Any, stage: RunningStage) -> Any: fn = self.current_transform(stage=stage, current_fn="per_sample_transform") if isinstance(sample, list): return [fn(s) for s in sample] return fn(sample) def _per_batch_transform(self, batch: Any, stage: RunningStage) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). .. note:: This option is mutually exclusive with :meth:`per_sample_transform_on_device`, since if both are specified, uncollation has to be applied. """ return self.current_transform(stage=stage, current_fn="per_batch_transform")(batch) def _collate(self, samples: Sequence, stage: RunningStage) -> Any: """Transform to convert a sequence of samples to a collated batch.""" return self.current_transform(stage=stage, current_fn="collate")(samples) def _per_sample_transform_on_device(self, sample: Any, stage: RunningStage) -> Any: """Transforms to apply to the data before the collation (per-sample basis). .. note:: This option is mutually exclusive with :meth:`per_batch_transform`, since if both are specified, uncollation has to be applied. .. note:: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ fn = self.current_transform(stage=stage, current_fn="per_sample_transform_on_device") if isinstance(sample, list): return [fn(s) for s in sample] return fn(sample) def _per_batch_transform_on_device(self, batch: Any, stage: RunningStage) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). .. note:: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ return self.current_transform(stage=stage, current_fn="per_batch_transform_on_device")(batch) ############# # UTILITIES # ############# def inject_collate_fn(self, collate_fn: Callable): # For all the stages possible, set collate function if collate_fn is not default_collate: for stage in RunningStage: if stage not in [RunningStage.SANITY_CHECKING, RunningStage.TUNING]: self._transform[stage].transforms[InputTransformPlacement.COLLATE.value] = collate_fn def _populate_transforms_for_stage(self, running_stage: RunningStage): transform, collate_in_worker = self.__check_transforms( transform=self.__resolve_transforms(running_stage), ) self._transform[running_stage] = _InputTransformPerStage( collate_in_worker=collate_in_worker, transforms=transform, ) def __resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: transforms = {} stage = _STAGES_PREFIX[running_stage] # iterate over all transforms hook name for transform_name in InputTransformPlacement: transform_name = transform_name.value method_name = f"{stage}_{transform_name}" # get associated transform try: fn = getattr(self, method_name)() except AttributeError as e: raise AttributeError( str(e) + ". Make sure you include a call to super().__init__(...) in your __init__ after setting " "all attributes." ) if fn is None: continue if not callable(fn): raise TypeError(f"The hook {method_name} should return a callable.") transforms[transform_name] = fn return transforms def __check_transforms(self, transform: Dict[str, Callable]) -> Tuple[Dict[str, Callable], Optional[bool]]: is_per_batch_transform_in = "per_batch_transform" in transform is_per_sample_transform_on_device_in = "per_sample_transform_on_device" in transform if is_per_batch_transform_in and is_per_sample_transform_on_device_in: raise TypeError( f"{transform}: `per_batch_transform` and `per_sample_transform_on_device` are mutually exclusive." ) collate_in_worker: Optional[bool] = not is_per_sample_transform_on_device_in return transform, collate_in_worker @staticmethod def _identity(x: Any) -> Any: return x def __str__(self) -> str: return f"{self.__class__.__name__}(" + f"transform={self._transform})"
def create_or_configure_input_transform( transform: INPUT_TRANSFORM_TYPE, transform_kwargs: Optional[Dict] = None, ) -> Optional[InputTransform]: if not transform_kwargs: transform_kwargs = {} if isinstance(transform, InputTransform): return transform if inspect.isclass(transform) and issubclass(transform, InputTransform): # Deprecation Warning rank_zero_warn( "Please pass an instantiated object of the `InputTransform` class. Passing the Class and keyword arguments" " separately has been deprecated since v0.8.0 and will be removed in v0.9.0.", stacklevel=8, category=FutureWarning, ) return transform(**transform_kwargs) if isinstance(transform, partial): return transform(**transform_kwargs) if not transform: return None raise ValueError(f"The format for the transform isn't correct. Found {transform}") class _InputTransformProcessor: """ This class is used to encapsulate the following functions of an `InputTransform` Object: Inside a worker: per_sample_transform: Function to transform an individual sample collate: Function to merge sample into a batch per_batch_transform: Function to transform an individual batch Inside main process: per_sample_transform_on_device: Function to transform an individual sample collate: Function to merge sample into a batch per_batch_transform_on_device: Function to transform an individual batch """ def __init__( self, input_transform: InputTransform, collate_fn: Callable, per_sample_transform: Callable, per_batch_transform: Callable, stage: RunningStage, apply_per_sample_transform: bool = True, on_device: bool = False, ): super().__init__() self.input_transform = input_transform self.callback = ControlFlow(self.input_transform.callbacks or []) self.collate_fn = collate_fn self.per_sample_transform = per_sample_transform self.per_batch_transform = per_batch_transform self.apply_per_sample_transform = apply_per_sample_transform self.stage = stage self.on_device = on_device def __call__(self, samples: Sequence[Any]) -> Any: if not self.on_device: for sample in samples: self.callback.on_load_sample(sample, self.stage) if self.apply_per_sample_transform: list_samples = [samples] if not isinstance(samples, list) else samples transformed_samples = [self.per_sample_transform(sample, self.stage) for sample in list_samples] for sample in transformed_samples: if self.on_device: self.callback.on_per_sample_transform_on_device(sample, self.stage) else: self.callback.on_per_sample_transform(sample, self.stage) collated_samples = self.collate_fn(transformed_samples, self.stage) self.callback.on_collate(collated_samples, self.stage) else: collated_samples = samples transformed_collated_samples = self.per_batch_transform(collated_samples, self.stage) if self.on_device: self.callback.on_per_batch_transform_on_device(transformed_collated_samples, self.stage) else: self.callback.on_per_batch_transform(transformed_collated_samples, self.stage) return transformed_collated_samples def __str__(self) -> str: # todo: define repr function which would take object and string attributes to be shown return ( "_InputTransformProcessor:\n" f"\t(per_sample_transform): {str(self.per_sample_transform)}\n" f"\t(collate_fn): {str(self.collate_fn)}\n" f"\t(per_batch_transform): {str(self.per_batch_transform)}\n" f"\t(apply_per_sample_transform): {str(self.apply_per_sample_transform)}\n" f"\t(on_device): {str(self.on_device)}\n" f"\t(stage): {str(self.stage)}" ) def __make_collates(input_transform: InputTransform, on_device: bool, collate: Callable) -> Tuple[Callable, Callable]: """Returns the appropriate collate functions based on whether the transforms happen in a DataLoader worker or on the device (main process).""" if on_device: return input_transform._identity, collate return collate, input_transform._identity def __configure_worker_and_device_collate_fn( running_stage: RunningStage, input_transform: InputTransform ) -> Tuple[Callable, Callable]: transform_for_stage: _InputTransformPerStage = input_transform._transform[running_stage] worker_collate_fn, device_collate_fn = __make_collates( input_transform, not transform_for_stage.collate_in_worker, input_transform._collate ) return worker_collate_fn, device_collate_fn def create_worker_input_transform_processor( running_stage: RunningStage, input_transform: InputTransform ) -> _InputTransformProcessor: """This utility is used to create the 2 `_InputTransformProcessor` objects which contain the transforms used as the DataLoader `collate_fn`.""" worker_collate_fn, _ = __configure_worker_and_device_collate_fn( running_stage=running_stage, input_transform=input_transform ) return _InputTransformProcessor( input_transform, worker_collate_fn, input_transform._per_sample_transform, input_transform._per_batch_transform, running_stage, ) def create_device_input_transform_processor( running_stage: RunningStage, input_transform: InputTransform ) -> _InputTransformProcessor: """This utility is used to create a `_InputTransformProcessor` object which contain the transforms used as the DataModule `on_after_batch_transfer` hook.""" _, device_collate_fn = __configure_worker_and_device_collate_fn( running_stage=running_stage, input_transform=input_transform ) return _InputTransformProcessor( input_transform, device_collate_fn, input_transform._per_sample_transform_on_device, input_transform._per_batch_transform_on_device, running_stage, apply_per_sample_transform=device_collate_fn != input_transform._identity, on_device=True, )

© Copyright 2020-2021, PyTorch Lightning. Revision a9cedb5a.

Built with Sphinx using a theme provided by Read the Docs.