Shortcuts

SpeechRecognitionData

class flash.audio.speech_recognition.data.SpeechRecognitionData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The SpeechRecognitionData class is a DataModule with a set of classmethods for loading data for speech recognition.

classmethod from_csv(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, sampling_rate=16000, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from CSV files containing audio file paths and their corresponding targets.

Input audio file paths will be extracted from the input_field column in the CSV files. The supported file extensions are: .aiff, .au, .avr, .caf, .flac, .mat, .mat4, .mat5, .mpc2k, .ogg, .paf, .pvf, .rf64, .ircam, .voc, .w64, .wav, .nist, and .wavex. The targets will be extracted from the target_field in the CSV files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
  • input_field (str) – The field (column name) in the CSV files containing the audio file paths.

  • target_field (Optional[str]) – The field (column name) in the CSV files containing the targets.

  • train_file (Optional[str]) – The CSV file to use when training.

  • val_file (Optional[str]) – The CSV file to use when validating.

  • test_file (Optional[str]) – The CSV file to use when testing.

  • predict_file (Optional[str]) – The CSV file to use when predicting.

  • sampling_rate (int) – Sampling rate to use when loading the audio files.

  • input_cls (Type[Input]) – The Input type to use for loading the data.

  • transform (TypeVar(INPUT_TRANSFORM_TYPE, Type[flash.core.data.io.input_transform.InputTransform], Callable, Tuple[Union[StrEnum, str], Dict[str, Any]], Union[StrEnum, str], None)) – The InputTransform type to use.

  • transform_kwargs (Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.

  • data_module_kwargs (Any) – Additional keyword arguments to provide to the DataModule constructor.

Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

The files can be in Comma Separated Values (CSV) format with either a .csv or .txt extension.

The file train_data.csv contains the following:

speech_files,targets
speech_1.wav,some speech
speech_2.wav,some other speech
speech_3.wav,some more speech

The file predict_data.csv contains the following:

speech_files
predict_speech_1.wav
predict_speech_2.wav
predict_speech_3.wav
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>> datamodule = SpeechRecognitionData.from_csv(
...     "speech_files",
...     "targets",
...     train_file="train_data.csv",
...     predict_file="predict_data.csv",
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

Alternatively, the files can be in Tab Separated Values (TSV) format with either a .tsv.

The file train_data.tsv contains the following:

speech_files    targets
speech_1.wav    some speech
speech_2.wav    some other speech
speech_3.wav    some more speech

The file predict_data.tsv contains the following:

speech_files
predict_speech_1.wav
predict_speech_2.wav
predict_speech_3.wav
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>> datamodule = SpeechRecognitionData.from_csv(
...     "speech_files",
...     "targets",
...     train_file="train_data.tsv",
...     predict_file="predict_data.tsv",
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, sampling_rate=16000, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionDatasetInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from PyTorch Dataset objects.

The Dataset objects should be one of the following:

  • A PyTorch Dataset where the __getitem__ returns a tuple: (file_path or , target)

  • A PyTorch Dataset where the __getitem__ returns a dict: {"input": file_path, "target": target}

The supported file extensions are: .aiff, .au, .avr, .caf, .flac, .mat, .mat4, .mat5, .mpc2k, .ogg, .paf, .pvf, .rf64, .ircam, .voc, .w64, .wav, .nist, and .wavex. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

A PyTorch Dataset where the __getitem__ returns a tuple: (file_path, target):

>>> from torch.utils.data import Dataset
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>>
>>> class CustomDataset(Dataset):
...     def __init__(self, files, targets=None):
...         self.files = files
...         self.targets = targets
...     def __getitem__(self, index):
...         if self.targets is not None:
...             return self.files[index], self.targets[index]
...         return self.files[index]
...     def __len__(self):
...         return len(self.files)
...
>>>
>>> datamodule = SpeechRecognitionData.from_datasets(
...     train_dataset=CustomDataset(
...         ["speech_1.wav", "speech_2.wav", "speech_3.wav"],
...         ["some speech", "some other speech", "some more speech"],
...     ),
...     predict_dataset=CustomDataset(
...         ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"],
...     ),
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

A PyTorch Dataset where the __getitem__ returns a dict: {"input": file_path, "target": target}:

>>> from torch.utils.data import Dataset
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>>
>>> class CustomDataset(Dataset):
...     def __init__(self, files, targets=None):
...         self.files = files
...         self.targets = targets
...     def __getitem__(self, index):
...         if self.targets is not None:
...             return {"input": self.files[index], "target": self.targets[index]}
...         return {"input": self.files[index]}
...     def __len__(self):
...         return len(self.files)
...
>>>
>>> datamodule = SpeechRecognitionData.from_datasets(
...     train_dataset=CustomDataset(
...         ["speech_1.wav", "speech_2.wav", "speech_3.wav"],
...         ["some speech", "some other speech", "some more speech"],
...     ),
...     predict_dataset=CustomDataset(
...         ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"],
...     ),
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_files(train_files=None, train_targets=None, val_files=None, val_targets=None, test_files=None, test_targets=None, predict_files=None, sampling_rate=16000, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionPathsInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from lists of audio files and corresponding lists of targets.

The supported file extensions are: .aiff, .au, .avr, .caf, .flac, .mat, .mat4, .mat5, .mpc2k, .ogg, .paf, .pvf, .rf64, .ircam, .voc, .w64, .wav, .nist, and .wavex. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>> datamodule = SpeechRecognitionData.from_files(
...     train_files=["speech_1.wav", "speech_2.wav", "speech_3.wav"],
...     train_targets=["some speech", "some other speech", "some more speech"],
...     predict_files=["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"],
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_json(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, sampling_rate=16000, field=None, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionJSONInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from JSON files containing audio file paths and their corresponding targets.

Input audio file paths will be extracted from the input_field field in the JSON files. The supported file extensions are: .aiff, .au, .avr, .caf, .flac, .mat, .mat4, .mat5, .mpc2k, .ogg, .paf, .pvf, .rf64, .ircam, .voc, .w64, .wav, .nist, and .wavex. The targets will be extracted from the target_field field in the JSON files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
  • input_field (str) – The field in the JSON files containing the audio file paths.

  • target_field (Optional[str]) – The field in the JSON files containing the targets.

  • train_file (Optional[str]) – The JSON file to use when training.

  • val_file (Optional[str]) – The JSON file to use when validating.

  • test_file (Optional[str]) – The JSON file to use when testing.

  • predict_file (Optional[str]) – The JSON file to use when predicting.

  • sampling_rate (int) – Sampling rate to use when loading the audio files.

  • field (Optional[str]) – The field that holds the data in the JSON file.

  • input_cls (Type[Input]) – The Input type to use for loading the data.

  • transform (TypeVar(INPUT_TRANSFORM_TYPE, Type[flash.core.data.io.input_transform.InputTransform], Callable, Tuple[Union[StrEnum, str], Dict[str, Any]], Union[StrEnum, str], None)) – The InputTransform type to use.

  • transform_kwargs (Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.

  • data_module_kwargs (Any) – Additional keyword arguments to provide to the DataModule constructor.

Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

The file train_data.json contains the following:

{"speech_files":"speech_1.wav","targets":"some speech"}
{"speech_files":"speech_2.wav","targets":"some other speech"}
{"speech_files":"speech_3.wav","targets":"some more speech"}

The file predict_data.json contains the following:

{"speech_files":"predict_speech_1.wav"}
{"speech_files":"predict_speech_2.wav"}
{"speech_files":"predict_speech_3.wav"}
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>> datamodule = SpeechRecognitionData.from_json(
...     "speech_files",
...     "targets",
...     train_file="train_data.json",
...     predict_file="predict_data.json",
...     batch_size=2,
... )  
Downloading...
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
input_transform_cls

alias of flash.core.data.io.input_transform.InputTransform