The Data

The first step to contributing a task is to implement the classes we need to load some data. Inside you should implement:

  1. some DataSource classes (optional)

  2. a Preprocess

  3. a DataModule

  4. a BaseVisualization (optional)

  5. a Postprocess (optional)


The DataSource class contains the logic for data loading from different sources such as folders, files, tensors, etc. Every Flash DataModule can be instantiated with from_datasets(). For each additional way you want the user to be able to instantiate your DataModule, you’ll need to create a DataSource. Each DataSource has 2 methods:

  • load_data() takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata.

  • load_sample() then takes as input a single element from the output of load_data and returns a sample.

By default these methods just return their input, so you don’t need both a load_data() and a load_sample() to create a DataSource. Where possible, you should override one of our existing DataSource classes.

Let’s start by implementing a TemplateNumpyDataSource, which overrides NumpyDataSource. The main DataSource method that we have to implement is load_data(). As we’re extending the NumpyDataSource, we expect the same data argument (in this case, a tuple containing data and corresponding target arrays).

We can also take the dataset argument. Any attributes we set on dataset will be available on the Dataset generated by our DataSource. In this data source, we’ll set the num_features attribute.

Here’s the code for our TemplateNumpyDataSource.load_data method:

def load_data(self, data: Tuple[np.ndarray, Sequence[Any]], dataset: Any) -> Sequence[Mapping[str, Any]]:
    """Sets the ``num_features`` attribute and calls ``super().load_data``.

        data: The tuple of ``np.ndarray`` (num_examples x num_features) and associated targets.
        dataset: The object that we can set attributes (such as ``num_features``) on.

        A sequence of samples / sample metadata.
    dataset.num_features = data[0].shape[1]
    return super().load_data(data, dataset)


Later, when we add our DataModule implementation, we’ll make num_features available to the user.

Sometimes you need to something a bit more custom. When creating a custom DataSource, the type of the data argument is up to you. For our template Task, it would be cool if the user could provide a scikit-learn Bunch as the data source. To achieve this, we’ll add a TemplateSKLearnDataSource whose load_data expects a Bunch as input. We override our TemplateNumpyDataSource so that we can call super with the data and targets extracted from the Bunch. We perform two additional steps here to improve the user experience:

  1. We set the num_classes attribute on the dataset. If num_classes is set, it is automatically made available as a property of the DataModule.

  2. We create and set a LabelsState. The labels provided here will be shared with the Labels serializer, so the user doesn’t need to provide them.

Here’s the code for the TemplateSKLearnDataSource.load_data method:

def load_data(self, data: Bunch, dataset: Any) -> Sequence[Mapping[str, Any]]:
    """Gets the ``data`` and ``target`` attributes from the ``Bunch`` and passes them to ``super().load_data``.

        data: The scikit-learn data ``Bunch``.
        dataset: The object that we can set attributes (such as ``num_classes``) on.

        A sequence of samples / sample metadata.
    dataset.num_classes = len(data.target_names)
    return super().load_data((,, dataset=dataset)

We can customize the behaviour of our load_data() for different stages, by prepending train, val, test, or predict. For our TemplateSKLearnDataSource, we don’t want to provide any targets to the model when predicting. We can implement predict_load_data like this:

def predict_load_data(self, data: Bunch) -> Sequence[Mapping[str, Any]]:
    """Avoid including targets when predicting.

        data: The scikit-learn data ``Bunch``.

        A sequence of samples / sample metadata.
    return super().predict_load_data(

DataSource vs Dataset

A DataSource is not the same as a When a from_* method is called on your DataModule, it gets the DataSource to use from the Preprocess. A Dataset is then created from the DataSource for each stage (train, val, test, predict) using the provided metadata (e.g. folder name, numpy array etc.).

The output of the load_data() can just be a instance. If the library that your Task is based on provides a custom dataset, you don’t need to re-write it as a DataSource. For example, the load_data() of the VideoClassificationPathsDataSource just creates an EncodedVideoDataset from the given folder. Here’s how it looks (from video/

def load_data(self, data: str, dataset: Optional[Any] = None) -> "LabeledVideoDataset":
    ds = self._make_encoded_video_dataset(data)
        label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in ds._labeled_videos._paths_and_labels}
        dataset.num_classes = len(np.unique([s[1]["label"] for s in ds._labeled_videos]))
    return ds


The Preprocess object contains all the data transforms. Internally we inject the Preprocess transforms at several points along the pipeline.

Defining the standard transforms (typically at least a to_tensor_transform should be defined) for your Preprocess is as simple as implementing the default_transforms method. The Preprocess must take train_transform, val_transform, test_transform, and predict_transform arguments in the __init__. These arguments can be provided by the user (when creating the DataModule) to override the default transforms. Any additional arguments are up to you.

Inside the __init__, we make a call to super. This is where we register our data sources. Data sources should be given as a dictionary which maps data source name to data source object. The name can be anything, but if you want to take advantage of our built-in from_* classmethods, you should use DefaultDataSources as the names. In our case, we have both a NUMPY and a custom scikit-learn data source (which we’ll call “sklearn”).

You should also provide a default_data_source. This is the name of the data source to use by default when predicting. It’d be cool if we could get predictions just from a numpy array, so we’ll use NUMPY as the default.

Here’s our TemplatePreprocess.__init__:

def __init__(
    train_transform: Optional[Dict[str, Callable]] = None,
    val_transform: Optional[Dict[str, Callable]] = None,
    test_transform: Optional[Dict[str, Callable]] = None,
    predict_transform: Optional[Dict[str, Callable]] = None,
            DefaultDataSources.NUMPY: TemplateNumpyDataSource(),
            "sklearn": TemplateSKLearnDataSource(),

For our TemplatePreprocess, we’ll just configure a default to_tensor_transform. Let’s first define the transform as a staticmethod:

def input_to_tensor(input: np.ndarray):
    """Transform which creates a tensor from the given numpy ``ndarray`` and converts it to ``float``"""
    return torch.from_numpy(input).float()

Our inputs samples will be dictionaries whose keys are in the DefaultDataKeys. You can map each key to different transforms using ApplyToKeys. Here’s our default_transforms method:

def default_transforms(self) -> Optional[Dict[str, Callable]]:
    """Configures the default ``to_tensor_transform``.

        Our dictionary of transforms.
    return {
        "to_tensor_transform": nn.Sequential(
            ApplyToKeys(DefaultDataKeys.INPUT, self.input_to_tensor),
            ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),


The DataModule is responsible for creating the DataLoader and injecting the transforms for each stage. When the user calls a from_* method (such as from_numpy()), the following steps take place:

  1. The from_data_source() method is called with the name of the DataSource to use and the inputs to provide to load_data() for each stage.

  2. The Preprocess is created from cls.preprocess_cls (if it wasn’t provided by the user) with any provided transforms.

  3. The DataSource of the provided name is retrieved from the Preprocess.

  4. A BaseAutoDataset is created from the DataSource for each stage.

  5. The DataModule is instantiated with the data sets.

To create our TemplateData DataModule, we first need to attach out preprocess class like this:

preprocess_cls = TemplatePreprocess

Since we provided a NUMPY DataSource in the TemplatePreprocess, from_numpy() will now work with our TemplateData.

If you’ve defined a fully custom DataSource (like our TemplateSKLearnDataSource), then you will need to write a from_* method for each. Here’s the from_sklearn method for our TemplateData:

def from_sklearn(
    train_bunch: Optional[Bunch] = None,
    val_bunch: Optional[Bunch] = None,
    test_bunch: Optional[Bunch] = None,
    predict_bunch: Optional[Bunch] = None,
    train_transform: Optional[Dict[str, Callable]] = None,
    val_transform: Optional[Dict[str, Callable]] = None,
    test_transform: Optional[Dict[str, Callable]] = None,
    predict_transform: Optional[Dict[str, Callable]] = None,
    data_fetcher: Optional[BaseDataFetcher] = None,
    preprocess: Optional[Preprocess] = None,
    val_split: Optional[float] = None,
    batch_size: int = 4,
    num_workers: int = 0,
    **preprocess_kwargs: Any,
    """This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and passes them
    through to the :meth:`` method underneath.

        train_bunch: The scikit-learn ``Bunch`` containing the train data.
        val_bunch: The scikit-learn ``Bunch`` containing the validation data.
        test_bunch: The scikit-learn ``Bunch`` containing the test data.
        predict_bunch: The scikit-learn ``Bunch`` containing the predict data.
        train_transform: The dictionary of transforms to use during training which maps
            :class:`` hook names to callable transforms.
        val_transform: The dictionary of transforms to use during validation which maps
            :class:`` hook names to callable transforms.
        test_transform: The dictionary of transforms to use during testing which maps
            :class:`` hook names to callable transforms.
        predict_transform: The dictionary of transforms to use during predicting which maps
            :class:`` hook names to callable transforms.
        data_fetcher: The :class:`` to pass to the
        preprocess: The :class:`` to pass to the
            :class:``. If ``None``, ``cls.preprocess_cls`` will be
            constructed and used.
        val_split: The ``val_split`` argument to pass to the :class:``.
        batch_size: The ``batch_size`` argument to pass to the :class:``.
        num_workers: The ``num_workers`` argument to pass to the :class:``.
        preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
            if ``preprocess = None``.

        The constructed data module.
    return super().from_data_source(

The final step is to implement the num_features property for our TemplateData. This is just a convenience for the user that finds the num_features attribute on any of the data sets and returns it. Here’s the code:

def num_features(self) -> Optional[int]:
    """Tries to get the ``num_features`` from each dataset in turn and returns the output."""
    n_fts_train = getattr(self.train_dataset, "num_features", None)
    n_fts_val = getattr(self.val_dataset, "num_features", None)
    n_fts_test = getattr(self.test_dataset, "num_features", None)
    return n_fts_train or n_fts_val or n_fts_test


An optional step is to implement a BaseVisualization. The BaseVisualization lets you control how data at various points in the pipeline can be visualized. This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms.


Don’t worry about implementing it right away, you can always come back and add it later!

Here’s the code for our TemplateVisualization which just prints the data:

class TemplateVisualization(BaseVisualization):
    """The ``TemplateVisualization`` class is a :class:`` that just
    prints the data.

    If you want to provide a visualization with your task, you can override these hooks.

    def show_load_sample(self, samples: List[Any], running_stage: RunningStage):

    def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage):

    def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage):

    def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage):

    def show_per_batch_transform(self, batch: List[Any], running_stage):

