Shortcuts

Source code for flash.pointcloud.detection.open3d_ml.input

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from os.path import basename, dirname, exists, isdir, isfile, join
from typing import Any, Dict, List, Optional, Type, Union

import yaml
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.data.io.input import BaseDataFormat, Input
from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE

if _POINTCLOUD_AVAILABLE:
    from open3d._ml3d.datasets.kitti import DataProcessing, KITTI


class PointCloudObjectDetectionDataFormat(BaseDataFormat):
    KITTI = "kitti"


class BasePointCloudObjectDetectorLoader:
    # TODO: Do we need this?

    def __init__(self, image_size: tuple = (375, 1242), **loader_kwargs):
        self.image_size = image_size

    def load_data(self, folder: str, dataset: Input):
        raise NotImplementedError

    def load_sample(self, metadata: Dict[str, str], has_label: bool = True) -> Any:
        raise NotImplementedError

    def predict_load_data(self, data: Union[str, List[str]], dataset: Input):
        raise NotImplementedError

    def predict_load_sample(self, metadata: Any):
        raise NotImplementedError


class KITTIPointCloudObjectDetectorLoader(BasePointCloudObjectDetectorLoader):
    meta: dict

    def __init__(
        self,
        image_size: tuple = (375, 1242),
        scans_folder_name: Optional[str] = "scans",
        labels_folder_name: Optional[str] = "labels",
        calibrations_folder_name: Optional[str] = "calibs",
    ):
        super().__init__(image_size)
        self.scans_folder_name = scans_folder_name
        self.labels_folder_name = labels_folder_name
        self.calibrations_folder_name = calibrations_folder_name

    def load_meta(self, root_dir, dataset: Input):
        meta_file = join(root_dir, "meta.yaml")
        if not exists(meta_file):
            raise MisconfigurationException(f"The {root_dir} should contain a `meta.yaml` file about the classes.")

        with open(meta_file) as f:
            self.meta = yaml.safe_load(f)

        if "label_to_names" not in self.meta:
            raise MisconfigurationException(
                f"The {root_dir} should contain a `meta.yaml` file about the classes with the field `label_to_names`."
            )

        dataset.num_classes = len(self.meta["label_to_names"])
        dataset.label_to_names = self.meta["label_to_names"]
        dataset.color_map = self.meta["color_map"]

    def load_data(self, folder: str, dataset: Input):
        sub_directories = os.listdir(folder)
        if len(sub_directories) != 3:
            raise MisconfigurationException(
                f"Using KITTI Format, the {folder} should contains 3 directories "
                "for ``calibrations``, ``labels`` and ``scans``."
            )

        assert self.scans_folder_name in sub_directories
        assert self.labels_folder_name in sub_directories
        assert self.calibrations_folder_name in sub_directories

        scans_dir = join(folder, self.scans_folder_name)
        labels_dir = join(folder, self.labels_folder_name)
        calibrations_dir = join(folder, self.calibrations_folder_name)

        scan_paths = [join(scans_dir, f) for f in os.listdir(scans_dir)]
        label_paths = [join(labels_dir, f) for f in os.listdir(labels_dir)]
        calibration_paths = [join(calibrations_dir, f) for f in os.listdir(calibrations_dir)]

        assert len(scan_paths) == len(label_paths) == len(calibration_paths)

        self.load_meta(dirname(folder), dataset)

        dataset.path_list = scan_paths

        return [
            {"scan_path": scan_path, "label_path": label_path, "calibration_path": calibration_path}
            for scan_path, label_path, calibration_path, in zip(scan_paths, label_paths, calibration_paths)
        ]

    def load_sample(self, metadata: Dict[str, str], has_label: bool = True) -> Any:
        pc = KITTI.read_lidar(metadata["scan_path"])
        calib = KITTI.read_calib(metadata["calibration_path"])
        label = None
        if has_label:
            label = KITTI.read_label(metadata["label_path"], calib)

        reduced_pc = DataProcessing.remove_outside_points(pc, calib["world_cam"], calib["cam_img"], self.image_size)

        attr = {
            "name": basename(metadata["scan_path"]),
            "path": metadata["scan_path"],
            "calibration_path": metadata["calibration_path"],
            "label_path": metadata["label_path"] if has_label else None,
            "split": "val",
        }

        data = {
            "point": reduced_pc,
            "full_point": pc,
            "feat": None,
            "calib": calib,
            "bounding_boxes": label if has_label else None,
            "attr": attr,
        }
        return data, attr

    def load_files(self, scan_paths: Union[str, List[str]], dataset: Input):
        if isinstance(scan_paths, str):
            scan_paths = [scan_paths]

        def clean_fn(path: str) -> str:
            return path.replace(self.scans_folder_name, self.calibrations_folder_name).replace(".bin", ".txt")

        dataset.path_list = scan_paths

        return [{"scan_path": scan_path, "calibration_path": clean_fn(scan_path)} for scan_path in scan_paths]

    def predict_load_data(self, data, dataset: Input):
        if (isinstance(data, str) and isfile(data)) or (isinstance(data, list) and all(isfile(p) for p in data)):
            return self.load_files(data, dataset)
        if isinstance(data, str) and isdir(data):
            raise NotImplementedError

    def predict_load_sample(self, metadata: Dict[str, str]):
        metadata, attr = self.load_sample(metadata, has_label=False)
        # hack to prevent manipulation of labels
        attr["split"] = "test"
        return metadata, attr


