Shortcuts

Trainer

class flash.core.trainer.Trainer(*args, serve_sanity_check=False, **kwargs)[source]
classmethod add_argparse_args(*args, **kwargs)[source]

See pytorch_lightning.utilities.argparse.add_argparse_args().

Return type

ArgumentParser

finetune(model, train_dataloader=None, val_dataloaders=None, datamodule=None, strategy='no_freeze', train_bn=True)[source]

Runs the full optimization routine. Same as pytorch_lightning.Trainer.fit(), but unfreezes layers of the backbone throughout training layers of the backbone throughout training.

Parameters
  • model (LightningModule) – Model to fit.

  • train_dataloader (Optional[DataLoader]) – A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • val_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped

  • datamodule (Optional[LightningDataModule]) – A instance of LightningDataModule.

  • strategy (Union[str, BaseFinetuning, Tuple[str, int], Tuple[str, Tuple[int, int]]]) –

    Should either be a string, or a Tuple, or a finetuning callback subclassing pytorch_lightning.callbacks.BaseFinetuning.

    Default strategies can be enabled with these inputs:

    • "no_freeze"

    • "freeze"

    • ("freeze_unfreeze", integer: unfreeze_epoch)

    • ("unfreeze_milestones", ((integer: unfreeze_epoch_num_layers, integer: unfreeze_epoch_all_layers), integer: num_layers))

    where integer can be any integer. By default, no_freeze strategy will be used.

  • train_bn (bool) – Whether to train Batch Norm layer

fit(model, train_dataloader=None, val_dataloaders=None, datamodule=None)[source]

Runs the full optimization routine. Same as pytorch_lightning.Trainer.fit()

Parameters
  • datamodule (Optional[LightningDataModule]) – A instance of LightningDataModule.

  • model (LightningModule) – Model to fit.

  • train_dataloader (Optional[DataLoader]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.

  • val_dataloaders (Union[DataLoader, List[DataLoader], None]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped

classmethod from_argparse_args(args, **kwargs)[source]

Modified version of pytorch_lightning.utilities.argparse.from_argparse_args() which populates valid_kwargs from pytorch_lightning.Trainer.

Return type

Trainer

request_dataloader(*args, **kwargs)[source]

Handles downloading data in the GPU or TPU case.

Return type

Union[DataLoader, List[DataLoader]]

Returns

The dataloader

Read the Docs v: latest
Versions
latest
stable
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.