Trainer¶
- class flash.core.trainer.Trainer(*args, **kwargs)[source]¶
Exteded Trainer for FLash tasks.
>>> Trainer() <...trainer.Trainer object at ...>
- classmethod add_argparse_args(*args, **kwargs)[source]¶
See
pytorch_lightning.utilities.argparse.add_argparse_args().- Return type
- property estimated_stepping_batches: Union[int, float]¶
Estimated stepping batches for the complete training inferred from DataLoaders, gradient accumulation factor and distributed setup.
Examples
def configure_optimizers(self): optimizer = ... scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches ) return [optimizer], [scheduler]
- finetune(model, train_dataloader=None, val_dataloaders=None, datamodule=None, strategy='no_freeze', train_bn=True)[source]¶
Runs the full optimization routine. Same as
pytorch_lightning.Trainer.fit(), but unfreezes layers of the backbone throughout training layers of the backbone throughout training.- Parameters
model¶ (
LightningModule) – Model to fit.train_dataloader¶ (
Optional[DataLoader]) – A PyTorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.val_dataloaders¶ (
Union[DataLoader,List[DataLoader],None]) – Either a single PyTorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skippeddatamodule¶ (
Optional[LightningDataModule]) – A instance ofLightningDataModule.strategy¶ (
Union[str,BaseFinetuning,Tuple[str,int],Tuple[str,Tuple[int,int]]]) –Should either be a string, or a Tuple, or a finetuning callback subclassing
pytorch_lightning.callbacks.BaseFinetuning.Default strategies can be enabled with these inputs:
"no_freeze""freeze"("freeze_unfreeze", integer: unfreeze_epoch)("unfreeze_milestones", ((integer: unfreeze_epoch_num_layers, integer: unfreeze_epoch_all_layers), integer: num_layers))
where
integercan be any integer. By default,no_freezestrategy will be used.
- fit(model, train_dataloader=None, val_dataloaders=None, datamodule=None)[source]¶
Runs the full optimization routine. Same as
pytorch_lightning.Trainer.fit()- Parameters
datamodule¶ (
Optional[LightningDataModule]) – A instance ofLightningDataModule.model¶ (
LightningModule) – Model to fit.train_dataloader¶ (
Optional[DataLoader]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.val_dataloaders¶ (
Union[DataLoader,List[DataLoader],None]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped
- classmethod from_argparse_args(args, **kwargs)[source]¶
Modified version of
pytorch_lightning.utilities.argparse.from_argparse_args()which populatesvalid_kwargsfrompytorch_lightning.Trainer.- Return type
- predict(model=None, dataloaders=None, output=None, **kwargs)[source]¶
Run inference on your data.
This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the prediction hooks.
- Parameters
model¶ (
Optional[LightningModule]) – The model to predict with.dataloaders¶ (
Union[DataLoader,LightningDataModule,None]) – Atorch.utils.data.DataLoaderor a sequence of them, or aLightningDataModulespecifying prediction samples.output¶ (
Union[str,Output,None]) – TheOutputto use to transform predict outputs.kwargs¶ – Additional keyword arguments to pass to
predict().
- Returns
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.