SemanticSegmentationData¶
- class flash.image.segmentation.data.SemanticSegmentationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
SemanticSegmentationDataclass is aDataModulewith a set of classmethods for loading data for semantic segmentation.- classmethod from_fiftyone(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, input_cls=<class 'flash.image.segmentation.input.SemanticSegmentationFiftyOneInput'>, num_classes=None, labels_map=None, transform=<class 'flash.image.segmentation.input_transform.SemanticSegmentationInputTransform'>, transform_kwargs=None, label_field='ground_truth', **data_module_kwargs)[source]¶
Load the
SemanticSegmentationDatafrom FiftyOneSampleCollectionobjects.The supported file extensions are:
.jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp, and.npy. Mask image file paths will be extracted from thelabel_fieldin theSampleCollectionobjects. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_dataset¶ (
Optional[object]) – TheSampleCollectionto use when training.val_dataset¶ (
Optional[object]) – TheSampleCollectionto use when validating.test_dataset¶ (
Optional[object]) – TheSampleCollectionto use when testing.predict_dataset¶ (
Optional[object]) – TheSampleCollectionto use when predicting.label_field¶ (
str) – The field in theSampleCollectionobjects containing the targets.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.num_classes¶ (
Optional[int]) – The number of segmentation classes.labels_map¶ (
Optional[Dict[int,Tuple[int,int,int]]]) – An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
SemanticSegmentationData.
Examples
>>> import numpy as np >>> import fiftyone as fo >>> from flash import Trainer >>> from flash.image import SemanticSegmentation, SemanticSegmentationData >>> train_dataset = fo.Dataset.from_images( ... ["image_1.png", "image_2.png", "image_3.png"] ... ) ... >>> samples = [train_dataset[filepath] for filepath in train_dataset.values("filepath")] >>> for sample in samples: ... sample["ground_truth"] = fo.Segmentation(mask=np.random.randint(0, 10, (64, 64), dtype="uint8")) ... sample.save() ... >>> predict_dataset = fo.Dataset.from_images( ... ["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"] ... ) ... >>> datamodule = SemanticSegmentationData.from_fiftyone( ... train_dataset=train_dataset, ... predict_dataset=predict_dataset, ... transform_kwargs=dict(image_size=(128, 128)), ... num_classes=10, ... batch_size=2, ... ) >>> datamodule.num_classes 10 >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_files(train_files=None, train_targets=None, val_files=None, val_targets=None, test_files=None, test_targets=None, predict_files=None, input_cls=<class 'flash.image.segmentation.input.SemanticSegmentationFilesInput'>, num_classes=None, labels_map=None, transform=<class 'flash.image.segmentation.input_transform.SemanticSegmentationInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SemanticSegmentationDatafrom lists of input files and corresponding lists of mask files.The supported file extensions are:
.jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp, and.npy. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_files¶ (
Optional[Sequence[str]]) – The list of image files to use when training.train_targets¶ (
Optional[Sequence[str]]) – The list of mask files to use when training.val_files¶ (
Optional[Sequence[str]]) – The list of image files to use when validating.val_targets¶ (
Optional[Sequence[str]]) – The list of mask files to use when validating.test_files¶ (
Optional[Sequence[str]]) – The list of image files to use when testing.test_targets¶ (
Optional[Sequence[str]]) – The list of mask files to use when testing.predict_files¶ (
Optional[Sequence[str]]) – The list of image files to use when predicting.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.num_classes¶ (
Optional[int]) – The number of segmentation classes.labels_map¶ (
Optional[Dict[int,Tuple[int,int,int]]]) – An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
SemanticSegmentationData.
Examples
>>> from flash import Trainer >>> from flash.image import SemanticSegmentation, SemanticSegmentationData >>> datamodule = SemanticSegmentationData.from_files( ... train_files=["image_1.png", "image_2.png", "image_3.png"], ... train_targets=["mask_1.npy", "mask_2.npy", "mask_3.npy"], ... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"], ... transform_kwargs=dict(image_size=(128, 128)), ... num_classes=10, ... batch_size=2, ... ) >>> datamodule.num_classes 10 >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_folders(train_folder=None, train_target_folder=None, val_folder=None, val_target_folder=None, test_folder=None, test_target_folder=None, predict_folder=None, input_cls=<class 'flash.image.segmentation.input.SemanticSegmentationFolderInput'>, num_classes=None, labels_map=None, transform=<class 'flash.image.segmentation.input_transform.SemanticSegmentationInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SemanticSegmentationDatafrom folders containing image files and folders containing mask files.The supported file extensions are:
.jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp, and.npy. For train, test, and validation data, the folders are expected to contain the images with a corresponding target folder which contains the mask in a file of the same name. For example, if yourtrain_imagesfolder (passed to thetrain_folderargument) looks like this:train_images ├── image_1.png ├── image_2.png ├── image_3.png ...
your
train_masksfolder (passed to thetrain_target_folderargument) would need to look like this (although the file extensions could be different):train_masks ├── image_1.png ├── image_2.png ├── image_3.png ...
For prediction, the folder is expected to contain the files for inference, like this:
predict_folder ├── predict_image_1.png ├── predict_image_2.png ├── predict_image_3.png ...
To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_folder¶ (
Optional[str]) – The folder containing images to use when training.train_target_folder¶ (
Optional[str]) – The folder containing masks to use when training (files should have the same name as the files in thetrain_folder).val_folder¶ (
Optional[str]) – The folder containing images to use when validating.val_target_folder¶ (
Optional[str]) – The folder containing masks to use when validating (files should have the same name as the files in thetrain_folder).test_folder¶ (
Optional[str]) – The folder containing images to use when testing.test_target_folder¶ (
Optional[str]) – The folder containing masks to use when testing (files should have the same name as the files in thetrain_folder).predict_folder¶ (
Optional[str]) – The folder containing images to use when predicting.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.num_classes¶ (
Optional[int]) – The number of segmentation classes.labels_map¶ (
Optional[Dict[int,Tuple[int,int,int]]]) – An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
SemanticSegmentationData.
Examples
>>> from flash import Trainer >>> from flash.image import SemanticSegmentation, SemanticSegmentationData >>> datamodule = SemanticSegmentationData.from_folders( ... train_folder="train_images", ... train_target_folder="train_masks", ... predict_folder="predict_folder", ... transform_kwargs=dict(image_size=(128, 128)), ... num_classes=10, ... batch_size=2, ... ) >>> datamodule.num_classes 10 >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_numpy(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, input_cls=<class 'flash.image.segmentation.input.SemanticSegmentationNumpyInput'>, num_classes=None, labels_map=None, transform=<class 'flash.image.segmentation.input_transform.SemanticSegmentationInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SemanticSegmentationDatafrom numpy arrays containing images (or lists of arrays) and corresponding numpy arrays containing masks (or lists of arrays).To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_data¶ (
Optional[Collection[ndarray]]) – The numpy array or list of arrays containing images to use when training.train_targets¶ (
Optional[Collection[ndarray]]) – The numpy array or list of arrays containing masks to use when training.val_data¶ (
Optional[Collection[ndarray]]) – The numpy array or list of arrays containing images to use when validating.val_targets¶ (
Optional[Collection[ndarray]]) – The numpy array or list of arrays containing masks to use when validating.test_data¶ (
Optional[Collection[ndarray]]) – The numpy array or list of arrays containing images to use when testing.test_targets¶ (
Optional[Collection[ndarray]]) – The numpy array or list of arrays containing masks to use when testing.predict_data¶ (
Optional[Collection[ndarray]]) – The numpy array or list of arrays to use when predicting.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.num_classes¶ (
Optional[int]) – The number of segmentation classes.labels_map¶ (
Optional[Dict[int,Tuple[int,int,int]]]) – An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
SemanticSegmentationData.
Examples
>>> import numpy as np >>> from flash import Trainer >>> from flash.image import SemanticSegmentation, SemanticSegmentationData >>> datamodule = SemanticSegmentationData.from_numpy( ... train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)], ... train_targets=[ ... np.random.randint(0, 10, (1, 64, 64), dtype="uint8"), ... np.random.randint(0, 10, (1, 64, 64), dtype="uint8"), ... np.random.randint(0, 10, (1, 64, 64), dtype="uint8"), ... ], ... predict_data=[np.random.rand(3, 64, 64)], ... transform_kwargs=dict(image_size=(128, 128)), ... num_classes=10, ... batch_size=2, ... ) >>> datamodule.num_classes 10 >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_tensors(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, input_cls=<class 'flash.image.segmentation.input.SemanticSegmentationTensorInput'>, num_classes=None, labels_map=None, transform=<class 'flash.image.segmentation.input_transform.SemanticSegmentationInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SemanticSegmentationDatafrom torch tensors containing images (or lists of tensors) and corresponding torch tensors containing masks (or lists of tensors).To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_data¶ (
Optional[Collection[Tensor]]) – The torch tensor or list of tensors containing images to use when training.train_targets¶ (
Optional[Collection[Tensor]]) – The torch tensor or list of tensors containing masks to use when training.val_data¶ (
Optional[Collection[Tensor]]) – The torch tensor or list of tensors containing images to use when validating.val_targets¶ (
Optional[Collection[Tensor]]) – The torch tensor or list of tensors containing masks to use when validating.test_data¶ (
Optional[Collection[Tensor]]) – The torch tensor or list of tensors containing images to use when testing.test_targets¶ (
Optional[Collection[Tensor]]) – The torch tensor or list of tensors containing masks to use when testing.predict_data¶ (
Optional[Collection[Tensor]]) – The torch tensor or list of tensors to use when predicting.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.num_classes¶ (
Optional[int]) – The number of segmentation classes.labels_map¶ (
Optional[Dict[int,Tuple[int,int,int]]]) – An optional mapping from class to RGB tuple indicating the colour to use when visualizing masks. If not provided, a random mapping will be used.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
SemanticSegmentationData.
Examples
>>> import torch >>> from flash import Trainer >>> from flash.image import SemanticSegmentation, SemanticSegmentationData >>> datamodule = SemanticSegmentationData.from_tensors( ... train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)], ... train_targets=[ ... torch.randint(10, (1, 64, 64)), ... torch.randint(10, (1, 64, 64)), ... torch.randint(10, (1, 64, 64)), ... ], ... predict_data=[torch.rand(3, 64, 64)], ... transform_kwargs=dict(image_size=(128, 128)), ... num_classes=10, ... batch_size=2, ... ) >>> datamodule.num_classes 10 >>> model = SemanticSegmentation(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- input_transform_cls¶
alias of
flash.image.segmentation.input_transform.SemanticSegmentationInputTransform