test_nan_inputs.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. """Test how the ufuncs in special handle nan inputs.
  2. """
  3. from typing import Callable, Dict
  4. import numpy as np
  5. from numpy.testing import assert_array_equal, assert_, suppress_warnings
  6. import pytest
  7. import scipy.special as sc
  8. KNOWNFAILURES: Dict[str, Callable] = {}
  9. POSTPROCESSING: Dict[str, Callable] = {}
  10. def _get_ufuncs():
  11. ufuncs = []
  12. ufunc_names = []
  13. for name in sorted(sc.__dict__):
  14. obj = sc.__dict__[name]
  15. if not isinstance(obj, np.ufunc):
  16. continue
  17. msg = KNOWNFAILURES.get(obj)
  18. if msg is None:
  19. ufuncs.append(obj)
  20. ufunc_names.append(name)
  21. else:
  22. fail = pytest.mark.xfail(run=False, reason=msg)
  23. ufuncs.append(pytest.param(obj, marks=fail))
  24. ufunc_names.append(name)
  25. return ufuncs, ufunc_names
  26. UFUNCS, UFUNC_NAMES = _get_ufuncs()
  27. @pytest.mark.parametrize("func", UFUNCS, ids=UFUNC_NAMES)
  28. def test_nan_inputs(func):
  29. args = (np.nan,)*func.nin
  30. with suppress_warnings() as sup:
  31. # Ignore warnings about unsafe casts from legacy wrappers
  32. sup.filter(RuntimeWarning,
  33. "floating point number truncated to an integer")
  34. try:
  35. with suppress_warnings() as sup:
  36. sup.filter(DeprecationWarning)
  37. res = func(*args)
  38. except TypeError:
  39. # One of the arguments doesn't take real inputs
  40. return
  41. if func in POSTPROCESSING:
  42. res = POSTPROCESSING[func](*res)
  43. msg = "got {} instead of nan".format(res)
  44. assert_array_equal(np.isnan(res), True, err_msg=msg)
  45. def test_legacy_cast():
  46. with suppress_warnings() as sup:
  47. sup.filter(RuntimeWarning,
  48. "floating point number truncated to an integer")
  49. res = sc.bdtrc(np.nan, 1, 0.5)
  50. assert_(np.isnan(res))