We can configure our custom visualization in the TemplateData using configure_data_fetcher() like this:

def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
    """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher``
    return TemplateVisualization(*args, **kwargs)


Postprocess contains any transforms that need to be applied after the model. You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. As an example, here’s the TextClassificationPostprocess which gets the logits from a SequenceClassifierOutput:

class TextClassificationPostprocess(Postprocess):
    def per_batch_transform(self, batch: Any) -> Any:
        if isinstance(batch, SequenceClassifierOutput):
            batch = batch.logits
        return super().per_batch_transform(batch)

In your DataSource or Preprocess, you can add metadata to the batch using the METADATA key. Your Postprocess can then use this metadata in its transforms. You should use this approach if your postprocessing depends on the state of the input before the Preprocess transforms. For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the METADATA. Here’s an example from the SemanticSegmentationNumpyDataSource:

def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]:
    img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float()
    sample[DefaultDataKeys.INPUT] = img
    sample[DefaultDataKeys.METADATA] = {"size": img.shape}
    return sample

The METADATA can now be referenced in your Postprocess. For example, here’s the code for the per_sample_transform method of the SemanticSegmentationPostprocess:

def per_sample_transform(self, sample: Any) -> Any:
    resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA]["size"][-2:], interpolation="bilinear")
    sample[DefaultDataKeys.PREDS] = resize(sample[DefaultDataKeys.PREDS])
    sample[DefaultDataKeys.INPUT] = resize(sample[DefaultDataKeys.INPUT])
    return super().per_sample_transform(sample)

Now that you’ve got some data, it’s time to add some backbones for your task!

Read the Docs v: 0.5.1
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.