Distributions.h 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. #pragma once
  2. #include <ATen/native/Math.h>
  3. #include <c10/macros/Macros.h>
  4. #include <c10/util/MathConstants.h>
  5. // ROCM hcc doesn't work well with using std:: in kernel functions
  6. #if defined(__CUDA_ARCH__)
  7. #include <c10/cuda/CUDAMathCompat.h>
  8. #define compat_exp c10::cuda::compat::exp
  9. #define compat_ceil c10::cuda::compat::ceil
  10. #define compat_floor c10::cuda::compat::floor
  11. #define compat_log c10::cuda::compat::log
  12. #define compat_pow c10::cuda::compat::pow
  13. #define compat_sqrt c10::cuda::compat::sqrt
  14. #define compat_tan c10::cuda::compat::tan
  15. #define compat_abs c10::cuda::compat::abs
  16. #define compat_log1p c10::cuda::compat::log1p
  17. #elif defined(__HIPCC__)
  18. #include <c10/hip/HIPMathCompat.h>
  19. #define compat_exp c10::hip::compat::exp
  20. #define compat_ceil c10::hip::compat::ceil
  21. #define compat_floor c10::hip::compat::floor
  22. #define compat_log c10::hip::compat::log
  23. #define compat_pow c10::hip::compat::pow
  24. #define compat_sqrt c10::hip::compat::sqrt
  25. #define compat_tan c10::hip::compat::tan
  26. #define compat_abs c10::hip::compat::abs
  27. #define compat_log1p c10::hip::compat::log1p
  28. #else
  29. #define compat_exp std::exp
  30. #define compat_ceil std::ceil
  31. #define compat_floor std::floor
  32. #define compat_log std::log
  33. #define compat_pow std::pow
  34. #define compat_sqrt std::sqrt
  35. #define compat_tan std::tan
  36. #define compat_abs std::abs
  37. #define compat_log1p std::log1p
  38. #endif
  39. namespace {
  40. #if !defined(__CUDA_ARCH__) && !defined(__HIPCC__)
  41. // we cannot use std::isnan directly due to some incompatibility of
  42. // gcc constexpr'ing and nvcc
  43. using std::isnan;
  44. #endif
  45. // Here sampler_t should be function type scalar_t(void). For gpu
  46. // "sampler" is a device function, but since ROCM doesn't have
  47. // equivalent to nvstd::function, we use a template type parameter to
  48. // capture it.
  49. template<typename scalar_t, typename sampler_t>
  50. struct BaseSampler {
  51. sampler_t sampler;
  52. C10_DEVICE BaseSampler(const sampler_t& sampler): sampler(sampler) {}
  53. C10_DEVICE scalar_t sample() {
  54. return sampler();
  55. }
  56. };
  57. // The function `sample_gamma` is
  58. // is adapted from Numpy's distributions.c implementation.
  59. // It is MIT licensed, so here is the copyright:
  60. /* Copyright 2005 Robert Kern (robert.kern@gmail.com)
  61. *
  62. * Permission is hereby granted, free of charge, to any person obtaining a
  63. * copy of this software and associated documentation files (the
  64. * "Software"), to deal in the Software without restriction, including
  65. * without limitation the rights to use, copy, modify, merge, publish,
  66. * distribute, sublicense, and/or sell copies of the Software, and to
  67. * permit persons to whom the Software is furnished to do so, subject to
  68. * the following conditions:
  69. *
  70. * The above copyright notice and this permission notice shall be included
  71. * in all copies or substantial portions of the Software.
  72. *
  73. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
  74. * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  75. * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
  76. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
  77. * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
  78. * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
  79. * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  80. */
  81. template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t, typename normal_sampler_t>
  82. C10_DEVICE scalar_t sample_gamma(scalar_t alpha, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform, BaseSampler<accscalar_t, normal_sampler_t>& standard_normal) {
  83. accscalar_t scale = 1.0f;
  84. // Boost alpha for higher acceptance probability.
  85. if (alpha < 1.0f) {
  86. if (alpha == 0.f) return 0.f;
  87. scale *= compat_pow(1 - standard_uniform.sample(), 1.0f / alpha);
  88. alpha += 1.0f;
  89. }
  90. // This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
  91. // doi:10.1145/358407.358414
  92. const accscalar_t d = alpha - 1.0f / 3.0f;
  93. const accscalar_t c = 1.0f / compat_sqrt(9.0f * d);
  94. for (;;) {
  95. accscalar_t x, y;
  96. do {
  97. x = standard_normal.sample();
  98. y = 1.0f + c * x;
  99. } while (y <= 0);
  100. const accscalar_t v = y * y * y;
  101. const accscalar_t u = 1 - standard_uniform.sample();
  102. const accscalar_t xx = x * x;
  103. if (u < 1.0f - 0.0331f * xx * xx)
  104. return static_cast<scalar_t>(scale * d * v);
  105. if (compat_log(u) < 0.5f * xx + d * (1.0f - v + compat_log(v)))
  106. return static_cast<scalar_t>(scale * d * v);
  107. }
  108. }
  109. /* the functions stirling_approx_tail, binomial_inversion, and btrs are adapted
  110. * from TensorFlow's random_binomial_op.cc implementation. That code is under
  111. * copyright: 2019 The TensorFlow Authors.
  112. *
  113. * It was released under the Apache License, Version 2.0 (the "License"), available at:
  114. * http://www.apache.org/licenses/LICENSE-2.0
  115. */
  116. template<typename scalar_t>
  117. C10_DEVICE scalar_t stirling_approx_tail(scalar_t k) {
  118. const static scalar_t kTailValues[] = {
  119. 0.0810614667953272,
  120. 0.0413406959554092,
  121. 0.0276779256849983,
  122. 0.02079067210376509,
  123. 0.0166446911898211,
  124. 0.0138761288230707,
  125. 0.0118967099458917,
  126. 0.0104112652619720,
  127. 0.00925546218271273,
  128. 0.00833056343336287
  129. };
  130. if (k <= 9) {
  131. return kTailValues[static_cast<size_t>(k)];
  132. }
  133. scalar_t kp1sq = (k + 1) * (k + 1);
  134. return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
  135. }
  136. template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
  137. C10_DEVICE scalar_t binomial_inversion(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
  138. accscalar_t U;
  139. accscalar_t geom_sum = 0;
  140. scalar_t num_geom = 0;
  141. accscalar_t logprob = compat_log1p(-prob);
  142. while (1) {
  143. U = standard_uniform.sample();
  144. accscalar_t geom = compat_ceil(compat_log(U) / logprob);
  145. geom_sum += geom;
  146. if (geom_sum > count) {
  147. break;
  148. }
  149. num_geom = num_geom + 1;
  150. }
  151. return num_geom;
  152. }
  153. template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
  154. C10_DEVICE scalar_t btrs(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
  155. scalar_t k;
  156. accscalar_t U, V, us;
  157. // This is spq in the paper.
  158. const accscalar_t stddev = compat_sqrt(count * prob * (1 - prob));
  159. // Other coefficients for Transformed Rejection sampling.
  160. const accscalar_t b = 1.15 + 2.53 * stddev;
  161. const accscalar_t a = -0.0873 + 0.0248 * b + 0.01 * prob;
  162. const accscalar_t c = count * prob + 0.5;
  163. const accscalar_t v_r = 0.92 - 4.2 / b;
  164. const accscalar_t r = prob / (1 - prob);
  165. const accscalar_t alpha = (2.83 + 5.1 / b) * stddev;
  166. const accscalar_t m = compat_floor((count + 1) * prob);
  167. while (1) {
  168. U = standard_uniform.sample() - 0.5;
  169. V = standard_uniform.sample();
  170. us = 0.5 - compat_abs(U);
  171. k = static_cast<scalar_t>(compat_floor((2 * a / us + b) * U + c));
  172. // Reject non-sensical answers.
  173. if (k < 0 || k > count) {
  174. continue;
  175. }
  176. // Region for which the box is tight, and we can return our calculated value.
  177. // This should happen 0.86 * v_r times. In the limit as n * p is large,
  178. // the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
  179. if (us >= 0.07 && V <= v_r) {
  180. return k;
  181. }
  182. // This deviates from Hormann's BTRS algorithm, as there is a log missing.
  183. // For all (u, v) pairs outside of the bounding box, this calculates the
  184. // transformed-reject ratio.
  185. V = compat_log(V * alpha / (a / (us * us) + b));
  186. accscalar_t upperbound =
  187. ((m + 0.5) * compat_log((m + 1) / (r * (count - m + 1))) +
  188. (count + 1) * compat_log((count - m + 1) / (count - k + 1)) +
  189. (k + 0.5) * compat_log(r * (count - k + 1) / (k + 1)) +
  190. stirling_approx_tail<accscalar_t>(m) + stirling_approx_tail<accscalar_t>(count - m) -
  191. stirling_approx_tail<accscalar_t>(k) - stirling_approx_tail<accscalar_t>(count - k));
  192. if (V <= upperbound) {
  193. return k;
  194. }
  195. }
  196. }
  197. template<typename scalar_t, typename accscalar_t, typename uniform_sampler_t>
  198. C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler<accscalar_t, uniform_sampler_t>& standard_uniform) {
  199. if (count <= 0.0 || prob <= 0.0) {
  200. return 0;
  201. } else if (prob >= 1.0) {
  202. return count;
  203. } else if (prob <= 0.5) {
  204. if (count * prob >= 10.0) {
  205. // btrs
  206. return btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
  207. } else {
  208. // binomial inversion
  209. return binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, prob, standard_uniform);
  210. }
  211. } else if (prob > 0.5) {
  212. scalar_t qprob = 1.0 - prob;
  213. if (count * qprob >= 10.0) {
  214. // btrs
  215. return count - btrs<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
  216. } else {
  217. // count - binomial inversion
  218. return count - binomial_inversion<scalar_t, accscalar_t, uniform_sampler_t>(count, qprob, standard_uniform);
  219. }
  220. } else {
  221. // prob is nan?
  222. return static_cast<scalar_t>(NAN);
  223. }
  224. }
  225. /*
  226. * This function is derived from the implementation of the digamma function in the Cephes Math Library.
  227. * See note [3-Clause BSD License for the Cephes Math Library] in ATen/native/Math.h.
  228. */
  229. template<typename scalar_t, typename accscalar_t>
  230. C10_DEVICE static inline scalar_t digamma_one(scalar_t x) {
  231. constexpr accscalar_t PSI_10 = 2.25175258906672110764;
  232. if (x == 0) {
  233. return INFINITY;
  234. }
  235. accscalar_t additional_summand = 0;
  236. int x_is_integer = x == compat_floor(x);
  237. if (x < 0) {
  238. if (x_is_integer) {
  239. return INFINITY;
  240. }
  241. // it is more standard to write this as recursion, but
  242. // nvcc does not like that
  243. additional_summand = -c10::pi<scalar_t> /
  244. compat_tan(c10::pi<scalar_t> * x);
  245. x = 1 - x;
  246. }
  247. // Push x to be >= 10
  248. accscalar_t result = 0;
  249. while (x < 10) {
  250. result -= 1 / x;
  251. x += 1;
  252. }
  253. if (x == 10) {
  254. return result + PSI_10 + additional_summand;
  255. }
  256. // Compute asymptotic digamma
  257. static const accscalar_t A[] = {
  258. 8.33333333333333333333E-2,
  259. -2.10927960927960927961E-2,
  260. 7.57575757575757575758E-3,
  261. -4.16666666666666666667E-3,
  262. 3.96825396825396825397E-3,
  263. -8.33333333333333333333E-3,
  264. 8.33333333333333333333E-2,
  265. };
  266. accscalar_t y = 0;
  267. if (x < 1.0e17f) {
  268. accscalar_t z = 1.0 / (x * x);
  269. y = z * polevl<accscalar_t>(z, A, 6);
  270. }
  271. return static_cast<scalar_t>(
  272. result + compat_log(x) - (0.5f / x) - y + additional_summand);
  273. }
  274. // Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
  275. // for random number x drawn from a standard Gamma distribution Gamma(alpha).
  276. template <typename scalar_t, typename accscalar_t>
  277. C10_HOST_DEVICE scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
  278. // Use a Taylor series expansion for small x.
  279. accscalar_t x = static_cast<accscalar_t>(x_);
  280. accscalar_t alpha = static_cast<accscalar_t>(alpha_);
  281. if (x < 0.8f) {
  282. accscalar_t numer = 1;
  283. accscalar_t denom = alpha;
  284. auto series1 = numer / denom;
  285. auto series2 = numer / (denom * denom);
  286. for (int i = 1; i <= 5; ++i) {
  287. numer *= -x / static_cast<accscalar_t>(i);
  288. denom += 1;
  289. series1 += numer / denom;
  290. series2 += numer / (denom * denom);
  291. }
  292. const auto pow_x_alpha = compat_pow(x, alpha);
  293. const auto gamma_pdf = compat_pow(x, alpha - 1) * compat_exp(-x);
  294. const auto gamma_cdf = pow_x_alpha * series1;
  295. const auto gamma_cdf_alpha =
  296. (compat_log(x) - digamma_one<accscalar_t, accscalar_t>(alpha)) *
  297. gamma_cdf -
  298. pow_x_alpha * series2;
  299. const auto result = -gamma_cdf_alpha / gamma_pdf;
  300. return isnan(result) ? static_cast<scalar_t>( 0.f ) : static_cast<scalar_t>(result);
  301. }
  302. // Use a Rice saddle point expansion for large alpha.
  303. if (alpha > 8.0f) {
  304. if (0.9f * alpha <= x && x <= 1.1f * alpha) {
  305. const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
  306. const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
  307. - 65 * x * x / alpha + alpha * (107 + 3600 * x);
  308. const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
  309. return static_cast<scalar_t>(numer_1 * numer_2 / denom);
  310. }
  311. const auto denom = compat_sqrt(8 * alpha);
  312. const auto term2 = denom / (alpha - x);
  313. const auto term3 = compat_pow(
  314. x - alpha - alpha * compat_log(x / alpha),
  315. static_cast<accscalar_t>(-1.5));
  316. const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
  317. const auto term1 = compat_log(x / alpha) * term23 -
  318. compat_sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
  319. const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
  320. const auto numer = x * term1;
  321. return static_cast<scalar_t>(-stirling * numer / denom);
  322. }
  323. // Use a bivariate rational approximation to the reparameterized gradient.
  324. const auto u = compat_log(x / alpha);
  325. const auto v = compat_log(alpha);
  326. static const accscalar_t coef_uv[3][8] = {
  327. {0.16009398, -0.094634809, 0.025146376, -0.0030648343,
  328. 1, 0.32668115, 0.10406089, 0.0014179084},
  329. {0.53487893, 0.1298071, 0.065735949, -0.0015649758,
  330. 0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
  331. {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
  332. 0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
  333. };
  334. accscalar_t coef_v[8];
  335. for (int i = 0; i < 8; ++ i) {
  336. coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
  337. }
  338. const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
  339. const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
  340. return static_cast<scalar_t>(compat_exp(p / q));
  341. }
  342. // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
  343. // Assumes x is close to zero and uses a Taylor expansion.
  344. template <typename scalar_t, typename accscalar_t>
  345. C10_DEVICE static inline scalar_t _beta_grad_alpha_small(scalar_t x, scalar_t alpha, scalar_t beta) {
  346. const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha)
  347. - digamma_one<scalar_t, accscalar_t>(alpha + beta) - compat_log(x);
  348. scalar_t numer = 1;
  349. scalar_t series = numer / alpha * (factor + 1 / alpha);
  350. for (int i = 1; i <= 10; ++i) {
  351. scalar_t casted_i = static_cast<scalar_t>(i);
  352. numer *= (casted_i - beta) * x / casted_i;
  353. const scalar_t denom = alpha + casted_i;
  354. series += numer / denom * (factor + 1 / denom);
  355. }
  356. const scalar_t result = x * compat_pow(1 - x, -beta) * series;
  357. return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
  358. }
  359. // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt beta.
  360. // Assumes x is close to zero and uses a Taylor expansion.
  361. template <typename scalar_t, typename accscalar_t>
  362. C10_DEVICE static inline scalar_t _beta_grad_beta_small(scalar_t x, scalar_t alpha, scalar_t beta) {
  363. const scalar_t factor = digamma_one<scalar_t, accscalar_t>(alpha + beta) - digamma_one<scalar_t, accscalar_t>(beta);
  364. scalar_t numer = 1, betas = 1, dbetas = 0, series = factor / alpha;
  365. for (int i = 1; i <= 8; ++i) {
  366. scalar_t casted_i = static_cast<scalar_t>(i);
  367. numer *= -x / casted_i;
  368. dbetas = dbetas * (beta - casted_i) + betas;
  369. betas = betas * (beta - casted_i);
  370. series += numer / (alpha + casted_i) * (dbetas + factor * betas);
  371. }
  372. const scalar_t result = -compat_pow(1 - x, 1 - beta) * series;
  373. return isnan(result) ? static_cast<scalar_t>( 0.f ) : result;
  374. }
  375. // Approximate reparameterized gradient of Beta(x,alpha,beta) wrt alpha.
  376. // Assumes alpha and beta are both large and uses a Rice saddle point expansion.
  377. // To ensure numerical stability, this computation is performed at higher precision.
  378. template<typename scalar_t, typename accscalar_t>
  379. C10_DEVICE static inline scalar_t _beta_grad_alpha_mid(accscalar_t x, accscalar_t alpha, accscalar_t beta) {
  380. const accscalar_t total = alpha + beta;
  381. const accscalar_t mean = alpha / total;
  382. const accscalar_t std = compat_sqrt(alpha * beta / (total + 1)) / total;
  383. if (mean - 0.1 * std <= x && x <= mean + 0.1 * std) {
  384. // Avoid the singularity at x = mean.
  385. const accscalar_t poly = 47 * x * (beta * beta) * (beta * beta) + alpha * (
  386. (43 + 20 * (16 + 27 * beta) * x) * (beta * beta) * beta + alpha * (
  387. 3 * (59 + 180 * beta - 90 * x) * (beta * beta) + alpha * (
  388. (453 + 1620 * beta * (1 - x) - 455 * x) * beta + alpha * (
  389. 8 * (1 - x) * (135 * beta - 11)))));
  390. const accscalar_t prefactor_num = (1 + 12 * alpha) * (1 + 12 * beta) / (total * total);
  391. const accscalar_t prefactor_den = 12960 * alpha * alpha * alpha * beta * beta * (1 + 12 * total);
  392. return prefactor_num / (1 - x) * poly / prefactor_den;
  393. }
  394. const accscalar_t prefactor = -x / compat_sqrt(2 * alpha * beta / total);
  395. const accscalar_t stirling = (1 + 1 / (12 * alpha) + 1 / (288 * alpha * alpha))
  396. * (1 + 1 / (12 * beta) + 1 / (288 * beta * beta))
  397. / (1 + 1 / (12 * total) + 1 / (288 * total * total));
  398. const accscalar_t term1_num = 2 * (alpha * alpha) * (x - 1) + alpha * beta * (x - 1) - x * (beta * beta);
  399. const accscalar_t axbx = alpha * (x - 1) + beta * x;
  400. const accscalar_t term1_den = compat_sqrt(2 * alpha / beta) * compat_pow(total, static_cast<accscalar_t>(1.5f)) * axbx * axbx;
  401. const accscalar_t term1 = term1_num / term1_den;
  402. const accscalar_t term2 = 0.5f * compat_log(alpha / (total * x));
  403. const accscalar_t term3_num = compat_sqrt(8 * alpha * beta / total);
  404. const accscalar_t term3_den = beta * x + alpha * (x - 1);
  405. const accscalar_t term3 = term3_num / term3_den;
  406. const accscalar_t term4_base = beta * compat_log(beta / (total * (1 - x))) +
  407. alpha * compat_log(alpha / (total * x));
  408. const accscalar_t term4 = compat_pow(term4_base, static_cast<accscalar_t>(-1.5f));
  409. const accscalar_t term1234 = term1 + term2 * (term3 + (x < mean ? term4 : -term4));
  410. return static_cast<scalar_t>(stirling * prefactor * term1234);
  411. }
  412. // Computes a scaled reparameterized gradient
  413. // -(d/dalpha cdf(x;alpha,beta)) / pdf(x;alpha,beta) / (1-x)
  414. // for random number x drawn from a Beta distribution Beta(alpha,beta).
  415. // This function inputs total=alpha+beta to make it easy to implement
  416. // Dirichlet reparameterized gradients in terms of Betas.
  417. template<typename scalar_t, typename accscalar_t>
  418. C10_HOST_DEVICE static inline scalar_t dirichlet_grad_one(scalar_t x, scalar_t alpha, scalar_t total) {
  419. accscalar_t x_ = static_cast<accscalar_t>(x);
  420. accscalar_t alpha_ = static_cast<accscalar_t>(alpha);
  421. accscalar_t total_ = static_cast<accscalar_t>(total);
  422. const scalar_t beta = total - alpha;
  423. const accscalar_t beta_ = total_ - alpha_;
  424. const scalar_t boundary = total * x * (1 - x);
  425. // Use an asymptotic approximation for x close to 0.
  426. if (x <= 0.5f && boundary < 2.5f) {
  427. return _beta_grad_alpha_small<scalar_t, accscalar_t>(x, alpha, beta);
  428. }
  429. // Use an asymptotic approximation for x close to 1.
  430. if (x >= 0.5f && boundary < 0.75f) {
  431. return -_beta_grad_beta_small<scalar_t, accscalar_t>(1 - x, beta, alpha);
  432. }
  433. // Use an asymptotic approximation when alpha and (total - alpha) are both large.
  434. if (alpha > 6 && beta > 6) {
  435. return _beta_grad_alpha_mid<scalar_t, accscalar_t>(x_, alpha_, beta_);
  436. }
  437. // Use a rational correction to an analytic approximation.
  438. static const accscalar_t c[2][3][3][4] = {
  439. {{{1.003668233, -0.01061107488, -0.0657888334, 0.01201642863},
  440. {0.6336835991, -0.3557432599, 0.05486251648, -0.001465281033},
  441. {-0.03276231906, 0.004474107445, 0.002429354597, -0.0001557569013}},
  442. {{0.221950385, -0.3187676331, 0.01799915743, 0.01074823814},
  443. {-0.2951249643, 0.06219954479, 0.01535556598, 0.001550077057},
  444. {0.02155310298, 0.004170831599, 0.001292462449, 6.976601077e-05}},
  445. {{-0.05980841433, 0.008441916499, 0.01085618172, 0.002319392565},
  446. {0.02911413504, 0.01400243777, -0.002721828457, 0.000751041181},
  447. {0.005900514878, -0.001936558688, -9.495446725e-06, 5.385558597e-05}}},
  448. {{{1, -0.02924021934, -0.04438342661, 0.007285809825},
  449. {0.6357567472, -0.3473456711, 0.05454656494, -0.002407477521},
  450. {-0.03301322327, 0.004845219414, 0.00231480583, -0.0002307248149}},
  451. {{0.5925320577, -0.1757678135, 0.01505928619, 0.000564515273},
  452. {0.1014815858, -0.06589186703, 0.01272886114, -0.0007316646956},
  453. {-0.007258481865, 0.001096195486, 0.0003934994223, -4.12701925e-05}},
  454. {{0.06469649321, -0.0236701437, 0.002902096474, -5.896963079e-05},
  455. {0.001925008108, -0.002869809258, 0.0008000589141, -6.063713228e-05},
  456. {-0.0003477407336, 6.959756487e-05, 1.097287507e-05, -1.650964693e-06}}},
  457. };
  458. const accscalar_t u = compat_log(x_);
  459. const accscalar_t a = compat_log(alpha_) - u;
  460. const accscalar_t b = compat_log(total_) - a;
  461. const accscalar_t pow_u[3] = {1, u, u * u};
  462. const accscalar_t pow_a[3] = {1, a, a * a};
  463. accscalar_t p = 0.0;
  464. accscalar_t q = 0.0;
  465. for (int i = 0; i < 3; ++i) {
  466. for (int j = 0; j < 3; ++j) {
  467. const accscalar_t ua = pow_u[i] * pow_a[j];
  468. p += ua * (c[0][i][j][0] + b * (c[0][i][j][1] + b * (c[0][i][j][2] + b * c[0][i][j][3])));
  469. q += ua * (c[1][i][j][0] + b * (c[1][i][j][1] + b * (c[1][i][j][2] + b * c[1][i][j][3])));
  470. }
  471. }
  472. const accscalar_t approx = x_ * (digamma_one<scalar_t, accscalar_t>(total_) - digamma_one<scalar_t, accscalar_t>(alpha_)) / beta_;
  473. return static_cast<scalar_t>(p / q * approx);
  474. }
  475. } // namespace