test_bdtr.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import numpy as np
  2. import scipy.special as sc
  3. import pytest
  4. from numpy.testing import assert_allclose, assert_array_equal, suppress_warnings
  5. class TestBdtr:
  6. def test(self):
  7. val = sc.bdtr(0, 1, 0.5)
  8. assert_allclose(val, 0.5)
  9. def test_sum_is_one(self):
  10. val = sc.bdtr([0, 1, 2], 2, 0.5)
  11. assert_array_equal(val, [0.25, 0.75, 1.0])
  12. def test_rounding(self):
  13. double_val = sc.bdtr([0.1, 1.1, 2.1], 2, 0.5)
  14. int_val = sc.bdtr([0, 1, 2], 2, 0.5)
  15. assert_array_equal(double_val, int_val)
  16. @pytest.mark.parametrize('k, n, p', [
  17. (np.inf, 2, 0.5),
  18. (1.0, np.inf, 0.5),
  19. (1.0, 2, np.inf)
  20. ])
  21. def test_inf(self, k, n, p):
  22. with suppress_warnings() as sup:
  23. sup.filter(DeprecationWarning)
  24. val = sc.bdtr(k, n, p)
  25. assert np.isnan(val)
  26. def test_domain(self):
  27. val = sc.bdtr(-1.1, 1, 0.5)
  28. assert np.isnan(val)
  29. class TestBdtrc:
  30. def test_value(self):
  31. val = sc.bdtrc(0, 1, 0.5)
  32. assert_allclose(val, 0.5)
  33. def test_sum_is_one(self):
  34. val = sc.bdtrc([0, 1, 2], 2, 0.5)
  35. assert_array_equal(val, [0.75, 0.25, 0.0])
  36. def test_rounding(self):
  37. double_val = sc.bdtrc([0.1, 1.1, 2.1], 2, 0.5)
  38. int_val = sc.bdtrc([0, 1, 2], 2, 0.5)
  39. assert_array_equal(double_val, int_val)
  40. @pytest.mark.parametrize('k, n, p', [
  41. (np.inf, 2, 0.5),
  42. (1.0, np.inf, 0.5),
  43. (1.0, 2, np.inf)
  44. ])
  45. def test_inf(self, k, n, p):
  46. with suppress_warnings() as sup:
  47. sup.filter(DeprecationWarning)
  48. val = sc.bdtrc(k, n, p)
  49. assert np.isnan(val)
  50. def test_domain(self):
  51. val = sc.bdtrc(-1.1, 1, 0.5)
  52. val2 = sc.bdtrc(2.1, 1, 0.5)
  53. assert np.isnan(val2)
  54. assert_allclose(val, 1.0)
  55. def test_bdtr_bdtrc_sum_to_one(self):
  56. bdtr_vals = sc.bdtr([0, 1, 2], 2, 0.5)
  57. bdtrc_vals = sc.bdtrc([0, 1, 2], 2, 0.5)
  58. vals = bdtr_vals + bdtrc_vals
  59. assert_allclose(vals, [1.0, 1.0, 1.0])
  60. class TestBdtri:
  61. def test_value(self):
  62. val = sc.bdtri(0, 1, 0.5)
  63. assert_allclose(val, 0.5)
  64. def test_sum_is_one(self):
  65. val = sc.bdtri([0, 1], 2, 0.5)
  66. actual = np.asarray([1 - 1/np.sqrt(2), 1/np.sqrt(2)])
  67. assert_allclose(val, actual)
  68. def test_rounding(self):
  69. double_val = sc.bdtri([0.1, 1.1], 2, 0.5)
  70. int_val = sc.bdtri([0, 1], 2, 0.5)
  71. assert_allclose(double_val, int_val)
  72. @pytest.mark.parametrize('k, n, p', [
  73. (np.inf, 2, 0.5),
  74. (1.0, np.inf, 0.5),
  75. (1.0, 2, np.inf)
  76. ])
  77. def test_inf(self, k, n, p):
  78. with suppress_warnings() as sup:
  79. sup.filter(DeprecationWarning)
  80. val = sc.bdtri(k, n, p)
  81. assert np.isnan(val)
  82. @pytest.mark.parametrize('k, n, p', [
  83. (-1.1, 1, 0.5),
  84. (2.1, 1, 0.5)
  85. ])
  86. def test_domain(self, k, n, p):
  87. val = sc.bdtri(k, n, p)
  88. assert np.isnan(val)
  89. def test_bdtr_bdtri_roundtrip(self):
  90. bdtr_vals = sc.bdtr([0, 1, 2], 2, 0.5)
  91. roundtrip_vals = sc.bdtri([0, 1, 2], 2, bdtr_vals)
  92. assert_allclose(roundtrip_vals, [0.5, 0.5, np.nan])