Shortcuts

SpeechRecognitionData

class flash.audio.speech_recognition.data.SpeechRecognitionData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The SpeechRecognitionData class is a DataModule with a set of classmethods for loading data for speech recognition.

classmethod from_csv(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, sampling_rate=16000, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from CSV files containing audio file paths and their corresponding targets.

Input audio file paths will be extracted from the input_field column in the CSV files. The supported file extensions are: .aiff, .au, .avr, .caf, .flac, .mat, .mat4, .mat5, .mpc2k, .ogg, .paf, .pvf, .rf64, .ircam, .voc, .w64, .wav, .nist, and .wavex. The targets will be extracted from the target_field in the CSV files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
  • input_field (str) – The field (column name) in the CSV files containing the audio file paths.

  • target_field (Optional[str]) – The field (column name) in the CSV files containing the targets.

  • train_file (Optional[str]) – The CSV file to use when training.

  • val_file (Optional[str]) – The CSV file to use when validating.

  • test_file (Optional[str]) – The CSV file to use when testing.

  • predict_file (Optional[str]) – The CSV file to use when predicting.

  • sampling_rate (int) – Sampling rate to use when loading the audio files.

  • input_cls (Type[Input]) – The Input type to use for loading the data.

  • transform (TypeVar(INPUT_TRANSFORM_TYPE, Type[flash.core.data.io.input_transform.InputTransform], Callable, Tuple[Union[StrEnum, str], Dict[str, Any]], Union[StrEnum, str], None)) – The InputTransform type to use.

  • transform_kwargs (Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.

  • data_module_kwargs (Any) – Additional keyword arguments to provide to the DataModule constructor.

Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

The files can be in Comma Separated Values (CSV) format with either a .csv or .txt extension.

The file train_data.csv contains the following:

speech_files,targets
speech_1.wav,some speech
speech_2.wav,some other speech
speech_3.wav,some more speech

The file predict_data.csv contains the following:

speech_files
predict_speech_1.wav
predict_speech_2.wav
predict_speech_3.wav
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>> datamodule = SpeechRecognitionData.from_csv(
...     "speech_files",
...     "targets",
...     train_file="train_data.csv",
...     predict_file="predict_data.csv",
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

Alternatively, the files can be in Tab Separated Values (TSV) format with either a .tsv.

The file train_data.tsv contains the following:

speech_files    targets
speech_1.wav    some speech
speech_2.wav    some other speech
speech_3.wav    some more speech

The file predict_data.tsv contains the following:

speech_files
predict_speech_1.wav
predict_speech_2.wav
predict_speech_3.wav
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>> datamodule = SpeechRecognitionData.from_csv(
...     "speech_files",
...     "targets",
...     train_file="train_data.tsv",
...     predict_file="predict_data.tsv",
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, sampling_rate=16000, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionDatasetInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from PyTorch Dataset objects.

The Dataset objects should be one of the following:

  • A PyTorch Dataset where the __getitem__ returns a tuple: (file_path or , target)

  • A PyTorch Dataset where the __getitem__ returns a dict: {"input": file_path, "target": target}

The supported file extensions are: .aiff, .au, .avr, .caf, .flac, .mat, .mat4, .mat5, .mpc2k, .ogg, .paf, .pvf, .rf64, .ircam, .voc, .w64, .wav, .nist, and .wavex. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

A PyTorch Dataset where the __getitem__ returns a tuple: (file_path, target):

>>> from torch.utils.data import Dataset
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>>
>>> class CustomDataset(Dataset):
...     def __init__(self, files, targets=None):
...         self.files = files
...         self.targets = targets
...     def __getitem__(self, index):
...         if self.targets is not None:
...             return self.files[index], self.targets[index]
...         return self.files[index]
...     def __len__(self):
...         return len(self.files)
...
>>>
>>> datamodule = SpeechRecognitionData.from_datasets(
...     train_dataset=CustomDataset(
...         ["speech_1.wav", "speech_2.wav", "speech_3.wav"],
...         ["some speech", "some other speech", "some more speech"],
...     ),
...     predict_dataset=CustomDataset(
...         ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"],
...     ),
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

A PyTorch Dataset where the __getitem__ returns a dict: {"input": file_path, "target": target}:

>>> from torch.utils.data import Dataset
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>>
>>> class CustomDataset(Dataset):
...     def __init__(self, files, targets=None):
...         self.files = files
...         self.targets = targets
...     def __getitem__(self, index):
...         if self.targets is not None:
...             return {"input": self.files[index], "target": self.targets[index]}
...         return {"input": self.files[index]}
...     def __len__(self):
...         return len(self.files)
...
>>>
>>> datamodule = SpeechRecognitionData.from_datasets(
...     train_dataset=CustomDataset(
...         ["speech_1.wav", "speech_2.wav", "speech_3.wav"],
...         ["some speech", "some other speech", "some more speech"],
...     ),
...     predict_dataset=CustomDataset(
...         ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"],
...     ),
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_files(train_files=None, train_targets=None, val_files=None, val_targets=None, test_files=None, test_targets=None, predict_files=None, sampling_rate=16000, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionPathsInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from lists of audio files and corresponding lists of targets.

The supported file extensions are: .aiff, .au, .avr, .caf, .flac, .mat, .mat4, .mat5, .mpc2k, .ogg, .paf, .pvf, .rf64, .ircam, .voc, .w64, .wav, .nist, and .wavex. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>> datamodule = SpeechRecognitionData.from_files(
...     train_files=["speech_1.wav", "speech_2.wav", "speech_3.wav"],
...     train_targets=["some speech", "some other speech", "some more speech"],
...     predict_files=["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"],
...     batch_size=2,
... )
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_json(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, sampling_rate=16000, field=None, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionJSONInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SpeechRecognitionData from JSON files containing audio file paths and their corresponding targets.

Input audio file paths will be extracted from the input_field field in the JSON files. The supported file extensions are: .aiff, .au, .avr, .caf, .flac, .mat, .mat4, .mat5, .mpc2k, .ogg, .paf, .pvf, .rf64, .ircam, .voc, .w64, .wav, .nist, and .wavex. The targets will be extracted from the target_field field in the JSON files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
  • input_field (str) – The field in the JSON files containing the audio file paths.

  • target_field (Optional[str]) – The field in the JSON files containing the targets.

  • train_file (Optional[str]) – The JSON file to use when training.

  • val_file (Optional[str]) – The JSON file to use when validating.

  • test_file (Optional[str]) – The JSON file to use when testing.

  • predict_file (Optional[str]) – The JSON file to use when predicting.

  • sampling_rate (int) – Sampling rate to use when loading the audio files.

  • field (Optional[str]) – The field that holds the data in the JSON file.

  • input_cls (Type[Input]) – The Input type to use for loading the data.

  • transform (TypeVar(INPUT_TRANSFORM_TYPE, Type[flash.core.data.io.input_transform.InputTransform], Callable, Tuple[Union[StrEnum, str], Dict[str, Any]], Union[StrEnum, str], None)) – The InputTransform type to use.

  • transform_kwargs (Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.

  • data_module_kwargs (Any) – Additional keyword arguments to provide to the DataModule constructor.

Return type

SpeechRecognitionData

Returns

The constructed SpeechRecognitionData.

Examples

The file train_data.json contains the following:

{"speech_files":"speech_1.wav","targets":"some speech"}
{"speech_files":"speech_2.wav","targets":"some other speech"}
{"speech_files":"speech_3.wav","targets":"some more speech"}

The file predict_data.json contains the following:

{"speech_files":"predict_speech_1.wav"}
{"speech_files":"predict_speech_2.wav"}
{"speech_files":"predict_speech_3.wav"}
>>> from flash import Trainer
>>> from flash.audio import SpeechRecognitionData, SpeechRecognition
>>> datamodule = SpeechRecognitionData.from_json(
...     "speech_files",
...     "targets",
...     train_file="train_data.json",
...     predict_file="predict_data.json",
...     batch_size=2,
... )  

...
>>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
input_transform_cls

alias of flash.core.data.io.input_transform.InputTransform

Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.