Shortcuts

Tabular Classification

The Task

Tabular classification is the task of assigning a class to samples of structured or relational data. The TabularClassifier task can be used for classification of samples in more than two classes (multi-class classification).


Example

Let’s look at training a model to predict if passenger survival on the Titanic using the classic Kaggle data set. The data is provided in CSV files that look like this:

PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
1,0,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S
3,1,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S
5,0,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S
6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q
...

We can create the TabularData from our CSV files using the from_csv() method. From the API reference, we need to provide:

  • cat_cols- A list of the names of columns that contain categorical data (strings or integers).

  • num_cols- A list of the names of columns that contain numerical continuous data (floats).

  • target- The name of the column we want to predict.

  • train_csv- A CSV file containing the training data converted to a Pandas DataFrame

Next, we create the TabularClassifier and finetune on the Titanic data. We then use the trained TabularClassifier for inference. Finally, we save the model. Here’s the full example:

import flash
import torch
from flash.tabular import TabularClassificationData, TabularClassifier

# 1. Create the DataModule
datamodule = TabularClassificationData.from_csv(
    categorical_fields=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
    numerical_fields="Fare",
    target_fields="Survived",
    train_file="https://pl-flash-data.s3.amazonaws.com/titanic.csv",
    val_split=0.1,
    batch_size=8,
)

# 2. Build the task
model = TabularClassifier.from_data(datamodule, backbone="fttransformer")

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Generate predictions from a CSV
datamodule = TabularClassificationData.from_csv(
    predict_file="https://pl-flash-data.s3.amazonaws.com/titanic.csv",
    parameters=datamodule.parameters,
    batch_size=8,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("tabular_classification_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The tabular classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash tabular_classifier

To view configuration options and options for running the tabular classifier with your own data, use:

flash tabular_classifier --help

Serving

The TabularClassifier is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.core.classification import LabelsOutput
from flash.tabular import TabularClassifier

model = TabularClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt"
)
model.serve(output=LabelsOutput(["Did not survive", "Survived"]))

You can now perform inference from your client like this:

import pandas as pd
import requests
from flash.core.data.utils import download_data

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")

df = pd.read_csv("./data/titanic/predict.csv")
text = str(df.to_csv())
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.