Shortcuts

BaseDataFetcher

class flash.core.data.callback.BaseDataFetcher(enabled=False)[source]

This class is used to profile InputTransform hook outputs.

By default, the callback won’t profile the data being processed as it may lead to OOMError.

Example:

from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import InputTransform

class CustomInputTransform(InputTransform):

    def __init__(**kwargs):
        super().__init__(
            inputs = {"inputs": Input()},
            **kwargs,
        )

class PrintData(BaseDataFetcher):

    def print(self):
        print(self.batches)

class CustomDataModule(DataModule):

    input_transform_cls = CustomInputTransform

    @staticmethod
    def configure_data_fetcher():
        return PrintData()

    @classmethod
    def from_inputs(
        cls,
        train_data: Any,
        val_data: Any,
        test_data: Any,
        predict_data: Any,
    ) -> "CustomDataModule":
        return cls.from_input(
            "inputs",
            train_data=train_data,
            val_data=val_data,
            test_data=test_data,
            predict_data=predict_data,
            batch_size=5,
        )

dm = CustomDataModule.from_inputs(range(5), range(5), range(5), range(5))
data_fetcher = dm.data_fetcher

# By default, the ``data_fetcher`` is disabled to prevent OOM.
# The ``enable`` context manager will activate it.
with data_fetcher.enable():

    # This will fetch the first val dataloader batch.
    _ = next(iter(dm.val_dataloader()))

data_fetcher.print()
# out:
{
    'train': {},
    'test': {},
    'val': {
        'load_sample': [0, 1, 2, 3, 4],
        'per_sample_transform': [0, 1, 2, 3, 4],
        'collate': [tensor([0, 1, 2, 3, 4])],
        'per_batch_transform': [tensor([0, 1, 2, 3, 4])]},
    'predict': {}
}
data_fetcher.reset()
data_fetcher.print()
# out:
{
    'train': {},
    'test': {},
    'val': {},
    'predict': {}
}
enable()[source]

This function is used to enable to BaseDataFetcher.

Read the Docs v: stable
Versions
latest
stable
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.