Shortcuts

Tabular Classification

The Task

Tabular classification is the task of assigning a class to samples of structured or relational data. The TabularClassifier task can be used for classification of samples in more than two classes (multi-class classification).


Example

Let’s look at training a model to predict if passenger survival on the Titanic using the classic Kaggle data set. The data is provided in CSV files that look like this:

PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
1,0,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S
3,1,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S
5,0,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S
6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q
...

Once we’ve downloaded the data using download_data(), we can create the TabularData from our CSV files using the from_csv() method. From the API reference, we need to provide:

  • cat_cols- A list of the names of columns that contain categorical data (strings or integers).

  • num_cols- A list of the names of columns that contain numerical continuous data (floats).

  • target- The name of the column we want to predict.

  • train_csv- A CSV file containing the training data converted to a Pandas DataFrame

Next, we create the TabularClassifier and finetune on the Titanic data. We then use the trained TabularClassifier for inference. Finally, we save the model. Here’s the full example:

import torch

import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularClassificationData, TabularClassifier

# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")

datamodule = TabularClassificationData.from_csv(
    ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
    "Fare",
    target_fields="Survived",
    train_file="data/titanic/titanic.csv",
    val_split=0.1,
)

# 2. Build the task
model = TabularClassifier.from_data(datamodule)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Generate predictions from a CSV
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("tabular_classification_model.pt")

Flash Zero

The tabular classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash tabular_classifier

To view configuration options and options for running the tabular classifier with your own data, use:

flash tabular_classifier --help

Serving

The TabularClassifier is servable. This means you can call .serve to serve your Task. Here’s an example:

from flash.core.classification import Labels
from flash.tabular import TabularClassifier

model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt")
model.serializer = Labels(["Did not survive", "Survived"])
model.serve()

You can now perform inference from your client like this:

import pandas as pd
import requests

from flash.core.data.utils import download_data

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")

df = pd.read_csv("./data/titanic/predict.csv")
text = str(df.to_csv())
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
Read the Docs v: latest
Versions
latest
stable
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.