Shortcuts

Source code for flash.core.data.io.input_transform

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.enums import LightningEnum

from flash.core.data.callback import ControlFlow
from flash.core.data.utilities.collate import default_collate
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE


class InputTransformPlacement(LightningEnum):
    PER_SAMPLE_TRANSFORM = "per_sample_transform"
    PER_BATCH_TRANSFORM = "per_batch_transform"
    COLLATE = "collate"
    PER_SAMPLE_TRANSFORM_ON_DEVICE = "per_sample_transform_on_device"
    PER_BATCH_TRANSFORM_ON_DEVICE = "per_batch_transform_on_device"


INVALID_STAGES_FOR_INPUT_TRANSFORMS = [RunningStage.SANITY_CHECKING, RunningStage.TUNING]


@dataclass
class _InputTransformPerStage:
    collate_in_worker: bool
    transforms: Optional[Dict[str, Callable]] = None


[docs]@dataclass class InputTransform: def __post_init__(self): self.callbacks: Optional[List] = None # used to keep track of provided transforms self._transform: Dict[RunningStage, _InputTransformPerStage] = {} # For all the stages possible, set/load the transforms. for stage in RunningStage: if stage not in INVALID_STAGES_FOR_INPUT_TRANSFORMS: self._populate_transforms_for_stage(stage) def current_transform(self, stage: RunningStage, current_fn: str) -> Callable: return self._transform[stage].transforms.get(current_fn, self._identity) ######################## # PER SAMPLE TRANSFORM # ########################
[docs] def per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ pass
[docs] def train_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_sample_transform()
[docs] def val_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform()
[docs] def test_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_sample_transform()
[docs] def predict_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform()
[docs] def serve_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the serving stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform()
################################## # PER SAMPLE TRANSFORM ON DEVICE # ##################################
[docs] def per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ pass
[docs] def train_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_sample_transform_on_device()
[docs] def val_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform_on_device()
[docs] def test_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_sample_transform_on_device()
[docs] def predict_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform_on_device()
[docs] def serve_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the serving stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def serve_per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_sample_transform_on_device()
####################### # PER BATCH TRANSFORM # #######################
[docs] def per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ pass
[docs] def train_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_batch_transform()
[docs] def val_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform()
[docs] def test_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_batch_transform()
[docs] def predict_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform()
[docs] def serve_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the serving stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform()
################################# # PER BATCH TRANSFORM ON DEVICE # #################################
[docs] def per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ pass
[docs] def train_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_batch_transform_on_device()
[docs] def val_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform_on_device()
[docs] def test_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self.per_batch_transform_on_device()
[docs] def predict_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform_on_device()
[docs] def serve_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the serving stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def serve_per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self.per_batch_transform_on_device()
########### # COLLATE # ###########
[docs] def collate(self) -> Callable: """Defines the transform to be applied on a list of sample to create a batch for all stages.""" return default_collate
[docs] def train_collate(self) -> Callable: """Defines the transform to be applied on a list of training sample to create a training batch.""" return self.collate()
[docs] def val_collate(self) -> Callable: """Defines the transform to be applied on a list of validating sample to create a validating batch.""" return self.collate()
[docs] def test_collate(self) -> Callable: """Defines the transform to be applied on a list of testing sample to create a testing batch.""" return self.collate()
[docs] def predict_collate(self) -> Callable: """Defines the transform to be applied on a list of predicting sample to create a predicting batch.""" return self.collate()
[docs] def serve_collate(self) -> Callable: """Defines the transform to be applied on a list of serving sample to create a serving batch.""" return self.collate()
######################################## # HOOKS CALLED INTERNALLY WITHIN FLASH # ######################################## def _per_sample_transform(self, sample: Any, stage: RunningStage) -> Any: fn = self.current_transform(stage=stage, current_fn="per_sample_transform") if isinstance(sample, list): return [fn(s) for s in sample] return fn(sample) def _per_batch_transform(self, batch: Any, stage: RunningStage) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). .. note:: This option is mutually exclusive with :meth:`per_sample_transform_on_device`, since if both are specified, uncollation has to be applied. """ return self.current_transform(stage=stage, current_fn="per_batch_transform")(batch) def _collate(self, samples: Sequence, stage: RunningStage) -> Any: """Transform to convert a sequence of samples to a collated batch.""" return self.current_transform(stage=stage, current_fn="collate")(samples) def _per_sample_transform_on_device(self, sample: Any, stage: RunningStage) -> Any: """Transforms to apply to the data before the collation (per-sample basis). .. note:: This option is mutually exclusive with :meth:`per_batch_transform`, since if both are specified, uncollation has to be applied. .. note:: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ fn = self.current_transform(stage=stage, current_fn="per_sample_transform_on_device") if isinstance(sample, list): return [fn(s) for s in sample] return fn(sample) def _per_batch_transform_on_device(self, batch: Any, stage: RunningStage) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). .. note:: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ return self.current_transform(stage=stage, current_fn="per_batch_transform_on_device")(batch) ############# # UTILITIES # ############# def inject_collate_fn(self, collate_fn: Callable): # For all the stages possible, set collate function if collate_fn is not default_collate: for stage in RunningStage: if stage not in [RunningStage.SANITY_CHECKING, RunningStage.TUNING]: self._transform[stage].transforms[InputTransformPlacement.COLLATE.value] = collate_fn def _populate_transforms_for_stage(self, running_stage: RunningStage): transform, collate_in_worker = self.__check_transforms( transform=self.__resolve_transforms(running_stage), ) self._transform[running_stage] = _InputTransformPerStage( collate_in_worker=collate_in_worker, transforms=transform, ) def __resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: transforms = {} stage = _STAGES_PREFIX[running_stage] # iterate over all transforms hook name for transform_name in InputTransformPlacement: transform_name = transform_name.value method_name = f"{stage}_{transform_name}" # get associated transform try: fn = getattr(self, method_name)() except AttributeError as e: raise AttributeError( str(e) + ". Make sure you include a call to super().__init__(...) in your __init__ after setting " "all attributes." ) if fn is None: continue if not callable(fn): raise TypeError(f"The hook {method_name} should return a callable.") transforms[transform_name] = fn return transforms def __check_transforms(self, transform: Dict[str, Callable]) -> Tuple[Dict[str, Callable], Optional[bool]]: is_per_batch_transform_in = "per_batch_transform" in transform is_per_sample_transform_on_device_in = "per_sample_transform_on_device" in transform if is_per_batch_transform_in and is_per_sample_transform_on_device_in: raise TypeError( f"{transform}: `per_batch_transform` and `per_sample_transform_on_device` are mutually exclusive." ) collate_in_worker: Optional[bool] = not is_per_sample_transform_on_device_in return transform, collate_in_worker @staticmethod def _identity(x: Any) -> Any: return x def __str__(self) -> str: return f"{self.__class__.__name__}(" + f"transform={self._transform})"
def create_or_configure_input_transform( transform: INPUT_TRANSFORM_TYPE, transform_kwargs: Optional[Dict] = None, ) -> Optional[InputTransform]: if not transform_kwargs: transform_kwargs = {} if isinstance(transform, InputTransform): return transform if inspect.isclass(transform) and issubclass(transform, InputTransform): # Deprecation Warning rank_zero_warn( "Please pass an instantiated object of the `InputTransform` class. Passing the Class and keyword arguments" " separately has been deprecated since v0.8.0 and will be removed in v0.9.0.", stacklevel=8, category=FutureWarning, ) return transform(**transform_kwargs) if isinstance(transform, partial): return transform(**transform_kwargs) if not transform: return None raise ValueError(f"The format for the transform isn't correct. Found {transform}") class _InputTransformProcessor: """ This class is used to encapsulate the following functions of an `InputTransform` Object: Inside a worker: per_sample_transform: Function to transform an individual sample collate: Function to merge sample into a batch per_batch_transform: Function to transform an individual batch Inside main process: per_sample_transform_on_device: Function to transform an individual sample collate: Function to merge sample into a batch per_batch_transform_on_device: Function to transform an individual batch """ def __init__( self, input_transform: InputTransform, collate_fn: Callable, per_sample_transform: Callable, per_batch_transform: Callable, stage: RunningStage, apply_per_sample_transform: bool = True, on_device: bool = False, ): super().__init__() self.input_transform = input_transform self.callback = ControlFlow(self.input_transform.callbacks or []) self.collate_fn = collate_fn self.per_sample_transform = per_sample_transform self.per_batch_transform = per_batch_transform self.apply_per_sample_transform = apply_per_sample_transform self.stage = stage self.on_device = on_device def __call__(self, samples: Sequence[Any]) -> Any: if not self.on_device: for sample in samples: self.callback.on_load_sample(sample, self.stage) if self.apply_per_sample_transform: list_samples = [samples] if not isinstance(samples, list) else samples transformed_samples = [self.per_sample_transform(sample, self.stage) for sample in list_samples] for sample in transformed_samples: if self.on_device: self.callback.on_per_sample_transform_on_device(sample, self.stage) else: self.callback.on_per_sample_transform(sample, self.stage) collated_samples = self.collate_fn(transformed_samples, self.stage) self.callback.on_collate(collated_samples, self.stage) else: collated_samples = samples transformed_collated_samples = self.per_batch_transform(collated_samples, self.stage) if self.on_device: self.callback.on_per_batch_transform_on_device(transformed_collated_samples, self.stage) else: self.callback.on_per_batch_transform(transformed_collated_samples, self.stage) return transformed_collated_samples def __str__(self) -> str: # todo: define repr function which would take object and string attributes to be shown return ( "_InputTransformProcessor:\n" f"\t(per_sample_transform): {str(self.per_sample_transform)}\n" f"\t(collate_fn): {str(self.collate_fn)}\n" f"\t(per_batch_transform): {str(self.per_batch_transform)}\n" f"\t(apply_per_sample_transform): {str(self.apply_per_sample_transform)}\n" f"\t(on_device): {str(self.on_device)}\n" f"\t(stage): {str(self.stage)}" ) def __make_collates(input_transform: InputTransform, on_device: bool, collate: Callable) -> Tuple[Callable, Callable]: """Returns the appropriate collate functions based on whether the transforms happen in a DataLoader worker or on the device (main process).""" if on_device: return input_transform._identity, collate return collate, input_transform._identity def __configure_worker_and_device_collate_fn( running_stage: RunningStage, input_transform: InputTransform ) -> Tuple[Callable, Callable]: transform_for_stage: _InputTransformPerStage = input_transform._transform[running_stage] worker_collate_fn, device_collate_fn = __make_collates( input_transform, not transform_for_stage.collate_in_worker, input_transform._collate ) return worker_collate_fn, device_collate_fn def create_worker_input_transform_processor( running_stage: RunningStage, input_transform: InputTransform ) -> _InputTransformProcessor: """This utility is used to create the 2 `_InputTransformProcessor` objects which contain the transforms used as the DataLoader `collate_fn`.""" worker_collate_fn, _ = __configure_worker_and_device_collate_fn( running_stage=running_stage, input_transform=input_transform ) return _InputTransformProcessor( input_transform, worker_collate_fn, input_transform._per_sample_transform, input_transform._per_batch_transform, running_stage, ) def create_device_input_transform_processor( running_stage: RunningStage, input_transform: InputTransform ) -> _InputTransformProcessor: """This utility is used to create a `_InputTransformProcessor` object which contain the transforms used as the DataModule `on_after_batch_transfer` hook.""" _, device_collate_fn = __configure_worker_and_device_collate_fn( running_stage=running_stage, input_transform=input_transform ) return _InputTransformProcessor( input_transform, device_collate_fn, input_transform._per_sample_transform_on_device, input_transform._per_batch_transform_on_device, running_stage, apply_per_sample_transform=device_collate_fn != input_transform._identity, on_device=True, )

© Copyright 2020-2021, PyTorch Lightning. Revision a9cedb5a.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.