Shortcuts

The Backbones

Now that you’ve got a way of loading data, you should implement some backbones to use with your Task. Create a FlashRegistry to use with your Task in backbones.py.

The registry allows you to register backbones for your task that can be selected by the user. The backbones can come from anywhere as long as you can register a function that loads the backbone. Furthermore, the user can add their own models to the existing backbones, without having to write their own Task!

You can create a registry like this:

TEMPLATE_BACKBONES = FlashRegistry("backbones")

Let’s add a simple MLP backbone to our registry. We need a function that creates the backbone and returns it along with the output size (so that we can create the model head in our Task). You can use any name for the function, although we use load_{model name} by convention. You also need to provide name and namespace of the backbone. The standard for namespace is data_type/task_type, so for an image classification task the namespace will be image/classification. Here’s the code:

@TEMPLATE_BACKBONES(name="mlp-128", namespace="template/classification")
def load_mlp_128(num_features, **_):
    """A simple MLP backbone with 128 hidden units."""
    return (
        nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(True),
            nn.BatchNorm1d(128),
        ),
        128,
    )

Here’s another example with a slightly more complex model:

@TEMPLATE_BACKBONES(name="mlp-128-256", namespace="template/classification")
def load_mlp_128_256(num_features, **_):
    """An two layer MLP backbone with 128 and 256 hidden units respectively."""
    return (
        nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(True),
            nn.BatchNorm1d(128),
            nn.Linear(128, 256),
            nn.ReLU(True),
            nn.BatchNorm1d(256),
        ),
        256,
    )

Here’s a another example, which adds DINO pretrained model from PyTorch Hub to the IMAGE_CLASSIFIER_BACKBONES, from flash/image/classification/backbones/transformers.py:

def dino_vitb16(*_, **__):
    backbone = torch.hub.load("facebookresearch/dino:main", "dino_vitb16")
    return backbone, 768

Once you’ve got some data and some backbones, implement your task!

Read the Docs v: latest
Versions
latest
stable
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.