

class flash.core.trainer.Trainer(*args, **kwargs)[source]

Exteded Trainer for FLash tasks.

>>> Trainer()  
<...trainer.Trainer object at ...>
classmethod add_argparse_args(*args, **kwargs)[source]

See pytorch_lightning.utilities.argparse.add_argparse_args().

Return type


property estimated_stepping_batches: Union[int, float]

Estimated stepping batches for the complete training inferred from DataLoaders, gradient accumulation factor and distributed setup.


def configure_optimizers(self):
    optimizer = ...
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
    return [optimizer], [scheduler]
Return type

Union[int, float]

finetune(model, train_dataloader=None, val_dataloaders=None, datamodule=None, strategy='no_freeze', train_bn=True)[source]

Runs the full optimization routine. Same as, but unfreezes layers of the backbone throughout training layers of the backbone throughout training.

  • model (LightningModule) – Model to fit.

  • train_dataloader (Optional[DataLoader]) – A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • val_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped

  • datamodule (Optional[LightningDataModule]) – A instance of LightningDataModule.

  • strategy (Union[str, BaseFinetuning, Tuple[str, int], Tuple[str, Tuple[int, int]]]) –

    Should either be a string, or a Tuple, or a finetuning callback subclassing pytorch_lightning.callbacks.BaseFinetuning.

    Default strategies can be enabled with these inputs:

    • "no_freeze"

    • "freeze"

    • ("freeze_unfreeze", integer: unfreeze_epoch)

    • ("unfreeze_milestones", ((integer: unfreeze_epoch_num_layers, integer: unfreeze_epoch_all_layers), integer: num_layers))

    where integer can be any integer. By default, no_freeze strategy will be used.

  • train_bn (bool) – Whether to train Batch Norm layer

fit(model, train_dataloader=None, val_dataloaders=None, datamodule=None)[source]

Runs the full optimization routine. Same as

  • datamodule (Optional[LightningDataModule]) – A instance of LightningDataModule.

  • model (LightningModule) – Model to fit.

  • train_dataloader (Optional[DataLoader]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • val_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped

classmethod from_argparse_args(args, **kwargs)[source]

Modified version of pytorch_lightning.utilities.argparse.from_argparse_args() which populates valid_kwargs from pytorch_lightning.Trainer.

Return type


predict(model=None, dataloaders=None, output=None, **kwargs)[source]

Run inference on your data.

This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the prediction hooks.


Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.