Shortcuts

Source code for flash.core.data.transforms

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Mapping, Sequence, Union

import torch
from torch import nn
from torch.utils.data._utils.collate import default_collate

from flash.core.data.utils import convert_to_modules


[docs]class ApplyToKeys(nn.Sequential): """The ``ApplyToKeys`` class is an ``nn.Sequential`` which applies the given transforms to the given keys from the input. When a single key is given, a single value will be passed to the transforms. When multiple keys are given, the corresponding values will be passed to the transforms as a list. Args: keys: The key (``str``) or sequence of keys (``Sequence[str]``) to extract and forward to the transforms. args: The transforms, passed to the ``nn.Sequential`` super constructor. """ def __init__(self, keys: Union[str, Sequence[str]], *args): super().__init__(*(convert_to_modules(arg) for arg in args)) if isinstance(keys, str): keys = [keys] self.keys = keys def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: keys = list(filter(lambda key: key in x, self.keys)) inputs = [x[key] for key in keys] result = {} result.update(x) if len(inputs) == 1: result[keys[0]] = super().forward(inputs[0]) elif len(inputs) > 1: try: outputs = super().forward(inputs) except TypeError as e: raise Exception( "Failed to apply transforms to multiple keys at the same time," " try using KorniaParallelTransforms." ) from e for i, key in enumerate(keys): result[key] = outputs[i] # result is simply returned if len(inputs) == 0 return result def __repr__(self): transform = list(self.children()) keys = self.keys[0] if len(self.keys) == 1 else self.keys transform = transform[0] if len(transform) == 1 else transform return f"{self.__class__.__name__}(keys={repr(keys)}, transform={repr(transform)})"
[docs]class KorniaParallelTransforms(nn.Sequential): """The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each input (to ``.forward``) in parallel, whilst sharing the random state (``._params``). This should be used when multiple elements need to be augmented in the same way (e.g. an image and corresponding segmentation mask). Args: args: The transforms, passed to the ``nn.Sequential`` super constructor. """ def __init__(self, *args): super().__init__(*(convert_to_modules(arg) for arg in args)) def forward(self, inputs: Any): result = list(inputs) if isinstance(inputs, Sequence) else [inputs] for transform in self.children(): inputs = result # we enforce the first time to sample random params result[0] = transform(inputs[0]) if hasattr(transform, "_params") and bool(transform._params): params = transform._params else: params = None # apply transforms from (1, n) for i, input in enumerate(inputs[1:]): if params is not None: result[i + 1] = transform(input, params) else: # case for non-random transforms result[i + 1] = transform(input) if hasattr(transform, "_params") and bool(transform._params): transform._params = None return result
[docs]def kornia_collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]: """Kornia transforms add batch dimension which need to be removed. This function removes that dimension and then applies ``torch.utils.data._utils.collate.default_collate``. """ if len(samples) == 1 and isinstance(samples[0], list): samples = samples[0] for sample in samples: for key in sample.keys(): if torch.is_tensor(sample[key]) and sample[key].ndim == 4: sample[key] = sample[key].squeeze(0) return default_collate(samples)

© Copyright 2020-2021, PyTorch Lightning. Revision 1c7d8e08.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.4
Versions
latest
stable
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.