SemanticSegmentationData¶
- class flash.image.segmentation.data.SemanticSegmentationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
SemanticSegmentationData
class is aDataModule
with a set of classmethods for loading data for semantic segmentation.- classmethod from_fiftyone(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, input_cls=<class 'flash.image.segmentation.input.SemanticSegmentationFiftyOneInput'>, num_classes=None, labels_map=None, transform=<class 'flash.image.segmentation.input_transform.SemanticSegmentationInputTransform'>, transform_kwargs=None, label_field='ground_truth', **data_module_kwargs)[source]¶
Load the
SemanticSegmentationData
from FiftyOneSampleCollection
objects.The supported file extensions are:
.jpg
,.jpeg
,.png
,.ppm
,.bmp
,.pgm
,.tif
,.tiff
,.webp
, and.npy
. Mask image file paths will be extracted from thelabel_field
in theSampleCollection
objects. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_dataset¶ (
Optional
[object
]) – TheSampleCollection
to use when training.val_dataset¶ (
Optional
[object
]) – TheSampleCollection
to use when validating.test_dataset¶ (
Optional
[object
]) – TheSampleCollection
to use when testing.predict_dataset¶ (
Optional
[object
]) – TheSampleCollection
to use when predicting.label_field¶ (
str
) – The field in theSampleCollection
objects containing the targets.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.num_classes¶ (
Optional
[int
]) – The number of segmentation classes.labels_map¶ (
Optional
[Dict
[int
,Tuple
[int
,int
,int
]]]) – An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
SemanticSegmentationData
.
Examples
>>> import numpy as np >>> import fiftyone as fo >>> from flash import Trainer >>> from flash.image import SemanticSegmentation, SemanticSegmentationData >>> train_dataset = fo.Dataset.from_images( ... ["image_1.png", "image_2.png", "image_3.png"] ... ) ... >>> samples = [train_dataset[filepath] for filepath in train_dataset.values("filepath")] >>> for sample in samples: ... sample["ground_truth"] = fo.Segmentation(mask=np.random.randint(0, 10, (64, 64), dtype="uint8")) ... sample.save() ... >>> predict_dataset = fo.Dataset.from_images( ... ["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"] ... ) ... >>> datamodule = SemanticSegmentationData.from_fiftyone( ... train_dataset=train_dataset, ... predict_dataset=predict_dataset, ... transform_kwargs=dict(image_size=(128, 128)), ... num_classes=10, ... batch_size=2, ... ) >>> datamodule.num_classes 10 >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_files(train_files=None, train_targets=None, val_files=None, val_targets=None, test_files=None, test_targets=None, predict_files=None, input_cls=<class 'flash.image.segmentation.input.SemanticSegmentationFilesInput'>, num_classes=None, labels_map=None, transform=<class 'flash.image.segmentation.input_transform.SemanticSegmentationInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SemanticSegmentationData
from lists of input files and corresponding lists of mask files.The supported file extensions are:
.jpg
,.jpeg
,.png
,.ppm
,.bmp
,.pgm
,.tif
,.tiff
,.webp
, and.npy
. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_files¶ (
Optional
[Sequence
[str
]]) – The list of image files to use when training.train_targets¶ (
Optional
[Sequence
[str
]]) – The list of mask files to use when training.val_files¶ (
Optional
[Sequence
[str
]]) – The list of image files to use when validating.val_targets¶ (
Optional
[Sequence
[str
]]) – The list of mask files to use when validating.test_files¶ (
Optional
[Sequence
[str
]]) – The list of image files to use when testing.test_targets¶ (
Optional
[Sequence
[str
]]) – The list of mask files to use when testing.predict_files¶ (
Optional
[Sequence
[str
]]) – The list of image files to use when predicting.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.num_classes¶ (
Optional
[int
]) – The number of segmentation classes.labels_map¶ (
Optional
[Dict
[int
,Tuple
[int
,int
,int
]]]) – An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
SemanticSegmentationData
.
Examples
>>> from flash import Trainer >>> from flash.image import SemanticSegmentation, SemanticSegmentationData >>> datamodule = SemanticSegmentationData.from_files( ... train_files=["image_1.png", "image_2.png", "image_3.png"], ... train_targets=["mask_1.npy", "mask_2.npy", "mask_3.npy"], ... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"], ... transform_kwargs=dict(image_size=(128, 128)), ... num_classes=10, ... batch_size=2, ... ) >>> datamodule.num_classes 10 >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_folders(train_folder=None, train_target_folder=None, val_folder=None, val_target_folder=None, test_folder=None, test_target_folder=None, predict_folder=None, input_cls=<class 'flash.image.segmentation.input.SemanticSegmentationFolderInput'>, num_classes=None, labels_map=None, transform=<class 'flash.image.segmentation.input_transform.SemanticSegmentationInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SemanticSegmentationData
from folders containing image files and folders containing mask files.The supported file extensions are:
.jpg
,.jpeg
,.png
,.ppm
,.bmp
,.pgm
,.tif
,.tiff
,.webp
, and.npy
. For train, test, and validation data, the folders are expected to contain the images with a corresponding target folder which contains the mask in a file of the same name. For example, if yourtrain_images
folder (passed to thetrain_folder
argument) looks like this:train_images ├── image_1.png ├── image_2.png ├── image_3.png ...
your
train_masks
folder (passed to thetrain_target_folder
argument) would need to look like this (although the file extensions could be different):train_masks ├── image_1.png ├── image_2.png ├── image_3.png ...
For prediction, the folder is expected to contain the files for inference, like this:
predict_folder ├── predict_image_1.png ├── predict_image_2.png ├── predict_image_3.png ...
To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_folder¶ (
Optional
[str
]) – The folder containing images to use when training.train_target_folder¶ (
Optional
[str
]) – The folder containing masks to use when training (files should have the same name as the files in thetrain_folder
).val_folder¶ (
Optional
[str
]) – The folder containing images to use when validating.val_target_folder¶ (
Optional
[str
]) – The folder containing masks to use when validating (files should have the same name as the files in thetrain_folder
).test_folder¶ (
Optional
[str
]) – The folder containing images to use when testing.test_target_folder¶ (
Optional
[str
]) – The folder containing masks to use when testing (files should have the same name as the files in thetrain_folder
).predict_folder¶ (
Optional
[str
]) – The folder containing images to use when predicting.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.num_classes¶ (
Optional
[int
]) – The number of segmentation classes.labels_map¶ (
Optional
[Dict
[int
,Tuple
[int
,int
,int
]]]) – An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
SemanticSegmentationData
.
Examples
>>> from flash import Trainer >>> from flash.image import SemanticSegmentation, SemanticSegmentationData >>> datamodule = SemanticSegmentationData.from_folders( ... train_folder="train_images", ... train_target_folder="train_masks", ... predict_folder="predict_folder", ... transform_kwargs=dict(image_size=(128, 128)), ... num_classes=10, ... batch_size=2, ... ) >>> datamodule.num_classes 10 >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_numpy(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, input_cls=<class 'flash.image.segmentation.input.SemanticSegmentationNumpyInput'>, num_classes=None, labels_map=None, transform=<class 'flash.image.segmentation.input_transform.SemanticSegmentationInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SemanticSegmentationData
from numpy arrays containing images (or lists of arrays) and corresponding numpy arrays containing masks (or lists of arrays).To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_data¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays containing images to use when training.train_targets¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays containing masks to use when training.val_data¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays containing images to use when validating.val_targets¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays containing masks to use when validating.test_data¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays containing images to use when testing.test_targets¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays containing masks to use when testing.predict_data¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays to use when predicting.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.num_classes¶ (
Optional
[int
]) – The number of segmentation classes.labels_map¶ (
Optional
[Dict
[int
,Tuple
[int
,int
,int
]]]) – An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
SemanticSegmentationData
.
Examples
>>> import numpy as np >>> from flash import Trainer >>> from flash.image import SemanticSegmentation, SemanticSegmentationData >>> datamodule = SemanticSegmentationData.from_numpy( ... train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)], ... train_targets=[ ... np.random.randint(0, 10, (1, 64, 64), dtype="uint8"), ... np.random.randint(0, 10, (1, 64, 64), dtype="uint8"), ... np.random.randint(0, 10, (1, 64, 64), dtype="uint8"), ... ], ... predict_data=[np.random.rand(3, 64, 64)], ... transform_kwargs=dict(image_size=(128, 128)), ... num_classes=10, ... batch_size=2, ... ) >>> datamodule.num_classes 10 >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_tensors(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, input_cls=<class 'flash.image.segmentation.input.SemanticSegmentationTensorInput'>, num_classes=None, labels_map=None, transform=<class 'flash.image.segmentation.input_transform.SemanticSegmentationInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SemanticSegmentationData
from torch tensors containing images (or lists of tensors) and corresponding torch tensors containing masks (or lists of tensors).To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_data¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors containing images to use when training.train_targets¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors containing masks to use when training.val_data¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors containing images to use when validating.val_targets¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors containing masks to use when validating.test_data¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors containing images to use when testing.test_targets¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors containing masks to use when testing.predict_data¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors to use when predicting.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.num_classes¶ (
Optional
[int
]) – The number of segmentation classes.labels_map¶ (
Optional
[Dict
[int
,Tuple
[int
,int
,int
]]]) – An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
SemanticSegmentationData
.
Examples
>>> import torch >>> from flash import Trainer >>> from flash.image import SemanticSegmentation, SemanticSegmentationData >>> datamodule = SemanticSegmentationData.from_tensors( ... train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)], ... train_targets=[ ... torch.randint(10, (1, 64, 64)), ... torch.randint(10, (1, 64, 64)), ... torch.randint(10, (1, 64, 64)), ... ], ... predict_data=[torch.rand(3, 64, 64)], ... transform_kwargs=dict(image_size=(128, 128)), ... num_classes=10, ... batch_size=2, ... ) >>> datamodule.num_classes 10 >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- input_transform_cls¶
alias of
flash.image.segmentation.input_transform.SemanticSegmentationInputTransform