Shortcuts

Source code for flash.pointcloud.segmentation.model

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterable, Optional, Tuple, Union

import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Sampler

import flash
from flash.core.classification import ClassificationTask
from flash.core.data.io.input import DataKeys, InputBase
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.utilities.collate import wrap_collate
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TM_GREATER_EQUAL_0_10_0, _TOPIC_POINTCLOUD_AVAILABLE
from flash.core.utilities.stability import beta
from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE
from flash.pointcloud.segmentation.backbones import POINTCLOUD_SEGMENTATION_BACKBONES

if _TOPIC_POINTCLOUD_AVAILABLE:
    from open3d._ml3d.torch.modules.losses.semseg_loss import filter_valid_label
    from open3d.ml.torch.dataloaders import TorchDataloader

if _TM_GREATER_EQUAL_0_10_0:
    from torchmetrics.classification import MulticlassJaccardIndex as JaccardIndex
else:
    from torchmetrics import JaccardIndex


[docs]@beta("Point cloud segmentation is currently in Beta.") class PointCloudSegmentation(ClassificationTask): """The ``PointCloudClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that classifies pointcloud data. Args: num_classes: The number of classes (outputs) for this :class:`~flash.core.model.Task`. backbone: The backbone name (or a tuple of ``nn.Module``, output size) to use. backbone_kwargs: Any additional kwargs to pass to the backbone constructor. head: a `nn.Module` to use on top of the backbone. The output dimension should match the `num_classes` argument. If not set will default to a single linear layer. loss_fn: The loss function to use. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. multi_label: If ``True``, this will be treated as a multi-label classification problem. """ backbones: FlashRegistry = POINTCLOUD_SEGMENTATION_BACKBONES required_extras: str = "pointcloud" def __init__( self, num_classes: int, backbone: Union[str, Tuple[nn.Module, int]] = "randlanet", backbone_kwargs: Optional[Dict] = None, head: Optional[nn.Module] = None, loss_fn: LOSS_FN_TYPE = torch.nn.functional.cross_entropy, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, learning_rate: Optional[float] = None, multi_label: bool = False, ): import flash if metrics is None: metrics = JaccardIndex(num_classes=num_classes) super().__init__( model=None, loss_fn=loss_fn, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, ) self.save_hyperparameters() if not backbone_kwargs: backbone_kwargs = {"num_classes": num_classes} if isinstance(backbone, tuple): self.backbone, out_features = backbone else: self.backbone, out_features, collate_fn = self.backbones.get(backbone)(**backbone_kwargs) self.collate_fn = wrap_collate(collate_fn) # replace latest layer if not flash._IS_TESTING: self.backbone.fc = nn.Identity() self.head = nn.Identity() if flash._IS_TESTING else (head or nn.Linear(out_features, num_classes)) def apply_filtering(self, labels, scores): scores, labels = filter_valid_label(scores, labels, self.hparams.num_classes, [0], self.device) return labels, scores def to_metrics_format(self, x: Tensor) -> Tensor: return F.softmax(self.to_loss_format(x), dim=-1) def to_loss_format(self, x: Tensor) -> Tensor: return x.reshape(-1, x.shape[-1]) def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DataKeys.INPUT], batch[DataKeys.INPUT]["labels"].view(-1)) return super().training_step(batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DataKeys.INPUT], batch[DataKeys.INPUT]["labels"].view(-1)) return super().validation_step(batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DataKeys.INPUT], batch[DataKeys.INPUT]["labels"].view(-1)) return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: batch[DataKeys.PREDS] = self(batch[DataKeys.INPUT]) batch[DataKeys.TARGET] = batch[DataKeys.INPUT]["labels"] # drop sub-sampled pointclouds batch[DataKeys.INPUT] = batch[DataKeys.INPUT]["xyz"][0] return batch
[docs] def forward(self, x) -> Tensor: """First call the backbone, then the model head.""" # hack to enable backbone to work properly. self.backbone.device = self.device x = self.backbone(x) if self.head is not None: x = self.head(x) return x
def _patch_dataset(self, dataset: InputBase): if not isinstance(dataset.dataset, TorchDataloader): dataset.dataset = TorchDataloader( dataset.dataset, preprocess=self.backbone.preprocess, transform=self.backbone.transform, use_cache=False, ) def process_train_dataset( self, dataset: InputBase, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle: bool = True, drop_last: bool = True, sampler: Optional[Sampler] = None, persistent_workers: bool = False, input_transform: Optional[InputTransform] = None, trainer: Optional["flash.Trainer"] = None, ) -> DataLoader: self._patch_dataset(dataset) return super().process_train_dataset( dataset, batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, drop_last=drop_last, sampler=sampler, persistent_workers=persistent_workers, input_transform=input_transform, trainer=trainer, ) def process_val_dataset( self, dataset: InputBase, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, persistent_workers: bool = False, input_transform: Optional[InputTransform] = None, trainer: Optional["flash.Trainer"] = None, ) -> DataLoader: self._patch_dataset(dataset) return super().process_val_dataset( dataset, batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, drop_last=drop_last, sampler=sampler, persistent_workers=persistent_workers, input_transform=input_transform, trainer=trainer, ) def process_test_dataset( self, dataset: InputBase, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, persistent_workers: bool = False, input_transform: Optional[InputTransform] = None, trainer: Optional["flash.Trainer"] = None, ) -> DataLoader: self._patch_dataset(dataset) return super().process_test_dataset( dataset, batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, drop_last=drop_last, sampler=sampler, persistent_workers=persistent_workers, input_transform=input_transform, trainer=trainer, ) def process_predict_dataset( self, dataset: InputBase, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, persistent_workers: bool = False, input_transform: Optional[InputTransform] = None, trainer: Optional["flash.Trainer"] = None, ) -> DataLoader: self._patch_dataset(dataset) return super().process_predict_dataset( dataset, batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, drop_last=drop_last, sampler=sampler, persistent_workers=persistent_workers, input_transform=input_transform, trainer=trainer, )
[docs] def modules_to_freeze(self) -> Union[nn.Module, Iterable[Union[nn.Module, Iterable]]]: """Return the module attributes of the model to be frozen.""" return list(self.backbone.children())

© Copyright 2020-2021, PyTorch Lightning. Revision a374dd4f.

Built with Sphinx using a theme provided by Read the Docs.