Shortcuts

Tabular Classification

The Task

Tabular classification is the task of assigning a class to samples of structured or relational data. The TabularClassifier task can be used for classification of samples in more than two classes (multi-class classification).


Example

Let’s look at training a model to predict if passenger survival on the Titanic using the classic Kaggle data set. The data is provided in CSV files that look like this:

PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
1,0,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S
3,1,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S
5,0,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S
6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q
...

Once we’ve downloaded the data using download_data(), we can create the TabularData from our CSV files using the from_csv() method. From the API reference, we need to provide:

  • cat_cols- A list of the names of columns that contain categorical data (strings or integers).

  • num_cols- A list of the names of columns that contain numerical continuous data (floats).

  • target- The name of the column we want to predict.

  • train_csv- A CSV file containing the training data converted to a Pandas DataFrame

Next, we create the TabularClassifier and finetune on the Titanic data. We then use the trained TabularClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularClassificationData, TabularClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")

datamodule = TabularClassificationData.from_csv(
    categorical_fields=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
    numerical_fields="Fare",
    target_fields="Survived",
    train_file="data/titanic/titanic.csv",
    val_split=0.1,
    batch_size=8,
)

# 2. Build the task
model = TabularClassifier.from_data(datamodule, backbone="fttransformer")

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Generate predictions from a CSV
datamodule = TabularClassificationData.from_csv(
    predict_file="data/titanic/titanic.csv",
    parameters=datamodule.parameters,
    batch_size=8,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("tabular_classification_model.pt")

Flash Zero

The tabular classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash tabular_classifier

To view configuration options and options for running the tabular classifier with your own data, use:

flash tabular_classifier --help

Serving

The TabularClassifier is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.core.classification import LabelsOutput
from flash.tabular import TabularClassifier

model = TabularClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt"
)
model.output = LabelsOutput(["Did not survive", "Survived"])
model.serve()

You can now perform inference from your client like this:

import pandas as pd
import requests

from flash.core.data.utils import download_data

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")

df = pd.read_csv("./data/titanic/predict.csv")
text = str(df.to_csv())
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Read the Docs v: 0.7.0
Versions
latest
stable
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.