Shortcuts

Source code for flash.core.adapter

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from typing import Any, Callable, Optional

import torch.jit
from torch import Tensor, nn
from torch.utils.data import DataLoader, Sampler

import flash
from flash.core.data.io.input import InputBase
from flash.core.data.io.input_transform import InputTransform
from flash.core.model import DatasetProcessor, ModuleWrapperBase, Task
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE


[docs]class Adapter(DatasetProcessor, ModuleWrapperBase, nn.Module): """The ``Adapter`` is a lightweight interface that can be used to encapsulate the logic from a particular provider within a :class:`~flash.core.model.Task`."""
[docs] @classmethod @abstractmethod def from_task(cls, task: "flash.Task", **kwargs) -> "Adapter": """Instantiate the adapter from the given :class:`~flash.core.model.Task`. This includes resolution / creation of backbones / heads and any other provider specific options. """
def forward(self, x: Any) -> Any: pass def training_step(self, batch: Any, batch_idx: int) -> Any: pass def validation_step(self, batch: Any, batch_idx: int) -> None: pass def test_step(self, batch: Any, batch_idx: int) -> None: pass def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: pass def training_epoch_end(self, outputs) -> None: pass def validation_epoch_end(self, outputs) -> None: pass def test_epoch_end(self, outputs) -> None: pass
def identity_collate_fn(x): return x
[docs]class AdapterTask(Task): """The ``AdapterTask`` is a :class:`~flash.core.model.Task` which wraps an :class:`~flash.core.adapter.Adapter` and forwards all of the hooks. Args: adapter: The :class:`~flash.core.adapter.Adapter` to wrap. kwargs: Keyword arguments to be passed to the base :class:`~flash.core.model.Task`. """ def __init__(self, adapter: Adapter, **kwargs): super().__init__(**kwargs) self.adapter = adapter @torch.jit.unused @property def input_transform(self) -> Optional[INPUT_TRANSFORM_TYPE]: return self.adapter.input_transform @input_transform.setter def input_transform(self, input_transform: INPUT_TRANSFORM_TYPE) -> None: self.adapter.input_transform = input_transform @torch.jit.unused @property def collate_fn(self) -> Optional[Callable]: return self.adapter.collate_fn @collate_fn.setter def collate_fn(self, collate_fn: Callable) -> None: self.adapter.collate_fn = collate_fn @torch.jit.unused @property def backbone(self) -> nn.Module: return self.adapter.backbone def forward(self, x: Tensor) -> Any: return self.adapter.forward(x) def training_step(self, batch: Any, batch_idx: int) -> Any: return self.adapter.training_step(batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> None: return self.adapter.validation_step(batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> None: return self.adapter.test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: return self.adapter.predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) def training_epoch_end(self, outputs) -> None: return self.adapter.training_epoch_end(outputs) def validation_epoch_end(self, outputs) -> None: return self.adapter.validation_epoch_end(outputs) def test_epoch_end(self, outputs) -> None: return self.adapter.test_epoch_end(outputs) def process_train_dataset( self, dataset: InputBase, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle: bool = True, drop_last: bool = True, sampler: Optional[Sampler] = None, persistent_workers: bool = False, input_transform: Optional[InputTransform] = None, trainer: Optional["flash.Trainer"] = None, ) -> DataLoader: return self.adapter.process_train_dataset( dataset, batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, drop_last=drop_last, sampler=sampler, persistent_workers=persistent_workers, input_transform=input_transform, trainer=trainer, ) def process_val_dataset( self, dataset: InputBase, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, persistent_workers: bool = False, input_transform: Optional[InputTransform] = None, trainer: Optional["flash.Trainer"] = None, ) -> DataLoader: return self.adapter.process_val_dataset( dataset, batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, drop_last=drop_last, sampler=sampler, persistent_workers=persistent_workers, input_transform=input_transform, trainer=trainer, ) def process_test_dataset( self, dataset: InputBase, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, persistent_workers: bool = False, input_transform: Optional[InputTransform] = None, trainer: Optional["flash.Trainer"] = None, ) -> DataLoader: return self.adapter.process_test_dataset( dataset, batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, drop_last=drop_last, sampler=sampler, persistent_workers=persistent_workers, input_transform=input_transform, trainer=trainer, ) def process_predict_dataset( self, dataset: InputBase, batch_size: int, num_workers: int = 0, pin_memory: bool = False, shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, persistent_workers: bool = False, input_transform: Optional[InputTransform] = None, trainer: Optional["flash.Trainer"] = None, ) -> DataLoader: return self.adapter.process_predict_dataset( dataset, batch_size, num_workers=num_workers, pin_memory=pin_memory, shuffle=shuffle, drop_last=drop_last, sampler=sampler, persistent_workers=persistent_workers, input_transform=input_transform, trainer=trainer, )

© Copyright 2020-2021, PyTorch Lightning. Revision a374dd4f.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.