Shortcuts

Speech Recognition

The Task

Speech recognition is the task of classifying audio into a text transcription. We rely on Wav2Vec as our backbone, fine-tuned on labeled transcriptions for speech to text. Wav2Vec is pre-trained on thousand of hours of unlabeled audio, providing a strong baseline when fine-tuning to downstream tasks such as Speech Recognition.


Example

Let’s fine-tune the model onto our own labeled audio transcription data:

Here’s the structure our CSV file:

file,text
"/path/to/file_1.wav","what was said in file 1."
"/path/to/file_2.wav","what was said in file 2."
"/path/to/file_3.wav","what was said in file 3."
...

Alternatively, here is the structure of our JSON file:

{"file": "/path/to/file_1.wav", "text": "what was said in file 1."}
{"file": "/path/to/file_2.wav", "text": "what was said in file 2."}
{"file": "/path/to/file_3.wav", "text": "what was said in file 3."}

Once we’ve downloaded the data using download_data(), we create the SpeechRecognitionData. We select a pre-trained Wav2Vec backbone to use for our SpeechRecognition and finetune on a subset of the TIMIT corpus. The backbone can be any Wav2Vec model from HuggingFace transformers. Next, we use the trained SpeechRecognition for inference and save the model. Here’s the full example:

import flash
import torch
from flash.audio import SpeechRecognition, SpeechRecognitionData
from flash.core.data.utils import download_data

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")

datamodule = SpeechRecognitionData.from_json(
    "file",
    "text",
    train_file="data/timit/train.json",
    test_file="data/timit/test.json",
    batch_size=4,
)

# 2. Build the task
model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Predict on audio files!
datamodule = SpeechRecognitionData.from_files(predict_files=["data/timit/example.wav"], batch_size=4)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("speech_recognition_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The speech recognition task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash speech_recognition

To view configuration options and options for running the speech recognition task with your own data, use:

flash speech_recognition --help

Serving

The SpeechRecognition is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.audio import SpeechRecognition

model = SpeechRecognition.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/speech_recognition_model.pt"
)
model.serve()

You can now perform inference from your client like this:

import base64
from pathlib import Path

import flash
import requests

with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f:
    audio_str = base64.b64encode(f.read()).decode("UTF-8")

body = {"session": "UUID", "payload": {"inputs": {"data": audio_str}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)

print(resp.json())