SpeechRecognitionData¶
- class flash.audio.speech_recognition.data.SpeechRecognitionData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
SpeechRecognitionData
class is aDataModule
with a set of classmethods for loading data for speech recognition.- classmethod from_csv(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, sampling_rate=16000, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SpeechRecognitionData
from CSV files containing audio file paths and their corresponding targets.Input audio file paths will be extracted from the
input_field
column in the CSV files. The supported file extensions are:.aiff
,.au
,.avr
,.caf
,.flac
,.mat
,.mat4
,.mat5
,.mpc2k
,.ogg
,.paf
,.pvf
,.rf64
,.ircam
,.voc
,.w64
,.wav
,.nist
, and.wavex
. The targets will be extracted from thetarget_field
in the CSV files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str
) – The field (column name) in the CSV files containing the audio file paths.target_field¶ (
Optional
[str
]) – The field (column name) in the CSV files containing the targets.train_file¶ (
Optional
[str
]) – The CSV file to use when training.val_file¶ (
Optional
[str
]) – The CSV file to use when validating.test_file¶ (
Optional
[str
]) – The CSV file to use when testing.predict_file¶ (
Optional
[str
]) – The CSV file to use when predicting.sampling_rate¶ (
int
) – Sampling rate to use when loading the audio files.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
SpeechRecognitionData
.
Examples
The files can be in Comma Separated Values (CSV) format with either a
.csv
or.txt
extension.The file
train_data.csv
contains the following:speech_files,targets speech_1.wav,some speech speech_2.wav,some other speech speech_3.wav,some more speech
The file
predict_data.csv
contains the following:speech_files predict_speech_1.wav predict_speech_2.wav predict_speech_3.wav
>>> from flash import Trainer >>> from flash.audio import SpeechRecognitionData, SpeechRecognition >>> datamodule = SpeechRecognitionData.from_csv( ... "speech_files", ... "targets", ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=2, ... ) >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
Alternatively, the files can be in Tab Separated Values (TSV) format with either a
.tsv
.The file
train_data.tsv
contains the following:speech_files targets speech_1.wav some speech speech_2.wav some other speech speech_3.wav some more speech
The file
predict_data.tsv
contains the following:speech_files predict_speech_1.wav predict_speech_2.wav predict_speech_3.wav
>>> from flash import Trainer >>> from flash.audio import SpeechRecognitionData, SpeechRecognition >>> datamodule = SpeechRecognitionData.from_csv( ... "speech_files", ... "targets", ... train_file="train_data.tsv", ... predict_file="predict_data.tsv", ... batch_size=2, ... ) >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, sampling_rate=16000, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionDatasetInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SpeechRecognitionData
from PyTorch Dataset objects.The Dataset objects should be one of the following:
A PyTorch Dataset where the
__getitem__
returns a tuple:(file_path or , target)
A PyTorch Dataset where the
__getitem__
returns a dict:{"input": file_path, "target": target}
The supported file extensions are:
.aiff
,.au
,.avr
,.caf
,.flac
,.mat
,.mat4
,.mat5
,.mpc2k
,.ogg
,.paf
,.pvf
,.rf64
,.ircam
,.voc
,.w64
,.wav
,.nist
, and.wavex
. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_dataset¶ (
Optional
[Dataset
]) – The Dataset to use when training.val_dataset¶ (
Optional
[Dataset
]) – The Dataset to use when validating.test_dataset¶ (
Optional
[Dataset
]) – The Dataset to use when testing.predict_dataset¶ (
Optional
[Dataset
]) – The Dataset to use when predicting.sampling_rate¶ (
int
) – Sampling rate to use when loading the audio files.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
SpeechRecognitionData
.
Examples
A PyTorch Dataset where the
__getitem__
returns a tuple:(file_path, target)
:>>> from torch.utils.data import Dataset >>> from flash import Trainer >>> from flash.audio import SpeechRecognitionData, SpeechRecognition >>> >>> class CustomDataset(Dataset): ... def __init__(self, files, targets=None): ... self.files = files ... self.targets = targets ... def __getitem__(self, index): ... if self.targets is not None: ... return self.files[index], self.targets[index] ... return self.files[index] ... def __len__(self): ... return len(self.files) ... >>> >>> datamodule = SpeechRecognitionData.from_datasets( ... train_dataset=CustomDataset( ... ["speech_1.wav", "speech_2.wav", "speech_3.wav"], ... ["some speech", "some other speech", "some more speech"], ... ), ... predict_dataset=CustomDataset( ... ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"], ... ), ... batch_size=2, ... ) >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
A PyTorch Dataset where the
__getitem__
returns a dict:{"input": file_path, "target": target}
:>>> from torch.utils.data import Dataset >>> from flash import Trainer >>> from flash.audio import SpeechRecognitionData, SpeechRecognition >>> >>> class CustomDataset(Dataset): ... def __init__(self, files, targets=None): ... self.files = files ... self.targets = targets ... def __getitem__(self, index): ... if self.targets is not None: ... return {"input": self.files[index], "target": self.targets[index]} ... return {"input": self.files[index]} ... def __len__(self): ... return len(self.files) ... >>> >>> datamodule = SpeechRecognitionData.from_datasets( ... train_dataset=CustomDataset( ... ["speech_1.wav", "speech_2.wav", "speech_3.wav"], ... ["some speech", "some other speech", "some more speech"], ... ), ... predict_dataset=CustomDataset( ... ["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"], ... ), ... batch_size=2, ... ) >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_files(train_files=None, train_targets=None, val_files=None, val_targets=None, test_files=None, test_targets=None, predict_files=None, sampling_rate=16000, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionPathsInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SpeechRecognitionData
from lists of audio files and corresponding lists of targets.The supported file extensions are:
.aiff
,.au
,.avr
,.caf
,.flac
,.mat
,.mat4
,.mat5
,.mpc2k
,.ogg
,.paf
,.pvf
,.rf64
,.ircam
,.voc
,.w64
,.wav
,.nist
, and.wavex
. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_files¶ (
Optional
[Sequence
[str
]]) – The list of audio files to use when training.train_targets¶ (
Optional
[Sequence
[str
]]) – The list of targets (ground truth speech transcripts) to use when training.val_files¶ (
Optional
[Sequence
[str
]]) – The list of audio files to use when validating.val_targets¶ (
Optional
[Sequence
[str
]]) – The list of targets (ground truth speech transcripts) to use when validating.test_files¶ (
Optional
[Sequence
[str
]]) – The list of audio files to use when testing.test_targets¶ (
Optional
[Sequence
[str
]]) – The list of targets (ground truth speech transcripts) to use when testing.predict_files¶ (
Optional
[Sequence
[str
]]) – The list of audio files to use when predicting.sampling_rate¶ (
int
) – Sampling rate to use when loading the audio files.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
SpeechRecognitionData
.
Examples
>>> from flash import Trainer >>> from flash.audio import SpeechRecognitionData, SpeechRecognition >>> datamodule = SpeechRecognitionData.from_files( ... train_files=["speech_1.wav", "speech_2.wav", "speech_3.wav"], ... train_targets=["some speech", "some other speech", "some more speech"], ... predict_files=["predict_speech_1.wav", "predict_speech_2.wav", "predict_speech_3.wav"], ... batch_size=2, ... ) >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_json(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, sampling_rate=16000, field=None, input_cls=<class 'flash.audio.speech_recognition.input.SpeechRecognitionJSONInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SpeechRecognitionData
from JSON files containing audio file paths and their corresponding targets.Input audio file paths will be extracted from the
input_field
field in the JSON files. The supported file extensions are:.aiff
,.au
,.avr
,.caf
,.flac
,.mat
,.mat4
,.mat5
,.mpc2k
,.ogg
,.paf
,.pvf
,.rf64
,.ircam
,.voc
,.w64
,.wav
,.nist
, and.wavex
. The targets will be extracted from thetarget_field
field in the JSON files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str
) – The field in the JSON files containing the audio file paths.target_field¶ (
Optional
[str
]) – The field in the JSON files containing the targets.train_file¶ (
Optional
[str
]) – The JSON file to use when training.val_file¶ (
Optional
[str
]) – The JSON file to use when validating.test_file¶ (
Optional
[str
]) – The JSON file to use when testing.predict_file¶ (
Optional
[str
]) – The JSON file to use when predicting.sampling_rate¶ (
int
) – Sampling rate to use when loading the audio files.field¶ (
Optional
[str
]) – The field that holds the data in the JSON file.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
SpeechRecognitionData
.
Examples
The file
train_data.json
contains the following:{"speech_files":"speech_1.wav","targets":"some speech"} {"speech_files":"speech_2.wav","targets":"some other speech"} {"speech_files":"speech_3.wav","targets":"some more speech"}
The file
predict_data.json
contains the following:{"speech_files":"predict_speech_1.wav"} {"speech_files":"predict_speech_2.wav"} {"speech_files":"predict_speech_3.wav"}
>>> from flash import Trainer >>> from flash.audio import SpeechRecognitionData, SpeechRecognition >>> datamodule = SpeechRecognitionData.from_json( ... "speech_files", ... "targets", ... train_file="train_data.json", ... predict_file="predict_data.json", ... batch_size=2, ... ) Downloading... >>> model = SpeechRecognition(backbone="patrickvonplaten/wav2vec2_tiny_random_robust") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- input_transform_cls¶