Shortcuts

Source code for flash.core.data.splits

from typing import Any, List, Optional

import numpy as np
from torch.utils.data import Dataset

from flash.core.data.properties import Properties
from flash.core.utilities.stages import RunningStage


[docs]class SplitDataset(Properties, Dataset): """SplitDataset is used to create Dataset Subset using indices. Args: dataset: A dataset to be split indices: List of indices to expose from the dataset use_duplicated_indices: Whether to allow duplicated indices. Example:: split_ds = SplitDataset(dataset, indices=[10, 14, 25]) split_ds = SplitDataset(dataset, indices=[10, 10, 10, 14, 25], use_duplicated_indices=True) """ def __init__( self, dataset: Any, indices: List[int], running_stage: Optional[RunningStage] = None, use_duplicated_indices: bool = False, ) -> None: kwargs = {} if running_stage is not None: kwargs = {"running_stage": running_stage} elif isinstance(dataset, Properties): kwargs = {"running_stage": dataset._running_stage} super().__init__(**kwargs) if not isinstance(indices, list): raise TypeError("indices should be a list") indices = list(indices) if use_duplicated_indices else list(np.unique(indices)) if np.max(indices) >= len(dataset) or np.min(indices) < 0: raise ValueError(f"`indices` should be within [0, {len(dataset) -1}].") self.dataset = dataset self.indices = indices def __getattr__(self, key: str): if key != "dataset": return getattr(self.dataset, key) raise AttributeError def __getitem__(self, index: int) -> Any: return self.dataset[self.indices[index]] def __len__(self) -> int: return len(self.indices)

© Copyright 2020-2021, PyTorch Lightning. Revision a9cedb5a.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: stable
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.