Shortcuts

Source code for flash.core.integrations.icevision.transforms

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
from torch import Tensor, nn

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.utilities.imports import (
    _ICEVISION_AVAILABLE,
    _ICEVISION_GREATER_EQUAL_0_11_0,
    _TOPIC_IMAGE_AVAILABLE,
    requires,
)

if _TOPIC_IMAGE_AVAILABLE:
    from PIL import Image

if _ICEVISION_AVAILABLE:
    from icevision.core import tasks
    from icevision.core.bbox import BBox
    from icevision.core.keypoints import KeyPoints
    from icevision.core.mask import Mask, MaskArray
    from icevision.core.record import BaseRecord
    from icevision.core.record_components import (
        BBoxesRecordComponent,
        ClassMapRecordComponent,
        FilepathRecordComponent,
        ImageRecordComponent,
        InstancesLabelsRecordComponent,
        KeyPointsRecordComponent,
        RecordComponent,
        RecordIDRecordComponent,
    )
    from icevision.data.prediction import Prediction
    from icevision.tfms import A
else:
    MaskArray = object
    RecordComponent = object

    class tasks:
        common = object


if _ICEVISION_AVAILABLE and _ICEVISION_GREATER_EQUAL_0_11_0:
    from icevision.core.record_components import InstanceMasksRecordComponent
elif _ICEVISION_AVAILABLE:
    from icevision.core.record_components import MasksRecordComponent


def _split_mask_array(mask_array: MaskArray) -> List[MaskArray]:
    """Utility to split a single ``MaskArray`` object into a list of ``MaskArray`` objects (one per mask)."""
    return [MaskArray(mask) for mask in mask_array.data]


class OriginalSizeRecordComponent(RecordComponent):
    def __init__(self, original_size: Optional[Tuple[int, int]], task=tasks.common):
        super().__init__(task=task)
        # original_size: (h, w)
        self.original_size: Optional[Tuple[int, int]] = original_size


def to_icevision_record(sample: Dict[str, Any]):
    record = BaseRecord([])

    metadata = sample.get(DataKeys.METADATA, None) or {}

    if "image_id" in metadata:
        record_id_component = RecordIDRecordComponent()
        record_id_component.set_record_id(metadata["image_id"])

    component = ClassMapRecordComponent(tasks.detection)
    component.set_class_map(metadata.get("class_map", None))
    record.add_component(component)

    if isinstance(sample[DataKeys.INPUT], str):
        input_component = FilepathRecordComponent()
        input_component.set_filepath(sample[DataKeys.INPUT])
    else:
        if "filepath" in metadata:
            input_component = FilepathRecordComponent()
            input_component.filepath = metadata["filepath"]
        else:
            input_component = ImageRecordComponent()
        input_component.composite = record
        image = sample[DataKeys.INPUT]
        if isinstance(image, Tensor):
            image = image.permute(1, 2, 0).numpy()
        elif isinstance(image, Image.Image):
            image = np.array(image)
        input_component.set_img(image)

        record.add_component(OriginalSizeRecordComponent(metadata.get("size", image.shape[:2])))
    record.add_component(input_component)

    if DataKeys.TARGET in sample:
        if "labels" in sample[DataKeys.TARGET]:
            labels_component = InstancesLabelsRecordComponent()
            labels_component.add_labels_by_id(sample[DataKeys.TARGET]["labels"])
            record.add_component(labels_component)

        if "bboxes" in sample[DataKeys.TARGET]:
            bboxes = [
                BBox.from_xywh(bbox["xmin"], bbox["ymin"], bbox["width"], bbox["height"])
                for bbox in sample[DataKeys.TARGET]["bboxes"]
            ]
            bboxes_component = BBoxesRecordComponent()
            bboxes_component.set_bboxes(bboxes)
            record.add_component(bboxes_component)

        if _ICEVISION_GREATER_EQUAL_0_11_0:
            masks = sample[DataKeys.TARGET].get("masks", None)

            if masks is not None:
                component = InstanceMasksRecordComponent()

                if len(masks) > 0 and isinstance(masks[0], Mask):
                    component.set_masks(masks)
                else:
                    # TODO: This treats invalid examples as negative examples
                    if len(masks) == 0 or not (
                        len(masks) == len(record.detection.bboxes) == len(record.detection.label_ids)
                    ):
                        data = np.zeros((0, record.height, record.width), np.uint8)
                        labels_component.label_ids = []
                        bboxes_component.bboxes = []
                    else:
                        data = np.stack(masks, axis=0)
                    mask_array = MaskArray(data)
                    component.set_mask_array(mask_array)
                    component.set_masks(_split_mask_array(mask_array))

                record.add_component(component)
        else:
            mask_array = sample[DataKeys.TARGET].get("mask_array", None)
            if mask_array is not None:
                component = MasksRecordComponent()
                component.set_masks(mask_array)
                record.add_component(component)

        if "keypoints" in sample[DataKeys.TARGET]:
            keypoints = []

            keypoints_metadata = sample[DataKeys.TARGET].get(
                "keypoints_metadata", [None] * len(sample[DataKeys.TARGET]["keypoints"])
            )

            for keypoints_list, keypoints_metadata in zip(sample[DataKeys.TARGET]["keypoints"], keypoints_metadata):
                xyv = []
                for keypoint in keypoints_list:
                    xyv.extend((keypoint["x"], keypoint["y"], keypoint["visible"]))

                keypoints.append(KeyPoints.from_xyv(xyv, keypoints_metadata))
            component = KeyPointsRecordComponent()
            component.set_keypoints(keypoints)
            record.add_component(component)

    return record


