Text Classification¶
The Task¶
Text classification is the task of assigning a piece of text (word, sentence or document) an appropriate class, or category. The categories depend on the chosen data set and can range from topics.
Example¶
Let’s train a model to classify text as expressing either positive or negative sentiment.
We will be using the IMDB data set, that contains a train.csv
and valid.csv
.
Here’s the structure:
review,sentiment
"Japanese indie film with humor ... ",positive
"Isaac Florentine has made some ...",negative
"After seeing the low-budget ...",negative
"I've seen the original English version ...",positive
"Hunters chase what they think is a man through ...",negative
...
Once we’ve downloaded the data using download_data()
, we create the TextClassificationData
.
We select a pre-trained backbone to use for our TextClassifier
and finetune on the IMDB data.
The backbone can be any BERT classification model from HuggingFace/transformers.
Note
When changing the backbone, make sure you pass in the same backbone to the TextClassifier
and the TextClassificationData
!
Next, we use the trained TextClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import flash
import torch
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
datamodule = TextClassificationData.from_csv(
"review",
"sentiment",
train_file="data/imdb/train.csv",
val_file="data/imdb/valid.csv",
batch_size=4,
)
# 2. Build the task
model = TextClassifier(backbone="prajjwal1/bert-medium", labels=datamodule.labels)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Classify a few sentences! How was the movie?
datamodule = TextClassificationData.from_lists(
predict_data=[
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado.",
],
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("text_classification_model.pt")
To learn how to view the available backbones / heads for this task, see Backbones and Heads.
Flash Zero¶
The text classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash text_classification
To view configuration options and options for running the text classifier with your own data, use:
flash text_classification --help
Serving¶
The TextClassifier
is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.text import TextClassifier
model = TextClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.9.0/text_classification_model.pt")
model.serve()
You can now perform inference from your client like this:
import requests
text = "Best movie ever"
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Accelerate Training & Inference with Torch ORT¶
Torch ORT converts your model into an optimized ONNX graph, speeding up training & inference when using NVIDIA or AMD GPUs. Enabling Torch ORT requires a single flag passed to the TextClassifier
once installed. See installation instructions here.
Note
Not all Transformer models are supported. See this table for supported models + branches containing fixes for certain models.
...
model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)