123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125 |
- import random
- import torch
- from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
- from typing import Iterator, List, Optional, TypeVar
- __all__ = ["ShufflerIterDataPipe", ]
- T_co = TypeVar('T_co', covariant=True)
- # @functional_datapipe('shuffle')
- class ShufflerIterDataPipe(IterDataPipe[T_co]):
- r"""
- Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``).
- When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
- set up random seed are different based on :attr:`num_workers`.
- For single-process mode (:attr:`num_workers == 0`), the random seed is set before
- the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
- mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed
- for each worker process.
- Args:
- datapipe: MapDataPipe being shuffled
- indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing
- Example:
- >>> # xdoctest: +SKIP
- >>> from torchdata.datapipes.map import SequenceWrapper
- >>> dp = SequenceWrapper(range(10))
- >>> shuffle_dp = dp.shuffle().set_seed(0)
- >>> list(shuffle_dp)
- [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
- >>> list(shuffle_dp)
- [6, 1, 9, 5, 2, 4, 7, 3, 8, 0]
- >>> # Reset seed for Shuffler
- >>> shuffle_dp = shuffle_dp.set_seed(0)
- >>> list(shuffle_dp)
- [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
- Note:
- Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an
- ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to
- the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order
- of data during data-processing.
- """
- datapipe: MapDataPipe[T_co]
- _enabled: bool
- _seed: Optional[int]
- _rng: random.Random
- def __init__(self,
- datapipe: MapDataPipe[T_co],
- *,
- indices: Optional[List] = None,
- ) -> None:
- super().__init__()
- self.datapipe = datapipe
- self.indices = list(range(len(datapipe))) if indices is None else indices
- self._enabled = True
- self._seed = None
- self._rng = random.Random()
- self._shuffled_indices: List = self.indices
- def set_shuffle(self, shuffle=True):
- self._enabled = shuffle
- return self
- def set_seed(self, seed: int):
- self._seed = seed
- return self
- def __iter__(self) -> Iterator[T_co]:
- if not self._enabled:
- for idx in self.indices:
- yield self.datapipe[idx]
- else:
- while self._shuffled_indices:
- idx = self._shuffled_indices.pop()
- yield self.datapipe[idx]
- def reset(self) -> None:
- if self._enabled and self._seed is None:
- self._seed = int(torch.empty((), dtype=torch.int64).random_().item())
- self._rng.seed(self._seed)
- self._seed = None
- self._shuffled_indices = self._rng.sample(self.indices, len(self.indices))
- def __len__(self) -> int:
- return len(self.datapipe)
- def __getstate__(self):
- state = (
- self.datapipe,
- self.indices,
- self._enabled,
- self._seed,
- self._rng.getstate(),
- self._shuffled_indices,
- self._valid_iterator_id,
- self._number_of_samples_yielded,
- )
- if IterDataPipe.getstate_hook is not None:
- return IterDataPipe.getstate_hook(state)
- return state
- def __setstate__(self, state):
- (
- self.datapipe,
- self.indices,
- self._enabled,
- self._seed,
- rng_state,
- self._shuffled_indices,
- self._valid_iterator_id,
- self._number_of_samples_yielded,
- ) = state
- self._rng = random.Random()
- self._rng.setstate(rng_state)
- MapDataPipe.register_datapipe_as_function("shuffle", ShufflerIterDataPipe)
|