DistributionTemplates.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. #pragma once
  2. #include <ATen/CPUApplyUtils.h>
  3. #include <ATen/Dispatch.h>
  4. #include <ATen/ExpandBase.h>
  5. #include <ATen/core/DistributionsHelper.h>
  6. #include <ATen/native/TensorIterator.h>
  7. #include <ATen/native/cpu/Loops.h>
  8. #include <limits>
  9. #include <mutex>
  10. #ifdef CPU_CAPABILITY_AVX2
  11. #include <ATen/native/cpu/avx_mathfun.h>
  12. #include <c10/util/irange.h>
  13. #endif
  14. namespace at {
  15. namespace native {
  16. namespace templates {
  17. namespace cpu {
  18. namespace {
  19. // ==================================================== Random ========================================================
  20. template<typename RNG>
  21. void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG generator) {
  22. AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_from_to_kernel_cpu", [&] {
  23. std::lock_guard<std::mutex> lock(generator->mutex_);
  24. cpu_serial_kernel(iter, [range, base, generator]() -> scalar_t {
  25. uniform_int_from_to_distribution<scalar_t> random(range, base);
  26. return random(generator);
  27. });
  28. });
  29. }
  30. // This is the special kernel to handle single specific case:
  31. // from(inclusive) = std::numeric_limits<int64_t>::lowest()
  32. // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
  33. template<typename RNG>
  34. void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG generator) {
  35. AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cpu", [&] {
  36. std::lock_guard<std::mutex> lock(generator->mutex_);
  37. if (std::is_same<scalar_t, int64_t>::value ||
  38. std::is_same<scalar_t, double>::value ||
  39. std::is_same<scalar_t, float>::value ||
  40. std::is_same<scalar_t, at::BFloat16>::value) {
  41. cpu_serial_kernel(iter, [generator]() -> scalar_t {
  42. uniform_int_full_range_distribution<scalar_t> random;
  43. return random(generator);
  44. });
  45. } else {
  46. TORCH_CHECK(false, "random_full_64_bits_range_kernel_cpu handles only int64, double, float and bfloat16");
  47. }
  48. });
  49. }
  50. template<typename RNG>
  51. struct RandomFromToKernel {
  52. void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, c10::optional<Generator> gen) {
  53. random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
  54. }
  55. void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
  56. random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
  57. }
  58. };
  59. template<typename RNG>
  60. void random_kernel(TensorIteratorBase& iter, RNG generator) {
  61. std::lock_guard<std::mutex> lock(generator->mutex_);
  62. AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cpu", [&] {
  63. cpu_serial_kernel(iter, [generator]() -> scalar_t {
  64. uniform_int_distribution<scalar_t> random;
  65. return random(generator);
  66. });
  67. });
  68. }
  69. template<typename RNG>
  70. struct RandomKernel {
  71. void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
  72. random_kernel(iter, check_generator<RNG>(gen));
  73. }
  74. };
  75. // ==================================================== Normal ========================================================
  76. #ifdef CPU_CAPABILITY_AVX2
  77. static void normal_fill_16_AVX2(float *data,
  78. const __m256* two_pi,
  79. const __m256* one,
  80. const __m256* minus_two,
  81. const __m256* mean,
  82. const __m256* std_v) {
  83. const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data));
  84. const __m256 u2 = _mm256_loadu_ps(data + 8);
  85. // sincos256_ps and log256_ps are from avx_mathfun.h
  86. const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1)));
  87. const __m256 theta = _mm256_mul_ps(*two_pi, u2);
  88. __m256 sintheta, costheta;
  89. sincos256_ps(theta, &sintheta, &costheta);
  90. const __m256 n1 = _mm256_mul_ps(radius, costheta);
  91. const __m256 n2 = _mm256_mul_ps(radius, sintheta);
  92. _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean));
  93. _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean));
  94. }
  95. template<typename RNG>
  96. void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) {
  97. float *data = self.data_ptr<float>();
  98. auto size = self.numel();
  99. std::lock_guard<std::mutex> lock(generator->mutex_);
  100. for (const auto i : c10::irange(size)) {
  101. at::uniform_real_distribution<float> uniform(0, 1);
  102. data[i] = uniform(generator);
  103. }
  104. const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi<double>);
  105. const __m256 one = _mm256_set1_ps(1.0f);
  106. const __m256 minus_two = _mm256_set1_ps(-2.0f);
  107. const __m256 mean_v = _mm256_set1_ps(mean);
  108. const __m256 std_v = _mm256_set1_ps(std);
  109. for (int64_t i = 0; i < size - 15; i += 16) {
  110. normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v);
  111. }
  112. if (size % 16 != 0) {
  113. // Recompute the last 16 values.
  114. data = data + size - 16;
  115. for (const auto i : c10::irange(16)) {
  116. at::uniform_real_distribution<float> uniform(0, 1);
  117. data[i] = uniform(generator);
  118. }
  119. normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v);
  120. }
  121. }
  122. #endif
  123. template <typename scalar_t>
  124. static void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) {
  125. for (const auto j : c10::irange(8)) {
  126. const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log.
  127. const scalar_t u2 = data[j + 8];
  128. const scalar_t radius = std::sqrt(-2 * std::log(u1));
  129. const scalar_t theta = 2.0f * c10::pi<double> * u2;
  130. data[j] = radius * std::cos(theta) * std + mean;
  131. data[j + 8] = radius * std::sin(theta) * std + mean;
  132. }
  133. }
  134. template <typename scalar_t, typename RNG>
  135. void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) {
  136. scalar_t *data = self.data_ptr<scalar_t>();
  137. auto size = self.numel();
  138. std::lock_guard<std::mutex> lock(generator->mutex_);
  139. for (const auto i : c10::irange(size)) {
  140. at::uniform_real_distribution<scalar_t> uniform(0, 1);
  141. data[i] = uniform(generator);
  142. }
  143. for (int64_t i = 0; i < size - 15; i += 16) {
  144. normal_fill_16<scalar_t>(data + i, mean, std);
  145. }
  146. if (size % 16 != 0) {
  147. // Recompute the last 16 values.
  148. data = data + size - 16;
  149. for (const auto i : c10::irange(16)) {
  150. at::uniform_real_distribution<scalar_t> uniform(0, 1);
  151. data[i] = uniform(generator);
  152. }
  153. normal_fill_16<scalar_t>(data, mean, std);
  154. }
  155. }
  156. template<typename RNG>
  157. void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) {
  158. auto size = self.numel();
  159. if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) {
  160. #ifdef CPU_CAPABILITY_AVX2
  161. normal_fill_AVX2(self, static_cast<float>(mean), static_cast<float>(std), generator);
  162. #else
  163. normal_fill(self, static_cast<float>(mean), static_cast<float>(std), generator);
  164. #endif
  165. } else {
  166. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] {
  167. if (size >= 16 && self.is_contiguous()) {
  168. normal_fill<scalar_t>(self, static_cast<scalar_t>(mean), static_cast<scalar_t>(std), generator);
  169. } else {
  170. auto iter = TensorIterator::borrowing_nullary_op(self);
  171. std::lock_guard<std::mutex> lock(generator->mutex_);
  172. cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t {
  173. at::normal_distribution<double> normal(mean, std);
  174. return static_cast<scalar_t>(normal(generator));
  175. });
  176. }
  177. });
  178. }
  179. }
  180. template<typename RNG>
  181. struct NormalKernel {
  182. void operator()(Tensor& self, double mean, double std, c10::optional<Generator> gen) {
  183. normal_kernel(self, mean, std, check_generator<RNG>(gen));
  184. }
  185. };
  186. // ==================================================== Uniform =======================================================
  187. template<typename RNG>
  188. void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG generator) {
  189. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "uniform_kernel_cpu", [&]() {
  190. std::lock_guard<std::mutex> lock(generator->mutex_);
  191. auto from = static_cast<scalar_t>(from_);
  192. auto to = static_cast<scalar_t>(to_);
  193. at::uniform_real_distribution<scalar_t> uniform(from, to);
  194. cpu_serial_kernel(iter, [&uniform, generator]() -> scalar_t {
  195. return static_cast<scalar_t>(uniform(generator));
  196. });
  197. });
  198. }
  199. template<typename RNG>
  200. struct UniformKernel {
  201. void operator()(TensorIteratorBase& iter, double from, double to, c10::optional<Generator> gen) {
  202. uniform_kernel(iter, from, to, check_generator<RNG>(gen));
  203. }
  204. };
  205. // ==================================================== Cauchy ========================================================
  206. template<typename RNG>
  207. void cauchy_kernel(TensorIteratorBase& iter, double median, double sigma, RNG generator) {
  208. AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "cauchy_cpu", [&]() {
  209. std::lock_guard<std::mutex> lock(generator->mutex_);
  210. at::cauchy_distribution<double> cauchy(median, sigma);
  211. cpu_serial_kernel(iter, [&cauchy, generator]() -> scalar_t {
  212. return static_cast<scalar_t>(cauchy(generator));
  213. });
  214. });
  215. }
  216. template<typename RNG>
  217. struct CauchyKernel {
  218. void operator()(TensorIteratorBase& iter, double median, double sigma, c10::optional<Generator> gen) {
  219. cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
  220. }
  221. };
  222. // ================================================== LogNormal =======================================================
  223. template<typename RNG>
  224. void log_normal_kernel(TensorIteratorBase& iter, double mean, double std, RNG generator) {
  225. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cpu", [&]() {
  226. std::lock_guard<std::mutex> lock(generator->mutex_);
  227. at::lognormal_distribution<double> logNormal(mean, std);
  228. cpu_serial_kernel(iter, [&logNormal, generator]() -> scalar_t {
  229. return static_cast<scalar_t>(logNormal(generator));
  230. });
  231. });
  232. }
  233. template<typename RNG>
  234. struct LogNormalKernel {
  235. void operator()(TensorIteratorBase& iter, double mean, double std, c10::optional<Generator> gen) {
  236. log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
  237. }
  238. };
  239. // =================================================== Geometric ======================================================
  240. template<typename RNG>
  241. void geometric_kernel(TensorIteratorBase& iter, double p, RNG generator) {
  242. AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cpu", [&]() {
  243. std::lock_guard<std::mutex> lock(generator->mutex_);
  244. at::geometric_distribution<double> geometric(p);
  245. cpu_serial_kernel(iter, [&geometric, generator]() -> scalar_t {
  246. return static_cast<scalar_t>(geometric(generator));
  247. });
  248. });
  249. }
  250. template<typename RNG>
  251. struct GeometricKernel {
  252. void operator()(TensorIteratorBase& iter, double p, c10::optional<Generator> gen) {
  253. geometric_kernel(iter, p, check_generator<RNG>(gen));
  254. }
  255. };
  256. // ================================================== Exponential =====================================================
  257. template<typename RNG>
  258. void exponential_kernel(TensorIteratorBase& iter, double lambda, RNG generator) {
  259. TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
  260. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cpu", [&]() {
  261. std::lock_guard<std::mutex> lock(generator->mutex_);
  262. at::exponential_distribution<double> exponential(lambda);
  263. cpu_serial_kernel(iter, [&exponential, generator]() -> scalar_t {
  264. return static_cast<scalar_t>(exponential(generator));
  265. });
  266. });
  267. }
  268. template<typename RNG>
  269. struct ExponentialKernel {
  270. void operator()(TensorIteratorBase& iter, double lambda, c10::optional<Generator> gen) {
  271. exponential_kernel(iter, lambda, check_generator<RNG>(gen));
  272. }
  273. };
  274. // ================================================== Bernoulli =======================================================
  275. template<typename RNG>
  276. void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG generator) {
  277. AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_tensor_cpu_self_", [&] {
  278. // See Note [Acquire lock when using random generators]
  279. std::lock_guard<std::mutex> lock(generator->mutex_);
  280. using self_t = scalar_t;
  281. auto p_cpu = p_.to(kCPU);
  282. auto p = expand_inplace(self, p_cpu);
  283. auto iter = TensorIteratorConfig()
  284. .add_output(self)
  285. .add_input(*p)
  286. .check_all_same_dtype(false)
  287. .build();
  288. if (p->scalar_type() == kDouble) {
  289. cpu_serial_kernel(iter, [&](const double p_val) -> self_t {
  290. at::bernoulli_distribution<double> bernoulli(p_val);
  291. return static_cast<self_t>(bernoulli(generator));
  292. });
  293. } else {
  294. AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, p->scalar_type(), "bernoulli_tensor_cpu_p_", [&] {
  295. using p_t = scalar_t;
  296. cpu_serial_kernel(iter, [&](const p_t p_val) -> self_t {
  297. at::bernoulli_distribution<float> bernoulli(p_val);
  298. return static_cast<self_t>(bernoulli(generator));
  299. });
  300. });
  301. }
  302. });
  303. }
  304. template<typename RNG>
  305. void bernoulli_kernel(const TensorBase &self, double p, RNG generator) {
  306. AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "bernoulli_scalar_cpu_", [&] {
  307. // See Note [Acquire lock when using random generators]
  308. std::lock_guard<std::mutex> lock(generator->mutex_);
  309. auto iter = TensorIterator::borrowing_nullary_op(self);
  310. cpu_serial_kernel(iter, [p, generator]() -> scalar_t {
  311. at::bernoulli_distribution<double> bernoulli(p);
  312. return static_cast<scalar_t>(bernoulli(generator));
  313. });
  314. });
  315. }
  316. template<typename RNG>
  317. struct BernoulliKernel {
  318. void operator()(const TensorBase &self, double p, c10::optional<Generator> gen) {
  319. bernoulli_kernel(self, p, check_generator<RNG>(gen));
  320. }
  321. void operator()(const TensorBase &self, const TensorBase &p_, c10::optional<Generator> gen) {
  322. bernoulli_kernel(self, p_, check_generator<RNG>(gen));
  323. }
  324. };
  325. }}}}}