Shortcuts

DataModule

class flash.core.data.data_module.DataModule(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

A basic DataModule class for all Flash tasks. This class includes references to a Input and a BaseDataFetcher.

Parameters
  • train_input (Optional[Input]) – Input dataset for training. Defaults to None.

  • val_input (Optional[Input]) – Input dataset for validating model performance during training. Defaults to None.

  • test_input (Optional[Input]) – Input dataset to test model performance. Defaults to None.

  • predict_input (Optional[Input]) – Input dataset for predicting. Defaults to None.

  • data_fetcher (Optional[BaseDataFetcher]) – The BaseDataFetcher to attach to the InputTransform. If None, the output from configure_data_fetcher() will be used.

  • val_split (Optional[float]) – An optional float which gives the relative amount of the training dataset to use for the validation dataset.

  • batch_size (Optional[int]) – The batch size to be used by the DataLoader.

  • num_workers (int) – The number of workers to use for parallelized loading.

  • sampler (Union[Callable, Sampler, Type[Sampler], None]) – A sampler following the Sampler type. Will be passed to the DataLoader for the training dataset. Defaults to None.

Examples

You can provide the sampler to use for the train dataloader using the sampler argument. The sampler can be a function or type that needs the dataset as an argument:

>>> datamodule = DataModule(train_input, sampler=SequentialSampler, batch_size=1)
>>> print(datamodule.train_dataloader().sampler)  
<torch.utils.data.sampler.SequentialSampler object at ...>

Alternatively, you can pass a sampler instance:

>>> datamodule = DataModule(train_input, sampler=WeightedRandomSampler([0.1, 0.5], 2), batch_size=1)
>>> print(datamodule.train_dataloader().sampler)  
<torch.utils.data.sampler.WeightedRandomSampler object at ...>
static configure_data_fetcher(*args, **kwargs)[source]

This function is used to configure a BaseDataFetcher.

Override with your custom one.

Return type

BaseDataFetcher

property data_fetcher: flash.core.data.callback.BaseDataFetcher

This property returns the data fetcher.

Return type

BaseDataFetcher

input_transform_cls

alias of flash.core.data.io.input_transform.InputTransform

property inputs: Optional[Union[flash.core.data.io.input.Input, List[flash.core.data.io.input.InputBase]]]

Property that returns the inputs associated with this DataModule.

Return type

Union[Input, List[InputBase], None]

property labels: Optional[int]

Property that returns the labels if this DataModule contains classification data.

Return type

Optional[int]

property multi_label: Optional[bool]

Property that returns True if this DataModule contains multi-label data.

Return type

Optional[bool]

property num_classes: Optional[int]

Property that returns the number of classes of the datamodule if a multiclass task.

Return type

Optional[int]

property predict_dataset: Optional[flash.core.data.io.input.Input]

This property returns the predict dataset.

Return type

Optional[Input]

show_predict_batch(hooks_names='load_sample', reset=True)[source]

This function is used to visualize a batch from the predict dataloader.

Return type

None

show_test_batch(hooks_names='load_sample', reset=True)[source]

This function is used to visualize a batch from the test dataloader.

Return type

None

show_train_batch(hooks_names='load_sample', reset=True)[source]

This function is used to visualize a batch from the train dataloader.

Return type

None

show_val_batch(hooks_names='load_sample', reset=True)[source]

This function is used to visualize a batch from the validation dataloader.

Return type

None

property test_dataset: Optional[flash.core.data.io.input.Input]

This property returns the test dataset.

Return type

Optional[Input]

property train_dataset: Optional[flash.core.data.io.input.Input]

This property returns the train dataset.

Return type

Optional[Input]

property val_dataset: Optional[flash.core.data.io.input.Input]

This property returns the validation dataset.

Return type

Optional[Input]

Read the Docs v: 0.7.1
Versions
latest
stable
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.