Shortcuts

Trainer

class flash.core.trainer.Trainer(*args, **kwargs)[source]

Exteded Trainer for FLash tasks.

>>> Trainer()  
<...trainer.Trainer object at ...>
classmethod add_argparse_args(*args, **kwargs)[source]

See pytorch_lightning.utilities.argparse.add_argparse_args().

Return type

ArgumentParser

property estimated_stepping_batches: Union[int, float]

Estimated stepping batches for the complete training inferred from DataLoaders, gradient accumulation factor and distributed setup.

Examples

def configure_optimizers(self):
    optimizer = ...
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
    )
    return [optimizer], [scheduler]
Return type

Union[int, float]

finetune(model, train_dataloader=None, val_dataloaders=None, datamodule=None, strategy='no_freeze', train_bn=True)[source]

Runs the full optimization routine. Same as pytorch_lightning.Trainer.fit(), but unfreezes layers of the backbone throughout training layers of the backbone throughout training.

Parameters
  • model (LightningModule) – Model to fit.

  • train_dataloader (Optional[DataLoader]) – A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • val_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped

  • datamodule (Optional[LightningDataModule]) – A instance of LightningDataModule.

  • strategy (Union[str, BaseFinetuning, Tuple[str, int], Tuple[str, Tuple[int, int]]]) –

    Should either be a string, or a Tuple, or a finetuning callback subclassing pytorch_lightning.callbacks.BaseFinetuning.

    Default strategies can be enabled with these inputs:

    • "no_freeze"

    • "freeze"

    • ("freeze_unfreeze", integer: unfreeze_epoch)

    • ("unfreeze_milestones", ((integer: unfreeze_epoch_num_layers, integer: unfreeze_epoch_all_layers), integer: num_layers))

    where integer can be any integer. By default, no_freeze strategy will be used.

  • train_bn (bool) – Whether to train Batch Norm layer

fit(model, train_dataloader=None, val_dataloaders=None, datamodule=None)[source]

Runs the full optimization routine. Same as pytorch_lightning.Trainer.fit()

Parameters
  • datamodule (Optional[LightningDataModule]) – A instance of LightningDataModule.

  • model (LightningModule) – Model to fit.

  • train_dataloader (Optional[DataLoader]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • val_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped

classmethod from_argparse_args(args, **kwargs)[source]

Modified version of pytorch_lightning.utilities.argparse.from_argparse_args() which populates valid_kwargs from pytorch_lightning.Trainer.

Return type

Trainer

predict(model=None, dataloaders=None, output=None, **kwargs)[source]

Run inference on your data.

This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the prediction hooks.

Parameters
Returns

Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.