BaseDataFetcher¶
- class flash.core.data.callback.BaseDataFetcher(enabled=False)[source]¶
This class is used to profile
InputTransform
hook outputs.By default, the callback won’t profile the data being processed as it may lead to
OOMError
.Example:
from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform class CustomInputTransform(InputTransform): def __init__(**kwargs): super().__init__( inputs = {"inputs": Input()}, **kwargs, ) class PrintData(BaseDataFetcher): def print(self): print(self.batches) class CustomDataModule(DataModule): input_transform_cls = CustomInputTransform @staticmethod def configure_data_fetcher(): return PrintData() @classmethod def from_inputs( cls, train_data: Any, val_data: Any, test_data: Any, predict_data: Any, ) -> "CustomDataModule": return cls.from_input( "inputs", train_data=train_data, val_data=val_data, test_data=test_data, predict_data=predict_data, batch_size=5, ) dm = CustomDataModule.from_inputs(range(5), range(5), range(5), range(5)) data_fetcher = dm.data_fetcher # By default, the ``data_fetcher`` is disabled to prevent OOM. # The ``enable`` context manager will activate it. with data_fetcher.enable(): # This will fetch the first val dataloader batch. _ = next(iter(dm.val_dataloader())) data_fetcher.print() # out: { 'train': {}, 'test': {}, 'val': { 'load_sample': [0, 1, 2, 3, 4], 'per_sample_transform': [0, 1, 2, 3, 4], 'collate': [tensor([0, 1, 2, 3, 4])], 'per_batch_transform': [tensor([0, 1, 2, 3, 4])]}, 'predict': {} } data_fetcher.reset() data_fetcher.print() # out: { 'train': {}, 'test': {}, 'val': {}, 'predict': {} }