Shortcuts

Object Detection

The Task

Object detection is the task of identifying objects in images and their associated classes and bounding boxes.

The ObjectDetector and ObjectDetectionData classes internally rely on IceVision.


Example

Let’s look at object detection with the COCO 128 data set, which contains 80 object classes. This is a subset of COCO train2017 with only 128 images. The data set is organized following the COCO format. Here’s an outline:

coco128
├── annotations
│   └── instances_train2017.json
├── images
│   └── train2017
│       ├── 000000000009.jpg
│       ├── 000000000025.jpg
│       ...
└── labels
    └── train2017
        ├── 000000000009.txt
        ├── 000000000025.txt
        ...

Once we’ve downloaded the data using download_data(), we can create the ObjectDetectionData. We select a pre-trained EfficientDet to use for our ObjectDetector and fine-tune on the COCO 128 data. We then use the trained ObjectDetector for inference. Finally, we save the model. Here’s the full example:

import flash
from flash.core.data.utils import download_data
from flash.image import ObjectDetectionData, ObjectDetector

# 1. Create the DataModule
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")

datamodule = ObjectDetectionData.from_coco(
    train_folder="data/coco128/images/train2017/",
    train_ann_file="data/coco128/annotations/instances_train2017.json",
    val_split=0.1,
    transform_kwargs={"image_size": 512},
    batch_size=4,
)

# 2. Build the task
model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=512)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Detect objects in a few images!
datamodule = ObjectDetectionData.from_files(
    predict_files=[
        "data/coco128/images/train2017/000000000625.jpg",
        "data/coco128/images/train2017/000000000626.jpg",
        "data/coco128/images/train2017/000000000629.jpg",
    ],
    transform_kwargs={"image_size": 512},
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("object_detection_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The object detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash object_detection

To view configuration options and options for running the object detector with your own data, use:

flash object_detection --help

Custom Transformations

Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case. The base InputTransform defines 7 hooks for different stages in the data loading pipeline. For object-detection tasks, you can leverage the transformations from Albumentations with the IceVisionTransformAdapter, creating a subclass of InputTransform

from dataclasses import dataclass
import albumentations as alb
from icevision.tfms import A

from flash import InputTransform
from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData


@dataclass
class BrightnessContrastTransform(InputTransform):
    image_size: int = 128

    def per_sample_transform(self):
        return IceVisionTransformAdapter(
            [*A.aug_tfms(size=self.image_size), A.Normalize(), alb.RandomBrightnessContrast()]
        )


datamodule = ObjectDetectionData.from_coco(
    train_folder="data/coco128/images/train2017/",
    train_ann_file="data/coco128/annotations/instances_train2017.json",
    val_split=0.1,
    transform=BrightnessContrastTransform,
    batch_size=4,
)

Serving

The ObjectDetector is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.image import ObjectDetector

model = ObjectDetector.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.8.0/object_detection_model.pt")
model.serve()

You can now perform inference from your client like this:

import base64
from pathlib import Path

import flash
import requests

with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
    imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())