Shortcuts

Semantic Segmentation

The Task

Semantic Segmentation, or image segmentation, is the task of performing classification at a pixel-level, meaning each pixel will associated to a given class. See more: https://paperswithcode.com/task/semantic-segmentation


Example

Let’s look at an example using a data set generated with the CARLA driving simulator. The data was generated as part of the Kaggle Lyft Udacity Challenge. The data contains one folder of images and another folder with the corresponding segmentation masks. Here’s the structure:

data
├── CameraRGB
│   ├── F61-1.png
│   ├── F61-2.png
│       ...
└── CameraSeg
    ├── F61-1.png
    ├── F61-2.png
        ...

Once we’ve downloaded the data using download_data(), we create the SemanticSegmentationData. We select a pre-trained mobilenet_v3_large backbone with an fpn head to use for our SemanticSegmentation task and fine-tune on the CARLA data. We then use the trained SemanticSegmentation for inference. You can check the available pretrained weights for the backbones like this SemanticSegmentation.available_pretrained_weights(“resnet18”). Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData

# 1. Create the DataModule
# The data was generated with the  CARLA self-driving simulator as part of the Kaggle Lyft Udacity Challenge.
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
download_data(
    "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
    "./data",
)

datamodule = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    val_split=0.1,
    image_size=(256, 256),
    num_classes=21,
)

# 2. Build the task
model = SemanticSegmentation(
    backbone="mobilenetv3_large_100",
    head="fpn",
    num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Segment a few images!
predictions = model.predict(
    [
        "data/CameraRGB/F61-1.png",
        "data/CameraRGB/F62-1.png",
        "data/CameraRGB/F63-1.png",
    ]
)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("semantic_segmentation_model.pt")

Flash Zero

The semantic segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash semantic_segmentation

To view configuration options and options for running the semantic segmentation task with your own data, use:

flash semantic_segmentation --help

Loading Data

This section details the available ways to load your own data into the SemanticSegmentationData.

from_folders

Construct the SemanticSegmentationData from folders.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp.

For train, test, and val data, we expect a folder containing inputs and another folder containing the masks. Here’s the required structure:

train_folder
├── inputs
│   ├── file1.jpg
│   ├── file2.jpg
│   ...
└── masks
    ├── file1.jpg
    ├── file2.jpg
    ...

For prediction, the folder is expected to contain the files for inference, like this:

predict_folder
├── file1.jpg
├── file2.jpg
...

Example:

data_module = SemanticSegmentationData.from_folders(
    train_folder = "./train_folder/inputs",
    train_target_folder = "./train_folder/masks",
    predict_folder = "./predict_folder",
    ...
)

from_files

Construct the SemanticSegmentationData from lists of input images and corresponding list of target images.

The supported file extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp.

Example:

train_files = ["file1.jpg", "file2.jpg", "file3.jpg", ...]
train_targets = ["mask1.jpg", "mask2.jpg", "mask3.jpg", ...]

datamodule = SemanticSegmentationData.from_files(
    train_files = train_files,
    train_targets = train_targets,
    ...
)

from_datasets

Construct the SemanticSegmentationData from the given datasets for each stage.

Example:

from torch.utils.data.dataset import Dataset

train_dataset: Dataset = ...

datamodule = SemanticSegmentationData.from_datasets(
    train_dataset = train_dataset,
    ...
)

Note

The __getitem__ of your datasets should return a dictionary with "input" and "target" keys which map to the input and target images as tensors.


Serving

The SemanticSegmentation task is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.image import SemanticSegmentation
from flash.image.segmentation.serialization import SegmentationLabels

model = SemanticSegmentation.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
)
model.serializer = SegmentationLabels(visualize=False)
model.serve()

You can now perform inference from your client like this:

import base64
from pathlib import Path

import requests

import flash

with (Path(flash.ASSETS_ROOT) / "road.png").open("rb") as f:
    imgstr = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Read the Docs v: latest
Versions
latest
stable
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.