Shortcuts

Source code for flash.core.data.callback

from contextlib import contextmanager
from typing import Any, List, Sequence

from pytorch_lightning.callbacks import Callback
from torch import Tensor

import flash
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.utilities.stages import RunningStage


[docs]class FlashCallback(Callback): """``FlashCallback`` is an extension of :class:`pytorch_lightning.callbacks.Callback`. A callback is a self-contained program that can be reused across projects. Flash and Lightning have a callback system to execute callbacks when needed. Callbacks should capture any NON-ESSENTIAL logic that is NOT required for your lightning module to run. Same as PyTorch Lightning, Callbacks can be provided directly to the Trainer:: trainer = Trainer(callbacks=[MyCustomCallback()]) """
[docs] def on_per_sample_transform(self, sample: Tensor, running_stage: RunningStage) -> None: """Called once ``per_sample_transform`` has been applied to a sample."""
[docs] def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: """Called once ``per_batch_transform`` has been applied to a batch."""
[docs] def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: """Called once ``collate`` has been applied to a sequence of samples."""
[docs] def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None: """Called once ``per_sample_transform_on_device`` has been applied to a sample."""
[docs] def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: """Called once ``per_batch_transform_on_device`` has been applied to a sample."""
[docs]class ControlFlow(FlashCallback): def __init__(self, callbacks: List[FlashCallback]): self._callbacks = callbacks def run_for_all_callbacks(self, *args, method_name: str, **kwargs): if self._callbacks: for cb in self._callbacks: getattr(cb, method_name)(*args, **kwargs) def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: self.run_for_all_callbacks(sample, running_stage, method_name="on_load_sample") def on_per_sample_transform(self, sample: Any, running_stage: RunningStage) -> None: self.run_for_all_callbacks(sample, running_stage, method_name="on_per_sample_transform") def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform") def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: self.run_for_all_callbacks(batch, running_stage, method_name="on_collate") def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None: self.run_for_all_callbacks(sample, running_stage, method_name="on_per_sample_transform_on_device") def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform_on_device")
[docs]class BaseDataFetcher(FlashCallback): """This class is used to profile :class:`~flash.core.data.io.input_transform.InputTransform` hook outputs. By default, the callback won't profile the data being processed as it may lead to ``OOMError``. Example:: from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform class CustomInputTransform(InputTransform): def __init__(**kwargs): super().__init__( inputs = {"inputs": Input()}, **kwargs, ) class PrintData(BaseDataFetcher): def print(self): print(self.batches) class CustomDataModule(DataModule): input_transform_cls = CustomInputTransform @staticmethod def configure_data_fetcher(): return PrintData() @classmethod def from_inputs( cls, train_data: Any, val_data: Any, test_data: Any, predict_data: Any, ) -> "CustomDataModule": return cls.from_input( "inputs", train_data=train_data, val_data=val_data, test_data=test_data, predict_data=predict_data, batch_size=5, ) dm = CustomDataModule.from_inputs(range(5), range(5), range(5), range(5)) data_fetcher = dm.data_fetcher # By default, the ``data_fetcher`` is disabled to prevent OOM. # The ``enable`` context manager will activate it. with data_fetcher.enable(): # This will fetch the first val dataloader batch. _ = next(iter(dm.val_dataloader())) data_fetcher.print() # out: { 'train': {}, 'test': {}, 'val': { 'load_sample': [0, 1, 2, 3, 4], 'per_sample_transform': [0, 1, 2, 3, 4], 'collate': [tensor([0, 1, 2, 3, 4])], 'per_batch_transform': [tensor([0, 1, 2, 3, 4])]}, 'predict': {} } data_fetcher.reset() data_fetcher.print() # out: { 'train': {}, 'test': {}, 'val': {}, 'predict': {} } """ batches: dict def __init__(self, enabled: bool = False): self.enabled = enabled self._input_transform = None self.reset() def _store(self, data: Any, fn_name: str, running_stage: RunningStage) -> None: if self.enabled: store = self.batches[_STAGES_PREFIX[running_stage]] store.setdefault(fn_name, []) store[fn_name].append(data) def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: self._store(sample, "load_sample", running_stage) def on_per_sample_transform(self, sample: Any, running_stage: RunningStage) -> None: self._store(sample, "per_sample_transform", running_stage) def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: self._store(batch, "per_batch_transform", running_stage) def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: self._store(batch, "collate", running_stage) def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: self._store(samples, "per_sample_transform_on_device", running_stage) def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: self._store(batch, "per_batch_transform_on_device", running_stage)
[docs] @contextmanager def enable(self): """This function is used to enable to BaseDataFetcher.""" self.enabled = True yield self.enabled = False
def attach_to_input_transform(self, input_transform: "flash.core.data.io.input_transform.InputTransform") -> None: input_transform.add_callbacks([self]) self._input_transform = input_transform def reset(self): self.batches = {k: {} for k in _STAGES_PREFIX.values()}

© Copyright 2020-2021, PyTorch Lightning. Revision da42a635.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.