quasirandom.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import torch
  2. from typing import Optional
  3. class SobolEngine:
  4. r"""
  5. The :class:`torch.quasirandom.SobolEngine` is an engine for generating
  6. (scrambled) Sobol sequences. Sobol sequences are an example of low
  7. discrepancy quasi-random sequences.
  8. This implementation of an engine for Sobol sequences is capable of
  9. sampling sequences up to a maximum dimension of 21201. It uses direction
  10. numbers from https://web.maths.unsw.edu.au/~fkuo/sobol/ obtained using the
  11. search criterion D(6) up to the dimension 21201. This is the recommended
  12. choice by the authors.
  13. References:
  14. - Art B. Owen. Scrambling Sobol and Niederreiter-Xing points.
  15. Journal of Complexity, 14(4):466-489, December 1998.
  16. - I. M. Sobol. The distribution of points in a cube and the accurate
  17. evaluation of integrals.
  18. Zh. Vychisl. Mat. i Mat. Phys., 7:784-802, 1967.
  19. Args:
  20. dimension (Int): The dimensionality of the sequence to be drawn
  21. scramble (bool, optional): Setting this to ``True`` will produce
  22. scrambled Sobol sequences. Scrambling is
  23. capable of producing better Sobol
  24. sequences. Default: ``False``.
  25. seed (Int, optional): This is the seed for the scrambling. The seed
  26. of the random number generator is set to this,
  27. if specified. Otherwise, it uses a random seed.
  28. Default: ``None``
  29. Examples::
  30. >>> # xdoctest: +SKIP("unseeded random state")
  31. >>> soboleng = torch.quasirandom.SobolEngine(dimension=5)
  32. >>> soboleng.draw(3)
  33. tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
  34. [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
  35. [0.7500, 0.2500, 0.2500, 0.2500, 0.7500]])
  36. """
  37. MAXBIT = 30
  38. MAXDIM = 21201
  39. def __init__(self, dimension, scramble=False, seed=None):
  40. if dimension > self.MAXDIM or dimension < 1:
  41. raise ValueError("Supported range of dimensionality "
  42. f"for SobolEngine is [1, {self.MAXDIM}]")
  43. self.seed = seed
  44. self.scramble = scramble
  45. self.dimension = dimension
  46. cpu = torch.device("cpu")
  47. self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long)
  48. torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
  49. if not self.scramble:
  50. self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long)
  51. else:
  52. self._scramble()
  53. self.quasi = self.shift.clone(memory_format=torch.contiguous_format)
  54. self._first_point = (self.quasi / 2 ** self.MAXBIT).reshape(1, -1)
  55. self.num_generated = 0
  56. def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
  57. dtype: torch.dtype = torch.float32) -> torch.Tensor:
  58. r"""
  59. Function to draw a sequence of :attr:`n` points from a Sobol sequence.
  60. Note that the samples are dependent on the previous samples. The size
  61. of the result is :math:`(n, dimension)`.
  62. Args:
  63. n (Int, optional): The length of sequence of points to draw.
  64. Default: 1
  65. out (Tensor, optional): The output tensor
  66. dtype (:class:`torch.dtype`, optional): the desired data type of the
  67. returned tensor.
  68. Default: ``torch.float32``
  69. """
  70. if self.num_generated == 0:
  71. if n == 1:
  72. result = self._first_point.to(dtype)
  73. else:
  74. result, self.quasi = torch._sobol_engine_draw(
  75. self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,
  76. )
  77. result = torch.cat((self._first_point, result), dim=-2)
  78. else:
  79. result, self.quasi = torch._sobol_engine_draw(
  80. self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1, dtype=dtype,
  81. )
  82. self.num_generated += n
  83. if out is not None:
  84. out.resize_as_(result).copy_(result)
  85. return out
  86. return result
  87. def draw_base2(self, m: int, out: Optional[torch.Tensor] = None,
  88. dtype: torch.dtype = torch.float32) -> torch.Tensor:
  89. r"""
  90. Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
  91. Note that the samples are dependent on the previous samples. The size
  92. of the result is :math:`(2**m, dimension)`.
  93. Args:
  94. m (Int): The (base2) exponent of the number of points to draw.
  95. out (Tensor, optional): The output tensor
  96. dtype (:class:`torch.dtype`, optional): the desired data type of the
  97. returned tensor.
  98. Default: ``torch.float32``
  99. """
  100. n = 2 ** m
  101. total_n = self.num_generated + n
  102. if not (total_n & (total_n - 1) == 0):
  103. raise ValueError("The balance properties of Sobol' points require "
  104. "n to be a power of 2. {0} points have been "
  105. "previously generated, then: n={0}+2**{1}={2}. "
  106. "If you still want to do this, please use "
  107. "'SobolEngine.draw()' instead."
  108. .format(self.num_generated, m, total_n))
  109. return self.draw(n=n, out=out, dtype=dtype)
  110. def reset(self):
  111. r"""
  112. Function to reset the ``SobolEngine`` to base state.
  113. """
  114. self.quasi.copy_(self.shift)
  115. self.num_generated = 0
  116. return self
  117. def fast_forward(self, n):
  118. r"""
  119. Function to fast-forward the state of the ``SobolEngine`` by
  120. :attr:`n` steps. This is equivalent to drawing :attr:`n` samples
  121. without using the samples.
  122. Args:
  123. n (Int): The number of steps to fast-forward by.
  124. """
  125. if self.num_generated == 0:
  126. torch._sobol_engine_ff_(self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated)
  127. else:
  128. torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1)
  129. self.num_generated += n
  130. return self
  131. def _scramble(self):
  132. g: Optional[torch.Generator] = None
  133. if self.seed is not None:
  134. g = torch.Generator()
  135. g.manual_seed(self.seed)
  136. cpu = torch.device("cpu")
  137. # Generate shift vector
  138. shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g)
  139. self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)))
  140. # Generate lower triangular matrices (stacked across dimensions)
  141. ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
  142. ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril()
  143. torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
  144. def __repr__(self):
  145. fmt_string = [f'dimension={self.dimension}']
  146. if self.scramble:
  147. fmt_string += ['scramble=True']
  148. if self.seed is not None:
  149. fmt_string += [f'seed={self.seed}']
  150. return self.__class__.__name__ + '(' + ', '.join(fmt_string) + ')'