test_cython_special.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. from __future__ import annotations
  2. from typing import List, Tuple, Callable, Optional
  3. import pytest
  4. from itertools import product
  5. from numpy.testing import assert_allclose, suppress_warnings
  6. from scipy import special
  7. from scipy.special import cython_special
  8. bint_points = [True, False]
  9. int_points = [-10, -1, 1, 10]
  10. real_points = [-10.0, -1.0, 1.0, 10.0]
  11. complex_points = [complex(*tup) for tup in product(real_points, repeat=2)]
  12. CYTHON_SIGNATURE_MAP = {
  13. 'b': 'bint',
  14. 'f': 'float',
  15. 'd': 'double',
  16. 'g': 'long double',
  17. 'F': 'float complex',
  18. 'D': 'double complex',
  19. 'G': 'long double complex',
  20. 'i': 'int',
  21. 'l': 'long'
  22. }
  23. TEST_POINTS = {
  24. 'b': bint_points,
  25. 'f': real_points,
  26. 'd': real_points,
  27. 'g': real_points,
  28. 'F': complex_points,
  29. 'D': complex_points,
  30. 'G': complex_points,
  31. 'i': int_points,
  32. 'l': int_points,
  33. }
  34. PARAMS: List[Tuple[Callable, Callable, Tuple[str, ...], Optional[str]]] = [
  35. (special.agm, cython_special.agm, ('dd',), None),
  36. (special.airy, cython_special._airy_pywrap, ('d', 'D'), None),
  37. (special.airye, cython_special._airye_pywrap, ('d', 'D'), None),
  38. (special.bdtr, cython_special.bdtr, ('dld', 'ddd'), None),
  39. (special.bdtrc, cython_special.bdtrc, ('dld', 'ddd'), None),
  40. (special.bdtri, cython_special.bdtri, ('dld', 'ddd'), None),
  41. (special.bdtrik, cython_special.bdtrik, ('ddd',), None),
  42. (special.bdtrin, cython_special.bdtrin, ('ddd',), None),
  43. (special.bei, cython_special.bei, ('d',), None),
  44. (special.beip, cython_special.beip, ('d',), None),
  45. (special.ber, cython_special.ber, ('d',), None),
  46. (special.berp, cython_special.berp, ('d',), None),
  47. (special.besselpoly, cython_special.besselpoly, ('ddd',), None),
  48. (special.beta, cython_special.beta, ('dd',), None),
  49. (special.betainc, cython_special.betainc, ('ddd',), None),
  50. (special.betaincinv, cython_special.betaincinv, ('ddd',), None),
  51. (special.betaln, cython_special.betaln, ('dd',), None),
  52. (special.binom, cython_special.binom, ('dd',), None),
  53. (special.boxcox, cython_special.boxcox, ('dd',), None),
  54. (special.boxcox1p, cython_special.boxcox1p, ('dd',), None),
  55. (special.btdtr, cython_special.btdtr, ('ddd',), None),
  56. (special.btdtri, cython_special.btdtri, ('ddd',), None),
  57. (special.btdtria, cython_special.btdtria, ('ddd',), None),
  58. (special.btdtrib, cython_special.btdtrib, ('ddd',), None),
  59. (special.cbrt, cython_special.cbrt, ('d',), None),
  60. (special.chdtr, cython_special.chdtr, ('dd',), None),
  61. (special.chdtrc, cython_special.chdtrc, ('dd',), None),
  62. (special.chdtri, cython_special.chdtri, ('dd',), None),
  63. (special.chdtriv, cython_special.chdtriv, ('dd',), None),
  64. (special.chndtr, cython_special.chndtr, ('ddd',), None),
  65. (special.chndtridf, cython_special.chndtridf, ('ddd',), None),
  66. (special.chndtrinc, cython_special.chndtrinc, ('ddd',), None),
  67. (special.chndtrix, cython_special.chndtrix, ('ddd',), None),
  68. (special.cosdg, cython_special.cosdg, ('d',), None),
  69. (special.cosm1, cython_special.cosm1, ('d',), None),
  70. (special.cotdg, cython_special.cotdg, ('d',), None),
  71. (special.dawsn, cython_special.dawsn, ('d', 'D'), None),
  72. (special.ellipe, cython_special.ellipe, ('d',), None),
  73. (special.ellipeinc, cython_special.ellipeinc, ('dd',), None),
  74. (special.ellipj, cython_special._ellipj_pywrap, ('dd',), None),
  75. (special.ellipkinc, cython_special.ellipkinc, ('dd',), None),
  76. (special.ellipkm1, cython_special.ellipkm1, ('d',), None),
  77. (special.ellipk, cython_special.ellipk, ('d',), None),
  78. (special.elliprc, cython_special.elliprc, ('dd', 'DD'), None),
  79. (special.elliprd, cython_special.elliprd, ('ddd', 'DDD'), None),
  80. (special.elliprf, cython_special.elliprf, ('ddd', 'DDD'), None),
  81. (special.elliprg, cython_special.elliprg, ('ddd', 'DDD'), None),
  82. (special.elliprj, cython_special.elliprj, ('dddd', 'DDDD'), None),
  83. (special.entr, cython_special.entr, ('d',), None),
  84. (special.erf, cython_special.erf, ('d', 'D'), None),
  85. (special.erfc, cython_special.erfc, ('d', 'D'), None),
  86. (special.erfcx, cython_special.erfcx, ('d', 'D'), None),
  87. (special.erfi, cython_special.erfi, ('d', 'D'), None),
  88. (special.erfinv, cython_special.erfinv, ('d',), None),
  89. (special.erfcinv, cython_special.erfcinv, ('d',), None),
  90. (special.eval_chebyc, cython_special.eval_chebyc, ('dd', 'dD', 'ld'), None),
  91. (special.eval_chebys, cython_special.eval_chebys, ('dd', 'dD', 'ld'),
  92. 'd and l differ for negative int'),
  93. (special.eval_chebyt, cython_special.eval_chebyt, ('dd', 'dD', 'ld'),
  94. 'd and l differ for negative int'),
  95. (special.eval_chebyu, cython_special.eval_chebyu, ('dd', 'dD', 'ld'),
  96. 'd and l differ for negative int'),
  97. (special.eval_gegenbauer, cython_special.eval_gegenbauer, ('ddd', 'ddD', 'ldd'),
  98. 'd and l differ for negative int'),
  99. (special.eval_genlaguerre, cython_special.eval_genlaguerre, ('ddd', 'ddD', 'ldd'),
  100. 'd and l differ for negative int'),
  101. (special.eval_hermite, cython_special.eval_hermite, ('ld',), None),
  102. (special.eval_hermitenorm, cython_special.eval_hermitenorm, ('ld',), None),
  103. (special.eval_jacobi, cython_special.eval_jacobi, ('dddd', 'dddD', 'lddd'),
  104. 'd and l differ for negative int'),
  105. (special.eval_laguerre, cython_special.eval_laguerre, ('dd', 'dD', 'ld'),
  106. 'd and l differ for negative int'),
  107. (special.eval_legendre, cython_special.eval_legendre, ('dd', 'dD', 'ld'), None),
  108. (special.eval_sh_chebyt, cython_special.eval_sh_chebyt, ('dd', 'dD', 'ld'), None),
  109. (special.eval_sh_chebyu, cython_special.eval_sh_chebyu, ('dd', 'dD', 'ld'),
  110. 'd and l differ for negative int'),
  111. (special.eval_sh_jacobi, cython_special.eval_sh_jacobi, ('dddd', 'dddD', 'lddd'),
  112. 'd and l differ for negative int'),
  113. (special.eval_sh_legendre, cython_special.eval_sh_legendre, ('dd', 'dD', 'ld'), None),
  114. (special.exp1, cython_special.exp1, ('d', 'D'), None),
  115. (special.exp10, cython_special.exp10, ('d',), None),
  116. (special.exp2, cython_special.exp2, ('d',), None),
  117. (special.expi, cython_special.expi, ('d', 'D'), None),
  118. (special.expit, cython_special.expit, ('f', 'd', 'g'), None),
  119. (special.expm1, cython_special.expm1, ('d', 'D'), None),
  120. (special.expn, cython_special.expn, ('ld', 'dd'), None),
  121. (special.exprel, cython_special.exprel, ('d',), None),
  122. (special.fdtr, cython_special.fdtr, ('ddd',), None),
  123. (special.fdtrc, cython_special.fdtrc, ('ddd',), None),
  124. (special.fdtri, cython_special.fdtri, ('ddd',), None),
  125. (special.fdtridfd, cython_special.fdtridfd, ('ddd',), None),
  126. (special.fresnel, cython_special._fresnel_pywrap, ('d', 'D'), None),
  127. (special.gamma, cython_special.gamma, ('d', 'D'), None),
  128. (special.gammainc, cython_special.gammainc, ('dd',), None),
  129. (special.gammaincc, cython_special.gammaincc, ('dd',), None),
  130. (special.gammainccinv, cython_special.gammainccinv, ('dd',), None),
  131. (special.gammaincinv, cython_special.gammaincinv, ('dd',), None),
  132. (special.gammaln, cython_special.gammaln, ('d',), None),
  133. (special.gammasgn, cython_special.gammasgn, ('d',), None),
  134. (special.gdtr, cython_special.gdtr, ('ddd',), None),
  135. (special.gdtrc, cython_special.gdtrc, ('ddd',), None),
  136. (special.gdtria, cython_special.gdtria, ('ddd',), None),
  137. (special.gdtrib, cython_special.gdtrib, ('ddd',), None),
  138. (special.gdtrix, cython_special.gdtrix, ('ddd',), None),
  139. (special.hankel1, cython_special.hankel1, ('dD',), None),
  140. (special.hankel1e, cython_special.hankel1e, ('dD',), None),
  141. (special.hankel2, cython_special.hankel2, ('dD',), None),
  142. (special.hankel2e, cython_special.hankel2e, ('dD',), None),
  143. (special.huber, cython_special.huber, ('dd',), None),
  144. (special.hyp0f1, cython_special.hyp0f1, ('dd', 'dD'), None),
  145. (special.hyp1f1, cython_special.hyp1f1, ('ddd', 'ddD'), None),
  146. (special.hyp2f1, cython_special.hyp2f1, ('dddd', 'dddD'), None),
  147. (special.hyperu, cython_special.hyperu, ('ddd',), None),
  148. (special.i0, cython_special.i0, ('d',), None),
  149. (special.i0e, cython_special.i0e, ('d',), None),
  150. (special.i1, cython_special.i1, ('d',), None),
  151. (special.i1e, cython_special.i1e, ('d',), None),
  152. (special.inv_boxcox, cython_special.inv_boxcox, ('dd',), None),
  153. (special.inv_boxcox1p, cython_special.inv_boxcox1p, ('dd',), None),
  154. (special.it2i0k0, cython_special._it2i0k0_pywrap, ('d',), None),
  155. (special.it2j0y0, cython_special._it2j0y0_pywrap, ('d',), None),
  156. (special.it2struve0, cython_special.it2struve0, ('d',), None),
  157. (special.itairy, cython_special._itairy_pywrap, ('d',), None),
  158. (special.iti0k0, cython_special._iti0k0_pywrap, ('d',), None),
  159. (special.itj0y0, cython_special._itj0y0_pywrap, ('d',), None),
  160. (special.itmodstruve0, cython_special.itmodstruve0, ('d',), None),
  161. (special.itstruve0, cython_special.itstruve0, ('d',), None),
  162. (special.iv, cython_special.iv, ('dd', 'dD'), None),
  163. (special.ive, cython_special.ive, ('dd', 'dD'), None),
  164. (special.j0, cython_special.j0, ('d',), None),
  165. (special.j1, cython_special.j1, ('d',), None),
  166. (special.jv, cython_special.jv, ('dd', 'dD'), None),
  167. (special.jve, cython_special.jve, ('dd', 'dD'), None),
  168. (special.k0, cython_special.k0, ('d',), None),
  169. (special.k0e, cython_special.k0e, ('d',), None),
  170. (special.k1, cython_special.k1, ('d',), None),
  171. (special.k1e, cython_special.k1e, ('d',), None),
  172. (special.kei, cython_special.kei, ('d',), None),
  173. (special.keip, cython_special.keip, ('d',), None),
  174. (special.kelvin, cython_special._kelvin_pywrap, ('d',), None),
  175. (special.ker, cython_special.ker, ('d',), None),
  176. (special.kerp, cython_special.kerp, ('d',), None),
  177. (special.kl_div, cython_special.kl_div, ('dd',), None),
  178. (special.kn, cython_special.kn, ('ld', 'dd'), None),
  179. (special.kolmogi, cython_special.kolmogi, ('d',), None),
  180. (special.kolmogorov, cython_special.kolmogorov, ('d',), None),
  181. (special.kv, cython_special.kv, ('dd', 'dD'), None),
  182. (special.kve, cython_special.kve, ('dd', 'dD'), None),
  183. (special.log1p, cython_special.log1p, ('d', 'D'), None),
  184. (special.log_expit, cython_special.log_expit, ('f', 'd', 'g'), None),
  185. (special.log_ndtr, cython_special.log_ndtr, ('d', 'D'), None),
  186. (special.ndtri_exp, cython_special.ndtri_exp, ('d',), None),
  187. (special.loggamma, cython_special.loggamma, ('D',), None),
  188. (special.logit, cython_special.logit, ('f', 'd', 'g'), None),
  189. (special.lpmv, cython_special.lpmv, ('ddd',), None),
  190. (special.mathieu_a, cython_special.mathieu_a, ('dd',), None),
  191. (special.mathieu_b, cython_special.mathieu_b, ('dd',), None),
  192. (special.mathieu_cem, cython_special._mathieu_cem_pywrap, ('ddd',), None),
  193. (special.mathieu_modcem1, cython_special._mathieu_modcem1_pywrap, ('ddd',), None),
  194. (special.mathieu_modcem2, cython_special._mathieu_modcem2_pywrap, ('ddd',), None),
  195. (special.mathieu_modsem1, cython_special._mathieu_modsem1_pywrap, ('ddd',), None),
  196. (special.mathieu_modsem2, cython_special._mathieu_modsem2_pywrap, ('ddd',), None),
  197. (special.mathieu_sem, cython_special._mathieu_sem_pywrap, ('ddd',), None),
  198. (special.modfresnelm, cython_special._modfresnelm_pywrap, ('d',), None),
  199. (special.modfresnelp, cython_special._modfresnelp_pywrap, ('d',), None),
  200. (special.modstruve, cython_special.modstruve, ('dd',), None),
  201. (special.nbdtr, cython_special.nbdtr, ('lld', 'ddd'), None),
  202. (special.nbdtrc, cython_special.nbdtrc, ('lld', 'ddd'), None),
  203. (special.nbdtri, cython_special.nbdtri, ('lld', 'ddd'), None),
  204. (special.nbdtrik, cython_special.nbdtrik, ('ddd',), None),
  205. (special.nbdtrin, cython_special.nbdtrin, ('ddd',), None),
  206. (special.ncfdtr, cython_special.ncfdtr, ('dddd',), None),
  207. (special.ncfdtri, cython_special.ncfdtri, ('dddd',), None),
  208. (special.ncfdtridfd, cython_special.ncfdtridfd, ('dddd',), None),
  209. (special.ncfdtridfn, cython_special.ncfdtridfn, ('dddd',), None),
  210. (special.ncfdtrinc, cython_special.ncfdtrinc, ('dddd',), None),
  211. (special.nctdtr, cython_special.nctdtr, ('ddd',), None),
  212. (special.nctdtridf, cython_special.nctdtridf, ('ddd',), None),
  213. (special.nctdtrinc, cython_special.nctdtrinc, ('ddd',), None),
  214. (special.nctdtrit, cython_special.nctdtrit, ('ddd',), None),
  215. (special.ndtr, cython_special.ndtr, ('d', 'D'), None),
  216. (special.ndtri, cython_special.ndtri, ('d',), None),
  217. (special.nrdtrimn, cython_special.nrdtrimn, ('ddd',), None),
  218. (special.nrdtrisd, cython_special.nrdtrisd, ('ddd',), None),
  219. (special.obl_ang1, cython_special._obl_ang1_pywrap, ('dddd',), None),
  220. (special.obl_ang1_cv, cython_special._obl_ang1_cv_pywrap, ('ddddd',), None),
  221. (special.obl_cv, cython_special.obl_cv, ('ddd',), None),
  222. (special.obl_rad1, cython_special._obl_rad1_pywrap, ('dddd',), "see gh-6211"),
  223. (special.obl_rad1_cv, cython_special._obl_rad1_cv_pywrap, ('ddddd',), "see gh-6211"),
  224. (special.obl_rad2, cython_special._obl_rad2_pywrap, ('dddd',), "see gh-6211"),
  225. (special.obl_rad2_cv, cython_special._obl_rad2_cv_pywrap, ('ddddd',), "see gh-6211"),
  226. (special.pbdv, cython_special._pbdv_pywrap, ('dd',), None),
  227. (special.pbvv, cython_special._pbvv_pywrap, ('dd',), None),
  228. (special.pbwa, cython_special._pbwa_pywrap, ('dd',), None),
  229. (special.pdtr, cython_special.pdtr, ('dd', 'dd'), None),
  230. (special.pdtrc, cython_special.pdtrc, ('dd', 'dd'), None),
  231. (special.pdtri, cython_special.pdtri, ('ld', 'dd'), None),
  232. (special.pdtrik, cython_special.pdtrik, ('dd',), None),
  233. (special.poch, cython_special.poch, ('dd',), None),
  234. (special.powm1, cython_special.powm1, ('dd',), None),
  235. (special.pro_ang1, cython_special._pro_ang1_pywrap, ('dddd',), None),
  236. (special.pro_ang1_cv, cython_special._pro_ang1_cv_pywrap, ('ddddd',), None),
  237. (special.pro_cv, cython_special.pro_cv, ('ddd',), None),
  238. (special.pro_rad1, cython_special._pro_rad1_pywrap, ('dddd',), "see gh-6211"),
  239. (special.pro_rad1_cv, cython_special._pro_rad1_cv_pywrap, ('ddddd',), "see gh-6211"),
  240. (special.pro_rad2, cython_special._pro_rad2_pywrap, ('dddd',), "see gh-6211"),
  241. (special.pro_rad2_cv, cython_special._pro_rad2_cv_pywrap, ('ddddd',), "see gh-6211"),
  242. (special.pseudo_huber, cython_special.pseudo_huber, ('dd',), None),
  243. (special.psi, cython_special.psi, ('d', 'D'), None),
  244. (special.radian, cython_special.radian, ('ddd',), None),
  245. (special.rel_entr, cython_special.rel_entr, ('dd',), None),
  246. (special.rgamma, cython_special.rgamma, ('d', 'D'), None),
  247. (special.round, cython_special.round, ('d',), None),
  248. (special.spherical_jn, cython_special.spherical_jn, ('ld', 'ldb', 'lD', 'lDb'), None),
  249. (special.spherical_yn, cython_special.spherical_yn, ('ld', 'ldb', 'lD', 'lDb'), None),
  250. (special.spherical_in, cython_special.spherical_in, ('ld', 'ldb', 'lD', 'lDb'), None),
  251. (special.spherical_kn, cython_special.spherical_kn, ('ld', 'ldb', 'lD', 'lDb'), None),
  252. (special.shichi, cython_special._shichi_pywrap, ('d', 'D'), None),
  253. (special.sici, cython_special._sici_pywrap, ('d', 'D'), None),
  254. (special.sindg, cython_special.sindg, ('d',), None),
  255. (special.smirnov, cython_special.smirnov, ('ld', 'dd'), None),
  256. (special.smirnovi, cython_special.smirnovi, ('ld', 'dd'), None),
  257. (special.spence, cython_special.spence, ('d', 'D'), None),
  258. (special.sph_harm, cython_special.sph_harm, ('lldd', 'dddd'), None),
  259. (special.stdtr, cython_special.stdtr, ('dd',), None),
  260. (special.stdtridf, cython_special.stdtridf, ('dd',), None),
  261. (special.stdtrit, cython_special.stdtrit, ('dd',), None),
  262. (special.struve, cython_special.struve, ('dd',), None),
  263. (special.tandg, cython_special.tandg, ('d',), None),
  264. (special.tklmbda, cython_special.tklmbda, ('dd',), None),
  265. (special.voigt_profile, cython_special.voigt_profile, ('ddd',), None),
  266. (special.wofz, cython_special.wofz, ('D',), None),
  267. (special.wright_bessel, cython_special.wright_bessel, ('ddd',), None),
  268. (special.wrightomega, cython_special.wrightomega, ('D',), None),
  269. (special.xlog1py, cython_special.xlog1py, ('dd', 'DD'), None),
  270. (special.xlogy, cython_special.xlogy, ('dd', 'DD'), None),
  271. (special.y0, cython_special.y0, ('d',), None),
  272. (special.y1, cython_special.y1, ('d',), None),
  273. (special.yn, cython_special.yn, ('ld', 'dd'), None),
  274. (special.yv, cython_special.yv, ('dd', 'dD'), None),
  275. (special.yve, cython_special.yve, ('dd', 'dD'), None),
  276. (special.zetac, cython_special.zetac, ('d',), None),
  277. (special.owens_t, cython_special.owens_t, ('dd',), None)
  278. ]
  279. IDS = [x[0].__name__ for x in PARAMS]
  280. def _generate_test_points(typecodes):
  281. axes = tuple(TEST_POINTS[x] for x in typecodes)
  282. pts = list(product(*axes))
  283. return pts
  284. def test_cython_api_completeness():
  285. # Check that everything is tested
  286. for name in dir(cython_special):
  287. func = getattr(cython_special, name)
  288. if callable(func) and not name.startswith('_'):
  289. for _, cyfun, _, _ in PARAMS:
  290. if cyfun is func:
  291. break
  292. else:
  293. raise RuntimeError(f"{name} missing from tests!")
  294. @pytest.mark.parametrize("param", PARAMS, ids=IDS)
  295. def test_cython_api(param):
  296. pyfunc, cyfunc, specializations, knownfailure = param
  297. if knownfailure:
  298. pytest.xfail(reason=knownfailure)
  299. # Check which parameters are expected to be fused types
  300. max_params = max(len(spec) for spec in specializations)
  301. values = [set() for _ in range(max_params)]
  302. for typecodes in specializations:
  303. for j, v in enumerate(typecodes):
  304. values[j].add(v)
  305. seen = set()
  306. is_fused_code = [False] * len(values)
  307. for j, v in enumerate(values):
  308. vv = tuple(sorted(v))
  309. if vv in seen:
  310. continue
  311. is_fused_code[j] = (len(v) > 1)
  312. seen.add(vv)
  313. # Check results
  314. for typecodes in specializations:
  315. # Pick the correct specialized function
  316. signature = [CYTHON_SIGNATURE_MAP[code]
  317. for j, code in enumerate(typecodes)
  318. if is_fused_code[j]]
  319. if signature:
  320. cy_spec_func = cyfunc[tuple(signature)]
  321. else:
  322. signature = None
  323. cy_spec_func = cyfunc
  324. # Test it
  325. pts = _generate_test_points(typecodes)
  326. for pt in pts:
  327. with suppress_warnings() as sup:
  328. sup.filter(DeprecationWarning)
  329. pyval = pyfunc(*pt)
  330. cyval = cy_spec_func(*pt)
  331. assert_allclose(cyval, pyval, err_msg="{} {} {}".format(pt, typecodes, signature))