Shortcuts

Source code for flash.core.data.splits

from typing import Any, List

import numpy as np
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import Dataset

from flash.core.data.properties import Properties


[docs]class SplitDataset(Properties, Dataset): """SplitDataset is used to create Dataset Subset using indices. Args: dataset: A dataset to be splitted indices: List of indices to expose from the dataset use_duplicated_indices: Whether to allow duplicated indices. Example:: split_ds = SplitDataset(dataset, indices=[10, 14, 25]) split_ds = SplitDataset(dataset, indices=[10, 10, 10, 14, 25], use_duplicated_indices=True) """ def __init__(self, dataset: Any, indices: List[int] = None, use_duplicated_indices: bool = False) -> None: kwargs = {} if isinstance(dataset, Properties): kwargs = dict( running_stage=dataset._running_stage, ) super().__init__(**kwargs) if indices is None: indices = [] if not isinstance(indices, list): raise MisconfigurationException("indices should be a list") if use_duplicated_indices: indices = list(indices) else: indices = list(np.unique(indices)) if np.max(indices) >= len(dataset) or np.min(indices) < 0: raise MisconfigurationException(f"`indices` should be within [0, {len(dataset) -1}].") self.dataset = dataset self.indices = indices def __getattr__(self, key: str): if key != "dataset": return getattr(self.dataset, key) raise AttributeError def __getitem__(self, index: int) -> Any: return self.dataset[self.indices[index]] def __len__(self) -> int: return len(self.indices)

© Copyright 2020-2021, PyTorch Lightning. Revision 8e9123c7.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.1
Versions
latest
stable
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.