Trainer¶
- class flash.core.trainer.Trainer(*args, **kwargs)[source]¶
Exteded Trainer for FLash tasks.
>>> Trainer() <...trainer.Trainer object at ...>
- classmethod add_argparse_args(*args, **kwargs)[source]¶
See
pytorch_lightning.utilities.argparse.add_argparse_args()
.- Return type
- property estimated_stepping_batches: Union[int, float]¶
Estimated stepping batches for the complete training inferred from DataLoaders, gradient accumulation factor and distributed setup.
Examples
def configure_optimizers(self): optimizer = ... scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches ) return [optimizer], [scheduler]
- finetune(model, train_dataloader=None, val_dataloaders=None, datamodule=None, strategy='no_freeze', train_bn=True)[source]¶
Runs the full optimization routine. Same as
pytorch_lightning.Trainer.fit()
, but unfreezes layers of the backbone throughout training layers of the backbone throughout training.- Parameters
model¶ (
LightningModule
) – Model to fit.train_dataloader¶ (
Optional
[DataLoader
]) – A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.val_dataloaders¶ (
Union
[DataLoader
,List
[DataLoader
],None
]) – Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skippeddatamodule¶ (
Optional
[LightningDataModule
]) – A instance ofLightningDataModule
.strategy¶ (
Union
[str
,BaseFinetuning
,Tuple
[str
,int
],Tuple
[str
,Tuple
[int
,int
]]]) –Should either be a string, or a Tuple, or a finetuning callback subclassing
pytorch_lightning.callbacks.BaseFinetuning
.Default strategies can be enabled with these inputs:
"no_freeze"
"freeze"
("freeze_unfreeze", integer: unfreeze_epoch)
("unfreeze_milestones", ((integer: unfreeze_epoch_num_layers, integer: unfreeze_epoch_all_layers), integer: num_layers))
where
integer
can be any integer. By default,no_freeze
strategy will be used.
- fit(model, train_dataloader=None, val_dataloaders=None, datamodule=None)[source]¶
Runs the full optimization routine. Same as
pytorch_lightning.Trainer.fit()
- Parameters
datamodule¶ (
Optional
[LightningDataModule
]) – A instance ofLightningDataModule
.model¶ (
LightningModule
) – Model to fit.train_dataloader¶ (
Optional
[DataLoader
]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.val_dataloaders¶ (
Union
[DataLoader
,List
[DataLoader
],None
]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped
- classmethod from_argparse_args(args, **kwargs)[source]¶
Modified version of
pytorch_lightning.utilities.argparse.from_argparse_args()
which populatesvalid_kwargs
frompytorch_lightning.Trainer
.- Return type
- predict(model=None, dataloaders=None, output=None, **kwargs)[source]¶
Run inference on your data.
This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the prediction hooks.
- Parameters
model¶ (
Optional
[LightningModule
]) – The model to predict with.dataloaders¶ (
Union
[DataLoader
,LightningDataModule
,None
]) – Atorch.utils.data.DataLoader
or a sequence of them, or aLightningDataModule
specifying prediction samples.output¶ (
Union
[str
,Output
,None
]) – TheOutput
to use to transform predict outputs.kwargs¶ – Additional keyword arguments to pass to
predict()
.
- Returns
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.