GraphClassificationData¶
- class flash.graph.classification.data.GraphClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
GraphClassificationDataclass is aDataModulewith a set of classmethods for loading data for graph classification.- classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, input_cls=<class 'flash.graph.classification.input.GraphClassificationDatasetInput'>, transform=<class 'flash.graph.classification.input_transform.GraphClassificationInputTransform'>, transform_kwargs=None, target_formatter=None, **data_module_kwargs)[source]¶
Load the
GraphClassificationDatafrom PyTorch Dataset objects.The Dataset objects should be one of the following:
A PyTorch Dataset where the
__getitem__returns a tuple:(PyTorch Geometric Data object, target)- A PyTorch Dataset where the
__getitem__returns a dict: {"input": PyTorch Geometric Data object, "target": target}
- A PyTorch Dataset where the
To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_dataset¶ (
Optional[Dataset]) – The Dataset to use when training.val_dataset¶ (
Optional[Dataset]) – The Dataset to use when validating.test_dataset¶ (
Optional[Dataset]) – The Dataset to use when testing.predict_dataset¶ (
Optional[Dataset]) – The Dataset to use when predicting.target_formatter¶ (
Optional[TargetFormatter]) – Optionally provide aTargetFormatterto control how targets are handled. IfNonethen no formatting will be applied to targets.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ – Additional keyword arguments to provide to the
DataModuleconstructor.
- Return type
- Returns
The constructed
GraphClassificationData.
Examples
A PyTorch Dataset where the
__getitem__returns a tuple:(PyTorch Geometric Data object, target):>>> import torch >>> from torch.utils.data import Dataset >>> from torch_geometric.data import Data >>> from flash import Trainer >>> from flash.graph import GraphClassificationData, GraphClassifier >>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter >>> >>> class CustomDataset(Dataset): ... def __init__(self, targets=None): ... self.targets = targets ... def __getitem__(self, index): ... edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) ... x = torch.tensor([[-1], [0], [1]], dtype=torch.float) ... data = Data(x=x, edge_index=edge_index) ... if self.targets is not None: ... return data, self.targets[index] ... return data ... def __len__(self): ... return len(self.targets) if self.targets is not None else 3 ... >>> datamodule = GraphClassificationData.from_datasets( ... train_dataset=CustomDataset(["cat", "dog", "cat"]), ... predict_dataset=CustomDataset(), ... target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]), ... batch_size=2, ... ) >>> datamodule.num_features 1 >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
A PyTorch Dataset where the
__getitem__returns a dict:{"input": PyTorch Geometric Data object, "target": target}:>>> import torch # noqa: F811 >>> from torch.utils.data import Dataset >>> from torch_geometric.data import Data # noqa: F811 >>> from flash import Trainer >>> from flash.graph import GraphClassificationData, GraphClassifier >>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter >>> >>> class CustomDataset(Dataset): ... def __init__(self, targets=None): ... self.targets = targets ... def __getitem__(self, index): ... edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) ... x = torch.tensor([[-1], [0], [1]], dtype=torch.float) ... data = Data(x=x, edge_index=edge_index) ... if self.targets is not None: ... return {"input": data, "target": self.targets[index]} ... return {"input": data} ... def __len__(self): ... return len(self.targets) if self.targets is not None else 3 ... >>> datamodule = GraphClassificationData.from_datasets( ... train_dataset=CustomDataset(["cat", "dog", "cat"]), ... predict_dataset=CustomDataset(), ... target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]), ... batch_size=2, ... ) >>> datamodule.num_features 1 >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- property num_features¶
The number of features per node in the graphs contained in this
GraphClassificationData.