Graph Classification¶
The Task¶
This task consist on classifying graphs. The task predicts which ‘class’ the graph belongs to. A class is a label that indicates the kind of graph. For example, a label may indicate whether one molecule interacts with another.
The GraphClassifier
and GraphClassificationData
classes internally rely on pytorch-geometric.
Example¶
Let’s look at the task of classifying graphs from the KKI data set from TU Dortmund University.
Once we’ve created the TUDataset, we create the GraphClassificationData
.
We then create our GraphClassifier
and train on the KKI data.
Next, we use the trained GraphClassifier
for inference.
Finally, we save the model.
Here’s the full example:
import flash
import torch
from flash.core.utilities.imports import example_requires
from flash.graph import GraphClassificationData, GraphClassifier
example_requires("graph")
from torch_geometric.datasets import TUDataset # noqa: E402
# 1. Create the DataModule
dataset = TUDataset(root="data", name="KKI")
datamodule = GraphClassificationData.from_datasets(
train_dataset=dataset,
val_split=0.1,
batch_size=4,
)
# 2. Build the task
backbone_kwargs = {"hidden_channels": 512, "num_layers": 4}
model = GraphClassifier(
num_features=datamodule.num_features, num_classes=datamodule.num_classes, backbone_kwargs=backbone_kwargs
)
# 3. Create the trainer and fit the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Classify some graphs!
datamodule = GraphClassificationData.from_datasets(
predict_dataset=dataset[:3],
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("graph_classification_model.pt")
To learn how to view the available backbones / heads for this task, see Backbones and Heads.
Flash Zero¶
The graph classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:
flash graph_classification
To view configuration options and options for running the graph classifier with your own data, use:
flash graph_classification --help