Point Cloud Object Detection¶
The Task¶
A Point Cloud is a set of data points in space, usually describes by x
, y
and z
coordinates.
PointCloud Object Detection is the task of identifying 3D objects in point clouds and their associated classes and 3D bounding boxes.
The current integration builds on top Open3D-ML.
Example¶
Let’s look at an example using a data set generated from the KITTI Vision Benchmark. The data are a tiny subset of the original dataset and contains sequences of point clouds.
- The data contains:
one folder for scans
one folder for scan calibrations
one folder for labels
a meta.yaml file describing the classes and their official associated color map.
Here’s the structure:
data
├── meta.yaml
├── train
│ ├── scans
| | ├── 00000.bin
| | ├── 00001.bin
| | ...
│ ├── calibs
| | ├── 00000.txt
| | ├── 00001.txt
| | ...
│ ├── labels
| | ├── 00000.txt
| | ├── 00001.txt
│ ...
├── val
│ ...
├── predict
├── scans
| ├── 00000.bin
| ├── 00001.bin
|
├── calibs
| ├── 00000.txt
| ├── 00001.txt
├── meta.yaml
Learn more: http://www.semantic-kitti.org/dataset.html
Once we’ve downloaded the data using download_data()
, we create the PointCloudObjectDetectorData
.
We select a pre-trained randlanet_semantic_kitti
backbone for our PointCloudObjectDetector
task.
We then use the trained PointCloudObjectDetector
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData
# 1. Create the DataModule
# Dataset Credit: http://www.semantic-kitti.org/
download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/")
datamodule = PointCloudObjectDetectorData.from_folders(
train_folder="data/KITTI_Tiny/Kitti/train",
val_folder="data/KITTI_Tiny/Kitti/val",
batch_size=4,
)
# 2. Build the task
model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
)
trainer.fit(model, datamodule)
# 4. Predict what's within a few PointClouds?
datamodule = PointCloudObjectDetectorData.from_files(
predict_files=[
"data/KITTI_Tiny/Kitti/predict/scans/000000.bin",
"data/KITTI_Tiny/Kitti/predict/scans/000001.bin",
],
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("pointcloud_detection_model.pt")
Flash Zero¶
The point cloud object detector can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash pointcloud_detection
To view configuration options and options for running the point cloud object detector with your own data, use:
flash pointcloud_detection --help