Shortcuts

Source code for flash.core.data.transforms

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Mapping, Sequence, Union

import numpy as np
from torch import nn

from flash.core.data.io.input import DataKeys
from flash.core.data.utils import convert_to_modules
from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, requires

if _ALBUMENTATIONS_AVAILABLE:
    from albumentations import BasicTransform, Compose
else:
    BasicTransform, Compose = object, object


class AlbumentationsAdapter(nn.Module):
    # mapping from albumentations to Flash
    TRANSFORM_INPUT_MAPPING = {"image": DataKeys.INPUT, "mask": DataKeys.TARGET}

    @requires("albumentations")
    def __init__(
        self,
        transform: Union[BasicTransform, Sequence[BasicTransform]],
        mapping: dict = None,
    ):
        super().__init__()
        if not isinstance(transform, (list, tuple)):
            transform = [transform]
        self.transform = Compose(list(transform))
        if not mapping:
            mapping = self.TRANSFORM_INPUT_MAPPING
        self._mapping_rev = mapping
        self._mapping = {v: k for k, v in mapping.items()}

    def forward(self, x: Any) -> Any:
        if isinstance(x, dict):
            x_ = {self._mapping.get(key, key): np.array(value) for key, value in x.items() if key in self._mapping}
        else:
            x_ = {"image": x}
        x_ = self.transform(**x_)
        if isinstance(x, dict):
            x.update({self._mapping_rev.get(k, k): x_[k] for k in self._mapping_rev if k in x_})
            return x

        return x_["image"]


[docs]class ApplyToKeys(nn.Sequential): """The ``ApplyToKeys`` class is an ``nn.Sequential`` which applies the given transforms to the given keys from the input. When a single key is given, a single value will be passed to the transforms. When multiple keys are given, the corresponding values will be passed to the transforms as a list. Args: keys: The key (``str``) or sequence of keys (``Sequence[str]``) to extract and forward to the transforms. args: The transforms, passed to the ``nn.Sequential`` super constructor. """ def __init__(self, keys: Union[str, Sequence[str]], *args): super().__init__(*(convert_to_modules(arg) for arg in args)) if isinstance(keys, str): keys = [keys] self.keys = keys def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: keys = list(filter(lambda key: key in x, self.keys)) inputs = [x[key] for key in keys] result = {} result.update(x) if len(inputs) == 1: result[keys[0]] = super().forward(inputs[0]) elif len(inputs) > 1: try: outputs = super().forward(inputs) except TypeError as e: raise Exception("Failed to apply transforms to multiple keys at the same time.") from e for i, key in enumerate(keys): result[key] = outputs[i] # result is simply returned if len(inputs) == 0 return result def __repr__(self): transform = list(self.children()) keys = self.keys[0] if len(self.keys) == 1 else self.keys transform = transform[0] if len(transform) == 1 else transform return f"{self.__class__.__name__}(keys={repr(keys)}, transform={repr(transform)})"

© Copyright 2020-2021, PyTorch Lightning. Revision a9cedb5a.

Built with Sphinx using a theme provided by Read the Docs.