Shortcuts

Multi-label Text Classification

The Task

Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality (text in this case). Multi-label text classification is supported by the TextClassifier via the multi-label argument.


Example

Let’s look at the task of classifying comment toxicity. The data we will use in this example is from the kaggle toxic comment classification challenge by jigsaw: www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge. The data is stored in CSV files with this structure:

"id","comment_text","toxic","severe_toxic","obscene","threat","insult","identity_hate"
"0000997932d777bf","...",0,0,0,0,0,0
"0002bcb3da6cb337","...",1,1,1,0,1,0
"0005c987bdfc9d4b","...",1,0,0,0,0,0
...

Once we’ve downloaded the data using download_data(), we create the TextClassificationData. We select a pre-trained backbone to use for our TextClassifier and finetune on the toxic comments data. The backbone can be any BERT classification model from HuggingFace/transformers.

Note

When changing the backbone, make sure you pass in the same backbone to the TextClassifier and the TextClassificationData!

Next, we use the trained TextClassifier for inference. Finally, we save the model. Here’s the full example:

import flash
import torch
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Create the DataModule
# Data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data")

datamodule = TextClassificationData.from_csv(
    "comment_text",
    ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
    train_file="data/jigsaw_toxic_comments/train.csv",
    val_split=0.1,
    batch_size=4,
)

# 2. Build the task
model = TextClassifier(
    backbone="unitary/toxic-bert",
    labels=datamodule.labels,
    multi_label=datamodule.multi_label,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Generate predictions for a few comments!
datamodule = TextClassificationData.from_lists(
    predict_data=[
        "No, he is an arrogant, self serving, immature idiot. Get it right.",
        "U SUCK HANNAH MONTANA",
        "Would you care to vote? Thx.",
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("text_classification_multi_label_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The multi-label text classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash text_classification from_toxic

To view configuration options and options for running the text classifier with your own data, use:

flash text_classification --help

Serving

The TextClassifier is servable. For more information, see Text Classification.

Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.