DataModule¶
- class flash.core.data.data_module.DataModule(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
A basic DataModule class for all Flash tasks. This class includes references to a
Inputand aBaseDataFetcher.- Parameters
train_input¶ (
Optional[Input]) – Input dataset for training. Defaults to None.val_input¶ (
Optional[Input]) – Input dataset for validating model performance during training. Defaults to None.test_input¶ (
Optional[Input]) – Input dataset to test model performance. Defaults to None.predict_input¶ (
Optional[Input]) – Input dataset for predicting. Defaults to None.data_fetcher¶ (
Optional[BaseDataFetcher]) – TheBaseDataFetcherto attach to theInputTransform. IfNone, the output fromconfigure_data_fetcher()will be used.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.val_split¶ (
Optional[float]) – An optional float which gives the relative amount of the training dataset to use for the validation dataset.batch_size¶ (
Optional[int]) – The batch size to be used by the DataLoader.num_workers¶ (
int) – The number of workers to use for parallelized loading.sampler¶ (
Union[Callable,Sampler,Type[Sampler],None]) – A sampler following theSamplertype. Will be passed to the DataLoader for the training dataset. Defaults to None.
Examples
You can provide the sampler to use for the train dataloader using the
samplerargument. The sampler can be a function or type that needs the dataset as an argument:>>> datamodule = DataModule(train_input, sampler=SequentialSampler, batch_size=1) >>> print(datamodule.train_dataloader().sampler) <torch.utils.data.sampler.SequentialSampler object at ...>
Alternatively, you can pass a sampler instance:
>>> datamodule = DataModule(train_input, sampler=WeightedRandomSampler([0.1, 0.5], 2), batch_size=1) >>> print(datamodule.train_dataloader().sampler) <torch.utils.data.sampler.WeightedRandomSampler object at ...>
- static configure_data_fetcher(*args, **kwargs)[source]¶
This function is used to configure a
BaseDataFetcher.Override with your custom one.
- Return type
- property data_fetcher: flash.core.data.callback.BaseDataFetcher¶
This property returns the data fetcher.
- Return type
- property input_transform: flash.core.data.io.input_transform.InputTransform¶
This property returns the data fetcher.
- Return type
- input_transform_cls¶
- property labels: Optional[int]¶
Property that returns the labels if this
DataModulecontains classification data.
- property multi_label: Optional[bool]¶
Property that returns
Trueif thisDataModulecontains multi-label data.
- property num_classes: Optional[int]¶
Property that returns the number of classes of the datamodule if a multiclass task.
- property predict_dataset: Optional[flash.core.data.io.input.Input]¶
This property returns the prediction dataset.
- show_predict_batch(hooks_names='load_sample', reset=True, limit_nb_samples=None, figsize=(6.4, 4.8))[source]¶
This function is used to visualize a batch from the prediction dataloader.
- Return type
- show_test_batch(hooks_names='load_sample', reset=True, limit_nb_samples=None, figsize=(6.4, 4.8))[source]¶
This function is used to visualize a batch from the test dataloader.
- Return type
- show_train_batch(hooks_names='load_sample', reset=True, limit_nb_samples=None, figsize=(6.4, 4.8))[source]¶
This function is used to visualize a batch from the train dataloader.
- Return type
- show_val_batch(hooks_names='load_sample', reset=True, limit_nb_samples=None, figsize=(6.4, 4.8))[source]¶
This function is used to visualize a batch from the validation dataloader.
- Return type
- property test_dataset: Optional[flash.core.data.io.input.Input]¶
This property returns the test dataset.
- property train_dataset: Optional[flash.core.data.io.input.Input]¶
This property returns the train dataset.
- property val_dataset: Optional[flash.core.data.io.input.Input]¶
This property returns the validation dataset.