Source code for flash.core.data.data_module
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader, Dataset
from torch.utils.data._utils.collate import default_collate
from torch.utils.data.dataset import IterableDataset
from torch.utils.data.sampler import Sampler
import flash
from flash.core.data.base_viz import BaseVisualization
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.io.input import DataKeys, Input, InputBase, IterableInput
from flash.core.data.io.input_transform import (
_create_collate_input_transform_processors,
_InputTransformProcessorV2,
create_transform,
InputTransform,
)
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.splits import SplitDataset
from flash.core.data.utils import _STAGES_PREFIX
from flash.core.registry import FlashRegistry
from flash.core.utilities.stages import RunningStage
class DatasetInput(Input):
"""The ``DatasetInput`` implements default behaviours for data sources which expect the input to
:meth:`~flash.core.data.io.input.Input.load_data` to be a :class:`torch.utils.data.dataset.Dataset`
"""
def load_sample(self, sample: Any) -> Dict[str, Any]:
if isinstance(sample, tuple) and len(sample) == 2:
return {DataKeys.INPUT: sample[0], DataKeys.TARGET: sample[1]}
return {DataKeys.INPUT: sample}
[docs]class DataModule(pl.LightningDataModule):
"""A basic DataModule class for all Flash tasks. This class includes references to a
:class:`~flash.core.data.datasets.Input` and a :class:`~flash.core.data.callback.BaseDataFetcher`.
Args:
train_input: Input dataset for training. Defaults to None.
val_input: Input dataset for validating model performance during training. Defaults to None.
test_input: Input dataset to test model performance. Defaults to None.
predict_input: Input dataset for predicting. Defaults to None.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to attach to the
:class:`~flash.core.data.io.input_transform.InputTransform`. If ``None``, the output from
:meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` will be used.
val_split: An optional float which gives the relative amount of the training dataset to use for the validation
dataset.
batch_size: The batch size to be used by the DataLoader.
num_workers: The number of workers to use for parallelized loading.
sampler: A sampler following the :class:`~torch.utils.data.sampler.Sampler` type.
Will be passed to the DataLoader for the training dataset. Defaults to None.
Examples
________
.. testsetup::
>>> from flash import DataModule
>>> from flash.core.utilities.stages import RunningStage
>>> from torch.utils.data.sampler import SequentialSampler, WeightedRandomSampler
>>> class TestInput(Input):
... def train_load_data(self, _):
... return [(0, 1, 2, 3), (0, 1, 2, 3)]
>>> train_input = TestInput(RunningStage.TRAINING, [1])
You can provide the sampler to use for the train dataloader using the ``sampler`` argument.
The sampler can be a function or type that needs the dataset as an argument:
.. doctest::
>>> datamodule = DataModule(train_input, sampler=SequentialSampler, batch_size=1)
>>> print(datamodule.train_dataloader().sampler) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<torch.utils.data.sampler.SequentialSampler object at ...>
Alternatively, you can pass a sampler instance:
.. doctest::
>>> datamodule = DataModule(train_input, sampler=WeightedRandomSampler([0.1, 0.5], 2), batch_size=1)
>>> print(datamodule.train_dataloader().sampler) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
<torch.utils.data.sampler.WeightedRandomSampler object at ...>
"""
input_transform_cls = InputTransform
input_transforms_registry: Optional[FlashRegistry] = None
def __init__(
self,
train_input: Optional[Input] = None,
val_input: Optional[Input] = None,
test_input: Optional[Input] = None,
predict_input: Optional[Input] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
val_split: Optional[float] = None,
batch_size: Optional[int] = None,
num_workers: int = 0,
sampler: Optional[Union[Callable, Sampler, Type[Sampler]]] = None,
pin_memory: bool = True,
persistent_workers: bool = False,
) -> None:
if not batch_size:
raise MisconfigurationException("The `batch_size` should be provided to the DataModule on instantiation.")
if flash._IS_TESTING and torch.cuda.is_available():
batch_size = 16
self._input_transform: Optional[OutputTransform] = None
self._viz: Optional[BaseVisualization] = None
self._train_input = train_input
self._val_input = val_input
self._test_input = test_input
self._predict_input = predict_input
if self._train_input and self._val_input and isinstance(val_split, float) and val_split > 0:
raise MisconfigurationException(
"A `val_dataset` was provided with `val_split`. Please, choose one or the other."
)
if self._train_input and (val_split is not None and not self._val_input):
self._train_input, self._val_input = self._split_train_val(self._train_input, val_split)
self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher()
self._train_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._train_input)
self._val_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._val_input)
self._test_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._test_input)
self._predict_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._predict_input)
self._on_after_batch_transfer_fns = {
RunningStage.TRAINING: self._resolve_on_after_batch_transfer_fn(self._train_input),
RunningStage.VALIDATING: self._resolve_on_after_batch_transfer_fn(self._val_input),
RunningStage.SANITY_CHECKING: self._resolve_on_after_batch_transfer_fn(self._val_input),
RunningStage.TESTING: self._resolve_on_after_batch_transfer_fn(self._test_input),
RunningStage.PREDICTING: self._resolve_on_after_batch_transfer_fn(self._predict_input),
}
self._model_on_after_batch_transfer_fns = None
if self._train_input:
self.train_dataloader = self._train_dataloader
if self._val_input:
self.val_dataloader = self._val_dataloader
if self._test_input:
self.test_dataloader = self._test_dataloader
if self._predict_input:
self.predict_dataloader = self._predict_dataloader
self.batch_size = batch_size
self.num_workers = num_workers
self.persistent_workers = persistent_workers and num_workers > 0
self.pin_memory = pin_memory
self.sampler = sampler
super().__init__()
@property
def train_dataset(self) -> Optional[Input]:
"""This property returns the train dataset."""
return self._train_input
@property
def val_dataset(self) -> Optional[Input]:
"""This property returns the validation dataset."""
return self._val_input
@property
def test_dataset(self) -> Optional[Input]:
"""This property returns the test dataset."""
return self._test_input
@property
def predict_dataset(self) -> Optional[Input]:
"""This property returns the predict dataset."""
return self._predict_input
def _resolve_dataloader_collate_fn(self, ds: Optional[Input]) -> Optional[Callable]:
if not ds:
return None
if isinstance(ds.transform, InputTransform):
return ds._create_dataloader_collate_fn([self.data_fetcher])
return default_collate
def _resolve_on_after_batch_transfer_fn(self, ds: Optional[Input]) -> Optional[Callable]:
if not ds:
return None
if isinstance(ds.transform, InputTransform):
return ds._create_on_after_batch_transfer_fn([self.data_fetcher])
def _train_dataloader(self) -> DataLoader:
train_ds: Input = self._train_input
collate_fn = self._train_dataloader_collate_fn
if isinstance(getattr(self, "trainer", None), pl.Trainer):
input_transform = getattr(self.trainer.lightning_module, "input_transform", None)
if input_transform is not None:
input_transform = create_transform(input_transform, RunningStage.TRAINING)
collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0]
transform_processor = None
if isinstance(collate_fn, _InputTransformProcessorV2):
transform_processor = collate_fn
collate_fn = transform_processor.collate_fn
shuffle: bool = False
if isinstance(train_ds, IterableDataset):
drop_last = False
else:
drop_last = len(train_ds) > self.batch_size
if self.sampler is None:
sampler = None
shuffle = not isinstance(train_ds, IterableDataset)
elif callable(self.sampler):
sampler = self.sampler(train_ds)
else:
sampler = self.sampler
if isinstance(getattr(self, "trainer", None), pl.Trainer):
dataloader = self.trainer.lightning_module.process_train_dataset(
train_ds,
trainer=self.trainer,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=shuffle,
drop_last=drop_last,
collate_fn=collate_fn,
sampler=sampler,
persistent_workers=self.persistent_workers,
)
else:
dataloader = DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
drop_last=drop_last,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)
if transform_processor is not None:
transform_processor.collate_fn = dataloader.collate_fn
dataloader.collate_fn = transform_processor
self._model_on_after_batch_transfer_fns = None
return dataloader
def _val_dataloader(self) -> DataLoader:
val_ds: Input = self._val_input
collate_fn = self._val_dataloader_collate_fn
if isinstance(getattr(self, "trainer", None), pl.Trainer):
input_transform = getattr(self.trainer.lightning_module, "input_transform", None)
if input_transform is not None:
input_transform = create_transform(input_transform, RunningStage.VALIDATING)
collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0]
transform_processor = None
if isinstance(collate_fn, _InputTransformProcessorV2):
transform_processor = collate_fn
collate_fn = transform_processor.collate_fn
if isinstance(getattr(self, "trainer", None), pl.Trainer):
dataloader = self.trainer.lightning_module.process_val_dataset(
val_ds,
trainer=self.trainer,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)
else:
dataloader = DataLoader(
val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)
if transform_processor is not None:
transform_processor.collate_fn = dataloader.collate_fn
dataloader.collate_fn = transform_processor
self._model_on_after_batch_transfer_fns = None
return dataloader
def _test_dataloader(self) -> DataLoader:
test_ds: Input = self._test_input
collate_fn = self._test_dataloader_collate_fn
if isinstance(getattr(self, "trainer", None), pl.Trainer):
input_transform = getattr(self.trainer.lightning_module, "input_transform", None)
if input_transform is not None:
input_transform = create_transform(input_transform, RunningStage.TESTING)
collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0]
transform_processor = None
if isinstance(collate_fn, _InputTransformProcessorV2):
transform_processor = collate_fn
collate_fn = transform_processor.collate_fn
if isinstance(getattr(self, "trainer", None), pl.Trainer):
dataloader = self.trainer.lightning_module.process_test_dataset(
test_ds,
trainer=self.trainer,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)
else:
dataloader = DataLoader(
test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)
if transform_processor is not None:
transform_processor.collate_fn = dataloader.collate_fn
dataloader.collate_fn = transform_processor
self._model_on_after_batch_transfer_fns = None
return dataloader
def _predict_dataloader(self) -> DataLoader:
predict_ds: Input = self._predict_input
collate_fn = self._predict_dataloader_collate_fn
if isinstance(getattr(self, "trainer", None), pl.Trainer):
input_transform = getattr(self.trainer.lightning_module, "input_transform", None)
if input_transform is not None:
input_transform = create_transform(input_transform, RunningStage.PREDICTING)
collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0]
transform_processor = None
if isinstance(collate_fn, _InputTransformProcessorV2):
transform_processor = collate_fn
collate_fn = transform_processor.collate_fn
if isinstance(predict_ds, IterableDataset):
batch_size = self.batch_size
else:
batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1)
if isinstance(getattr(self, "trainer", None), pl.Trainer):
dataloader = self.trainer.lightning_module.process_predict_dataset(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)
else:
dataloader = DataLoader(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)
if transform_processor is not None:
transform_processor.collate_fn = dataloader.collate_fn
dataloader.collate_fn = transform_processor
self._model_on_after_batch_transfer_fns = None
return dataloader
def _load_model_on_after_batch_transfer_fns(self) -> None:
self._model_on_after_batch_transfer_fns = {}
for stage in [
RunningStage.TRAINING,
RunningStage.VALIDATING,
RunningStage.SANITY_CHECKING,
RunningStage.TESTING,
RunningStage.PREDICTING,
]:
transform = None
if isinstance(getattr(self, "trainer", None), pl.Trainer):
input_transform = getattr(self.trainer.lightning_module, "input_transform", None)
if input_transform is not None:
input_transform = create_transform(
input_transform, stage if stage != RunningStage.SANITY_CHECKING else RunningStage.VALIDATING
)
transform = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[1]
self._model_on_after_batch_transfer_fns[stage] = transform
def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
if getattr(self, "trainer", None) is None:
return batch
if self._model_on_after_batch_transfer_fns is None:
self._load_model_on_after_batch_transfer_fns()
stage = self.trainer.state.stage
transform = self._model_on_after_batch_transfer_fns[stage]
if transform is None:
transform = self._on_after_batch_transfer_fns[stage]
if transform:
batch = transform(batch)
return batch
@property
def viz(self) -> BaseVisualization:
return self._viz or DataModule.configure_data_fetcher()
@viz.setter
def viz(self, viz: BaseVisualization) -> None:
self._viz = viz
[docs] @staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
"""This function is used to configure a :class:`~flash.core.data.callback.BaseDataFetcher`.
Override with your custom one.
"""
return BaseDataFetcher()
@property
def data_fetcher(self) -> BaseDataFetcher:
"""This property returns the data fetcher."""
return self._data_fetcher or DataModule.configure_data_fetcher()
@data_fetcher.setter
def data_fetcher(self, data_fetcher: BaseDataFetcher) -> None:
self._data_fetcher = data_fetcher
def _reset_iterator(self, stage: str) -> Iterable[Any]:
iter_name = f"_{stage}_iter"
# num_workers has to be set to 0 to work properly
num_workers = self.num_workers
self.num_workers = 0
dataloader_fn = getattr(self, f"{stage}_dataloader")
iterator = iter(dataloader_fn())
self.num_workers = num_workers
setattr(self, iter_name, iterator)
return iterator
def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool = True) -> None:
"""This function is used to handle transforms profiling for batch visualization."""
# don't show in CI
if os.getenv("FLASH_TESTING", "0") == "1":
return None
iter_name = f"_{stage}_iter"
if not hasattr(self, iter_name):
self._reset_iterator(stage)
# list of functions to visualise
if isinstance(func_names, str):
func_names = [func_names]
iter_dataloader = getattr(self, iter_name)
with self.data_fetcher.enable():
if reset:
self.data_fetcher.batches[stage] = {}
try:
_ = next(iter_dataloader)
except StopIteration:
iter_dataloader = self._reset_iterator(stage)
_ = next(iter_dataloader)
data_fetcher: BaseVisualization = self.data_fetcher
data_fetcher._show(stage, func_names)
if reset:
self.data_fetcher.batches[stage] = {}
[docs] def show_train_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None:
"""This function is used to visualize a batch from the train dataloader."""
stage_name: str = _STAGES_PREFIX[RunningStage.TRAINING]
self._show_batch(stage_name, hooks_names, reset=reset)
[docs] def show_val_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None:
"""This function is used to visualize a batch from the validation dataloader."""
stage_name: str = _STAGES_PREFIX[RunningStage.VALIDATING]
self._show_batch(stage_name, hooks_names, reset=reset)
[docs] def show_test_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None:
"""This function is used to visualize a batch from the test dataloader."""
stage_name: str = _STAGES_PREFIX[RunningStage.TESTING]
self._show_batch(stage_name, hooks_names, reset=reset)
[docs] def show_predict_batch(self, hooks_names: Union[str, List[str]] = "load_sample", reset: bool = True) -> None:
"""This function is used to visualize a batch from the predict dataloader."""
stage_name: str = _STAGES_PREFIX[RunningStage.PREDICTING]
self._show_batch(stage_name, hooks_names, reset=reset)
def _get_property(self, property_name: str) -> Optional[Any]:
train = getattr(self.train_dataset, property_name, None)
val = getattr(self.val_dataset, property_name, None)
test = getattr(self.test_dataset, property_name, None)
filtered = list(filter(lambda x: x is not None, [train, val, test]))
return filtered[0] if len(filtered) > 0 else None
@property
def num_classes(self) -> Optional[int]:
"""Property that returns the number of classes of the datamodule if a multiclass task."""
return self._get_property("num_classes")
@property
def labels(self) -> Optional[int]:
"""Property that returns the labels if this ``DataModule`` contains classification data."""
return self._get_property("labels")
@property
def multi_label(self) -> Optional[bool]:
"""Property that returns ``True`` if this ``DataModule`` contains multi-label data."""
return self._get_property("multi_label")
@property
def inputs(self) -> Optional[Union[Input, List[InputBase]]]:
"""Property that returns the inputs associated with this ``DataModule``."""
inputs = [self.train_dataset, self.val_dataset, self.test_dataset, self.predict_dataset]
return [input for input in inputs if input]
@staticmethod
def _split_train_val(
train_dataset: Dataset,
val_split: float,
) -> Tuple[Any, Any]:
"""Utility function for splitting the training dataset into a disjoint subset of training samples and
validation samples.
Args:
train_dataset: A instance of a :class:`torch.utils.data.Dataset`.
val_split: A float between 0 and 1 determining the number fraction of samples that should go into the
validation split
Returns:
A tuple containing the training and validation datasets
"""
if not isinstance(val_split, float) or (isinstance(val_split, float) and val_split > 1 or val_split < 0):
raise MisconfigurationException(f"`val_split` should be a float between 0 and 1. Found {val_split}.")
if isinstance(train_dataset, IterableInput):
raise MisconfigurationException(
"`val_split` should be `None` when the dataset is built with an IterableDataset."
)
val_num_samples = int(len(train_dataset) * val_split)
indices = list(range(len(train_dataset)))
np.random.shuffle(indices)
val_indices = indices[:val_num_samples]
train_indices = indices[val_num_samples:]
return (
SplitDataset(train_dataset, train_indices, use_duplicated_indices=True),
SplitDataset(train_dataset, val_indices, use_duplicated_indices=True),
)