Object Detection¶
The Task¶
Object detection is the task of identifying objects in images and their associated classes and bounding boxes.
The ObjectDetector
and ObjectDetectionData
classes internally rely on IceVision.
Example¶
Let’s look at object detection with the COCO 128 data set, which contains 80 object classes. This is a subset of COCO train2017 with only 128 images. The data set is organized following the COCO format. Here’s an outline:
coco128
├── annotations
│ └── instances_train2017.json
├── images
│ └── train2017
│ ├── 000000000009.jpg
│ ├── 000000000025.jpg
│ ...
└── labels
└── train2017
├── 000000000009.txt
├── 000000000025.txt
...
Once we’ve downloaded the data using download_data()
, we can create the ObjectDetectionData
.
We select a pre-trained EfficientDet to use for our ObjectDetector
and fine-tune on the COCO 128 data.
We then use the trained ObjectDetector
for inference.
Finally, we save the model.
Here’s the full example:
import flash
from flash.core.data.utils import download_data
from flash.image import ObjectDetectionData, ObjectDetector
# 1. Create the DataModule
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")
datamodule = ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=0.1,
transform_kwargs={"image_size": 512},
batch_size=4,
)
# 2. Build the task
model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=512)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Detect objects in a few images!
datamodule = ObjectDetectionData.from_files(
predict_files=[
"data/coco128/images/train2017/000000000625.jpg",
"data/coco128/images/train2017/000000000626.jpg",
"data/coco128/images/train2017/000000000629.jpg",
],
transform_kwargs={"image_size": 512},
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("object_detection_model.pt")
To learn how to view the available backbones / heads for this task, see Backbones and Heads.
Flash Zero¶
The object detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash object_detection
To view configuration options and options for running the object detector with your own data, use:
flash object_detection --help
Custom Transformations¶
Flash automatically applies some default image / mask transformations and augmentations, but you may wish to customize these for your own use case.
The base InputTransform
defines 7 hooks for different stages in the data loading pipeline.
For object-detection tasks, you can leverage the transformations from Albumentations with the IceVisionTransformAdapter
,
creating a subclass of InputTransform
from dataclasses import dataclass
import albumentations as alb
from icevision.tfms import A
from flash import InputTransform
from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData
@dataclass
class BrightnessContrastTransform(InputTransform):
image_size: int = 128
def per_sample_transform(self):
return IceVisionTransformAdapter(
[*A.aug_tfms(size=self.image_size), A.Normalize(), alb.RandomBrightnessContrast()]
)
datamodule = ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=0.1,
transform=BrightnessContrastTransform,
batch_size=4,
)
Serving¶
The ObjectDetector
is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.image import ObjectDetector
model = ObjectDetector.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.8.0/object_detection_model.pt")
model.serve()
You can now perform inference from your client like this:
import base64
from pathlib import Path
import flash
import requests
with (Path(flash.ASSETS_ROOT) / "fish.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())