Shortcuts

GraphClassificationData

class flash.graph.classification.data.GraphClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The GraphClassificationData class is a DataModule with a set of classmethods for loading data for graph classification.

classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, input_cls=<class 'flash.graph.classification.input.GraphClassificationDatasetInput'>, transform=<class 'flash.graph.classification.input_transform.GraphClassificationInputTransform'>, transform_kwargs=None, target_formatter=None, **data_module_kwargs)[source]

Load the GraphClassificationData from PyTorch Dataset objects.

The Dataset objects should be one of the following:

  • A PyTorch Dataset where the __getitem__ returns a tuple: (PyTorch Geometric Data object, target)

  • A PyTorch Dataset where the __getitem__ returns a dict:

    {"input": PyTorch Geometric Data object, "target": target}

To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

GraphClassificationData

Returns

The constructed GraphClassificationData.

Examples

A PyTorch Dataset where the __getitem__ returns a tuple: (PyTorch Geometric Data object, target):

>>> import torch
>>> from torch.utils.data import Dataset
>>> from torch_geometric.data import Data
>>> from flash import Trainer
>>> from flash.graph import GraphClassificationData, GraphClassifier
>>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter
>>>
>>> class CustomDataset(Dataset):
...     def __init__(self, targets=None):
...         self.targets = targets
...     def __getitem__(self, index):
...         edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
...         x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
...         data = Data(x=x, edge_index=edge_index)
...         if self.targets is not None:
...             return data, self.targets[index]
...         return data
...     def __len__(self):
...         return len(self.targets) if self.targets is not None else 3
...
>>> datamodule = GraphClassificationData.from_datasets(
...     train_dataset=CustomDataset(["cat", "dog", "cat"]),
...     predict_dataset=CustomDataset(),
...     target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]),
...     batch_size=2,
... )
>>> datamodule.num_features
1
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

A PyTorch Dataset where the __getitem__ returns a dict: {"input": PyTorch Geometric Data object, "target": target}:

>>> import torch  # noqa: F811
>>> from torch.utils.data import Dataset
>>> from torch_geometric.data import Data  # noqa: F811
>>> from flash import Trainer
>>> from flash.graph import GraphClassificationData, GraphClassifier
>>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter
>>>
>>> class CustomDataset(Dataset):
...     def __init__(self, targets=None):
...         self.targets = targets
...     def __getitem__(self, index):
...         edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
...         x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
...         data = Data(x=x, edge_index=edge_index)
...         if self.targets is not None:
...             return {"input": data, "target": self.targets[index]}
...         return {"input": data}
...     def __len__(self):
...         return len(self.targets) if self.targets is not None else 3
...
>>> datamodule = GraphClassificationData.from_datasets(
...     train_dataset=CustomDataset(["cat", "dog", "cat"]),
...     predict_dataset=CustomDataset(),
...     target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]),
...     batch_size=2,
... )
>>> datamodule.num_features
1
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
property num_features

The number of features per node in the graphs contained in this GraphClassificationData.

Read the Docs v: stable
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.