Shortcuts

Preprocess

class flash.core.data.process.Preprocess(train_transform=None, val_transform=None, test_transform=None, predict_transform=None, data_sources=None, deserializer=None, default_data_source=None)[source]

The Preprocess encapsulates all the data processing logic that should run before the data is passed to the model. It is particularly useful when you want to provide an end to end implementation which works with 4 different stages: train, validation, test, and inference (predict).

The Preprocess supports the following hooks:

  • pre_tensor_transform: Performs transforms on a single data sample.

    Example:

    * Input: Receive a PIL Image and its label.
    
    * Action: Rotate the PIL Image.
    
    * Output: Return the rotated PIL image and its label.
    
  • to_tensor_transform: Converts a single data sample to a tensor / data structure containing tensors.

    Example:

    * Input: Receive the rotated PIL Image and its label.
    
    * Action: Convert the rotated PIL Image to a tensor.
    
    * Output: Return the tensored image and its label.
    
  • post_tensor_transform: Performs transform on a single tensor sample.

    Example:

    * Input: Receive the tensored image and its label.
    
    * Action: Flip the tensored image randomly.
    
    * Output: Return the tensored image and its label.
    
  • per_batch_transform: Performs transforms on a batch.

    In this example, we decided not to override the hook.

  • per_sample_transform_on_device: Performs transform on a sample already on a GPU or TPU.

    Example:

    * Input: Receive a tensored image on device and its label.
    
    * Action: Apply random transforms.
    
    * Output: Return an augmented tensored image on device and its label.
    
  • collate: Converts a sequence of data samples into a batch.

    Defaults to torch.utils.data._utils.collate.default_collate. Example:

    * Input: Receive a list of augmented tensored images and their respective labels.
    
    * Action: Collate the list of images into batch.
    
    * Output: Return a batch of images and their labels.
    
  • per_batch_transform_on_device: Performs transform on a batch already on GPU or TPU.

    Example:

    * Input: Receive a batch of images and their labels.
    
    * Action: Apply normalization on the batch by subtracting the mean
        and dividing by the standard deviation from ImageNet.
    
    * Output: Return a normalized augmented batch of images and their labels.
    

Note

The per_sample_transform_on_device and per_batch_transform are mutually exclusive as it will impact performances.

Data processing can be configured by overriding hooks or through transforms. The preprocess transforms are given as a mapping from hook names to callables. Default transforms can be configured by overriding the default_transforms or {train,val,test,predict}_default_transforms methods. These can then be overridden by the user with the {train,val,test,predict}_transform arguments to the Preprocess. All of the hooks can be used in the transform mappings.

Example:

class CustomPreprocess(Preprocess):

    def default_transforms() -> Mapping[str, Callable]:
        return {
            "to_tensor_transform": transforms.ToTensor(),
            "collate": torch.utils.data._utils.collate.default_collate,
        }

    def train_default_transforms() -> Mapping[str, Callable]:
        return {
            "pre_tensor_transform": transforms.RandomHorizontalFlip(),
            "to_tensor_transform": transforms.ToTensor(),
            "collate": torch.utils.data._utils.collate.default_collate,
        }

When overriding hooks for particular stages, you can prefix with train, val, test or predict. For example, you can achieve the same as the above example by implementing train_pre_tensor_transform and train_to_tensor_transform.

Example:

class CustomPreprocess(Preprocess):

    def train_pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image:
        return transforms.RandomHorizontalFlip()(sample)

    def to_tensor_transform(self, sample: PIL.Image) -> torch.Tensor:
        return transforms.ToTensor()(sample)

    def collate(self, samples: List[torch.Tensor]) -> torch.Tensor:
        return torch.utils.data._utils.collate.default_collate(samples)

Each hook is aware of the Trainer running stage through booleans. These are useful for adapting functionality for a stage without duplicating code.

Example:

class CustomPreprocess(Preprocess):

    def pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image:

        if self.training:
            # logic for training

        elif self.validating:
            # logic for validation

        elif self.testing:
            # logic for testing

        elif self.predicting:
            # logic for predicting
available_data_sources()[source]

Get the list of available data source names for use with this Preprocess.

Return type

Sequence[str]

Returns

The list of data source names.

collate(samples, metadata=None)[source]

Transform to convert a sequence of samples to a collated batch.

Return type

Any

data_source_of_name(data_source_name)[source]

Get the DataSource of the given name from the Preprocess.

Parameters

data_source_name (str) – The name of the data source to look up.

Return type

DataSource

Returns

The DataSource of the given name.

Raises

MisconfigurationException – If the requested data source is not configured by this Preprocess.

static default_transforms()[source]

The default transforms to use.

Will be overridden by transforms passed to the __init__.

Return type

Optional[Dict[str, Callable]]

per_batch_transform(batch)[source]

Transforms to apply to a whole batch (if possible use this for efficiency).

Note

This option is mutually exclusive with per_sample_transform_on_device(), since if both are specified, uncollation has to be applied.

Return type

Any

per_batch_transform_on_device(batch)[source]

Transforms to apply to a whole batch (if possible use this for efficiency).

Note

This function won’t be called within the dataloader workers, since to make that happen each of the workers would have to create it’s own CUDA-context which would pollute GPU memory (if on GPU).

Return type

Any

per_sample_transform_on_device(sample)[source]

Transforms to apply to the data before the collation (per-sample basis).

Note

This option is mutually exclusive with per_batch_transform(), since if both are specified, uncollation has to be applied.

Note

This function won’t be called within the dataloader workers, since to make that happen each of the workers would have to create it’s own CUDA-context which would pollute GPU memory (if on GPU).

Return type

Any

post_tensor_transform(sample)[source]

Transforms to apply on a tensor.

Return type

Tensor

pre_tensor_transform(sample)[source]

Transforms to apply on a single object.

Return type

Any

to_tensor_transform(sample)[source]

Transforms to convert single object to a tensor.

Return type

Tensor

property transforms: Dict[str, Optional[Dict[str, Callable]]]

The transforms currently being used by this Preprocess.

Return type

Dict[str, Optional[Dict[str, Callable]]]

Read the Docs v: latest
Versions
latest
stable
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.