Shortcuts

Source code for flash.image.keypoint_detection.data

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type, Union

from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
from flash.core.integrations.icevision.data import IceVisionInput
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.image.keypoint_detection.input_transform import KeypointDetectionInputTransform

if _ICEVISION_AVAILABLE:
    from icevision.core import KeyPoints, KeypointsMetadata
    from icevision.parsers import COCOKeyPointsParser, Parser
else:
    COCOKeyPointsParser = object
    Parser = object
    KeyPoints = object


# Skip doctests if requirements aren't available
if not _ICEVISION_AVAILABLE:
    __doctest_skip__ = ["KeypointDetectionData", "KeypointDetectionData.*"]


class FlashCOCOKeyPointsParser(COCOKeyPointsParser):
    def __init__(
        self,
        annotations_filepath: Union[str, Path, dict],
        img_dir: Union[str, Path],
    ):
        super().__init__(annotations_filepath, img_dir)

        categories = self.annotations_dict["categories"]
        self.keypoint_labels = categories[0]["keypoints"]
        for o in categories[1:]:
            if not o["keypoints"] == self.keypoint_labels:
                raise ValueError(
                    "When performing keypoint detection with multiple categories, all categories are expected to have "
                    f"the same keypoints. Found {self.keypoint_labels} for category with ID {categories[0]['id']} and "
                    f"{o['keypoints']} for category with ID {o['id']}."
                )

    def keypoints(self, o) -> List[KeyPoints]:
        meta = KeypointsMetadata()
        meta.labels = self.keypoint_labels
        return [KeyPoints.from_xyv(o["keypoints"], meta)] if sum(o["keypoints"]) > 0 else []


