AudioClassificationData¶
- class flash.audio.classification.data.AudioClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
AudioClassificationData
class is aDataModule
with a set of class methods for loading data for audio classification.- classmethod from_csv(input_field, target_fields=None, train_file=None, train_images_root=None, train_resolver=None, val_file=None, val_images_root=None, val_resolver=None, test_file=None, test_images_root=None, test_resolver=None, predict_file=None, predict_images_root=None, predict_resolver=None, sampling_rate=16000, n_fft=400, input_cls=<class 'flash.audio.classification.input.AudioClassificationCSVInput'>, transform=<class 'flash.audio.classification.input_transform.AudioClassificationInputTransform'>, transform_kwargs=None, target_formatter=None, **data_module_kwargs)[source]¶
Load the
AudioClassificationData
from CSV files containing spectrogram image file paths and their corresponding targets.Input spectrogram images will be extracted from the
input_field
column in the CSV files. The supported file extensions for precomputed spectrograms are:.jpg
,.jpeg
,.png
,.ppm
,.bmp
,.pgm
,.tif
,.tiff
,.webp
, and.npy
. The supported file extensions for raw audio (where spectrograms will be computed automatically) are:.aiff
,.au
,.avr
,.caf
,.flac
,.mat
,.mat4
,.mat5
,.mpc2k
,.ogg
,.paf
,.pvf
,.rf64
,.ircam
,.voc
,.w64
,.wav
,.nist
, and.wavex
. The targets will be extracted from thetarget_fields
in the CSV files and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str
) – The field (column name) in the CSV files containing the spectrogram image file paths.target_fields¶ (
Union
[str
,List
[str
],None
]) – The field (column name) or list of fields in the CSV files containing the targets.train_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The CSV file to use when training.train_images_root¶ (
Union
[str
,bytes
,PathLike
,None
]) – The root directory containing train spectrogram images.train_resolver¶ (
Optional
[Callable
[[Union
[str
,bytes
,PathLike
],Any
],Union
[str
,bytes
,PathLike
]]]) – Optionally provide a function which converts an entry from theinput_field
into a spectrogram image file path.val_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The CSV file to use when validating.val_images_root¶ (
Union
[str
,bytes
,PathLike
,None
]) – The root directory containing validation spectrogram images.val_resolver¶ (
Optional
[Callable
[[Union
[str
,bytes
,PathLike
],Any
],Union
[str
,bytes
,PathLike
]]]) – Optionally provide a function which converts an entry from theinput_field
into a spectrogram image file path.test_file¶ (
Optional
[str
]) – The CSV file to use when testing.test_images_root¶ (
Optional
[str
]) – The root directory containing test spectrogram images.test_resolver¶ (
Optional
[Callable
[[Union
[str
,bytes
,PathLike
],Any
],Union
[str
,bytes
,PathLike
]]]) – Optionally provide a function which converts an entry from theinput_field
into a spectrogram image file path.predict_file¶ (
Optional
[str
]) – The CSV file to use when predicting.predict_images_root¶ (
Optional
[str
]) – The root directory containing predict spectrogram images.predict_resolver¶ (
Optional
[Callable
[[Union
[str
,bytes
,PathLike
],Any
],Union
[str
,bytes
,PathLike
]]]) – Optionally provide a function which converts an entry from theinput_field
into a spectrogram image file path.sampling_rate¶ (
int
) – Sampling rate to use when loading raw audio files.n_fft¶ (
int
) – The size of the FFT to use when creating spectrograms from raw audio.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
AudioClassificationData
.
Examples
The files can be in Comma Separated Values (CSV) format with either a
.csv
or.txt
extension.The file
train_data.csv
contains the following:images,targets spectrogram_1.png,meow spectrogram_2.png,bark spectrogram_3.png,meow
The file
predict_data.csv
contains the following:images predict_spectrogram_1.png predict_spectrogram_2.png predict_spectrogram_3.png
>>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> datamodule = AudioClassificationData.from_csv( ... "images", ... "targets", ... train_file="train_data.csv", ... train_images_root="train_folder", ... predict_file="predict_data.csv", ... predict_images_root="predict_folder", ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
Alternatively, the files can be in Tab Separated Values (TSV) format with a
.tsv
extension.The file
train_data.tsv
contains the following:images targets spectrogram_1.png meow spectrogram_2.png bark spectrogram_3.png meow
The file
predict_data.tsv
contains the following:images predict_spectrogram_1.png predict_spectrogram_2.png predict_spectrogram_3.png
>>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> datamodule = AudioClassificationData.from_csv( ... "images", ... "targets", ... train_file="train_data.tsv", ... train_images_root="train_folder", ... predict_file="predict_data.tsv", ... predict_images_root="predict_folder", ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_data_frame(input_field, target_fields=None, train_data_frame=None, train_images_root=None, train_resolver=None, val_data_frame=None, val_images_root=None, val_resolver=None, test_data_frame=None, test_images_root=None, test_resolver=None, predict_data_frame=None, predict_images_root=None, predict_resolver=None, sampling_rate=16000, n_fft=400, input_cls=<class 'flash.audio.classification.input.AudioClassificationDataFrameInput'>, transform=<class 'flash.audio.classification.input_transform.AudioClassificationInputTransform'>, transform_kwargs=None, target_formatter=None, **data_module_kwargs)[source]¶
Load the
AudioClassificationData
from pandas DataFrame objects containing spectrogram image file paths and their corresponding targets.Input spectrogram image paths will be extracted from the
input_field
in the DataFrame. The supported file extensions for precomputed spectrograms are:.jpg
,.jpeg
,.png
,.ppm
,.bmp
,.pgm
,.tif
,.tiff
,.webp
, and.npy
. The supported file extensions for raw audio (where spectrograms will be computed automatically) are:.aiff
,.au
,.avr
,.caf
,.flac
,.mat
,.mat4
,.mat5
,.mpc2k
,.ogg
,.paf
,.pvf
,.rf64
,.ircam
,.voc
,.w64
,.wav
,.nist
, and.wavex
. The targets will be extracted from thetarget_fields
in the DataFrame and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str
) – The field (column name) in the DataFrames containing the spectrogram image file paths.target_fields¶ (
Union
[str
,Sequence
[str
],None
]) – The field (column name) or list of fields in the DataFrames containing the targets.train_data_frame¶ (
Optional
[DataFrame
]) – The pandas DataFrame to use when training.train_images_root¶ (
Optional
[str
]) – The root directory containing train spectrogram images.train_resolver¶ (
Optional
[Callable
[[str
,str
],str
]]) – Optionally provide a function which converts an entry from theinput_field
into a spectrogram image file path.val_data_frame¶ (
Optional
[DataFrame
]) – The pandas DataFrame to use when validating.val_images_root¶ (
Optional
[str
]) – The root directory containing validation spectrogram images.val_resolver¶ (
Optional
[Callable
[[str
,str
],str
]]) – Optionally provide a function which converts an entry from theinput_field
into a spectrogram image file path.test_data_frame¶ (
Optional
[DataFrame
]) – The pandas DataFrame to use when testing.test_images_root¶ (
Optional
[str
]) – The root directory containing test spectrogram images.test_resolver¶ (
Optional
[Callable
[[str
,str
],str
]]) – Optionally provide a function which converts an entry from theinput_field
into a spectrogram image file path.predict_data_frame¶ (
Optional
[DataFrame
]) – The pandas DataFrame to use when predicting.predict_images_root¶ (
Optional
[str
]) – The root directory containing predict spectrogram images.predict_resolver¶ (
Optional
[Callable
[[str
,str
],str
]]) – Optionally provide a function which converts an entry from theinput_field
into a spectrogram image file path.sampling_rate¶ (
int
) – Sampling rate to use when loading raw audio files.n_fft¶ (
int
) – The size of the FFT to use when creating spectrograms from raw audio.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
ImageClassificationData
.
Examples
>>> from pandas import DataFrame >>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> train_data_frame = DataFrame.from_dict( ... { ... "images": ["spectrogram_1.png", "spectrogram_2.png", "spectrogram_3.png"], ... "targets": ["meow", "bark", "meow"], ... } ... ) >>> predict_data_frame = DataFrame.from_dict( ... { ... "images": [ ... "predict_spectrogram_1.png", ... "predict_spectrogram_2.png", ... "predict_spectrogram_3.png", ... ], ... } ... ) >>> datamodule = AudioClassificationData.from_data_frame( ... "images", ... "targets", ... train_data_frame=train_data_frame, ... train_images_root="train_folder", ... predict_data_frame=predict_data_frame, ... predict_images_root="predict_folder", ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_files(train_files=None, train_targets=None, val_files=None, val_targets=None, test_files=None, test_targets=None, predict_files=None, sampling_rate=16000, n_fft=400, input_cls=<class 'flash.audio.classification.input.AudioClassificationFilesInput'>, transform=<class 'flash.audio.classification.input_transform.AudioClassificationInputTransform'>, transform_kwargs=None, target_formatter=None, **data_module_kwargs)[source]¶
Load the
AudioClassificationData
from lists of files and corresponding lists of targets.The supported file extensions for precomputed spectrograms are:
.jpg
,.jpeg
,.png
,.ppm
,.bmp
,.pgm
,.tif
,.tiff
,.webp
, and.npy
. The supported file extensions for raw audio (where spectrograms will be computed automatically) are:.aiff
,.au
,.avr
,.caf
,.flac
,.mat
,.mat4
,.mat5
,.mpc2k
,.ogg
,.paf
,.pvf
,.rf64
,.ircam
,.voc
,.w64
,.wav
,.nist
, and.wavex
. The targets can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_files¶ (
Optional
[Sequence
[str
]]) – The list of spectrogram image files to use when training.train_targets¶ (
Optional
[Sequence
[Any
]]) – The list of targets to use when training.val_files¶ (
Optional
[Sequence
[str
]]) – The list of spectrogram image files to use when validating.val_targets¶ (
Optional
[Sequence
[Any
]]) – The list of targets to use when validating.test_files¶ (
Optional
[Sequence
[str
]]) – The list of spectrogram image files to use when testing.test_targets¶ (
Optional
[Sequence
[Any
]]) – The list of targets to use when testing.predict_files¶ (
Optional
[Sequence
[str
]]) – The list of spectrogram image files to use when predicting.sampling_rate¶ (
int
) – Sampling rate to use when loading raw audio files.n_fft¶ (
int
) – The size of the FFT to use when creating spectrograms from raw audio.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
AudioClassificationData
.
Examples
>>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> datamodule = AudioClassificationData.from_files( ... train_files=["spectrogram_1.png", "spectrogram_2.png", "spectrogram_3.png"], ... train_targets=["meow", "bark", "meow"], ... predict_files=[ ... "predict_spectrogram_1.png", ... "predict_spectrogram_2.png", ... "predict_spectrogram_3.png", ... ], ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_folders(train_folder=None, val_folder=None, test_folder=None, predict_folder=None, sampling_rate=16000, n_fft=400, input_cls=<class 'flash.audio.classification.input.AudioClassificationFolderInput'>, transform=<class 'flash.audio.classification.input_transform.AudioClassificationInputTransform'>, transform_kwargs=None, target_formatter=None, **data_module_kwargs)[source]¶
Load the
AudioClassificationData
from folders containing spectrogram images.The supported file extensions for precomputed spectrograms are:
.jpg
,.jpeg
,.png
,.ppm
,.bmp
,.pgm
,.tif
,.tiff
,.webp
, and.npy
. The supported file extensions for raw audio (where spectrograms will be computed automatically) are:.aiff
,.au
,.avr
,.caf
,.flac
,.mat
,.mat4
,.mat5
,.mpc2k
,.ogg
,.paf
,.pvf
,.rf64
,.ircam
,.voc
,.w64
,.wav
,.nist
, and.wavex
. For train, test, and validation data, the folders are expected to contain a sub-folder for each class. Here’s the required structure:train_folder ├── meow │ ├── spectrogram_1.png │ ├── spectrogram_3.png │ ... └── bark ├── spectrogram_2.png ...
For prediction, the folder is expected to contain the files for inference, like this:
predict_folder ├── predict_spectrogram_1.png ├── predict_spectrogram_2.png ├── predict_spectrogram_3.png ...
To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_folder¶ (
Optional
[str
]) – The folder containing spectrogram images to use when training.val_folder¶ (
Optional
[str
]) – The folder containing spectrogram images to use when validating.test_folder¶ (
Optional
[str
]) – The folder containing spectrogram images to use when testing.predict_folder¶ (
Optional
[str
]) – The folder containing spectrogram images to use when predicting.sampling_rate¶ (
int
) – Sampling rate to use when loading raw audio files.n_fft¶ (
int
) – The size of the FFT to use when creating spectrograms from raw audio.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
ImageClassificationData
.
Examples
>>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> datamodule = AudioClassificationData.from_folders( ... train_folder="train_folder", ... predict_folder="predict_folder", ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_numpy(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, input_cls=<class 'flash.audio.classification.input.AudioClassificationNumpyInput'>, transform=<class 'flash.audio.classification.input_transform.AudioClassificationInputTransform'>, transform_kwargs=None, target_formatter=None, **data_module_kwargs)[source]¶
Load the
AudioClassificationData
from numpy arrays (or lists of arrays) and corresponding lists of targets.The targets can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_data¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays to use when training.train_targets¶ (
Optional
[Collection
[Any
]]) – The list of targets to use when training.val_data¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays to use when validating.val_targets¶ (
Optional
[Sequence
[Any
]]) – The list of targets to use when validating.test_data¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays to use when testing.test_targets¶ (
Optional
[Sequence
[Any
]]) – The list of targets to use when testing.predict_data¶ (
Optional
[Collection
[ndarray
]]) – The numpy array or list of arrays to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
AudioClassificationData
.
Examples
>>> import numpy as np >>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> datamodule = AudioClassificationData.from_numpy( ... train_data=[np.random.rand(3, 64, 64), np.random.rand(3, 64, 64), np.random.rand(3, 64, 64)], ... train_targets=["meow", "bark", "meow"], ... predict_data=[np.random.rand(3, 64, 64)], ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_tensors(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, input_cls=<class 'flash.audio.classification.input.AudioClassificationTensorInput'>, transform=<class 'flash.audio.classification.input_transform.AudioClassificationInputTransform'>, transform_kwargs=None, target_formatter=None, **data_module_kwargs)[source]¶
Load the
AudioClassificationData
from torch tensors (or lists of tensors) and corresponding lists of targets.The targets can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_data¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors to use when training.train_targets¶ (
Optional
[Collection
[Any
]]) – The list of targets to use when training.val_data¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors to use when validating.val_targets¶ (
Optional
[Sequence
[Any
]]) – The list of targets to use when validating.test_data¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors to use when testing.test_targets¶ (
Optional
[Sequence
[Any
]]) – The list of targets to use when testing.predict_data¶ (
Optional
[Collection
[Tensor
]]) – The torch tensor or list of tensors to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
AudioClassificationData
.
Examples
>>> import torch >>> from flash import Trainer >>> from flash.audio import AudioClassificationData >>> from flash.image import ImageClassifier >>> datamodule = AudioClassificationData.from_tensors( ... train_data=[torch.rand(3, 64, 64), torch.rand(3, 64, 64), torch.rand(3, 64, 64)], ... train_targets=["meow", "bark", "meow"], ... predict_data=[torch.rand(3, 64, 64)], ... transform_kwargs=dict(spectrogram_size=(128, 128)), ... batch_size=2, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['bark', 'meow'] >>> model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- input_transform_cls¶
alias of
flash.audio.classification.input_transform.AudioClassificationInputTransform