test_digamma.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import numpy as np
  2. from numpy import pi, log, sqrt
  3. from numpy.testing import assert_, assert_equal
  4. from scipy.special._testutils import FuncData
  5. import scipy.special as sc
  6. # Euler-Mascheroni constant
  7. euler = 0.57721566490153286
  8. def test_consistency():
  9. # Make sure the implementation of digamma for real arguments
  10. # agrees with the implementation of digamma for complex arguments.
  11. # It's all poles after -1e16
  12. x = np.r_[-np.logspace(15, -30, 200), np.logspace(-30, 300, 200)]
  13. dataset = np.vstack((x + 0j, sc.digamma(x))).T
  14. FuncData(sc.digamma, dataset, 0, 1, rtol=5e-14, nan_ok=True).check()
  15. def test_special_values():
  16. # Test special values from Gauss's digamma theorem. See
  17. #
  18. # https://en.wikipedia.org/wiki/Digamma_function
  19. dataset = [(1, -euler),
  20. (0.5, -2*log(2) - euler),
  21. (1/3, -pi/(2*sqrt(3)) - 3*log(3)/2 - euler),
  22. (1/4, -pi/2 - 3*log(2) - euler),
  23. (1/6, -pi*sqrt(3)/2 - 2*log(2) - 3*log(3)/2 - euler),
  24. (1/8, -pi/2 - 4*log(2) - (pi + log(2 + sqrt(2)) - log(2 - sqrt(2)))/sqrt(2) - euler)]
  25. dataset = np.asarray(dataset)
  26. FuncData(sc.digamma, dataset, 0, 1, rtol=1e-14).check()
  27. def test_nonfinite():
  28. pts = [0.0, -0.0, np.inf]
  29. std = [-np.inf, np.inf, np.inf]
  30. assert_equal(sc.digamma(pts), std)
  31. assert_(all(np.isnan(sc.digamma([-np.inf, -1]))))