test_boost_ufuncs.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import pytest
  2. import numpy as np
  3. from numpy.testing import assert_allclose
  4. from scipy.stats import _boost
  5. type_char_to_type_tol = {'f': (np.float32, 32*np.finfo(np.float32).eps),
  6. 'd': (np.float64, 32*np.finfo(np.float64).eps),
  7. 'g': (np.longdouble, 32*np.finfo(np.longdouble).eps)}
  8. # Each item in this list is
  9. # (func, args, expected_value)
  10. # All the values can be represented exactly, even with np.float32.
  11. #
  12. # This is not an exhaustive test data set of all the functions!
  13. # It is a spot check of several functions, primarily for
  14. # checking that the different data types are handled correctly.
  15. test_data = [
  16. (_boost._beta_cdf, (0.5, 2, 3), 0.6875),
  17. (_boost._beta_ppf, (0.6875, 2, 3), 0.5),
  18. (_boost._beta_pdf, (0.5, 2, 3), 1.5),
  19. (_boost._beta_sf, (0.5, 2, 1), 0.75),
  20. (_boost._beta_isf, (0.75, 2, 1), 0.5),
  21. (_boost._binom_cdf, (1, 3, 0.5), 0.5),
  22. (_boost._binom_pdf, (1, 4, 0.5), 0.25),
  23. (_boost._hypergeom_cdf, (2, 3, 5, 6), 0.5),
  24. (_boost._nbinom_cdf, (1, 4, 0.25), 0.015625),
  25. (_boost._ncf_mean, (10, 12, 2.5), 1.5),
  26. ]
  27. @pytest.mark.filterwarnings('ignore::RuntimeWarning')
  28. @pytest.mark.parametrize('func, args, expected', test_data)
  29. def test_stats_boost_ufunc(func, args, expected):
  30. type_sigs = func.types
  31. type_chars = [sig.split('->')[-1] for sig in type_sigs]
  32. for type_char in type_chars:
  33. typ, rtol = type_char_to_type_tol[type_char]
  34. args = [typ(arg) for arg in args]
  35. value = func(*args)
  36. assert isinstance(value, typ)
  37. assert_allclose(value, expected, rtol=rtol)