ImageClassifier¶
- class flash.image.classification.model.ImageClassifier(num_classes=None, labels=None, backbone='resnet18', backbone_kwargs=None, head='linear', pretrained=True, loss_fn=None, optimizer='Adam', lr_scheduler=None, metrics=None, learning_rate=None, multi_label=False, training_strategy='default', training_strategy_kwargs=None)[source]¶
The
ImageClassifier
is aTask
for classifying images. For more details, see Image Classification. TheImageClassifier
also supports multi-label classification withmulti_label=True
. For more details, see Multi-label Image Classification.You can register custom backbones to use with the
ImageClassifier
:from torch import nn import torchvision from flash.image import ImageClassifier # This is useful to create new backbone and make them accessible from `ImageClassifier` @ImageClassifier.backbones(name="resnet18") def fn_resnet(pretrained: bool = True): model = torchvision.models.resnet18(pretrained) # remove the last two layers & turn it into a Sequential model backbone = nn.Sequential(*list(model.children())[:-2]) num_features = model.fc.in_features # backbones need to return the num_features to build the head return backbone, num_features
- Parameters
num_classes¶ (
Optional
[int
]) – Number of classes to classify.backbone¶ (
Union
[str
,Tuple
[Module
,int
]]) – A string or (model, num_features) tuple to use to compute image features, defaults to"resnet18"
.head¶ (
Union
[str
,function
,Module
]) – A string fromImageClassifier.available_heads()
, annn.Module
, or a function of (num_features
,num_classes
) which returns annn.Module
to use as the model head.pretrained¶ (
Union
[bool
,str
]) – A bool or string to specify the pretrained weights of the backbone, defaults toTrue
which loads the default supervised pretrained weights.loss_fn¶ (
Optional
[TypeVar
(LOSS_FN_TYPE
,Callable
,Mapping
,Sequence
,None
)]) – Loss function for training, defaults totorch.nn.functional.cross_entropy()
.optimizer¶ (
TypeVar
(OPTIMIZER_TYPE
,str
,Callable
,Tuple
[str
,Dict
[str
,Any
]],None
)) – Optimizer to use for training.lr_scheduler¶ (
Optional
[TypeVar
(LR_SCHEDULER_TYPE
,str
,Callable
,Tuple
[str
,Dict
[str
,Any
]],Tuple
[str
,Dict
[str
,Any
],Dict
[str
,Any
]],None
)]) – The LR scheduler to use during training.metrics¶ (
Optional
[TypeVar
(METRICS_TYPE
,Metric
,Mapping
,Sequence
,None
)]) – Metrics to compute for training and evaluation. Can either be an metric from the torchmetrics package, a custom metric inheriting from torchmetrics.Metric, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature metric(preds,target) and return a single scalar tensor. Defaults totorchmetrics.Accuracy
.learning_rate¶ (
Optional
[float
]) – Learning rate to use for training, defaults to1e-3
.multi_label¶ (
bool
) – Whether the targets are multi-label or not.training_strategy¶ (
Optional
[str
]) – string indicating the training strategy. Adjust if you want to use learn2learn for doing meta-learning researchtraining_strategy_kwargs¶ (
Optional
[Dict
[str
,Any
]]) – Additional kwargs for setting the training strategy
- classmethod available_finetuning_strategies(cls)¶
Returns a list containing the keys of the available Finetuning Strategies.
- classmethod available_lr_schedulers(cls)¶
Returns a list containing the keys of the available LR schedulers.
- classmethod available_optimizers(cls)¶
Returns a list containing the keys of the available Optimizers.