test_diff.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from sympy.concrete.summations import Sum
  2. from sympy.core.expr import Expr
  3. from sympy.core.function import (Derivative, Function, diff, Subs)
  4. from sympy.core.numbers import (I, Rational, pi)
  5. from sympy.core.relational import Eq
  6. from sympy.core.singleton import S
  7. from sympy.core.symbol import Symbol
  8. from sympy.functions.combinatorial.factorials import factorial
  9. from sympy.functions.elementary.complexes import (im, re)
  10. from sympy.functions.elementary.exponential import (exp, log)
  11. from sympy.functions.elementary.miscellaneous import Max
  12. from sympy.functions.elementary.piecewise import Piecewise
  13. from sympy.functions.elementary.trigonometric import (cos, cot, sin, tan)
  14. from sympy.tensor.array.ndim_array import NDimArray
  15. from sympy.testing.pytest import raises
  16. from sympy.abc import a, b, c, x, y, z
  17. def test_diff():
  18. assert Rational(1, 3).diff(x) is S.Zero
  19. assert I.diff(x) is S.Zero
  20. assert pi.diff(x) is S.Zero
  21. assert x.diff(x, 0) == x
  22. assert (x**2).diff(x, 2, x) == 0
  23. assert (x**2).diff((x, 2), x) == 0
  24. assert (x**2).diff((x, 1), x) == 2
  25. assert (x**2).diff((x, 1), (x, 1)) == 2
  26. assert (x**2).diff((x, 2)) == 2
  27. assert (x**2).diff(x, y, 0) == 2*x
  28. assert (x**2).diff(x, (y, 0)) == 2*x
  29. assert (x**2).diff(x, y) == 0
  30. raises(ValueError, lambda: x.diff(1, x))
  31. p = Rational(5)
  32. e = a*b + b**p
  33. assert e.diff(a) == b
  34. assert e.diff(b) == a + 5*b**4
  35. assert e.diff(b).diff(a) == Rational(1)
  36. e = a*(b + c)
  37. assert e.diff(a) == b + c
  38. assert e.diff(b) == a
  39. assert e.diff(b).diff(a) == Rational(1)
  40. e = c**p
  41. assert e.diff(c, 6) == Rational(0)
  42. assert e.diff(c, 5) == Rational(120)
  43. e = c**Rational(2)
  44. assert e.diff(c) == 2*c
  45. e = a*b*c
  46. assert e.diff(c) == a*b
  47. def test_diff2():
  48. n3 = Rational(3)
  49. n2 = Rational(2)
  50. n6 = Rational(6)
  51. e = n3*(-n2 + x**n2)*cos(x) + x*(-n6 + x**n2)*sin(x)
  52. assert e == 3*(-2 + x**2)*cos(x) + x*(-6 + x**2)*sin(x)
  53. assert e.diff(x).expand() == x**3*cos(x)
  54. e = (x + 1)**3
  55. assert e.diff(x) == 3*(x + 1)**2
  56. e = x*(x + 1)**3
  57. assert e.diff(x) == (x + 1)**3 + 3*x*(x + 1)**2
  58. e = 2*exp(x*x)*x
  59. assert e.diff(x) == 2*exp(x**2) + 4*x**2*exp(x**2)
  60. def test_diff3():
  61. p = Rational(5)
  62. e = a*b + sin(b**p)
  63. assert e == a*b + sin(b**5)
  64. assert e.diff(a) == b
  65. assert e.diff(b) == a + 5*b**4*cos(b**5)
  66. e = tan(c)
  67. assert e == tan(c)
  68. assert e.diff(c) in [cos(c)**(-2), 1 + sin(c)**2/cos(c)**2, 1 + tan(c)**2]
  69. e = c*log(c) - c
  70. assert e == -c + c*log(c)
  71. assert e.diff(c) == log(c)
  72. e = log(sin(c))
  73. assert e == log(sin(c))
  74. assert e.diff(c) in [sin(c)**(-1)*cos(c), cot(c)]
  75. e = (Rational(2)**a/log(Rational(2)))
  76. assert e == 2**a*log(Rational(2))**(-1)
  77. assert e.diff(a) == 2**a
  78. def test_diff_no_eval_derivative():
  79. class My(Expr):
  80. def __new__(cls, x):
  81. return Expr.__new__(cls, x)
  82. # My doesn't have its own _eval_derivative method
  83. assert My(x).diff(x).func is Derivative
  84. assert My(x).diff(x, 3).func is Derivative
  85. assert re(x).diff(x, 2) == Derivative(re(x), (x, 2)) # issue 15518
  86. assert diff(NDimArray([re(x), im(x)]), (x, 2)) == NDimArray(
  87. [Derivative(re(x), (x, 2)), Derivative(im(x), (x, 2))])
  88. # it doesn't have y so it shouldn't need a method for this case
  89. assert My(x).diff(y) == 0
  90. def test_speed():
  91. # this should return in 0.0s. If it takes forever, it's wrong.
  92. assert x.diff(x, 10**8) == 0
  93. def test_deriv_noncommutative():
  94. A = Symbol("A", commutative=False)
  95. f = Function("f")
  96. assert A*f(x)*A == f(x)*A**2
  97. assert A*f(x).diff(x)*A == f(x).diff(x) * A**2
  98. def test_diff_nth_derivative():
  99. f = Function("f")
  100. n = Symbol("n", integer=True)
  101. expr = diff(sin(x), (x, n))
  102. expr2 = diff(f(x), (x, 2))
  103. expr3 = diff(f(x), (x, n))
  104. assert expr.subs(sin(x), cos(-x)) == Derivative(cos(-x), (x, n))
  105. assert expr.subs(n, 1).doit() == cos(x)
  106. assert expr.subs(n, 2).doit() == -sin(x)
  107. assert expr2.subs(Derivative(f(x), x), y) == Derivative(y, x)
  108. # Currently not supported (cannot determine if `n > 1`):
  109. #assert expr3.subs(Derivative(f(x), x), y) == Derivative(y, (x, n-1))
  110. assert expr3 == Derivative(f(x), (x, n))
  111. assert diff(x, (x, n)) == Piecewise((x, Eq(n, 0)), (1, Eq(n, 1)), (0, True))
  112. assert diff(2*x, (x, n)).dummy_eq(
  113. Sum(Piecewise((2*x*factorial(n)/(factorial(y)*factorial(-y + n)),
  114. Eq(y, 0) & Eq(Max(0, -y + n), 0)),
  115. (2*factorial(n)/(factorial(y)*factorial(-y + n)), Eq(y, 0) & Eq(Max(0,
  116. -y + n), 1)), (0, True)), (y, 0, n)))
  117. # TODO: assert diff(x**2, (x, n)) == x**(2-n)*ff(2, n)
  118. exprm = x*sin(x)
  119. mul_diff = diff(exprm, (x, n))
  120. assert isinstance(mul_diff, Sum)
  121. for i in range(5):
  122. assert mul_diff.subs(n, i).doit() == exprm.diff((x, i)).expand()
  123. exprm2 = 2*y*x*sin(x)*cos(x)*log(x)*exp(x)
  124. dex = exprm2.diff((x, n))
  125. assert isinstance(dex, Sum)
  126. for i in range(7):
  127. assert dex.subs(n, i).doit().expand() == \
  128. exprm2.diff((x, i)).expand()
  129. assert (cos(x)*sin(y)).diff([[x, y, z]]) == NDimArray([
  130. -sin(x)*sin(y), cos(x)*cos(y), 0])
  131. def test_issue_16160():
  132. assert Derivative(x**3, (x, x)).subs(x, 2) == Subs(
  133. Derivative(x**3, (x, 2)), x, 2)
  134. assert Derivative(1 + x**3, (x, x)).subs(x, 0
  135. ) == Derivative(1 + y**3, (y, 0)).subs(y, 0)