Shortcuts

KeypointDetectionData

class flash.image.keypoint_detection.data.KeypointDetectionData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The KeypointDetectionData class is a DataModule with a set of classmethods for loading data for keypoint detection.

classmethod from_coco(train_folder=None, train_ann_file=None, val_folder=None, val_ann_file=None, test_folder=None, test_ann_file=None, predict_folder=None, train_transform=<class 'flash.image.keypoint_detection.input_transform.KeypointDetectionInputTransform'>, val_transform=<class 'flash.image.keypoint_detection.input_transform.KeypointDetectionInputTransform'>, test_transform=<class 'flash.image.keypoint_detection.input_transform.KeypointDetectionInputTransform'>, predict_transform=<class 'flash.image.keypoint_detection.input_transform.KeypointDetectionInputTransform'>, input_cls=<class 'flash.core.integrations.icevision.data.IceVisionInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Creates a KeypointDetectionData object from the given data folders and annotation files in the COCO JSON format.

For help understanding and using the COCO format, take a look at this tutorial: Create COCO annotations from scratch.

To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Returns

The constructed KeypointDetectionData.

Examples

The folder train_folder has the following contents:

train_folder
├── image_1.png
├── image_2.png
├── image_3.png
...

The file train_annotations.json contains the following:

{
    "annotations": [
        {
            "area": 50, "bbox": [10, 20, 5, 10], "num_keypoints": 2, "keypoints": [10, 15, 2, 20, 30, 2],
            "category_id": 1, "id": 1, "image_id": 1, "iscrowd": 0
        }, {
            "area": 100, "bbox": [20, 30, 10, 10], "num_keypoints": 2, "keypoints": [20, 30, 2, 30, 40, 2],
            "category_id": 2, "id": 2, "image_id": 2, "iscrowd": 0
        }, {
            "area": 125, "bbox": [10, 20, 5, 25], "num_keypoints": 2, "keypoints": [10, 15, 2, 20, 45, 2],
            "category_id": 1, "id": 3, "image_id": 3, "iscrowd": 0
        }
    ], "categories": [
        {"id": 1, "name": "cat", "supercategory": "cat", "keypoints": ["left ear", "right ear"]},
        {"id": 2, "name": "dog", "supercategory": "dog", "keypoints": ["left ear", "right ear"]}
    ], "images": [
        {"file_name": "image_1.png", "height": 64, "width": 64, "id": 1},
        {"file_name": "image_2.png", "height": 64, "width": 64, "id": 2},
        {"file_name": "image_3.png", "height": 64, "width": 64, "id": 3}
    ]
}
>>> from flash import Trainer
>>> from flash.image import KeypointDetector, KeypointDetectionData
>>> datamodule = KeypointDetectionData.from_coco(
...     train_folder="train_folder",
...     train_ann_file="train_annotations.json",
...     predict_folder="predict_folder",
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['background', 'cat', 'dog']
>>> model = KeypointDetector(2, num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_files(predict_files=None, predict_transform=<class 'flash.image.keypoint_detection.input_transform.KeypointDetectionInputTransform'>, input_cls=<class 'flash.core.integrations.icevision.data.IceVisionInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Creates a DataModule object from the given a list of files.

This is supported only for the predicting stage.

Parameters
  • predict_files (Optional[List[str]]) – The list of files containing the predict data.

  • predict_transform (TypeVar(INPUT_TRANSFORM_TYPE, Type[flash.core.data.io.input_transform.InputTransform], Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str], None)) – The dictionary of transforms to use during predicting which maps.

  • input_cls (Type[Input]) – The Input used to create the dataset.

  • transform_kwargs (Optional[Dict]) – Keyword arguments provided to the transform on instantiation.

  • data_module_kwargs (Any) – The keywords arguments for creating the datamodule.

Return type

DataModule

Returns

The constructed data module.

classmethod from_folders(predict_folder=None, predict_transform=<class 'flash.image.keypoint_detection.input_transform.KeypointDetectionInputTransform'>, input_cls=<class 'flash.core.integrations.icevision.data.IceVisionInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Creates a DataModule object from the given folders.

This is supported only for the predicting stage.

Parameters
  • predict_folder (Optional[str]) – The folder containing the predict data.

  • predict_transform (TypeVar(INPUT_TRANSFORM_TYPE, Type[flash.core.data.io.input_transform.InputTransform], Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str], None)) – The dictionary of transforms to use during predicting which maps

  • input_cls (Type[Input]) – The Input used to create the dataset.

  • transform_kwargs (Optional[Dict]) – Keyword arguments provided to the transform on instantiation.

  • data_module_kwargs (Any) – The keywords arguments for creating the datamodule.

Return type

DataModule

Returns

The constructed data module.

Read the Docs v: 0.7.1
Versions
latest
stable
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.