Source code for flash.audio.classification.data
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Collection, Dict, List, Optional, Sequence, Type, Union
import numpy as np
import pandas as pd
import torch
from flash.audio.classification.input import (
AudioClassificationCSVInput,
AudioClassificationDataFrameInput,
AudioClassificationFilesInput,
AudioClassificationFolderInput,
AudioClassificationNumpyInput,
AudioClassificationTensorInput,
)
from flash.audio.classification.input_transform import AudioClassificationInputTransform
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE
from flash.core.data.utilities.classification import TargetFormatter
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_TESTING
from flash.core.utilities.stages import RunningStage
from flash.image.classification.data import MatplotlibVisualization
# Skip doctests if requirements aren't available
if not _AUDIO_TESTING:
__doctest_skip__ = ["AudioClassificationData", "AudioClassificationData.*"]
[docs]class AudioClassificationData(DataModule):
"""The ``AudioClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of
classmethods for loading data for audio classification."""
input_transform_cls = AudioClassificationInputTransform
input_transforms_registry = FlashRegistry("input_transforms")
[docs] @classmethod
def from_files(
cls,
train_files: Optional[Sequence[str]] = None,
train_targets: Optional[Sequence[Any]] = None,
val_files: Optional[Sequence[str]] = None,
val_targets: Optional[Sequence[Any]] = None,
test_files: Optional[Sequence[str]] = None,
test_targets: Optional[Sequence[Any]] = None,
predict_files: Optional[Sequence[str]] = None,
train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = AudioClassificationFilesInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "AudioClassificationData":
"""Load the :class:`~flash.audio.classification.data.AudioClassificationData` from lists of files and
corresponding lists of targets.
The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``,
``.tiff``, ``.webp``, and ``.npy``.
The targets can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_files: The list of spectrogram image files to use when training.
train_targets: The list of targets to use when training.
val_files: The list of spectrogram image files to use when validating.
val_targets: The list of targets to use when validating.
test_files: The list of spectrogram image files to use when testing.
test_targets: The list of targets to use when testing.
predict_files: The list of spectrogram image files to use when predicting.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.audio.classification.data.AudioClassificationData`.
Examples
________
.. testsetup::
>>> from PIL import Image
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> _ = [rand_image.save(f"spectrogram_{i}.png") for i in range(1, 4)]
>>> _ = [rand_image.save(f"predict_spectrogram_{i}.png") for i in range(1, 4)]
.. doctest::
>>> from flash import Trainer
>>> from flash.audio import AudioClassificationData
>>> from flash.image import ImageClassifier
>>> datamodule = AudioClassificationData.from_files(
... train_files=["spectrogram_1.png", "spectrogram_2.png", "spectrogram_3.png"],
... train_targets=["meow", "bark", "meow"],
... predict_files=[
... "predict_spectrogram_1.png",
... "predict_spectrogram_2.png",
... "predict_spectrogram_3.png",
... ],
... transform_kwargs=dict(spectrogram_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['bark', 'meow']
>>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> import os
>>> _ = [os.remove(f"spectrogram_{i}.png") for i in range(1, 4)]
>>> _ = [os.remove(f"predict_spectrogram_{i}.png") for i in range(1, 4)]
"""
ds_kw = dict(
target_formatter=target_formatter,
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)
train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)
return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_files, test_targets, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)
[docs] @classmethod
def from_folders(
cls,
train_folder: Optional[str] = None,
val_folder: Optional[str] = None,
test_folder: Optional[str] = None,
predict_folder: Optional[str] = None,
train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = AudioClassificationFolderInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "AudioClassificationData":
"""Load the :class:`~flash.audio.classification.data.AudioClassificationData` from folders containing
spectrogram images.
The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``,
``.tiff``, ``.webp``, and ``.npy``.
For train, test, and validation data, the folders are expected to contain a sub-folder for each class.
Here's the required structure:
.. code-block::
train_folder
├── meow
│ ├── spectrogram_1.png
│ ├── spectrogram_3.png
│ ...
└── bark
├── spectrogram_2.png
...
For prediction, the folder is expected to contain the files for inference, like this:
.. code-block::
predict_folder
├── predict_spectrogram_1.png
├── predict_spectrogram_2.png
├── predict_spectrogram_3.png
...
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_folder: The folder containing spectrogram images to use when training.
val_folder: The folder containing spectrogram images to use when validating.
test_folder: The folder containing spectrogram images to use when testing.
predict_folder: The folder containing spectrogram images to use when predicting.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.image.classification.data.ImageClassificationData`.
Examples
________
.. testsetup::
>>> import os
>>> from PIL import Image
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> os.makedirs(os.path.join("train_folder", "meow"), exist_ok=True)
>>> os.makedirs(os.path.join("train_folder", "bark"), exist_ok=True)
>>> os.makedirs("predict_folder", exist_ok=True)
>>> rand_image.save(os.path.join("train_folder", "meow", "spectrogram_1.png"))
>>> rand_image.save(os.path.join("train_folder", "bark", "spectrogram_2.png"))
>>> rand_image.save(os.path.join("train_folder", "meow", "spectrogram_3.png"))
>>> _ = [rand_image.save(
... os.path.join("predict_folder", f"predict_spectrogram_{i}.png")
... ) for i in range(1, 4)]
.. doctest::
>>> from flash import Trainer
>>> from flash.audio import AudioClassificationData
>>> from flash.image import ImageClassifier
>>> datamodule = AudioClassificationData.from_folders(
... train_folder="train_folder",
... predict_folder="predict_folder",
... transform_kwargs=dict(spectrogram_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['bark', 'meow']
>>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> import shutil
>>> shutil.rmtree("train_folder")
>>> shutil.rmtree("predict_folder")
"""
ds_kw = dict(
target_formatter=target_formatter,
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)
train_input = input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)
return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_folder, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_folder, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)
[docs] @classmethod
def from_numpy(
cls,
train_data: Optional[Collection[np.ndarray]] = None,
train_targets: Optional[Collection[Any]] = None,
val_data: Optional[Collection[np.ndarray]] = None,
val_targets: Optional[Sequence[Any]] = None,
test_data: Optional[Collection[np.ndarray]] = None,
test_targets: Optional[Sequence[Any]] = None,
predict_data: Optional[Collection[np.ndarray]] = None,
train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = AudioClassificationNumpyInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "AudioClassificationData":
"""Load the :class:`~flash.audio.classification.data.AudioClassificationData` from numpy arrays (or lists
of arrays) and corresponding lists of targets.
The targets can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_data: The numpy array or list of arrays to use when training.
train_targets: The list of targets to use when training.
val_data: The numpy array or list of arrays to use when validating.
val_targets: The list of targets to use when validating.
test_data: The numpy array or list of arrays to use when testing.
test_targets: The list of targets to use when testing.
predict_data: The numpy array or list of arrays to use when predicting.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.audio.classification.data.AudioClassificationData`.
Examples
________
.. doctest::
>>> import numpy as np
>>> from flash import Trainer
>>> from flash.audio import AudioClassificationData
>>> from flash.image import ImageClassifier
>>> datamodule = AudioClassificationData.from_numpy(
... train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)],
... train_targets=["meow", "bark", "meow"],
... predict_data=[np.random.rand(3, 64, 64)],
... transform_kwargs=dict(spectrogram_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['bark', 'meow']
>>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
"""
ds_kw = dict(
target_formatter=target_formatter,
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)
train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)
return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)
[docs] @classmethod
def from_tensors(
cls,
train_data: Optional[Collection[torch.Tensor]] = None,
train_targets: Optional[Collection[Any]] = None,
val_data: Optional[Collection[torch.Tensor]] = None,
val_targets: Optional[Sequence[Any]] = None,
test_data: Optional[Collection[torch.Tensor]] = None,
test_targets: Optional[Sequence[Any]] = None,
predict_data: Optional[Collection[torch.Tensor]] = None,
train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = AudioClassificationTensorInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "AudioClassificationData":
"""Load the :class:`~flash.audio.classification.data.AudioClassificationData` from torch tensors (or lists
of tensors) and corresponding lists of targets.
The targets can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
train_data: The torch tensor or list of tensors to use when training.
train_targets: The list of targets to use when training.
val_data: The torch tensor or list of tensors to use when validating.
val_targets: The list of targets to use when validating.
test_data: The torch tensor or list of tensors to use when testing.
test_targets: The list of targets to use when testing.
predict_data: The torch tensor or list of tensors to use when predicting.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.audio.classification.data.AudioClassificationData`.
Examples
________
.. doctest::
>>> import torch
>>> from flash import Trainer
>>> from flash.audio import AudioClassificationData
>>> from flash.image import ImageClassifier
>>> datamodule = AudioClassificationData.from_tensors(
... train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)],
... train_targets=["meow", "bark", "meow"],
... predict_data=[torch.rand(3, 64, 64)],
... transform_kwargs=dict(spectrogram_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['bark', 'meow']
>>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
"""
ds_kw = dict(
target_formatter=target_formatter,
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)
train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)
return cls(
train_input,
input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)
[docs] @classmethod
def from_data_frame(
cls,
input_field: str,
target_fields: Optional[Union[str, Sequence[str]]] = None,
train_data_frame: Optional[pd.DataFrame] = None,
train_images_root: Optional[str] = None,
train_resolver: Optional[Callable[[str, str], str]] = None,
val_data_frame: Optional[pd.DataFrame] = None,
val_images_root: Optional[str] = None,
val_resolver: Optional[Callable[[str, str], str]] = None,
test_data_frame: Optional[pd.DataFrame] = None,
test_images_root: Optional[str] = None,
test_resolver: Optional[Callable[[str, str], str]] = None,
predict_data_frame: Optional[pd.DataFrame] = None,
predict_images_root: Optional[str] = None,
predict_resolver: Optional[Callable[[str, str], str]] = None,
train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = AudioClassificationDataFrameInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "AudioClassificationData":
"""Load the :class:`~flash.audio.classification.data.AudioClassificationData` from pandas DataFrame objects
containing spectrogram image file paths and their corresponding targets.
Input spectrogram image paths will be extracted from the ``input_field`` in the DataFrame.
The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``,
``.tiff``, ``.webp``, and ``.npy``.
The targets will be extracted from the ``target_fields`` in the DataFrame and can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
input_field: The field (column name) in the DataFrames containing the spectrogram image file paths.
target_fields: The field (column name) or list of fields in the DataFrames containing the targets.
train_data_frame: The pandas DataFrame to use when training.
train_images_root: The root directory containing train spectrogram images.
train_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a
spectrogram image file path.
val_data_frame: The pandas DataFrame to use when validating.
val_images_root: The root directory containing validation spectrogram images.
val_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a
spectrogram image file path.
test_data_frame: The pandas DataFrame to use when testing.
test_images_root: The root directory containing test spectrogram images.
test_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a
spectrogram image file path.
predict_data_frame: The pandas DataFrame to use when predicting.
predict_images_root: The root directory containing predict spectrogram images.
predict_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a
spectrogram image file path.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.image.classification.data.ImageClassificationData`.
Examples
________
.. testsetup::
>>> import os
>>> from PIL import Image
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> os.makedirs("train_folder", exist_ok=True)
>>> os.makedirs("predict_folder", exist_ok=True)
>>> _ = [rand_image.save(os.path.join("train_folder", f"spectrogram_{i}.png")) for i in range(1, 4)]
>>> _ = [rand_image.save(
... os.path.join("predict_folder", f"predict_spectrogram_{i}.png")
... ) for i in range(1, 4)]
.. doctest::
>>> from pandas import DataFrame
>>> from flash import Trainer
>>> from flash.audio import AudioClassificationData
>>> from flash.image import ImageClassifier
>>> train_data_frame = DataFrame.from_dict(
... {
... "images": ["spectrogram_1.png", "spectrogram_2.png", "spectrogram_3.png"],
... "targets": ["meow", "bark", "meow"],
... }
... )
>>> predict_data_frame = DataFrame.from_dict(
... {
... "images": [
... "predict_spectrogram_1.png",
... "predict_spectrogram_2.png",
... "predict_spectrogram_3.png",
... ],
... }
... )
>>> datamodule = AudioClassificationData.from_data_frame(
... "images",
... "targets",
... train_data_frame=train_data_frame,
... train_images_root="train_folder",
... predict_data_frame=predict_data_frame,
... predict_images_root="predict_folder",
... transform_kwargs=dict(spectrogram_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['bark', 'meow']
>>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> import shutil
>>> shutil.rmtree("train_folder")
>>> shutil.rmtree("predict_folder")
>>> del train_data_frame
>>> del predict_data_frame
"""
ds_kw = dict(
target_formatter=target_formatter,
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)
train_data = (train_data_frame, input_field, target_fields, train_images_root, train_resolver)
val_data = (val_data_frame, input_field, target_fields, val_images_root, val_resolver)
test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver)
predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver)
train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)
return cls(
train_input,
input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)
[docs] @classmethod
def from_csv(
cls,
input_field: str,
target_fields: Optional[Union[str, List[str]]] = None,
train_file: Optional[PATH_TYPE] = None,
train_images_root: Optional[PATH_TYPE] = None,
train_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None,
val_file: Optional[PATH_TYPE] = None,
val_images_root: Optional[PATH_TYPE] = None,
val_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None,
test_file: Optional[str] = None,
test_images_root: Optional[str] = None,
test_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None,
predict_file: Optional[str] = None,
predict_images_root: Optional[str] = None,
predict_resolver: Optional[Callable[[PATH_TYPE, Any], PATH_TYPE]] = None,
train_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = AudioClassificationInputTransform,
target_formatter: Optional[TargetFormatter] = None,
input_cls: Type[Input] = AudioClassificationCSVInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "AudioClassificationData":
"""Load the :class:`~flash.audio.classification.data.AudioClassificationData` from CSV files containing
spectrogram image file paths and their corresponding targets.
Input spectrogram images will be extracted from the ``input_field`` column in the CSV files.
The supported file extensions are: ``.jpg``, ``.jpeg``, ``.png``, ``.ppm``, ``.bmp``, ``.pgm``, ``.tif``,
``.tiff``, ``.webp``, and ``.npy``.
The targets will be extracted from the ``target_fields`` in the CSV files and can be in any of our
:ref:`supported classification target formats <formatting_classification_targets>`.
To learn how to customize the transforms applied for each stage, read our
:ref:`customizing transforms guide <customizing_transforms>`.
Args:
input_field: The field (column name) in the CSV files containing the spectrogram image file paths.
target_fields: The field (column name) or list of fields in the CSV files containing the targets.
train_file: The CSV file to use when training.
train_images_root: The root directory containing train spectrogram images.
train_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a
spectrogram image file path.
val_file: The CSV file to use when validating.
val_images_root: The root directory containing validation spectrogram images.
val_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a
spectrogram image file path.
test_file: The CSV file to use when testing.
test_images_root: The root directory containing test spectrogram images.
test_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a
spectrogram image file path.
predict_file: The CSV file to use when predicting.
predict_images_root: The root directory containing predict spectrogram images.
predict_resolver: Optionally provide a function which converts an entry from the ``input_field`` into a
spectrogram image file path.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to
control how targets are handled. See :ref:`formatting_classification_targets` for more details.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Returns:
The constructed :class:`~flash.audio.classification.data.AudioClassificationData`.
Examples
________
.. testsetup::
>>> import os
>>> from PIL import Image
>>> from pandas import DataFrame
>>> rand_image = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))
>>> os.makedirs("train_folder", exist_ok=True)
>>> os.makedirs("predict_folder", exist_ok=True)
>>> _ = [rand_image.save(os.path.join("train_folder", f"spectrogram_{i}.png")) for i in range(1, 4)]
>>> _ = [rand_image.save(
... os.path.join("predict_folder", f"predict_spectrogram_{i}.png")
... ) for i in range(1, 4)]
>>> DataFrame.from_dict({
... "images": ["spectrogram_1.png", "spectrogram_2.png", "spectrogram_3.png"],
... "targets": ["meow", "bark", "meow"],
... }).to_csv("train_data.csv", index=False)
>>> DataFrame.from_dict({
... "images": ["predict_spectrogram_1.png", "predict_spectrogram_2.png", "predict_spectrogram_3.png"],
... }).to_csv("predict_data.csv", index=False)
The file ``train_data.csv`` contains the following:
.. code-block::
images,targets
spectrogram_1.png,meow
spectrogram_2.png,bark
spectrogram_3.png,meow
The file ``predict_data.csv`` contains the following:
.. code-block::
images
predict_spectrogram_1.png
predict_spectrogram_2.png
predict_spectrogram_3.png
.. doctest::
>>> from flash import Trainer
>>> from flash.audio import AudioClassificationData
>>> from flash.image import ImageClassifier
>>> datamodule = AudioClassificationData.from_csv(
... "images",
... "targets",
... train_file="train_data.csv",
... train_images_root="train_folder",
... predict_file="predict_data.csv",
... predict_images_root="predict_folder",
... transform_kwargs=dict(spectrogram_size=(128, 128)),
... batch_size=2,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['bark', 'meow']
>>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> import shutil
>>> shutil.rmtree("train_folder")
>>> shutil.rmtree("predict_folder")
>>> os.remove("train_data.csv")
>>> os.remove("predict_data.csv")
"""
ds_kw = dict(
target_formatter=target_formatter,
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)
train_data = (train_file, input_field, target_fields, train_images_root, train_resolver)
val_data = (val_file, input_field, target_fields, val_images_root, val_resolver)
test_data = (test_file, input_field, target_fields, test_images_root, test_resolver)
predict_data = (predict_file, input_field, None, predict_images_root, predict_resolver)
train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)
return cls(
train_input,
input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw),
**data_module_kwargs,
)
[docs] def set_block_viz_window(self, value: bool) -> None:
"""Setter method to switch on/off matplotlib to pop up windows."""
self.data_fetcher.block_viz_window = value
@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return MatplotlibVisualization(*args, **kwargs)