Shortcuts

The Example

Now you’ve implemented your task, it’s time to add an example showing how cool it is! We usually provide one example in flash_examples/. You can base these off of our template.py examples.

The example should:

  1. download the data (we’ll add the example to our CI later on, so choose a dataset small enough that it runs in reasonable time)

  2. load the data into a DataModule

  3. create an instance of the Task

  4. create a Trainer

  5. call finetune() or fit() to train your model

  6. generate predictions for a few examples

  7. save the checkpoint

For our template example we don’t have a pretrained backbone, so we can just call fit() rather than finetune(). Here’s the full example (flash_examples/template.py):

import numpy as np
import torch
from sklearn import datasets

import flash
from flash.template import TemplateData, TemplateSKLearnClassifier

# 1. Create the DataModule
datamodule = TemplateData.from_sklearn(
    train_bunch=datasets.load_iris(),
    val_split=0.1,
)

# 2. Build the task
model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Classify a few examples
predictions = model.predict(
    [
        np.array([4.9, 3.0, 1.4, 0.2]),
        np.array([6.9, 3.2, 5.7, 2.3]),
        np.array([7.2, 3.0, 5.8, 1.6]),
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("template_model.pt")

We get this output:

['setosa', 'virginica', 'versicolor']

Now that you’ve got an example showing your awesome task in action, it’s time to write some tests!

Read the Docs v: stable
Versions
latest
stable
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_tabular_forecasting
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.