sampler.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. import torch
  2. from torch import Tensor
  3. from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union
  4. __all__ = [
  5. "BatchSampler",
  6. "RandomSampler",
  7. "Sampler",
  8. "SequentialSampler",
  9. "SubsetRandomSampler",
  10. "WeightedRandomSampler",
  11. ]
  12. T_co = TypeVar('T_co', covariant=True)
  13. class Sampler(Generic[T_co]):
  14. r"""Base class for all Samplers.
  15. Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
  16. way to iterate over indices of dataset elements, and a :meth:`__len__` method
  17. that returns the length of the returned iterators.
  18. .. note:: The :meth:`__len__` method isn't strictly required by
  19. :class:`~torch.utils.data.DataLoader`, but is expected in any
  20. calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
  21. """
  22. def __init__(self, data_source: Optional[Sized]) -> None:
  23. pass
  24. def __iter__(self) -> Iterator[T_co]:
  25. raise NotImplementedError
  26. # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  27. #
  28. # Many times we have an abstract class representing a collection/iterable of
  29. # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
  30. # implementing a `__len__` method. In such cases, we must make sure to not
  31. # provide a default implementation, because both straightforward default
  32. # implementations have their issues:
  33. #
  34. # + `return NotImplemented`:
  35. # Calling `len(subclass_instance)` raises:
  36. # TypeError: 'NotImplementedType' object cannot be interpreted as an integer
  37. #
  38. # + `raise NotImplementedError()`:
  39. # This prevents triggering some fallback behavior. E.g., the built-in
  40. # `list(X)` tries to call `len(X)` first, and executes a different code
  41. # path if the method is not found or `NotImplemented` is returned, while
  42. # raising an `NotImplementedError` will propagate and and make the call
  43. # fail where it could have use `__iter__` to complete the call.
  44. #
  45. # Thus, the only two sensible things to do are
  46. #
  47. # + **not** provide a default `__len__`.
  48. #
  49. # + raise a `TypeError` instead, which is what Python uses when users call
  50. # a method that is not defined on an object.
  51. # (@ssnl verifies that this works on at least Python 3.7.)
  52. class SequentialSampler(Sampler[int]):
  53. r"""Samples elements sequentially, always in the same order.
  54. Args:
  55. data_source (Dataset): dataset to sample from
  56. """
  57. data_source: Sized
  58. def __init__(self, data_source: Sized) -> None:
  59. self.data_source = data_source
  60. def __iter__(self) -> Iterator[int]:
  61. return iter(range(len(self.data_source)))
  62. def __len__(self) -> int:
  63. return len(self.data_source)
  64. class RandomSampler(Sampler[int]):
  65. r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
  66. If with replacement, then user can specify :attr:`num_samples` to draw.
  67. Args:
  68. data_source (Dataset): dataset to sample from
  69. replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
  70. num_samples (int): number of samples to draw, default=`len(dataset)`.
  71. generator (Generator): Generator used in sampling.
  72. """
  73. data_source: Sized
  74. replacement: bool
  75. def __init__(self, data_source: Sized, replacement: bool = False,
  76. num_samples: Optional[int] = None, generator=None) -> None:
  77. self.data_source = data_source
  78. self.replacement = replacement
  79. self._num_samples = num_samples
  80. self.generator = generator
  81. if not isinstance(self.replacement, bool):
  82. raise TypeError("replacement should be a boolean value, but got "
  83. "replacement={}".format(self.replacement))
  84. if not isinstance(self.num_samples, int) or self.num_samples <= 0:
  85. raise ValueError("num_samples should be a positive integer "
  86. "value, but got num_samples={}".format(self.num_samples))
  87. @property
  88. def num_samples(self) -> int:
  89. # dataset size might change at runtime
  90. if self._num_samples is None:
  91. return len(self.data_source)
  92. return self._num_samples
  93. def __iter__(self) -> Iterator[int]:
  94. n = len(self.data_source)
  95. if self.generator is None:
  96. seed = int(torch.empty((), dtype=torch.int64).random_().item())
  97. generator = torch.Generator()
  98. generator.manual_seed(seed)
  99. else:
  100. generator = self.generator
  101. if self.replacement:
  102. for _ in range(self.num_samples // 32):
  103. yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
  104. yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
  105. else:
  106. for _ in range(self.num_samples // n):
  107. yield from torch.randperm(n, generator=generator).tolist()
  108. yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
  109. def __len__(self) -> int:
  110. return self.num_samples
  111. class SubsetRandomSampler(Sampler[int]):
  112. r"""Samples elements randomly from a given list of indices, without replacement.
  113. Args:
  114. indices (sequence): a sequence of indices
  115. generator (Generator): Generator used in sampling.
  116. """
  117. indices: Sequence[int]
  118. def __init__(self, indices: Sequence[int], generator=None) -> None:
  119. self.indices = indices
  120. self.generator = generator
  121. def __iter__(self) -> Iterator[int]:
  122. for i in torch.randperm(len(self.indices), generator=self.generator):
  123. yield self.indices[i]
  124. def __len__(self) -> int:
  125. return len(self.indices)
  126. class WeightedRandomSampler(Sampler[int]):
  127. r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
  128. Args:
  129. weights (sequence) : a sequence of weights, not necessary summing up to one
  130. num_samples (int): number of samples to draw
  131. replacement (bool): if ``True``, samples are drawn with replacement.
  132. If not, they are drawn without replacement, which means that when a
  133. sample index is drawn for a row, it cannot be drawn again for that row.
  134. generator (Generator): Generator used in sampling.
  135. Example:
  136. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  137. >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
  138. [4, 4, 1, 4, 5]
  139. >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
  140. [0, 1, 4, 3, 2]
  141. """
  142. weights: Tensor
  143. num_samples: int
  144. replacement: bool
  145. def __init__(self, weights: Sequence[float], num_samples: int,
  146. replacement: bool = True, generator=None) -> None:
  147. if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
  148. num_samples <= 0:
  149. raise ValueError("num_samples should be a positive integer "
  150. "value, but got num_samples={}".format(num_samples))
  151. if not isinstance(replacement, bool):
  152. raise ValueError("replacement should be a boolean value, but got "
  153. "replacement={}".format(replacement))
  154. weights_tensor = torch.as_tensor(weights, dtype=torch.double)
  155. if len(weights_tensor.shape) != 1:
  156. raise ValueError("weights should be a 1d sequence but given "
  157. "weights have shape {}".format(tuple(weights_tensor.shape)))
  158. self.weights = weights_tensor
  159. self.num_samples = num_samples
  160. self.replacement = replacement
  161. self.generator = generator
  162. def __iter__(self) -> Iterator[int]:
  163. rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
  164. yield from iter(rand_tensor.tolist())
  165. def __len__(self) -> int:
  166. return self.num_samples
  167. class BatchSampler(Sampler[List[int]]):
  168. r"""Wraps another sampler to yield a mini-batch of indices.
  169. Args:
  170. sampler (Sampler or Iterable): Base sampler. Can be any iterable object
  171. batch_size (int): Size of mini-batch.
  172. drop_last (bool): If ``True``, the sampler will drop the last batch if
  173. its size would be less than ``batch_size``
  174. Example:
  175. >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
  176. [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
  177. >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
  178. [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
  179. """
  180. def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
  181. # Since collections.abc.Iterable does not check for `__getitem__`, which
  182. # is one way for an object to be an iterable, we don't do an `isinstance`
  183. # check here.
  184. if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
  185. batch_size <= 0:
  186. raise ValueError("batch_size should be a positive integer value, "
  187. "but got batch_size={}".format(batch_size))
  188. if not isinstance(drop_last, bool):
  189. raise ValueError("drop_last should be a boolean value, but got "
  190. "drop_last={}".format(drop_last))
  191. self.sampler = sampler
  192. self.batch_size = batch_size
  193. self.drop_last = drop_last
  194. def __iter__(self) -> Iterator[List[int]]:
  195. # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
  196. if self.drop_last:
  197. sampler_iter = iter(self.sampler)
  198. while True:
  199. try:
  200. batch = [next(sampler_iter) for _ in range(self.batch_size)]
  201. yield batch
  202. except StopIteration:
  203. break
  204. else:
  205. batch = [0] * self.batch_size
  206. idx_in_batch = 0
  207. for idx in self.sampler:
  208. batch[idx_in_batch] = idx
  209. idx_in_batch += 1
  210. if idx_in_batch == self.batch_size:
  211. yield batch
  212. idx_in_batch = 0
  213. batch = [0] * self.batch_size
  214. if idx_in_batch > 0:
  215. yield batch[:idx_in_batch]
  216. def __len__(self) -> int:
  217. # Can only be called if self.sampler has __len__ implemented
  218. # We cannot enforce this condition, so we turn off typechecking for the
  219. # implementation below.
  220. # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
  221. if self.drop_last:
  222. return len(self.sampler) // self.batch_size # type: ignore[arg-type]
  223. else:
  224. return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]