Beta
Point cloud segmentation is currently in Beta. The API and functionality may change without warning in future releases. More details.
Point Cloud Segmentation¶
The Task¶
A Point Cloud is a set of data points in space, usually describes by x
, y
and z
coordinates.
PointCloud Segmentation is the task of performing classification at a point-level, meaning each point will associated to a given class. The current integration builds on top Open3D-ML.
Example¶
Let’s look at an example using a data set generated from the KITTI Vision Benchmark.
The data are a tiny subset of the original dataset and contains sequences of point clouds.
The data contains multiple folder, one for each sequence and a meta.yaml file describing the classes and their official associated color map.
A sequence should contain one folder for scans and one folder for labels, plus a pose.txt
to re-align the sequence if required.
Here’s the structure:
data
├── meta.yaml
├── 00
│ ├── scans
| | ├── 00000.bin
| | ├── 00001.bin
| | ...
│ ├── labels
| | ├── 00000.label
| | ├── 00001.label
| | ...
| ├── pose.txt
│ ...
|
└── XX
├── scans
| ├── 00000.bin
| ├── 00001.bin
| ...
├── labels
| ├── 00000.label
| ├── 00001.label
| ...
├── pose.txt
Learn more: http://www.semantic-kitti.org/dataset.html
Once we’ve downloaded the data using download_data()
, we create the PointCloudSegmentationData
.
We select a pre-trained randlanet_semantic_kitti
backbone for our PointCloudSegmentation
task.
We then use the trained PointCloudSegmentation
for inference.
Finally, we save the model.
Here’s the full example:
import flash
import torch
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData
# 1. Create the DataModule
# Dataset Credit: http://www.semantic-kitti.org/
download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/")
datamodule = PointCloudSegmentationData.from_folders(
train_folder="data/SemanticKittiTiny/train",
val_folder="data/SemanticKittiTiny/val",
batch_size=4,
)
# 2. Build the task
model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule)
# 4. Predict what's within a few PointClouds?
datamodule = PointCloudSegmentationData.from_files(
predict_files=[
"data/SemanticKittiTiny/predict/000000.bin",
"data/SemanticKittiTiny/predict/000001.bin",
],
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("pointcloud_segmentation_model.pt")
To learn how to view the available backbones / heads for this task, see Backbones and Heads.
Flash Zero¶
The point cloud segmentation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash pointcloud_segmentation
To view configuration options and options for running the point cloud segmentation task with your own data, use:
flash pointcloud_segmentation --help