TranslationTask¶
- class flash.text.seq2seq.translation.model.TranslationTask(backbone='t5-small', tokenizer_kwargs=None, max_source_length=128, max_target_length=128, padding='max_length', loss_fn=None, optimizer='Adam', lr_scheduler=None, metrics=None, learning_rate=None, num_beams=4, n_gram=4, smooth=True, enable_ort=False)[source]¶
The
TranslationTaskis aTaskfor Seq2Seq text translation. For more details, see Translation.You can change the backbone to any translation model from HuggingFace/transformers using the
backboneargument.- Parameters
max_source_length¶ (
int) – The maximum length to pad / truncate input sequences to.max_target_length¶ (
int) – The maximum length to pad / truncate target sequences to.padding¶ (
Union[str,bool]) – The type of padding to apply. One of: “longest” orTrue, “max_length”, “do_not_pad” orFalse.loss_fn¶ (
Optional[TypeVar(LOSS_FN_TYPE,Callable,Mapping,Sequence,None)]) – Loss function for training.optimizer¶ (
TypeVar(OPTIMIZER_TYPE,str,Callable,Tuple[str,Dict[str,Any]],None)) – Optimizer to use for training.lr_scheduler¶ (
Optional[TypeVar(LR_SCHEDULER_TYPE,str,Callable,Tuple[str,Dict[str,Any]],Tuple[str,Dict[str,Any],Dict[str,Any]],None)]) – The LR scheduler to use during training.metrics¶ (
Optional[TypeVar(METRICS_TYPE,Metric,Mapping,Sequence,None)]) – Metrics to compute for training and evaluation. Defauls to calculating the BLEU metric. Changing this argument currently has no effect.learning_rate¶ (
Optional[float]) – Learning rate to use for training, defaults to 1e-5num_beams¶ (
Optional[int]) – Number of beams to use in validation when generating predictions. Defaults to 4n_gram¶ (
int) – Maximum n_grams to use in metric calculation. Defaults to 4smooth¶ (
bool) – Apply smoothing in BLEU calculation. Defaults to Trueenable_ort¶ (
bool) – Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
- classmethod available_finetuning_strategies(cls)¶
Returns a list containing the keys of the available Finetuning Strategies.
- classmethod available_lr_schedulers(cls)¶
Returns a list containing the keys of the available LR schedulers.
- classmethod available_optimizers(cls)¶
Returns a list containing the keys of the available Optimizers.
- classmethod available_outputs(cls)¶
Returns the list of available outputs (that can be used during prediction or serving) for this
Task.Examples
..testsetup:
>>> from flash import Task
>>> print(Task.available_outputs()) ['preds', 'raw']
- classmethod load_from_checkpoint(cls, checkpoint_path, map_location=None, hparams_file=None, strict=True, **kwargs)¶
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to
__init__in the checkpoint under"hyper_parameters".Any arguments specified through **kwargs will override args stored in
"hyper_parameters".- Parameters
checkpoint_path¶ (
Union[str,Path,IO]) – Path to checkpoint. This can also be a URL, or file-like objectmap_location¶ (
Union[device,str,int,Callable[[Union[device,str,int]],Union[device,str,int]],Dict[Union[device,str,int],Union[device,str,int]],None]) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as intorch.load().hparams_file¶ (
Union[str,Path,None]) –Optional path to a
.yamlor.csvfile with hierarchical structure as in this example:drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a
.yamlfile with the hparams you’d like to use. These will be converted into adictand passed into yourLightningModulefor use.If your model’s
hparamsargument isNamespaceand.yamlfile has hierarchical structure, you need to refactor your model to treathparamsasdict.strict¶ (
bool) – Whether to strictly enforce that the keys incheckpoint_pathmatch the keys returned by this module’s state dict.**kwargs¶ – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
- Return type
Self- Returns
LightningModuleinstance with loaded weights and hyperparameters (if available).
Note
load_from_checkpointis a class method. You should use yourLightningModuleclass to call it instead of theLightningModuleinstance.Example:
# load weights without mapping ... model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)