Shortcuts

Source code for flash.audio.speech_recognition.input

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import io
import os.path
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

from torch.utils.data import Dataset

import flash
from flash.core.data.io.input import DataKeys, Input, ServeInput
from flash.core.data.utilities.paths import filter_valid_files, list_valid_files
from flash.core.data.utilities.samples import to_sample, to_samples
from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires

if _AUDIO_AVAILABLE:
    import librosa
    from datasets import load_dataset


[docs]class SpeechRecognitionDeserializer(ServeInput): @requires("audio") def __init__(self, sampling_rate: int = 16000, **kwargs): super().__init__(**kwargs) self.sampling_rate = sampling_rate def serve_load_sample(self, sample: Any) -> Dict: encoded_with_padding = (sample + "===").encode("ascii") audio = base64.b64decode(encoded_with_padding) buffer = io.BytesIO(audio) data, sampling_rate = librosa.load(buffer, sr=self.sampling_rate) return { DataKeys.INPUT: data, DataKeys.METADATA: {"sampling_rate": sampling_rate}, } @property def example_input(self) -> str: with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f: return base64.b64encode(f.read()).decode("UTF-8")
[docs]class BaseSpeechRecognition(Input): @staticmethod def load_sample(sample: Dict[str, Any], sampling_rate: int = 16000) -> Any: path = sample[DataKeys.INPUT] if not os.path.isabs(path) and DataKeys.METADATA in sample and "root" in sample[DataKeys.METADATA]: path = os.path.join(sample[DataKeys.METADATA]["root"], path) speech_array, sampling_rate = librosa.load(path, sr=sampling_rate) sample[DataKeys.INPUT] = speech_array sample[DataKeys.METADATA] = {"sampling_rate": sampling_rate} return sample
[docs]class SpeechRecognitionFileInput(BaseSpeechRecognition): sampling_rate: int @requires("audio") def load_data( self, file: str, input_key: str, target_key: Optional[str] = None, field: Optional[str] = None, sampling_rate: int = 16000, filetype: Optional[str] = None, ) -> Sequence[Mapping[str, Any]]: self.sampling_rate = sampling_rate stage = self.running_stage.value if filetype == "json" and field is not None: dataset_dict = load_dataset(filetype, data_files={stage: str(file)}, field=field) else: dataset_dict = load_dataset(filetype, data_files={stage: str(file)}) dataset = dataset_dict[stage] meta = {"root": os.path.dirname(file)} if target_key is not None: return [ { DataKeys.INPUT: input_file, DataKeys.TARGET: target, DataKeys.METADATA: meta, } for input_file, target in zip(dataset[input_key], dataset[target_key]) ] return [ { DataKeys.INPUT: input_file, DataKeys.METADATA: meta, } for input_file in dataset[input_key] ] def load_sample(self, sample: Dict[str, Any]) -> Any: return super().load_sample(sample, self.sampling_rate)
[docs]class SpeechRecognitionCSVInput(SpeechRecognitionFileInput): @requires("audio") def load_data( self, file: str, input_key: str, target_key: Optional[str] = None, sampling_rate: int = 16000, ): return super().load_data(file, input_key, target_key, sampling_rate=sampling_rate, filetype="csv")
[docs]class SpeechRecognitionJSONInput(SpeechRecognitionFileInput): @requires("audio") def load_data( self, file: str, input_key: str, target_key: Optional[str] = None, field: Optional[str] = None, sampling_rate: int = 16000, ): return super().load_data(file, input_key, target_key, field, sampling_rate=sampling_rate, filetype="json")
[docs]class SpeechRecognitionDatasetInput(BaseSpeechRecognition): sampling_rate: int @requires("audio") def load_data(self, dataset: Dataset, sampling_rate: int = 16000) -> Sequence[Mapping[str, Any]]: self.sampling_rate = sampling_rate return super().load_data(dataset) def load_sample(self, sample: Any) -> Any: sample = to_sample(sample) if isinstance(sample[DataKeys.INPUT], (str, Path)): sample = super().load_sample(sample, self.sampling_rate) return sample
[docs]class SpeechRecognitionPathsInput(BaseSpeechRecognition): sampling_rate: int @requires("audio") def load_data( self, paths: Union[str, List[str]], targets: Optional[List[str]] = None, sampling_rate: int = 16000, ) -> Sequence: self.sampling_rate = sampling_rate if targets is None: return to_samples(list_valid_files(paths, ("wav", "ogg", "flac", "mat", "mp3"))) return to_samples(*filter_valid_files(paths, targets, valid_extensions=("wav", "ogg", "flac", "mat", "mp3"))) def load_sample(self, sample: Dict[str, Any]) -> Any: return super().load_sample(sample, self.sampling_rate)

© Copyright 2020-2021, PyTorch Lightning. Revision 8db29e8e.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.3
Versions
latest
stable
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.