test_logit.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import numpy as np
  2. from numpy.testing import (assert_equal, assert_almost_equal,
  3. assert_allclose)
  4. from scipy.special import logit, expit, log_expit
  5. class TestLogit:
  6. def check_logit_out(self, dtype, expected):
  7. a = np.linspace(0, 1, 10)
  8. a = np.array(a, dtype=dtype)
  9. with np.errstate(divide='ignore'):
  10. actual = logit(a)
  11. assert_almost_equal(actual, expected)
  12. assert_equal(actual.dtype, np.dtype(dtype))
  13. def test_float32(self):
  14. expected = np.array([-np.inf, -2.07944155,
  15. -1.25276291, -0.69314718,
  16. -0.22314353, 0.22314365,
  17. 0.6931473, 1.25276303,
  18. 2.07944155, np.inf], dtype=np.float32)
  19. self.check_logit_out('f4', expected)
  20. def test_float64(self):
  21. expected = np.array([-np.inf, -2.07944154,
  22. -1.25276297, -0.69314718,
  23. -0.22314355, 0.22314355,
  24. 0.69314718, 1.25276297,
  25. 2.07944154, np.inf])
  26. self.check_logit_out('f8', expected)
  27. def test_nan(self):
  28. expected = np.array([np.nan]*4)
  29. with np.errstate(invalid='ignore'):
  30. actual = logit(np.array([-3., -2., 2., 3.]))
  31. assert_equal(expected, actual)
  32. class TestExpit:
  33. def check_expit_out(self, dtype, expected):
  34. a = np.linspace(-4, 4, 10)
  35. a = np.array(a, dtype=dtype)
  36. actual = expit(a)
  37. assert_almost_equal(actual, expected)
  38. assert_equal(actual.dtype, np.dtype(dtype))
  39. def test_float32(self):
  40. expected = np.array([0.01798621, 0.04265125,
  41. 0.09777259, 0.20860852,
  42. 0.39068246, 0.60931754,
  43. 0.79139149, 0.9022274,
  44. 0.95734876, 0.98201376], dtype=np.float32)
  45. self.check_expit_out('f4', expected)
  46. def test_float64(self):
  47. expected = np.array([0.01798621, 0.04265125,
  48. 0.0977726, 0.20860853,
  49. 0.39068246, 0.60931754,
  50. 0.79139147, 0.9022274,
  51. 0.95734875, 0.98201379])
  52. self.check_expit_out('f8', expected)
  53. def test_large(self):
  54. for dtype in (np.float32, np.float64, np.longdouble):
  55. for n in (88, 89, 709, 710, 11356, 11357):
  56. n = np.array(n, dtype=dtype)
  57. assert_allclose(expit(n), 1.0, atol=1e-20)
  58. assert_allclose(expit(-n), 0.0, atol=1e-20)
  59. assert_equal(expit(n).dtype, dtype)
  60. assert_equal(expit(-n).dtype, dtype)
  61. class TestLogExpit:
  62. def test_large_negative(self):
  63. x = np.array([-10000.0, -750.0, -500.0, -35.0])
  64. y = log_expit(x)
  65. assert_equal(y, x)
  66. def test_large_positive(self):
  67. x = np.array([750.0, 1000.0, 10000.0])
  68. y = log_expit(x)
  69. # y will contain -0.0, and -0.0 is used in the expected value,
  70. # but assert_equal does not check the sign of zeros, and I don't
  71. # think the sign is an essential part of the test (i.e. it would
  72. # probably be OK if log_expit(1000) returned 0.0 instead of -0.0).
  73. assert_equal(y, np.array([-0.0, -0.0, -0.0]))
  74. def test_basic_float64(self):
  75. x = np.array([-32, -20, -10, -3, -1, -0.1, -1e-9,
  76. 0, 1e-9, 0.1, 1, 10, 100, 500, 710, 725, 735])
  77. y = log_expit(x)
  78. #
  79. # Expected values were computed with mpmath:
  80. #
  81. # import mpmath
  82. #
  83. # mpmath.mp.dps = 100
  84. #
  85. # def mp_log_expit(x):
  86. # return -mpmath.log1p(mpmath.exp(-x))
  87. #
  88. # expected = [float(mp_log_expit(t)) for t in x]
  89. #
  90. expected = [-32.000000000000014, -20.000000002061153,
  91. -10.000045398899218, -3.048587351573742,
  92. -1.3132616875182228, -0.7443966600735709,
  93. -0.6931471810599453, -0.6931471805599453,
  94. -0.6931471800599454, -0.6443966600735709,
  95. -0.3132616875182228, -4.539889921686465e-05,
  96. -3.720075976020836e-44, -7.124576406741286e-218,
  97. -4.47628622567513e-309, -1.36930634e-315,
  98. -6.217e-320]
  99. # When tested locally, only one value in y was not exactly equal to
  100. # expected. That was for x=1, and the y value differed from the
  101. # expected by 1 ULP. For this test, however, I'll use rtol=1e-15.
  102. assert_allclose(y, expected, rtol=1e-15)
  103. def test_basic_float32(self):
  104. x = np.array([-32, -20, -10, -3, -1, -0.1, -1e-9,
  105. 0, 1e-9, 0.1, 1, 10, 100], dtype=np.float32)
  106. y = log_expit(x)
  107. #
  108. # Expected values were computed with mpmath:
  109. #
  110. # import mpmath
  111. #
  112. # mpmath.mp.dps = 100
  113. #
  114. # def mp_log_expit(x):
  115. # return -mpmath.log1p(mpmath.exp(-x))
  116. #
  117. # expected = [np.float32(mp_log_expit(t)) for t in x]
  118. #
  119. expected = np.array([-32.0, -20.0, -10.000046, -3.0485873,
  120. -1.3132616, -0.7443967, -0.6931472,
  121. -0.6931472, -0.6931472, -0.64439666,
  122. -0.3132617, -4.5398898e-05, -3.8e-44],
  123. dtype=np.float32)
  124. assert_allclose(y, expected, rtol=5e-7)