Shortcuts

Data

DataFlow Gif

Terminology

Here are common terms you need to be familiar with:

Terminology

Term

Definition

Deserializer

The Deserializer provides a single deserialize() method.

DataModule

The DataModule contains the datasets, transforms and dataloaders.

DataPipeline

The DataPipeline is Flash internal object to manage Deserializer, Input, InputTransform, OutputTransform, and Output objects.

Input

The Input provides load_data() and load_sample() hooks for creating data sets from metadata (such as folder names).

InputTransform

The InputTransform provides a simple hook-based API to encapsulate your pre-processing logic.

These hooks (such as per_sample_transform()) enable transformations to be applied to your data at every point along the pipeline (including on the device). The DataPipeline contains a system to call the right hooks when needed. The InputTransform hooks can be either overridden directly or provided as a dictionary of transforms (mapping hook name to callable transform).

OutputTransform

The OutputTransform provides a simple hook-based API to encapsulate your post-processing logic.

The OutputTransform hooks cover from model outputs to predictions export.

Output

The Output provides a single serialize() method that is used to convert model outputs (after the OutputTransform) to the desired output format during prediction.

How to use out-of-the-box Flash DataModules

Flash provides several DataModules with helpers functions. Check out the Image Classification section (or the sections for any of our other tasks) to learn more.

Data Processing

Currently, it is common practice to implement a torch.utils.data.Dataset and provide it to a torch.utils.data.DataLoader. However, after model training, it requires a lot of engineering overhead to make inference on raw data and deploy the model in production environment. Usually, extra processing logic should be added to bridge the gap between training data and raw data.

The Input class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way.

The InputTransform and OutputTransform classes can be used to manage the input and output transforms. The Output class provides the logic for converting OutputTransform outputs to the desired predict format (e.g. classes, labels, probabilities, etc.).

By providing a series of hooks that can be overridden with custom data processing logic (or just targeted with transforms), Flash gives the user much more granular control over their data processing flow.

Here are the primary advantages:

  • Making inference on raw data simple

  • Make the code more readable, modular and self-contained

  • Data Augmentation experimentation is simpler

To change the processing behavior only on specific stages for a given hook, you can prefix each of the InputTransform and OutputTransform hooks by adding train, val, test or predict.

Check out InputTransform for some examples.

How to customize existing DataModules

Any Flash DataModule can be created directly from datasets using the from_datasets() like this:

from flash import DataModule, Trainer

data_module = DataModule.from_datasets(train_dataset=MyDataset())
trainer = Trainer()
trainer.fit(model, data_module=data_module)

The DataModule provides additional classmethod helpers (from_*) for loading data from various sources. In each from_* method, the DataModule internally retrieves the correct Input to use from the InputTransform. Flash AutoDataset instances are created from the Input for train, val, test, and predict. The DataModule populates the DataLoader for each stage with the corresponding AutoDataset.

Customize preprocessing of DataModules

The InputTransform contains the processing logic related to a given task. Each InputTransform provides some default transforms through the default_transforms() method. Users can easily override these by providing their own transforms to the DataModule. Here’s an example:

from flash.core.data.transforms import ApplyToKeys
from flash.image import ImageClassificationData, ImageClassifier

transform = {"per_sample_transform": ApplyToKeys("input", my_per_sample_transform)}

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    train_transform=transform,
    val_transform=transform,
    test_transform=transform,
)

Alternatively, the user may directly override the hooks for their needs like this:

from typing import Any, Dict
from flash.image import ImageClassificationData, ImageClassifier, ImageClassificationInputTransform


class CustomImageClassificationInputTransform(ImageClassificationInputTransform):
    def per_sample_transform(sample: Dict[str, Any]) -> Dict[str, Any]:
        sample["input"] = my_per_sample_transform(sample["input"])
        return sample


datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
    test_folder="data/hymenoptera_data/test/",
    input_transform=CustomImageClassificationInputTransform(),
)

Create your own InputTransform and DataModule

The example below shows a very simple ImageClassificationInputTransform with a single ImageClassificationFoldersInput and an ImageClassificationDataModule.

1. User-Facing API design

Designing an easy-to-use API is key. This is the first and most important step. We want the ImageClassificationDataModule to generate a dataset from folders of images arranged in this way.

Example:

train/dog/xxx.png
train/dog/xxy.png
train/dog/xxz.png
train/cat/123.png
train/cat/nsdf3.png
train/cat/asd932.png

Example:

dm = ImageClassificationDataModule.from_folders(
    train_folder="./data/train",
    val_folder="./data/val",
    test_folder="./data/test",
    predict_folder="./data/predict",
)

model = ImageClassifier(...)
trainer = Trainer(...)

trainer.fit(model, dm)

2. The Input

We start by implementing the ImageClassificationFoldersInput. The load_data method will produce a list of files and targets from the given directory. The load_sample method will load the given file as a PIL.Image. Here’s the full ImageClassificationFoldersInput:

from PIL import Image
from torchvision.datasets.folder import make_dataset
from typing import Any, Dict
from flash.core.data.io.input import Input, DataKeys


