Shortcuts

Source code for flash.core.data.io.input_transform

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data._utils.collate import default_collate

from flash.core.data.callback import ControlFlow, FlashCallback
from flash.core.data.io.input import DataKeys
from flash.core.data.properties import Properties
from flash.core.data.transforms import ApplyToKeys
from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _STAGES_PREFIX
from flash.core.registry import FlashRegistry
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE


class InputTransformPlacement(LightningEnum):

    PER_SAMPLE_TRANSFORM = "per_sample_transform"
    PER_BATCH_TRANSFORM = "per_batch_transform"
    COLLATE = "collate"
    PER_SAMPLE_TRANSFORM_ON_DEVICE = "per_sample_transform_on_device"
    PER_BATCH_TRANSFORM_ON_DEVICE = "per_batch_transform_on_device"


class ApplyToKeyPrefix(LightningEnum):

    INPUT = "input"
    TARGET = "target"


def transform_context(func: Callable, current_fn: str) -> Callable:
    @wraps(func)
    def wrapper(self, *args, **kwargs) -> Any:
        self.current_fn = current_fn
        result = func(self, *args, **kwargs)
        self.current_fn = None
        return result

    return wrapper


# Credit to Torchvision Team:
# https://pytorch.org/vision/stable/_modules/torchvision/transforms/transforms.html#Compose
class Compose:
    """Composes several transforms together.

    This transform does not support torchscript.
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, x):
        for t in self.transforms:
            x = t(x)
        return x

    def __repr__(self):
        format_string = self.__class__.__name__ + "("
        for t in self.transforms:
            format_string += "\n"
            format_string += f"{t}"
        format_string += "\n)"
        return format_string


[docs]@dataclass class InputTransform(Properties): running_stage: RunningStage def __post_init__(self): # used to keep track of provided transforms self._collate_in_worker_from_transform: Optional[bool] = None self._transform = None self._transform = self._check_transforms(self._resolve_transforms(self.running_stage), self.running_stage) # Hack Properties.__init__(self, running_stage=self.running_stage) @property def current_transform(self) -> Callable: if self._transform: return self._get_transform(self._transform) return self._identity @property def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: """The transforms currently being used by this :class:`~flash.core.data.io.input_transform.InputTransform`.""" return { "transform": self._transform, } ######################## # PER SAMPLE TRANSFORM # ########################
[docs] def per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def input_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each sample on device for all stages stage.""" return self._identity
[docs] def target_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each sample on device for all stages stage.""" return self._identity
[docs] def train_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self._identity
[docs] def train_input_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on cpu for the training stage.""" return self._identity
[docs] def train_target_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on cpu for the training stage.""" return self._identity
[docs] def val_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def val_input_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on cpu for the validating stage.""" return self._identity
[docs] def val_target_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on cpu for the validating stage.""" return self._identity
[docs] def test_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self._identity
[docs] def test_input_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on cpu for the testing stage.""" return self._identity
[docs] def test_target_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on cpu for the testing stage.""" return self._identity
[docs] def predict_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def predict_input_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on cpu for the predicting stage.""" return self._identity
[docs] def predict_target_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on cpu for the predicting stage.""" return self._identity
[docs] def serve_per_sample_transform(self) -> Callable: """Defines the transform to be applied on a single sample on cpu for the serving stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def serve_input_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on cpu for the serving stage.""" return self._identity
[docs] def serve_target_per_sample_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on cpu for the serving stage.""" return self._identity
################################## # PER SAMPLE TRANSFORM ON DEVICE # ##################################
[docs] def per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def input_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each sample on device for all stages stage.""" return self._identity
[docs] def target_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each sample on device for all stages stage.""" return self._identity
[docs] def train_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self._identity
[docs] def train_input_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on device for the training stage.""" return self._identity
[docs] def train_target_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on device for the training stage.""" return self._identity
[docs] def val_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def val_input_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on device for the validating stage.""" return self._identity
[docs] def val_target_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on device for the validating stage.""" return self._identity
[docs] def test_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self._identity
[docs] def test_input_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on device for the testing stage.""" return self._identity
[docs] def test_target_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on device for the testing stage.""" return self._identity
[docs] def predict_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on a single sample on device for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_sample_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def predict_input_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on device for the predicting stage.""" return self._identity
[docs] def predict_target_per_sample_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on device for the predicting stage.""" return self._identity
####################### # PER BATCH TRANSFORM # #######################
[docs] def per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def input_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of batch on cpu for all stages stage.""" return self._identity
[docs] def target_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of batch on cpu for all stages stage.""" return self._identity
[docs] def train_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self._identity
[docs] def train_input_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on cpu for the training stage.""" return self._identity
[docs] def train_target_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on cpu for the training stage.""" return self._identity
[docs] def val_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def val_input_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on cpu for the validating stage.""" return self._identity
[docs] def val_target_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on cpu for the validating stage.""" return self._identity
[docs] def test_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self._identity
[docs] def test_input_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on cpu for the testing stage.""" return self._identity
[docs] def test_target_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on cpu for the testing stage.""" return self._identity
[docs] def predict_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def predict_input_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on cpu for the predicting stage.""" return self._identity
[docs] def predict_target_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on cpu for the predicting stage.""" return self._identity
[docs] def serve_per_batch_transform(self) -> Callable: """Defines the transform to be applied on a batch of data on cpu for the serving stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def serve_input_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on cpu for the serving stage.""" return self._identity
[docs] def serve_target_per_batch_transform(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on cpu for the serving stage.""" return self._identity
################################# # PER BATCH TRANSFORM ON DEVICE # #################################
[docs] def per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for all stages stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def input_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of batch on device for all stages stage.""" return self._identity
[docs] def target_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of batch on device for all stages stage.""" return self._identity
[docs] def train_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the training stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self._identity
[docs] def train_input_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on device for the training stage.""" return self._identity
[docs] def train_target_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on device for the training stage.""" return self._identity
[docs] def val_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the validating stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def val_input_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on device for the validating stage.""" return self._identity
[docs] def val_target_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on device for the validating stage.""" return self._identity
[docs] def test_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the testing stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } """ return self._identity
[docs] def test_input_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on device for the testing stage.""" return self._identity
[docs] def test_target_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on device for the testing stage.""" return self._identity
[docs] def predict_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on a batch of data on device for the predicting stage. The input data of the transform would have the following form:: { DataKeys.INPUT: ..., DataKeys.TARGET: ..., DataKeys.METADATA: ..., } You would need to use :class:`flash.core.data.transforms.ApplyToKeys` as follows: .. code-block:: python from flash.core.data.transforms import ApplyToKeys class MyInputTransform(InputTransform): def per_batch_transform_on_device(self) -> Callable: return ApplyToKeys("input", my_func) """ return self._identity
[docs] def predict_input_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "input" key of each single sample on device for the predicting stage.""" return self._identity
[docs] def predict_target_per_batch_transform_on_device(self) -> Callable: """Defines the transform to be applied on the value associated with the "target" key of each single sample on device for the predicting stage.""" return self._identity
########### # COLLATE # ###########
[docs] def train_collate(self) -> Callable: """Defines the transform to be applied on a list of training sample to create a training batch.""" return default_collate
[docs] def val_collate(self) -> Callable: """Defines the transform to be applied on a list of validating sample to create a validating batch.""" return default_collate
[docs] def test_collate(self) -> Callable: """Defines the transform to be applied on a list of testing sample to create a testing batch.""" return default_collate
[docs] def predict_collate(self) -> Callable: """Defines the transform to be applied on a list of predicting sample to create a predicting batch.""" return default_collate
[docs] def serve_collate(self) -> Callable: """Defines the transform to be applied on a list of serving sample to create a serving batch.""" return default_collate
[docs] def collate(self) -> Callable: """Defines the transform to be applied on a list of sample to create a batch for all stages.""" return default_collate
######################################## # HOOKS CALLED INTERNALLY WITHIN FLASH # ######################################## @partial(transform_context, current_fn="per_sample_transform") def _per_sample_transform(self, sample: Any) -> Any: fn = self.current_transform if isinstance(sample, list): return [fn(s) for s in sample] return fn(sample) @partial(transform_context, current_fn="per_batch_transform") def _per_batch_transform(self, batch: Any) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). .. note:: This option is mutually exclusive with :meth:`per_sample_transform_on_device`, since if both are specified, uncollation has to be applied. """ return self.current_transform(batch) @partial(transform_context, current_fn="collate") def _collate(self, samples: Sequence, metadata=None) -> Any: """Transform to convert a sequence of samples to a collated batch.""" collate_fn = self.current_transform parameters = inspect.signature(collate_fn).parameters if len(parameters) > 1 and DataKeys.METADATA in parameters: return collate_fn(samples, metadata) return collate_fn(samples) @partial(transform_context, current_fn="per_sample_transform_on_device") def _per_sample_transform_on_device(self, sample: Any) -> Any: """Transforms to apply to the data before the collation (per-sample basis). .. note:: This option is mutually exclusive with :meth:`per_batch_transform`, since if both are specified, uncollation has to be applied. .. note:: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ fn = self.current_transform if isinstance(sample, list): return [fn(s) for s in sample] return fn(sample) @partial(transform_context, current_fn="per_batch_transform_on_device") def _per_batch_transform_on_device(self, batch: Any) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). .. note:: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ return self.current_transform(batch) ############# # UTILITIES # ############# def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: from flash.core.data.data_pipeline import DataPipeline transforms_out = {} stage = _STAGES_PREFIX[running_stage] # iterate over all transforms hook name for transform_name in InputTransformPlacement: transforms = {} transform_name = transform_name.value # iterate over all prefixes for key in ApplyToKeyPrefix: # get the resolved hook name based on the current stage resolved_name = DataPipeline._resolve_function_hierarchy( transform_name, self, running_stage, InputTransform ) # check if the hook name is specialized is_specialized_name = resolved_name.startswith(stage) # get the resolved hook name for apply to key on the current stage resolved_apply_to_key_name = DataPipeline._resolve_function_hierarchy( f"{key}_{transform_name}", self, running_stage, InputTransform ) # check if resolved hook name for apply to key is specialized is_specialized_apply_to_key_name = resolved_apply_to_key_name.startswith(stage) # check if they are overridden by the user resolve_name_overridden = DataPipeline._is_overridden(resolved_name, self, InputTransform) resolved_apply_to_key_name_overridden = DataPipeline._is_overridden( resolved_apply_to_key_name, self, InputTransform ) if resolve_name_overridden and resolved_apply_to_key_name_overridden: # if both are specialized or both aren't specialized, raise a exception # It means there is priority to specialize hooks name. if not (is_specialized_name ^ is_specialized_apply_to_key_name): raise MisconfigurationException( f"Only one of {resolved_name} or {resolved_apply_to_key_name} can be overridden." ) method_name = resolved_name if is_specialized_name else resolved_apply_to_key_name else: method_name = resolved_apply_to_key_name if resolved_apply_to_key_name_overridden else resolved_name # get associated transform try: fn = getattr(self, method_name)() except AttributeError as e: raise AttributeError(str(e) + ". Hint: Call super().__init__(...) after setting all attributes.") if not callable(fn): raise MisconfigurationException(f"The hook {method_name} should return a function.") # if the default hook is used, it should return identity, skip it. if fn is self._identity: continue # wrap apply to key hook into `ApplyToKeys` with the associated key. if method_name == resolved_apply_to_key_name: fn = ApplyToKeys(key.value, fn) if method_name not in transforms: transforms[method_name] = fn # store the transforms. if transforms: transforms = list(transforms.values()) transforms_out[transform_name] = Compose(transforms) if len(transforms) > 1 else transforms[0] return transforms_out def _check_transforms( self, transform: Optional[Dict[str, Callable]], stage: RunningStage ) -> Optional[Dict[str, Callable]]: if transform is None: return transform keys_diff = set(transform.keys()).difference([v.value for v in InputTransformPlacement]) if len(keys_diff) > 0: raise MisconfigurationException( f"{stage}_transform contains {keys_diff}. Only {_INPUT_TRANSFORM_FUNCS} keys are supported." ) is_per_batch_transform_in = "per_batch_transform" in transform is_per_sample_transform_on_device_in = "per_sample_transform_on_device" in transform if is_per_batch_transform_in and is_per_sample_transform_on_device_in: raise MisconfigurationException( f"{transform}: `per_batch_transform` and `per_sample_transform_on_device` are mutually exclusive." ) collate_in_worker: Optional[bool] = None if is_per_batch_transform_in or (not is_per_batch_transform_in and not is_per_sample_transform_on_device_in): collate_in_worker = True elif is_per_sample_transform_on_device_in: collate_in_worker = False self._collate_in_worker_from_transform = collate_in_worker return transform @staticmethod def _identity(x: Any) -> Any: return x def _get_transform(self, transform: Dict[str, Callable]) -> Callable: if self.current_fn in transform: return transform[self.current_fn] return self._identity def __str__(self) -> str: return f"{self.__class__.__name__}(" + f"running_stage={self.running_stage}, transform={self._transform})" def __getitem__(self, placement: InputTransformPlacement) -> Callable: return self._transform[placement]
@dataclass class LambdaInputTransform(InputTransform): transform: Callable = InputTransform._identity def per_sample_transform(self) -> Callable: return self.transform def _sanitize_registry_transform( transform: Tuple[Union[LightningEnum, str], Any], input_transforms_registry: Optional[FlashRegistry] ) -> Tuple[Union[LightningEnum, str], Dict]: msg = "The transform should be provided as a tuple with the following types (LightningEnum, Dict[str, Any]) " msg += "when requesting transform from the registry." if not input_transforms_registry: raise MisconfigurationException("You requested a transform from the registry, but it is empty.") if isinstance(transform, tuple) and len(transform) > 2: raise MisconfigurationException(msg) if isinstance(transform, (LightningEnum, str)): enum = transform transform_kwargs = {} else: enum, transform_kwargs = transform if not isinstance(enum, (LightningEnum, str)): raise MisconfigurationException(msg) if not isinstance(transform_kwargs, Dict): raise MisconfigurationException(msg) return enum, transform_kwargs def create_transform( transform: INPUT_TRANSFORM_TYPE, running_stage: RunningStage, input_transforms_registry: Optional[FlashRegistry] = None, transform_kwargs: Optional[Dict] = None, ) -> Optional["InputTransform"]: if not transform_kwargs: transform_kwargs = {} if isinstance(transform, InputTransform): return transform if inspect.isclass(transform) and issubclass(transform, InputTransform): return transform(running_stage=running_stage, **transform_kwargs) if isinstance(transform, partial): return transform(running_stage=running_stage, **transform_kwargs) if isinstance(transform, Callable): return LambdaInputTransform( running_stage=running_stage, transform=transform, **transform_kwargs, ) if isinstance(transform, tuple) or isinstance(transform, (LightningEnum, str)): enum, transform_kwargs = _sanitize_registry_transform(transform, input_transforms_registry) transform_cls = input_transforms_registry.get(enum) return transform_cls(running_stage, **transform_kwargs) if not transform: return None raise MisconfigurationException(f"The format for the transform isn't correct. Found {transform}") def _make_collates(input_transform: "InputTransform", on_device: bool, collate: Callable) -> Tuple[Callable, Callable]: if on_device: return input_transform._identity, collate return collate, input_transform._identity class _InputTransformProcessorV2: """ This class is used to encapsulate the following functions of a InputTransformInputTransform Object: Inside a worker: per_sample_transform: Function to transform an individual sample collate: Function to merge sample into a batch per_batch_transform: Function to transform an individual batch Inside main process: per_sample_transform_on_device: Function to transform an individual sample collate: Function to merge sample into a batch per_batch_transform_on_device: Function to transform an individual batch """ def __init__( self, input_transform: InputTransform, collate_fn: Callable, per_sample_transform: Callable, per_batch_transform: Callable, stage: RunningStage, apply_per_sample_transform: bool = True, on_device: bool = False, callbacks: Optional[List[FlashCallback]] = None, ): super().__init__() self.input_transform = input_transform self.callback = ControlFlow(callbacks or []) self.collate_fn = collate_fn self.per_sample_transform = per_sample_transform self.per_batch_transform = per_batch_transform self.apply_per_sample_transform = apply_per_sample_transform self.stage = stage self.on_device = on_device @staticmethod def _extract_metadata( samples: List[Dict[str, Any]], ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]: metadata = [s.pop(DataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples] return samples, metadata if any(m is not None for m in metadata) else None def __call__(self, samples: Sequence[Any]) -> Any: if not self.on_device: for sample in samples: self.callback.on_load_sample(sample, self.stage) if self.apply_per_sample_transform: if not isinstance(samples, list): list_samples = [samples] else: list_samples = samples transformed_samples = [self.per_sample_transform(sample) for sample in list_samples] for sample in transformed_samples: if self.on_device: self.callback.on_per_sample_transform_on_device(sample, self.stage) else: self.callback.on_per_sample_transform(sample, self.stage) extracted_samples, metadata = self._extract_metadata(transformed_samples) try: collated_samples = self.collate_fn(extracted_samples, metadata) except TypeError: collated_samples = self.collate_fn(extracted_samples) if metadata and isinstance(collated_samples, dict): collated_samples[DataKeys.METADATA] = metadata self.callback.on_collate(collated_samples, self.stage) else: collated_samples = samples transformed_collated_samples = self.per_batch_transform(collated_samples) if self.on_device: self.callback.on_per_batch_transform_on_device(transformed_collated_samples, self.stage) else: self.callback.on_per_batch_transform(transformed_collated_samples, self.stage) return transformed_collated_samples def __str__(self) -> str: # todo: define repr function which would take object and string attributes to be shown return ( "_InputTransformProcessor:\n" f"\t(per_sample_transform): {str(self.per_sample_transform)}\n" f"\t(collate_fn): {str(self.collate_fn)}\n" f"\t(per_batch_transform): {str(self.per_batch_transform)}\n" f"\t(apply_per_sample_transform): {str(self.apply_per_sample_transform)}\n" f"\t(on_device): {str(self.on_device)}\n" f"\t(stage): {str(self.stage)}" ) def _create_collate_input_transform_processors( input_transform: "InputTransform", callbacks: List[FlashCallback] ) -> Tuple[_InputTransformProcessorV2, _InputTransformProcessorV2]: """This utility is used to create the 2 `_InputTransformProcessorV2` objects which contain the transforms used as the DataLoader `collate_fn` and the DataModule `on_after_batch_transfer` hook.""" from flash.core.data.data_pipeline import DataPipeline prefix: str = _STAGES_PREFIX[input_transform.running_stage] per_batch_transform_overridden: bool = DataPipeline._is_overridden_recursive( "per_batch_transform", input_transform, InputTransform, prefix=prefix ) per_sample_transform_on_device_overridden: bool = DataPipeline._is_overridden_recursive( "per_sample_transform_on_device", input_transform, InputTransform, prefix=prefix ) is_per_overridden = per_batch_transform_overridden and per_sample_transform_on_device_overridden if input_transform._collate_in_worker_from_transform is None and is_per_overridden: raise MisconfigurationException( f"{input_transform.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` " f"are mutually exclusive for stage {input_transform.running_stage}" ) if isinstance(input_transform._collate_in_worker_from_transform, bool): worker_collate_fn, device_collate_fn = _make_collates( input_transform, not input_transform._collate_in_worker_from_transform, input_transform._collate ) else: worker_collate_fn, device_collate_fn = _make_collates( input_transform, per_sample_transform_on_device_overridden, input_transform._collate ) worker_collate_fn = ( worker_collate_fn.collate_fn if isinstance(worker_collate_fn, _InputTransformProcessorV2) else worker_collate_fn ) worker_input_transform_processor = _InputTransformProcessorV2( input_transform, worker_collate_fn, input_transform._per_sample_transform, input_transform._per_batch_transform, input_transform.running_stage, callbacks=callbacks, ) device_input_transform_processor = _InputTransformProcessorV2( input_transform, device_collate_fn, input_transform._per_sample_transform_on_device, input_transform._per_batch_transform_on_device, input_transform.running_stage, apply_per_sample_transform=device_collate_fn != input_transform._identity, on_device=True, callbacks=callbacks, ) return worker_input_transform_processor, device_input_transform_processor

© Copyright 2020-2021, PyTorch Lightning. Revision 1c7d8e08.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.4
Versions
latest
stable
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.