import random import torch from torch.utils.data import Sampler, SequentialSampler from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import IterDataPipe from typing import Dict, Iterator, List, Optional, Sized, Tuple, Type, TypeVar __all__ = [ "SamplerIterDataPipe", "ShufflerIterDataPipe", ] T_co = TypeVar('T_co', covariant=True) class SamplerIterDataPipe(IterDataPipe[T_co]): r""" Generates sample elements using the provided ``Sampler`` (defaults to :class:`SequentialSampler`). Args: datapipe: IterDataPipe to sample from sampler: Sampler class to generate sample elements from input DataPipe. Default is :class:`SequentialSampler` for IterDataPipe """ datapipe: IterDataPipe sampler: Sampler def __init__(self, datapipe: IterDataPipe, sampler: Type[Sampler] = SequentialSampler, sampler_args: Optional[Tuple] = None, sampler_kwargs: Optional[Dict] = None ) -> None: assert isinstance(datapipe, Sized), \ "Sampler class requires input datapipe implemented `__len__`" super().__init__() self.datapipe = datapipe self.sampler_args = () if sampler_args is None else sampler_args self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs # https://github.com/python/mypy/pull/9629 will solve self.sampler = sampler(data_source=self.datapipe, *self.sampler_args, **self.sampler_kwargs) # type: ignore[misc] def __iter__(self) -> Iterator[T_co]: return iter(self.sampler) def __len__(self) -> int: # Dataset has been tested as `Sized` if isinstance(self.sampler, Sized): return len(self.sampler) raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) @functional_datapipe('shuffle') class ShufflerIterDataPipe(IterDataPipe[T_co]): r""" Shuffles the input DataPipe with a buffer (functional name: ``shuffle``). The buffer with ``buffer_size`` is filled with elements from the datapipe first. Then, each item will be yielded from the buffer by reservoir sampling via iterator. ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the datapipe is not shuffled. In order to fully shuffle all elements from datapipe, ``buffer_size`` is required to be greater than or equal to the size of datapipe. When it is used with :class:`torch.utils.data.DataLoader`, the methods to set up random seed are different based on :attr:`num_workers`. For single-process mode (:attr:`num_workers == 0`), the random seed is set before the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed for each worker process. Args: datapipe: The IterDataPipe being shuffled buffer_size: The buffer size for shuffling (default to ``10000``) unbatch_level: Specifies if it is necessary to unbatch source data before applying the shuffle Example: >>> # xdoctest: +SKIP >>> from torchdata.datapipes.iter import IterableWrapper >>> dp = IterableWrapper(range(10)) >>> shuffle_dp = dp.shuffle() >>> list(shuffle_dp) [0, 4, 1, 6, 3, 2, 9, 5, 7, 8] """ datapipe: IterDataPipe[T_co] buffer_size: int _buffer: List[T_co] _enabled: bool _seed: Optional[int] _rng: random.Random def __init__(self, datapipe: IterDataPipe[T_co], *, buffer_size: int = 10000, unbatch_level: int = 0 ) -> None: super().__init__() # TODO: Performance optimization # buffer can be a fixed size and remove expensive `append()` and `len()` operations self._buffer: List[T_co] = [] assert buffer_size > 0, "buffer_size should be larger than 0" if unbatch_level == 0: self.datapipe = datapipe else: self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level) self.buffer_size = buffer_size self._enabled = True self._seed = None self._rng = random.Random() def set_shuffle(self, shuffle=True): self._enabled = shuffle return self def set_seed(self, seed: int): self._seed = seed return self def __iter__(self) -> Iterator[T_co]: if not self._enabled: for x in self.datapipe: yield x else: for x in self.datapipe: if len(self._buffer) == self.buffer_size: idx = self._rng.randint(0, len(self._buffer) - 1) val, self._buffer[idx] = self._buffer[idx], x yield val else: self._buffer.append(x) while self._buffer: idx = self._rng.randint(0, len(self._buffer) - 1) yield self._buffer.pop(idx) def __len__(self) -> int: if isinstance(self.datapipe, Sized): return len(self.datapipe) raise TypeError("{} instance doesn't have valid length".format(type(self).__name__)) def reset(self) -> None: self._buffer = [] if self._enabled: if self._seed is None: self._seed = int(torch.empty((), dtype=torch.int64).random_().item()) self._rng.seed(self._seed) self._seed = None def __getstate__(self): state = ( self.datapipe, self.buffer_size, self._enabled, self._seed, self._buffer, self._rng.getstate(), self._valid_iterator_id, self._number_of_samples_yielded, ) if IterDataPipe.getstate_hook is not None: return IterDataPipe.getstate_hook(state) return state def __setstate__(self, state): ( self.datapipe, self.buffer_size, self._enabled, self._seed, self._buffer, rng_state, self._valid_iterator_id, self._number_of_samples_yielded, ) = state self._rng = random.Random() self._rng.setstate(rng_state) def __del__(self): self._buffer.clear()