Shortcuts

SpeechRecognitionData

class flash.audio.speech_recognition.data.SpeechRecognitionData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The SpeechRecognitionData class is a DataModule with a set of classmethods for loading data for speech recognition.

classmethod from_csv(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, sampling_rate=16000, train_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, val_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, test_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, predict_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionCSVInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from CSV files containing audio file paths and their corresponding targets.

Input audio file paths will be extracted from the input_field column in the CSV files. The supported file extensions are: wav, ogg, flac, mat, and mp3. The targets will be extracted from the target_field in the CSV files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

The file train_data.csv contains the following:

speech_files,targets
speech_1.wav,some speech
speech_2.wav,some other speech
speech_3.wav,some more speech

The file predict_data.csv contains the following:

speech_files
predict_speech_1.wav
predict_speech_2.wav
predict_speech_3.wav
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>> datamodule = SpeechRecognitionData.from_csv(
...     "speech_files",
...     "targets",
...     train_file="train_data.csv",
...     predict_file="predict_data.csv",
...     batch_size=2,
... )  
Downloading...
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, train_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, val_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, test_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, predict_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, sampling_rate=16000, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionDatasetInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from PyTorch Dataset objects.

The Dataset objects should be one of the following:

  • A PyTorch Dataset where the __getitem__ returns a tuple: (file_path or , target)

  • A PyTorch Dataset where the __getitem__ returns a dict: {"input": file_path, "target": target}

The supported file extensions are: wav, ogg, flac, mat, and mp3. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

A PyTorch Dataset where the __getitem__ returns a tuple: (file_path, target):

>>> from torch.utils.data import Dataset
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>>
>>> class CustomDataset(Dataset):
...     def __init__(self, files, targets=None):
...         self.files = files
...         self.targets = targets
...     def __getitem__(self, index):
...         if self.targets is not None:
...             return self.files[index], self.targets[index]
...         return self.files[index]
...     def __len__(self):
...         return len(self.files)
...
>>>
>>> datamodule = SpeechRecognitionData.from_datasets(
...     train_dataset=CustomDataset(
...         ["speech_1.wav", "speech_2.wav", "speech_3.wav"],
...         ["some speech", "some other speech", "some more speech"],
...     ),
...     predict_dataset=CustomDataset(
...         ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"],
...     ),
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

A PyTorch Dataset where the __getitem__ returns a dict: {"input": file_path, "target": target}:

>>> from torch.utils.data import Dataset
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>>
>>> class CustomDataset(Dataset):
...     def __init__(self, files, targets=None):
...         self.files = files
...         self.targets = targets
...     def __getitem__(self, index):
...         if self.targets is not None:
...             return {"input": self.files[index], "target": self.targets[index]}
...         return {"input": self.files[index]}
...     def __len__(self):
...         return len(self.files)
...
>>>
>>> datamodule = SpeechRecognitionData.from_datasets(
...     train_dataset=CustomDataset(
...         ["speech_1.wav", "speech_2.wav", "speech_3.wav"],
...         ["some speech", "some other speech", "some more speech"],
...     ),
...     predict_dataset=CustomDataset(
...         ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"],
...     ),
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_files(train_files=None, train_targets=None, val_files=None, val_targets=None, test_files=None, test_targets=None, predict_files=None, sampling_rate=16000, train_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, val_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, test_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, predict_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionPathsInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from lists of audio files and corresponding lists of targets.

The supported file extensions are: wav, ogg, flac, mat, and mp3. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>> datamodule = SpeechRecognitionData.from_files(
...     train_files=["speech_1.wav", "speech_2.wav", "speech_3.wav"],
...     train_targets=["some speech", "some other speech", "some more speech"],
...     predict_files=["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"],
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_json(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, sampling_rate=16000, field=None, train_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, val_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, test_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, predict_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionJSONInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from JSON files containing audio file paths and their corresponding targets.

Input audio file paths will be extracted from the input_field field in the JSON files. The supported file extensions are: wav, ogg, flac, mat, and mp3. The targets will be extracted from the target_field field in the JSON files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

The file train_data.json contains the following:

{"speech_files":"speech_1.wav","targets":"some speech"}
{"speech_files":"speech_2.wav","targets":"some other speech"}
{"speech_files":"speech_3.wav","targets":"some more speech"}

The file predict_data.json contains the following:

{"speech_files":"predict_speech_1.wav"}
{"speech_files":"predict_speech_2.wav"}
{"speech_files":"predict_speech_3.wav"}
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>> datamodule = SpeechRecognitionData.from_json(
...     "speech_files",
...     "targets",
...     train_file="train_data.json",
...     predict_file="predict_data.json",
...     batch_size=2,
... )  
Downloading...
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
input_transform_cls

alias of flash.core.data.io.input_transform.InputTransform

Read the Docs v: 0.7.3
Versions
latest
stable
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.