ImageClassifier¶
- class flash.image.classification.model.ImageClassifier(num_classes=None, labels=None, backbone='resnet18', backbone_kwargs=None, head='linear', pretrained=True, loss_fn=None, optimizer='Adam', lr_scheduler=None, metrics=None, learning_rate=None, multi_label=False, training_strategy='default', training_strategy_kwargs=None)[source]¶
The
ImageClassifier
is aTask
for classifying images. For more details, see Image Classification. TheImageClassifier
also supports multi-label classification withmulti_label=True
. For more details, see Multi-label Image Classification.You can register custom backbones to use with the
ImageClassifier
:from torch import nn import torchvision from flash.image import ImageClassifier # This is useful to create new backbone and make them accessible from `ImageClassifier` @ImageClassifier.backbones(name="resnet18") def fn_resnet(pretrained: bool = True): model = torchvision.models.resnet18(pretrained) # remove the last two layers & turn it into a Sequential model backbone = nn.Sequential(*list(model.children())[:-2]) num_features = model.fc.in_features # backbones need to return the num_features to build the head return backbone, num_features
- Parameters
num_classes¶ (
Optional
[int
]) – Number of classes to classify.backbone¶ (
Union
[str
,Tuple
[Module
,int
]]) – A string or (model, num_features) tuple to use to compute image features, defaults to"resnet18"
.head¶ (
Union
[str
,function
,Module
]) – A string fromImageClassifier.available_heads()
, annn.Module
, or a function of (num_features
,num_classes
) which returns annn.Module
to use as the model head.pretrained¶ (
Union
[bool
,str
]) – A bool or string to specify the pretrained weights of the backbone, defaults toTrue
which loads the default supervised pretrained weights.loss_fn¶ (
Optional
[TypeVar
(LOSS_FN_TYPE
,Callable
,Mapping
,Sequence
,None
)]) – Loss function for training, defaults totorch.nn.functional.cross_entropy()
.optimizer¶ (
TypeVar
(OPTIMIZER_TYPE
,str
,Callable
,Tuple
[str
,Dict
[str
,Any
]],None
)) – Optimizer to use for training.lr_scheduler¶ (
Optional
[TypeVar
(LR_SCHEDULER_TYPE
,str
,Callable
,Tuple
[str
,Dict
[str
,Any
]],Tuple
[str
,Dict
[str
,Any
],Dict
[str
,Any
]],None
)]) – The LR scheduler to use during training.metrics¶ (
Optional
[TypeVar
(METRICS_TYPE
,Metric
,Mapping
,Sequence
,None
)]) – Metrics to compute for training and evaluation. Can either be an metric from the torchmetrics package, a custom metric inheriting from torchmetrics.Metric, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature metric(preds,target) and return a single scalar tensor. Defaults totorchmetrics.Accuracy
.learning_rate¶ (
Optional
[float
]) – Learning rate to use for training, defaults to1e-3
.multi_label¶ (
bool
) – Whether the targets are multi-label or not.training_strategy¶ (
Optional
[str
]) – string indicating the training strategy. Adjust if you want to use learn2learn for doing meta-learning researchtraining_strategy_kwargs¶ (
Optional
[Dict
[str
,Any
]]) – Additional kwargs for setting the training strategy
- classmethod available_finetuning_strategies(cls)¶
Returns a list containing the keys of the available Finetuning Strategies.
- classmethod available_lr_schedulers(cls)¶
Returns a list containing the keys of the available LR schedulers.
- classmethod available_optimizers(cls)¶
Returns a list containing the keys of the available Optimizers.
- classmethod available_outputs(cls)¶
Returns the list of available outputs (that can be used during prediction or serving) for this
Task
.Examples
..testsetup:
>>> from flash import Task
>>> print(Task.available_outputs()) ['preds', 'raw']
- classmethod load_from_checkpoint(cls, checkpoint_path, map_location=None, hparams_file=None, strict=True, **kwargs)¶
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to
__init__
in the checkpoint under"hyper_parameters"
.Any arguments specified through **kwargs will override args stored in
"hyper_parameters"
.- Parameters
checkpoint_path¶ (
Union
[str
,Path
,IO
]) – Path to checkpoint. This can also be a URL, or file-like objectmap_location¶ (
Union
[device
,str
,int
,Callable
[[Union
[device
,str
,int
]],Union
[device
,str
,int
]],Dict
[Union
[device
,str
,int
],Union
[device
,str
,int
]],None
]) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as intorch.load()
.hparams_file¶ (
Union
[str
,Path
,None
]) –Optional path to a
.yaml
or.csv
file with hierarchical structure as in this example:drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a
.yaml
file with the hparams you’d like to use. These will be converted into adict
and passed into yourLightningModule
for use.If your model’s
hparams
argument isNamespace
and.yaml
file has hierarchical structure, you need to refactor your model to treathparams
asdict
.strict¶ (
bool
) – Whether to strictly enforce that the keys incheckpoint_path
match the keys returned by this module’s state dict.**kwargs¶ – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
- Return type
Self
- Returns
LightningModule
instance with loaded weights and hyperparameters (if available).
Note
load_from_checkpoint
is a class method. You should use yourLightningModule
class to call it instead of theLightningModule
instance.Example:
# load weights without mapping ... model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)