class ImageClassificationFoldersInput(Input):
    def load_data(self, folder: str, dataset: Any) -> Iterable:
        # The dataset is optional but can be useful to save some metadata.

        # `metadata` contains the image path and its corresponding label
        # with the following structure:
        # [(image_path_1, label_1), ... (image_path_n, label_n)].
        metadata = make_dataset(folder)

        # for the train `AutoDataset`, we want to store the `num_classes`.
        if self.training:
            dataset.num_classes = len(np.unique([m[1] for m in metadata]))

        return [
            {
                DataKeys.INPUT: file,
                DataKeys.TARGET: target,
            }
            for file, target in metadata
        ]

    def predict_load_data(self, predict_folder: str) -> Iterable:
        # This returns [image_path_1, ... image_path_m].
        return [{DataKeys.INPUT: file} for file in os.listdir(folder)]

    def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
        sample[DataKeys.INPUT] = Image.open(sample[DataKeys.INPUT])
        return sample

Note

We return samples as dictionaries using the DataKeys by convention. This is the recommended (although not required) way to represent data in Flash.

3. The InputTransform

Next, implement your custom ImageClassificationInputTransform with some default transforms and a reference to the data source:

from typing import Any, Callable, Dict, Optional
from flash.core.data.io.input import DataKeys, InputFormat
from flash.core.data.io.input_transform import InputTransform
import torchvision.transforms.functional as T

# Subclass `InputTransform`
class ImageClassificationInputTransform(InputTransform):
    def __init__(
        self,
        train_transform: Optional[Dict[str, Callable]] = None,
        val_transform: Optional[Dict[str, Callable]] = None,
        test_transform: Optional[Dict[str, Callable]] = None,
        predict_transform: Optional[Dict[str, Callable]] = None,
    ):
        super().__init__(
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            predict_transform=predict_transform,
            inputs={
                InputFormat.FOLDERS: ImageClassificationFoldersInput(),
            },
            default_input=InputFormat.FOLDERS,
        )

    def get_state_dict(self) -> Dict[str, Any]:
        return {**self.transforms}

    @classmethod
    def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
        return cls(**state_dict)

    def default_transforms(self) -> Dict[str, Callable]:
        return {"per_sample_transform": ApplyToKeys(DataKeys.INPUT, T.to_tensor)}

4. The DataModule

Finally, let’s implement the ImageClassificationDataModule. We get the from_folders classmethod for free as we’ve registered a InputFormat.FOLDERS data source in our ImageClassificationInputTransform. All we need to do is attach our InputTransform class like this:

from flash import DataModule


class ImageClassificationDataModule(DataModule):

    # Set `input_transform_cls` with your custom `InputTransform`.
    input_transform_cls = ImageClassificationInputTransform

How it works behind the scenes

Input

Note

The load_data() and load_sample() will be used to generate an AutoDataset object.

Here is the AutoDataset pseudo-code.

class AutoDataset:
    def __init__(
        self,
        data: List[Any],  # output of `Input.load_data`
        input: Input,
        running_stage: RunningStage,
    ):

        self.data = data
        self.input = input

    def __getitem__(self, index: int):
        return self.input.load_sample(self.data[index])

    def __len__(self):
        return len(self.data)

InputTransform

Note

The per_sample_transform(), collate(), per_batch_transform() are injected as the torch.utils.data.DataLoader.collate_fn function of the DataLoader.

Here is the pseudo code using the input transform hooks name. Flash takes care of calling the right hooks for each stage.

Example:

# This will be wrapped into a :class:`~flash.core.data.io.input_transform.flash.core.data.io.input_transform._InputTransformProcessor`.
def collate_fn(samples: Sequence[Any]) -> Any:

    # This will be wrapped into a :class:`~flash.core.data.io.input_transform._InputTransformSequential`
    for sample in samples:
        sample = per_sample_transform(sample)

    samples = type(samples)(samples)

    # if :func:`flash.core.data.io.input_transform.InputTransform.per_sample_transform_on_device` hook is overridden,
    # those functions below will be no-ops

    samples = collate(samples)
    samples = per_batch_transform(samples)
    return samples

dataloader = DataLoader(dataset, collate_fn=collate_fn)

Note

The per_sample_transform_on_device, collate, per_batch_transform_on_device are injected after the LightningModule transfer_batch_to_device hook.

Here is the pseudo code using the input transform hooks name. Flash takes care of calling the right hooks for each stage.

Example:

# This will be wrapped into a :class:`~flash.core.data.io.input_transform._InputTransformProcessor`
def collate_fn(samples: Sequence[Any]) -> Any:

    # if ``per_batch_transform`` hook is overridden, those functions below will be no-ops
    samples = [per_sample_transform_on_device(sample) for sample in samples]
    samples = type(samples)(samples)
    samples = collate(samples)

    samples = per_batch_transform_on_device(samples)
    return samples

# move the data to device
data = lightning_module.transfer_data_to_device(data)
data = collate_fn(data)
predictions = lightning_module(data)

OutputTransform and Output

Once the predictions have been generated by the Flash Task, the Flash DataPipeline will execute the OutputTransform hooks and the Output behind the scenes.

First, the per_batch_transform() hooks will be applied on the batch predictions. Then, the uncollate() will split the batch into individual predictions. Next, the per_sample_transform() will be applied on each prediction. Finally, the transform() method will be called to serialize the predictions.

Note

The transform can be applied either on device or CPU.

Here is the pseudo-code:

Example:

# This will be wrapped into a :class:`~flash.core.data.batch._OutputTransformProcessor`
def uncollate_fn(batch: Any) -> Any:

    batch = per_batch_transform(batch)

    samples = uncollate(batch)

    samples = [per_sample_transform(sample) for sample in samples]

    return [output.transform(sample) for sample in samples]

predictions = lightning_module(data)
return uncollate_fn(predictions)
Read the Docs v: latest
Versions
latest
stable
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.