StyleTransferData¶
- class flash.image.style_transfer.data.StyleTransferData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
StyleTransferData
class is aDataModule
with a set of classmethods for loading data for image style transfer.- classmethod from_files(train_files=None, predict_files=None, train_transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, predict_transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, input_cls=<class 'flash.image.classification.input.ImageClassificationFilesInput'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
StyleTransferData
from lists of image files.The supported file extensions are:
.jpg
,.jpeg
,.png
,.ppm
,.bmp
,.pgm
,.tif
,.tiff
,.webp
, and.npy
. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_files¶ (
Optional
[Sequence
[str
]]) – The list of image files to use when training.predict_files¶ (
Optional
[Sequence
[str
]]) – The list of image files to use when predicting.train_transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[LightningEnum
,str
],Dict
[str
,Any
]],Union
[LightningEnum
,str
],None
)) – TheInputTransform
type to use when training.predict_transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[LightningEnum
,str
],Dict
[str
,Any
]],Union
[LightningEnum
,str
],None
)) – TheInputTransform
type to use when predicting.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
StyleTransferData
.
Examples
>>> from flash import Trainer >>> from flash.image import StyleTransfer, StyleTransferData >>> datamodule = StyleTransferData.from_files( ... train_files=["image_1.png", "image_2.png", "image_3.png"], ... predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"], ... transform_kwargs=dict(image_size=(128, 128)), ... batch_size=2, ... ) >>> model = StyleTransfer() >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_folders(train_folder=None, predict_folder=None, train_transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, predict_transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, input_cls=<class 'flash.image.classification.input.ImageClassificationFolderInput'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
StyleTransferData
from folders containing images.The supported file extensions are:
.jpg
,.jpeg
,.png
,.ppm
,.bmp
,.pgm
,.tif
,.tiff
,.webp
, and.npy
. Here’s the required folder structure:train_folder ├── image_1.png ├── image_2.png ├── image_3.png ...
To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_folder¶ (
Optional
[str
]) – The folder containing images to use when training.predict_folder¶ (
Optional
[str
]) – The folder containing images to use when predicting.train_transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[LightningEnum
,str
],Dict
[str
,Any
]],Union
[LightningEnum
,str
],None
)) – TheInputTransform
type to use when training.predict_transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[LightningEnum
,str
],Dict
[str
,Any
]],Union
[LightningEnum
,str
],None
)) – TheInputTransform
type to use when predicting.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
StyleTransferData
.
Examples
>>> from flash import Trainer >>> from flash.image import StyleTransfer, StyleTransferData >>> datamodule = StyleTransferData.from_folders( ... train_folder="train_folder", ... predict_folder="predict_folder", ... transform_kwargs=dict(image_size=(128, 128)), ... batch_size=2, ... ) >>> model = StyleTransfer() >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_numpy(train_data=None, predict_data=None, train_transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, predict_transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, input_cls=<class 'flash.image.data.ImageNumpyInput'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
StyleTransferData
from numpy arrays (or lists of arrays).To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_data¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays to use when training.predict_data¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays to use when predicting.train_transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[LightningEnum
,str
],Dict
[str
,Any
]],Union
[LightningEnum
,str
],None
)) – TheInputTransform
type to use when training.predict_transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[LightningEnum
,str
],Dict
[str
,Any
]],Union
[LightningEnum
,str
],None
)) – TheInputTransform
type to use when predicting.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
StyleTransferData
.
Examples
>>> import numpy as np >>> from flash import Trainer >>> from flash.image import StyleTransfer, StyleTransferData >>> datamodule = StyleTransferData.from_numpy( ... train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)], ... predict_data=[np.random.rand(3, 64, 64)], ... transform_kwargs=dict(image_size=(128, 128)), ... batch_size=2, ... ) >>> model = StyleTransfer() >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_tensors(train_data=None, predict_data=None, train_transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, predict_transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, input_cls=<class 'flash.image.data.ImageTensorInput'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
StyleTransferData
from torch tensors (or lists of tensors).To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_data¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors to use when training.predict_data¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors to use when predicting.train_transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[LightningEnum
,str
],Dict
[str
,Any
]],Union
[LightningEnum
,str
],None
)) – TheInputTransform
type to use when training.predict_transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[LightningEnum
,str
],Dict
[str
,Any
]],Union
[LightningEnum
,str
],None
)) – TheInputTransform
type to use when predicting.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
StyleTransferData
.
Examples
>>> import torch >>> from flash import Trainer >>> from flash.image import StyleTransfer, StyleTransferData >>> datamodule = StyleTransferData.from_tensors( ... train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)], ... predict_data=[torch.rand(3, 64, 64)], ... transform_kwargs=dict(image_size=(128, 128)), ... batch_size=2, ... ) >>> model = StyleTransfer() >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- input_transform_cls¶
alias of
flash.image.style_transfer.input_transform.StyleTransferInputTransform