GraphClassifier¶
- class flash.graph.classification.model.GraphClassifier(num_features, num_classes=None, labels=None, backbone='GCN', backbone_kwargs={}, pooling_fn='mean', head=None, loss_fn=torch.nn.functional.cross_entropy, learning_rate=None, optimizer='Adam', lr_scheduler=None, metrics=None)[source]¶
The
GraphClassifier
is aTask
for classifying graphs. For more details, see Graph Classification.- Parameters
backbone¶ (
Union
[str
,Tuple
[Module
,int
]]) – Name of the backbone to use.backbone_kwargs¶ (
Optional
[Dict
]) – Dictionary dependent on the backbone, containing for example in_channels, out_channels, hidden_channels or depth (number of layers).pooling_fn¶ (
Union
[str
,Callable
,None
]) – The global pooling operation to use (one of: “max”, “max”, “add” or a callable).loss_fn¶ (
TypeVar
(LOSS_FN_TYPE
,Callable
,Mapping
,Sequence
,None
)) – Loss function for training, defaults to cross entropy.learning_rate¶ (
Optional
[float
]) – Learning rate to use for training.optimizer¶ (
TypeVar
(OPTIMIZER_TYPE
,str
,Callable
,Tuple
[str
,Dict
[str
,Any
]],None
)) – Optimizer to use for training.lr_scheduler¶ (
Optional
[TypeVar
(LR_SCHEDULER_TYPE
,str
,Callable
,Tuple
[str
,Dict
[str
,Any
]],Tuple
[str
,Dict
[str
,Any
],Dict
[str
,Any
]],None
)]) – The LR scheduler to use during training.metrics¶ (
Optional
[TypeVar
(METRICS_TYPE
,Metric
,Mapping
,Sequence
,None
)]) – Metrics to compute for training and evaluation.
- classmethod available_finetuning_strategies(cls)¶
Returns a list containing the keys of the available Finetuning Strategies.
- classmethod available_lr_schedulers(cls)¶
Returns a list containing the keys of the available LR schedulers.
- classmethod available_optimizers(cls)¶
Returns a list containing the keys of the available Optimizers.