test_continuous_basic.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997
  1. import numpy as np
  2. import numpy.testing as npt
  3. import pytest
  4. from pytest import raises as assert_raises
  5. from scipy.integrate import IntegrationWarning
  6. import itertools
  7. from scipy import stats
  8. from .common_tests import (check_normalization, check_moment, check_mean_expect,
  9. check_var_expect, check_skew_expect,
  10. check_kurt_expect, check_entropy,
  11. check_private_entropy, check_entropy_vect_scale,
  12. check_edge_support, check_named_args,
  13. check_random_state_property,
  14. check_meth_dtype, check_ppf_dtype, check_cmplx_deriv,
  15. check_pickling, check_rvs_broadcast, check_freezing,
  16. check_deprecation_warning_gh5982_moment,
  17. check_deprecation_warning_gh5982_interval)
  18. from scipy.stats._distr_params import distcont
  19. from scipy.stats._distn_infrastructure import rv_continuous_frozen
  20. """
  21. Test all continuous distributions.
  22. Parameters were chosen for those distributions that pass the
  23. Kolmogorov-Smirnov test. This provides safe parameters for each
  24. distributions so that we can perform further testing of class methods.
  25. These tests currently check only/mostly for serious errors and exceptions,
  26. not for numerically exact results.
  27. """
  28. # Note that you need to add new distributions you want tested
  29. # to _distr_params
  30. DECIMAL = 5 # specify the precision of the tests # increased from 0 to 5
  31. # For skipping test_cont_basic
  32. distslow = ['recipinvgauss', 'vonmises', 'kappa4', 'vonmises_line',
  33. 'gausshyper', 'norminvgauss', 'geninvgauss', 'genhyperbolic',
  34. 'truncnorm', 'truncweibull_min']
  35. # distxslow are sorted by speed (very slow to slow)
  36. distxslow = ['studentized_range', 'kstwo', 'ksone', 'wrapcauchy', 'genexpon']
  37. # For skipping test_moments, which is already marked slow
  38. distxslow_test_moments = ['studentized_range', 'vonmises', 'vonmises_line',
  39. 'ksone', 'kstwo', 'recipinvgauss', 'genexpon']
  40. # skip check_fit_args (test is slow)
  41. skip_fit_test_mle = ['exponpow', 'exponweib', 'gausshyper', 'genexpon',
  42. 'halfgennorm', 'gompertz', 'johnsonsb', 'johnsonsu',
  43. 'kappa4', 'ksone', 'kstwo', 'kstwobign', 'mielke', 'ncf',
  44. 'nct', 'powerlognorm', 'powernorm', 'recipinvgauss',
  45. 'trapezoid', 'vonmises', 'vonmises_line', 'levy_stable',
  46. 'rv_histogram_instance', 'studentized_range']
  47. # these were really slow in `test_fit`.py.
  48. # note that this list is used to skip both fit_test and fit_fix tests
  49. slow_fit_test_mm = ['argus', 'exponpow', 'exponweib', 'gausshyper', 'genexpon',
  50. 'genhalflogistic', 'halfgennorm', 'gompertz', 'johnsonsb',
  51. 'kappa4', 'kstwobign', 'recipinvgauss',
  52. 'trapezoid', 'truncexpon', 'vonmises', 'vonmises_line',
  53. 'studentized_range']
  54. # pearson3 fails due to something weird
  55. # the first list fails due to non-finite distribution moments encountered
  56. # most of the rest fail due to integration warnings
  57. # pearson3 is overriden as not implemented due to gh-11746
  58. fail_fit_test_mm = (['alpha', 'betaprime', 'bradford', 'burr', 'burr12',
  59. 'cauchy', 'crystalball', 'f', 'fisk', 'foldcauchy',
  60. 'genextreme', 'genpareto', 'halfcauchy', 'invgamma',
  61. 'kappa3', 'levy', 'levy_l', 'loglaplace', 'lomax',
  62. 'mielke', 'nakagami', 'ncf', 'skewcauchy', 't',
  63. 'tukeylambda', 'invweibull']
  64. + ['genhyperbolic', 'johnsonsu', 'ksone', 'kstwo',
  65. 'nct', 'pareto', 'powernorm', 'powerlognorm']
  66. + ['pearson3'])
  67. skip_fit_test = {"MLE": skip_fit_test_mle,
  68. "MM": slow_fit_test_mm + fail_fit_test_mm}
  69. # skip check_fit_args_fix (test is slow)
  70. skip_fit_fix_test_mle = ['burr', 'exponpow', 'exponweib', 'gausshyper',
  71. 'genexpon', 'halfgennorm', 'gompertz', 'johnsonsb',
  72. 'johnsonsu', 'kappa4', 'ksone', 'kstwo', 'kstwobign',
  73. 'levy_stable', 'mielke', 'ncf', 'ncx2',
  74. 'powerlognorm', 'powernorm', 'rdist', 'recipinvgauss',
  75. 'trapezoid', 'vonmises', 'vonmises_line',
  76. 'studentized_range']
  77. # the first list fails due to non-finite distribution moments encountered
  78. # most of the rest fail due to integration warnings
  79. # pearson3 is overriden as not implemented due to gh-11746
  80. fail_fit_fix_test_mm = (['alpha', 'betaprime', 'burr', 'burr12', 'cauchy',
  81. 'crystalball', 'f', 'fisk', 'foldcauchy',
  82. 'genextreme', 'genpareto', 'halfcauchy', 'invgamma',
  83. 'kappa3', 'levy', 'levy_l', 'loglaplace', 'lomax',
  84. 'mielke', 'nakagami', 'ncf', 'nct', 'skewcauchy', 't',
  85. 'truncpareto', 'invweibull']
  86. + ['genhyperbolic', 'johnsonsu', 'ksone', 'kstwo',
  87. 'pareto', 'powernorm', 'powerlognorm']
  88. + ['pearson3'])
  89. skip_fit_fix_test = {"MLE": skip_fit_fix_test_mle,
  90. "MM": slow_fit_test_mm + fail_fit_fix_test_mm}
  91. # These distributions fail the complex derivative test below.
  92. # Here 'fail' mean produce wrong results and/or raise exceptions, depending
  93. # on the implementation details of corresponding special functions.
  94. # cf https://github.com/scipy/scipy/pull/4979 for a discussion.
  95. fails_cmplx = set(['argus', 'beta', 'betaprime', 'chi', 'chi2', 'cosine',
  96. 'dgamma', 'dweibull', 'erlang', 'f', 'gamma',
  97. 'gausshyper', 'gengamma', 'genhyperbolic',
  98. 'geninvgauss', 'gennorm', 'genpareto',
  99. 'halfgennorm', 'invgamma',
  100. 'ksone', 'kstwo', 'kstwobign', 'levy_l', 'loggamma',
  101. 'logistic', 'loguniform', 'maxwell', 'nakagami',
  102. 'ncf', 'nct', 'ncx2', 'norminvgauss', 'pearson3', 'rdist',
  103. 'reciprocal', 'rice', 'skewnorm', 't', 'truncweibull_min',
  104. 'tukeylambda', 'vonmises', 'vonmises_line',
  105. 'rv_histogram_instance', 'truncnorm', 'studentized_range'])
  106. # rv_histogram instances, with uniform and non-uniform bins;
  107. # stored as (dist, arg) tuples for cases_test_cont_basic
  108. # and cases_test_moments.
  109. histogram_test_instances = []
  110. case1 = {'a': [1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6,
  111. 6, 6, 6, 7, 7, 7, 8, 8, 9], 'bins': 8} # equal width bins
  112. case2 = {'a': [1, 1], 'bins': [0, 1, 10]} # unequal width bins
  113. for case, density in itertools.product([case1, case2], [True, False]):
  114. _hist = np.histogram(**case, density=density)
  115. _rv_hist = stats.rv_histogram(_hist, density=density)
  116. histogram_test_instances.append((_rv_hist, tuple()))
  117. def cases_test_cont_basic():
  118. for distname, arg in distcont[:] + histogram_test_instances:
  119. if distname == 'levy_stable':
  120. continue
  121. elif distname in distslow:
  122. yield pytest.param(distname, arg, marks=pytest.mark.slow)
  123. elif distname in distxslow:
  124. yield pytest.param(distname, arg, marks=pytest.mark.xslow)
  125. else:
  126. yield distname, arg
  127. @pytest.mark.filterwarnings('ignore::RuntimeWarning')
  128. @pytest.mark.parametrize('distname,arg', cases_test_cont_basic())
  129. @pytest.mark.parametrize('sn, n_fit_samples', [(500, 200)])
  130. def test_cont_basic(distname, arg, sn, n_fit_samples):
  131. # this test skips slow distributions
  132. try:
  133. distfn = getattr(stats, distname)
  134. except TypeError:
  135. distfn = distname
  136. distname = 'rv_histogram_instance'
  137. rng = np.random.RandomState(765456)
  138. rvs = distfn.rvs(size=sn, *arg, random_state=rng)
  139. m, v = distfn.stats(*arg)
  140. if distname not in {'laplace_asymmetric'}:
  141. check_sample_meanvar_(m, v, rvs)
  142. check_cdf_ppf(distfn, arg, distname)
  143. check_sf_isf(distfn, arg, distname)
  144. check_pdf(distfn, arg, distname)
  145. check_pdf_logpdf(distfn, arg, distname)
  146. check_pdf_logpdf_at_endpoints(distfn, arg, distname)
  147. check_cdf_logcdf(distfn, arg, distname)
  148. check_sf_logsf(distfn, arg, distname)
  149. check_ppf_broadcast(distfn, arg, distname)
  150. check_deprecation_warning_gh5982_moment(distfn, arg, distname)
  151. check_deprecation_warning_gh5982_interval(distfn, arg, distname)
  152. alpha = 0.01
  153. if distname == 'rv_histogram_instance':
  154. check_distribution_rvs(distfn.cdf, arg, alpha, rvs)
  155. elif distname != 'geninvgauss':
  156. # skip kstest for geninvgauss since cdf is too slow; see test for
  157. # rv generation in TestGenInvGauss in test_distributions.py
  158. check_distribution_rvs(distname, arg, alpha, rvs)
  159. locscale_defaults = (0, 1)
  160. meths = [distfn.pdf, distfn.logpdf, distfn.cdf, distfn.logcdf,
  161. distfn.logsf]
  162. # make sure arguments are within support
  163. spec_x = {'weibull_max': -0.5, 'levy_l': -0.5,
  164. 'pareto': 1.5, 'truncpareto': 3.2, 'tukeylambda': 0.3,
  165. 'rv_histogram_instance': 5.0}
  166. x = spec_x.get(distname, 0.5)
  167. if distname == 'invweibull':
  168. arg = (1,)
  169. elif distname == 'ksone':
  170. arg = (3,)
  171. check_named_args(distfn, x, arg, locscale_defaults, meths)
  172. check_random_state_property(distfn, arg)
  173. check_pickling(distfn, arg)
  174. check_freezing(distfn, arg)
  175. # Entropy
  176. if distname not in ['kstwobign', 'kstwo', 'ncf']:
  177. check_entropy(distfn, arg, distname)
  178. if distfn.numargs == 0:
  179. check_vecentropy(distfn, arg)
  180. if (distfn.__class__._entropy != stats.rv_continuous._entropy
  181. and distname != 'vonmises'):
  182. check_private_entropy(distfn, arg, stats.rv_continuous)
  183. with npt.suppress_warnings() as sup:
  184. sup.filter(IntegrationWarning, "The occurrence of roundoff error")
  185. sup.filter(IntegrationWarning, "Extremely bad integrand")
  186. sup.filter(RuntimeWarning, "invalid value")
  187. check_entropy_vect_scale(distfn, arg)
  188. check_retrieving_support(distfn, arg)
  189. check_edge_support(distfn, arg)
  190. check_meth_dtype(distfn, arg, meths)
  191. check_ppf_dtype(distfn, arg)
  192. if distname not in fails_cmplx:
  193. check_cmplx_deriv(distfn, arg)
  194. if distname != 'truncnorm':
  195. check_ppf_private(distfn, arg, distname)
  196. for method in ["MLE", "MM"]:
  197. if distname not in skip_fit_test[method]:
  198. check_fit_args(distfn, arg, rvs[:n_fit_samples], method)
  199. if distname not in skip_fit_fix_test[method]:
  200. check_fit_args_fix(distfn, arg, rvs[:n_fit_samples], method)
  201. @pytest.mark.parametrize('distname,arg', cases_test_cont_basic())
  202. def test_rvs_scalar(distname, arg):
  203. # rvs should return a scalar when given scalar arguments (gh-12428)
  204. try:
  205. distfn = getattr(stats, distname)
  206. except TypeError:
  207. distfn = distname
  208. distname = 'rv_histogram_instance'
  209. assert np.isscalar(distfn.rvs(*arg))
  210. assert np.isscalar(distfn.rvs(*arg, size=()))
  211. assert np.isscalar(distfn.rvs(*arg, size=None))
  212. def test_levy_stable_random_state_property():
  213. # levy_stable only implements rvs(), so it is skipped in the
  214. # main loop in test_cont_basic(). Here we apply just the test
  215. # check_random_state_property to levy_stable.
  216. check_random_state_property(stats.levy_stable, (0.5, 0.1))
  217. def cases_test_moments():
  218. fail_normalization = set()
  219. fail_higher = set(['ncf'])
  220. for distname, arg in distcont[:] + histogram_test_instances:
  221. if distname == 'levy_stable':
  222. continue
  223. if distname in distxslow_test_moments:
  224. yield pytest.param(distname, arg, True, True, True,
  225. marks=pytest.mark.xslow(reason="too slow"))
  226. continue
  227. cond1 = distname not in fail_normalization
  228. cond2 = distname not in fail_higher
  229. marks = list()
  230. # Currently unused, `marks` can be used to add a timeout to a test of
  231. # a specific distribution. For example, this shows how a timeout could
  232. # be added for the 'skewnorm' distribution:
  233. #
  234. # marks = list()
  235. # if distname == 'skewnorm':
  236. # marks.append(pytest.mark.timeout(300))
  237. yield pytest.param(distname, arg, cond1, cond2, False, marks=marks)
  238. if not cond1 or not cond2:
  239. # Run the distributions that have issues twice, once skipping the
  240. # not_ok parts, once with the not_ok parts but marked as knownfail
  241. yield pytest.param(distname, arg, True, True, True,
  242. marks=[pytest.mark.xfail] + marks)
  243. @pytest.mark.slow
  244. @pytest.mark.parametrize('distname,arg,normalization_ok,higher_ok,'
  245. 'is_xfailing',
  246. cases_test_moments())
  247. def test_moments(distname, arg, normalization_ok, higher_ok, is_xfailing):
  248. try:
  249. distfn = getattr(stats, distname)
  250. except TypeError:
  251. distfn = distname
  252. distname = 'rv_histogram_instance'
  253. with npt.suppress_warnings() as sup:
  254. sup.filter(IntegrationWarning,
  255. "The integral is probably divergent, or slowly convergent.")
  256. sup.filter(IntegrationWarning,
  257. "The maximum number of subdivisions.")
  258. if is_xfailing:
  259. sup.filter(IntegrationWarning)
  260. m, v, s, k = distfn.stats(*arg, moments='mvsk')
  261. with np.errstate(all="ignore"):
  262. if normalization_ok:
  263. check_normalization(distfn, arg, distname)
  264. if higher_ok:
  265. check_mean_expect(distfn, arg, m, distname)
  266. check_skew_expect(distfn, arg, m, v, s, distname)
  267. check_var_expect(distfn, arg, m, v, distname)
  268. check_kurt_expect(distfn, arg, m, v, k, distname)
  269. check_moment(distfn, arg, m, v, distname)
  270. @pytest.mark.parametrize('dist,shape_args', distcont)
  271. def test_rvs_broadcast(dist, shape_args):
  272. if dist in ['gausshyper', 'genexpon', 'studentized_range']:
  273. pytest.skip("too slow")
  274. # If shape_only is True, it means the _rvs method of the
  275. # distribution uses more than one random number to generate a random
  276. # variate. That means the result of using rvs with broadcasting or
  277. # with a nontrivial size will not necessarily be the same as using the
  278. # numpy.vectorize'd version of rvs(), so we can only compare the shapes
  279. # of the results, not the values.
  280. # Whether or not a distribution is in the following list is an
  281. # implementation detail of the distribution, not a requirement. If
  282. # the implementation the rvs() method of a distribution changes, this
  283. # test might also have to be changed.
  284. shape_only = dist in ['argus', 'betaprime', 'dgamma', 'dweibull',
  285. 'exponnorm', 'genhyperbolic', 'geninvgauss',
  286. 'levy_stable', 'nct', 'norminvgauss', 'rice',
  287. 'skewnorm', 'semicircular', 'gennorm', 'loggamma']
  288. distfunc = getattr(stats, dist)
  289. loc = np.zeros(2)
  290. scale = np.ones((3, 1))
  291. nargs = distfunc.numargs
  292. allargs = []
  293. bshape = [3, 2]
  294. # Generate shape parameter arguments...
  295. for k in range(nargs):
  296. shp = (k + 4,) + (1,)*(k + 2)
  297. allargs.append(shape_args[k]*np.ones(shp))
  298. bshape.insert(0, k + 4)
  299. allargs.extend([loc, scale])
  300. # bshape holds the expected shape when loc, scale, and the shape
  301. # parameters are all broadcast together.
  302. check_rvs_broadcast(distfunc, dist, allargs, bshape, shape_only, 'd')
  303. # Expected values of the SF, CDF, PDF were computed using
  304. # mpmath with mpmath.mp.dps = 50 and output at 20:
  305. #
  306. # def ks(x, n):
  307. # x = mpmath.mpf(x)
  308. # logp = -mpmath.power(6.0*n*x+1.0, 2)/18.0/n
  309. # sf, cdf = mpmath.exp(logp), -mpmath.expm1(logp)
  310. # pdf = (6.0*n*x+1.0) * 2 * sf/3
  311. # print(mpmath.nstr(sf, 20), mpmath.nstr(cdf, 20), mpmath.nstr(pdf, 20))
  312. #
  313. # Tests use 1/n < x < 1-1/n and n > 1e6 to use the asymptotic computation.
  314. # Larger x has a smaller sf.
  315. @pytest.mark.parametrize('x,n,sf,cdf,pdf,rtol',
  316. [(2.0e-5, 1000000000,
  317. 0.44932297307934442379, 0.55067702692065557621,
  318. 35946.137394996276407, 5e-15),
  319. (2.0e-9, 1000000000,
  320. 0.99999999061111115519, 9.3888888448132728224e-9,
  321. 8.6666665852962971765, 5e-14),
  322. (5.0e-4, 1000000000,
  323. 7.1222019433090374624e-218, 1.0,
  324. 1.4244408634752704094e-211, 5e-14)])
  325. def test_gh17775_regression(x, n, sf, cdf, pdf, rtol):
  326. # Regression test for gh-17775. In scipy 1.9.3 and earlier,
  327. # these test would fail.
  328. #
  329. # KS one asymptotic sf ~ e^(-(6nx+1)^2 / 18n)
  330. # Given a large 32-bit integer n, 6n will overflow in the c implementation.
  331. # Example of broken behaviour:
  332. # ksone.sf(2.0e-5, 1000000000) == 0.9374359693473666
  333. ks = stats.ksone
  334. vals = np.array([ks.sf(x, n), ks.cdf(x, n), ks.pdf(x, n)])
  335. expected = np.array([sf, cdf, pdf])
  336. npt.assert_allclose(vals, expected, rtol=rtol)
  337. # The sf+cdf must sum to 1.0.
  338. npt.assert_equal(vals[0] + vals[1], 1.0)
  339. # Check inverting the (potentially very small) sf (uses a lower tolerance)
  340. npt.assert_allclose([ks.isf(sf, n)], [x], rtol=1e-8)
  341. def test_rvs_gh2069_regression():
  342. # Regression tests for gh-2069. In scipy 0.17 and earlier,
  343. # these tests would fail.
  344. #
  345. # A typical example of the broken behavior:
  346. # >>> norm.rvs(loc=np.zeros(5), scale=np.ones(5))
  347. # array([-2.49613705, -2.49613705, -2.49613705, -2.49613705, -2.49613705])
  348. rng = np.random.RandomState(123)
  349. vals = stats.norm.rvs(loc=np.zeros(5), scale=1, random_state=rng)
  350. d = np.diff(vals)
  351. npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!")
  352. vals = stats.norm.rvs(loc=0, scale=np.ones(5), random_state=rng)
  353. d = np.diff(vals)
  354. npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!")
  355. vals = stats.norm.rvs(loc=np.zeros(5), scale=np.ones(5), random_state=rng)
  356. d = np.diff(vals)
  357. npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!")
  358. vals = stats.norm.rvs(loc=np.array([[0], [0]]), scale=np.ones(5),
  359. random_state=rng)
  360. d = np.diff(vals.ravel())
  361. npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!")
  362. assert_raises(ValueError, stats.norm.rvs, [[0, 0], [0, 0]],
  363. [[1, 1], [1, 1]], 1)
  364. assert_raises(ValueError, stats.gamma.rvs, [2, 3, 4, 5], 0, 1, (2, 2))
  365. assert_raises(ValueError, stats.gamma.rvs, [1, 1, 1, 1], [0, 0, 0, 0],
  366. [[1], [2]], (4,))
  367. def test_nomodify_gh9900_regression():
  368. # Regression test for gh-9990
  369. # Prior to gh-9990, calls to stats.truncnorm._cdf() use what ever was
  370. # set inside the stats.truncnorm instance during stats.truncnorm.cdf().
  371. # This could cause issues wth multi-threaded code.
  372. # Since then, the calls to cdf() are not permitted to modify the global
  373. # stats.truncnorm instance.
  374. tn = stats.truncnorm
  375. # Use the right-half truncated normal
  376. # Check that the cdf and _cdf return the same result.
  377. npt.assert_almost_equal(tn.cdf(1, 0, np.inf), 0.6826894921370859)
  378. npt.assert_almost_equal(tn._cdf([1], [0], [np.inf]), 0.6826894921370859)
  379. # Now use the left-half truncated normal
  380. npt.assert_almost_equal(tn.cdf(-1, -np.inf, 0), 0.31731050786291415)
  381. npt.assert_almost_equal(tn._cdf([-1], [-np.inf], [0]), 0.31731050786291415)
  382. # Check that the right-half truncated normal _cdf hasn't changed
  383. npt.assert_almost_equal(tn._cdf([1], [0], [np.inf]), 0.6826894921370859) # noqa, NOT 1.6826894921370859
  384. npt.assert_almost_equal(tn.cdf(1, 0, np.inf), 0.6826894921370859)
  385. # Check that the left-half truncated normal _cdf hasn't changed
  386. npt.assert_almost_equal(tn._cdf([-1], [-np.inf], [0]), 0.31731050786291415) # noqa, Not -0.6826894921370859
  387. npt.assert_almost_equal(tn.cdf(1, -np.inf, 0), 1) # Not 1.6826894921370859
  388. npt.assert_almost_equal(tn.cdf(-1, -np.inf, 0), 0.31731050786291415) # Not -0.6826894921370859
  389. def test_broadcast_gh9990_regression():
  390. # Regression test for gh-9990
  391. # The x-value 7 only lies within the support of 4 of the supplied
  392. # distributions. Prior to 9990, one array passed to
  393. # stats.reciprocal._cdf would have 4 elements, but an array
  394. # previously stored by stats.reciprocal_argcheck() would have 6, leading
  395. # to a broadcast error.
  396. a = np.array([1, 2, 3, 4, 5, 6])
  397. b = np.array([8, 16, 1, 32, 1, 48])
  398. ans = [stats.reciprocal.cdf(7, _a, _b) for _a, _b in zip(a,b)]
  399. npt.assert_array_almost_equal(stats.reciprocal.cdf(7, a, b), ans)
  400. ans = [stats.reciprocal.cdf(1, _a, _b) for _a, _b in zip(a,b)]
  401. npt.assert_array_almost_equal(stats.reciprocal.cdf(1, a, b), ans)
  402. ans = [stats.reciprocal.cdf(_a, _a, _b) for _a, _b in zip(a,b)]
  403. npt.assert_array_almost_equal(stats.reciprocal.cdf(a, a, b), ans)
  404. ans = [stats.reciprocal.cdf(_b, _a, _b) for _a, _b in zip(a,b)]
  405. npt.assert_array_almost_equal(stats.reciprocal.cdf(b, a, b), ans)
  406. def test_broadcast_gh7933_regression():
  407. # Check broadcast works
  408. stats.truncnorm.logpdf(
  409. np.array([3.0, 2.0, 1.0]),
  410. a=(1.5 - np.array([6.0, 5.0, 4.0])) / 3.0,
  411. b=np.inf,
  412. loc=np.array([6.0, 5.0, 4.0]),
  413. scale=3.0
  414. )
  415. def test_gh2002_regression():
  416. # Add a check that broadcast works in situations where only some
  417. # x-values are compatible with some of the shape arguments.
  418. x = np.r_[-2:2:101j]
  419. a = np.r_[-np.ones(50), np.ones(51)]
  420. expected = [stats.truncnorm.pdf(_x, _a, np.inf) for _x, _a in zip(x, a)]
  421. ans = stats.truncnorm.pdf(x, a, np.inf)
  422. npt.assert_array_almost_equal(ans, expected)
  423. def test_gh1320_regression():
  424. # Check that the first example from gh-1320 now works.
  425. c = 2.62
  426. stats.genextreme.ppf(0.5, np.array([[c], [c + 0.5]]))
  427. # The other examples in gh-1320 appear to have stopped working
  428. # some time ago.
  429. # ans = stats.genextreme.moment(2, np.array([c, c + 0.5]))
  430. # expected = np.array([25.50105963, 115.11191437])
  431. # stats.genextreme.moment(5, np.array([[c], [c + 0.5]]))
  432. # stats.genextreme.moment(5, np.array([c, c + 0.5]))
  433. def test_method_of_moments():
  434. # example from https://en.wikipedia.org/wiki/Method_of_moments_(statistics)
  435. np.random.seed(1234)
  436. x = [0, 0, 0, 0, 1]
  437. a = 1/5 - 2*np.sqrt(3)/5
  438. b = 1/5 + 2*np.sqrt(3)/5
  439. # force use of method of moments (uniform.fit is overriden)
  440. loc, scale = super(type(stats.uniform), stats.uniform).fit(x, method="MM")
  441. npt.assert_almost_equal(loc, a, decimal=4)
  442. npt.assert_almost_equal(loc+scale, b, decimal=4)
  443. def check_sample_meanvar_(popmean, popvar, sample):
  444. if np.isfinite(popmean):
  445. check_sample_mean(sample, popmean)
  446. if np.isfinite(popvar):
  447. check_sample_var(sample, popvar)
  448. def check_sample_mean(sample, popmean):
  449. # Checks for unlikely difference between sample mean and population mean
  450. prob = stats.ttest_1samp(sample, popmean).pvalue
  451. assert prob > 0.01
  452. def check_sample_var(sample, popvar):
  453. # check that population mean lies within the CI bootstrapped from the
  454. # sample. This used to be a chi-squared test for variance, but there were
  455. # too many false positives
  456. res = stats.bootstrap(
  457. (sample,),
  458. lambda x, axis: x.var(ddof=1, axis=axis),
  459. confidence_level=0.995,
  460. )
  461. conf = res.confidence_interval
  462. low, high = conf.low, conf.high
  463. assert low <= popvar <= high
  464. def check_cdf_ppf(distfn, arg, msg):
  465. values = [0.001, 0.5, 0.999]
  466. npt.assert_almost_equal(distfn.cdf(distfn.ppf(values, *arg), *arg),
  467. values, decimal=DECIMAL, err_msg=msg +
  468. ' - cdf-ppf roundtrip')
  469. def check_sf_isf(distfn, arg, msg):
  470. npt.assert_almost_equal(distfn.sf(distfn.isf([0.1, 0.5, 0.9], *arg), *arg),
  471. [0.1, 0.5, 0.9], decimal=DECIMAL, err_msg=msg +
  472. ' - sf-isf roundtrip')
  473. npt.assert_almost_equal(distfn.cdf([0.1, 0.9], *arg),
  474. 1.0 - distfn.sf([0.1, 0.9], *arg),
  475. decimal=DECIMAL, err_msg=msg +
  476. ' - cdf-sf relationship')
  477. def check_pdf(distfn, arg, msg):
  478. # compares pdf at median with numerical derivative of cdf
  479. median = distfn.ppf(0.5, *arg)
  480. eps = 1e-6
  481. pdfv = distfn.pdf(median, *arg)
  482. if (pdfv < 1e-4) or (pdfv > 1e4):
  483. # avoid checking a case where pdf is close to zero or
  484. # huge (singularity)
  485. median = median + 0.1
  486. pdfv = distfn.pdf(median, *arg)
  487. cdfdiff = (distfn.cdf(median + eps, *arg) -
  488. distfn.cdf(median - eps, *arg))/eps/2.0
  489. # replace with better diff and better test (more points),
  490. # actually, this works pretty well
  491. msg += ' - cdf-pdf relationship'
  492. npt.assert_almost_equal(pdfv, cdfdiff, decimal=DECIMAL, err_msg=msg)
  493. def check_pdf_logpdf(distfn, args, msg):
  494. # compares pdf at several points with the log of the pdf
  495. points = np.array([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
  496. vals = distfn.ppf(points, *args)
  497. vals = vals[np.isfinite(vals)]
  498. pdf = distfn.pdf(vals, *args)
  499. logpdf = distfn.logpdf(vals, *args)
  500. pdf = pdf[(pdf != 0) & np.isfinite(pdf)]
  501. logpdf = logpdf[np.isfinite(logpdf)]
  502. msg += " - logpdf-log(pdf) relationship"
  503. npt.assert_almost_equal(np.log(pdf), logpdf, decimal=7, err_msg=msg)
  504. def check_pdf_logpdf_at_endpoints(distfn, args, msg):
  505. # compares pdf with the log of the pdf at the (finite) end points
  506. points = np.array([0, 1])
  507. vals = distfn.ppf(points, *args)
  508. vals = vals[np.isfinite(vals)]
  509. pdf = distfn.pdf(vals, *args)
  510. logpdf = distfn.logpdf(vals, *args)
  511. pdf = pdf[(pdf != 0) & np.isfinite(pdf)]
  512. logpdf = logpdf[np.isfinite(logpdf)]
  513. msg += " - logpdf-log(pdf) relationship"
  514. npt.assert_almost_equal(np.log(pdf), logpdf, decimal=7, err_msg=msg)
  515. def check_sf_logsf(distfn, args, msg):
  516. # compares sf at several points with the log of the sf
  517. points = np.array([0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 1.0])
  518. vals = distfn.ppf(points, *args)
  519. vals = vals[np.isfinite(vals)]
  520. sf = distfn.sf(vals, *args)
  521. logsf = distfn.logsf(vals, *args)
  522. sf = sf[sf != 0]
  523. logsf = logsf[np.isfinite(logsf)]
  524. msg += " - logsf-log(sf) relationship"
  525. npt.assert_almost_equal(np.log(sf), logsf, decimal=7, err_msg=msg)
  526. def check_cdf_logcdf(distfn, args, msg):
  527. # compares cdf at several points with the log of the cdf
  528. points = np.array([0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 1.0])
  529. vals = distfn.ppf(points, *args)
  530. vals = vals[np.isfinite(vals)]
  531. cdf = distfn.cdf(vals, *args)
  532. logcdf = distfn.logcdf(vals, *args)
  533. cdf = cdf[cdf != 0]
  534. logcdf = logcdf[np.isfinite(logcdf)]
  535. msg += " - logcdf-log(cdf) relationship"
  536. npt.assert_almost_equal(np.log(cdf), logcdf, decimal=7, err_msg=msg)
  537. def check_ppf_broadcast(distfn, arg, msg):
  538. # compares ppf for multiple argsets.
  539. num_repeats = 5
  540. args = [] * num_repeats
  541. if arg:
  542. args = [np.array([_] * num_repeats) for _ in arg]
  543. median = distfn.ppf(0.5, *arg)
  544. medians = distfn.ppf(0.5, *args)
  545. msg += " - ppf multiple"
  546. npt.assert_almost_equal(medians, [median] * num_repeats, decimal=7, err_msg=msg)
  547. def check_distribution_rvs(dist, args, alpha, rvs):
  548. # dist is either a cdf function or name of a distribution in scipy.stats.
  549. # args are the args for scipy.stats.dist(*args)
  550. # alpha is a significance level, ~0.01
  551. # rvs is array_like of random variables
  552. # test from scipy.stats.tests
  553. # this version reuses existing random variables
  554. D, pval = stats.kstest(rvs, dist, args=args, N=1000)
  555. if (pval < alpha):
  556. # The rvs passed in failed the K-S test, which _could_ happen
  557. # but is unlikely if alpha is small enough.
  558. # Repeat the test with a new sample of rvs.
  559. # Generate 1000 rvs, perform a K-S test that the new sample of rvs
  560. # are distributed according to the distribution.
  561. D, pval = stats.kstest(dist, dist, args=args, N=1000)
  562. npt.assert_(pval > alpha, "D = " + str(D) + "; pval = " + str(pval) +
  563. "; alpha = " + str(alpha) + "\nargs = " + str(args))
  564. def check_vecentropy(distfn, args):
  565. npt.assert_equal(distfn.vecentropy(*args), distfn._entropy(*args))
  566. def check_loc_scale(distfn, arg, m, v, msg):
  567. # Make `loc` and `scale` arrays to catch bugs like gh-13580 where
  568. # `loc` and `scale` arrays improperly broadcast with shapes.
  569. loc, scale = np.array([10.0, 20.0]), np.array([10.0, 20.0])
  570. mt, vt = distfn.stats(loc=loc, scale=scale, *arg)
  571. npt.assert_allclose(m*scale + loc, mt)
  572. npt.assert_allclose(v*scale*scale, vt)
  573. def check_ppf_private(distfn, arg, msg):
  574. # fails by design for truncnorm self.nb not defined
  575. ppfs = distfn._ppf(np.array([0.1, 0.5, 0.9]), *arg)
  576. npt.assert_(not np.any(np.isnan(ppfs)), msg + 'ppf private is nan')
  577. def check_retrieving_support(distfn, args):
  578. loc, scale = 1, 2
  579. supp = distfn.support(*args)
  580. supp_loc_scale = distfn.support(*args, loc=loc, scale=scale)
  581. npt.assert_almost_equal(np.array(supp)*scale + loc,
  582. np.array(supp_loc_scale))
  583. def check_fit_args(distfn, arg, rvs, method):
  584. with np.errstate(all='ignore'), npt.suppress_warnings() as sup:
  585. sup.filter(category=RuntimeWarning,
  586. message="The shape parameter of the erlang")
  587. sup.filter(category=RuntimeWarning,
  588. message="floating point number truncated")
  589. vals = distfn.fit(rvs, method=method)
  590. vals2 = distfn.fit(rvs, optimizer='powell', method=method)
  591. # Only check the length of the return; accuracy tested in test_fit.py
  592. npt.assert_(len(vals) == 2+len(arg))
  593. npt.assert_(len(vals2) == 2+len(arg))
  594. def check_fit_args_fix(distfn, arg, rvs, method):
  595. with np.errstate(all='ignore'), npt.suppress_warnings() as sup:
  596. sup.filter(category=RuntimeWarning,
  597. message="The shape parameter of the erlang")
  598. vals = distfn.fit(rvs, floc=0, method=method)
  599. vals2 = distfn.fit(rvs, fscale=1, method=method)
  600. npt.assert_(len(vals) == 2+len(arg))
  601. npt.assert_(vals[-2] == 0)
  602. npt.assert_(vals2[-1] == 1)
  603. npt.assert_(len(vals2) == 2+len(arg))
  604. if len(arg) > 0:
  605. vals3 = distfn.fit(rvs, f0=arg[0], method=method)
  606. npt.assert_(len(vals3) == 2+len(arg))
  607. npt.assert_(vals3[0] == arg[0])
  608. if len(arg) > 1:
  609. vals4 = distfn.fit(rvs, f1=arg[1], method=method)
  610. npt.assert_(len(vals4) == 2+len(arg))
  611. npt.assert_(vals4[1] == arg[1])
  612. if len(arg) > 2:
  613. vals5 = distfn.fit(rvs, f2=arg[2], method=method)
  614. npt.assert_(len(vals5) == 2+len(arg))
  615. npt.assert_(vals5[2] == arg[2])
  616. @pytest.mark.filterwarnings('ignore::RuntimeWarning')
  617. @pytest.mark.parametrize('method', ['pdf', 'logpdf', 'cdf', 'logcdf',
  618. 'sf', 'logsf', 'ppf', 'isf'])
  619. @pytest.mark.parametrize('distname, args', distcont)
  620. def test_methods_with_lists(method, distname, args):
  621. # Test that the continuous distributions can accept Python lists
  622. # as arguments.
  623. dist = getattr(stats, distname)
  624. f = getattr(dist, method)
  625. if distname == 'invweibull' and method.startswith('log'):
  626. x = [1.5, 2]
  627. else:
  628. x = [0.1, 0.2]
  629. shape2 = [[a]*2 for a in args]
  630. loc = [0, 0.1]
  631. scale = [1, 1.01]
  632. result = f(x, *shape2, loc=loc, scale=scale)
  633. npt.assert_allclose(result,
  634. [f(*v) for v in zip(x, *shape2, loc, scale)],
  635. rtol=1e-14, atol=5e-14)
  636. @pytest.mark.parametrize('method', ['pdf', 'logpdf', 'cdf', 'logcdf',
  637. 'sf', 'logsf', 'ppf', 'isf'])
  638. def test_gilbrat_deprecation(method):
  639. expected = getattr(stats.gibrat, method)(1)
  640. with pytest.warns(
  641. DeprecationWarning,
  642. match=rf"\s*`gilbrat\.{method}` is deprecated,.*",
  643. ):
  644. result = getattr(stats.gilbrat, method)(1)
  645. assert result == expected
  646. @pytest.mark.parametrize('method', ['pdf', 'logpdf', 'cdf', 'logcdf',
  647. 'sf', 'logsf', 'ppf', 'isf'])
  648. def test_gilbrat_deprecation_frozen(method):
  649. expected = getattr(stats.gibrat, method)(1)
  650. with pytest.warns(DeprecationWarning, match=r"\s*`gilbrat` is deprecated"):
  651. # warn on instantiation of frozen distribution...
  652. g = stats.gilbrat()
  653. # ... not on its methods
  654. result = getattr(g, method)(1)
  655. assert result == expected
  656. def test_burr_fisk_moment_gh13234_regression():
  657. vals0 = stats.burr.moment(1, 5, 4)
  658. assert isinstance(vals0, float)
  659. vals1 = stats.fisk.moment(1, 8)
  660. assert isinstance(vals1, float)
  661. def test_moments_with_array_gh12192_regression():
  662. # array loc and scalar scale
  663. vals0 = stats.norm.moment(order=1, loc=np.array([1, 2, 3]), scale=1)
  664. expected0 = np.array([1., 2., 3.])
  665. npt.assert_equal(vals0, expected0)
  666. # array loc and invalid scalar scale
  667. vals1 = stats.norm.moment(order=1, loc=np.array([1, 2, 3]), scale=-1)
  668. expected1 = np.array([np.nan, np.nan, np.nan])
  669. npt.assert_equal(vals1, expected1)
  670. # array loc and array scale with invalid entries
  671. vals2 = stats.norm.moment(order=1, loc=np.array([1, 2, 3]),
  672. scale=[-3, 1, 0])
  673. expected2 = np.array([np.nan, 2., np.nan])
  674. npt.assert_equal(vals2, expected2)
  675. # (loc == 0) & (scale < 0)
  676. vals3 = stats.norm.moment(order=2, loc=0, scale=-4)
  677. expected3 = np.nan
  678. npt.assert_equal(vals3, expected3)
  679. assert isinstance(vals3, expected3.__class__)
  680. # array loc with 0 entries and scale with invalid entries
  681. vals4 = stats.norm.moment(order=2, loc=[1, 0, 2], scale=[3, -4, -5])
  682. expected4 = np.array([10., np.nan, np.nan])
  683. npt.assert_equal(vals4, expected4)
  684. # all(loc == 0) & (array scale with invalid entries)
  685. vals5 = stats.norm.moment(order=2, loc=[0, 0, 0], scale=[5., -2, 100.])
  686. expected5 = np.array([25., np.nan, 10000.])
  687. npt.assert_equal(vals5, expected5)
  688. # all( (loc == 0) & (scale < 0) )
  689. vals6 = stats.norm.moment(order=2, loc=[0, 0, 0], scale=[-5., -2, -100.])
  690. expected6 = np.array([np.nan, np.nan, np.nan])
  691. npt.assert_equal(vals6, expected6)
  692. # scalar args, loc, and scale
  693. vals7 = stats.chi.moment(order=2, df=1, loc=0, scale=0)
  694. expected7 = np.nan
  695. npt.assert_equal(vals7, expected7)
  696. assert isinstance(vals7, expected7.__class__)
  697. # array args, scalar loc, and scalar scale
  698. vals8 = stats.chi.moment(order=2, df=[1, 2, 3], loc=0, scale=0)
  699. expected8 = np.array([np.nan, np.nan, np.nan])
  700. npt.assert_equal(vals8, expected8)
  701. # array args, array loc, and array scale
  702. vals9 = stats.chi.moment(order=2, df=[1, 2, 3], loc=[1., 0., 2.],
  703. scale=[1., -3., 0.])
  704. expected9 = np.array([3.59576912, np.nan, np.nan])
  705. npt.assert_allclose(vals9, expected9, rtol=1e-8)
  706. # (n > 4), all(loc != 0), and all(scale != 0)
  707. vals10 = stats.norm.moment(5, [1., 2.], [1., 2.])
  708. expected10 = np.array([26., 832.])
  709. npt.assert_allclose(vals10, expected10, rtol=1e-13)
  710. # test broadcasting and more
  711. a = [-1.1, 0, 1, 2.2, np.pi]
  712. b = [-1.1, 0, 1, 2.2, np.pi]
  713. loc = [-1.1, 0, np.sqrt(2)]
  714. scale = [-2.1, 0, 1, 2.2, np.pi]
  715. a = np.array(a).reshape((-1, 1, 1, 1))
  716. b = np.array(b).reshape((-1, 1, 1))
  717. loc = np.array(loc).reshape((-1, 1))
  718. scale = np.array(scale)
  719. vals11 = stats.beta.moment(order=2, a=a, b=b, loc=loc, scale=scale)
  720. a, b, loc, scale = np.broadcast_arrays(a, b, loc, scale)
  721. for i in np.ndenumerate(a):
  722. with np.errstate(invalid='ignore', divide='ignore'):
  723. i = i[0] # just get the index
  724. # check against same function with scalar input
  725. expected = stats.beta.moment(order=2, a=a[i], b=b[i],
  726. loc=loc[i], scale=scale[i])
  727. np.testing.assert_equal(vals11[i], expected)
  728. def test_broadcasting_in_moments_gh12192_regression():
  729. vals0 = stats.norm.moment(order=1, loc=np.array([1, 2, 3]), scale=[[1]])
  730. expected0 = np.array([[1., 2., 3.]])
  731. npt.assert_equal(vals0, expected0)
  732. assert vals0.shape == expected0.shape
  733. vals1 = stats.norm.moment(order=1, loc=np.array([[1], [2], [3]]),
  734. scale=[1, 2, 3])
  735. expected1 = np.array([[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]])
  736. npt.assert_equal(vals1, expected1)
  737. assert vals1.shape == expected1.shape
  738. vals2 = stats.chi.moment(order=1, df=[1., 2., 3.], loc=0., scale=1.)
  739. expected2 = np.array([0.79788456, 1.25331414, 1.59576912])
  740. npt.assert_allclose(vals2, expected2, rtol=1e-8)
  741. assert vals2.shape == expected2.shape
  742. vals3 = stats.chi.moment(order=1, df=[[1.], [2.], [3.]], loc=[0., 1., 2.],
  743. scale=[-1., 0., 3.])
  744. expected3 = np.array([[np.nan, np.nan, 4.39365368],
  745. [np.nan, np.nan, 5.75994241],
  746. [np.nan, np.nan, 6.78730736]])
  747. npt.assert_allclose(vals3, expected3, rtol=1e-8)
  748. assert vals3.shape == expected3.shape
  749. def test_kappa3_array_gh13582():
  750. # https://github.com/scipy/scipy/pull/15140#issuecomment-994958241
  751. shapes = [0.5, 1.5, 2.5, 3.5, 4.5]
  752. moments = 'mvsk'
  753. res = np.array([[stats.kappa3.stats(shape, moments=moment)
  754. for shape in shapes] for moment in moments])
  755. res2 = np.array(stats.kappa3.stats(shapes, moments=moments))
  756. npt.assert_allclose(res, res2)
  757. @pytest.mark.xslow
  758. def test_kappa4_array_gh13582():
  759. h = np.array([-0.5, 2.5, 3.5, 4.5, -3])
  760. k = np.array([-0.5, 1, -1.5, 0, 3.5])
  761. moments = 'mvsk'
  762. res = np.array([[stats.kappa4.stats(h[i], k[i], moments=moment)
  763. for i in range(5)] for moment in moments])
  764. res2 = np.array(stats.kappa4.stats(h, k, moments=moments))
  765. npt.assert_allclose(res, res2)
  766. # https://github.com/scipy/scipy/pull/15250#discussion_r775112913
  767. h = np.array([-1, -1/4, -1/4, 1, -1, 0])
  768. k = np.array([1, 1, 1/2, -1/3, -1, 0])
  769. res = np.array([[stats.kappa4.stats(h[i], k[i], moments=moment)
  770. for i in range(6)] for moment in moments])
  771. res2 = np.array(stats.kappa4.stats(h, k, moments=moments))
  772. npt.assert_allclose(res, res2)
  773. # https://github.com/scipy/scipy/pull/15250#discussion_r775115021
  774. h = np.array([-1, -0.5, 1])
  775. k = np.array([-1, -0.5, 0, 1])[:, None]
  776. res2 = np.array(stats.kappa4.stats(h, k, moments=moments))
  777. assert res2.shape == (4, 4, 3)
  778. def test_frozen_attributes():
  779. # gh-14827 reported that all frozen distributions had both pmf and pdf
  780. # attributes; continuous should have pdf and discrete should have pmf.
  781. message = "'rv_continuous_frozen' object has no attribute"
  782. with pytest.raises(AttributeError, match=message):
  783. stats.norm().pmf
  784. with pytest.raises(AttributeError, match=message):
  785. stats.norm().logpmf
  786. stats.norm.pmf = "herring"
  787. frozen_norm = stats.norm()
  788. assert isinstance(frozen_norm, rv_continuous_frozen)
  789. delattr(stats.norm, 'pmf')
  790. def test_skewnorm_pdf_gh16038():
  791. rng = np.random.default_rng(0)
  792. x, a = -np.inf, 0
  793. npt.assert_equal(stats.skewnorm.pdf(x, a), stats.norm.pdf(x))
  794. x, a = rng.random(size=(3, 3)), rng.random(size=(3, 3))
  795. mask = rng.random(size=(3, 3)) < 0.5
  796. a[mask] = 0
  797. x_norm = x[mask]
  798. res = stats.skewnorm.pdf(x, a)
  799. npt.assert_equal(res[mask], stats.norm.pdf(x_norm))
  800. npt.assert_equal(res[~mask], stats.skewnorm.pdf(x[~mask], a[~mask]))
  801. # for scalar input, these functions should return scalar output
  802. scalar_out = [['rvs', []], ['pdf', [0]], ['logpdf', [0]], ['cdf', [0]],
  803. ['logcdf', [0]], ['sf', [0]], ['logsf', [0]], ['ppf', [0]],
  804. ['isf', [0]], ['moment', [1]], ['entropy', []], ['expect', []],
  805. ['median', []], ['mean', []], ['std', []], ['var', []]]
  806. scalars_out = [['interval', [0.95]], ['support', []], ['stats', ['mv']]]
  807. @pytest.mark.parametrize('case', scalar_out + scalars_out)
  808. def test_scalar_for_scalar(case):
  809. # Some rv_continuous functions returned 0d array instead of NumPy scalar
  810. # Guard against regression
  811. method_name, args = case
  812. method = getattr(stats.norm(), method_name)
  813. res = method(*args)
  814. if case in scalar_out:
  815. assert isinstance(res, np.number)
  816. else:
  817. assert isinstance(res[0], np.number)
  818. assert isinstance(res[1], np.number)
  819. def test_scalar_for_scalar2():
  820. # test methods that are not attributes of frozen distributions
  821. res = stats.norm.fit([1, 2, 3])
  822. assert isinstance(res[0], np.number)
  823. assert isinstance(res[1], np.number)
  824. res = stats.norm.fit_loc_scale([1, 2, 3])
  825. assert isinstance(res[0], np.number)
  826. assert isinstance(res[1], np.number)
  827. res = stats.norm.nnlf((0, 1), [1, 2, 3])
  828. assert isinstance(res, np.number)