Shortcuts

GraphClassificationData

class flash.graph.classification.data.GraphClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The GraphClassificationData class is a DataModule with a set of classmethods for loading data for graph classification.

classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, train_transform=<class 'flash.graph.classification.input_transform.GraphClassificationInputTransform'>, val_transform=<class 'flash.graph.classification.input_transform.GraphClassificationInputTransform'>, test_transform=<class 'flash.graph.classification.input_transform.GraphClassificationInputTransform'>, predict_transform=<class 'flash.graph.classification.input_transform.GraphClassificationInputTransform'>, target_formatter=None, input_cls=<class 'flash.graph.classification.input.GraphClassificationDatasetInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the GraphClassificationData from PyTorch Dataset objects.

The Dataset objects should be one of the following:

  • A PyTorch Dataset where the __getitem__ returns a tuple: (PyTorch Geometric Data object, target)

  • A PyTorch Dataset where the __getitem__ returns a dict:

    {"input": PyTorch Geometric Data object, "target": target}

To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

GraphClassificationData

Returns

The constructed GraphClassificationData.

Examples

A PyTorch Dataset where the __getitem__ returns a tuple: (PyTorch Geometric Data object, target):

>>> import torch
>>> from torch.utils.data import Dataset
>>> from torch_geometric.data import Data
>>> from flash import Trainer
>>> from flash.graph import GraphClassificationData, GraphClassifier
>>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter
>>>
>>> class CustomDataset(Dataset):
...     def __init__(self, targets=None):
...         self.targets = targets
...     def __getitem__(self, index):
...         edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
...         x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
...         data = Data(x=x, edge_index=edge_index)
...         if self.targets is not None:
...             return data, self.targets[index]
...         return data
...     def __len__(self):
...         return len(self.targets) if self.targets is not None else 3
...
>>> datamodule = GraphClassificationData.from_datasets(
...     train_dataset=CustomDataset(["cat", "dog", "cat"]),
...     predict_dataset=CustomDataset(),
...     target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]),
...     batch_size=2,
... )
>>> datamodule.num_features
1
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

A PyTorch Dataset where the __getitem__ returns a dict: {"input": PyTorch Geometric Data object, "target": target}:

>>> import torch  # noqa: F811
>>> from torch.utils.data import Dataset
>>> from torch_geometric.data import Data  # noqa: F811
>>> from flash import Trainer
>>> from flash.graph import GraphClassificationData, GraphClassifier
>>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter
>>>
>>> class CustomDataset(Dataset):
...     def __init__(self, targets=None):
...         self.targets = targets
...     def __getitem__(self, index):
...         edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
...         x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
...         data = Data(x=x, edge_index=edge_index)
...         if self.targets is not None:
...             return {"input": data, "target": self.targets[index]}
...         return {"input": data}
...     def __len__(self):
...         return len(self.targets) if self.targets is not None else 3
...
>>> datamodule = GraphClassificationData.from_datasets(
...     train_dataset=CustomDataset(["cat", "dog", "cat"]),
...     predict_dataset=CustomDataset(),
...     target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]),
...     batch_size=2,
... )
>>> datamodule.num_features
1
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
property num_features

The number of features per node in the graphs contained in this GraphClassificationData.

Read the Docs v: stable
Versions
latest
stable
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.