Shortcuts

Source code for flash.core.data.splits

from typing import Any, List, Optional

import numpy as np
from torch.utils.data import Dataset

from flash.core.data.properties import Properties
from flash.core.utilities.stages import RunningStage


[docs]class SplitDataset(Properties, Dataset): """SplitDataset is used to create Dataset Subset using indices. Args: dataset: A dataset to be split indices: List of indices to expose from the dataset use_duplicated_indices: Whether to allow duplicated indices. Example:: split_ds = SplitDataset(dataset, indices=[10, 14, 25]) split_ds = SplitDataset(dataset, indices=[10, 10, 10, 14, 25], use_duplicated_indices=True) """ def __init__( self, dataset: Any, indices: List[int], running_stage: Optional[RunningStage] = None, use_duplicated_indices: bool = False, ) -> None: kwargs = {} if running_stage is not None: kwargs = {"running_stage": running_stage} elif isinstance(dataset, Properties): kwargs = {"running_stage": dataset._running_stage} super().__init__(**kwargs) if not isinstance(indices, list): raise TypeError("indices should be a list") indices = list(indices) if use_duplicated_indices else list(np.unique(indices)) if np.max(indices) >= len(dataset) or np.min(indices) < 0: raise ValueError(f"`indices` should be within [0, {len(dataset) -1}].") self.dataset = dataset self.indices = indices def __getattr__(self, key: str): if key != "dataset": return getattr(self.dataset, key) raise AttributeError def __getitem__(self, index: int) -> Any: return self.dataset[self.indices[index]] def __len__(self) -> int: return len(self.indices)

© Copyright 2020-2021, PyTorch Lightning. Revision a9cedb5a.

Built with Sphinx using a theme provided by Read the Docs.