Shortcuts

Text Classification

The Task

Text classification is the task of assigning a piece of text (word, sentence or document) an appropriate class, or category. The categories depend on the chosen data set and can range from topics.


Example

Let’s train a model to classify text as expressing either positive or negative sentiment. We will be using the IMDB data set, that contains a train.csv and valid.csv. Here’s the structure:

review,sentiment
"Japanese indie film with humor ... ",positive
"Isaac Florentine has made some ...",negative
"After seeing the low-budget ...",negative
"I've seen the original English version ...",positive
"Hunters chase what they think is a man through ...",negative
...

Once we’ve downloaded the data using download_data(), we create the TextClassificationData. We select a pre-trained backbone to use for our TextClassifier and finetune on the IMDB data. The backbone can be any BERT classification model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the TextClassifier and the TextClassificationData!

Next, we use the trained TextClassifier for inference. Finally, we save the model. Here’s the full example:

import flash
import torch
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")

datamodule = TextClassificationData.from_csv(
    "review",
    "sentiment",
    train_file="data/imdb/train.csv",
    val_file="data/imdb/valid.csv",
    batch_size=4,
)

# 2. Build the task
model = TextClassifier(backbone="prajjwal1/bert-medium", labels=datamodule.labels)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Classify a few sentences! How was the movie?
datamodule = TextClassificationData.from_lists(
    predict_data=[
        "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
        "The worst movie in the history of cinema.",
        "I come from Bulgaria where it 's almost impossible to have a tornado.",
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("text_classification_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The text classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash text_classification

To view configuration options and options for running the text classifier with your own data, use:

flash text_classification --help

Serving

The TextClassifier is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.text import TextClassifier

model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.9.0/text_classification_model.pt")
model.serve()

You can now perform inference from your client like this:

import requests

text = "Best movie ever"
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)

print(resp.json())

Accelerate Training & Inference with Torch ORT

Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the TextClassifier once installed. See installation instructions here.

Note

Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.

...

model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)