Shortcuts

StyleTransferData

class flash.image.style_transfer.data.StyleTransferData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The StyleTransferData class is a DataModule with a set of classmethods for loading data for image style transfer.

classmethod from_files(train_files=None, predict_files=None, input_cls=<class 'flash.image.classification.input.ImageClassificationFilesInput'>, transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the StyleTransferData from lists of image files.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, and .npy. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

StyleTransferData

Returns

The constructed StyleTransferData.

Examples

>>> from flash import Trainer
>>> from flash.image import StyleTransfer, StyleTransferData
>>> datamodule = StyleTransferData.from_files(
...     train_files=["image_1.png", "image_2.png", "image_3.png"],
...     predict_files=["predict_image_1.png", "predict_image_2.png", "predict_image_3.png"],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> model = StyleTransfer()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_folders(train_folder=None, predict_folder=None, input_cls=<class 'flash.image.classification.input.ImageClassificationFolderInput'>, transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the StyleTransferData from folders containing images.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, and .npy. Here’s the required folder structure:

train_folder
├── image_1.png
├── image_2.png
├── image_3.png
...

To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

StyleTransferData

Returns

The constructed StyleTransferData.

Examples

>>> from flash import Trainer
>>> from flash.image import StyleTransfer, StyleTransferData
>>> datamodule = StyleTransferData.from_folders(
...     train_folder="train_folder",
...     predict_folder="predict_folder",
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> model = StyleTransfer()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_numpy(train_data=None, predict_data=None, input_cls=<class 'flash.image.data.ImageNumpyInput'>, transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the StyleTransferData from numpy arrays (or lists of arrays).

To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

StyleTransferData

Returns

The constructed StyleTransferData.

Examples

>>> import numpy as np
>>> from flash import Trainer
>>> from flash.image import StyleTransfer, StyleTransferData
>>> datamodule = StyleTransferData.from_numpy(
...     train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)],
...     predict_data=[np.random.rand(3, 64, 64)],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> model = StyleTransfer()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_tensors(train_data=None, predict_data=None, input_cls=<class 'flash.image.data.ImageTensorInput'>, transform=<class 'flash.image.style_transfer.input_transform.StyleTransferInputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the StyleTransferData from torch tensors (or lists of tensors).

To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

StyleTransferData

Returns

The constructed StyleTransferData.

Examples

>>> import torch
>>> from flash import Trainer
>>> from flash.image import StyleTransfer, StyleTransferData
>>> datamodule = StyleTransferData.from_tensors(
...     train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)],
...     predict_data=[torch.rand(3, 64, 64)],
...     transform_kwargs=dict(image_size=(128, 128)),
...     batch_size=2,
... )
>>> model = StyleTransfer()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
input_transform_cls

alias of flash.image.style_transfer.input_transform.StyleTransferInputTransform

Read the Docs v: stable
Versions
latest
stable
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.