The Data

The first step to contributing a task is to implement the classes we need to load some data. Inside you should implement:

  1. some Input classes (optional)

  2. a InputTransform

  3. a DataModule

  4. a BaseVisualization (optional)

  5. a OutputTransform (optional)


The Input class contains the logic for data loading from different sources such as folders, files, tensors, etc. Every Flash DataModule can be instantiated with from_datasets(). For each additional way you want the user to be able to instantiate your DataModule, you’ll need to create a Input. Each Input has 2 methods:

  • load_data() takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata.

  • load_sample() then takes as input a single element from the output of load_data and returns a sample.

By default these methods just return their input, so you don’t need both a load_data() and a load_sample() to create a Input. Where possible, you should override one of our existing Input classes.

Let’s start by implementing a TemplateNumpyClassificationInput, which overrides ClassificationInputMixin. The main Input method that we have to implement is load_data(). ClassificationInputMixin provides utilities for handling targets within flash which need to be called from the load_data() and load_sample(). In this Input, we’ll also set the num_features attribute so that we can access it later.

Here’s the code for our TemplateNumpyClassificationInput.load_data method:

def load_data(
    examples: Collection[np.ndarray],
    targets: Optional[Sequence[Any]] = None,
    target_formatter: Optional[TargetFormatter] = None,
) -> Sequence[Dict[str, Any]]:
    """Sets the ``num_features`` attribute and calls ``super().load_data``.

        examples: The ``np.ndarray`` (num_examples x num_features).
        targets: Associated targets.
        target_formatter: Optionally provide a ``TargetFormatter`` to control how targets are formatted.

        A sequence of samples / sample metadata.
    if not self.predicting and isinstance(examples, np.ndarray):
        self.num_features = examples.shape[1]
    if targets is not None:
        self.load_target_metadata(targets, target_formatter=target_formatter)
    return to_samples(examples, targets)

and here’s the code for the TemplateNumpyClassificationInput.load_sample method:

def load_sample(self, sample: Dict[str, Any]) -> Any:
    if DataKeys.TARGET in sample:
        sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET])
    return sample


Later, when we add our DataModule implementation, we’ll make num_features available to the user.

For our template Task, it would be cool if the user could provide a scikit-learn Bunch as the data source. To achieve this, we’ll add a TemplateSKLearnClassificationInput whose load_data expects a Bunch as input. We override our TemplateNumpyClassificationInput so that we can call super with the data and targets extracted from the Bunch. We perform two additional steps here to improve the user experience:

  1. We set the num_classes attribute on the dataset. If num_classes is set, it is automatically made available as a property of the DataModule.

  2. We create and set a ClassificationState. The labels provided here will be shared with the Labels output, so the user doesn’t need to provide them.

Here’s the code for the TemplateSKLearnClassificationInput.load_data method:

def load_data(self, data: Bunch, target_formatter: Optional[TargetFormatter] = None) -> Sequence[Dict[str, Any]]:
    """Gets the ``data`` and ``target`` attributes from the ``Bunch`` and passes them to ``super().load_data``.

        data: The scikit-learn data ``Bunch``.
        target_formatter: Optionally provide a ``TargetFormatter`` to control how targets are formatted.

        A sequence of samples / sample metadata.
    return super().load_data(,, target_formatter=target_formatter)

We can customize the behaviour of our load_data() for different stages, by prepending train, val, test, or predict. For our TemplateSKLearnClassificationInput, we don’t want to provide any targets to the model when predicting. We can implement predict_load_data like this:

