Shortcuts

Task

class flash.core.model.Task(model=None, loss_fn=None, learning_rate=5e-05, optimizer='Adam', lr_scheduler=None, metrics=None, deserializer=None, input_transform=None, output_transform=None, output=None)[source]

A general Task.

Parameters
  • model (Optional[~MODEL_TYPE]) – Model to use for the task.

  • loss_fn (Optional[~LOSS_FN_TYPE]) – Loss function for training.

  • learning_rate (float) – Learning rate to use for training, defaults to 5e-5.

  • optimizer (~OPTIMIZER_TYPE) – Optimizer to use for training.

  • lr_scheduler (Optional[~LR_SCHEDULER_TYPE]) – The LR scheduler to use during training.

  • metrics (Optional[~METRICS_TYPE]) – Metrics to compute for training and evaluation. Can either be an metric from the torchmetrics package, a custom metric inheriting from torchmetrics.Metric, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature metric(preds,target) and return a single scalar tensor.

  • deserializer (Optional[~DESERIALIZER_TYPE]) – Either a single Deserializer or a mapping of these to deserialize the input

  • input_transform (Optional[~INPUT_TRANSFORM_TYPE]) – InputTransform to use as the default for this task.

  • output_transform (Optional[~OUTPUT_TRANSFORM_TYPE]) – OutputTransform to use as the default for this task.

  • output (Optional[~OUTPUT_TYPE]) – The Output to use when formatting prediction outputs.

static apply_filtering(y, y_hat)[source]

This function is used to filter some labels or predictions which aren’t conform.

Return type

Tuple[Tensor, Tensor]

classmethod available_finetuning_strategies()[source]

Returns a list containing the keys of the available Finetuning Strategies.

Return type

List[str]

classmethod available_lr_schedulers()[source]

Returns a list containing the keys of the available LR schedulers.

Return type

List[str]

classmethod available_optimizers()[source]

Returns a list containing the keys of the available Optimizers.

Return type

List[str]

build_data_pipeline(input=None, deserializer=None, data_pipeline=None)[source]

Build a DataPipeline incorporating available InputTransform and OutputTransform objects. These will be overridden in the following resolution order (lowest priority first):

  • Lightning Datamodule, either attached to the Trainer or to the Task.

  • Task defaults given to Task.__init__().

  • Task manual overrides by setting data_pipeline.

  • DataPipeline passed to this method.

Parameters
Return type

Optional[DataPipeline]

Returns

The fully resolved DataPipeline.

configure_optimizers()[source]

Implement how optimizer and optionally learning rate schedulers should be configured.

Return type

Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]

get_num_training_steps()[source]

Total training steps inferred from datamodule and devices.

Return type

int

modules_to_freeze()[source]

By default, we try to get the backbone attribute from the task and return it or None if not present.

Return type

Optional[Module]

Returns

The backbone Module to freeze or None if this task does not have a backbone attribute.

serve(host='127.0.0.1', port=8000, sanity_check=True, input_cls=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None)[source]

Serve the Task. Override this method to provide a default input_cls, transform, and transform_kwargs.

Parameters
  • host (str) – The IP address to host the Task on.

  • port (int) – The port to host on.

  • sanity_check (bool) – If True, runs a sanity check before serving.

  • input_cls (Optional[Type[ServeInput]]) – The ServeInput type to use.

  • transform (~INPUT_TRANSFORM_TYPE) – The transform to use when serving.

  • transform_kwargs (Optional[Dict]) – Keyword arguments used to instantiate the transform.

Return type

Composition

step(batch, batch_idx, metrics)[source]

Implement the core logic for the training/validation/test step. By default this includes:

  • Inference on the current batch

  • Calculating the loss

  • Calculating relevant metrics

Override for custom behavior.

Parameters
  • batch (Any) – The output of your dataloader. Can either be a single Tensor or a list of Tensors.

  • batch_idx (int) – Integer displaying index of this batch

  • metrics (ModuleDict) – A module dict containing metrics for calculating relevant training statitics

Return type

Any

Returns

A dict containing both the loss and relevant metrics

Read the Docs v: stable
Versions
latest
stable
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.