test_zeta_functions.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. from sympy.concrete.summations import Sum
  2. from sympy.core.function import expand_func
  3. from sympy.core.numbers import (Float, I, Rational, nan, oo, pi, zoo)
  4. from sympy.core.singleton import S
  5. from sympy.core.symbol import Symbol
  6. from sympy.functions.elementary.complexes import (Abs, polar_lift)
  7. from sympy.functions.elementary.exponential import (exp, exp_polar, log)
  8. from sympy.functions.elementary.miscellaneous import sqrt
  9. from sympy.functions.special.zeta_functions import (dirichlet_eta, lerchphi, polylog, riemann_xi, stieltjes, zeta)
  10. from sympy.series.order import O
  11. from sympy.core.function import ArgumentIndexError
  12. from sympy.functions.combinatorial.numbers import bernoulli, factorial, genocchi, harmonic
  13. from sympy.testing.pytest import raises
  14. from sympy.core.random import (test_derivative_numerically as td,
  15. random_complex_number as randcplx, verify_numerically)
  16. x = Symbol('x')
  17. a = Symbol('a')
  18. b = Symbol('b', negative=True)
  19. z = Symbol('z')
  20. s = Symbol('s')
  21. def test_zeta_eval():
  22. assert zeta(nan) is nan
  23. assert zeta(x, nan) is nan
  24. assert zeta(0) == Rational(-1, 2)
  25. assert zeta(0, x) == S.Half - x
  26. assert zeta(0, b) == S.Half - b
  27. assert zeta(1) is zoo
  28. assert zeta(1, 2) is zoo
  29. assert zeta(1, -7) is zoo
  30. assert zeta(1, x) is zoo
  31. assert zeta(2, 1) == pi**2/6
  32. assert zeta(3, 1) == zeta(3)
  33. assert zeta(2) == pi**2/6
  34. assert zeta(4) == pi**4/90
  35. assert zeta(6) == pi**6/945
  36. assert zeta(4, 3) == pi**4/90 - Rational(17, 16)
  37. assert zeta(7, 4) == zeta(7) - Rational(282251, 279936)
  38. assert zeta(S.Half, 2).func == zeta
  39. assert expand_func(zeta(S.Half, 2)) == zeta(S.Half) - 1
  40. assert zeta(x, 3).func == zeta
  41. assert expand_func(zeta(x, 3)) == zeta(x) - 1 - 1/2**x
  42. assert zeta(2, 0) is nan
  43. assert zeta(3, -1) is nan
  44. assert zeta(4, -2) is nan
  45. assert zeta(oo) == 1
  46. assert zeta(-1) == Rational(-1, 12)
  47. assert zeta(-2) == 0
  48. assert zeta(-3) == Rational(1, 120)
  49. assert zeta(-4) == 0
  50. assert zeta(-5) == Rational(-1, 252)
  51. assert zeta(-1, 3) == Rational(-37, 12)
  52. assert zeta(-1, 7) == Rational(-253, 12)
  53. assert zeta(-1, -4) == Rational(-121, 12)
  54. assert zeta(-1, -9) == Rational(-541, 12)
  55. assert zeta(-4, 3) == -17
  56. assert zeta(-4, -8) == 8772
  57. assert zeta(0, 1) == Rational(-1, 2)
  58. assert zeta(0, -1) == Rational(3, 2)
  59. assert zeta(0, 2) == Rational(-3, 2)
  60. assert zeta(0, -2) == Rational(5, 2)
  61. assert zeta(
  62. 3).evalf(20).epsilon_eq(Float("1.2020569031595942854", 20), 1e-19)
  63. def test_zeta_series():
  64. assert zeta(x, a).series(a, z, 2) == \
  65. zeta(x, z) - x*(a-z)*zeta(x+1, z) + O((a-z)**2, (a, z))
  66. def test_dirichlet_eta_eval():
  67. assert dirichlet_eta(0) == S.Half
  68. assert dirichlet_eta(-1) == Rational(1, 4)
  69. assert dirichlet_eta(1) == log(2)
  70. assert dirichlet_eta(1, S.Half).simplify() == pi/2
  71. assert dirichlet_eta(1, 2) == 1 - log(2)
  72. assert dirichlet_eta(2) == pi**2/12
  73. assert dirichlet_eta(4) == pi**4*Rational(7, 720)
  74. assert str(dirichlet_eta(I).evalf(n=10)) == '0.5325931818 + 0.2293848577*I'
  75. assert str(dirichlet_eta(I, I).evalf(n=10)) == '3.462349253 + 0.220285771*I'
  76. def test_riemann_xi_eval():
  77. assert riemann_xi(2) == pi/6
  78. assert riemann_xi(0) == Rational(1, 2)
  79. assert riemann_xi(1) == Rational(1, 2)
  80. assert riemann_xi(3).rewrite(zeta) == 3*zeta(3)/(2*pi)
  81. assert riemann_xi(4) == pi**2/15
  82. def test_rewriting():
  83. from sympy.functions.elementary.piecewise import Piecewise
  84. assert isinstance(dirichlet_eta(x).rewrite(zeta), Piecewise)
  85. assert isinstance(dirichlet_eta(x).rewrite(genocchi), Piecewise)
  86. assert zeta(x).rewrite(dirichlet_eta) == dirichlet_eta(x)/(1 - 2**(1 - x))
  87. assert zeta(x).rewrite(dirichlet_eta, a=2) == zeta(x)
  88. assert verify_numerically(dirichlet_eta(x), dirichlet_eta(x).rewrite(zeta), x)
  89. assert verify_numerically(dirichlet_eta(x), dirichlet_eta(x).rewrite(genocchi), x)
  90. assert verify_numerically(zeta(x), zeta(x).rewrite(dirichlet_eta), x)
  91. assert zeta(x, a).rewrite(lerchphi) == lerchphi(1, x, a)
  92. assert polylog(s, z).rewrite(lerchphi) == lerchphi(z, s, 1)*z
  93. assert lerchphi(1, x, a).rewrite(zeta) == zeta(x, a)
  94. assert z*lerchphi(z, s, 1).rewrite(polylog) == polylog(s, z)
  95. def test_derivatives():
  96. from sympy.core.function import Derivative
  97. assert zeta(x, a).diff(x) == Derivative(zeta(x, a), x)
  98. assert zeta(x, a).diff(a) == -x*zeta(x + 1, a)
  99. assert lerchphi(
  100. z, s, a).diff(z) == (lerchphi(z, s - 1, a) - a*lerchphi(z, s, a))/z
  101. assert lerchphi(z, s, a).diff(a) == -s*lerchphi(z, s + 1, a)
  102. assert polylog(s, z).diff(z) == polylog(s - 1, z)/z
  103. b = randcplx()
  104. c = randcplx()
  105. assert td(zeta(b, x), x)
  106. assert td(polylog(b, z), z)
  107. assert td(lerchphi(c, b, x), x)
  108. assert td(lerchphi(x, b, c), x)
  109. raises(ArgumentIndexError, lambda: lerchphi(c, b, x).fdiff(2))
  110. raises(ArgumentIndexError, lambda: lerchphi(c, b, x).fdiff(4))
  111. raises(ArgumentIndexError, lambda: polylog(b, z).fdiff(1))
  112. raises(ArgumentIndexError, lambda: polylog(b, z).fdiff(3))
  113. def myexpand(func, target):
  114. expanded = expand_func(func)
  115. if target is not None:
  116. return expanded == target
  117. if expanded == func: # it didn't expand
  118. return False
  119. # check to see that the expanded and original evaluate to the same value
  120. subs = {}
  121. for a in func.free_symbols:
  122. subs[a] = randcplx()
  123. return abs(func.subs(subs).n()
  124. - expanded.replace(exp_polar, exp).subs(subs).n()) < 1e-10
  125. def test_polylog_expansion():
  126. assert polylog(s, 0) == 0
  127. assert polylog(s, 1) == zeta(s)
  128. assert polylog(s, -1) == -dirichlet_eta(s)
  129. assert polylog(s, exp_polar(I*pi*Rational(4, 3))) == polylog(s, exp(I*pi*Rational(4, 3)))
  130. assert polylog(s, exp_polar(I*pi)/3) == polylog(s, exp(I*pi)/3)
  131. assert myexpand(polylog(1, z), -log(1 - z))
  132. assert myexpand(polylog(0, z), z/(1 - z))
  133. assert myexpand(polylog(-1, z), z/(1 - z)**2)
  134. assert ((1-z)**3 * expand_func(polylog(-2, z))).simplify() == z*(1 + z)
  135. assert myexpand(polylog(-5, z), None)
  136. def test_polylog_series():
  137. assert polylog(1, z).series(z, n=5) == z + z**2/2 + z**3/3 + z**4/4 + O(z**5)
  138. assert polylog(1, sqrt(z)).series(z, n=3) == z/2 + z**2/4 + sqrt(z)\
  139. + z**(S(3)/2)/3 + z**(S(5)/2)/5 + O(z**3)
  140. # https://github.com/sympy/sympy/issues/9497
  141. assert polylog(S(3)/2, -z).series(z, 0, 5) == -z + sqrt(2)*z**2/4\
  142. - sqrt(3)*z**3/9 + z**4/8 + O(z**5)
  143. def test_issue_8404():
  144. i = Symbol('i', integer=True)
  145. assert Abs(Sum(1/(3*i + 1)**2, (i, 0, S.Infinity)).doit().n(4)
  146. - 1.122) < 0.001
  147. def test_polylog_values():
  148. assert polylog(2, 2) == pi**2/4 - I*pi*log(2)
  149. assert polylog(2, S.Half) == pi**2/12 - log(2)**2/2
  150. for z in [S.Half, 2, (sqrt(5)-1)/2, -(sqrt(5)-1)/2, -(sqrt(5)+1)/2, (3-sqrt(5))/2]:
  151. assert Abs(polylog(2, z).evalf() - polylog(2, z, evaluate=False).evalf()) < 1e-15
  152. z = Symbol("z")
  153. for s in [-1, 0]:
  154. for _ in range(10):
  155. assert verify_numerically(polylog(s, z), polylog(s, z, evaluate=False),
  156. z, a=-3, b=-2, c=S.Half, d=2)
  157. assert verify_numerically(polylog(s, z), polylog(s, z, evaluate=False),
  158. z, a=2, b=-2, c=5, d=2)
  159. from sympy.integrals.integrals import Integral
  160. assert polylog(0, Integral(1, (x, 0, 1))) == -S.Half
  161. def test_lerchphi_expansion():
  162. assert myexpand(lerchphi(1, s, a), zeta(s, a))
  163. assert myexpand(lerchphi(z, s, 1), polylog(s, z)/z)
  164. # direct summation
  165. assert myexpand(lerchphi(z, -1, a), a/(1 - z) + z/(1 - z)**2)
  166. assert myexpand(lerchphi(z, -3, a), None)
  167. # polylog reduction
  168. assert myexpand(lerchphi(z, s, S.Half),
  169. 2**(s - 1)*(polylog(s, sqrt(z))/sqrt(z)
  170. - polylog(s, polar_lift(-1)*sqrt(z))/sqrt(z)))
  171. assert myexpand(lerchphi(z, s, 2), -1/z + polylog(s, z)/z**2)
  172. assert myexpand(lerchphi(z, s, Rational(3, 2)), None)
  173. assert myexpand(lerchphi(z, s, Rational(7, 3)), None)
  174. assert myexpand(lerchphi(z, s, Rational(-1, 3)), None)
  175. assert myexpand(lerchphi(z, s, Rational(-5, 2)), None)
  176. # hurwitz zeta reduction
  177. assert myexpand(lerchphi(-1, s, a),
  178. 2**(-s)*zeta(s, a/2) - 2**(-s)*zeta(s, (a + 1)/2))
  179. assert myexpand(lerchphi(I, s, a), None)
  180. assert myexpand(lerchphi(-I, s, a), None)
  181. assert myexpand(lerchphi(exp(I*pi*Rational(2, 5)), s, a), None)
  182. def test_stieltjes():
  183. assert isinstance(stieltjes(x), stieltjes)
  184. assert isinstance(stieltjes(x, a), stieltjes)
  185. # Zero'th constant EulerGamma
  186. assert stieltjes(0) == S.EulerGamma
  187. assert stieltjes(0, 1) == S.EulerGamma
  188. # Not defined
  189. assert stieltjes(nan) is nan
  190. assert stieltjes(0, nan) is nan
  191. assert stieltjes(-1) is S.ComplexInfinity
  192. assert stieltjes(1.5) is S.ComplexInfinity
  193. assert stieltjes(z, 0) is S.ComplexInfinity
  194. assert stieltjes(z, -1) is S.ComplexInfinity
  195. def test_stieltjes_evalf():
  196. assert abs(stieltjes(0).evalf() - 0.577215664) < 1E-9
  197. assert abs(stieltjes(0, 0.5).evalf() - 1.963510026) < 1E-9
  198. assert abs(stieltjes(1, 2).evalf() + 0.072815845) < 1E-9
  199. def test_issue_10475():
  200. a = Symbol('a', extended_real=True)
  201. b = Symbol('b', extended_positive=True)
  202. s = Symbol('s', zero=False)
  203. assert zeta(2 + I).is_finite
  204. assert zeta(1).is_finite is False
  205. assert zeta(x).is_finite is None
  206. assert zeta(x + I).is_finite is None
  207. assert zeta(a).is_finite is None
  208. assert zeta(b).is_finite is None
  209. assert zeta(-b).is_finite is True
  210. assert zeta(b**2 - 2*b + 1).is_finite is None
  211. assert zeta(a + I).is_finite is True
  212. assert zeta(b + 1).is_finite is True
  213. assert zeta(s + 1).is_finite is True
  214. def test_issue_14177():
  215. n = Symbol('n', nonnegative=True, integer=True)
  216. assert zeta(-n).rewrite(bernoulli) == bernoulli(n+1) / (-n-1)
  217. assert zeta(-n, a).rewrite(bernoulli) == bernoulli(n+1, a) / (-n-1)
  218. z2n = -(2*I*pi)**(2*n)*bernoulli(2*n) / (2*factorial(2*n))
  219. assert zeta(2*n).rewrite(bernoulli) == z2n
  220. assert expand_func(zeta(s, n+1)) == zeta(s) - harmonic(n, s)
  221. assert expand_func(zeta(-b, -n)) is nan
  222. assert expand_func(zeta(-b, n)) == zeta(-b, n)
  223. n = Symbol('n')
  224. assert zeta(2*n) == zeta(2*n) # As sign of z (= 2*n) is not determined