test_cdflib.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. """
  2. Test cdflib functions versus mpmath, if available.
  3. The following functions still need tests:
  4. - ncfdtr
  5. - ncfdtri
  6. - ncfdtridfn
  7. - ncfdtridfd
  8. - ncfdtrinc
  9. - nbdtrik
  10. - nbdtrin
  11. - nrdtrimn
  12. - nrdtrisd
  13. - pdtrik
  14. - nctdtr
  15. - nctdtrit
  16. - nctdtridf
  17. - nctdtrinc
  18. """
  19. import itertools
  20. import numpy as np
  21. from numpy.testing import assert_equal, assert_allclose
  22. import pytest
  23. import scipy.special as sp
  24. from scipy.special._testutils import (
  25. MissingModule, check_version, FuncData)
  26. from scipy.special._mptestutils import (
  27. Arg, IntArg, get_args, mpf2float, assert_mpmath_equal)
  28. try:
  29. import mpmath
  30. except ImportError:
  31. mpmath = MissingModule('mpmath')
  32. class ProbArg:
  33. """Generate a set of probabilities on [0, 1]."""
  34. def __init__(self):
  35. # Include the endpoints for compatibility with Arg et. al.
  36. self.a = 0
  37. self.b = 1
  38. def values(self, n):
  39. """Return an array containing approximatively n numbers."""
  40. m = max(1, n//3)
  41. v1 = np.logspace(-30, np.log10(0.3), m)
  42. v2 = np.linspace(0.3, 0.7, m + 1, endpoint=False)[1:]
  43. v3 = 1 - np.logspace(np.log10(0.3), -15, m)
  44. v = np.r_[v1, v2, v3]
  45. return np.unique(v)
  46. class EndpointFilter:
  47. def __init__(self, a, b, rtol, atol):
  48. self.a = a
  49. self.b = b
  50. self.rtol = rtol
  51. self.atol = atol
  52. def __call__(self, x):
  53. mask1 = np.abs(x - self.a) < self.rtol*np.abs(self.a) + self.atol
  54. mask2 = np.abs(x - self.b) < self.rtol*np.abs(self.b) + self.atol
  55. return np.where(mask1 | mask2, False, True)
  56. class _CDFData:
  57. def __init__(self, spfunc, mpfunc, index, argspec, spfunc_first=True,
  58. dps=20, n=5000, rtol=None, atol=None,
  59. endpt_rtol=None, endpt_atol=None):
  60. self.spfunc = spfunc
  61. self.mpfunc = mpfunc
  62. self.index = index
  63. self.argspec = argspec
  64. self.spfunc_first = spfunc_first
  65. self.dps = dps
  66. self.n = n
  67. self.rtol = rtol
  68. self.atol = atol
  69. if not isinstance(argspec, list):
  70. self.endpt_rtol = None
  71. self.endpt_atol = None
  72. elif endpt_rtol is not None or endpt_atol is not None:
  73. if isinstance(endpt_rtol, list):
  74. self.endpt_rtol = endpt_rtol
  75. else:
  76. self.endpt_rtol = [endpt_rtol]*len(self.argspec)
  77. if isinstance(endpt_atol, list):
  78. self.endpt_atol = endpt_atol
  79. else:
  80. self.endpt_atol = [endpt_atol]*len(self.argspec)
  81. else:
  82. self.endpt_rtol = None
  83. self.endpt_atol = None
  84. def idmap(self, *args):
  85. if self.spfunc_first:
  86. res = self.spfunc(*args)
  87. if np.isnan(res):
  88. return np.nan
  89. args = list(args)
  90. args[self.index] = res
  91. with mpmath.workdps(self.dps):
  92. res = self.mpfunc(*tuple(args))
  93. # Imaginary parts are spurious
  94. res = mpf2float(res.real)
  95. else:
  96. with mpmath.workdps(self.dps):
  97. res = self.mpfunc(*args)
  98. res = mpf2float(res.real)
  99. args = list(args)
  100. args[self.index] = res
  101. res = self.spfunc(*tuple(args))
  102. return res
  103. def get_param_filter(self):
  104. if self.endpt_rtol is None and self.endpt_atol is None:
  105. return None
  106. filters = []
  107. for rtol, atol, spec in zip(self.endpt_rtol, self.endpt_atol, self.argspec):
  108. if rtol is None and atol is None:
  109. filters.append(None)
  110. continue
  111. elif rtol is None:
  112. rtol = 0.0
  113. elif atol is None:
  114. atol = 0.0
  115. filters.append(EndpointFilter(spec.a, spec.b, rtol, atol))
  116. return filters
  117. def check(self):
  118. # Generate values for the arguments
  119. args = get_args(self.argspec, self.n)
  120. param_filter = self.get_param_filter()
  121. param_columns = tuple(range(args.shape[1]))
  122. result_columns = args.shape[1]
  123. args = np.hstack((args, args[:,self.index].reshape(args.shape[0], 1)))
  124. FuncData(self.idmap, args,
  125. param_columns=param_columns, result_columns=result_columns,
  126. rtol=self.rtol, atol=self.atol, vectorized=False,
  127. param_filter=param_filter).check()
  128. def _assert_inverts(*a, **kw):
  129. d = _CDFData(*a, **kw)
  130. d.check()
  131. def _binomial_cdf(k, n, p):
  132. k, n, p = mpmath.mpf(k), mpmath.mpf(n), mpmath.mpf(p)
  133. if k <= 0:
  134. return mpmath.mpf(0)
  135. elif k >= n:
  136. return mpmath.mpf(1)
  137. onemp = mpmath.fsub(1, p, exact=True)
  138. return mpmath.betainc(n - k, k + 1, x2=onemp, regularized=True)
  139. def _f_cdf(dfn, dfd, x):
  140. if x < 0:
  141. return mpmath.mpf(0)
  142. dfn, dfd, x = mpmath.mpf(dfn), mpmath.mpf(dfd), mpmath.mpf(x)
  143. ub = dfn*x/(dfn*x + dfd)
  144. res = mpmath.betainc(dfn/2, dfd/2, x2=ub, regularized=True)
  145. return res
  146. def _student_t_cdf(df, t, dps=None):
  147. if dps is None:
  148. dps = mpmath.mp.dps
  149. with mpmath.workdps(dps):
  150. df, t = mpmath.mpf(df), mpmath.mpf(t)
  151. fac = mpmath.hyp2f1(0.5, 0.5*(df + 1), 1.5, -t**2/df)
  152. fac *= t*mpmath.gamma(0.5*(df + 1))
  153. fac /= mpmath.sqrt(mpmath.pi*df)*mpmath.gamma(0.5*df)
  154. return 0.5 + fac
  155. def _noncentral_chi_pdf(t, df, nc):
  156. res = mpmath.besseli(df/2 - 1, mpmath.sqrt(nc*t))
  157. res *= mpmath.exp(-(t + nc)/2)*(t/nc)**(df/4 - 1/2)/2
  158. return res
  159. def _noncentral_chi_cdf(x, df, nc, dps=None):
  160. if dps is None:
  161. dps = mpmath.mp.dps
  162. x, df, nc = mpmath.mpf(x), mpmath.mpf(df), mpmath.mpf(nc)
  163. with mpmath.workdps(dps):
  164. res = mpmath.quad(lambda t: _noncentral_chi_pdf(t, df, nc), [0, x])
  165. return res
  166. def _tukey_lmbda_quantile(p, lmbda):
  167. # For lmbda != 0
  168. return (p**lmbda - (1 - p)**lmbda)/lmbda
  169. @pytest.mark.slow
  170. @check_version(mpmath, '0.19')
  171. class TestCDFlib:
  172. @pytest.mark.xfail(run=False)
  173. def test_bdtrik(self):
  174. _assert_inverts(
  175. sp.bdtrik,
  176. _binomial_cdf,
  177. 0, [ProbArg(), IntArg(1, 1000), ProbArg()],
  178. rtol=1e-4)
  179. def test_bdtrin(self):
  180. _assert_inverts(
  181. sp.bdtrin,
  182. _binomial_cdf,
  183. 1, [IntArg(1, 1000), ProbArg(), ProbArg()],
  184. rtol=1e-4, endpt_atol=[None, None, 1e-6])
  185. def test_btdtria(self):
  186. _assert_inverts(
  187. sp.btdtria,
  188. lambda a, b, x: mpmath.betainc(a, b, x2=x, regularized=True),
  189. 0, [ProbArg(), Arg(0, 1e2, inclusive_a=False),
  190. Arg(0, 1, inclusive_a=False, inclusive_b=False)],
  191. rtol=1e-6)
  192. def test_btdtrib(self):
  193. # Use small values of a or mpmath doesn't converge
  194. _assert_inverts(
  195. sp.btdtrib,
  196. lambda a, b, x: mpmath.betainc(a, b, x2=x, regularized=True),
  197. 1, [Arg(0, 1e2, inclusive_a=False), ProbArg(),
  198. Arg(0, 1, inclusive_a=False, inclusive_b=False)],
  199. rtol=1e-7, endpt_atol=[None, 1e-18, 1e-15])
  200. @pytest.mark.xfail(run=False)
  201. def test_fdtridfd(self):
  202. _assert_inverts(
  203. sp.fdtridfd,
  204. _f_cdf,
  205. 1, [IntArg(1, 100), ProbArg(), Arg(0, 100, inclusive_a=False)],
  206. rtol=1e-7)
  207. def test_gdtria(self):
  208. _assert_inverts(
  209. sp.gdtria,
  210. lambda a, b, x: mpmath.gammainc(b, b=a*x, regularized=True),
  211. 0, [ProbArg(), Arg(0, 1e3, inclusive_a=False),
  212. Arg(0, 1e4, inclusive_a=False)], rtol=1e-7,
  213. endpt_atol=[None, 1e-7, 1e-10])
  214. def test_gdtrib(self):
  215. # Use small values of a and x or mpmath doesn't converge
  216. _assert_inverts(
  217. sp.gdtrib,
  218. lambda a, b, x: mpmath.gammainc(b, b=a*x, regularized=True),
  219. 1, [Arg(0, 1e2, inclusive_a=False), ProbArg(),
  220. Arg(0, 1e3, inclusive_a=False)], rtol=1e-5)
  221. def test_gdtrix(self):
  222. _assert_inverts(
  223. sp.gdtrix,
  224. lambda a, b, x: mpmath.gammainc(b, b=a*x, regularized=True),
  225. 2, [Arg(0, 1e3, inclusive_a=False), Arg(0, 1e3, inclusive_a=False),
  226. ProbArg()], rtol=1e-7,
  227. endpt_atol=[None, 1e-7, 1e-10])
  228. def test_stdtr(self):
  229. # Ideally the left endpoint for Arg() should be 0.
  230. assert_mpmath_equal(
  231. sp.stdtr,
  232. _student_t_cdf,
  233. [IntArg(1, 100), Arg(1e-10, np.inf)], rtol=1e-7)
  234. @pytest.mark.xfail(run=False)
  235. def test_stdtridf(self):
  236. _assert_inverts(
  237. sp.stdtridf,
  238. _student_t_cdf,
  239. 0, [ProbArg(), Arg()], rtol=1e-7)
  240. def test_stdtrit(self):
  241. _assert_inverts(
  242. sp.stdtrit,
  243. _student_t_cdf,
  244. 1, [IntArg(1, 100), ProbArg()], rtol=1e-7,
  245. endpt_atol=[None, 1e-10])
  246. def test_chdtriv(self):
  247. _assert_inverts(
  248. sp.chdtriv,
  249. lambda v, x: mpmath.gammainc(v/2, b=x/2, regularized=True),
  250. 0, [ProbArg(), IntArg(1, 100)], rtol=1e-4)
  251. @pytest.mark.xfail(run=False)
  252. def test_chndtridf(self):
  253. # Use a larger atol since mpmath is doing numerical integration
  254. _assert_inverts(
  255. sp.chndtridf,
  256. _noncentral_chi_cdf,
  257. 1, [Arg(0, 100, inclusive_a=False), ProbArg(),
  258. Arg(0, 100, inclusive_a=False)],
  259. n=1000, rtol=1e-4, atol=1e-15)
  260. @pytest.mark.xfail(run=False)
  261. def test_chndtrinc(self):
  262. # Use a larger atol since mpmath is doing numerical integration
  263. _assert_inverts(
  264. sp.chndtrinc,
  265. _noncentral_chi_cdf,
  266. 2, [Arg(0, 100, inclusive_a=False), IntArg(1, 100), ProbArg()],
  267. n=1000, rtol=1e-4, atol=1e-15)
  268. def test_chndtrix(self):
  269. # Use a larger atol since mpmath is doing numerical integration
  270. _assert_inverts(
  271. sp.chndtrix,
  272. _noncentral_chi_cdf,
  273. 0, [ProbArg(), IntArg(1, 100), Arg(0, 100, inclusive_a=False)],
  274. n=1000, rtol=1e-4, atol=1e-15,
  275. endpt_atol=[1e-6, None, None])
  276. def test_tklmbda_zero_shape(self):
  277. # When lmbda = 0 the CDF has a simple closed form
  278. one = mpmath.mpf(1)
  279. assert_mpmath_equal(
  280. lambda x: sp.tklmbda(x, 0),
  281. lambda x: one/(mpmath.exp(-x) + one),
  282. [Arg()], rtol=1e-7)
  283. def test_tklmbda_neg_shape(self):
  284. _assert_inverts(
  285. sp.tklmbda,
  286. _tukey_lmbda_quantile,
  287. 0, [ProbArg(), Arg(-25, 0, inclusive_b=False)],
  288. spfunc_first=False, rtol=1e-5,
  289. endpt_atol=[1e-9, 1e-5])
  290. @pytest.mark.xfail(run=False)
  291. def test_tklmbda_pos_shape(self):
  292. _assert_inverts(
  293. sp.tklmbda,
  294. _tukey_lmbda_quantile,
  295. 0, [ProbArg(), Arg(0, 100, inclusive_a=False)],
  296. spfunc_first=False, rtol=1e-5)
  297. def test_nonfinite():
  298. funcs = [
  299. ("btdtria", 3),
  300. ("btdtrib", 3),
  301. ("bdtrik", 3),
  302. ("bdtrin", 3),
  303. ("chdtriv", 2),
  304. ("chndtr", 3),
  305. ("chndtrix", 3),
  306. ("chndtridf", 3),
  307. ("chndtrinc", 3),
  308. ("fdtridfd", 3),
  309. ("ncfdtr", 4),
  310. ("ncfdtri", 4),
  311. ("ncfdtridfn", 4),
  312. ("ncfdtridfd", 4),
  313. ("ncfdtrinc", 4),
  314. ("gdtrix", 3),
  315. ("gdtrib", 3),
  316. ("gdtria", 3),
  317. ("nbdtrik", 3),
  318. ("nbdtrin", 3),
  319. ("nrdtrimn", 3),
  320. ("nrdtrisd", 3),
  321. ("pdtrik", 2),
  322. ("stdtr", 2),
  323. ("stdtrit", 2),
  324. ("stdtridf", 2),
  325. ("nctdtr", 3),
  326. ("nctdtrit", 3),
  327. ("nctdtridf", 3),
  328. ("nctdtrinc", 3),
  329. ("tklmbda", 2),
  330. ]
  331. np.random.seed(1)
  332. for func, numargs in funcs:
  333. func = getattr(sp, func)
  334. args_choices = [(float(x), np.nan, np.inf, -np.inf) for x in
  335. np.random.rand(numargs)]
  336. for args in itertools.product(*args_choices):
  337. res = func(*args)
  338. if any(np.isnan(x) for x in args):
  339. # Nan inputs should result to nan output
  340. assert_equal(res, np.nan)
  341. else:
  342. # All other inputs should return something (but not
  343. # raise exceptions or cause hangs)
  344. pass
  345. def test_chndtrix_gh2158():
  346. # test that gh-2158 is resolved; previously this blew up
  347. res = sp.chndtrix(0.999999, 2, np.arange(20.)+1e-6)
  348. # Generated in R
  349. # options(digits=16)
  350. # ncp <- seq(0, 19) + 1e-6
  351. # print(qchisq(0.999999, df = 2, ncp = ncp))
  352. res_exp = [27.63103493142305, 35.25728589950540, 39.97396073236288,
  353. 43.88033702110538, 47.35206403482798, 50.54112500166103,
  354. 53.52720257322766, 56.35830042867810, 59.06600769498512,
  355. 61.67243118946381, 64.19376191277179, 66.64228141346548,
  356. 69.02756927200180, 71.35726934749408, 73.63759723904816,
  357. 75.87368842650227, 78.06984431185720, 80.22971052389806,
  358. 82.35640899964173, 84.45263768373256]
  359. assert_allclose(res, res_exp)