Shortcuts

The Task

Once you’ve implemented a Flash DataModule and some backbones, you should implement your Task in model.py. The Task is responsible for: setting up the backbone, performing the forward pass of the model, and calculating the loss and any metrics. Remember that, under the hood, the Flash Task is simply a LightningModule with some helpful defaults.

To build your task, you can start by overriding the base Task or any of the existing Task implementations. For example, in our scikit-learn example, we can just override ClassificationTask which provides good defaults for classification.

You should attach your backbones registry as a class attribute like this:

class TemplateSKLearnClassifier(ClassificationTask):
    backbones: FlashRegistry = TEMPLATE_BACKBONES

Model architecture and hyper-parameters

In the __init__(), you will need to configure defaults for the:

  • loss function

  • optimizer

  • metrics

  • backbone / model

You will also need to create the backbone from the registry and create the model head. Here’s the code:

def __init__(
    self,
    num_features: int,
    num_classes: Optional[int] = None,
    labels: Optional[List[str]] = None,
    backbone: Union[str, Tuple[nn.Module, int]] = "mlp-128",
    backbone_kwargs: Optional[Dict] = None,
    loss_fn: LOSS_FN_TYPE = None,
    optimizer: OPTIMIZER_TYPE = "Adam",
    lr_scheduler: LR_SCHEDULER_TYPE = None,
    metrics: METRICS_TYPE = None,
    learning_rate: Optional[float] = None,
    multi_label: bool = False,
):
    self.save_hyperparameters()

    if labels is not None and num_classes is None:
        num_classes = len(labels)

    super().__init__(
        model=None,
        loss_fn=loss_fn,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        metrics=metrics,
        learning_rate=learning_rate,
        multi_label=multi_label,
        num_classes=num_classes,
        labels=labels,
    )

    if not backbone_kwargs:
        backbone_kwargs = {}

    if isinstance(backbone, tuple):
        self.backbone, out_features = backbone
    else:
        self.backbone, out_features = self.backbones.get(backbone)(num_features=num_features, **backbone_kwargs)

    self.head = nn.Linear(out_features, num_classes)

Note

We call save_hyperparameters() to log the arguments to the __init__ as hyperparameters. Read more here.

Adding the model routines

You should override the {train,val,test,predict}_step methods. The default {train,val,test,predict}_step implementations in Task expect a tuple containing the input (to be passed to the model) and target (to be used when computing the loss), and should be suitable for most applications. In our template example, we just extract the input and target from the input mapping and forward them to the super methods. Here’s the code for the training_step:

def training_step(self, batch: Any, batch_idx: int) -> Any:
    """For the training step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` and
    :attr:`~flash.core.data.io.input.DataKeys.TARGET` keys from the input and forward them to the
    :meth:`~flash.core.model.Task.training_step`."""
    batch = (batch[DataKeys.INPUT], batch[DataKeys.TARGET])
    return super().training_step(batch, batch_idx)

We use the same code for the validation_step and test_step. For predict_step we don’t need the targets, so our code looks like this:

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
    """For the predict step, we just extract the :attr:`~flash.core.data.io.input.DataKeys.INPUT` key from the input
    and forward it to the :meth:`~flash.core.model.Task.predict_step`."""
    batch = batch[DataKeys.INPUT]
    return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)

Note

You can completely replace the {train,val,test,predict}_step methods (that is, without a call to super) if you need more custom behaviour for your Task at a particular stage.

Finally, we use our backbone and head in a custom forward pass:

def forward(self, x) -> Tensor:
    """First call the backbone, then the model head."""
    x = self.backbone(x)
    return self.head(x)

Now that you’ve got your task, take a look at some optional advanced features you can add or go ahead and create some examples showing your task in action!

Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.