Shortcuts

Source code for flash.image.segmentation.input_transform

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Dict, Tuple

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import AlbumentationsAdapter, ApplyToKeys
from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, _TORCHVISION_AVAILABLE, requires

if _ALBUMENTATIONS_AVAILABLE:
    import albumentations as alb
else:
    alb = None

if _TORCHVISION_AVAILABLE:
    from torchvision import transforms as T


[docs]def prepare_target(batch: Dict[str, Any]) -> Dict[str, Any]: """Convert the target mask to long and remove the channel dimension.""" if DataKeys.TARGET in batch: batch[DataKeys.TARGET] = batch[DataKeys.TARGET].squeeze().long() return batch
def permute_target(sample: Dict[str, Any]) -> Dict[str, Any]: if DataKeys.TARGET in sample: target = sample[DataKeys.TARGET] if target.ndim == 2: target = target[None, :, :] sample[DataKeys.TARGET] = target.transpose((1, 2, 0)) return sample def remove_extra_dimensions(batch: Dict[str, Any]): if isinstance(batch[DataKeys.INPUT], list): assert len(batch[DataKeys.INPUT]) == 1 batch[DataKeys.INPUT] = batch[DataKeys.INPUT][0] return batch
[docs]@dataclass class SemanticSegmentationInputTransform(InputTransform): # https://albumentations.ai/docs/examples/pytorch_semantic_segmentation image_size: Tuple[int, int] = (128, 128) mean: Tuple[float, float, float] = (0.485, 0.456, 0.406) std: Tuple[float, float, float] = (0.229, 0.224, 0.225) @requires("image") def train_per_sample_transform(self) -> Callable: return T.Compose( [ permute_target, AlbumentationsAdapter( [ alb.Resize(*self.image_size), alb.HorizontalFlip(p=0.5), alb.Normalize(mean=self.mean, std=self.std), ] ), ApplyToKeys( DataKeys.INPUT, T.ToTensor(), ), ] ) @requires("image") def per_sample_transform(self) -> Callable: return T.Compose( [ permute_target, AlbumentationsAdapter( [ alb.Resize(*self.image_size), alb.Normalize(mean=self.mean, std=self.std), ] ), ApplyToKeys( DataKeys.INPUT, T.ToTensor(), ), ] ) def per_batch_transform(self) -> Callable: return T.Compose([prepare_target, remove_extra_dimensions])

© Copyright 2020-2021, PyTorch Lightning. Revision a374dd4f.

Built with Sphinx using a theme provided by Read the Docs.