GraphClassificationData¶
- class flash.graph.classification.data.GraphClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
GraphClassificationData
class is aDataModule
with a set of classmethods for loading data for graph classification.- classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, input_cls=<class 'flash.graph.classification.input.GraphClassificationDatasetInput'>, transform=<class 'flash.graph.classification.input_transform.GraphClassificationInputTransform'>, transform_kwargs=None, target_formatter=None, **data_module_kwargs)[source]¶
Load the
GraphClassificationData
from PyTorch Dataset objects.The Dataset objects should be one of the following:
A PyTorch Dataset where the
__getitem__
returns a tuple:(PyTorch Geometric Data object, target)
- A PyTorch Dataset where the
__getitem__
returns a dict: {"input": PyTorch Geometric Data object, "target": target}
- A PyTorch Dataset where the
To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_dataset¶ (
Optional
[Dataset
]) – The Dataset to use when training.val_dataset¶ (
Optional
[Dataset
]) – The Dataset to use when validating.test_dataset¶ (
Optional
[Dataset
]) – The Dataset to use when testing.predict_dataset¶ (
Optional
[Dataset
]) – The Dataset to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. IfNone
then no formatting will be applied to targets.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ – Additional keyword arguments to provide to the
DataModule
constructor.
- Return type
- Returns
The constructed
GraphClassificationData
.
Examples
A PyTorch Dataset where the
__getitem__
returns a tuple:(PyTorch Geometric Data object, target)
:>>> import torch >>> from torch.utils.data import Dataset >>> from torch_geometric.data import Data >>> from flash import Trainer >>> from flash.graph import GraphClassificationData, GraphClassifier >>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter >>> >>> class CustomDataset(Dataset): ... def __init__(self, targets=None): ... self.targets = targets ... def __getitem__(self, index): ... edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) ... x = torch.tensor([[-1], [0], [1]], dtype=torch.float) ... data = Data(x=x, edge_index=edge_index) ... if self.targets is not None: ... return data, self.targets[index] ... return data ... def __len__(self): ... return len(self.targets) if self.targets is not None else 3 ... >>> datamodule = GraphClassificationData.from_datasets( ... train_dataset=CustomDataset(["cat", "dog", "cat"]), ... predict_dataset=CustomDataset(), ... target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]), ... batch_size=2, ... ) >>> datamodule.num_features 1 >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
A PyTorch Dataset where the
__getitem__
returns a dict:{"input": PyTorch Geometric Data object, "target": target}
:>>> import torch # noqa: F811 >>> from torch.utils.data import Dataset >>> from torch_geometric.data import Data # noqa: F811 >>> from flash import Trainer >>> from flash.graph import GraphClassificationData, GraphClassifier >>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter >>> >>> class CustomDataset(Dataset): ... def __init__(self, targets=None): ... self.targets = targets ... def __getitem__(self, index): ... edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) ... x = torch.tensor([[-1], [0], [1]], dtype=torch.float) ... data = Data(x=x, edge_index=edge_index) ... if self.targets is not None: ... return {"input": data, "target": self.targets[index]} ... return {"input": data} ... def __len__(self): ... return len(self.targets) if self.targets is not None else 3 ... >>> datamodule = GraphClassificationData.from_datasets( ... train_dataset=CustomDataset(["cat", "dog", "cat"]), ... predict_dataset=CustomDataset(), ... target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]), ... batch_size=2, ... ) >>> datamodule.num_features 1 >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- property num_features¶
The number of features per node in the graphs contained in this
GraphClassificationData
.