Semantic Segmentation¶
The Task¶
Semantic Segmentation, or image segmentation, is the task of performing classification at a pixel-level, meaning each pixel will associated to a given class. See more: https://paperswithcode.com/task/semantic-segmentation
Example¶
Let’s look at an example using a data set generated with the CARLA driving simulator. The data was generated as part of the Kaggle Lyft Udacity Challenge. The data contains one folder of images and another folder with the corresponding segmentation masks. Here’s the structure:
data
├── CameraRGB
│ ├── F61-1.png
│ ├── F61-2.png
│ ...
└── CameraSeg
├── F61-1.png
├── F61-2.png
...
Once we’ve downloaded the data using download_data()
, we create the SemanticSegmentationData
.
We select a pre-trained mobilenet_v3_large
backbone with an fpn
head to use for our SemanticSegmentation
task and fine-tune on the CARLA data.
We then use the trained SemanticSegmentation
for inference. You can check the available pretrained weights for the backbones like this SemanticSegmentation.available_pretrained_weights(“resnet18”).
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData
# 1. Create the DataModule
# The data was generated with the CARLA self-driving simulator as part of the Kaggle Lyft Udacity Challenge.
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
download_data(
"https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
"./data",
)
datamodule = SemanticSegmentationData.from_folders(
train_folder="data/CameraRGB",
train_target_folder="data/CameraSeg",
val_split=0.1,
transform_kwargs=dict(image_size=(256, 256)),
num_classes=21,
batch_size=4,
)
# 2. Build the task
model = SemanticSegmentation(
backbone="mobilenetv3_large_100",
head="fpn",
num_classes=datamodule.num_classes,
)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Segment a few images!
datamodule = SemanticSegmentationData.from_files(
predict_files=[
"data/CameraRGB/F61-1.png",
"data/CameraRGB/F62-1.png",
"data/CameraRGB/F63-1.png",
],
batch_size=3,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("semantic_segmentation_model.pt")
To learn how to view the available backbones / heads for this task, see Backbones and Heads.
Flash Zero¶
The semantic segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash semantic_segmentation
To view configuration options and options for running the semantic segmentation task with your own data, use:
flash semantic_segmentation --help
Serving¶
The SemanticSegmentation
task is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.image import SemanticSegmentation
from flash.image.segmentation.output import SegmentationLabelsOutput
model = SemanticSegmentation.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.7.0/semantic_segmentation_model.pt"
)
model.output = SegmentationLabelsOutput(visualize=False)
model.serve()
You can now perform inference from your client like this:
import base64
from pathlib import Path
import requests
import flash
with (Path(flash.ASSETS_ROOT) / "road.png").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())