Shortcuts

Source code for flash.graph.classification.data

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Type

from torch.utils.data import Dataset

from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
from flash.core.data.utilities.classification import TargetFormatter
from flash.core.utilities.imports import _GRAPH_TESTING
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.graph.classification.input import GraphClassificationDatasetInput
from flash.graph.classification.input_transform import GraphClassificationInputTransform

# Skip doctests if requirements aren't available
if not _GRAPH_TESTING:
    __doctest_skip__ = ["GraphClassificationData", "GraphClassificationData.*"]


[docs]class GraphClassificationData(DataModule): """The ``GraphClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods for loading data for graph classification.""" input_transform_cls = GraphClassificationInputTransform
[docs] @classmethod def from_datasets( cls, train_dataset: Optional[Dataset] = None, val_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None, predict_dataset: Optional[Dataset] = None, train_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, val_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, test_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, predict_transform: INPUT_TRANSFORM_TYPE = GraphClassificationInputTransform, target_formatter: Optional[TargetFormatter] = None, input_cls: Type[Input] = GraphClassificationDatasetInput, transform_kwargs: Optional[Dict] = None, **data_module_kwargs, ) -> "GraphClassificationData": """Load the :class:`~flash.graph.classification.data.GraphClassificationData` from PyTorch Dataset objects. The Dataset objects should be one of the following: * A PyTorch Dataset where the ``__getitem__`` returns a tuple: ``(PyTorch Geometric Data object, target)`` * A PyTorch Dataset where the ``__getitem__`` returns a dict: ``{"input": PyTorch Geometric Data object, "target": target}`` To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: train_dataset: The Dataset to use when training. val_dataset: The Dataset to use when validating. test_dataset: The Dataset to use when testing. predict_dataset: The Dataset to use when predicting. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. target_formatter: Optionally provide a :class:`~flash.core.data.utilities.classification.TargetFormatter` to control how targets are handled. If ``None`` then no formatting will be applied to targets. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. Returns: The constructed :class:`~flash.graph.classification.data.GraphClassificationData`. Examples ________ A PyTorch Dataset where the ``__getitem__`` returns a tuple: ``(PyTorch Geometric Data object, target)``: .. doctest:: >>> import torch >>> from torch.utils.data import Dataset >>> from torch_geometric.data import Data >>> from flash import Trainer >>> from flash.graph import GraphClassificationData, GraphClassifier >>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter >>> >>> class CustomDataset(Dataset): ... def __init__(self, targets=None): ... self.targets = targets ... def __getitem__(self, index): ... edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) ... x = torch.tensor([[-1], [0], [1]], dtype=torch.float) ... data = Data(x=x, edge_index=edge_index) ... if self.targets is not None: ... return data, self.targets[index] ... return data ... def __len__(self): ... return len(self.targets) if self.targets is not None else 3 ... >>> datamodule = GraphClassificationData.from_datasets( ... train_dataset=CustomDataset(["cat", "dog", "cat"]), ... predict_dataset=CustomDataset(), ... target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]), ... batch_size=2, ... ) >>> datamodule.num_features 1 >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... A PyTorch Dataset where the ``__getitem__`` returns a dict: ``{"input": PyTorch Geometric Data object, "target": target}``: .. doctest:: >>> import torch # noqa: F811 >>> from torch.utils.data import Dataset >>> from torch_geometric.data import Data # noqa: F811 >>> from flash import Trainer >>> from flash.graph import GraphClassificationData, GraphClassifier >>> from flash.core.data.utilities.classification import SingleLabelTargetFormatter >>> >>> class CustomDataset(Dataset): ... def __init__(self, targets=None): ... self.targets = targets ... def __getitem__(self, index): ... edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) ... x = torch.tensor([[-1], [0], [1]], dtype=torch.float) ... data = Data(x=x, edge_index=edge_index) ... if self.targets is not None: ... return {"input": data, "target": self.targets[index]} ... return {"input": data} ... def __len__(self): ... return len(self.targets) if self.targets is not None else 3 ... >>> datamodule = GraphClassificationData.from_datasets( ... train_dataset=CustomDataset(["cat", "dog", "cat"]), ... predict_dataset=CustomDataset(), ... target_formatter=SingleLabelTargetFormatter(labels=["cat", "dog"]), ... batch_size=2, ... ) >>> datamodule.num_features 1 >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... """ ds_kw = dict( target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) train_input = input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw) ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), **data_module_kwargs, )
@property def num_features(self): """The number of features per node in the graphs contained in this ``GraphClassificationData``.""" n_cls_train = getattr(self.train_dataset, "num_features", None) n_cls_val = getattr(self.val_dataset, "num_features", None) n_cls_test = getattr(self.test_dataset, "num_features", None) return n_cls_train or n_cls_val or n_cls_test

© Copyright 2020-2021, PyTorch Lightning. Revision a5e68476.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.2
Versions
latest
stable
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.