The Data

The first step to contributing a task is to implement the classes we need to load some data. Inside you should implement:

  1. some Input classes (optional)

  2. a InputTransform

  3. a DataModule

  4. a BaseVisualization (optional)

  5. a OutputTransform (optional)


The Input class contains the logic for data loading from different sources such as folders, files, tensors, etc. Every Flash DataModule can be instantiated with from_datasets(). For each additional way you want the user to be able to instantiate your DataModule, you’ll need to create a Input. Each Input has 2 methods:

  • load_data() takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata.

  • load_sample() then takes as input a single element from the output of load_data and returns a sample.

By default these methods just return their input, so you don’t need both a load_data() and a load_sample() to create a Input. Where possible, you should override one of our existing Input classes.

Let’s start by implementing a TemplateNumpyInput, which overrides NumpyInput. The main Input method that we have to implement is load_data(). As we’re extending the NumpyInput, we expect the same data argument (in this case, a tuple containing data and corresponding target arrays).

We can also take the dataset argument. Any attributes we set on dataset will be available on the Dataset generated by our Input. In this data source, we’ll set the num_features attribute.

Here’s the code for our TemplateNumpyInput.load_data method:

def load_data(self, data: Tuple[np.ndarray, Sequence[Any]], dataset: Any) -> Sequence[Mapping[str, Any]]:
    """Sets the ``num_features`` attribute and calls ``super().load_data``.

        data: The tuple of ``np.ndarray`` (num_examples x num_features) and associated targets.
        dataset: The object that we can set attributes (such as ``num_features``) on.

        A sequence of samples / sample metadata.
    dataset.num_features = data[0].shape[1]
    return super().load_data(data, dataset)


Later, when we add our DataModule implementation, we’ll make num_features available to the user.

Sometimes you need to something a bit more custom. When creating a custom Input, the type of the data argument is up to you. For our template Task, it would be cool if the user could provide a scikit-learn Bunch as the data source. To achieve this, we’ll add a TemplateSKLearnInput whose load_data expects a Bunch as input. We override our TemplateNumpyInput so that we can call super with the data and targets extracted from the Bunch. We perform two additional steps here to improve the user experience:

  1. We set the num_classes attribute on the dataset. If num_classes is set, it is automatically made available as a property of the DataModule.

  2. We create and set a ClassificationState. The labels provided here will be shared with the Labels output, so the user doesn’t need to provide them.

Here’s the code for the TemplateSKLearnInput.load_data method:

def load_data(self, data: Bunch, dataset: Any) -> Sequence[Mapping[str, Any]]:
    """Gets the ``data`` and ``target`` attributes from the ``Bunch`` and passes them to ``super().load_data``.

        data: The scikit-learn data ``Bunch``.
        dataset: The object that we can set attributes (such as ``num_classes``) on.

        A sequence of samples / sample metadata.
    dataset.num_classes = len(data.target_names)
    return super().load_data((,, dataset=dataset)

We can customize the behaviour of our load_data() for different stages, by prepending train, val, test, or predict. For our TemplateSKLearnInput, we don’t want to provide any targets to the model when predicting. We can implement predict_load_data like this:

def predict_load_data(self, data: Bunch) -> Sequence[Mapping[str, Any]]:
    """Avoid including targets when predicting.

        data: The scikit-learn data ``Bunch``.

        A sequence of samples / sample metadata.
    return super().predict_load_data(

Input vs Dataset

A Input is not the same as a When a from_* method is called on your DataModule, it gets the Input to use from the InputTransform. A Dataset is then created from the Input` for each stage (train, val, test, predict) using the provided metadata (e.g. folder name, numpy array etc.).

The output of the load_data() can just be a instance. If the library that your Task is based on provides a custom dataset, you don’t need to re-write it as a Input. For example, the load_data() of the VideoClassificationFoldersInput just creates an EncodedVideoDataset from the given folder. Here’s how it looks (from video/

def load_data(
    path: str,
    clip_sampler: Union[str, "ClipSampler"] = "random",
    clip_duration: float = 2,
    clip_sampler_kwargs: Dict[str, Any] = None,
    video_sampler: Type[Sampler] =,
    decode_audio: bool = False,
    decoder: str = "pyav",
) -> "LabeledVideoDataset":
    dataset = labeled_video_dataset(
        _make_clip_sampler(clip_sampler, clip_duration, clip_sampler_kwargs),
    return super().load_data(dataset)


The InputTransform object contains all the data transforms. Internally we inject the InputTransform transforms at several points along the pipeline.

Defining the standard transforms (typically at least a per_sample_transform should be defined) for your InputTransform is as simple as implementing the default_transforms method. The InputTransform must take train_transform, val_transform, test_transform, and predict_transform arguments in the __init__. These arguments can be provided by the user (when creating the DataModule) to override the default transforms. Any additional arguments are up to you.

Inside the __init__, we make a call to super. This is where we register our data sources. Data sources should be given as a dictionary which maps data source name to data source object. The name can be anything, but if you want to take advantage of our built-in from_* classmethods, you should use InputFormat as the names. In our case, we have both a NUMPY and a custom scikit-learn data source (which we’ll call “sklearn”).

You should also provide a default_. This is the name of the data source to use by default when predicting. It’d be cool if we could get predictions just from a numpy array, so we’ll use NUMPY as the default.

Here’s our TemplateInputTransform.__init__:

def __init__(
    train_transform: Optional[Dict[str, Callable]] = None,
    val_transform: Optional[Dict[str, Callable]] = None,
    test_transform: Optional[Dict[str, Callable]] = None,
    predict_transform: Optional[Dict[str, Callable]] = None,
            InputFormat.NUMPY: TemplateNumpyInput(),
            "sklearn": TemplateSKLearnInput(),

For our TemplateInputTransform, we’ll just configure a default per_sample_transform. Let’s first define the transform as a staticmethod:

def input_to_tensor(input: np.ndarray):
    """Transform which creates a tensor from the given numpy ``ndarray`` and converts it to ``float``"""
    return torch.from_numpy(input).float()

Our inputs samples will be dictionaries whose keys are in the DataKeys. You can map each key to different transforms using ApplyToKeys. Here’s our default_transforms method:

def default_transforms(self) -> Optional[Dict[str, Callable]]:
    """Configures the default ``per_sample_transform``.

        Our dictionary of transforms.
    return {
        "per_sample_transform": nn.Sequential(
            ApplyToKeys(DataKeys.INPUT, self.input_to_tensor),
            ApplyToKeys(DataKeys.TARGET, torch.as_tensor),


The DataModule is responsible for creating the DataLoader and injecting the transforms for each stage. When the user calls a from_* method (such as from_numpy()), the following steps take place:

  1. The from_() method is called with the name of the Input to use and the inputs to provide to load_data() for each stage.

  2. The InputTransform is created from cls.input_transform_cls (if it wasn’t provided by the user) with any provided transforms.

  3. The Input of the provided name is retrieved from the InputTransform.

  4. A BaseAutoDataset is created from the Input for each stage.

  5. The DataModule is instantiated with the data sets.

To create our TemplateData DataModule, we first need to attach our input transform class like this:

input_transform_cls = TemplateInputTransform

Since we provided a NUMPY Input in the TemplateInputTransform, from_numpy() will now work with our TemplateData.

If you’ve defined a fully custom Input (like our TemplateSKLearnInput), then you will need to write a from_* method for each. Here’s the from_sklearn method for our TemplateData:

def from_sklearn(
    train_bunch: Optional[Bunch] = None,
    val_bunch: Optional[Bunch] = None,
    test_bunch: Optional[Bunch] = None,
    predict_bunch: Optional[Bunch] = None,
    train_transform: Optional[Dict[str, Callable]] = None,
    val_transform: Optional[Dict[str, Callable]] = None,
    test_transform: Optional[Dict[str, Callable]] = None,
    predict_transform: Optional[Dict[str, Callable]] = None,
    data_fetcher: Optional[BaseDataFetcher] = None,
    input_transform: Optional[InputTransform] = None,
    val_split: Optional[float] = None,
    batch_size: int = 4,
    num_workers: int = 0,
    **input_transform_kwargs: Any,
    """This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and passes them
    through to the :meth:`` method underneath.

        train_bunch: The scikit-learn ``Bunch`` containing the train data.
        val_bunch: The scikit-learn ``Bunch`` containing the validation data.
        test_bunch: The scikit-learn ``Bunch`` containing the test data.
        predict_bunch: The scikit-learn ``Bunch`` containing the predict data.
        train_transform: The dictionary of transforms to use during training which maps
            :class:`` hook names to callable transforms.
        val_transform: The dictionary of transforms to use during validation which maps
            :class:`` hook names to callable transforms.
        test_transform: The dictionary of transforms to use during testing which maps
            :class:`` hook names to callable transforms.
        predict_transform: The dictionary of transforms to use during predicting which maps
            :class:`` hook names to callable transforms.
        data_fetcher: The :class:`` to pass to the
        input_transform: The :class:`` to pass to the
            :class:``. If ``None``, ``cls.input_transform_cls`` will be
            constructed and used.
        val_split: The ``val_split`` argument to pass to the :class:``.
        batch_size: The ``batch_size`` argument to pass to the :class:``.
        num_workers: The ``num_workers`` argument to pass to the :class:``.
        input_transform_kwargs: Additional keyword arguments to use when constructing the input_transform.
            Will only be used if ``input_transform = None``.

        The constructed data module.
    return super().from_input(

The final step is to implement the num_features property for our TemplateData. This is just a convenience for the user that finds the num_features attribute on any of the data sets and returns it. Here’s the code:

def num_features(self) -> Optional[int]:
    """Tries to get the ``num_features`` from each dataset in turn and returns the output."""
    n_fts_train = getattr(self.train_dataset, "num_features", None)
    n_fts_val = getattr(self.val_dataset, "num_features", None)
    n_fts_test = getattr(self.test_dataset, "num_features", None)
    return n_fts_train or n_fts_val or n_fts_test


An optional step is to implement a BaseVisualization. The BaseVisualization lets you control how data at various points in the pipeline can be visualized. This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms.


Don’t worry about implementing it right away, you can always come back and add it later!

Here’s the code for our TemplateVisualization which just prints the data:

class TemplateVisualization(BaseVisualization):
    """The ``TemplateVisualization`` class is a :class:`` that just
    prints the data.

    If you want to provide a visualization with your task, you can override these hooks.

    def show_load_sample(self, samples: List[Any], running_stage: RunningStage):

    def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage):

We can configure our custom visualization in the TemplateData using configure_data_fetcher() like this:

def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
    """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher``
    return TemplateVisualization(*args, **kwargs)


OutputTransform contains any transforms that need to be applied after the model. You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. As an example, here’s the TextClassificationOutputTransform which gets the logits from a SequenceClassifierOutput:

class TextClassificationOutputTransform(OutputTransform):
    def per_batch_transform(self, batch: Any) -> Any:
        if isinstance(batch, SequenceClassifierOutput):
            batch = batch.logits
        return super().per_batch_transform(batch)

In your Input or InputTransform, you can add metadata to the batch using the METADATA key. Your OutputTransform can then use this metadata in its transforms. You should use this approach if your postprocessing depends on the state of the input before the InputTransform transforms. For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the METADATA. Here’s an example from the SemanticSegmentationNumpyInput:

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
    sample[DataKeys.INPUT] = torch.from_numpy(sample[DataKeys.INPUT])
    if DataKeys.TARGET in sample:
        sample[DataKeys.TARGET] = torch.from_numpy(sample[DataKeys.TARGET])
    return super().load_sample(sample)

The METADATA can now be referenced in your OutputTransform. For example, here’s the code for the per_sample_transform method of the SemanticSegmentationOutputTransform:

def per_sample_transform(self, sample: Any) -> Any:
    resize = K.geometry.Resize(sample[DataKeys.METADATA]["size"], interpolation="bilinear")
    sample[DataKeys.PREDS] = resize(sample[DataKeys.PREDS])
    sample[DataKeys.INPUT] = resize(sample[DataKeys.INPUT])
    return super().per_sample_transform(sample)

Now that you’ve got some data, it’s time to add some backbones for your task!

Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.