def predict_load_data(self, data: Bunch) -> Sequence[Dict[str, Any]]:
    """Avoid including targets when predicting.

        data: The scikit-learn data ``Bunch``.

        A sequence of samples / sample metadata.
    return super().load_data(


The InputTransform object contains all the data transforms. Internally we inject the InputTransform transforms at several points along the pipeline.

Defining the standard transforms (typically at least a per_sample_transform should be defined) for your InputTransform involves simply overriding the required hook to return a callable transform.

For our TemplateInputTransform, we’ll just configure an input_per_sample_transform and a target_per_sample_transform. Let’s first define a to_tensor transform as a staticmethod:

def input_to_tensor(input: np.ndarray):
    """Transform which creates a tensor from the given numpy ``ndarray`` and converts it to ``float``"""
    return torch.from_numpy(input).float()

Now in our input_per_sample_transform hook, we return the transform:

def input_per_sample_transform(self) -> Callable:
    return self.input_to_tensor

To convert the targets to a tensor we can simply use torch.as_tensor. Here’s our target_per_sample_transform:

def target_per_sample_transform(self) -> Callable:
    return self.target_to_tensor


The DataModule is responsible for creating the DataLoader and injecting the transforms for each stage. When the user calls a from_* method (such as from_numpy()), the following steps take place:

  1. The from_() method is called with the name of the Input to use and the inputs to provide to load_data() for each stage.

  2. The InputTransform is created from cls.input_transform_cls (if it wasn’t provided by the user) with any provided transforms.

  3. The Input of the provided name is retrieved from the InputTransform.

  4. A BaseAutoDataset is created from the Input for each stage.

  5. The DataModule is instantiated with the data sets.

To create our TemplateData DataModule, we first need to attach our input transform class like this:

input_transform_cls = TemplateInputTransform

Since we provided a NUMPY Input in the TemplateInputTransform, from_numpy() will now work with our TemplateData.

If you’ve defined a fully custom Input (like our TemplateSKLearnClassificationInput), then you will need to write a from_* method for each. Here’s the from_sklearn method for our TemplateData:

def from_sklearn(
    train_bunch: Optional[Bunch] = None,
    val_bunch: Optional[Bunch] = None,
    test_bunch: Optional[Bunch] = None,
    predict_bunch: Optional[Bunch] = None,
    train_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform,
    val_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform,
    test_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform,
    predict_transform: INPUT_TRANSFORM_TYPE = TemplateInputTransform,
    input_cls: Type[Input] = TemplateSKLearnClassificationInput,
    transform_kwargs: Optional[Dict] = None,
    **data_module_kwargs: Any,
) -> "TemplateData":
    """This is our custom ``from_*`` method. It expects scikit-learn ``Bunch`` objects as input and creates the
    ``TemplateData`` with them.

        train_bunch: The scikit-learn ``Bunch`` containing the train data.
        val_bunch: The scikit-learn ``Bunch`` containing the validation data.
        test_bunch: The scikit-learn ``Bunch`` containing the test data.
        predict_bunch: The scikit-learn ``Bunch`` containing the predict data.
        train_transform: The dictionary of transforms to use during training which maps
            :class:`` hook names to callable transforms.
        val_transform: The dictionary of transforms to use during validation which maps
            :class:`` hook names to callable transforms.
        test_transform: The dictionary of transforms to use during testing which maps
            :class:`` hook names to callable transforms.
        predict_transform: The dictionary of transforms to use during predicting which maps
            :class:`` hook names to callable transforms.

        The constructed data module.
    ds_kw = dict(

    train_input = input_cls(RunningStage.TRAINING, train_bunch, transform=train_transform, **ds_kw)
    target_formatter = getattr(train_input, "target_formatter", None)

    return cls(
        input_cls(RunningStage.PREDICTING, predict_bunch, transform=predict_transform, **ds_kw),

The final step is to implement the num_features property for our TemplateData. This is just a convenience for the user that finds the num_features attribute on any of the data sets and returns it. Here’s the code:

def num_features(self) -> Optional[int]:
    """Tries to get the ``num_features`` from each dataset in turn and returns the output."""
    n_fts_train = getattr(self.train_dataset, "num_features", None)
    n_fts_val = getattr(self.val_dataset, "num_features", None)
    n_fts_test = getattr(self.test_dataset, "num_features", None)
    return n_fts_train or n_fts_val or n_fts_test


An optional step is to implement a BaseVisualization. The BaseVisualization lets you control how data at various points in the pipeline can be visualized. This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms.


Don’t worry about implementing it right away, you can always come back and add it later!

Here’s the code for our TemplateVisualization which just prints the data:

class TemplateVisualization(BaseVisualization):
    """The ``TemplateVisualization`` class is a :class:`` that just
    prints the data.

    If you want to provide a visualization with your task, you can override these hooks.

    def show_load_sample(self, samples: List[Any], running_stage: RunningStage):

    def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage):

We can configure our custom visualization in the TemplateData using configure_data_fetcher() like this:

def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
    """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher``
    return TemplateVisualization(*args, **kwargs)


OutputTransform contains any transforms that need to be applied after the model. You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. As an example, here’s the SemanticSegmentationOutputTransform which decodes tokenized model outputs:

class SemanticSegmentationOutputTransform(OutputTransform):
    def per_sample_transform(self, sample: Any) -> Any:
        resize = K.geometry.Resize(sample[DataKeys.METADATA]["size"], interpolation="bilinear")
        sample[DataKeys.PREDS] = resize(sample[DataKeys.PREDS])
        sample[DataKeys.INPUT] = resize(sample[DataKeys.INPUT])
        return super().per_sample_transform(sample)

In your Input or InputTransform, you can add metadata to the batch using the METADATA key. Your OutputTransform can then use this metadata in its transforms. You should use this approach if your postprocessing depends on the state of the input before the InputTransform transforms. For example, if you want to resize the predictions to the original size of the inputs you should add the original image size in the METADATA. Here’s an example from the SemanticSegmentationNumpyInput:

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
    sample[DataKeys.INPUT] = torch.from_numpy(sample[DataKeys.INPUT])
    if DataKeys.TARGET in sample:
        sample[DataKeys.TARGET] = torch.from_numpy(sample[DataKeys.TARGET])
    return super().load_sample(sample)

The METADATA can now be referenced in your OutputTransform. For example, here’s the code for the per_sample_transform method of the SemanticSegmentationOutputTransform:

def per_sample_transform(self, sample: Any) -> Any:
    resize = K.geometry.Resize(sample[DataKeys.METADATA]["size"], interpolation="bilinear")
    sample[DataKeys.PREDS] = resize(sample[DataKeys.PREDS])
    sample[DataKeys.INPUT] = resize(sample[DataKeys.INPUT])
    return super().per_sample_transform(sample)

Now that you’ve got some data, it’s time to add some backbones for your task!

Read the Docs v: stable
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.