common_tests.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. import pickle
  2. import re
  3. import numpy as np
  4. import numpy.testing as npt
  5. from numpy.testing import assert_allclose, assert_equal
  6. from pytest import raises as assert_raises
  7. import numpy.ma.testutils as ma_npt
  8. from scipy._lib._util import getfullargspec_no_self as _getfullargspec
  9. from scipy import stats
  10. def check_named_results(res, attributes, ma=False):
  11. for i, attr in enumerate(attributes):
  12. if ma:
  13. ma_npt.assert_equal(res[i], getattr(res, attr))
  14. else:
  15. npt.assert_equal(res[i], getattr(res, attr))
  16. def check_normalization(distfn, args, distname):
  17. norm_moment = distfn.moment(0, *args)
  18. npt.assert_allclose(norm_moment, 1.0)
  19. if distname == "rv_histogram_instance":
  20. atol, rtol = 1e-5, 0
  21. else:
  22. atol, rtol = 1e-7, 1e-7
  23. normalization_expect = distfn.expect(lambda x: 1, args=args)
  24. npt.assert_allclose(normalization_expect, 1.0, atol=atol, rtol=rtol,
  25. err_msg=distname, verbose=True)
  26. _a, _b = distfn.support(*args)
  27. normalization_cdf = distfn.cdf(_b, *args)
  28. npt.assert_allclose(normalization_cdf, 1.0)
  29. def check_moment(distfn, arg, m, v, msg):
  30. m1 = distfn.moment(1, *arg)
  31. m2 = distfn.moment(2, *arg)
  32. if not np.isinf(m):
  33. npt.assert_almost_equal(m1, m, decimal=10, err_msg=msg +
  34. ' - 1st moment')
  35. else: # or np.isnan(m1),
  36. npt.assert_(np.isinf(m1),
  37. msg + ' - 1st moment -infinite, m1=%s' % str(m1))
  38. if not np.isinf(v):
  39. npt.assert_almost_equal(m2 - m1 * m1, v, decimal=10, err_msg=msg +
  40. ' - 2ndt moment')
  41. else: # or np.isnan(m2),
  42. npt.assert_(np.isinf(m2),
  43. msg + ' - 2nd moment -infinite, m2=%s' % str(m2))
  44. def check_mean_expect(distfn, arg, m, msg):
  45. if np.isfinite(m):
  46. m1 = distfn.expect(lambda x: x, arg)
  47. npt.assert_almost_equal(m1, m, decimal=5, err_msg=msg +
  48. ' - 1st moment (expect)')
  49. def check_var_expect(distfn, arg, m, v, msg):
  50. kwargs = {'rtol': 5e-6} if msg == "rv_histogram_instance" else {}
  51. if np.isfinite(v):
  52. m2 = distfn.expect(lambda x: x*x, arg)
  53. npt.assert_allclose(m2, v + m*m, **kwargs)
  54. def check_skew_expect(distfn, arg, m, v, s, msg):
  55. if np.isfinite(s):
  56. m3e = distfn.expect(lambda x: np.power(x-m, 3), arg)
  57. npt.assert_almost_equal(m3e, s * np.power(v, 1.5),
  58. decimal=5, err_msg=msg + ' - skew')
  59. else:
  60. npt.assert_(np.isnan(s))
  61. def check_kurt_expect(distfn, arg, m, v, k, msg):
  62. if np.isfinite(k):
  63. m4e = distfn.expect(lambda x: np.power(x-m, 4), arg)
  64. npt.assert_allclose(m4e, (k + 3.) * np.power(v, 2), atol=1e-5, rtol=1e-5,
  65. err_msg=msg + ' - kurtosis')
  66. elif not np.isposinf(k):
  67. npt.assert_(np.isnan(k))
  68. def check_entropy(distfn, arg, msg):
  69. ent = distfn.entropy(*arg)
  70. npt.assert_(not np.isnan(ent), msg + 'test Entropy is nan')
  71. def check_private_entropy(distfn, args, superclass):
  72. # compare a generic _entropy with the distribution-specific implementation
  73. npt.assert_allclose(distfn._entropy(*args),
  74. superclass._entropy(distfn, *args))
  75. def check_entropy_vect_scale(distfn, arg):
  76. # check 2-d
  77. sc = np.asarray([[1, 2], [3, 4]])
  78. v_ent = distfn.entropy(*arg, scale=sc)
  79. s_ent = [distfn.entropy(*arg, scale=s) for s in sc.ravel()]
  80. s_ent = np.asarray(s_ent).reshape(v_ent.shape)
  81. assert_allclose(v_ent, s_ent, atol=1e-14)
  82. # check invalid value, check cast
  83. sc = [1, 2, -3]
  84. v_ent = distfn.entropy(*arg, scale=sc)
  85. s_ent = [distfn.entropy(*arg, scale=s) for s in sc]
  86. s_ent = np.asarray(s_ent).reshape(v_ent.shape)
  87. assert_allclose(v_ent, s_ent, atol=1e-14)
  88. def check_edge_support(distfn, args):
  89. # Make sure that x=self.a and self.b are handled correctly.
  90. x = distfn.support(*args)
  91. if isinstance(distfn, stats.rv_discrete):
  92. x = x[0]-1, x[1]
  93. npt.assert_equal(distfn.cdf(x, *args), [0.0, 1.0])
  94. npt.assert_equal(distfn.sf(x, *args), [1.0, 0.0])
  95. if distfn.name not in ('skellam', 'dlaplace'):
  96. # with a = -inf, log(0) generates warnings
  97. npt.assert_equal(distfn.logcdf(x, *args), [-np.inf, 0.0])
  98. npt.assert_equal(distfn.logsf(x, *args), [0.0, -np.inf])
  99. npt.assert_equal(distfn.ppf([0.0, 1.0], *args), x)
  100. npt.assert_equal(distfn.isf([0.0, 1.0], *args), x[::-1])
  101. # out-of-bounds for isf & ppf
  102. npt.assert_(np.isnan(distfn.isf([-1, 2], *args)).all())
  103. npt.assert_(np.isnan(distfn.ppf([-1, 2], *args)).all())
  104. def check_named_args(distfn, x, shape_args, defaults, meths):
  105. ## Check calling w/ named arguments.
  106. # check consistency of shapes, numargs and _parse signature
  107. signature = _getfullargspec(distfn._parse_args)
  108. npt.assert_(signature.varargs is None)
  109. npt.assert_(signature.varkw is None)
  110. npt.assert_(not signature.kwonlyargs)
  111. npt.assert_(list(signature.defaults) == list(defaults))
  112. shape_argnames = signature.args[:-len(defaults)] # a, b, loc=0, scale=1
  113. if distfn.shapes:
  114. shapes_ = distfn.shapes.replace(',', ' ').split()
  115. else:
  116. shapes_ = ''
  117. npt.assert_(len(shapes_) == distfn.numargs)
  118. npt.assert_(len(shapes_) == len(shape_argnames))
  119. # check calling w/ named arguments
  120. shape_args = list(shape_args)
  121. vals = [meth(x, *shape_args) for meth in meths]
  122. npt.assert_(np.all(np.isfinite(vals)))
  123. names, a, k = shape_argnames[:], shape_args[:], {}
  124. while names:
  125. k.update({names.pop(): a.pop()})
  126. v = [meth(x, *a, **k) for meth in meths]
  127. npt.assert_array_equal(vals, v)
  128. if 'n' not in k.keys():
  129. # `n` is first parameter of moment(), so can't be used as named arg
  130. npt.assert_equal(distfn.moment(1, *a, **k),
  131. distfn.moment(1, *shape_args))
  132. # unknown arguments should not go through:
  133. k.update({'kaboom': 42})
  134. assert_raises(TypeError, distfn.cdf, x, **k)
  135. def check_random_state_property(distfn, args):
  136. # check the random_state attribute of a distribution *instance*
  137. # This test fiddles with distfn.random_state. This breaks other tests,
  138. # hence need to save it and then restore.
  139. rndm = distfn.random_state
  140. # baseline: this relies on the global state
  141. np.random.seed(1234)
  142. distfn.random_state = None
  143. r0 = distfn.rvs(*args, size=8)
  144. # use an explicit instance-level random_state
  145. distfn.random_state = 1234
  146. r1 = distfn.rvs(*args, size=8)
  147. npt.assert_equal(r0, r1)
  148. distfn.random_state = np.random.RandomState(1234)
  149. r2 = distfn.rvs(*args, size=8)
  150. npt.assert_equal(r0, r2)
  151. # check that np.random.Generator can be used (numpy >= 1.17)
  152. if hasattr(np.random, 'default_rng'):
  153. # obtain a np.random.Generator object
  154. rng = np.random.default_rng(1234)
  155. distfn.rvs(*args, size=1, random_state=rng)
  156. # can override the instance-level random_state for an individual .rvs call
  157. distfn.random_state = 2
  158. orig_state = distfn.random_state.get_state()
  159. r3 = distfn.rvs(*args, size=8, random_state=np.random.RandomState(1234))
  160. npt.assert_equal(r0, r3)
  161. # ... and that does not alter the instance-level random_state!
  162. npt.assert_equal(distfn.random_state.get_state(), orig_state)
  163. # finally, restore the random_state
  164. distfn.random_state = rndm
  165. def check_meth_dtype(distfn, arg, meths):
  166. q0 = [0.25, 0.5, 0.75]
  167. x0 = distfn.ppf(q0, *arg)
  168. x_cast = [x0.astype(tp) for tp in
  169. (np.int_, np.float16, np.float32, np.float64)]
  170. for x in x_cast:
  171. # casting may have clipped the values, exclude those
  172. distfn._argcheck(*arg)
  173. x = x[(distfn.a < x) & (x < distfn.b)]
  174. for meth in meths:
  175. val = meth(x, *arg)
  176. npt.assert_(val.dtype == np.float_)
  177. def check_ppf_dtype(distfn, arg):
  178. q0 = np.asarray([0.25, 0.5, 0.75])
  179. q_cast = [q0.astype(tp) for tp in (np.float16, np.float32, np.float64)]
  180. for q in q_cast:
  181. for meth in [distfn.ppf, distfn.isf]:
  182. val = meth(q, *arg)
  183. npt.assert_(val.dtype == np.float_)
  184. def check_cmplx_deriv(distfn, arg):
  185. # Distributions allow complex arguments.
  186. def deriv(f, x, *arg):
  187. x = np.asarray(x)
  188. h = 1e-10
  189. return (f(x + h*1j, *arg)/h).imag
  190. x0 = distfn.ppf([0.25, 0.51, 0.75], *arg)
  191. x_cast = [x0.astype(tp) for tp in
  192. (np.int_, np.float16, np.float32, np.float64)]
  193. for x in x_cast:
  194. # casting may have clipped the values, exclude those
  195. distfn._argcheck(*arg)
  196. x = x[(distfn.a < x) & (x < distfn.b)]
  197. pdf, cdf, sf = distfn.pdf(x, *arg), distfn.cdf(x, *arg), distfn.sf(x, *arg)
  198. assert_allclose(deriv(distfn.cdf, x, *arg), pdf, rtol=1e-5)
  199. assert_allclose(deriv(distfn.logcdf, x, *arg), pdf/cdf, rtol=1e-5)
  200. assert_allclose(deriv(distfn.sf, x, *arg), -pdf, rtol=1e-5)
  201. assert_allclose(deriv(distfn.logsf, x, *arg), -pdf/sf, rtol=1e-5)
  202. assert_allclose(deriv(distfn.logpdf, x, *arg),
  203. deriv(distfn.pdf, x, *arg) / distfn.pdf(x, *arg),
  204. rtol=1e-5)
  205. def check_pickling(distfn, args):
  206. # check that a distribution instance pickles and unpickles
  207. # pay special attention to the random_state property
  208. # save the random_state (restore later)
  209. rndm = distfn.random_state
  210. # check unfrozen
  211. distfn.random_state = 1234
  212. distfn.rvs(*args, size=8)
  213. s = pickle.dumps(distfn)
  214. r0 = distfn.rvs(*args, size=8)
  215. unpickled = pickle.loads(s)
  216. r1 = unpickled.rvs(*args, size=8)
  217. npt.assert_equal(r0, r1)
  218. # also smoke test some methods
  219. medians = [distfn.ppf(0.5, *args), unpickled.ppf(0.5, *args)]
  220. npt.assert_equal(medians[0], medians[1])
  221. npt.assert_equal(distfn.cdf(medians[0], *args),
  222. unpickled.cdf(medians[1], *args))
  223. # check frozen pickling/unpickling with rvs
  224. frozen_dist = distfn(*args)
  225. pkl = pickle.dumps(frozen_dist)
  226. unpickled = pickle.loads(pkl)
  227. r0 = frozen_dist.rvs(size=8)
  228. r1 = unpickled.rvs(size=8)
  229. npt.assert_equal(r0, r1)
  230. # check pickling/unpickling of .fit method
  231. if hasattr(distfn, "fit"):
  232. fit_function = distfn.fit
  233. pickled_fit_function = pickle.dumps(fit_function)
  234. unpickled_fit_function = pickle.loads(pickled_fit_function)
  235. assert fit_function.__name__ == unpickled_fit_function.__name__ == "fit"
  236. # restore the random_state
  237. distfn.random_state = rndm
  238. def check_freezing(distfn, args):
  239. # regression test for gh-11089: freezing a distribution fails
  240. # if loc and/or scale are specified
  241. if isinstance(distfn, stats.rv_continuous):
  242. locscale = {'loc': 1, 'scale': 2}
  243. else:
  244. locscale = {'loc': 1}
  245. rv = distfn(*args, **locscale)
  246. assert rv.a == distfn(*args).a
  247. assert rv.b == distfn(*args).b
  248. def check_rvs_broadcast(distfunc, distname, allargs, shape, shape_only, otype):
  249. np.random.seed(123)
  250. sample = distfunc.rvs(*allargs)
  251. assert_equal(sample.shape, shape, "%s: rvs failed to broadcast" % distname)
  252. if not shape_only:
  253. rvs = np.vectorize(lambda *allargs: distfunc.rvs(*allargs), otypes=otype)
  254. np.random.seed(123)
  255. expected = rvs(*allargs)
  256. assert_allclose(sample, expected, rtol=1e-13)
  257. def check_deprecation_warning_gh5982_moment(distfn, arg, distname):
  258. # See description of cases that need to be tested in the definition of
  259. # scipy.stats.rv_generic.moment
  260. shapes = [] if distfn.shapes is None else distfn.shapes.split(", ")
  261. kwd_shapes = dict(zip(shapes, arg or [])) # dictionary of shape kwds
  262. n = kwd_shapes.pop('n', None)
  263. message1 = "moment() missing 1 required positional argument"
  264. message2 = "_parse_args() missing 1 required positional argument: 'n'"
  265. message3 = "moment() got multiple values for first argument"
  266. if 'n' in shapes:
  267. expected = distfn.mean(n=n, **kwd_shapes)
  268. # A1
  269. res = distfn.moment(1, n=n, **kwd_shapes)
  270. assert_allclose(res, expected)
  271. # A2
  272. with assert_raises(TypeError, match=re.escape(message1)):
  273. distfn.moment(n=n, **kwd_shapes)
  274. # A3
  275. # if `n` is not provided at all
  276. with assert_raises(TypeError, match=re.escape(message2)):
  277. distfn.moment(1, **kwd_shapes)
  278. # if `n` is provided as a positional argument
  279. res = distfn.moment(1, *arg)
  280. assert_allclose(res, expected)
  281. # A4
  282. with assert_raises(TypeError, match=re.escape(message1)):
  283. distfn.moment(**kwd_shapes)
  284. else:
  285. expected = distfn.mean(**kwd_shapes)
  286. # B1
  287. with assert_raises(TypeError, match=re.escape(message3)):
  288. res = distfn.moment(1, n=1, **kwd_shapes)
  289. # B2
  290. with np.testing.assert_warns(DeprecationWarning):
  291. res = distfn.moment(n=1, **kwd_shapes)
  292. assert_allclose(res, expected)
  293. # B3
  294. res = distfn.moment(1, *arg)
  295. assert_allclose(res, expected)
  296. # B4
  297. with assert_raises(TypeError, match=re.escape(message1)):
  298. distfn.moment(**kwd_shapes)
  299. def check_deprecation_warning_gh5982_interval(distfn, arg, distname):
  300. # See description of cases that need to be tested in the definition of
  301. # scipy.stats.rv_generic.moment
  302. shapes = [] if distfn.shapes is None else distfn.shapes.split(", ")
  303. kwd_shapes = dict(zip(shapes, arg or [])) # dictionary of shape kwds
  304. alpha = kwd_shapes.pop('alpha', None)
  305. def my_interval(*args, **kwds):
  306. return (distfn.ppf(0.25, *args, **kwds),
  307. distfn.ppf(0.75, *args, **kwds))
  308. message1 = "interval() missing 1 required positional argument"
  309. message2 = "_parse_args() missing 1 required positional argument: 'alpha'"
  310. message3 = "interval() got multiple values for first argument"
  311. if 'alpha' in shapes:
  312. expected = my_interval(alpha=alpha, **kwd_shapes)
  313. # A1
  314. res = distfn.interval(0.5, alpha=alpha, **kwd_shapes)
  315. assert_allclose(res, expected)
  316. # A2
  317. with assert_raises(TypeError, match=re.escape(message1)):
  318. distfn.interval(alpha=alpha, **kwd_shapes)
  319. # A3
  320. # if `alpha` is not provided at all
  321. with assert_raises(TypeError, match=re.escape(message2)):
  322. distfn.interval(0.5, **kwd_shapes)
  323. # if `alpha` is provided as a positional argument
  324. res = distfn.interval(0.5, *arg)
  325. assert_allclose(res, expected)
  326. # A4
  327. with assert_raises(TypeError, match=re.escape(message1)):
  328. distfn.interval(**kwd_shapes)
  329. else:
  330. expected = my_interval(**kwd_shapes)
  331. # B1
  332. with assert_raises(TypeError, match=re.escape(message3)):
  333. res = distfn.interval(0.5, alpha=1, **kwd_shapes)
  334. # B2
  335. with np.testing.assert_warns(DeprecationWarning):
  336. res = distfn.interval(alpha=0.5, **kwd_shapes)
  337. assert_allclose(res, expected)
  338. # B3
  339. res = distfn.interval(0.5, *arg)
  340. assert_allclose(res, expected)
  341. # B4
  342. with assert_raises(TypeError, match=re.escape(message1)):
  343. distfn.interval(**kwd_shapes)