def from_icevision_detection(record: "BaseRecord"):
    detection = record.detection

    result = {}

    if hasattr(detection, "bboxes"):
        result["bboxes"] = [
            {
                "xmin": bbox.xmin,
                "ymin": bbox.ymin,
                "width": bbox.width,
                "height": bbox.height,
            }
            for bbox in detection.bboxes
        ]

    masks = getattr(detection, "masks", None)
    mask_array = getattr(detection, "mask_array", None)
    if mask_array is not None or not _ICEVISION_GREATER_EQUAL_0_11_0:
        if not isinstance(mask_array, MaskArray) or len(mask_array.data) == 0:
            mask_array = MaskArray.from_masks(masks, record.height, record.width)

        result["masks"] = [mask.data[0] for mask in _split_mask_array(mask_array)]
    elif masks is not None:
        result["masks"] = masks  # Note - this doesn't unpack IceVision objects

    if hasattr(detection, "keypoints"):
        keypoints = detection.keypoints

        result["keypoints"] = []
        result["keypoints_metadata"] = []

        for keypoint in keypoints:
            keypoints_list = []
            for x, y, v in keypoint.xyv:
                keypoints_list.append(
                    {
                        "x": x,
                        "y": y,
                        "visible": v,
                    }
                )
            result["keypoints"].append(keypoints_list)

            # TODO: Unpack keypoints_metadata
            result["keypoints_metadata"].append(keypoint.metadata)

    if getattr(detection, "label_ids", None) is not None:
        result["labels"] = list(detection.label_ids)

    if getattr(detection, "scores", None) is not None:
        result["scores"] = list(detection.scores)

    return result


def from_icevision_record(record: "BaseRecord"):
    sample = {
        DataKeys.METADATA: {
            "size": getattr(record, "original_size", (record.height, record.width)),
            "output_size": (record.height, record.width),
        }
    }

    if getattr(record, "record_id", None) is not None:
        sample[DataKeys.METADATA]["image_id"] = record.record_id

    if getattr(record, "filepath", None) is not None:
        sample[DataKeys.METADATA]["filepath"] = record.filepath

    if record.img is not None:
        sample[DataKeys.INPUT] = record.img
        filepath = getattr(record, "filepath", None)
        if filepath is not None:
            sample[DataKeys.METADATA]["filepath"] = filepath
    elif getattr(record, "filepath", None) is not None:
        sample[DataKeys.INPUT] = record.filepath

    sample[DataKeys.TARGET] = from_icevision_detection(record)

    if getattr(record.detection, "class_map", None) is not None:
        sample[DataKeys.METADATA]["class_map"] = record.detection.class_map

    return sample


def from_icevision_predictions(predictions: List["Prediction"]):
    result = []
    for prediction in predictions:
        result.append(from_icevision_detection(prediction.pred))
    return result


[docs]class IceVisionTransformAdapter(nn.Module): """ Args: transform: list of transformation functions to apply """ def __init__(self, transform: List[Callable]): super().__init__() self.transform = A.Adapter(transform) def forward(self, x): record = to_icevision_record(x) record = self.transform(record) return from_icevision_record(record)
@dataclass class IceVisionInputTransform(InputTransform): image_size: int = 128 @requires("image", "icevision") def per_sample_transform(self): return IceVisionTransformAdapter([*A.resize_and_pad(self.image_size), A.Normalize()]) @requires("image", "icevision") def train_per_sample_transform(self): return IceVisionTransformAdapter([*A.aug_tfms(size=self.image_size), A.Normalize()]) def collate(self) -> Callable: return self._identity

© Copyright 2020-2021, PyTorch Lightning. Revision a374dd4f.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.