normal_distribution.hpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. /* boost random/normal_distribution.hpp header file
  2. *
  3. * Copyright Jens Maurer 2000-2001
  4. * Copyright Steven Watanabe 2010-2011
  5. * Distributed under the Boost Software License, Version 1.0. (See
  6. * accompanying file LICENSE_1_0.txt or copy at
  7. * http://www.boost.org/LICENSE_1_0.txt)
  8. *
  9. * See http://www.boost.org for most recent version including documentation.
  10. *
  11. * $Id$
  12. *
  13. * Revision history
  14. * 2001-02-18 moved to individual header files
  15. */
  16. #ifndef BOOST_RANDOM_NORMAL_DISTRIBUTION_HPP
  17. #define BOOST_RANDOM_NORMAL_DISTRIBUTION_HPP
  18. #include <boost/config/no_tr1/cmath.hpp>
  19. #include <istream>
  20. #include <iosfwd>
  21. #include <boost/assert.hpp>
  22. #include <boost/limits.hpp>
  23. #include <boost/static_assert.hpp>
  24. #include <boost/random/detail/config.hpp>
  25. #include <boost/random/detail/operators.hpp>
  26. #include <boost/random/detail/int_float_pair.hpp>
  27. #include <boost/random/uniform_01.hpp>
  28. #include <boost/random/exponential_distribution.hpp>
  29. namespace boost {
  30. namespace random {
  31. namespace detail {
  32. // tables for the ziggurat algorithm
  33. template<class RealType>
  34. struct normal_table {
  35. static const RealType table_x[129];
  36. static const RealType table_y[129];
  37. };
  38. template<class RealType>
  39. const RealType normal_table<RealType>::table_x[129] = {
  40. 3.7130862467403632609, 3.4426198558966521214, 3.2230849845786185446, 3.0832288582142137009,
  41. 2.9786962526450169606, 2.8943440070186706210, 2.8231253505459664379, 2.7611693723841538514,
  42. 2.7061135731187223371, 2.6564064112581924999, 2.6109722484286132035, 2.5690336259216391328,
  43. 2.5300096723854666170, 2.4934545220919507609, 2.4590181774083500943, 2.4264206455302115930,
  44. 2.3954342780074673425, 2.3658713701139875435, 2.3375752413355307354, 2.3104136836950021558,
  45. 2.2842740596736568056, 2.2590595738653295251, 2.2346863955870569803, 2.2110814088747278106,
  46. 2.1881804320720206093, 2.1659267937448407377, 2.1442701823562613518, 2.1231657086697899595,
  47. 2.1025731351849988838, 2.0824562379877246441, 2.0627822745039633575, 2.0435215366506694976,
  48. 2.0246469733729338782, 2.0061338699589668403, 1.9879595741230607243, 1.9701032608497132242,
  49. 1.9525457295488889058, 1.9352692282919002011, 1.9182573008597320303, 1.9014946531003176140,
  50. 1.8849670357028692380, 1.8686611409895420085, 1.8525645117230870617, 1.8366654602533840447,
  51. 1.8209529965910050740, 1.8054167642140487420, 1.7900469825946189862, 1.7748343955807692457,
  52. 1.7597702248942318749, 1.7448461281083765085, 1.7300541605582435350, 1.7153867407081165482,
  53. 1.7008366185643009437, 1.6863968467734863258, 1.6720607540918522072, 1.6578219209482075462,
  54. 1.6436741568569826489, 1.6296114794646783962, 1.6156280950371329644, 1.6017183802152770587,
  55. 1.5878768648844007019, 1.5740982160167497219, 1.5603772223598406870, 1.5467087798535034608,
  56. 1.5330878776675560787, 1.5195095847593707806, 1.5059690368565502602, 1.4924614237746154081,
  57. 1.4789819769830978546, 1.4655259573357946276, 1.4520886428822164926, 1.4386653166774613138,
  58. 1.4252512545068615734, 1.4118417124397602509, 1.3984319141236063517, 1.3850170377251486449,
  59. 1.3715922024197322698, 1.3581524543224228739, 1.3446927517457130432, 1.3312079496576765017,
  60. 1.3176927832013429910, 1.3041418501204215390, 1.2905495919178731508, 1.2769102735516997175,
  61. 1.2632179614460282310, 1.2494664995643337480, 1.2356494832544811749, 1.2217602305309625678,
  62. 1.2077917504067576028, 1.1937367078237721994, 1.1795873846544607035, 1.1653356361550469083,
  63. 1.1509728421389760651, 1.1364898520030755352, 1.1218769225722540661, 1.1071236475235353980,
  64. 1.0922188768965537614, 1.0771506248819376573, 1.0619059636836193998, 1.0464709007525802629,
  65. 1.0308302360564555907, 1.0149673952392994716, 0.99886423348064351303, 0.98250080350276038481,
  66. 0.96585507938813059489, 0.94890262549791195381, 0.93161619660135381056, 0.91396525100880177644,
  67. 0.89591535256623852894, 0.87742742909771569142, 0.85845684317805086354, 0.83895221428120745572,
  68. 0.81885390668331772331, 0.79809206062627480454, 0.77658398787614838598, 0.75423066443451007146,
  69. 0.73091191062188128150, 0.70647961131360803456, 0.68074791864590421664, 0.65347863871504238702,
  70. 0.62435859730908822111, 0.59296294244197797913, 0.55869217837551797140, 0.52065603872514491759,
  71. 0.47743783725378787681, 0.42654798630330512490, 0.36287143102841830424, 0.27232086470466385065,
  72. 0
  73. };
  74. template<class RealType>
  75. const RealType normal_table<RealType>::table_y[129] = {
  76. 0, 0.0026696290839025035092, 0.0055489952208164705392, 0.0086244844129304709682,
  77. 0.011839478657982313715, 0.015167298010672042468, 0.018592102737165812650, 0.022103304616111592615,
  78. 0.025693291936149616572, 0.029356317440253829618, 0.033087886146505155566, 0.036884388786968774128,
  79. 0.040742868074790604632, 0.044660862200872429800, 0.048636295860284051878, 0.052667401903503169793,
  80. 0.056752663481538584188, 0.060890770348566375972, 0.065080585213631873753, 0.069321117394180252601,
  81. 0.073611501884754893389, 0.077950982514654714188, 0.082338898242957408243, 0.086774671895542968998,
  82. 0.091257800827634710201, 0.09578784912257815216, 0.10036444102954554013, 0.10498725541035453978,
  83. 0.10965602101581776100, 0.11437051244988827452, 0.11913054670871858767, 0.12393598020398174246,
  84. 0.12878670619710396109, 0.13368265258464764118, 0.13862377998585103702, 0.14361008009193299469,
  85. 0.14864157424369696566, 0.15371831220958657066, 0.15884037114093507813, 0.16400785468492774791,
  86. 0.16922089223892475176, 0.17447963833240232295, 0.17978427212496211424, 0.18513499701071343216,
  87. 0.19053204032091372112, 0.19597565311811041399, 0.20146611007620324118, 0.20700370944187380064,
  88. 0.21258877307373610060, 0.21822164655637059599, 0.22390269938713388747, 0.22963232523430270355,
  89. 0.23541094226572765600, 0.24123899354775131610, 0.24711694751469673582, 0.25304529850976585934,
  90. 0.25902456739871074263, 0.26505530225816194029, 0.27113807914102527343, 0.27727350292189771153,
  91. 0.28346220822601251779, 0.28970486044581049771, 0.29600215684985583659, 0.30235482778947976274,
  92. 0.30876363800925192282, 0.31522938806815752222, 0.32175291587920862031, 0.32833509837615239609,
  93. 0.33497685331697116147, 0.34167914123501368412, 0.34844296754987246935, 0.35526938485154714435,
  94. 0.36215949537303321162, 0.36911445366827513952, 0.37613546951445442947, 0.38322381105988364587,
  95. 0.39038080824138948916, 0.39760785649804255208, 0.40490642081148835099, 0.41227804010702462062,
  96. 0.41972433205403823467, 0.42724699830956239880, 0.43484783025466189638, 0.44252871528024661483,
  97. 0.45029164368692696086, 0.45813871627287196483, 0.46607215269457097924, 0.47409430069824960453,
  98. 0.48220764633483869062, 0.49041482528932163741, 0.49871863547658432422, 0.50712205108130458951,
  99. 0.51562823824987205196, 0.52424057267899279809, 0.53296265938998758838, 0.54179835503172412311,
  100. 0.55075179312105527738, 0.55982741271069481791, 0.56902999107472161225, 0.57836468112670231279,
  101. 0.58783705444182052571, 0.59745315095181228217, 0.60721953663260488551, 0.61714337082656248870,
  102. 0.62723248525781456578, 0.63749547734314487428, 0.64794182111855080873, 0.65858200005865368016,
  103. 0.66942766735770616891, 0.68049184100641433355, 0.69178914344603585279, 0.70333609902581741633,
  104. 0.71515150742047704368, 0.72725691835450587793, 0.73967724368333814856, 0.75244155918570380145,
  105. 0.76558417390923599480, 0.77914608594170316563, 0.79317701178385921053, 0.80773829469612111340,
  106. 0.82290721139526200050, 0.83878360531064722379, 0.85550060788506428418, 0.87324304892685358879,
  107. 0.89228165080230272301, 0.91304364799203805999, 0.93628268170837107547, 0.96359969315576759960,
  108. 1
  109. };
  110. template<class RealType = double>
  111. struct unit_normal_distribution
  112. {
  113. template<class Engine>
  114. RealType operator()(Engine& eng) {
  115. const double * const table_x = normal_table<double>::table_x;
  116. const double * const table_y = normal_table<double>::table_y;
  117. for(;;) {
  118. std::pair<RealType, int> vals = generate_int_float_pair<RealType, 8>(eng);
  119. int i = vals.second;
  120. int sign = (i & 1) * 2 - 1;
  121. i = i >> 1;
  122. RealType x = vals.first * RealType(table_x[i]);
  123. if(x < table_x[i + 1]) return x * sign;
  124. if(i == 0) return generate_tail(eng) * sign;
  125. RealType y01 = uniform_01<RealType>()(eng);
  126. RealType y = RealType(table_y[i]) + y01 * RealType(table_y[i + 1] - table_y[i]);
  127. // These store the value y - bound, or something proportional to that difference:
  128. RealType y_above_ubound, y_above_lbound;
  129. // There are three cases to consider:
  130. // - convex regions (where x[i] > x[j] >= 1)
  131. // - concave regions (where 1 <= x[i] < x[j])
  132. // - region containing the inflection point (where x[i] > 1 > x[j])
  133. // For convex (concave), exp^(-x^2/2) is bounded below (above) by the tangent at
  134. // (x[i],y[i]) and is bounded above (below) by the diagonal line from (x[i+1],y[i+1]) to
  135. // (x[i],y[i]).
  136. //
  137. // *If* the inflection point region satisfies slope(x[i+1]) < slope(diagonal), then we
  138. // can treat the inflection region as a convex region: this condition is necessary and
  139. // sufficient to ensure that the curve lies entirely below the diagonal (that, in turn,
  140. // also implies that it will be above the tangent at x[i]).
  141. //
  142. // For the current table size (128), this is satisfied: slope(x[i+1]) = -0.60653 <
  143. // slope(diag) = -0.60649, and so we have only two cases below instead of three.
  144. if (table_x[i] >= 1) { // convex (incl. inflection)
  145. y_above_ubound = RealType(table_x[i] - table_x[i+1]) * y01 - (RealType(table_x[i]) - x);
  146. y_above_lbound = y - (RealType(table_y[i]) + (RealType(table_x[i]) - x) * RealType(table_y[i]) * RealType(table_x[i]));
  147. }
  148. else { // concave
  149. y_above_lbound = RealType(table_x[i] - table_x[i+1]) * y01 - (RealType(table_x[i]) - x);
  150. y_above_ubound = y - (RealType(table_y[i]) + (RealType(table_x[i]) - x) * RealType(table_y[i]) * RealType(table_x[i]));
  151. }
  152. if (y_above_ubound < 0 // if above the upper bound reject immediately
  153. &&
  154. (
  155. y_above_lbound < 0 // If below the lower bound accept immediately
  156. ||
  157. y < f(x) // Otherwise it's between the bounds and we need a full check
  158. )
  159. ) {
  160. return x * sign;
  161. }
  162. }
  163. }
  164. static RealType f(RealType x) {
  165. using std::exp;
  166. return exp(-(x*x/2));
  167. }
  168. // Generate from the tail using rejection sampling from the exponential(x_1) distribution,
  169. // shifted by x_1. This looks a little different from the usual rejection sampling because it
  170. // transforms the condition by taking the log of both sides, thus avoiding the costly exp() call
  171. // on the RHS, then takes advantage of the fact that -log(unif01) is simply generating an
  172. // exponential (by inverse cdf sampling) by replacing the log(unif01) on the LHS with a
  173. // exponential(1) draw, y.
  174. template<class Engine>
  175. RealType generate_tail(Engine& eng) {
  176. const RealType tail_start = RealType(normal_table<double>::table_x[1]);
  177. boost::random::exponential_distribution<RealType> exp_x(tail_start);
  178. boost::random::exponential_distribution<RealType> exp_y;
  179. for(;;) {
  180. RealType x = exp_x(eng);
  181. RealType y = exp_y(eng);
  182. // If we were doing non-transformed rejection sampling, this condition would be:
  183. // if (unif01 < exp(-.5*x*x)) return x + tail_start;
  184. if(2*y > x*x) return x + tail_start;
  185. }
  186. }
  187. };
  188. } // namespace detail
  189. /**
  190. * Instantiations of class template normal_distribution model a
  191. * \random_distribution. Such a distribution produces random numbers
  192. * @c x distributed with probability density function
  193. * \f$\displaystyle p(x) =
  194. * \frac{1}{\sqrt{2\pi}\sigma} e^{-\frac{(x-\mu)^2}{2\sigma^2}}
  195. * \f$,
  196. * where mean and sigma are the parameters of the distribution.
  197. *
  198. * The implementation uses the "ziggurat" algorithm, as described in
  199. *
  200. * @blockquote
  201. * "The Ziggurat Method for Generating Random Variables",
  202. * George Marsaglia and Wai Wan Tsang, Journal of Statistical Software,
  203. * Volume 5, Number 8 (2000), 1-7.
  204. * @endblockquote
  205. */
  206. template<class RealType = double>
  207. class normal_distribution
  208. {
  209. public:
  210. typedef RealType input_type;
  211. typedef RealType result_type;
  212. class param_type {
  213. public:
  214. typedef normal_distribution distribution_type;
  215. /**
  216. * Constructs a @c param_type with a given mean and
  217. * standard deviation.
  218. *
  219. * Requires: sigma >= 0
  220. */
  221. explicit param_type(RealType mean_arg = RealType(0.0),
  222. RealType sigma_arg = RealType(1.0))
  223. : _mean(mean_arg),
  224. _sigma(sigma_arg)
  225. {}
  226. /** Returns the mean of the distribution. */
  227. RealType mean() const { return _mean; }
  228. /** Returns the standand deviation of the distribution. */
  229. RealType sigma() const { return _sigma; }
  230. /** Writes a @c param_type to a @c std::ostream. */
  231. BOOST_RANDOM_DETAIL_OSTREAM_OPERATOR(os, param_type, parm)
  232. { os << parm._mean << " " << parm._sigma ; return os; }
  233. /** Reads a @c param_type from a @c std::istream. */
  234. BOOST_RANDOM_DETAIL_ISTREAM_OPERATOR(is, param_type, parm)
  235. { is >> parm._mean >> std::ws >> parm._sigma; return is; }
  236. /** Returns true if the two sets of parameters are the same. */
  237. BOOST_RANDOM_DETAIL_EQUALITY_OPERATOR(param_type, lhs, rhs)
  238. { return lhs._mean == rhs._mean && lhs._sigma == rhs._sigma; }
  239. /** Returns true if the two sets of parameters are the different. */
  240. BOOST_RANDOM_DETAIL_INEQUALITY_OPERATOR(param_type)
  241. private:
  242. RealType _mean;
  243. RealType _sigma;
  244. };
  245. /**
  246. * Constructs a @c normal_distribution object. @c mean and @c sigma are
  247. * the parameters for the distribution.
  248. *
  249. * Requires: sigma >= 0
  250. */
  251. explicit normal_distribution(const RealType& mean_arg = RealType(0.0),
  252. const RealType& sigma_arg = RealType(1.0))
  253. : _mean(mean_arg), _sigma(sigma_arg)
  254. {
  255. BOOST_ASSERT(_sigma >= RealType(0));
  256. }
  257. /**
  258. * Constructs a @c normal_distribution object from its parameters.
  259. */
  260. explicit normal_distribution(const param_type& parm)
  261. : _mean(parm.mean()), _sigma(parm.sigma())
  262. {}
  263. /** Returns the mean of the distribution. */
  264. RealType mean() const { return _mean; }
  265. /** Returns the standard deviation of the distribution. */
  266. RealType sigma() const { return _sigma; }
  267. /** Returns the smallest value that the distribution can produce. */
  268. RealType min BOOST_PREVENT_MACRO_SUBSTITUTION () const
  269. { return -std::numeric_limits<RealType>::infinity(); }
  270. /** Returns the largest value that the distribution can produce. */
  271. RealType max BOOST_PREVENT_MACRO_SUBSTITUTION () const
  272. { return std::numeric_limits<RealType>::infinity(); }
  273. /** Returns the parameters of the distribution. */
  274. param_type param() const { return param_type(_mean, _sigma); }
  275. /** Sets the parameters of the distribution. */
  276. void param(const param_type& parm)
  277. {
  278. _mean = parm.mean();
  279. _sigma = parm.sigma();
  280. }
  281. /**
  282. * Effects: Subsequent uses of the distribution do not depend
  283. * on values produced by any engine prior to invoking reset.
  284. */
  285. void reset() { }
  286. /** Returns a normal variate. */
  287. template<class Engine>
  288. result_type operator()(Engine& eng)
  289. {
  290. detail::unit_normal_distribution<RealType> impl;
  291. return impl(eng) * _sigma + _mean;
  292. }
  293. /** Returns a normal variate with parameters specified by @c param. */
  294. template<class URNG>
  295. result_type operator()(URNG& urng, const param_type& parm)
  296. {
  297. return normal_distribution(parm)(urng);
  298. }
  299. /** Writes a @c normal_distribution to a @c std::ostream. */
  300. BOOST_RANDOM_DETAIL_OSTREAM_OPERATOR(os, normal_distribution, nd)
  301. {
  302. os << nd._mean << " " << nd._sigma;
  303. return os;
  304. }
  305. /** Reads a @c normal_distribution from a @c std::istream. */
  306. BOOST_RANDOM_DETAIL_ISTREAM_OPERATOR(is, normal_distribution, nd)
  307. {
  308. is >> std::ws >> nd._mean >> std::ws >> nd._sigma;
  309. return is;
  310. }
  311. /**
  312. * Returns true if the two instances of @c normal_distribution will
  313. * return identical sequences of values given equal generators.
  314. */
  315. BOOST_RANDOM_DETAIL_EQUALITY_OPERATOR(normal_distribution, lhs, rhs)
  316. {
  317. return lhs._mean == rhs._mean && lhs._sigma == rhs._sigma;
  318. }
  319. /**
  320. * Returns true if the two instances of @c normal_distribution will
  321. * return different sequences of values given equal generators.
  322. */
  323. BOOST_RANDOM_DETAIL_INEQUALITY_OPERATOR(normal_distribution)
  324. private:
  325. RealType _mean, _sigma;
  326. };
  327. } // namespace random
  328. using random::normal_distribution;
  329. } // namespace boost
  330. #endif // BOOST_RANDOM_NORMAL_DISTRIBUTION_HPP