test_sf_error.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import sys
  2. import warnings
  3. from numpy.testing import assert_, assert_equal, IS_PYPY
  4. import pytest
  5. from pytest import raises as assert_raises
  6. import scipy.special as sc
  7. from scipy.special._ufuncs import _sf_error_test_function
  8. _sf_error_code_map = {
  9. # skip 'ok'
  10. 'singular': 1,
  11. 'underflow': 2,
  12. 'overflow': 3,
  13. 'slow': 4,
  14. 'loss': 5,
  15. 'no_result': 6,
  16. 'domain': 7,
  17. 'arg': 8,
  18. 'other': 9
  19. }
  20. _sf_error_actions = [
  21. 'ignore',
  22. 'warn',
  23. 'raise'
  24. ]
  25. def _check_action(fun, args, action):
  26. if action == 'warn':
  27. with pytest.warns(sc.SpecialFunctionWarning):
  28. fun(*args)
  29. elif action == 'raise':
  30. with assert_raises(sc.SpecialFunctionError):
  31. fun(*args)
  32. else:
  33. # action == 'ignore', make sure there are no warnings/exceptions
  34. with warnings.catch_warnings():
  35. warnings.simplefilter("error")
  36. fun(*args)
  37. def test_geterr():
  38. err = sc.geterr()
  39. for key, value in err.items():
  40. assert_(key in _sf_error_code_map)
  41. assert_(value in _sf_error_actions)
  42. def test_seterr():
  43. entry_err = sc.geterr()
  44. try:
  45. for category, error_code in _sf_error_code_map.items():
  46. for action in _sf_error_actions:
  47. geterr_olderr = sc.geterr()
  48. seterr_olderr = sc.seterr(**{category: action})
  49. assert_(geterr_olderr == seterr_olderr)
  50. newerr = sc.geterr()
  51. assert_(newerr[category] == action)
  52. geterr_olderr.pop(category)
  53. newerr.pop(category)
  54. assert_(geterr_olderr == newerr)
  55. _check_action(_sf_error_test_function, (error_code,), action)
  56. finally:
  57. sc.seterr(**entry_err)
  58. @pytest.mark.skipif(IS_PYPY, reason="Test not meaningful on PyPy")
  59. def test_sf_error_special_refcount():
  60. # Regression test for gh-16233.
  61. # Check that the reference count of scipy.special is not increased
  62. # when a SpecialFunctionError is raised.
  63. refcount_before = sys.getrefcount(sc)
  64. with sc.errstate(all='raise'):
  65. with pytest.raises(sc.SpecialFunctionError, match='domain error'):
  66. sc.ndtri(2.0)
  67. refcount_after = sys.getrefcount(sc)
  68. assert refcount_after == refcount_before
  69. def test_errstate_pyx_basic():
  70. olderr = sc.geterr()
  71. with sc.errstate(singular='raise'):
  72. with assert_raises(sc.SpecialFunctionError):
  73. sc.loggamma(0)
  74. assert_equal(olderr, sc.geterr())
  75. def test_errstate_c_basic():
  76. olderr = sc.geterr()
  77. with sc.errstate(domain='raise'):
  78. with assert_raises(sc.SpecialFunctionError):
  79. sc.spence(-1)
  80. assert_equal(olderr, sc.geterr())
  81. def test_errstate_cpp_basic():
  82. olderr = sc.geterr()
  83. with sc.errstate(underflow='raise'):
  84. with assert_raises(sc.SpecialFunctionError):
  85. sc.wrightomega(-1000)
  86. assert_equal(olderr, sc.geterr())
  87. def test_errstate():
  88. for category, error_code in _sf_error_code_map.items():
  89. for action in _sf_error_actions:
  90. olderr = sc.geterr()
  91. with sc.errstate(**{category: action}):
  92. _check_action(_sf_error_test_function, (error_code,), action)
  93. assert_equal(olderr, sc.geterr())
  94. def test_errstate_all_but_one():
  95. olderr = sc.geterr()
  96. with sc.errstate(all='raise', singular='ignore'):
  97. sc.gammaln(0)
  98. with assert_raises(sc.SpecialFunctionError):
  99. sc.spence(-1.0)
  100. assert_equal(olderr, sc.geterr())