Shortcuts

GraphClassifier

class flash.graph.classification.model.GraphClassifier(num_features, num_classes, backbone='GCN', backbone_kwargs={}, pooling_fn='mean', head=None, loss_fn=torch.nn.functional.cross_entropy, learning_rate=0.001, optimizer='Adam', lr_scheduler=None, metrics=None)[source]

The GraphClassifier is a Task for classifying graphs. For more details, see Graph Classification.

Parameters
  • num_features (int) – The number of features in the input.

  • num_classes (int) – Number of classes to classify.

  • backbone (Union[str, Tuple[Module, int]]) – Name of the backbone to use.

  • backbone_kwargs (Optional[Dict]) – Dictionary dependent on the backbone, containing for example in_channels, out_channels, hidden_channels or depth (number of layers).

  • pooling_fn (Union[str, Callable, None]) – The global pooling operation to use (one of: “max”, “max”, “add” or a callable).

  • head (Union[Callable, Module, None]) – The head to use.

  • loss_fn (~LOSS_FN_TYPE) – Loss function for training, defaults to cross entropy.

  • learning_rate (float) – Learning rate to use for training.

  • optimizer (~OPTIMIZER_TYPE) – Optimizer to use for training.

  • lr_scheduler (Optional[~LR_SCHEDULER_TYPE]) – The LR scheduler to use during training.

  • metrics (Optional[~METRICS_TYPE]) – Metrics to compute for training and evaluation.

Read the Docs v: stable
Versions
latest
stable
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.