Shortcuts

Source code for flash.audio.classification.data

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Type, Union

import numpy as np
import pandas as pd
import torch

from flash.audio.classification.input import (
    AudioClassificationCSVInput,
    AudioClassificationDataFrameInput,
    AudioClassificationFilesInput,
    AudioClassificationFolderInput,
    AudioClassificationNumpyInput,
    AudioClassificationTensorInput,
)
from flash.audio.classification.input_transform import AudioClassificationInputTransform
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE
from flash.core.data.utilities.classification import TargetFormatter
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_TESTING
from flash.core.utilities.stages import RunningStage
from flash.image.classification.data import MatplotlibVisualization

# Skip doctests if requirements aren't available
if not _AUDIO_TESTING:
    __doctest_skip__ = ["AudioClassificationData", "AudioClassificationData.*"]


[docs]class AudioClassificationData(DataModule): """The ``AudioClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods for loading data for audio classification.""" input_transform_cls = AudioClassificationInputTransform input_transforms_registry = FlashRegistry("input_transforms")
[docs] @classmethod def from_files( cls, train_files: Optional[Sequence[str]] = None, train_targets: Optional[Sequence[Any]] = None, val_files: Optional[Sequence[str]] = None, val_targets: Optional[Sequence[Any]] = None, test_files: Optional[Sequence[str]] = None, test_targets: Optional[Sequence[Any]] = None, predict_files: Optional[Sequence[str]] = None, train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationFilesInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from lists of files and corresponding lists of targets. The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. The targets can be in any of our :ref:`supported classification target formats <formatting_classification_targets>`. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: train_files: The list of spectrogram image files to use when training. train_targets: The list of targets to use when training. val_files: The list of spectrogram image files to use when validating. val_targets: The list of targets to use when validating. test_files: The list of spectrogram image files to use when testing. test_targets: The list of targets to use when testing. predict_files: The list of spectrogram image files to use when predicting. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.audio.classification.data.AudioClassificationData`. Examples ________ .. testsetup:: >>> from PIL import Image >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) >>> _ = [rand_image.save(f"spectrogram_{i}.png") for i in range(1, 4)] >>> _ = [rand_image.save(f"predict_spectrogram_{i}.png") for i in range(1, 4)] .. doctest:: >>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> datamodule = AudioClassificationData.from_files( ... train_files=["spectrogram_1.png", "spectrogram_2.png", "spectrogram_3.png"], ... train_targets=["meow", "bark", "meow"], ... predict_files=[ ... "predict_spectrogram_1.png", ... "predict_spectrogram_2.png", ... "predict_spectrogram_3.png", ... ], ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... .. testcleanup:: >>> import os >>> _ = [os.remove(f"spectrogram_{i}.png") for i in range(1, 4)] >>> _ = [os.remove(f"predict_spectrogram_{i}.png") for i in range(1, 4)] """ ds_kw = dict( target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, input_cls(RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_files, test_targets, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), **data_module_kwargs, )
[docs] @classmethod def from_folders( cls, train_folder: Optional[str] = None, val_folder: Optional[str] = None, test_folder: Optional[str] = None, predict_folder: Optional[str] = None, train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationFolderInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from folders containing spectrogram images. The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. For train, test, and validation data, the folders are expected to contain a sub-folder for each class. Here's the required structure: .. code-block:: train_folder ├── meow │ ├── spectrogram_1.png │ ├── spectrogram_3.png │ ... └── bark ├── spectrogram_2.png ... For prediction, the folder is expected to contain the files for inference, like this: .. code-block:: predict_folder ├── predict_spectrogram_1.png ├── predict_spectrogram_2.png ├── predict_spectrogram_3.png ... To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: train_folder: The folder containing spectrogram images to use when training. val_folder: The folder containing spectrogram images to use when validating. test_folder: The folder containing spectrogram images to use when testing. predict_folder: The folder containing spectrogram images to use when predicting. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.classification.data.ImageClassificationData`. Examples ________ .. testsetup:: >>> import os >>> from PIL import Image >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) >>> os.makedirs(os.path.join("train_folder", "meow"), exist_ok=True) >>> os.makedirs(os.path.join("train_folder", "bark"), exist_ok=True) >>> os.makedirs("predict_folder", exist_ok=True) >>> rand_image.save(os.path.join("train_folder", "meow", "spectrogram_1.png")) >>> rand_image.save(os.path.join("train_folder", "bark", "spectrogram_2.png")) >>> rand_image.save(os.path.join("train_folder", "meow", "spectrogram_3.png")) >>> _ = [rand_image.save( ... os.path.join("predict_folder", f"predict_spectrogram_{i}.png") ... ) for i in range(1, 4)] .. doctest:: >>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> datamodule = AudioClassificationData.from_folders( ... train_folder="train_folder", ... predict_folder="predict_folder", ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... .. testcleanup:: >>> import shutil >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") """ ds_kw = dict( target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) train_input = input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, input_cls(RunningStage.VALIDATING, val_folder, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_folder, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), **data_module_kwargs, )
[docs] @classmethod def from_numpy( cls, train_data: Optional[Collection[np.ndarray]] = None, train_targets: Optional[Collection[Any]] = None, val_data: Optional[Collection[np.ndarray]] = None, val_targets: Optional[Sequence[Any]] = None, test_data: Optional[Collection[np.ndarray]] = None, test_targets: Optional[Sequence[Any]] = None, predict_data: Optional[Collection[np.ndarray]] = None, train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationNumpyInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from numpy arrays (or lists of arrays) and corresponding lists of targets. The targets can be in any of our :ref:`supported classification target formats <formatting_classification_targets>`. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: train_data: The numpy array or list of arrays to use when training. train_targets: The list of targets to use when training. val_data: The numpy array or list of arrays to use when validating. val_targets: The list of targets to use when validating. test_data: The numpy array or list of arrays to use when testing. test_targets: The list of targets to use when testing. predict_data: The numpy array or list of arrays to use when predicting. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.audio.classification.data.AudioClassificationData`. Examples ________ .. doctest:: >>> import numpy as np >>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> datamodule = AudioClassificationData.from_numpy( ... train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)], ... train_targets=["meow", "bark", "meow"], ... predict_data=[np.random.rand(3, 64, 64)], ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... """ ds_kw = dict( target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), **data_module_kwargs, )
[docs] @classmethod def from_tensors( cls, train_data: Optional[Collection[torch.Tensor]] = None, train_targets: Optional[Collection[Any]] = None, val_data: Optional[Collection[torch.Tensor]] = None, val_targets: Optional[Sequence[Any]] = None, test_data: Optional[Collection[torch.Tensor]] = None, test_targets: Optional[Sequence[Any]] = None, predict_data: Optional[Collection[torch.Tensor]] = None, train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationTensorInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from torch tensors (or lists of tensors) and corresponding lists of targets. The targets can be in any of our :ref:`supported classification target formats <formatting_classification_targets>`. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: train_data: The torch tensor or list of tensors to use when training. train_targets: The list of targets to use when training. val_data: The torch tensor or list of tensors to use when validating. val_targets: The list of targets to use when validating. test_data: The torch tensor or list of tensors to use when testing. test_targets: The list of targets to use when testing. predict_data: The torch tensor or list of tensors to use when predicting. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.audio.classification.data.AudioClassificationData`. Examples ________ .. doctest:: >>> import torch >>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> datamodule = AudioClassificationData.from_tensors( ... train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)], ... train_targets=["meow", "bark", "meow"], ... predict_data=[torch.rand(3, 64, 64)], ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... """ ds_kw = dict( target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), **data_module_kwargs, )
[docs] @classmethod def from_data_frame( cls, input_field: str, target_fields: Optional[Union[str, Sequence[str]]] = None, train_data_frame: Optional[pd.DataFrame] = None, train_images_root: Optional[str] = None, train_resolver: Optional[Callable[[str, str], str]] = None, val_data_frame: Optional[pd.DataFrame] = None, val_images_root: Optional[str] = None, val_resolver: Optional[Callable[[str, str], str]] = None, test_data_frame: Optional[pd.DataFrame] = None, test_images_root: Optional[str] = None, test_resolver: Optional[Callable[[str, str], str]] = None, predict_data_frame: Optional[pd.DataFrame] = None, predict_images_root: Optional[str] = None, predict_resolver: Optional[Callable[[str, str], str]] = None, train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationDataFrameInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from pandas DataFrame objects containing spectrogram image file paths and their corresponding targets. Input spectrogram image paths will be extracted from the ``input_field`` in the DataFrame. The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. The targets will be extracted from the ``target_fields`` in the DataFrame and can be in any of our :ref:`supported classification target formats <formatting_classification_targets>`. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: input_field: The field (column name) in the DataFrames containing the spectrogram image file paths. target_fields: The field (column name) or list of fields in the DataFrames containing the targets. train_data_frame: The pandas DataFrame to use when training. train_images_root: The root directory containing train spectrogram images. train_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a spectrogram image file path. val_data_frame: The pandas DataFrame to use when validating. val_images_root: The root directory containing validation spectrogram images. val_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a spectrogram image file path. test_data_frame: The pandas DataFrame to use when testing. test_images_root: The root directory containing test spectrogram images. test_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a spectrogram image file path. predict_data_frame: The pandas DataFrame to use when predicting. predict_images_root: The root directory containing predict spectrogram images. predict_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a spectrogram image file path. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.image.classification.data.ImageClassificationData`. Examples ________ .. testsetup:: >>> import os >>> from PIL import Image >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) >>> os.makedirs("train_folder", exist_ok=True) >>> os.makedirs("predict_folder", exist_ok=True) >>> _ = [rand_image.save(os.path.join("train_folder", f"spectrogram_{i}.png")) for i in range(1, 4)] >>> _ = [rand_image.save( ... os.path.join("predict_folder", f"predict_spectrogram_{i}.png") ... ) for i in range(1, 4)] .. doctest:: >>> from pandas import DataFrame >>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> train_data_frame = DataFrame.from_dict( ... { ... "images": ["spectrogram_1.png", "spectrogram_2.png", "spectrogram_3.png"], ... "targets": ["meow", "bark", "meow"], ... } ... ) >>> predict_data_frame = DataFrame.from_dict( ... { ... "images": [ ... "predict_spectrogram_1.png", ... "predict_spectrogram_2.png", ... "predict_spectrogram_3.png", ... ], ... } ... ) >>> datamodule = AudioClassificationData.from_data_frame( ... "images", ... "targets", ... train_data_frame=train_data_frame, ... train_images_root="train_folder", ... predict_data_frame=predict_data_frame, ... predict_images_root="predict_folder", ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... .. testcleanup:: >>> import shutil >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") >>> del train_data_frame >>> del predict_data_frame """ ds_kw = dict( target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) train_data = (train_data_frame, input_field, target_fields, train_images_root, train_resolver) val_data = (val_data_frame, input_field, target_fields, val_images_root, val_resolver) test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver) predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver) train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), **data_module_kwargs, )
[docs] @classmethod def from_csv( cls, input_field: str, target_fields: Optional[Union[str, List[str]]] = None, train_file: Optional[PATH_TYPE] = None, train_images_root: Optional[PATH_TYPE] = None, train_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, val_file: Optional[PATH_TYPE] = None, val_images_root: Optional[PATH_TYPE] = None, val_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, test_file: Optional[str] = None, test_images_root: Optional[str] = None, test_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, predict_file: Optional[str] = None, predict_images_root: Optional[str] = None, predict_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None, train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = AudioClassificationCSVInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs: Any, ) -> "AudioClassificationData": """Load the :class:`~flash.audio.classification.data.AudioClassificationData` from CSV files containing spectrogram image file paths and their corresponding targets. Input spectrogram images will be extracted from the ``input_field`` column in the CSV files. The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``, ``.tiff``, ``.webp``, and ``.npy``. The targets will be extracted from the ``target_fields`` in the CSV files and can be in any of our :ref:`supported classification target formats <formatting_classification_targets>`. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: input_field: The field (column name) in the CSV files containing the spectrogram image file paths. target_fields: The field (column name) or list of fields in the CSV files containing the targets. train_file: The CSV file to use when training. train_images_root: The root directory containing train spectrogram images. train_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a spectrogram image file path. val_file: The CSV file to use when validating. val_images_root: The root directory containing validation spectrogram images. val_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a spectrogram image file path. test_file: The CSV file to use when testing. test_images_root: The root directory containing test spectrogram images. test_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a spectrogram image file path. predict_file: The CSV file to use when predicting. predict_images_root: The root directory containing predict spectrogram images. predict_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a spectrogram image file path. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. See :ref:`formatting_classification_targets` for more details. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.audio.classification.data.AudioClassificationData`. Examples ________ .. testsetup:: >>> import os >>> from PIL import Image >>> from pandas import DataFrame >>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) >>> os.makedirs("train_folder", exist_ok=True) >>> os.makedirs("predict_folder", exist_ok=True) >>> _ = [rand_image.save(os.path.join("train_folder", f"spectrogram_{i}.png")) for i in range(1, 4)] >>> _ = [rand_image.save( ... os.path.join("predict_folder", f"predict_spectrogram_{i}.png") ... ) for i in range(1, 4)] >>> DataFrame.from_dict({ ... "images": ["spectrogram_1.png", "spectrogram_2.png", "spectrogram_3.png"], ... "targets": ["meow", "bark", "meow"], ... }).to_csv("train_data.csv", index=False) >>> DataFrame.from_dict({ ... "images": ["predict_spectrogram_1.png", "predict_spectrogram_2.png", "predict_spectrogram_3.png"], ... }).to_csv("predict_data.csv", index=False) The file ``train_data.csv`` contains the following: .. code-block:: images,targets spectrogram_1.png,meow spectrogram_2.png,bark spectrogram_3.png,meow The file ``predict_data.csv`` contains the following: .. code-block:: images predict_spectrogram_1.png predict_spectrogram_2.png predict_spectrogram_3.png .. doctest:: >>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> datamodule = AudioClassificationData.from_csv( ... "images", ... "targets", ... train_file="train_data.csv", ... train_images_root="train_folder", ... predict_file="predict_data.csv", ... predict_images_root="predict_folder", ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... .. testcleanup:: >>> import shutil >>> shutil.rmtree("train_folder") >>> shutil.rmtree("predict_folder") >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") """ ds_kw = dict( target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) train_data = (train_file, input_field, target_fields, train_images_root, train_resolver) val_data = (val_file, input_field, target_fields, val_images_root, val_resolver) test_data = (test_file, input_field, target_fields, test_images_root, test_resolver) predict_data = (predict_file, input_field, None, predict_images_root, predict_resolver) train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), **data_module_kwargs, )
[docs] def set_block_viz_window(self, value: bool) -> None: """Setter method to switch on/off matplotlib to pop up windows.""" self.data_fetcher.block_viz_window = value
@staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: return MatplotlibVisualization(*args, **kwargs)

© Copyright 2020-2021, PyTorch Lightning. Revision 1c7d8e08.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.4
Versions
latest
stable
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.