12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- """Test how the ufuncs in special handle nan inputs.
- """
- from typing import Callable, Dict
- import numpy as np
- from numpy.testing import assert_array_equal, assert_, suppress_warnings
- import pytest
- import scipy.special as sc
- KNOWNFAILURES: Dict[str, Callable] = {}
- POSTPROCESSING: Dict[str, Callable] = {}
- def _get_ufuncs():
- ufuncs = []
- ufunc_names = []
- for name in sorted(sc.__dict__):
- obj = sc.__dict__[name]
- if not isinstance(obj, np.ufunc):
- continue
- msg = KNOWNFAILURES.get(obj)
- if msg is None:
- ufuncs.append(obj)
- ufunc_names.append(name)
- else:
- fail = pytest.mark.xfail(run=False, reason=msg)
- ufuncs.append(pytest.param(obj, marks=fail))
- ufunc_names.append(name)
- return ufuncs, ufunc_names
- UFUNCS, UFUNC_NAMES = _get_ufuncs()
- @pytest.mark.parametrize("func", UFUNCS, ids=UFUNC_NAMES)
- def test_nan_inputs(func):
- args = (np.nan,)*func.nin
- with suppress_warnings() as sup:
- # Ignore warnings about unsafe casts from legacy wrappers
- sup.filter(RuntimeWarning,
- "floating point number truncated to an integer")
- try:
- with suppress_warnings() as sup:
- sup.filter(DeprecationWarning)
- res = func(*args)
- except TypeError:
- # One of the arguments doesn't take real inputs
- return
- if func in POSTPROCESSING:
- res = POSTPROCESSING[func](*res)
- msg = "got {} instead of nan".format(res)
- assert_array_equal(np.isnan(res), True, err_msg=msg)
- def test_legacy_cast():
- with suppress_warnings() as sup:
- sup.filter(RuntimeWarning,
- "floating point number truncated to an integer")
- res = sc.bdtrc(np.nan, 1, 0.5)
- assert_(np.isnan(res))
|