combinatorics.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import random
  2. import torch
  3. from torch.utils.data.datapipes.datapipe import IterDataPipe, MapDataPipe
  4. from typing import Iterator, List, Optional, TypeVar
  5. __all__ = ["ShufflerIterDataPipe", ]
  6. T_co = TypeVar('T_co', covariant=True)
  7. # @functional_datapipe('shuffle')
  8. class ShufflerIterDataPipe(IterDataPipe[T_co]):
  9. r"""
  10. Shuffle the input MapDataPipe via its indices (functional name: ``shuffle``).
  11. When it is used with :class:`~torch.utils.data.DataLoader`, the methods to
  12. set up random seed are different based on :attr:`num_workers`.
  13. For single-process mode (:attr:`num_workers == 0`), the random seed is set before
  14. the :class:`~torch.utils.data.DataLoader` in the main process. For multi-process
  15. mode (:attr:`num_worker > 0`), ``worker_init_fn`` is used to set up a random seed
  16. for each worker process.
  17. Args:
  18. datapipe: MapDataPipe being shuffled
  19. indices: a list of indices of the MapDataPipe. If not provided, we assume it uses 0-based indexing
  20. Example:
  21. >>> # xdoctest: +SKIP
  22. >>> from torchdata.datapipes.map import SequenceWrapper
  23. >>> dp = SequenceWrapper(range(10))
  24. >>> shuffle_dp = dp.shuffle().set_seed(0)
  25. >>> list(shuffle_dp)
  26. [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
  27. >>> list(shuffle_dp)
  28. [6, 1, 9, 5, 2, 4, 7, 3, 8, 0]
  29. >>> # Reset seed for Shuffler
  30. >>> shuffle_dp = shuffle_dp.set_seed(0)
  31. >>> list(shuffle_dp)
  32. [7, 8, 1, 5, 3, 4, 2, 0, 9, 6]
  33. Note:
  34. Even thought this ``shuffle`` operation takes a ``MapDataPipe`` as the input, it would return an
  35. ``IterDataPipe`` rather than a ``MapDataPipe``, because ``MapDataPipe`` should be non-sensitive to
  36. the order of data order for the sake of random reads, but ``IterDataPipe`` depends on the order
  37. of data during data-processing.
  38. """
  39. datapipe: MapDataPipe[T_co]
  40. _enabled: bool
  41. _seed: Optional[int]
  42. _rng: random.Random
  43. def __init__(self,
  44. datapipe: MapDataPipe[T_co],
  45. *,
  46. indices: Optional[List] = None,
  47. ) -> None:
  48. super().__init__()
  49. self.datapipe = datapipe
  50. self.indices = list(range(len(datapipe))) if indices is None else indices
  51. self._enabled = True
  52. self._seed = None
  53. self._rng = random.Random()
  54. self._shuffled_indices: List = self.indices
  55. def set_shuffle(self, shuffle=True):
  56. self._enabled = shuffle
  57. return self
  58. def set_seed(self, seed: int):
  59. self._seed = seed
  60. return self
  61. def __iter__(self) -> Iterator[T_co]:
  62. if not self._enabled:
  63. for idx in self.indices:
  64. yield self.datapipe[idx]
  65. else:
  66. while self._shuffled_indices:
  67. idx = self._shuffled_indices.pop()
  68. yield self.datapipe[idx]
  69. def reset(self) -> None:
  70. if self._enabled and self._seed is None:
  71. self._seed = int(torch.empty((), dtype=torch.int64).random_().item())
  72. self._rng.seed(self._seed)
  73. self._seed = None
  74. self._shuffled_indices = self._rng.sample(self.indices, len(self.indices))
  75. def __len__(self) -> int:
  76. return len(self.datapipe)
  77. def __getstate__(self):
  78. state = (
  79. self.datapipe,
  80. self.indices,
  81. self._enabled,
  82. self._seed,
  83. self._rng.getstate(),
  84. self._shuffled_indices,
  85. self._valid_iterator_id,
  86. self._number_of_samples_yielded,
  87. )
  88. if IterDataPipe.getstate_hook is not None:
  89. return IterDataPipe.getstate_hook(state)
  90. return state
  91. def __setstate__(self, state):
  92. (
  93. self.datapipe,
  94. self.indices,
  95. self._enabled,
  96. self._seed,
  97. rng_state,
  98. self._shuffled_indices,
  99. self._valid_iterator_id,
  100. self._number_of_samples_yielded,
  101. ) = state
  102. self._rng = random.Random()
  103. self._rng.setstate(rng_state)
  104. MapDataPipe.register_datapipe_as_function("shuffle", ShufflerIterDataPipe)