[docs]class PointCloudObjectDetectorFoldersInput(Input): loaders: Dict[PointCloudObjectDetectionDataFormat, Type[BasePointCloudObjectDetectorLoader]] = { PointCloudObjectDetectionDataFormat.KITTI: KITTIPointCloudObjectDetectorLoader } def _get_loader( self, data_format: Optional[BaseDataFormat] = None, image_size: tuple = (375, 1242), **loader_kwargs: Any ) -> BasePointCloudObjectDetectorLoader: return self.loaders[data_format or PointCloudObjectDetectionDataFormat.KITTI]( image_size=image_size, **loader_kwargs ) def _validate_data(self, folder: str) -> None: msg = f"The provided dataset for stage {self._running_stage} should be a folder. Found {folder}." if not isinstance(folder, str): raise MisconfigurationException(msg) if isinstance(folder, str) and not isdir(folder): raise MisconfigurationException(msg) def load_data( self, folder: str, data_format: Optional[BaseDataFormat] = None, image_size: tuple = (375, 1242), **loader_kwargs: Any, ) -> Any: self._validate_data(folder) self.loader = self._get_loader(data_format, image_size, **loader_kwargs) return self.loader.load_data(folder, self) def load_sample(self, metadata: Dict[str, str]) -> Any: data, metadata = self.loader.load_sample(metadata) input_transform_fn = getattr(self, "input_transform_fn", None) if input_transform_fn: data = input_transform_fn(data, metadata) transform_fn = getattr(self, "transform_fn", None) if transform_fn: data = transform_fn(data, metadata) return {"data": data, "attr": metadata} def _validate_predict_data(self, data: Union[str, List[str]]) -> None: msg = f"The provided predict data should be a either a folder or a single/list of scan path(s). Found {data}." if not isinstance(data, str) and not isinstance(data, list): raise MisconfigurationException(msg) if isinstance(data, str) and (not isfile(data) or not isdir(data)): raise MisconfigurationException(msg) if isinstance(data, list) and not all(isfile(p) for p in data): raise MisconfigurationException(msg) def predict_load_data( self, data: Union[str, List[str]], data_format: Optional[BaseDataFormat] = None, image_size: tuple = (375, 1242), **loader_kwargs: Any, ) -> Any: self._validate_predict_data(data) self.loader = self._get_loader(data_format, image_size, **loader_kwargs) return self.loader.predict_load_data(data, self) def predict_load_sample(self, metadata: Dict[str, str]) -> Any: data, metadata = self.loader.predict_load_sample(metadata) input_transform_fn = getattr(self, "input_transform_fn", None) if input_transform_fn: data = input_transform_fn(data, metadata) transform_fn = getattr(self, "transform_fn", None) if transform_fn: data = transform_fn(data, metadata) return {"data": data, "attr": metadata}

© Copyright 2020-2021, PyTorch Lightning. Revision 8db29e8e.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.3
Versions
latest
stable
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.