Shortcuts

Source code for flash.core.data.transforms

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Mapping, Sequence, Union

import numpy as np
from torch import nn

from flash.core.data.io.input import DataKeys
from flash.core.data.utils import convert_to_modules
from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, requires

if _ALBUMENTATIONS_AVAILABLE:
    from albumentations import BasicTransform, Compose
else:
    BasicTransform, Compose = object, object


class AlbumentationsAdapter(nn.Module):
    # mapping from albumentations to Flash
    TRANSFORM_INPUT_MAPPING = {"image": DataKeys.INPUT, "mask": DataKeys.TARGET}

    @requires("albumentations")
    def __init__(
        self,
        transform: Union[BasicTransform, Sequence[BasicTransform]],
        mapping: dict = None,
    ):
        super().__init__()
        if not isinstance(transform, (list, tuple)):
            transform = [transform]
        self.transform = Compose(list(transform))
        if not mapping:
            mapping = self.TRANSFORM_INPUT_MAPPING
        self._mapping_rev = mapping
        self._mapping = {v: k for k, v in mapping.items()}

    def forward(self, x: Any) -> Any:
        if isinstance(x, dict):
            x_ = {self._mapping.get(key, key): np.array(value) for key, value in x.items() if key in self._mapping}
        else:
            x_ = {"image": x}
        x_ = self.transform(**x_)
        if isinstance(x, dict):
            x.update({self._mapping_rev.get(k, k): x_[k] for k in self._mapping_rev if k in x_})
            return x

        return x_["image"]


[docs]class ApplyToKeys(nn.Sequential): """The ``ApplyToKeys`` class is an ``nn.Sequential`` which applies the given transforms to the given keys from the input. When a single key is given, a single value will be passed to the transforms. When multiple keys are given, the corresponding values will be passed to the transforms as a list. Args: keys: The key (``str``) or sequence of keys (``Sequence[str]``) to extract and forward to the transforms. args: The transforms, passed to the ``nn.Sequential`` super constructor. """ def __init__(self, keys: Union[str, Sequence[str]], *args): super().__init__(*(convert_to_modules(arg) for arg in args)) if isinstance(keys, str): keys = [keys] self.keys = keys def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: keys = list(filter(lambda key: key in x, self.keys)) inputs = [x[key] for key in keys] result = {} result.update(x) if len(inputs) == 1: result[keys[0]] = super().forward(inputs[0]) elif len(inputs) > 1: try: outputs = super().forward(inputs) except TypeError as e: raise Exception("Failed to apply transforms to multiple keys at the same time.") from e for i, key in enumerate(keys): result[key] = outputs[i] # result is simply returned if len(inputs) == 0 return result def __repr__(self): transform = list(self.children()) keys = self.keys[0] if len(self.keys) == 1 else self.keys transform = transform[0] if len(transform) == 1 else transform return f"{self.__class__.__name__}(keys={repr(keys)}, transform={repr(transform)})"

© Copyright 2020-2021, PyTorch Lightning. Revision a374dd4f.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.