
Multi-label Image Classification

The Task

Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality (images in this case). Multi-label image classification is supported by the ImageClassifier via the multi-label argument.


Let’s look at the task of trying to predict the movie genres from an image of the movie poster. The data we will use is a subset of the awesome movie poster genre prediction data set from the paper “Movie Genre Classification based on Poster Images with Deep Neural Networks” by Wei-Ta Chu and Hung-Jui Guo, resized to 128 by 128. Take a look at their paper (and please consider citing their paper if you use the data) here: The data set contains train and validation folders, and then each folder contains images and a metadata.csv which stores the labels. Here’s an overview:

├── train
│   ├── metadata.csv
│   ├── tt0084058.jpg
│   ├── tt0084867.jpg
│   ...
└── val
    ├── metadata.csv
    ├── tt0200465.jpg
    ├── tt0326965.jpg

Once we’ve downloaded the data using download_data(), we need to create the ImageClassificationData. We first create a function (load_data) to extract the list of images and associated labels which can then be passed to from_files(). We select a pre-trained backbone to use for our ImageClassifier and fine-tune on the posters data. We then use the trained ImageClassifier for inference. Finally, we save the model. Here’s the full example:

import os

import flash
import torch
from import download_data
from flash.image import ImageClassificationData, ImageClassifier

# 1. Create the DataModule
# Data set from the paper "Movie Genre Classification based on Poster Images with Deep Neural Networks".
# More info here:

def resolver(root, file_id):
    return os.path.join(root, f"{file_id}.jpg")

datamodule = ImageClassificationData.from_csv(
    ["Action", "Romance", "Crime", "Thriller", "Adventure"],
    transform_kwargs={"image_size": (128, 128)},

# 2. Build the task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels, multi_label=datamodule.multi_label)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict the genre of a few movies!
datamodule = ImageClassificationData.from_files(
predictions = trainer.predict(model, datamodule=datamodule, output="labels")

# 5. Save the model!

To learn how to view the available backbones / heads for this task, see Backbones and Heads.

Flash Zero

The multi-label image classifier can be used directly from the command line with zero code using Flash Zero. You can run the movie posters example with:

flash image_classification from_movie_posters

To view configuration options and options for running the image classifier with your own data, use:

flash image_classification --help


The ImageClassifier is servable. For more information, see Image Classification.