Shortcuts

ImageClassifier

class flash.image.classification.model.ImageClassifier(num_classes=None, labels=None, backbone='resnet18', backbone_kwargs=None, head='linear', pretrained=True, loss_fn=None, optimizer='Adam', lr_scheduler=None, metrics=None, learning_rate=None, multi_label=False, training_strategy='default', training_strategy_kwargs=None)[source]

The ImageClassifier is a Task for classifying images. For more details, see Image Classification. The ImageClassifier also supports multi-label classification with multi_label=True. For more details, see Multi-label Image Classification.

You can register custom backbones to use with the ImageClassifier:

from torch import nn
import torchvision
from flash.image import ImageClassifier

# This is useful to create new backbone and make them accessible from `ImageClassifier`
@ImageClassifier.backbones(name="resnet18")
def fn_resnet(pretrained: bool = True):
    model = torchvision.models.resnet18(pretrained)
    # remove the last two layers & turn it into a Sequential model
    backbone = nn.Sequential(*list(model.children())[:-2])
    num_features = model.fc.in_features
    # backbones need to return the num_features to build the head
    return backbone, num_features
Parameters
  • num_classes (Optional[int]) – Number of classes to classify.

  • backbone (Union[str, Tuple[Module, int]]) – A string or (model, num_features) tuple to use to compute image features, defaults to "resnet18".

  • head (Union[str, function, Module]) – A string from ImageClassifier.available_heads(), an nn.Module, or a function of (num_features, num_classes) which returns an nn.Module to use as the model head.

  • pretrained (Union[bool, str]) – A bool or string to specify the pretrained weights of the backbone, defaults to True which loads the default supervised pretrained weights.

  • loss_fn (Optional[TypeVar(LOSS_FN_TYPE, Callable, Mapping, Sequence, None)]) – Loss function for training, defaults to torch.nn.functional.cross_entropy().

  • optimizer (TypeVar(OPTIMIZER_TYPE, str, Callable, Tuple[str, Dict[str, Any]], None)) – Optimizer to use for training.

  • lr_scheduler (Optional[TypeVar(LR_SCHEDULER_TYPE, str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]], None)]) – The LR scheduler to use during training.

  • metrics (Optional[TypeVar(METRICS_TYPE, Metric, Mapping, Sequence, None)]) – Metrics to compute for training and evaluation. Can either be an metric from the torchmetrics package, a custom metric inheriting from torchmetrics.Metric, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature metric(preds,target) and return a single scalar tensor. Defaults to torchmetrics.Accuracy.

  • learning_rate (Optional[float]) – Learning rate to use for training, defaults to 1e-3.

  • multi_label (bool) – Whether the targets are multi-label or not.

  • training_strategy (Optional[str]) – string indicating the training strategy. Adjust if you want to use learn2learn for doing meta-learning research

  • training_strategy_kwargs (Optional[Dict[str, Any]]) – Additional kwargs for setting the training strategy

classmethod available_finetuning_strategies(cls)

Returns a list containing the keys of the available Finetuning Strategies.

Return type

List[str]

classmethod available_lr_schedulers(cls)

Returns a list containing the keys of the available LR schedulers.

Return type

List[str]

classmethod available_optimizers(cls)

Returns a list containing the keys of the available Optimizers.

Return type

List[str]

classmethod available_outputs(cls)

Returns the list of available outputs (that can be used during prediction or serving) for this Task.

Examples

..testsetup:

>>> from flash import Task
>>> print(Task.available_outputs())
['preds', 'raw']
Return type

List[str]

Read the Docs v: 0.8.1.post0
Versions
latest
stable
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.