Shortcuts

Graph Classification

The Task

This task consist on classifying graphs. The task predicts which ‘class’ the graph belongs to. A class is a label that indicates the kind of graph. For example, a label may indicate whether one molecule interacts with another.

The GraphClassifier and GraphClassificationData classes internally rely on pytorch-geometric.


Example

Let’s look at the task of classifying graphs from the KKI data set from TU Dortmund University.

Once we’ve created the TUDataset, we create the GraphClassificationData. We then create our GraphClassifier and train on the KKI data. Next, we use the trained GraphClassifier for inference. Finally, we save the model. Here’s the full example:

import flash
import torch
from flash.core.utilities.imports import example_requires
from flash.graph import GraphClassificationData, GraphClassifier

example_requires("graph")

from torch_geometric.datasets import TUDataset  # noqa: E402

# 1. Create the DataModule
dataset = TUDataset(root="data", name="KKI")

datamodule = GraphClassificationData.from_datasets(
    train_dataset=dataset,
    val_split=0.1,
    batch_size=4,
)
# 2. Build the task
backbone_kwargs = {"hidden_channels": 512, "num_layers": 4}
model = GraphClassifier(
    num_features=datamodule.num_features, num_classes=datamodule.num_classes, backbone_kwargs=backbone_kwargs
)

# 3. Create the trainer and fit the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Classify some graphs!
datamodule = GraphClassificationData.from_datasets(
    predict_dataset=dataset[:3],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule, output="classes")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("graph_classification_model.pt")

To learn how to view the available backbones / heads for this task, see Backbones and Heads.


Flash Zero

The graph classifier can be used directly from the command line with zero code using Flash Zero. You can run the above example with:

flash graph_classification

To view configuration options and options for running the graph classifier with your own data, use:

flash graph_classification --help
Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.