Speech Recognition¶
The Task¶
Speech recognition is the task of classifying audio into a text transcription. We rely on Wav2Vec as our backbone, fine-tuned on labeled transcriptions for speech to text. Wav2Vec is pre-trained on thousand of hours of unlabeled audio, providing a strong baseline when fine-tuning to downstream tasks such as Speech Recognition.
Example¶
Let’s fine-tune the model onto our own labeled audio transcription data:
Here’s the structure our CSV file:
file,text
"/path/to/file_1.wav","what was said in file 1."
"/path/to/file_2.wav","what was said in file 2."
"/path/to/file_3.wav","what was said in file 3."
...
Alternatively, here is the structure of our JSON file:
{"file": "/path/to/file_1.wav", "text": "what was said in file 1."}
{"file": "/path/to/file_2.wav", "text": "what was said in file 2."}
{"file": "/path/to/file_3.wav", "text": "what was said in file 3."}
Once we’ve downloaded the data using download_data()
, we create the SpeechRecognitionData
.
We select a pre-trained Wav2Vec backbone to use for our SpeechRecognition
and finetune on a subset of the TIMIT corpus.
The backbone can be any Wav2Vec model from HuggingFace transformers.
Next, we use the trained SpeechRecognition
for inference and save the model.
Here’s the full example:
import flash
import torch
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")
datamodule = SpeechRecognitionData.from_json(
"file",
"text",
train_file="data/timit/train.json",
test_file="data/timit/test.json",
batch_size=4,
)
# 2. Build the task
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict on audio files!
datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("speech_recognition_model.pt")
To learn how to view the available backbones / heads for this task, see Backbones and Heads.
Flash Zero¶
The speech recognition task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash speech_recognition
To view configuration options and options for running the speech recognition task with your own data, use:
flash speech_recognition --help
Serving¶
The SpeechRecognition
is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.audio import SpeechRecognition
model = SpeechRecognition.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.7.0/speech_recognition_model.pt"
)
model.serve()
You can now perform inference from your client like this:
import base64
from pathlib import Path
import flash
import requests
with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f:
audio_str = base64.b64encode(f.read()).decode("UTF-8")
body = {"session": "UUID", "payload": {"inputs": {"data": audio_str}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())