Shortcuts

Task

class flash.core.model.Task(model=None, loss_fn=None, optimizer=torch.optim.Adam, optimizer_kwargs=None, scheduler=None, scheduler_kwargs=None, metrics=None, learning_rate=5e-05, deserializer=None, preprocess=None, postprocess=None, serializer=None)[source]

A general Task.

Parameters
static apply_filtering(y, y_hat)[source]

This function is used to filter some labels or predictions which aren’t conform.

Return type

Tuple[Tensor, Tensor]

build_data_pipeline(data_source=None, deserializer=None, data_pipeline=None)[source]

Build a DataPipeline incorporating available Preprocess and Postprocess objects. These will be overridden in the following resolution order (lowest priority first):

  • Lightning Datamodule, either attached to the Trainer or to the Task.

  • Task defaults given to Task.__init__().

  • Task manual overrides by setting data_pipeline.

  • DataPipeline passed to this method.

Parameters

data_pipeline (Optional[DataPipeline]) – Optional highest priority source of Preprocess and Postprocess.

Return type

Optional[DataPipeline]

Returns

The fully resolved DataPipeline.

get_num_training_steps()[source]

Total training steps inferred from datamodule and devices.

Return type

int

predict(x, data_source=None, deserializer=None, data_pipeline=None)[source]

Predict function for raw data or processed data.

Parameters
  • x (Any) – Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data.

  • data_pipeline (Optional[DataPipeline]) – Use this to override the current data pipeline

Return type

Any

Returns

The post-processed model predictions

step(batch, batch_idx, metrics)[source]

The training/validation/test step.

Override for custom behavior.

Return type

Any

Read the Docs v: latest
Versions
latest
stable
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.