Shortcuts

Task

class flash.core.model.Task(model=None, loss_fn=None, learning_rate=None, optimizer='Adam', lr_scheduler=None, metrics=None, output_transform=None)[source]

A general Task.

Parameters
static apply_filtering(y, y_hat)[source]

This function is used to filter some labels or predictions which aren’t conform.

Return type

Tuple[Tensor, Tensor]

as_embedder(layer)[source]

Convert this task to an embedder. Note that the parameters are not copied so that any optimization of the embedder will also apply to the converted Task.

Parameters

layer (str) – The layer to embed to. This should be one of the available_layers().

classmethod available_finetuning_strategies()[source]

Returns a list containing the keys of the available Finetuning Strategies.

Return type

List[str]

available_layers()[source]

Get the list of available layers for use with the as_embedder() method.

classmethod available_lr_schedulers()[source]

Returns a list containing the keys of the available LR schedulers.

Return type

List[str]

classmethod available_optimizers()[source]

Returns a list containing the keys of the available Optimizers.

Return type

List[str]

classmethod available_outputs()[source]

Returns the list of available outputs (that can be used during prediction or serving) for this Task.

Examples

..testsetup:

>>> from flash import Task
>>> print(Task.available_outputs())
['preds', 'raw']
Return type

List[str]

configure_optimizers()[source]

Implement how optimizer and optionally learning rate schedulers should be configured.

Return type

Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]

get_num_training_steps()[source]

Total training steps inferred from datamodule and devices.

Return type

int

modules_to_freeze()[source]

By default, we try to get the backbone attribute from the task and return it or None if not present.

Return type

Optional[Module]

Returns

The backbone Module to freeze or None if this task does not have a backbone attribute.

serve(host='127.0.0.1', port=8000, sanity_check=True, input_cls=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, output=None)[source]

Serve the Task. Override this method to provide a default input_cls, transform, and transform_kwargs.

Parameters
  • host (str) – The IP address to host the Task on.

  • port (int) – The port to host on.

  • sanity_check (bool) – If True, runs a sanity check before serving.

  • input_cls (Optional[Type[ServeInput]]) – The ServeInput type to use.

  • transform (TypeVar(INPUT_TRANSFORM_TYPE, Type[flash.core.data.io.input_transform.InputTransform], Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str], None)) – The transform to use when serving.

  • transform_kwargs (Optional[Dict]) – Keyword arguments used to instantiate the transform.

Return type

Composition

step(batch, batch_idx, metrics)[source]

Implement the core logic for the training/validation/test step. By default this includes:

  • Inference on the current batch

  • Calculating the loss

  • Calculating relevant metrics

Override for custom behavior.

Parameters
  • batch (Any) – The output of your dataloader. Can either be a single Tensor or a list of Tensors.

  • batch_idx (int) – Integer displaying index of this batch

  • metrics (ModuleDict) – A module dict containing metrics for calculating relevant training statitics

Return type

Any

Returns

A dict containing both the loss and relevant metrics

Read the Docs v: latest
Versions
latest
stable
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.