[docs]class KeypointDetectionData(DataModule): """The ``KeypointDetectionData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods for loading data for keypoint detection.""" input_transform_cls = KeypointDetectionInputTransform @classmethod def from_icedata( cls, train_folder: Optional[str] = None, train_ann_file: Optional[str] = None, train_parser_kwargs: Optional[Dict[str, Any]] = None, val_folder: Optional[str] = None, val_ann_file: Optional[str] = None, val_parser_kwargs: Optional[Dict[str, Any]] = None, test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, test_parser_kwargs: Optional[Dict[str, Any]] = None, predict_folder: Optional[str] = None, train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, parser: Optional[Union[Callable, Type[Parser]]] = None, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "KeypointDetectionData": ds_kw = dict(parser=parser, transform_kwargs=transform_kwargs) return cls( input_cls( RunningStage.TRAINING, train_folder, train_ann_file, parser_kwargs=train_parser_kwargs, transform=train_transform, **ds_kw, ), input_cls( RunningStage.VALIDATING, val_folder, val_ann_file, parser_kwargs=val_parser_kwargs, transform=val_transform, **ds_kw, ), input_cls( RunningStage.TESTING, test_folder, test_ann_file, parser_kwargs=test_parser_kwargs, transform=test_transform, **ds_kw, ), input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), **data_module_kwargs, )
[docs] @classmethod def from_coco( cls, train_folder: Optional[str] = None, train_ann_file: Optional[str] = None, val_folder: Optional[str] = None, val_ann_file: Optional[str] = None, test_folder: Optional[str] = None, test_ann_file: Optional[str] = None, predict_folder: Optional[str] = None, train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ): """Creates a :class:`~flash.image.keypoint_detection.data.KeypointDetectionData` object from the given data folders and annotation files in the `COCO JSON format <https://cocodataset.org/#format-data>`_. For help understanding and using the COCO format, take a look at this tutorial: `Create COCO annotations from scratch <COCO>`__. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: train_folder: The folder containing images to use when training. train_ann_file: The COCO format annotation file to use when training. val_folder: The folder containing images to use when validating. val_ann_file: The COCO format annotation file to use when validating. test_folder: The folder containing images to use when testing. test_ann_file: The COCO format annotation file to use when testing. predict_folder: The folder containing images to use when predicting. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.keypoint_detection.data.KeypointDetectionData`. Examples ________ .. testsetup:: >>> import os >>> import json >>> import numpy as np >>> from PIL import Image >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) >>> os.makedirs("train_folder", exist_ok=True) >>> os.makedirs("predict_folder", exist_ok=True) >>> _ = [rand_image.save(os.path.join("train_folder", f"image_{i}.png")) for i in range(1, 4)] >>> _ = [rand_image.save(os.path.join("predict_folder", f"predict_image_{i}.png")) for i in range(1, 4)] >>> annotations = {"annotations": [ ... {"area": 50, "bbox": [10, 20, 5, 10], "num_keypoints": 2, "keypoints": [10, 15, 2, 20, 30, 2], ... "category_id": 1, "id": 1, "image_id": 1, "iscrowd": 0}, ... {"area": 100, "bbox": [20, 30, 10, 10], "num_keypoints": 2, "keypoints": [20, 30, 2, 30, 40, 2], ... "category_id": 2, "id": 2, "image_id": 2, "iscrowd": 0}, ... {"area": 125, "bbox": [10, 20, 5, 25], "num_keypoints": 2, "keypoints": [10, 15, 2, 20, 45, 2], ... "category_id": 1, "id": 3, "image_id": 3, "iscrowd": 0}, ... ], "categories": [ ... {"id": 1, "name": "cat", "supercategory": "cat", "keypoints": ["left ear", "right ear"]}, ... {"id": 2, "name": "dog", "supercategory": "dog", "keypoints": ["left ear", "right ear"]}, ... ], "images": [ ... {"file_name": "image_1.png", "height": 64, "width": 64, "id": 1}, ... {"file_name": "image_2.png", "height": 64, "width": 64, "id": 2}, ... {"file_name": "image_3.png", "height": 64, "width": 64, "id": 3}, ... ]} >>> with open("train_annotations.json", "w") as annotation_file: ... json.dump(annotations, annotation_file) The folder ``train_folder`` has the following contents: .. code-block:: train_folder ├── image_1.png ├── image_2.png ├── image_3.png ... The file ``train_annotations.json`` contains the following: .. code-block:: { "annotations": [ { "area": 50, "bbox": [10, 20, 5, 10], "num_keypoints": 2, "keypoints": [10, 15, 2, 20, 30, 2], "category_id": 1, "id": 1, "image_id": 1, "iscrowd": 0 }, { "area": 100, "bbox": [20, 30, 10, 10], "num_keypoints": 2, "keypoints": [20, 30, 2, 30, 40, 2], "category_id": 2, "id": 2, "image_id": 2, "iscrowd": 0 }, { "area": 125, "bbox": [10, 20, 5, 25], "num_keypoints": 2, "keypoints": [10, 15, 2, 20, 45, 2], "category_id": 1, "id": 3, "image_id": 3, "iscrowd": 0 } ], "categories": [ {"id": 1, "name": "cat", "supercategory": "cat", "keypoints": ["left ear", "right ear"]}, {"id": 2, "name": "dog", "supercategory": "dog", "keypoints": ["left ear", "right ear"]} ], "images": [ {"file_name": "image_1.png", "height": 64, "width": 64, "id": 1}, {"file_name": "image_2.png", "height": 64, "width": 64, "id": 2}, {"file_name": "image_3.png", "height": 64, "width": 64, "id": 3} ] } .. doctest:: >>> from flash import Trainer >>> from flash.image import KeypointDetector, KeypointDetectionData >>> datamodule = KeypointDetectionData.from_coco( ... train_folder="train_folder", ... train_ann_file="train_annotations.json", ... predict_folder="predict_folder", ... transform_kwargs=dict(image_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 3 >>> datamodule.labels ['background', 'cat', 'dog'] >>> model = KeypointDetector(2, num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... .. testcleanup:: >>> import shutil >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") >>> os.remove("train_annotations.json") """ return cls.from_icedata( train_folder=train_folder, train_ann_file=train_ann_file, val_folder=val_folder, val_ann_file=val_ann_file, test_folder=test_folder, test_ann_file=test_ann_file, predict_folder=predict_folder, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, parser=FlashCOCOKeyPointsParser, input_cls=input_cls, transform_kwargs=transform_kwargs, **data_module_kwargs, )
[docs] @classmethod def from_folders( cls, predict_folder: Optional[str] = None, predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders. This is supported only for the predicting stage. Args: predict_folder: The folder containing the predict data. predict_transform: The dictionary of transforms to use during predicting which maps input_cls: The :class:`~flash.core.data.io.input.Input` used to create the dataset. transform_kwargs: Keyword arguments provided to the transform on instantiation. data_module_kwargs: The keywords arguments for creating the datamodule. Returns: The constructed data module. """ ds_kw = dict(transform=predict_transform, transform_kwargs=transform_kwargs) return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), **data_module_kwargs, )
[docs] @classmethod def from_files( cls, predict_files: Optional[List[str]] = None, predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform, input_cls: Type[Input] = IceVisionInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "DataModule": """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given a list of files. This is supported only for the predicting stage. Args: predict_files: The list of files containing the predict data. predict_transform: The dictionary of transforms to use during predicting which maps. input_cls: The :class:`~flash.core.data.io.input.Input` used to create the dataset. transform_kwargs: Keyword arguments provided to the transform on instantiation. data_module_kwargs: The keywords arguments for creating the datamodule. Returns: The constructed data module. """ ds_kw = dict(transform=predict_transform, transform_kwargs=transform_kwargs) return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), **data_module_kwargs, )

© Copyright 2020-2021, PyTorch Lightning. Revision 1c7d8e08.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.4
Versions
latest
stable
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.