Shortcuts

ImageClassifier

class flash.image.classification.model.ImageClassifier(num_classes=None, backbone='resnet18', backbone_kwargs=None, head=None, pretrained=True, loss_fn=None, optimizer='Adam', lr_scheduler=None, metrics=None, learning_rate=0.001, multi_label=False, output=None, training_strategy='default', training_strategy_kwargs=None)[source]

The ImageClassifier is a Task for classifying images. For more details, see Image Classification. The ImageClassifier also supports multi-label classification with multi_label=True. For more details, see Multi-label Image Classification.

You can register custom backbones to use with the ImageClassifier:

from torch import nn
import torchvision
from flash.image import ImageClassifier

# This is useful to create new backbone and make them accessible from `ImageClassifier`
@ImageClassifier.backbones(name="resnet18")
def fn_resnet(pretrained: bool = True):
    model = torchvision.models.resnet18(pretrained)
    # remove the last two layers & turn it into a Sequential model
    backbone = nn.Sequential(*list(model.children())[:-2])
    num_features = model.fc.in_features
    # backbones need to return the num_features to build the head
    return backbone, num_features
Parameters
  • num_classes (Optional[int]) – Number of classes to classify.

  • backbone (Union[str, Tuple[Module, int]]) – A string or (model, num_features) tuple to use to compute image features, defaults to "resnet18".

  • pretrained (Union[bool, str]) – A bool or string to specify the pretrained weights of the backbone, defaults to True which loads the default supervised pretrained weights.

  • loss_fn (Optional[~LOSS_FN_TYPE]) – Loss function for training, defaults to torch.nn.functional.cross_entropy().

  • optimizer (~OPTIMIZER_TYPE) – Optimizer to use for training.

  • lr_scheduler (Optional[~LR_SCHEDULER_TYPE]) – The LR scheduler to use during training.

  • metrics (Optional[~METRICS_TYPE]) – Metrics to compute for training and evaluation. Can either be an metric from the torchmetrics package, a custom metric inheriting from torchmetrics.Metric, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature metric(preds,target) and return a single scalar tensor. Defaults to torchmetrics.Accuracy.

  • learning_rate (float) – Learning rate to use for training, defaults to 1e-3.

  • multi_label (bool) – Whether the targets are multi-label or not.

  • output (Optional[~OUTPUT_TYPE]) – The Output to use when formatting prediction outputs.

  • training_strategy (Optional[str]) – string indicating the training strategy. Adjust if you want to use learn2learn for doing meta-learning research

  • training_strategy_kwargs (Optional[Dict[str, Any]]) – Additional kwargs for setting the training strategy

Read the Docs v: stable
Versions
latest
stable
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.