combinatorics.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import random
  2. import torch
  3. from torch.utils.data import Sampler, SequentialSampler
  4. from torch.utils.data.datapipes._decorator import functional_datapipe
  5. from torch.utils.data.datapipes.datapipe import IterDataPipe
  6. from typing import Dict, Iterator, List, Optional, Sized, Tuple, Type, TypeVar
  7. __all__ = [
  8. "SamplerIterDataPipe",
  9. "ShufflerIterDataPipe",
  10. ]
  11. T_co = TypeVar('T_co', covariant=True)
  12. class SamplerIterDataPipe(IterDataPipe[T_co]):
  13. r"""
  14. Generates sample elements using the provided ``Sampler`` (defaults to :class:`SequentialSampler`).
  15. Args:
  16. datapipe: IterDataPipe to sample from
  17. sampler: Sampler class to generate sample elements from input DataPipe.
  18. Default is :class:`SequentialSampler` for IterDataPipe
  19. """
  20. datapipe: IterDataPipe
  21. sampler: Sampler
  22. def __init__(self,
  23. datapipe: IterDataPipe,
  24. sampler: Type[Sampler] = SequentialSampler,
  25. sampler_args: Optional[Tuple] = None,
  26. sampler_kwargs: Optional[Dict] = None
  27. ) -> None:
  28. assert isinstance(datapipe, Sized), \
  29. "Sampler class requires input datapipe implemented `__len__`"
  30. super().__init__()
  31. self.datapipe = datapipe
  32. self.sampler_args = () if sampler_args is None else sampler_args
  33. self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
  34. # https://github.com/python/mypy/pull/9629 will solve
  35. self.sampler = sampler(data_source=self.datapipe, *self.sampler_args, **self.sampler_kwargs) # type: ignore[misc]
  36. def __iter__(self) -> Iterator[T_co]:
  37. return iter(self.sampler)
  38. def __len__(self) -> int:
  39. # Dataset has been tested as `Sized`
  40. if isinstance(self.sampler, Sized):
  41. return len(self.sampler)
  42. raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
  43. @functional_datapipe('shuffle')
  44. class ShufflerIterDataPipe(IterDataPipe[T_co]):
  45. r"""
  46. Shuffles the input DataPipe with a buffer (functional name: ``shuffle``). The buffer
  47. with ``buffer_size`` is filled with elements from the datapipe first. Then,
  48. each item will be yielded from the buffer by reservoir sampling via iterator.
  49. ``buffer_size`` is required to be larger than ``0``. For ``buffer_size == 1``, the
  50. datapipe is not shuffled. In order to fully shuffle all elements from datapipe,
  51. ``buffer_size`` is required to be greater than or equal to the size of datapipe.
  52. When it is used with :class:`torch.utils.data.DataLoader`, the methods to
  53. set up random seed are different based on :attr:`num_workers`.
  54. For single-process mode (:attr:`num_workers == 0`), the random seed is set before
  55. the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
  56. mode (:attr:`num_worker > 0`), `worker_init_fn` is used to set up a random seed
  57. for each worker process.
  58. Args:
  59. datapipe: The IterDataPipe being shuffled
  60. buffer_size: The buffer size for shuffling (default to ``10000``)
  61. unbatch_level: Specifies if it is necessary to unbatch source data before
  62. applying the shuffle
  63. Example:
  64. >>> # xdoctest: +SKIP
  65. >>> from torchdata.datapipes.iter import IterableWrapper
  66. >>> dp = IterableWrapper(range(10))
  67. >>> shuffle_dp = dp.shuffle()
  68. >>> list(shuffle_dp)
  69. [0, 4, 1, 6, 3, 2, 9, 5, 7, 8]
  70. """
  71. datapipe: IterDataPipe[T_co]
  72. buffer_size: int
  73. _buffer: List[T_co]
  74. _enabled: bool
  75. _seed: Optional[int]
  76. _rng: random.Random
  77. def __init__(self,
  78. datapipe: IterDataPipe[T_co],
  79. *,
  80. buffer_size: int = 10000,
  81. unbatch_level: int = 0
  82. ) -> None:
  83. super().__init__()
  84. # TODO: Performance optimization
  85. # buffer can be a fixed size and remove expensive `append()` and `len()` operations
  86. self._buffer: List[T_co] = []
  87. assert buffer_size > 0, "buffer_size should be larger than 0"
  88. if unbatch_level == 0:
  89. self.datapipe = datapipe
  90. else:
  91. self.datapipe = datapipe.unbatch(unbatch_level=unbatch_level)
  92. self.buffer_size = buffer_size
  93. self._enabled = True
  94. self._seed = None
  95. self._rng = random.Random()
  96. def set_shuffle(self, shuffle=True):
  97. self._enabled = shuffle
  98. return self
  99. def set_seed(self, seed: int):
  100. self._seed = seed
  101. return self
  102. def __iter__(self) -> Iterator[T_co]:
  103. if not self._enabled:
  104. for x in self.datapipe:
  105. yield x
  106. else:
  107. for x in self.datapipe:
  108. if len(self._buffer) == self.buffer_size:
  109. idx = self._rng.randint(0, len(self._buffer) - 1)
  110. val, self._buffer[idx] = self._buffer[idx], x
  111. yield val
  112. else:
  113. self._buffer.append(x)
  114. while self._buffer:
  115. idx = self._rng.randint(0, len(self._buffer) - 1)
  116. yield self._buffer.pop(idx)
  117. def __len__(self) -> int:
  118. if isinstance(self.datapipe, Sized):
  119. return len(self.datapipe)
  120. raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
  121. def reset(self) -> None:
  122. self._buffer = []
  123. if self._enabled:
  124. if self._seed is None:
  125. self._seed = int(torch.empty((), dtype=torch.int64).random_().item())
  126. self._rng.seed(self._seed)
  127. self._seed = None
  128. def __getstate__(self):
  129. state = (
  130. self.datapipe,
  131. self.buffer_size,
  132. self._enabled,
  133. self._seed,
  134. self._buffer,
  135. self._rng.getstate(),
  136. self._valid_iterator_id,
  137. self._number_of_samples_yielded,
  138. )
  139. if IterDataPipe.getstate_hook is not None:
  140. return IterDataPipe.getstate_hook(state)
  141. return state
  142. def __setstate__(self, state):
  143. (
  144. self.datapipe,
  145. self.buffer_size,
  146. self._enabled,
  147. self._seed,
  148. self._buffer,
  149. rng_state,
  150. self._valid_iterator_id,
  151. self._number_of_samples_yielded,
  152. ) = state
  153. self._rng = random.Random()
  154. self._rng.setstate(rng_state)
  155. def __del__(self):
  156. self._buffer.clear()