Shortcuts

SemanticSegmentation

class flash.image.segmentation.model.SemanticSegmentation(num_classes, backbone='resnet50', backbone_kwargs=None, head='fpn', head_kwargs=None, pretrained=True, loss_fn=None, optimizer='Adam', lr_scheduler=None, metrics=None, learning_rate=None, multi_label=False, output_transform=None)[source]

SemanticSegmentation is a Task for semantic segmentation of images. For more details, see Semantic Segmentation.

Parameters
  • num_classes (int) – Number of classes to classify.

  • backbone (Union[str, Module]) – A string or model to use to compute image features.

  • backbone_kwargs (Optional[Dict]) – Additional arguments for the backbone configuration.

  • head (str) – A string or (model, num_features) tuple to use to compute image features.

  • head_kwargs (Optional[Dict]) – Additional arguments for the head configuration.

  • pretrained (Union[bool, str]) – Use a pretrained backbone.

  • loss_fn (Optional[TypeVar(LOSS_FN_TYPE, Callable, Mapping, Sequence, None)]) – Loss function for training.

  • optimizer (TypeVar(OPTIMIZER_TYPE, str, Callable, Tuple[str, Dict[str, Any]], None)) – Optimizer to use for training.

  • lr_scheduler (Optional[TypeVar(LR_SCHEDULER_TYPE, str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]], None)]) – The LR scheduler to use during training.

  • metrics (Optional[TypeVar(METRICS_TYPE, Metric, Mapping, Sequence, None)]) – Metrics to compute for training and evaluation. Can either be an metric from the torchmetrics package, a custom metric inherenting from torchmetrics.Metric, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature metric(preds,target) and return a single scalar tensor. Defaults to torchmetrics.IOU.

  • learning_rate (Optional[float]) – Learning rate to use for training. If None (the default) then the default LR for your chosen optimizer will be used.

  • multi_label (bool) – Whether the targets are multi-label or not.

  • output – The Output to use when formatting prediction outputs.

  • output_transform (Optional[TypeVar(OUTPUT_TRANSFORM_TYPE, flash.core.data.io.output_transform.OutputTransform, None)]) – OutputTransform use for post processing samples.

classmethod available_finetuning_strategies(cls)

Returns a list containing the keys of the available Finetuning Strategies.

Return type

List[str]

classmethod available_lr_schedulers(cls)

Returns a list containing the keys of the available LR schedulers.

Return type

List[str]

classmethod available_optimizers(cls)

Returns a list containing the keys of the available Optimizers.

Return type

List[str]

classmethod available_outputs(cls)

Returns the list of available outputs (that can be used during prediction or serving) for this Task.

Examples

..testsetup:

>>> from flash import Task
>>> print(Task.available_outputs())
['preds', 'raw']
Return type

List[str]

classmethod load_from_checkpoint(cls, checkpoint_path, map_location=None, hparams_file=None, strict=True, **kwargs)

Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under "hyper_parameters".

Any arguments specified through **kwargs will override args stored in "hyper_parameters".

Parameters
  • checkpoint_path (Union[str, Path, IO]) – Path to checkpoint. This can also be a URL, or file-like object

  • map_location (Union[device, str, int, Callable[[Union[device, str, int]], Union[device, str, int]], Dict[Union[device, str, int], Union[device, str, int]], None]) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in torch.load().

  • hparams_file (Union[str, Path, None]) –

    Optional path to a .yaml or .csv file with hierarchical structure as in this example:

    drop_prob: 0.2
    dataloader:
        batch_size: 32
    

    You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you’d like to use. These will be converted into a dict and passed into your LightningModule for use.

    If your model’s hparams argument is Namespace and .yaml file has hierarchical structure, you need to refactor your model to treat hparams as dict.

  • strict (bool) – Whether to strictly enforce that the keys in checkpoint_path match the keys returned by this module’s state dict.

  • **kwargs – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.

Return type

Self

Returns

LightningModule instance with loaded weights and hyperparameters (if available).

Note

load_from_checkpoint is a class method. You should use your LightningModule class to call it instead of the LightningModule instance.

Example:

# load weights without mapping ...
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    map_location=map_location
)

# or load weights and hyperparameters from separate files.
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    hparams_file='/path/to/hparams_file.yaml'
)

# override some of the params with new values
model = MyLightningModule.load_from_checkpoint(
    PATH,
    num_layers=128,
    pretrained_ckpt_path=NEW_PATH,
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
Read the Docs v: stable
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.