Tabular Classification¶
The Task¶
Tabular classification is the task of assigning a class to samples of structured or relational data.
The TabularClassifier
task can be used for classification of samples in more than two classes (multi-class classification).
Example¶
Let’s look at training a model to predict if passenger survival on the Titanic using the classic Kaggle data set. The data is provided in CSV files that look like this:
PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
1,0,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S
3,1,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S
5,0,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S
6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q
...
Once we’ve downloaded the data using download_data()
, we can create the TabularData
from our CSV files using the from_csv()
method.
From the API reference
, we need to provide:
cat_cols- A list of the names of columns that contain categorical data (strings or integers).
num_cols- A list of the names of columns that contain numerical continuous data (floats).
target- The name of the column we want to predict.
train_csv- A CSV file containing the training data converted to a Pandas DataFrame
Next, we create the TabularClassifier
and finetune on the Titanic data.
We then use the trained TabularClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import torch
import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularClassificationData, TabularClassifier
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")
datamodule = TabularClassificationData.from_csv(
categorical_fields=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_fields="Fare",
target_fields="Survived",
train_file="data/titanic/titanic.csv",
val_split=0.1,
batch_size=8,
)
# 2. Build the task
model = TabularClassifier.from_data(datamodule, backbone="fttransformer")
# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Generate predictions from a CSV
datamodule = TabularClassificationData.from_csv(
predict_file="data/titanic/titanic.csv",
parameters=datamodule.parameters,
batch_size=8,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("tabular_classification_model.pt")
Flash Zero¶
The tabular classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash tabular_classifier
To view configuration options and options for running the tabular classifier with your own data, use:
flash tabular_classifier --help
Serving¶
The TabularClassifier
is servable.
This means you can call .serve
to serve your Task
.
Here’s an example:
from flash.core.classification import LabelsOutput
from flash.tabular import TabularClassifier
model = TabularClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt"
)
model.output = LabelsOutput(["Did not survive", "Survived"])
model.serve()
You can now perform inference from your client like this:
import pandas as pd
import requests
from flash.core.data.utils import download_data
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")
df = pd.read_csv("./data/titanic/predict.csv")
text = str(df.to_csv())
body = {"session": "UUID", "payload": {"inputs": {"data": text}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())