Source code for flash.pointcloud.detection.open3d_ml.input
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from os.path import basename, dirname, exists, isdir, isfile, join
from typing import Any, Dict, List, Optional, Type, Union
import yaml
from flash.core.data.io.input import BaseDataFormat, Input
from flash.core.utilities.imports import _TOPIC_POINTCLOUD_AVAILABLE
if _TOPIC_POINTCLOUD_AVAILABLE:
from open3d._ml3d.datasets.kitti import KITTI, DataProcessing
class PointCloudObjectDetectionDataFormat(BaseDataFormat):
KITTI = "kitti"
class BasePointCloudObjectDetectorLoader:
# TODO: Do we need this?
def __init__(self, image_size: tuple = (375, 1242), **loader_kwargs):
self.image_size = image_size
def load_data(self, folder: str, dataset: Input):
raise NotImplementedError
def load_sample(self, metadata: Dict[str, str], has_label: bool = True) -> Any:
raise NotImplementedError
def predict_load_data(self, data: Union[str, List[str]], dataset: Input):
raise NotImplementedError
def predict_load_sample(self, metadata: Any):
raise NotImplementedError
class KITTIPointCloudObjectDetectorLoader(BasePointCloudObjectDetectorLoader):
meta: dict
def __init__(
self,
image_size: tuple = (375, 1242),
scans_folder_name: Optional[str] = "scans",
labels_folder_name: Optional[str] = "labels",
calibrations_folder_name: Optional[str] = "calibs",
):
super().__init__(image_size)
self.scans_folder_name = scans_folder_name
self.labels_folder_name = labels_folder_name
self.calibrations_folder_name = calibrations_folder_name
def load_meta(self, root_dir, dataset: Input):
meta_file = join(root_dir, "meta.yaml")
if not exists(meta_file):
raise ValueError(f"The {root_dir} should contain a `meta.yaml` file about the classes.")
with open(meta_file) as f:
self.meta = yaml.safe_load(f)
if "label_to_names" not in self.meta:
raise ValueError(
f"The {root_dir} should contain a `meta.yaml` file about the classes with the field `label_to_names`."
)
dataset.num_classes = len(self.meta["label_to_names"])
dataset.label_to_names = self.meta["label_to_names"]
dataset.color_map = self.meta["color_map"]
def load_data(self, folder: str, dataset: Input):
sub_directories = os.listdir(folder)
if len(sub_directories) != 3:
raise ValueError(
f"Using KITTI Format, the {folder} should contains 3 directories "
"for ``calibrations``, ``labels`` and ``scans``."
)
assert self.scans_folder_name in sub_directories
assert self.labels_folder_name in sub_directories
assert self.calibrations_folder_name in sub_directories
scans_dir = join(folder, self.scans_folder_name)
labels_dir = join(folder, self.labels_folder_name)
calibrations_dir = join(folder, self.calibrations_folder_name)
scan_paths = [join(scans_dir, f) for f in os.listdir(scans_dir)]
label_paths = [join(labels_dir, f) for f in os.listdir(labels_dir)]
calibration_paths = [join(calibrations_dir, f) for f in os.listdir(calibrations_dir)]
assert len(scan_paths) == len(label_paths) == len(calibration_paths)
self.load_meta(dirname(folder), dataset)
dataset.path_list = scan_paths
return [
{"scan_path": scan_path, "label_path": label_path, "calibration_path": calibration_path}
for scan_path, label_path, calibration_path, in zip(scan_paths, label_paths, calibration_paths)
]
def load_sample(self, metadata: Dict[str, str], has_label: bool = True) -> Any:
pc = KITTI.read_lidar(metadata["scan_path"])
calib = KITTI.read_calib(metadata["calibration_path"])
label = None
if has_label:
label = KITTI.read_label(metadata["label_path"], calib)
reduced_pc = DataProcessing.remove_outside_points(pc, calib["world_cam"], calib["cam_img"], self.image_size)
attr = {
"name": basename(metadata["scan_path"]),
"path": metadata["scan_path"],
"calibration_path": metadata["calibration_path"],
"label_path": metadata["label_path"] if has_label else None,
"split": "val",
}
data = {
"point": reduced_pc,
"full_point": pc,
"feat": None,
"calib": calib,
"bounding_boxes": label if has_label else None,
"attr": attr,
}
return data, attr
def load_files(self, scan_paths: Union[str, List[str]], dataset: Input):
if isinstance(scan_paths, str):
scan_paths = [scan_paths]
def clean_fn(path: str) -> str:
return path.replace(self.scans_folder_name, self.calibrations_folder_name).replace(".bin", ".txt")
dataset.path_list = scan_paths
return [{"scan_path": scan_path, "calibration_path": clean_fn(scan_path)} for scan_path in scan_paths]
def predict_load_data(self, data, dataset: Input):
if (isinstance(data, str) and isfile(data)) or (isinstance(data, list) and all(isfile(p) for p in data)):
return self.load_files(data, dataset)
if isinstance(data, str) and isdir(data):
raise NotImplementedError
return None
def predict_load_sample(self, metadata: Dict[str, str]):
metadata, attr = self.load_sample(metadata, has_label=False)
# hack to prevent manipulation of labels
attr["split"] = "test"
return metadata, attr
[docs]class PointCloudObjectDetectorFoldersInput(Input):
loaders: Dict[PointCloudObjectDetectionDataFormat, Type[BasePointCloudObjectDetectorLoader]] = {
PointCloudObjectDetectionDataFormat.KITTI: KITTIPointCloudObjectDetectorLoader
}
def _get_loader(
self, data_format: Optional[BaseDataFormat] = None, image_size: tuple = (375, 1242), **loader_kwargs: Any
) -> BasePointCloudObjectDetectorLoader:
return self.loaders[data_format or PointCloudObjectDetectionDataFormat.KITTI](
image_size=image_size, **loader_kwargs
)
def _validate_data(self, folder: str) -> None:
msg = f"The provided dataset for stage {self._running_stage} should be a folder. Found {folder}."
if not isinstance(folder, str):
raise ValueError(msg)
if isinstance(folder, str) and not isdir(folder):
raise ValueError(msg)
def load_data(
self,
folder: str,
data_format: Optional[BaseDataFormat] = None,
image_size: tuple = (375, 1242),
**loader_kwargs: Any,
) -> Any:
self._validate_data(folder)
self.loader = self._get_loader(data_format, image_size, **loader_kwargs)
return self.loader.load_data(folder, self)
def load_sample(self, metadata: Dict[str, str]) -> Any:
data, metadata = self.loader.load_sample(metadata)
input_transform_fn = getattr(self, "input_transform_fn", None)
if input_transform_fn:
data = input_transform_fn(data, metadata)
transform_fn = getattr(self, "transform_fn", None)
if transform_fn:
data = transform_fn(data, metadata)
return {"data": data, "attr": metadata}
def _validate_predict_data(self, data: Union[str, List[str]]) -> None:
msg = f"The provided predict data should be a either a folder or a single/list of scan path(s). Found {data}."
if not isinstance(data, str) and not isinstance(data, list):
raise ValueError(msg)
if isinstance(data, str) and (not isfile(data) or not isdir(data)):
raise ValueError(msg)
if isinstance(data, list) and not all(isfile(p) for p in data):
raise ValueError(msg)
def predict_load_data(
self,
data: Union[str, List[str]],
data_format: Optional[BaseDataFormat] = None,
image_size: tuple = (375, 1242),
**loader_kwargs: Any,
) -> Any:
self._validate_predict_data(data)
self.loader = self._get_loader(data_format, image_size, **loader_kwargs)
return self.loader.predict_load_data(data, self)
def predict_load_sample(self, metadata: Dict[str, str]) -> Any:
data, metadata = self.loader.predict_load_sample(metadata)
input_transform_fn = getattr(self, "input_transform_fn", None)
if input_transform_fn:
data = input_transform_fn(data, metadata)
transform_fn = getattr(self, "transform_fn", None)
if transform_fn:
data = transform_fn(data, metadata)
return {"data": data, "attr": metadata}