DataModule¶
- class flash.core.data.data_module.DataModule(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
A basic DataModule class for all Flash tasks. This class includes references to a
Input
and aBaseDataFetcher
.- Parameters
train_input¶ (
Optional
[Input
]) – Input dataset for training. Defaults to None.val_input¶ (
Optional
[Input
]) – Input dataset for validating model performance during training. Defaults to None.test_input¶ (
Optional
[Input
]) – Input dataset to test model performance. Defaults to None.predict_input¶ (
Optional
[Input
]) – Input dataset for predicting. Defaults to None.data_fetcher¶ (
Optional
[BaseDataFetcher
]) – TheBaseDataFetcher
to attach to theInputTransform
. IfNone
, the output fromconfigure_data_fetcher()
will be used.val_split¶ (
Optional
[float
]) – An optional float which gives the relative amount of the training dataset to use for the validation dataset.batch_size¶ (
Optional
[int
]) – The batch size to be used by the DataLoader.num_workers¶ (
int
) – The number of workers to use for parallelized loading.sampler¶ (
Union
[Callable
,Sampler
,Type
[Sampler
],None
]) – A sampler following theSampler
type. Will be passed to the DataLoader for the training dataset. Defaults to None.
Examples
You can provide the sampler to use for the train dataloader using the
sampler
argument. The sampler can be a function or type that needs the dataset as an argument:>>> datamodule = DataModule(train_input, sampler=SequentialSampler, batch_size=1) >>> print(datamodule.train_dataloader().sampler) <torch.utils.data.sampler.SequentialSampler object at ...>
Alternatively, you can pass a sampler instance:
>>> datamodule = DataModule(train_input, sampler=WeightedRandomSampler([0.1, 0.5], 2), batch_size=1) >>> print(datamodule.train_dataloader().sampler) <torch.utils.data.sampler.WeightedRandomSampler object at ...>
- static configure_data_fetcher(*args, **kwargs)[source]¶
This function is used to configure a
BaseDataFetcher
.Override with your custom one.
- Return type
- property data_fetcher: flash.core.data.callback.BaseDataFetcher¶
This property returns the data fetcher.
- Return type
- input_transform_cls¶
- property inputs: Optional[Union[flash.core.data.io.input.Input, List[flash.core.data.io.input.InputBase]]]¶
Property that returns the inputs associated with this
DataModule
.
- property labels: Optional[int]¶
Property that returns the labels if this
DataModule
contains classification data.
- property multi_label: Optional[bool]¶
Property that returns
True
if thisDataModule
contains multi-label data.
- property num_classes: Optional[int]¶
Property that returns the number of classes of the datamodule if a multiclass task.
- property predict_dataset: Optional[flash.core.data.io.input.Input]¶
This property returns the prediction dataset.
- show_predict_batch(hooks_names='load_sample', reset=True)[source]¶
This function is used to visualize a batch from the prediction dataloader.
- Return type
- show_test_batch(hooks_names='load_sample', reset=True)[source]¶
This function is used to visualize a batch from the test dataloader.
- Return type
- show_train_batch(hooks_names='load_sample', reset=True)[source]¶
This function is used to visualize a batch from the train dataloader.
- Return type
- show_val_batch(hooks_names='load_sample', reset=True)[source]¶
This function is used to visualize a batch from the validation dataloader.
- Return type
- property test_dataset: Optional[flash.core.data.io.input.Input]¶
This property returns the test dataset.
- property train_dataset: Optional[flash.core.data.io.input.Input]¶
This property returns the train dataset.
- property val_dataset: Optional[flash.core.data.io.input.Input]¶
This property returns the validation dataset.