Image Classification¶
The Task¶
The task of identifying what is in an image is called image classification. Typically, Image Classification is used to identify images containing a single object. The task predicts which ‘class’ the image most likely belongs to with a degree of certainty. A class is a label that describes what is in an image, such as ‘car’, ‘house’, ‘cat’ etc.
Example¶
Let’s look at the task of predicting whether images contain Ants or Bees using the hymenoptera dataset.
The dataset contains train
and validation
folders, and then each folder contains a bees folder, with pictures of bees, and an ants folder with images of, you guessed it, ants.
hymenoptera_data
├── train
│ ├── ants
│ │ ├── 0013035.jpg
│ │ ├── 1030023514_aad5c608f9.jpg
│ │ ...
│ └── bees
│ ├── 1092977343_cb42b38d62.jpg
│ ├── 1093831624_fb5fbe2308.jpg
│ ...
└── val
├── ants
│ ├── 10308379_1b6c72e180.jpg
│ ├── 1053149811_f62a3410d3.jpg
│ ...
└── bees
├── 1032546534_06907fe3b3.jpg
├── 10870992_eebeeb3a12.jpg
...
Once we’ve downloaded the data using download_data()
, we create the ImageClassificationData
.
We select a pre-trained backbone to use for our ImageClassifier
and fine-tune on the hymenoptera data.
We then use the trained ImageClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import flash
import torch
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
batch_size=4,
transform_kwargs={"image_size": (196, 196), "mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)},
)
# 2. Build the task
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict what's on a few images! ants or bees?
datamodule = ImageClassificationData.from_files(
predict_files=[
"https://pl-flash-data.s3.amazonaws.com/images/ant_1.jpg",
"https://pl-flash-data.s3.amazonaws.com/images/ant_2.jpg",
"https://pl-flash-data.s3.amazonaws.com/images/bee_1.jpg",
],
batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")
To learn how to view the available backbones / heads for this task, see Backbones and Heads. Benchmarks for backbones provided by PyTorch Image Models (TIMM) can be found here: https://github.com/rwightman/pytorch-image-models/blob/master/results/results-imagenet-real.csv
Flash Zero¶
The image classifier can be used directly from the command line with zero code using Flash Zero. You can run the hymenoptera example with:
flash image_classification
To view configuration options and options for running the image classifier with your own data, use:
flash image_classification --help
Custom Transformations¶
Flash automatically applies some default image transformations and augmentations, but you may wish to customize these for your own use case.
The base InputTransform
defines 7 hooks for different stages in the data loading pipeline.
To apply custom image augmentations you can create your own InputTransform
.
Here’s an example:
from torchvision import transforms as T
from typing import Callable, Tuple, Union
import flash
from flash.image import ImageClassificationData, ImageClassifier
from flash.core.data.transforms import ApplyToKeys
from flash.core.data.io.input_transform import InputTransform
from dataclasses import dataclass
@dataclass
class ImageClassificationInputTransform(InputTransform):
image_size: Tuple[int, int] = (196, 196)
mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)
def per_sample_transform(self):
return T.Compose(
[
ApplyToKeys(
"input",
T.Compose([T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std)]),
),
ApplyToKeys("target", torch.as_tensor),
]
)
def train_per_sample_transform(self):
return T.Compose(
[
ApplyToKeys(
"input",
T.Compose(
[
T.ToTensor(),
T.Resize(self.image_size),
T.Normalize(self.mean, self.std),
T.RandomHorizontalFlip(),
T.ColorJitter(),
T.RandomAutocontrast(),
T.RandomPerspective(),
]
),
),
ApplyToKeys("target", torch.as_tensor),
]
)
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
transform=ImageClassificationInputTransform,
transform_kwargs=dict(image_size=(128, 128)),
batch_size=1,
)
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
Serving¶
The ImageClassifier
is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.image import ImageClassifier
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt"
)
model.serve(output="labels")
You can now perform inference from your client like this:
import base64
from pathlib import Path
import flash
import requests
with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())