test_function.py 50 KB


  1. from sympy.concrete.summations import Sum
  2. from sympy.core.basic import Basic, _aresame
  3. from sympy.core.cache import clear_cache
  4. from sympy.core.containers import Dict, Tuple
  5. from sympy.core.expr import Expr, unchanged
  6. from sympy.core.function import (Subs, Function, diff, Lambda, expand,
  7. nfloat, Derivative)
  8. from sympy.core.numbers import E, Float, zoo, Rational, pi, I, oo, nan
  9. from sympy.core.power import Pow
  10. from sympy.core.relational import Eq
  11. from sympy.core.singleton import S
  12. from sympy.core.symbol import symbols, Dummy, Symbol
  13. from sympy.functions.elementary.complexes import im, re
  14. from sympy.functions.elementary.exponential import log, exp
  15. from sympy.functions.elementary.miscellaneous import sqrt
  16. from sympy.functions.elementary.piecewise import Piecewise
  17. from sympy.functions.elementary.trigonometric import sin, cos, acos
  18. from sympy.functions.special.error_functions import expint
  19. from sympy.functions.special.gamma_functions import loggamma, polygamma
  20. from sympy.matrices.dense import Matrix
  21. from sympy.printing.str import sstr
  22. from sympy.series.order import O
  23. from sympy.tensor.indexed import Indexed
  24. from sympy.core.function import (PoleError, _mexpand, arity,
  25. BadSignatureError, BadArgumentsError)
  26. from sympy.core.parameters import _exp_is_pow
  27. from sympy.core.sympify import sympify, SympifyError
  28. from sympy.matrices import MutableMatrix, ImmutableMatrix
  29. from sympy.sets.sets import FiniteSet
  30. from sympy.solvers.solveset import solveset
  31. from sympy.tensor.array import NDimArray
  32. from sympy.utilities.iterables import subsets, variations
  33. from sympy.testing.pytest import XFAIL, raises, warns_deprecated_sympy, _both_exp_pow
  34. from sympy.abc import t, w, x, y, z
  35. f, g, h = symbols('f g h', cls=Function)
  36. _xi_1, _xi_2, _xi_3 = [Dummy() for i in range(3)]
  37. def test_f_expand_complex():
  38. x = Symbol('x', real=True)
  39. assert f(x).expand(complex=True) == I*im(f(x)) + re(f(x))
  40. assert exp(x).expand(complex=True) == exp(x)
  41. assert exp(I*x).expand(complex=True) == cos(x) + I*sin(x)
  42. assert exp(z).expand(complex=True) == cos(im(z))*exp(re(z)) + \
  43. I*sin(im(z))*exp(re(z))
  44. def test_bug1():
  45. e = sqrt(-log(w))
  46. assert e.subs(log(w), -x) == sqrt(x)
  47. e = sqrt(-5*log(w))
  48. assert e.subs(log(w), -x) == sqrt(5*x)
  49. def test_general_function():
  50. nu = Function('nu')
  51. e = nu(x)
  52. edx = e.diff(x)
  53. edy = e.diff(y)
  54. edxdx = e.diff(x).diff(x)
  55. edxdy = e.diff(x).diff(y)
  56. assert e == nu(x)
  57. assert edx != nu(x)
  58. assert edx == diff(nu(x), x)
  59. assert edy == 0
  60. assert edxdx == diff(diff(nu(x), x), x)
  61. assert edxdy == 0
  62. def test_general_function_nullary():
  63. nu = Function('nu')
  64. e = nu()
  65. edx = e.diff(x)
  66. edxdx = e.diff(x).diff(x)
  67. assert e == nu()
  68. assert edx != nu()
  69. assert edx == 0
  70. assert edxdx == 0
  71. def test_derivative_subs_bug():
  72. e = diff(g(x), x)
  73. assert e.subs(g(x), f(x)) != e
  74. assert e.subs(g(x), f(x)) == Derivative(f(x), x)
  75. assert e.subs(g(x), -f(x)) == Derivative(-f(x), x)
  76. assert e.subs(x, y) == Derivative(g(y), y)
  77. def test_derivative_subs_self_bug():
  78. d = diff(f(x), x)
  79. assert d.subs(d, y) == y
  80. def test_derivative_linearity():
  81. assert diff(-f(x), x) == -diff(f(x), x)
  82. assert diff(8*f(x), x) == 8*diff(f(x), x)
  83. assert diff(8*f(x), x) != 7*diff(f(x), x)
  84. assert diff(8*f(x)*x, x) == 8*f(x) + 8*x*diff(f(x), x)
  85. assert diff(8*f(x)*y*x, x).expand() == 8*y*f(x) + 8*y*x*diff(f(x), x)
  86. def test_derivative_evaluate():
  87. assert Derivative(sin(x), x) != diff(sin(x), x)
  88. assert Derivative(sin(x), x).doit() == diff(sin(x), x)
  89. assert Derivative(Derivative(f(x), x), x) == diff(f(x), x, x)
  90. assert Derivative(sin(x), x, 0) == sin(x)
  91. assert Derivative(sin(x), (x, y), (x, -y)) == sin(x)
  92. def test_diff_symbols():
  93. assert diff(f(x, y, z), x, y, z) == Derivative(f(x, y, z), x, y, z)
  94. assert diff(f(x, y, z), x, x, x) == Derivative(f(x, y, z), x, x, x) == Derivative(f(x, y, z), (x, 3))
  95. assert diff(f(x, y, z), x, 3) == Derivative(f(x, y, z), x, 3)
  96. # issue 5028
  97. assert [diff(-z + x/y, sym) for sym in (z, x, y)] == [-1, 1/y, -x/y**2]
  98. assert diff(f(x, y, z), x, y, z, 2) == Derivative(f(x, y, z), x, y, z, z)
  99. assert diff(f(x, y, z), x, y, z, 2, evaluate=False) == \
  100. Derivative(f(x, y, z), x, y, z, z)
  101. assert Derivative(f(x, y, z), x, y, z)._eval_derivative(z) == \
  102. Derivative(f(x, y, z), x, y, z, z)
  103. assert Derivative(Derivative(f(x, y, z), x), y)._eval_derivative(z) == \
  104. Derivative(f(x, y, z), x, y, z)
  105. raises(TypeError, lambda: cos(x).diff((x, y)).variables)
  106. assert cos(x).diff((x, y))._wrt_variables == [x]
  107. # issue 23222
  108. assert sympify("a*x+b").diff("x") == sympify("a")
  109. def test_Function():
  110. class myfunc(Function):
  111. @classmethod
  112. def eval(cls): # zero args
  113. return
  114. assert myfunc.nargs == FiniteSet(0)
  115. assert myfunc().nargs == FiniteSet(0)
  116. raises(TypeError, lambda: myfunc(x).nargs)
  117. class myfunc(Function):
  118. @classmethod
  119. def eval(cls, x): # one arg
  120. return
  121. assert myfunc.nargs == FiniteSet(1)
  122. assert myfunc(x).nargs == FiniteSet(1)
  123. raises(TypeError, lambda: myfunc(x, y).nargs)
  124. class myfunc(Function):
  125. @classmethod
  126. def eval(cls, *x): # star args
  127. return
  128. assert myfunc.nargs == S.Naturals0
  129. assert myfunc(x).nargs == S.Naturals0
  130. def test_nargs():
  131. f = Function('f')
  132. assert f.nargs == S.Naturals0
  133. assert f(1).nargs == S.Naturals0
  134. assert Function('f', nargs=2)(1, 2).nargs == FiniteSet(2)
  135. assert sin.nargs == FiniteSet(1)
  136. assert sin(2).nargs == FiniteSet(1)
  137. assert log.nargs == FiniteSet(1, 2)
  138. assert log(2).nargs == FiniteSet(1, 2)
  139. assert Function('f', nargs=2).nargs == FiniteSet(2)
  140. assert Function('f', nargs=0).nargs == FiniteSet(0)
  141. assert Function('f', nargs=(0, 1)).nargs == FiniteSet(0, 1)
  142. assert Function('f', nargs=None).nargs == S.Naturals0
  143. raises(ValueError, lambda: Function('f', nargs=()))
  144. def test_nargs_inheritance():
  145. class f1(Function):
  146. nargs = 2
  147. class f2(f1):
  148. pass
  149. class f3(f2):
  150. pass
  151. class f4(f3):
  152. nargs = 1,2
  153. class f5(f4):
  154. pass
  155. class f6(f5):
  156. pass
  157. class f7(f6):
  158. nargs=None
  159. class f8(f7):
  160. pass
  161. class f9(f8):
  162. pass
  163. class f10(f9):
  164. nargs = 1
  165. class f11(f10):
  166. pass
  167. assert f1.nargs == FiniteSet(2)
  168. assert f2.nargs == FiniteSet(2)
  169. assert f3.nargs == FiniteSet(2)
  170. assert f4.nargs == FiniteSet(1, 2)
  171. assert f5.nargs == FiniteSet(1, 2)
  172. assert f6.nargs == FiniteSet(1, 2)
  173. assert f7.nargs == S.Naturals0
  174. assert f8.nargs == S.Naturals0
  175. assert f9.nargs == S.Naturals0
  176. assert f10.nargs == FiniteSet(1)
  177. assert f11.nargs == FiniteSet(1)
  178. def test_arity():
  179. f = lambda x, y: 1
  180. assert arity(f) == 2
  181. def f(x, y, z=None):
  182. pass
  183. assert arity(f) == (2, 3)
  184. assert arity(lambda *x: x) is None
  185. assert arity(log) == (1, 2)
  186. def test_Lambda():
  187. e = Lambda(x, x**2)
  188. assert e(4) == 16
  189. assert e(x) == x**2
  190. assert e(y) == y**2
  191. assert Lambda((), 42)() == 42
  192. assert unchanged(Lambda, (), 42)
  193. assert Lambda((), 42) != Lambda((), 43)
  194. assert Lambda((), f(x))() == f(x)
  195. assert Lambda((), 42).nargs == FiniteSet(0)
  196. assert unchanged(Lambda, (x,), x**2)
  197. assert Lambda(x, x**2) == Lambda((x,), x**2)
  198. assert Lambda(x, x**2) != Lambda(x, x**2 + 1)
  199. assert Lambda((x, y), x**y) != Lambda((y, x), y**x)
  200. assert Lambda((x, y), x**y) != Lambda((x, y), y**x)
  201. assert Lambda((x, y), x**y)(x, y) == x**y
  202. assert Lambda((x, y), x**y)(3, 3) == 3**3
  203. assert Lambda((x, y), x**y)(x, 3) == x**3
  204. assert Lambda((x, y), x**y)(3, y) == 3**y
  205. assert Lambda(x, f(x))(x) == f(x)
  206. assert Lambda(x, x**2)(e(x)) == x**4
  207. assert e(e(x)) == x**4
  208. x1, x2 = (Indexed('x', i) for i in (1, 2))
  209. assert Lambda((x1, x2), x1 + x2)(x, y) == x + y
  210. assert Lambda((x, y), x + y).nargs == FiniteSet(2)
  211. p = x, y, z, t
  212. assert Lambda(p, t*(x + y + z))(*p) == t * (x + y + z)
  213. eq = Lambda(x, 2*x) + Lambda(y, 2*y)
  214. assert eq != 2*Lambda(x, 2*x)
  215. assert eq.as_dummy() == 2*Lambda(x, 2*x).as_dummy()
  216. assert Lambda(x, 2*x) not in [ Lambda(x, x) ]
  217. raises(BadSignatureError, lambda: Lambda(1, x))
  218. assert Lambda(x, 1)(1) is S.One
  219. raises(BadSignatureError, lambda: Lambda((x, x), x + 2))
  220. raises(BadSignatureError, lambda: Lambda(((x, x), y), x))
  221. raises(BadSignatureError, lambda: Lambda(((y, x), x), x))
  222. raises(BadSignatureError, lambda: Lambda(((y, 1), 2), x))
  223. with warns_deprecated_sympy():
  224. assert Lambda([x, y], x+y) == Lambda((x, y), x+y)
  225. flam = Lambda(((x, y),), x + y)
  226. assert flam((2, 3)) == 5
  227. flam = Lambda(((x, y), z), x + y + z)
  228. assert flam((2, 3), 1) == 6
  229. flam = Lambda((((x, y), z),), x + y + z)
  230. assert flam(((2, 3), 1)) == 6
  231. raises(BadArgumentsError, lambda: flam(1, 2, 3))
  232. flam = Lambda( (x,), (x, x))
  233. assert flam(1,) == (1, 1)
  234. assert flam((1,)) == ((1,), (1,))
  235. flam = Lambda( ((x,),), (x, x))
  236. raises(BadArgumentsError, lambda: flam(1))
  237. assert flam((1,)) == (1, 1)
  238. # Previously TypeError was raised so this is potentially needed for
  239. # backwards compatibility.
  240. assert issubclass(BadSignatureError, TypeError)
  241. assert issubclass(BadArgumentsError, TypeError)
  242. # These are tested to see they don't raise:
  243. hash(Lambda(x, 2*x))
  244. hash(Lambda(x, x)) # IdentityFunction subclass
  245. def test_IdentityFunction():
  246. assert Lambda(x, x) is Lambda(y, y) is S.IdentityFunction
  247. assert Lambda(x, 2*x) is not S.IdentityFunction
  248. assert Lambda((x, y), x) is not S.IdentityFunction
  249. def test_Lambda_symbols():
  250. assert Lambda(x, 2*x).free_symbols == set()
  251. assert Lambda(x, x*y).free_symbols == {y}
  252. assert Lambda((), 42).free_symbols == set()
  253. assert Lambda((), x*y).free_symbols == {x,y}
  254. def test_functionclas_symbols():
  255. assert f.free_symbols == set()
  256. def test_Lambda_arguments():
  257. raises(TypeError, lambda: Lambda(x, 2*x)(x, y))
  258. raises(TypeError, lambda: Lambda((x, y), x + y)(x))
  259. raises(TypeError, lambda: Lambda((), 42)(x))
  260. def test_Lambda_equality():
  261. assert Lambda((x, y), 2*x) == Lambda((x, y), 2*x)
  262. # these, of course, should never be equal
  263. assert Lambda(x, 2*x) != Lambda((x, y), 2*x)
  264. assert Lambda(x, 2*x) != 2*x
  265. # But it is tempting to want expressions that differ only
  266. # in bound symbols to compare the same. But this is not what
  267. # Python's `==` is intended to do; two objects that compare
  268. # as equal means that they are indistibguishable and cache to the
  269. # same value. We wouldn't want to expression that are
  270. # mathematically the same but written in different variables to be
  271. # interchanged else what is the point of allowing for different
  272. # variable names?
  273. assert Lambda(x, 2*x) != Lambda(y, 2*y)
  274. def test_Subs():
  275. assert Subs(1, (), ()) is S.One
  276. # check null subs influence on hashing
  277. assert Subs(x, y, z) != Subs(x, y, 1)
  278. # neutral subs works
  279. assert Subs(x, x, 1).subs(x, y).has(y)
  280. # self mapping var/point
  281. assert Subs(Derivative(f(x), (x, 2)), x, x).doit() == f(x).diff(x, x)
  282. assert Subs(x, x, 0).has(x) # it's a structural answer
  283. assert not Subs(x, x, 0).free_symbols
  284. assert Subs(Subs(x + y, x, 2), y, 1) == Subs(x + y, (x, y), (2, 1))
  285. assert Subs(x, (x,), (0,)) == Subs(x, x, 0)
  286. assert Subs(x, x, 0) == Subs(y, y, 0)
  287. assert Subs(x, x, 0).subs(x, 1) == Subs(x, x, 0)
  288. assert Subs(y, x, 0).subs(y, 1) == Subs(1, x, 0)
  289. assert Subs(f(x), x, 0).doit() == f(0)
  290. assert Subs(f(x**2), x**2, 0).doit() == f(0)
  291. assert Subs(f(x, y, z), (x, y, z), (0, 1, 1)) != \
  292. Subs(f(x, y, z), (x, y, z), (0, 0, 1))
  293. assert Subs(x, y, 2).subs(x, y).doit() == 2
  294. assert Subs(f(x, y), (x, y, z), (0, 1, 1)) != \
  295. Subs(f(x, y) + z, (x, y, z), (0, 1, 0))
  296. assert Subs(f(x, y), (x, y), (0, 1)).doit() == f(0, 1)
  297. assert Subs(Subs(f(x, y), x, 0), y, 1).doit() == f(0, 1)
  298. raises(ValueError, lambda: Subs(f(x, y), (x, y), (0, 0, 1)))
  299. raises(ValueError, lambda: Subs(f(x, y), (x, x, y), (0, 0, 1)))
  300. assert len(Subs(f(x, y), (x, y), (0, 1)).variables) == 2
  301. assert Subs(f(x, y), (x, y), (0, 1)).point == Tuple(0, 1)
  302. assert Subs(f(x), x, 0) == Subs(f(y), y, 0)
  303. assert Subs(f(x, y), (x, y), (0, 1)) == Subs(f(x, y), (y, x), (1, 0))
  304. assert Subs(f(x)*y, (x, y), (0, 1)) == Subs(f(y)*x, (y, x), (0, 1))
  305. assert Subs(f(x)*y, (x, y), (1, 1)) == Subs(f(y)*x, (x, y), (1, 1))
  306. assert Subs(f(x), x, 0).subs(x, 1).doit() == f(0)
  307. assert Subs(f(x), x, y).subs(y, 0) == Subs(f(x), x, 0)
  308. assert Subs(y*f(x), x, y).subs(y, 2) == Subs(2*f(x), x, 2)
  309. assert (2 * Subs(f(x), x, 0)).subs(Subs(f(x), x, 0), y) == 2*y
  310. assert Subs(f(x), x, 0).free_symbols == set()
  311. assert Subs(f(x, y), x, z).free_symbols == {y, z}
  312. assert Subs(f(x).diff(x), x, 0).doit(), Subs(f(x).diff(x), x, 0)
  313. assert Subs(1 + f(x).diff(x), x, 0).doit(), 1 + Subs(f(x).diff(x), x, 0)
  314. assert Subs(y*f(x, y).diff(x), (x, y), (0, 2)).doit() == \
  315. 2*Subs(Derivative(f(x, 2), x), x, 0)
  316. assert Subs(y**2*f(x), x, 0).diff(y) == 2*y*f(0)
  317. e = Subs(y**2*f(x), x, y)
  318. assert e.diff(y) == e.doit().diff(y) == y**2*Derivative(f(y), y) + 2*y*f(y)
  319. assert Subs(f(x), x, 0) + Subs(f(x), x, 0) == 2*Subs(f(x), x, 0)
  320. e1 = Subs(z*f(x), x, 1)
  321. e2 = Subs(z*f(y), y, 1)
  322. assert e1 + e2 == 2*e1
  323. assert e1.__hash__() == e2.__hash__()
  324. assert Subs(z*f(x + 1), x, 1) not in [ e1, e2 ]
  325. assert Derivative(f(x), x).subs(x, g(x)) == Derivative(f(g(x)), g(x))
  326. assert Derivative(f(x), x).subs(x, x + y) == Subs(Derivative(f(x), x),
  327. x, x + y)
  328. assert Subs(f(x)*cos(y) + z, (x, y), (0, pi/3)).n(2) == \
  329. Subs(f(x)*cos(y) + z, (x, y), (0, pi/3)).evalf(2) == \
  330. z + Rational('1/2').n(2)*f(0)
  331. assert f(x).diff(x).subs(x, 0).subs(x, y) == f(x).diff(x).subs(x, 0)
  332. assert (x*f(x).diff(x).subs(x, 0)).subs(x, y) == y*f(x).diff(x).subs(x, 0)
  333. assert Subs(Derivative(g(x)**2, g(x), x), g(x), exp(x)
  334. ).doit() == 2*exp(x)
  335. assert Subs(Derivative(g(x)**2, g(x), x), g(x), exp(x)
  336. ).doit(deep=False) == 2*Derivative(exp(x), x)
  337. assert Derivative(f(x, g(x)), x).doit() == Derivative(
  338. f(x, g(x)), g(x))*Derivative(g(x), x) + Subs(Derivative(
  339. f(y, g(x)), y), y, x)
  340. def test_doitdoit():
  341. done = Derivative(f(x, g(x)), x, g(x)).doit()
  342. assert done == done.doit()
  343. @XFAIL
  344. def test_Subs2():
  345. # this reflects a limitation of subs(), probably won't fix
  346. assert Subs(f(x), x**2, x).doit() == f(sqrt(x))
  347. def test_expand_function():
  348. assert expand(x + y) == x + y
  349. assert expand(x + y, complex=True) == I*im(x) + I*im(y) + re(x) + re(y)
  350. assert expand((x + y)**11, modulus=11) == x**11 + y**11
  351. def test_function_comparable():
  352. assert sin(x).is_comparable is False
  353. assert cos(x).is_comparable is False
  354. assert sin(Float('0.1')).is_comparable is True
  355. assert cos(Float('0.1')).is_comparable is True
  356. assert sin(E).is_comparable is True
  357. assert cos(E).is_comparable is True
  358. assert sin(Rational(1, 3)).is_comparable is True
  359. assert cos(Rational(1, 3)).is_comparable is True
  360. def test_function_comparable_infinities():
  361. assert sin(oo).is_comparable is False
  362. assert sin(-oo).is_comparable is False
  363. assert sin(zoo).is_comparable is False
  364. assert sin(nan).is_comparable is False
  365. def test_deriv1():
  366. # These all require derivatives evaluated at a point (issue 4719) to work.
  367. # See issue 4624
  368. assert f(2*x).diff(x) == 2*Subs(Derivative(f(x), x), x, 2*x)
  369. assert (f(x)**3).diff(x) == 3*f(x)**2*f(x).diff(x)
  370. assert (f(2*x)**3).diff(x) == 6*f(2*x)**2*Subs(
  371. Derivative(f(x), x), x, 2*x)
  372. assert f(2 + x).diff(x) == Subs(Derivative(f(x), x), x, x + 2)
  373. assert f(2 + 3*x).diff(x) == 3*Subs(
  374. Derivative(f(x), x), x, 3*x + 2)
  375. assert f(3*sin(x)).diff(x) == 3*cos(x)*Subs(
  376. Derivative(f(x), x), x, 3*sin(x))
  377. # See issue 8510
  378. assert f(x, x + z).diff(x) == (
  379. Subs(Derivative(f(y, x + z), y), y, x) +
  380. Subs(Derivative(f(x, y), y), y, x + z))
  381. assert f(x, x**2).diff(x) == (
  382. 2*x*Subs(Derivative(f(x, y), y), y, x**2) +
  383. Subs(Derivative(f(y, x**2), y), y, x))
  384. # but Subs is not always necessary
  385. assert f(x, g(y)).diff(g(y)) == Derivative(f(x, g(y)), g(y))
  386. def test_deriv2():
  387. assert (x**3).diff(x) == 3*x**2
  388. assert (x**3).diff(x, evaluate=False) != 3*x**2
  389. assert (x**3).diff(x, evaluate=False) == Derivative(x**3, x)
  390. assert diff(x**3, x) == 3*x**2
  391. assert diff(x**3, x, evaluate=False) != 3*x**2
  392. assert diff(x**3, x, evaluate=False) == Derivative(x**3, x)
  393. def test_func_deriv():
  394. assert f(x).diff(x) == Derivative(f(x), x)
  395. # issue 4534
  396. assert f(x, y).diff(x, y) - f(x, y).diff(y, x) == 0
  397. assert Derivative(f(x, y), x, y).args[1:] == ((x, 1), (y, 1))
  398. assert Derivative(f(x, y), y, x).args[1:] == ((y, 1), (x, 1))
  399. assert (Derivative(f(x, y), x, y) - Derivative(f(x, y), y, x)).doit() == 0
  400. def test_suppressed_evaluation():
  401. a = sin(0, evaluate=False)
  402. assert a != 0
  403. assert a.func is sin
  404. assert a.args == (0,)
  405. def test_function_evalf():
  406. def eq(a, b, eps):
  407. return abs(a - b) < eps
  408. assert eq(sin(1).evalf(15), Float("0.841470984807897"), 1e-13)
  409. assert eq(
  410. sin(2).evalf(25), Float("0.9092974268256816953960199", 25), 1e-23)
  411. assert eq(sin(1 + I).evalf(
  412. 15), Float("1.29845758141598") + Float("0.634963914784736")*I, 1e-13)
  413. assert eq(exp(1 + I).evalf(15), Float(
  414. "1.46869393991588") + Float("2.28735528717884239")*I, 1e-13)
  415. assert eq(exp(-0.5 + 1.5*I).evalf(15), Float(
  416. "0.0429042815937374") + Float("0.605011292285002")*I, 1e-13)
  417. assert eq(log(pi + sqrt(2)*I).evalf(
  418. 15), Float("1.23699044022052") + Float("0.422985442737893")*I, 1e-13)
  419. assert eq(cos(100).evalf(15), Float("0.86231887228768"), 1e-13)
  420. def test_extensibility_eval():
  421. class MyFunc(Function):
  422. @classmethod
  423. def eval(cls, *args):
  424. return (0, 0, 0)
  425. assert MyFunc(0) == (0, 0, 0)
  426. @_both_exp_pow
  427. def test_function_non_commutative():
  428. x = Symbol('x', commutative=False)
  429. assert f(x).is_commutative is False
  430. assert sin(x).is_commutative is False
  431. assert exp(x).is_commutative is False
  432. assert log(x).is_commutative is False
  433. assert f(x).is_complex is False
  434. assert sin(x).is_complex is False
  435. assert exp(x).is_complex is False
  436. assert log(x).is_complex is False
  437. def test_function_complex():
  438. x = Symbol('x', complex=True)
  439. xzf = Symbol('x', complex=True, zero=False)
  440. assert f(x).is_commutative is True
  441. assert sin(x).is_commutative is True
  442. assert exp(x).is_commutative is True
  443. assert log(x).is_commutative is True
  444. assert f(x).is_complex is None
  445. assert sin(x).is_complex is True
  446. assert exp(x).is_complex is True
  447. assert log(x).is_complex is None
  448. assert log(xzf).is_complex is True
  449. def test_function__eval_nseries():
  450. n = Symbol('n')
  451. assert sin(x)._eval_nseries(x, 2, None) == x + O(x**2)
  452. assert sin(x + 1)._eval_nseries(x, 2, None) == x*cos(1) + sin(1) + O(x**2)
  453. assert sin(pi*(1 - x))._eval_nseries(x, 2, None) == pi*x + O(x**2)
  454. assert acos(1 - x**2)._eval_nseries(x, 2, None) == sqrt(2)*sqrt(x**2) + O(x**2)
  455. assert polygamma(n, x + 1)._eval_nseries(x, 2, None) == \
  456. polygamma(n, 1) + polygamma(n + 1, 1)*x + O(x**2)
  457. raises(PoleError, lambda: sin(1/x)._eval_nseries(x, 2, None))
  458. assert acos(1 - x)._eval_nseries(x, 2, None) == sqrt(2)*sqrt(x) + sqrt(2)*x**(S(3)/2)/12 + O(x**2)
  459. assert acos(1 + x)._eval_nseries(x, 2, None) == sqrt(2)*sqrt(-x) + sqrt(2)*(-x)**(S(3)/2)/12 + O(x**2)
  460. assert loggamma(1/x)._eval_nseries(x, 0, None) == \
  461. log(x)/2 - log(x)/x - 1/x + O(1, x)
  462. assert loggamma(log(1/x)).nseries(x, n=1, logx=y) == loggamma(-y)
  463. # issue 6725:
  464. assert expint(Rational(3, 2), -x)._eval_nseries(x, 5, None) == \
  465. 2 - 2*sqrt(pi)*sqrt(-x) - 2*x + x**2 + x**3/3 + x**4/12 + 4*I*x**(S(3)/2)*sqrt(-x)/3 + \
  466. 2*I*x**(S(5)/2)*sqrt(-x)/5 + 2*I*x**(S(7)/2)*sqrt(-x)/21 + O(x**5)
  467. assert sin(sqrt(x))._eval_nseries(x, 3, None) == \
  468. sqrt(x) - x**Rational(3, 2)/6 + x**Rational(5, 2)/120 + O(x**3)
  469. # issue 19065:
  470. s1 = f(x,y).series(y, n=2)
  471. assert {i.name for i in s1.atoms(Symbol)} == {'x', 'xi', 'y'}
  472. xi = Symbol('xi')
  473. s2 = f(xi, y).series(y, n=2)
  474. assert {i.name for i in s2.atoms(Symbol)} == {'xi', 'xi0', 'y'}
  475. def test_doit():
  476. n = Symbol('n', integer=True)
  477. f = Sum(2 * n * x, (n, 1, 3))
  478. d = Derivative(f, x)
  479. assert d.doit() == 12
  480. assert d.doit(deep=False) == Sum(2*n, (n, 1, 3))
  481. def test_evalf_default():
  482. from sympy.functions.special.gamma_functions import polygamma
  483. assert type(sin(4.0)) == Float
  484. assert type(re(sin(I + 1.0))) == Float
  485. assert type(im(sin(I + 1.0))) == Float
  486. assert type(sin(4)) == sin
  487. assert type(polygamma(2.0, 4.0)) == Float
  488. assert type(sin(Rational(1, 4))) == sin
  489. def test_issue_5399():
  490. args = [x, y, S(2), S.Half]
  491. def ok(a):
  492. """Return True if the input args for diff are ok"""
  493. if not a:
  494. return False
  495. if a[0].is_Symbol is False:
  496. return False
  497. s_at = [i for i in range(len(a)) if a[i].is_Symbol]
  498. n_at = [i for i in range(len(a)) if not a[i].is_Symbol]
  499. # every symbol is followed by symbol or int
  500. # every number is followed by a symbol
  501. return (all(a[i + 1].is_Symbol or a[i + 1].is_Integer
  502. for i in s_at if i + 1 < len(a)) and
  503. all(a[i + 1].is_Symbol
  504. for i in n_at if i + 1 < len(a)))
  505. eq = x**10*y**8
  506. for a in subsets(args):
  507. for v in variations(a, len(a)):
  508. if ok(v):
  509. eq.diff(*v) # does not raise
  510. else:
  511. raises(ValueError, lambda: eq.diff(*v))
  512. def test_derivative_numerically():
  513. z0 = x._random()
  514. assert abs(Derivative(sin(x), x).doit_numerically(z0) - cos(z0)) < 1e-15
  515. def test_fdiff_argument_index_error():
  516. from sympy.core.function import ArgumentIndexError
  517. class myfunc(Function):
  518. nargs = 1 # define since there is no eval routine
  519. def fdiff(self, idx):
  520. raise ArgumentIndexError
  521. mf = myfunc(x)
  522. assert mf.diff(x) == Derivative(mf, x)
  523. raises(TypeError, lambda: myfunc(x, x))
  524. def test_deriv_wrt_function():
  525. x = f(t)
  526. xd = diff(x, t)
  527. xdd = diff(xd, t)
  528. y = g(t)
  529. yd = diff(y, t)
  530. assert diff(x, t) == xd
  531. assert diff(2 * x + 4, t) == 2 * xd
  532. assert diff(2 * x + 4 + y, t) == 2 * xd + yd
  533. assert diff(2 * x + 4 + y * x, t) == 2 * xd + x * yd + xd * y
  534. assert diff(2 * x + 4 + y * x, x) == 2 + y
  535. assert (diff(4 * x**2 + 3 * x + x * y, t) == 3 * xd + x * yd + xd * y +
  536. 8 * x * xd)
  537. assert (diff(4 * x**2 + 3 * xd + x * y, t) == 3 * xdd + x * yd + xd * y +
  538. 8 * x * xd)
  539. assert diff(4 * x**2 + 3 * xd + x * y, xd) == 3
  540. assert diff(4 * x**2 + 3 * xd + x * y, xdd) == 0
  541. assert diff(sin(x), t) == xd * cos(x)
  542. assert diff(exp(x), t) == xd * exp(x)
  543. assert diff(sqrt(x), t) == xd / (2 * sqrt(x))
  544. def test_diff_wrt_value():
  545. assert Expr()._diff_wrt is False
  546. assert x._diff_wrt is True
  547. assert f(x)._diff_wrt is True
  548. assert Derivative(f(x), x)._diff_wrt is True
  549. assert Derivative(x**2, x)._diff_wrt is False
  550. def test_diff_wrt():
  551. fx = f(x)
  552. dfx = diff(f(x), x)
  553. ddfx = diff(f(x), x, x)
  554. assert diff(sin(fx) + fx**2, fx) == cos(fx) + 2*fx
  555. assert diff(sin(dfx) + dfx**2, dfx) == cos(dfx) + 2*dfx
  556. assert diff(sin(ddfx) + ddfx**2, ddfx) == cos(ddfx) + 2*ddfx
  557. assert diff(fx**2, dfx) == 0
  558. assert diff(fx**2, ddfx) == 0
  559. assert diff(dfx**2, fx) == 0
  560. assert diff(dfx**2, ddfx) == 0
  561. assert diff(ddfx**2, dfx) == 0
  562. assert diff(fx*dfx*ddfx, fx) == dfx*ddfx
  563. assert diff(fx*dfx*ddfx, dfx) == fx*ddfx
  564. assert diff(fx*dfx*ddfx, ddfx) == fx*dfx
  565. assert diff(f(x), x).diff(f(x)) == 0
  566. assert (sin(f(x)) - cos(diff(f(x), x))).diff(f(x)) == cos(f(x))
  567. assert diff(sin(fx), fx, x) == diff(sin(fx), x, fx)
  568. # Chain rule cases
  569. assert f(g(x)).diff(x) == \
  570. Derivative(g(x), x)*Derivative(f(g(x)), g(x))
  571. assert diff(f(g(x), h(y)), x) == \
  572. Derivative(g(x), x)*Derivative(f(g(x), h(y)), g(x))
  573. assert diff(f(g(x), h(x)), x) == (
  574. Derivative(f(g(x), h(x)), g(x))*Derivative(g(x), x) +
  575. Derivative(f(g(x), h(x)), h(x))*Derivative(h(x), x))
  576. assert f(
  577. sin(x)).diff(x) == cos(x)*Subs(Derivative(f(x), x), x, sin(x))
  578. assert diff(f(g(x)), g(x)) == Derivative(f(g(x)), g(x))
  579. def test_diff_wrt_func_subs():
  580. assert f(g(x)).diff(x).subs(g, Lambda(x, 2*x)).doit() == f(2*x).diff(x)
  581. def test_subs_in_derivative():
  582. expr = sin(x*exp(y))
  583. u = Function('u')
  584. v = Function('v')
  585. assert Derivative(expr, y).subs(expr, y) == Derivative(y, y)
  586. assert Derivative(expr, y).subs(y, x).doit() == \
  587. Derivative(expr, y).doit().subs(y, x)
  588. assert Derivative(f(x, y), y).subs(y, x) == Subs(Derivative(f(x, y), y), y, x)
  589. assert Derivative(f(x, y), y).subs(x, y) == Subs(Derivative(f(x, y), y), x, y)
  590. assert Derivative(f(x, y), y).subs(y, g(x, y)) == Subs(Derivative(f(x, y), y), y, g(x, y)).doit()
  591. assert Derivative(f(x, y), y).subs(x, g(x, y)) == Subs(Derivative(f(x, y), y), x, g(x, y))
  592. assert Derivative(f(x, y), g(y)).subs(x, g(x, y)) == Derivative(f(g(x, y), y), g(y))
  593. assert Derivative(f(u(x), h(y)), h(y)).subs(h(y), g(x, y)) == \
  594. Subs(Derivative(f(u(x), h(y)), h(y)), h(y), g(x, y)).doit()
  595. assert Derivative(f(x, y), y).subs(y, z) == Derivative(f(x, z), z)
  596. assert Derivative(f(x, y), y).subs(y, g(y)) == Derivative(f(x, g(y)), g(y))
  597. assert Derivative(f(g(x), h(y)), h(y)).subs(h(y), u(y)) == \
  598. Derivative(f(g(x), u(y)), u(y))
  599. assert Derivative(f(x, f(x, x)), f(x, x)).subs(
  600. f, Lambda((x, y), x + y)) == Subs(
  601. Derivative(z + x, z), z, 2*x)
  602. assert Subs(Derivative(f(f(x)), x), f, cos).doit() == sin(x)*sin(cos(x))
  603. assert Subs(Derivative(f(f(x)), f(x)), f, cos).doit() == -sin(cos(x))
  604. # Issue 13791. No comparison (it's a long formula) but this used to raise an exception.
  605. assert isinstance(v(x, y, u(x, y)).diff(y).diff(x).diff(y), Expr)
  606. # This is also related to issues 13791 and 13795; issue 15190
  607. F = Lambda((x, y), exp(2*x + 3*y))
  608. abstract = f(x, f(x, x)).diff(x, 2)
  609. concrete = F(x, F(x, x)).diff(x, 2)
  610. assert (abstract.subs(f, F).doit() - concrete).simplify() == 0
  611. # don't introduce a new symbol if not necessary
  612. assert x in f(x).diff(x).subs(x, 0).atoms()
  613. # case (4)
  614. assert Derivative(f(x,f(x,y)), x, y).subs(x, g(y)
  615. ) == Subs(Derivative(f(x, f(x, y)), x, y), x, g(y))
  616. assert Derivative(f(x, x), x).subs(x, 0
  617. ) == Subs(Derivative(f(x, x), x), x, 0)
  618. # issue 15194
  619. assert Derivative(f(y, g(x)), (x, z)).subs(z, x
  620. ) == Derivative(f(y, g(x)), (x, x))
  621. df = f(x).diff(x)
  622. assert df.subs(df, 1) is S.One
  623. assert df.diff(df) is S.One
  624. dxy = Derivative(f(x, y), x, y)
  625. dyx = Derivative(f(x, y), y, x)
  626. assert dxy.subs(Derivative(f(x, y), y, x), 1) is S.One
  627. assert dxy.diff(dyx) is S.One
  628. assert Derivative(f(x, y), x, 2, y, 3).subs(
  629. dyx, g(x, y)) == Derivative(g(x, y), x, 1, y, 2)
  630. assert Derivative(f(x, x - y), y).subs(x, x + y) == Subs(
  631. Derivative(f(x, x - y), y), x, x + y)
  632. def test_diff_wrt_not_allowed():
  633. # issue 7027 included
  634. for wrt in (
  635. cos(x), re(x), x**2, x*y, 1 + x,
  636. Derivative(cos(x), x), Derivative(f(f(x)), x)):
  637. raises(ValueError, lambda: diff(f(x), wrt))
  638. # if we don't differentiate wrt then don't raise error
  639. assert diff(exp(x*y), x*y, 0) == exp(x*y)
  640. def test_diff_wrt_intlike():
  641. class Two:
  642. def __int__(self):
  643. return 2
  644. assert cos(x).diff(x, Two()) == -cos(x)
  645. def test_klein_gordon_lagrangian():
  646. m = Symbol('m')
  647. phi = f(x, t)
  648. L = -(diff(phi, t)**2 - diff(phi, x)**2 - m**2*phi**2)/2
  649. eqna = Eq(
  650. diff(L, phi) - diff(L, diff(phi, x), x) - diff(L, diff(phi, t), t), 0)
  651. eqnb = Eq(diff(phi, t, t) - diff(phi, x, x) + m**2*phi, 0)
  652. assert eqna == eqnb
  653. def test_sho_lagrangian():
  654. m = Symbol('m')
  655. k = Symbol('k')
  656. x = f(t)
  657. L = m*diff(x, t)**2/2 - k*x**2/2
  658. eqna = Eq(diff(L, x), diff(L, diff(x, t), t))
  659. eqnb = Eq(-k*x, m*diff(x, t, t))
  660. assert eqna == eqnb
  661. assert diff(L, x, t) == diff(L, t, x)
  662. assert diff(L, diff(x, t), t) == m*diff(x, t, 2)
  663. assert diff(L, t, diff(x, t)) == -k*x + m*diff(x, t, 2)
  664. def test_straight_line():
  665. F = f(x)
  666. Fd = F.diff(x)
  667. L = sqrt(1 + Fd**2)
  668. assert diff(L, F) == 0
  669. assert diff(L, Fd) == Fd/sqrt(1 + Fd**2)
  670. def test_sort_variable():
  671. vsort = Derivative._sort_variable_count
  672. def vsort0(*v, reverse=False):
  673. return [i[0] for i in vsort([(i, 0) for i in (
  674. reversed(v) if reverse else v)])]
  675. for R in range(2):
  676. assert vsort0(y, x, reverse=R) == [x, y]
  677. assert vsort0(f(x), x, reverse=R) == [x, f(x)]
  678. assert vsort0(f(y), f(x), reverse=R) == [f(x), f(y)]
  679. assert vsort0(g(x), f(y), reverse=R) == [f(y), g(x)]
  680. assert vsort0(f(x, y), f(x), reverse=R) == [f(x), f(x, y)]
  681. fx = f(x).diff(x)
  682. assert vsort0(fx, y, reverse=R) == [y, fx]
  683. fy = f(y).diff(y)
  684. assert vsort0(fy, fx, reverse=R) == [fx, fy]
  685. fxx = fx.diff(x)
  686. assert vsort0(fxx, fx, reverse=R) == [fx, fxx]
  687. assert vsort0(Basic(x), f(x), reverse=R) == [f(x), Basic(x)]
  688. assert vsort0(Basic(y), Basic(x), reverse=R) == [Basic(x), Basic(y)]
  689. assert vsort0(Basic(y, z), Basic(x), reverse=R) == [
  690. Basic(x), Basic(y, z)]
  691. assert vsort0(fx, x, reverse=R) == [
  692. x, fx] if R else [fx, x]
  693. assert vsort0(Basic(x), x, reverse=R) == [
  694. x, Basic(x)] if R else [Basic(x), x]
  695. assert vsort0(Basic(f(x)), f(x), reverse=R) == [
  696. f(x), Basic(f(x))] if R else [Basic(f(x)), f(x)]
  697. assert vsort0(Basic(x, z), Basic(x), reverse=R) == [
  698. Basic(x), Basic(x, z)] if R else [Basic(x, z), Basic(x)]
  699. assert vsort([]) == []
  700. assert _aresame(vsort([(x, 1)]), [Tuple(x, 1)])
  701. assert vsort([(x, y), (x, z)]) == [(x, y + z)]
  702. assert vsort([(y, 1), (x, 1 + y)]) == [(x, 1 + y), (y, 1)]
  703. # coverage complete; legacy tests below
  704. assert vsort([(x, 3), (y, 2), (z, 1)]) == [(x, 3), (y, 2), (z, 1)]
  705. assert vsort([(h(x), 1), (g(x), 1), (f(x), 1)]) == [
  706. (f(x), 1), (g(x), 1), (h(x), 1)]
  707. assert vsort([(z, 1), (y, 2), (x, 3), (h(x), 1), (g(x), 1),
  708. (f(x), 1)]) == [(x, 3), (y, 2), (z, 1), (f(x), 1), (g(x), 1),
  709. (h(x), 1)]
  710. assert vsort([(x, 1), (f(x), 1), (y, 1), (f(y), 1)]) == [(x, 1),
  711. (y, 1), (f(x), 1), (f(y), 1)]
  712. assert vsort([(y, 1), (x, 2), (g(x), 1), (f(x), 1), (z, 1),
  713. (h(x), 1), (y, 2), (x, 1)]) == [(x, 3), (y, 3), (z, 1),
  714. (f(x), 1), (g(x), 1), (h(x), 1)]
  715. assert vsort([(z, 1), (y, 1), (f(x), 1), (x, 1), (f(x), 1),
  716. (g(x), 1)]) == [(x, 1), (y, 1), (z, 1), (f(x), 2), (g(x), 1)]
  717. assert vsort([(z, 1), (y, 2), (f(x), 1), (x, 2), (f(x), 2),
  718. (g(x), 1), (z, 2), (z, 1), (y, 1), (x, 1)]) == [(x, 3), (y, 3),
  719. (z, 4), (f(x), 3), (g(x), 1)]
  720. assert vsort(((y, 2), (x, 1), (y, 1), (x, 1))) == [(x, 2), (y, 3)]
  721. assert isinstance(vsort([(x, 3), (y, 2), (z, 1)])[0], Tuple)
  722. assert vsort([(x, 1), (f(x), 1), (x, 1)]) == [(x, 2), (f(x), 1)]
  723. assert vsort([(y, 2), (x, 3), (z, 1)]) == [(x, 3), (y, 2), (z, 1)]
  724. assert vsort([(h(y), 1), (g(x), 1), (f(x), 1)]) == [
  725. (f(x), 1), (g(x), 1), (h(y), 1)]
  726. assert vsort([(x, 1), (y, 1), (x, 1)]) == [(x, 2), (y, 1)]
  727. assert vsort([(f(x), 1), (f(y), 1), (f(x), 1)]) == [
  728. (f(x), 2), (f(y), 1)]
  729. dfx = f(x).diff(x)
  730. self = [(dfx, 1), (x, 1)]
  731. assert vsort(self) == self
  732. assert vsort([
  733. (dfx, 1), (y, 1), (f(x), 1), (x, 1), (f(y), 1), (x, 1)]) == [
  734. (y, 1), (f(x), 1), (f(y), 1), (dfx, 1), (x, 2)]
  735. dfy = f(y).diff(y)
  736. assert vsort([(dfy, 1), (dfx, 1)]) == [(dfx, 1), (dfy, 1)]
  737. d2fx = dfx.diff(x)
  738. assert vsort([(d2fx, 1), (dfx, 1)]) == [(dfx, 1), (d2fx, 1)]
  739. def test_multiple_derivative():
  740. # Issue #15007
  741. assert f(x, y).diff(y, y, x, y, x
  742. ) == Derivative(f(x, y), (x, 2), (y, 3))
  743. def test_unhandled():
  744. class MyExpr(Expr):
  745. def _eval_derivative(self, s):
  746. if not s.name.startswith('xi'):
  747. return self
  748. else:
  749. return None
  750. eq = MyExpr(f(x), y, z)
  751. assert diff(eq, x, y, f(x), z) == Derivative(eq, f(x))
  752. assert diff(eq, f(x), x) == Derivative(eq, f(x))
  753. assert f(x, y).diff(x,(y, z)) == Derivative(f(x, y), x, (y, z))
  754. assert f(x, y).diff(x,(y, 0)) == Derivative(f(x, y), x)
  755. def test_nfloat():
  756. from sympy.core.basic import _aresame
  757. from sympy.polys.rootoftools import rootof
  758. x = Symbol("x")
  759. eq = x**Rational(4, 3) + 4*x**(S.One/3)/3
  760. assert _aresame(nfloat(eq), x**Rational(4, 3) + (4.0/3)*x**(S.One/3))
  761. assert _aresame(nfloat(eq, exponent=True), x**(4.0/3) + (4.0/3)*x**(1.0/3))
  762. eq = x**Rational(4, 3) + 4*x**(x/3)/3
  763. assert _aresame(nfloat(eq), x**Rational(4, 3) + (4.0/3)*x**(x/3))
  764. big = 12345678901234567890
  765. # specify precision to match value used in nfloat
  766. Float_big = Float(big, 15)
  767. assert _aresame(nfloat(big), Float_big)
  768. assert _aresame(nfloat(big*x), Float_big*x)
  769. assert _aresame(nfloat(x**big, exponent=True), x**Float_big)
  770. assert nfloat(cos(x + sqrt(2))) == cos(x + nfloat(sqrt(2)))
  771. # issue 6342
  772. f = S('x*lamda + lamda**3*(x/2 + 1/2) + lamda**2 + 1/4')
  773. assert not any(a.free_symbols for a in solveset(f.subs(x, -0.139)))
  774. # issue 6632
  775. assert nfloat(-100000*sqrt(2500000001) + 5000000001) == \
  776. 9.99999999800000e-11
  777. # issue 7122
  778. eq = cos(3*x**4 + y)*rootof(x**5 + 3*x**3 + 1, 0)
  779. assert str(nfloat(eq, exponent=False, n=1)) == '-0.7*cos(3.0*x**4 + y)'
  780. # issue 10933
  781. for ti in (dict, Dict):
  782. d = ti({S.Half: S.Half})
  783. n = nfloat(d)
  784. assert isinstance(n, ti)
  785. assert _aresame(list(n.items()).pop(), (S.Half, Float(.5)))
  786. for ti in (dict, Dict):
  787. d = ti({S.Half: S.Half})
  788. n = nfloat(d, dkeys=True)
  789. assert isinstance(n, ti)
  790. assert _aresame(list(n.items()).pop(), (Float(.5), Float(.5)))
  791. d = [S.Half]
  792. n = nfloat(d)
  793. assert type(n) is list
  794. assert _aresame(n[0], Float(.5))
  795. assert _aresame(nfloat(Eq(x, S.Half)).rhs, Float(.5))
  796. assert _aresame(nfloat(S(True)), S(True))
  797. assert _aresame(nfloat(Tuple(S.Half))[0], Float(.5))
  798. assert nfloat(Eq((3 - I)**2/2 + I, 0)) == S.false
  799. # pass along kwargs
  800. assert nfloat([{S.Half: x}], dkeys=True) == [{Float(0.5): x}]
  801. # Issue 17706
  802. A = MutableMatrix([[1, 2], [3, 4]])
  803. B = MutableMatrix(
  804. [[Float('1.0', precision=53), Float('2.0', precision=53)],
  805. [Float('3.0', precision=53), Float('4.0', precision=53)]])
  806. assert _aresame(nfloat(A), B)
  807. A = ImmutableMatrix([[1, 2], [3, 4]])
  808. B = ImmutableMatrix(
  809. [[Float('1.0', precision=53), Float('2.0', precision=53)],
  810. [Float('3.0', precision=53), Float('4.0', precision=53)]])
  811. assert _aresame(nfloat(A), B)
  812. # issue 22524
  813. f = Function('f')
  814. assert not nfloat(f(2)).atoms(Float)
  815. def test_issue_7068():
  816. from sympy.abc import a, b
  817. f = Function('f')
  818. y1 = Dummy('y')
  819. y2 = Dummy('y')
  820. func1 = f(a + y1 * b)
  821. func2 = f(a + y2 * b)
  822. func1_y = func1.diff(y1)
  823. func2_y = func2.diff(y2)
  824. assert func1_y != func2_y
  825. z1 = Subs(f(a), a, y1)
  826. z2 = Subs(f(a), a, y2)
  827. assert z1 != z2
  828. def test_issue_7231():
  829. from sympy.abc import a
  830. ans1 = f(x).series(x, a)
  831. res = (f(a) + (-a + x)*Subs(Derivative(f(y), y), y, a) +
  832. (-a + x)**2*Subs(Derivative(f(y), y, y), y, a)/2 +
  833. (-a + x)**3*Subs(Derivative(f(y), y, y, y),
  834. y, a)/6 +
  835. (-a + x)**4*Subs(Derivative(f(y), y, y, y, y),
  836. y, a)/24 +
  837. (-a + x)**5*Subs(Derivative(f(y), y, y, y, y, y),
  838. y, a)/120 + O((-a + x)**6, (x, a)))
  839. assert res == ans1
  840. ans2 = f(x).series(x, a)
  841. assert res == ans2
  842. def test_issue_7687():
  843. from sympy.core.function import Function
  844. from sympy.abc import x
  845. f = Function('f')(x)
  846. ff = Function('f')(x)
  847. match_with_cache = ff.matches(f)
  848. assert isinstance(f, type(ff))
  849. clear_cache()
  850. ff = Function('f')(x)
  851. assert isinstance(f, type(ff))
  852. assert match_with_cache == ff.matches(f)
  853. def test_issue_7688():
  854. from sympy.core.function import Function, UndefinedFunction
  855. f = Function('f') # actually an UndefinedFunction
  856. clear_cache()
  857. class A(UndefinedFunction):
  858. pass
  859. a = A('f')
  860. assert isinstance(a, type(f))
  861. def test_mexpand():
  862. from sympy.abc import x
  863. assert _mexpand(None) is None
  864. assert _mexpand(1) is S.One
  865. assert _mexpand(x*(x + 1)**2) == (x*(x + 1)**2).expand()
  866. def test_issue_8469():
  867. # This should not take forever to run
  868. N = 40
  869. def g(w, theta):
  870. return 1/(1+exp(w-theta))
  871. ws = symbols(['w%i'%i for i in range(N)])
  872. import functools
  873. expr = functools.reduce(g, ws)
  874. assert isinstance(expr, Pow)
  875. def test_issue_12996():
  876. # foo=True imitates the sort of arguments that Derivative can get
  877. # from Integral when it passes doit to the expression
  878. assert Derivative(im(x), x).doit(foo=True) == Derivative(im(x), x)
  879. def test_should_evalf():
  880. # This should not take forever to run (see #8506)
  881. assert isinstance(sin((1.0 + 1.0*I)**10000 + 1), sin)
  882. def test_Derivative_as_finite_difference():
  883. # Central 1st derivative at gridpoint
  884. x, h = symbols('x h', real=True)
  885. dfdx = f(x).diff(x)
  886. assert (dfdx.as_finite_difference([x-2, x-1, x, x+1, x+2]) -
  887. (S.One/12*(f(x-2)-f(x+2)) + Rational(2, 3)*(f(x+1)-f(x-1)))).simplify() == 0
  888. # Central 1st derivative "half-way"
  889. assert (dfdx.as_finite_difference() -
  890. (f(x + S.Half)-f(x - S.Half))).simplify() == 0
  891. assert (dfdx.as_finite_difference(h) -
  892. (f(x + h/S(2))-f(x - h/S(2)))/h).simplify() == 0
  893. assert (dfdx.as_finite_difference([x - 3*h, x-h, x+h, x + 3*h]) -
  894. (S(9)/(8*2*h)*(f(x+h) - f(x-h)) +
  895. S.One/(24*2*h)*(f(x - 3*h) - f(x + 3*h)))).simplify() == 0
  896. # One sided 1st derivative at gridpoint
  897. assert (dfdx.as_finite_difference([0, 1, 2], 0) -
  898. (Rational(-3, 2)*f(0) + 2*f(1) - f(2)/2)).simplify() == 0
  899. assert (dfdx.as_finite_difference([x, x+h], x) -
  900. (f(x+h) - f(x))/h).simplify() == 0
  901. assert (dfdx.as_finite_difference([x-h, x, x+h], x-h) -
  902. (-S(3)/(2*h)*f(x-h) + 2/h*f(x) -
  903. S.One/(2*h)*f(x+h))).simplify() == 0
  904. # One sided 1st derivative "half-way"
  905. assert (dfdx.as_finite_difference([x-h, x+h, x + 3*h, x + 5*h, x + 7*h])
  906. - 1/(2*h)*(-S(11)/(12)*f(x-h) + S(17)/(24)*f(x+h)
  907. + Rational(3, 8)*f(x + 3*h) - Rational(5, 24)*f(x + 5*h)
  908. + S.One/24*f(x + 7*h))).simplify() == 0
  909. d2fdx2 = f(x).diff(x, 2)
  910. # Central 2nd derivative at gridpoint
  911. assert (d2fdx2.as_finite_difference([x-h, x, x+h]) -
  912. h**-2 * (f(x-h) + f(x+h) - 2*f(x))).simplify() == 0
  913. assert (d2fdx2.as_finite_difference([x - 2*h, x-h, x, x+h, x + 2*h]) -
  914. h**-2 * (Rational(-1, 12)*(f(x - 2*h) + f(x + 2*h)) +
  915. Rational(4, 3)*(f(x+h) + f(x-h)) - Rational(5, 2)*f(x))).simplify() == 0
  916. # Central 2nd derivative "half-way"
  917. assert (d2fdx2.as_finite_difference([x - 3*h, x-h, x+h, x + 3*h]) -
  918. (2*h)**-2 * (S.Half*(f(x - 3*h) + f(x + 3*h)) -
  919. S.Half*(f(x+h) + f(x-h)))).simplify() == 0
  920. # One sided 2nd derivative at gridpoint
  921. assert (d2fdx2.as_finite_difference([x, x+h, x + 2*h, x + 3*h]) -
  922. h**-2 * (2*f(x) - 5*f(x+h) +
  923. 4*f(x+2*h) - f(x+3*h))).simplify() == 0
  924. # One sided 2nd derivative at "half-way"
  925. assert (d2fdx2.as_finite_difference([x-h, x+h, x + 3*h, x + 5*h]) -
  926. (2*h)**-2 * (Rational(3, 2)*f(x-h) - Rational(7, 2)*f(x+h) + Rational(5, 2)*f(x + 3*h) -
  927. S.Half*f(x + 5*h))).simplify() == 0
  928. d3fdx3 = f(x).diff(x, 3)
  929. # Central 3rd derivative at gridpoint
  930. assert (d3fdx3.as_finite_difference() -
  931. (-f(x - Rational(3, 2)) + 3*f(x - S.Half) -
  932. 3*f(x + S.Half) + f(x + Rational(3, 2)))).simplify() == 0
  933. assert (d3fdx3.as_finite_difference(
  934. [x - 3*h, x - 2*h, x-h, x, x+h, x + 2*h, x + 3*h]) -
  935. h**-3 * (S.One/8*(f(x - 3*h) - f(x + 3*h)) - f(x - 2*h) +
  936. f(x + 2*h) + Rational(13, 8)*(f(x-h) - f(x+h)))).simplify() == 0
  937. # Central 3rd derivative at "half-way"
  938. assert (d3fdx3.as_finite_difference([x - 3*h, x-h, x+h, x + 3*h]) -
  939. (2*h)**-3 * (f(x + 3*h)-f(x - 3*h) +
  940. 3*(f(x-h)-f(x+h)))).simplify() == 0
  941. # One sided 3rd derivative at gridpoint
  942. assert (d3fdx3.as_finite_difference([x, x+h, x + 2*h, x + 3*h]) -
  943. h**-3 * (f(x + 3*h)-f(x) + 3*(f(x+h)-f(x + 2*h)))).simplify() == 0
  944. # One sided 3rd derivative at "half-way"
  945. assert (d3fdx3.as_finite_difference([x-h, x+h, x + 3*h, x + 5*h]) -
  946. (2*h)**-3 * (f(x + 5*h)-f(x-h) +
  947. 3*(f(x+h)-f(x + 3*h)))).simplify() == 0
  948. # issue 11007
  949. y = Symbol('y', real=True)
  950. d2fdxdy = f(x, y).diff(x, y)
  951. ref0 = Derivative(f(x + S.Half, y), y) - Derivative(f(x - S.Half, y), y)
  952. assert (d2fdxdy.as_finite_difference(wrt=x) - ref0).simplify() == 0
  953. half = S.Half
  954. xm, xp, ym, yp = x-half, x+half, y-half, y+half
  955. ref2 = f(xm, ym) + f(xp, yp) - f(xp, ym) - f(xm, yp)
  956. assert (d2fdxdy.as_finite_difference() - ref2).simplify() == 0
  957. def test_issue_11159():
  958. # Tests Application._eval_subs
  959. with _exp_is_pow(False):
  960. expr1 = E
  961. expr0 = expr1 * expr1
  962. expr1 = expr0.subs(expr1,expr0)
  963. assert expr0 == expr1
  964. with _exp_is_pow(True):
  965. expr1 = E
  966. expr0 = expr1 * expr1
  967. expr2 = expr0.subs(expr1, expr0)
  968. assert expr2 == E ** 4
  969. def test_issue_12005():
  970. e1 = Subs(Derivative(f(x), x), x, x)
  971. assert e1.diff(x) == Derivative(f(x), x, x)
  972. e2 = Subs(Derivative(f(x), x), x, x**2 + 1)
  973. assert e2.diff(x) == 2*x*Subs(Derivative(f(x), x, x), x, x**2 + 1)
  974. e3 = Subs(Derivative(f(x) + y**2 - y, y), y, y**2)
  975. assert e3.diff(y) == 4*y
  976. e4 = Subs(Derivative(f(x + y), y), y, (x**2))
  977. assert e4.diff(y) is S.Zero
  978. e5 = Subs(Derivative(f(x), x), (y, z), (y, z))
  979. assert e5.diff(x) == Derivative(f(x), x, x)
  980. assert f(g(x)).diff(g(x), g(x)) == Derivative(f(g(x)), g(x), g(x))
  981. def test_issue_13843():
  982. x = symbols('x')
  983. f = Function('f')
  984. m, n = symbols('m n', integer=True)
  985. assert Derivative(Derivative(f(x), (x, m)), (x, n)) == Derivative(f(x), (x, m + n))
  986. assert Derivative(Derivative(f(x), (x, m+5)), (x, n+3)) == Derivative(f(x), (x, m + n + 8))
  987. assert Derivative(f(x), (x, n)).doit() == Derivative(f(x), (x, n))
  988. def test_order_could_be_zero():
  989. x, y = symbols('x, y')
  990. n = symbols('n', integer=True, nonnegative=True)
  991. m = symbols('m', integer=True, positive=True)
  992. assert diff(y, (x, n)) == Piecewise((y, Eq(n, 0)), (0, True))
  993. assert diff(y, (x, n + 1)) is S.Zero
  994. assert diff(y, (x, m)) is S.Zero
  995. def test_undefined_function_eq():
  996. f = Function('f')
  997. f2 = Function('f')
  998. g = Function('g')
  999. f_real = Function('f', is_real=True)
  1000. # This test may only be meaningful if the cache is turned off
  1001. assert f == f2
  1002. assert hash(f) == hash(f2)
  1003. assert f == f
  1004. assert f != g
  1005. assert f != f_real
  1006. def test_function_assumptions():
  1007. x = Symbol('x')
  1008. f = Function('f')
  1009. f_real = Function('f', real=True)
  1010. f_real1 = Function('f', real=1)
  1011. f_real_inherit = Function(Symbol('f', real=True))
  1012. assert f_real == f_real1 # assumptions are sanitized
  1013. assert f != f_real
  1014. assert f(x) != f_real(x)
  1015. assert f(x).is_real is None
  1016. assert f_real(x).is_real is True
  1017. assert f_real_inherit(x).is_real is True and f_real_inherit.name == 'f'
  1018. # Can also do it this way, but it won't be equal to f_real because of the
  1019. # way UndefinedFunction.__new__ works. Any non-recognized assumptions
  1020. # are just added literally as something which is used in the hash
  1021. f_real2 = Function('f', is_real=True)
  1022. assert f_real2(x).is_real is True
  1023. def test_undef_fcn_float_issue_6938():
  1024. f = Function('ceil')
  1025. assert not f(0.3).is_number
  1026. f = Function('sin')
  1027. assert not f(0.3).is_number
  1028. assert not f(pi).evalf().is_number
  1029. x = Symbol('x')
  1030. assert not f(x).evalf(subs={x:1.2}).is_number
  1031. def test_undefined_function_eval():
  1032. # Issue 15170. Make sure UndefinedFunction with eval defined works
  1033. # properly.
  1034. fdiff = lambda self, argindex=1: cos(self.args[argindex - 1])
  1035. eval = classmethod(lambda cls, t: None)
  1036. _imp_ = classmethod(lambda cls, t: sin(t))
  1037. temp = Function('temp', fdiff=fdiff, eval=eval, _imp_=_imp_)
  1038. expr = temp(t)
  1039. assert sympify(expr) == expr
  1040. assert type(sympify(expr)).fdiff.__name__ == "<lambda>"
  1041. assert expr.diff(t) == cos(t)
  1042. def test_issue_15241():
  1043. F = f(x)
  1044. Fx = F.diff(x)
  1045. assert (F + x*Fx).diff(x, Fx) == 2
  1046. assert (F + x*Fx).diff(Fx, x) == 1
  1047. assert (x*F + x*Fx*F).diff(F, x) == x*Fx.diff(x) + Fx + 1
  1048. assert (x*F + x*Fx*F).diff(x, F) == x*Fx.diff(x) + Fx + 1
  1049. y = f(x)
  1050. G = f(y)
  1051. Gy = G.diff(y)
  1052. assert (G + y*Gy).diff(y, Gy) == 2
  1053. assert (G + y*Gy).diff(Gy, y) == 1
  1054. assert (y*G + y*Gy*G).diff(G, y) == y*Gy.diff(y) + Gy + 1
  1055. assert (y*G + y*Gy*G).diff(y, G) == y*Gy.diff(y) + Gy + 1
  1056. def test_issue_15226():
  1057. assert Subs(Derivative(f(y), x, y), y, g(x)).doit() != 0
  1058. def test_issue_7027():
  1059. for wrt in (cos(x), re(x), Derivative(cos(x), x)):
  1060. raises(ValueError, lambda: diff(f(x), wrt))
  1061. def test_derivative_quick_exit():
  1062. assert f(x).diff(y) == 0
  1063. assert f(x).diff(y, f(x)) == 0
  1064. assert f(x).diff(x, f(y)) == 0
  1065. assert f(f(x)).diff(x, f(x), f(y)) == 0
  1066. assert f(f(x)).diff(x, f(x), y) == 0
  1067. assert f(x).diff(g(x)) == 0
  1068. assert f(x).diff(x, f(x).diff(x)) == 1
  1069. df = f(x).diff(x)
  1070. assert f(x).diff(df) == 0
  1071. dg = g(x).diff(x)
  1072. assert dg.diff(df).doit() == 0
  1073. def test_issue_15084_13166():
  1074. eq = f(x, g(x))
  1075. assert eq.diff((g(x), y)) == Derivative(f(x, g(x)), (g(x), y))
  1076. # issue 13166
  1077. assert eq.diff(x, 2).doit() == (
  1078. (Derivative(f(x, g(x)), (g(x), 2))*Derivative(g(x), x) +
  1079. Subs(Derivative(f(x, _xi_2), _xi_2, x), _xi_2, g(x)))*Derivative(g(x),
  1080. x) + Derivative(f(x, g(x)), g(x))*Derivative(g(x), (x, 2)) +
  1081. Derivative(g(x), x)*Subs(Derivative(f(_xi_1, g(x)), _xi_1, g(x)),
  1082. _xi_1, x) + Subs(Derivative(f(_xi_1, g(x)), (_xi_1, 2)), _xi_1, x))
  1083. # issue 6681
  1084. assert diff(f(x, t, g(x, t)), x).doit() == (
  1085. Derivative(f(x, t, g(x, t)), g(x, t))*Derivative(g(x, t), x) +
  1086. Subs(Derivative(f(_xi_1, t, g(x, t)), _xi_1), _xi_1, x))
  1087. # make sure the order doesn't matter when using diff
  1088. assert eq.diff(x, g(x)) == eq.diff(g(x), x)
  1089. def test_negative_counts():
  1090. # issue 13873
  1091. raises(ValueError, lambda: sin(x).diff(x, -1))
  1092. def test_Derivative__new__():
  1093. raises(TypeError, lambda: f(x).diff((x, 2), 0))
  1094. assert f(x, y).diff([(x, y), 0]) == f(x, y)
  1095. assert f(x, y).diff([(x, y), 1]) == NDimArray([
  1096. Derivative(f(x, y), x), Derivative(f(x, y), y)])
  1097. assert f(x,y).diff(y, (x, z), y, x) == Derivative(
  1098. f(x, y), (x, z + 1), (y, 2))
  1099. assert Matrix([x]).diff(x, 2) == Matrix([0]) # is_zero exit
  1100. def test_issue_14719_10150():
  1101. class V(Expr):
  1102. _diff_wrt = True
  1103. is_scalar = False
  1104. assert V().diff(V()) == Derivative(V(), V())
  1105. assert (2*V()).diff(V()) == 2*Derivative(V(), V())
  1106. class X(Expr):
  1107. _diff_wrt = True
  1108. assert X().diff(X()) == 1
  1109. assert (2*X()).diff(X()) == 2
  1110. def test_noncommutative_issue_15131():
  1111. x = Symbol('x', commutative=False)
  1112. t = Symbol('t', commutative=False)
  1113. fx = Function('Fx', commutative=False)(x)
  1114. ft = Function('Ft', commutative=False)(t)
  1115. A = Symbol('A', commutative=False)
  1116. eq = fx * A * ft
  1117. eqdt = eq.diff(t)
  1118. assert eqdt.args[-1] == ft.diff(t)
  1119. def test_Subs_Derivative():
  1120. a = Derivative(f(g(x), h(x)), g(x), h(x),x)
  1121. b = Derivative(Derivative(f(g(x), h(x)), g(x), h(x)),x)
  1122. c = f(g(x), h(x)).diff(g(x), h(x), x)
  1123. d = f(g(x), h(x)).diff(g(x), h(x)).diff(x)
  1124. e = Derivative(f(g(x), h(x)), x)
  1125. eqs = (a, b, c, d, e)
  1126. subs = lambda arg: arg.subs(f, Lambda((x, y), exp(x + y))
  1127. ).subs(g(x), 1/x).subs(h(x), x**3)
  1128. ans = 3*x**2*exp(1/x)*exp(x**3) - exp(1/x)*exp(x**3)/x**2
  1129. assert all(subs(i).doit().expand() == ans for i in eqs)
  1130. assert all(subs(i.doit()).doit().expand() == ans for i in eqs)
  1131. def test_issue_15360():
  1132. f = Function('f')
  1133. assert f.name == 'f'
  1134. def test_issue_15947():
  1135. assert f._diff_wrt is False
  1136. raises(TypeError, lambda: f(f))
  1137. raises(TypeError, lambda: f(x).diff(f))
  1138. def test_Derivative_free_symbols():
  1139. f = Function('f')
  1140. n = Symbol('n', integer=True, positive=True)
  1141. assert diff(f(x), (x, n)).free_symbols == {n, x}
  1142. def test_issue_20683():
  1143. x = Symbol('x')
  1144. y = Symbol('y')
  1145. z = Symbol('z')
  1146. y = Derivative(z, x).subs(x,0)
  1147. assert y.doit() == 0
  1148. y = Derivative(8, x).subs(x,0)
  1149. assert y.doit() == 0
  1150. def test_issue_10503():
  1151. f = exp(x**3)*cos(x**6)
  1152. assert f.series(x, 0, 14) == 1 + x**3 + x**6/2 + x**9/6 - 11*x**12/24 + O(x**14)
  1153. def test_issue_17382():
  1154. # copied from sympy/core/tests/test_evalf.py
  1155. def NS(e, n=15, **options):
  1156. return sstr(sympify(e).evalf(n, **options), full_prec=True)
  1157. x = Symbol('x')
  1158. expr = solveset(2 * cos(x) * cos(2 * x) - 1, x, S.Reals)
  1159. expected = "Union(" \
  1160. "ImageSet(Lambda(_n, 6.28318530717959*_n + 5.79812359592087), Integers), " \
  1161. "ImageSet(Lambda(_n, 6.28318530717959*_n + 0.485061711258717), Integers))"
  1162. assert NS(expr) == expected
  1163. def test_eval_sympified():
  1164. # Check both arguments and return types from eval are sympified
  1165. class F(Function):
  1166. @classmethod
  1167. def eval(cls, x):
  1168. assert x is S.One
  1169. return 1
  1170. assert F(1) is S.One
  1171. # String arguments are not allowed
  1172. class F2(Function):
  1173. @classmethod
  1174. def eval(cls, x):
  1175. if x == 0:
  1176. return '1'
  1177. raises(SympifyError, lambda: F2(0))
  1178. F2(1) # Doesn't raise
  1179. # TODO: Disable string inputs (https://github.com/sympy/sympy/issues/11003)
  1180. # raises(SympifyError, lambda: F2('2'))
  1181. def test_eval_classmethod_check():
  1182. with raises(TypeError):
  1183. class F(Function):
  1184. def eval(self, x):
  1185. pass