Shortcuts

Source code for flash.core.data.utilities.loading

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import glob
import re
from functools import partial
from os import PathLike
from typing import Union
from urllib.parse import parse_qs, quote, urlencode, urlparse

import fsspec
import numpy as np
import pandas as pd
import torch

from flash.core.data.utilities.paths import has_file_allowed_extension
from flash.core.utilities.imports import _TOPIC_AUDIO_AVAILABLE, _TORCHVISION_AVAILABLE, Image

if _TOPIC_AUDIO_AVAILABLE:
    from torchaudio.transforms import Spectrogram

if _TORCHVISION_AVAILABLE:
    from torchvision.datasets.folder import IMG_EXTENSIONS
else:
    IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")

NP_EXTENSIONS = (".npy",)

AUDIO_EXTENSIONS = (
    ".aiff",
    ".au",
    ".avr",
    ".caf",
    ".flac",
    ".mat",
    ".mat4",
    ".mat5",
    ".mpc2k",
    ".ogg",
    ".paf",
    ".pvf",
    ".rf64",
    ".ircam",
    ".voc",
    ".w64",
    ".wav",
    ".nist",
    ".wavex",
)

CSV_EXTENSIONS = (".csv", ".txt")

TSV_EXTENSIONS = (".tsv",)


def _load_image_from_image(file):
    img = Image.open(file)
    img.load()

    img = img.convert("RGB")
    return img


def _load_image_from_numpy(file):
    return Image.fromarray(np.load(file).astype("uint8")).convert("RGB")


def _load_spectrogram_from_image(file):
    img = _load_image_from_image(file)
    return np.array(img).astype("float32")


def _load_spectrogram_from_numpy(file):
    return np.load(file).astype("float32")


def _load_spectrogram_from_audio(file, sampling_rate: int = 16000, n_fft: int = 400):
    # Import locally to prevent import errors if system dependencies are not available.
    import librosa
    from soundfile import SoundFile

    sound_file = SoundFile(file)
    waveform, _ = librosa.load(sound_file, sr=sampling_rate)
    return Spectrogram(n_fft, normalized=True)(torch.from_numpy(waveform).unsqueeze(0)).permute(1, 2, 0).numpy()


def _load_audio_from_audio(file, sampling_rate: int = 16000):
    # Import locally to prevent import errors if system dependencies are not available.
    import librosa

    waveform, _ = librosa.load(file, sr=sampling_rate)
    return waveform


def _load_data_frame_from_csv(file, encoding: str):
    return pd.read_csv(file, encoding=encoding)


def _load_data_frame_from_tsv(file, encoding: str):
    return pd.read_csv(file, sep="\t", encoding=encoding)


_image_loaders = {
    IMG_EXTENSIONS: _load_image_from_image,
    NP_EXTENSIONS: _load_image_from_numpy,
}


_spectrogram_loaders = {
    IMG_EXTENSIONS: _load_spectrogram_from_image,
    NP_EXTENSIONS: _load_spectrogram_from_numpy,
    AUDIO_EXTENSIONS: _load_spectrogram_from_audio,
}


_audio_loaders = {
    AUDIO_EXTENSIONS: _load_audio_from_audio,
}


_data_frame_loaders = {
    CSV_EXTENSIONS: _load_data_frame_from_csv,
    TSV_EXTENSIONS: _load_data_frame_from_tsv,
}


def _get_loader(file_path: str, loaders):
    for extensions, loader in loaders.items():
        if has_file_allowed_extension(file_path, extensions):
            return loader
    raise ValueError(
        f"File: {file_path} has an unsupported extension. Supported extensions: " f"{list(sum(loaders.keys(), ()))}."
    )


WINDOWS_FILE_PATH_RE = re.compile("^[a-zA-Z]:(\\\\[^\\\\]|/[^/]).*")


def is_local_path(file_path: str) -> bool:
    if WINDOWS_FILE_PATH_RE.fullmatch(file_path):
        return True
    return urlparse(file_path).scheme in ["", "file"]


def escape_url(url: str) -> str:
    parsed = urlparse(url)
    return f"{parsed.scheme}://{parsed.netloc}{quote(parsed.path)}?{urlencode(parse_qs(parsed.query), doseq=True)}"


def escape_file_path(file_path: Union[str, PathLike]) -> str:
    file_path_str = str(file_path)
    return glob.escape(file_path_str) if is_local_path(file_path_str) else escape_url(file_path_str)


def load(file_path: str, loaders):
    loader = _get_loader(file_path, loaders)
    # escaping file_path to avoid fsspec treating the path as a glob pattern
    # fsspec ignores `expand=False` in read mode
    with fsspec.open(escape_file_path(file_path)) as file:
        return loader(file)


[docs]def load_image(file_path: str): """Load an image from a file. Args: file_path: The image file to load. """ return load(file_path, _image_loaders)
[docs]def load_spectrogram(file_path: str, sampling_rate: int = 16000, n_fft: int = 400): """Load a spectrogram from an image or audio file. Args: file_path: The file to load. sampling_rate: The sampling rate to resample to if loading from an audio file. n_fft: The size of the FFT to use when creating a spectrogram from an audio file. """ loaders = copy.copy(_spectrogram_loaders) loaders[AUDIO_EXTENSIONS] = partial(loaders[AUDIO_EXTENSIONS], sampling_rate=sampling_rate, n_fft=n_fft) return load(file_path, loaders)
[docs]def load_audio(file_path: str, sampling_rate: int = 16000): """Load a waveform from an audio file. Args: file_path: The file to load. sampling_rate: The sampling rate to resample to. """ loaders = { extensions: partial(loader, sampling_rate=sampling_rate) for extensions, loader in _audio_loaders.items() } return load(file_path, loaders)
[docs]def load_data_frame(file_path: str, encoding: str = "utf-8"): """Load a data frame from a CSV (or similar) file. Args: file_path: The file to load. encoding: The encoding to use when reading the file. """ loaders = {extensions: partial(loader, encoding=encoding) for extensions, loader in _data_frame_loaders.items()} return load(file_path, loaders)

© Copyright 2020-2021, PyTorch Lightning. Revision a9cedb5a.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.