Shortcuts

Image Classification

The Task

The task of identifying what is in an image is called image classification. Typically, Image Classification is used to identify images containing a single object. The task predicts which ‘class’ the image most likely belongs to with a degree of certainty. A class is a label that describes what is in an image, such as ‘car’, ‘house’, ‘cat’ etc.


Example

Let’s look at the task of predicting whether images contain Ants or Bees using the hymenoptera dataset. The dataset contains train and validation folders, and then each folder contains a bees folder, with pictures of bees, and an ants folder with images of, you guessed it, ants.

hymenoptera_data
├── train
│   ├── ants
│   │   ├── 0013035.jpg
│   │   ├── 1030023514_aad5c608f9.jpg
│   │   ...
│   └── bees
│       ├── 1092977343_cb42b38d62.jpg
│       ├── 1093831624_fb5fbe2308.jpg
│       ...
└── val
    ├── ants
    │   ├── 10308379_1b6c72e180.jpg
    │   ├── 1053149811_f62a3410d3.jpg
    │   ...
    └── bees
        ├── 1032546534_06907fe3b3.jpg
        ├── 10870992_eebeeb3a12.jpg
        ...

Once we’ve downloaded the data using download_data(), we create the ImageClassificationData. We select a pre-trained backbone to use for our ImageClassifier and fine-tune on the hymenoptera data. We then use the trained ImageClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/",
    val_folder="data/hymenoptera_data/val/",
)

# 2. Build the task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict what's on a few images! ants or bees?
predictions = model.predict(
    [
        "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
        "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
        "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

Flash Zero

The image classifier can be used directly from the command line with zero code using Flash Zero. You can run the hymenoptera example with:

flash image_classification

To view configuration options and options for running the image classifier with your own data, use:

flash image_classification --help

Loading Data

This section details the available ways to load your own data into the ImageClassificationData.

from_folders

Construct the ImageClassificationData from folders.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, .npy.

For train, test, and val data, the folders are expected to contain a sub-folder for each class. Here’s the required structure:

train_folder
├── class_1
│   ├── file1.jpg
│   ├── file2.jpg
│   ...
└── class_2
    ├── file1.jpg
    ├── file2.jpg
    ...

For prediction, the folder is expected to contain the files for inference, like this:

predict_folder
├── file1.jpg
├── file2.jpg
...

Example:

data_module = ImageClassificationData.from_folders(
    train_folder = "./train_folder",
    predict_folder = "./predict_folder",
    ...
)

from_files

Construct the ImageClassificationData from lists of files and corresponding lists of targets.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp, .npy.

Example:

train_files = ["file1.jpg", "file2.jpg", "file3.jpg", ...]
train_targets = [0, 1, 0, ...]

datamodule = ImageClassificationData.from_files(
    train_files = train_files,
    train_targets = train_targets,
    ...
)

from_datasets

Construct the ImageClassificationData from the given datasets for each stage.

Example:

from torch.utils.data.dataset import Dataset

train_dataset: Dataset = ...

datamodule = ImageClassificationData.from_datasets(
    train_dataset = train_dataset,
    ...
)

Note

The __getitem__ of your datasets should return a dictionary with "input" and "target" keys which map to the input image (as a PIL.Image) and the target (as an int or list of ints) respectively.


Custom Transformations

Flash automatically applies some default image transformations and augmentations, but you may wish to customize these for your own use case. The base Preprocess defines 7 hooks for different stages in the data loading pipeline. To apply image augmentations you can directly import the default_transforms from flash.image.classification.transforms and then merge your custom image transformations with them using the merge_transforms() helper function. Here’s an example where we load the default transforms and merge with custom torchvision transformations. We use the post_tensor_transform hook to apply the transformations after the image has been converted to a torch.Tensor.

from torchvision import transforms as T

import flash
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.transforms import ApplyToKeys, merge_transforms
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.transforms import default_transforms

post_tensor_transform = ApplyToKeys(
    DefaultDataKeys.INPUT,
    T.Compose([T.RandomHorizontalFlip(), T.ColorJitter(), T.RandomAutocontrast(), T.RandomPerspective()]),
)

new_transforms = merge_transforms(default_transforms((64, 64)), {"post_tensor_transform": post_tensor_transform})

datamodule = ImageClassificationData.from_folders(
    train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", train_transform=new_transforms
)

model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

Serving

The ImageClassifier is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.image import ImageClassifier

model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.5.2/image_classification_model.pt"
)
model.serve()

You can now perform inference from your client like this:

import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
    imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Read the Docs v: stable
Versions
latest
stable
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_tabular_forecasting
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.