Multi-label Image Classification¶
The Task¶
Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality (images in this case).
Multi-label image classification is supported by the ImageClassifier
via the multi-label
argument.
Example¶
Let’s look at the task of trying to predict the movie genres from an image of the movie poster.
The data we will use is a subset of the awesome movie poster genre prediction data set from the paper “Movie Genre Classification based on Poster Images with Deep Neural Networks” by Wei-Ta Chu and Hung-Jui Guo, resized to 128 by 128.
Take a look at their paper (and please consider citing their paper if you use the data) here: www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/.
The data set contains train
and validation
folders, and then each folder contains images and a metadata.csv
which stores the labels.
Here’s an overview:
movie_posters
├── train
│ ├── metadata.csv
│ ├── tt0084058.jpg
│ ├── tt0084867.jpg
│ ...
└── val
├── metadata.csv
├── tt0200465.jpg
├── tt0326965.jpg
...
Once we’ve downloaded the data using download_data()
, we need to create the ImageClassificationData
.
We first create a function (load_data
) to extract the list of images and associated labels which can then be passed to from_files()
.
We select a pre-trained backbone to use for our ImageClassifier
and fine-tune on the posters data.
We then use the trained ImageClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import os
import flash
import torch
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# 1. Create the DataModule
# Data set from the paper "Movie Genre Classification based on Poster Images with Deep Neural Networks".
# More info here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/
download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip")
def resolver(root, file_id):
return os.path.join(root, f"{file_id}.jpg")
datamodule = ImageClassificationData.from_csv(
"Id",
["Action", "Romance", "Crime", "Thriller", "Adventure"],
train_file="data/movie_posters/train/metadata.csv",
train_resolver=resolver,
val_file="data/movie_posters/val/metadata.csv",
val_resolver=resolver,
transform_kwargs={"image_size": (128, 128)},
batch_size=1,
)
# 2. Build the task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels, multi_label=datamodule.multi_label)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict the genre of a few movies!
datamodule = ImageClassificationData.from_files(
predict_files=[
"data/movie_posters/predict/tt0085318.jpg",
"data/movie_posters/predict/tt0089461.jpg",
"data/movie_posters/predict/tt0097179.jpg",
],
batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("image_classification_multi_label_model.pt")
To learn how to view the available backbones / heads for this task, see Backbones and Heads.
Flash Zero¶
The multi-label image classifier can be used directly from the command line with zero code using Flash Zero. You can run the movie posters example with:
flash image_classification from_movie_posters
To view configuration options and options for running the image classifier with your own data, use:
flash image_classification --help
Serving¶
The ImageClassifier
is servable.
For more information, see Image Classification.