Shortcuts

Multi-label Image Classification

The Task

Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality (images in this case). Multi-label image classification is supported by the ImageClassifier via the multi-label argument.


Example

Let’s look at the task of trying to predict the movie genres from an image of the movie poster. The data we will use is a subset of the awesome movie poster genre prediction data set from the paper “Movie Genre Classification based on Poster Images with Deep Neural Networks” by Wei-Ta Chu and Hung-Jui Guo, resized to 128 by 128. Take a look at their paper (and please consider citing their paper if you use the data) here: www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/. The data set contains train and validation folders, and then each folder contains images and a metadata.csv which stores the labels. Here’s an overview:

movie_posters
├── train
│   ├── metadata.csv
│   ├── tt0084058.jpg
│   ├── tt0084867.jpg
│   ...
└── val
    ├── metadata.csv
    ├── tt0200465.jpg
    ├── tt0326965.jpg
    ...

Once we’ve downloaded the data using download_data(), we need to create the ImageClassificationData. We first create a function (load_data) to extract the list of images and associated labels which can then be passed to from_files(). We select a pre-trained backbone to use for our ImageClassifier and fine-tune on the posters data. We then use the trained ImageClassifier for inference. Finally, we save the model. Here’s the full example:

import os

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# 1. Create the DataModule
# Data set from the paper "Movie Genre Classification based on Poster Images with Deep Neural Networks".
# More info here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/
download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip")


def resolver(root, file_id):
    return os.path.join(root, f"{file_id}.jpg")


datamodule = ImageClassificationData.from_csv(
    "Id",
    ["Action", "Romance", "Crime", "Thriller", "Adventure"],
    train_file="data/movie_posters/train/metadata.csv",
    train_resolver=resolver,
    val_file="data/movie_posters/val/metadata.csv",
    val_resolver=resolver,
    transform_kwargs={"image_size": (128, 128)},
    batch_size=1,
)

# 2. Build the task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels, multi_label=datamodule.multi_label)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict the genre of a few movies!
datamodule = ImageClassificationData.from_files(
    predict_files=[
        "data/movie_posters/predict/tt0085318.jpg",
        "data/movie_posters/predict/tt0089461.jpg",
        "data/movie_posters/predict/tt0097179.jpg",
    ],
    batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("image_classification_multi_label_model.pt")

Flash Zero

The multi-label image classifier can be used directly from the command line with zero code using Flash Zero. You can run the movie posters example with:

flash image_classification from_movie_posters

To view configuration options and options for running the image classifier with your own data, use:

flash image_classification --help

Serving

The ImageClassifier is servable. For more information, see Image Classification.

Read the Docs v: 0.7.0
Versions
latest
stable
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.