Shortcuts

Translation

The Task

Translation is the task of translating text from a source language to another, such as English to Romanian. This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case, the task will take an English sequence as input, and output the same sequence in Romanian.


Example

Let’s look at an example. We’ll use WMT16 English/Romanian, a dataset of English to Romanian samples, based on the Europarl corpora. The data set contains a train.csv and valid.csv. Each CSV file looks like this:

input,target
"Written statements and oral questions (tabling): see Minutes","Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal"
"Closure of sitting","Ridicarea şedinţei"
...

In the above the input/target columns represent the English and Romanian translation respectively. Once we’ve downloaded the data using download_data(), we create the TranslationData. We select a pre-trained backbone to use for our TranslationTask and finetune on the WMT16 data. The backbone can be any Seq2Seq translation model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the TranslationData and the TranslationTask!

Next, we use the trained TranslationTask for inference. Finally, we save the model. Here’s the full example:

import flash
import torch
from flash.core.data.utils import download_data
from flash.text import TranslationData, TranslationTask

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "./data")

datamodule = TranslationData.from_csv(
    "input",
    "target",
    train_file="data/wmt_en_ro/train.csv",
    val_file="data/wmt_en_ro/valid.csv",
    batch_size=4,
)

# 2. Build the task
model = TranslationTask(backbone="Helsinki-NLP/opus-mt-en-ro")

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Translate something!
datamodule = TranslationData.from_lists(
    predict_data=[
        "BBC News went to meet one of the project's first graduates.",
        "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",
        "Of course, it's still early in the election cycle.",
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("translation_model_en_ro.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The translation task can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash translation

To view configuration options and options for running the translation task with your own data, use:

flash translation --help

Serving

The TranslationTask is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.text import TranslationTask

model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.7.0/translation_model_en_ro.pt")
model.serve()

You can now perform inference from your client like this:

import requests

text = "Some English text"
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)

print(resp.json())

Accelerate Training & Inference with Torch ORT

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the TranslationTask once installed. See installation instructions here.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

...

model = TranslationTask(backbone="t5-large", num_classes=datamodule.num_classes, enable_ort=True)
Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.