Shortcuts

The Example

Now you’ve implemented your task, it’s time to add an example showing how cool it is! We usually provide one example in examples/. You can base these off of our template.py examples.

The example should:

  1. download the data (we’ll add the example to our CI later on, so choose a dataset small enough that it runs in reasonable time)

  2. load the data into a DataModule

  3. create an instance of the Task

  4. create a Trainer

  5. call finetune() or fit() to train your model

  6. generate predictions for a few examples

  7. save the checkpoint

For our template example we don’t have a pretrained backbone, so we can just call fit() rather than finetune(). Here’s the full example (examples/template.py):

import flash
import numpy as np
import torch
from flash.template import TemplateData, TemplateSKLearnClassifier
from sklearn import datasets

# 1. Create the DataModule
datamodule = TemplateData.from_sklearn(
    train_bunch=datasets.load_iris(),
    val_split=0.1,
    batch_size=4,
)

# 2. Build the task
model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Classify a few examples
datamodule = TemplateData.from_numpy(
    predict_data=[
        np.array([4.9, 3.0, 1.4, 0.2]),
        np.array([6.9, 3.2, 5.7, 2.3]),
        np.array([7.2, 3.0, 5.8, 1.6]),
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("template_model.pt")

We get this output:

['setosa', 'virginica', 'versicolor']

Now that you’ve got an example showing your awesome task in action, it’s time to write some tests!