test_subs.py 29 KB


  1. from sympy.calculus.accumulationbounds import AccumBounds
  2. from sympy.core.add import Add
  3. from sympy.core.basic import Basic
  4. from sympy.core.containers import (Dict, Tuple)
  5. from sympy.core.function import (Derivative, Function, Lambda, Subs)
  6. from sympy.core.mul import Mul
  7. from sympy.core.numbers import (Float, I, Integer, Rational, oo, pi, zoo)
  8. from sympy.core.relational import Eq
  9. from sympy.core.singleton import S
  10. from sympy.core.symbol import (Symbol, Wild, symbols)
  11. from sympy.core.sympify import SympifyError
  12. from sympy.functions.elementary.exponential import (exp, log)
  13. from sympy.functions.elementary.miscellaneous import sqrt
  14. from sympy.functions.elementary.piecewise import Piecewise
  15. from sympy.functions.elementary.trigonometric import (atan2, cos, cot, sin, tan)
  16. from sympy.matrices.dense import (Matrix, zeros)
  17. from sympy.matrices.expressions.special import ZeroMatrix
  18. from sympy.polys.polytools import factor
  19. from sympy.polys.rootoftools import RootOf
  20. from sympy.simplify.cse_main import cse
  21. from sympy.simplify.simplify import nsimplify
  22. from sympy.core.basic import _aresame
  23. from sympy.testing.pytest import XFAIL, raises
  24. from sympy.abc import a, x, y, z, t
  25. def test_subs():
  26. n3 = Rational(3)
  27. e = x
  28. e = e.subs(x, n3)
  29. assert e == Rational(3)
  30. e = 2*x
  31. assert e == 2*x
  32. e = e.subs(x, n3)
  33. assert e == Rational(6)
  34. def test_subs_Matrix():
  35. z = zeros(2)
  36. z1 = ZeroMatrix(2, 2)
  37. assert (x*y).subs({x:z, y:0}) in [z, z1]
  38. assert (x*y).subs({y:z, x:0}) == 0
  39. assert (x*y).subs({y:z, x:0}, simultaneous=True) in [z, z1]
  40. assert (x + y).subs({x: z, y: z}, simultaneous=True) in [z, z1]
  41. assert (x + y).subs({x: z, y: z}) in [z, z1]
  42. # Issue #15528
  43. assert Mul(Matrix([[3]]), x).subs(x, 2.0) == Matrix([[6.0]])
  44. # Does not raise a TypeError, see comment on the MatAdd postprocessor
  45. assert Add(Matrix([[3]]), x).subs(x, 2.0) == Add(Matrix([[3]]), 2.0)
  46. def test_subs_AccumBounds():
  47. e = x
  48. e = e.subs(x, AccumBounds(1, 3))
  49. assert e == AccumBounds(1, 3)
  50. e = 2*x
  51. e = e.subs(x, AccumBounds(1, 3))
  52. assert e == AccumBounds(2, 6)
  53. e = x + x**2
  54. e = e.subs(x, AccumBounds(-1, 1))
  55. assert e == AccumBounds(-1, 2)
  56. def test_trigonometric():
  57. n3 = Rational(3)
  58. e = (sin(x)**2).diff(x)
  59. assert e == 2*sin(x)*cos(x)
  60. e = e.subs(x, n3)
  61. assert e == 2*cos(n3)*sin(n3)
  62. e = (sin(x)**2).diff(x)
  63. assert e == 2*sin(x)*cos(x)
  64. e = e.subs(sin(x), cos(x))
  65. assert e == 2*cos(x)**2
  66. assert exp(pi).subs(exp, sin) == 0
  67. assert cos(exp(pi)).subs(exp, sin) == 1
  68. i = Symbol('i', integer=True)
  69. zoo = S.ComplexInfinity
  70. assert tan(x).subs(x, pi/2) is zoo
  71. assert cot(x).subs(x, pi) is zoo
  72. assert cot(i*x).subs(x, pi) is zoo
  73. assert tan(i*x).subs(x, pi/2) == tan(i*pi/2)
  74. assert tan(i*x).subs(x, pi/2).subs(i, 1) is zoo
  75. o = Symbol('o', odd=True)
  76. assert tan(o*x).subs(x, pi/2) == tan(o*pi/2)
  77. def test_powers():
  78. assert sqrt(1 - sqrt(x)).subs(x, 4) == I
  79. assert (sqrt(1 - x**2)**3).subs(x, 2) == - 3*I*sqrt(3)
  80. assert (x**Rational(1, 3)).subs(x, 27) == 3
  81. assert (x**Rational(1, 3)).subs(x, -27) == 3*(-1)**Rational(1, 3)
  82. assert ((-x)**Rational(1, 3)).subs(x, 27) == 3*(-1)**Rational(1, 3)
  83. n = Symbol('n', negative=True)
  84. assert (x**n).subs(x, 0) is S.ComplexInfinity
  85. assert exp(-1).subs(S.Exp1, 0) is S.ComplexInfinity
  86. assert (x**(4.0*y)).subs(x**(2.0*y), n) == n**2.0
  87. assert (2**(x + 2)).subs(2, 3) == 3**(x + 3)
  88. def test_logexppow(): # no eval()
  89. x = Symbol('x', real=True)
  90. w = Symbol('w')
  91. e = (3**(1 + x) + 2**(1 + x))/(3**x + 2**x)
  92. assert e.subs(2**x, w) != e
  93. assert e.subs(exp(x*log(Rational(2))), w) != e
  94. def test_bug():
  95. x1 = Symbol('x1')
  96. x2 = Symbol('x2')
  97. y = x1*x2
  98. assert y.subs(x1, Float(3.0)) == Float(3.0)*x2
  99. def test_subbug1():
  100. # see that they don't fail
  101. (x**x).subs(x, 1)
  102. (x**x).subs(x, 1.0)
  103. def test_subbug2():
  104. # Ensure this does not cause infinite recursion
  105. assert Float(7.7).epsilon_eq(abs(x).subs(x, -7.7))
  106. def test_dict_set():
  107. a, b, c = map(Wild, 'abc')
  108. f = 3*cos(4*x)
  109. r = f.match(a*cos(b*x))
  110. assert r == {a: 3, b: 4}
  111. e = a/b*sin(b*x)
  112. assert e.subs(r) == r[a]/r[b]*sin(r[b]*x)
  113. assert e.subs(r) == 3*sin(4*x) / 4
  114. s = set(r.items())
  115. assert e.subs(s) == r[a]/r[b]*sin(r[b]*x)
  116. assert e.subs(s) == 3*sin(4*x) / 4
  117. assert e.subs(r) == r[a]/r[b]*sin(r[b]*x)
  118. assert e.subs(r) == 3*sin(4*x) / 4
  119. assert x.subs(Dict((x, 1))) == 1
  120. def test_dict_ambigous(): # see issue 3566
  121. f = x*exp(x)
  122. g = z*exp(z)
  123. df = {x: y, exp(x): y}
  124. dg = {z: y, exp(z): y}
  125. assert f.subs(df) == y**2
  126. assert g.subs(dg) == y**2
  127. # and this is how order can affect the result
  128. assert f.subs(x, y).subs(exp(x), y) == y*exp(y)
  129. assert f.subs(exp(x), y).subs(x, y) == y**2
  130. # length of args and count_ops are the same so
  131. # default_sort_key resolves ordering...if one
  132. # doesn't want this result then an unordered
  133. # sequence should not be used.
  134. e = 1 + x*y
  135. assert e.subs({x: y, y: 2}) == 5
  136. # here, there are no obviously clashing keys or values
  137. # but the results depend on the order
  138. assert exp(x/2 + y).subs({exp(y + 1): 2, x: 2}) == exp(y + 1)
  139. def test_deriv_sub_bug3():
  140. f = Function('f')
  141. pat = Derivative(f(x), x, x)
  142. assert pat.subs(y, y**2) == Derivative(f(x), x, x)
  143. assert pat.subs(y, y**2) != Derivative(f(x), x)
  144. def test_equality_subs1():
  145. f = Function('f')
  146. eq = Eq(f(x)**2, x)
  147. res = Eq(Integer(16), x)
  148. assert eq.subs(f(x), 4) == res
  149. def test_equality_subs2():
  150. f = Function('f')
  151. eq = Eq(f(x)**2, 16)
  152. assert bool(eq.subs(f(x), 3)) is False
  153. assert bool(eq.subs(f(x), 4)) is True
  154. def test_issue_3742():
  155. e = sqrt(x)*exp(y)
  156. assert e.subs(sqrt(x), 1) == exp(y)
  157. def test_subs_dict1():
  158. assert (1 + x*y).subs(x, pi) == 1 + pi*y
  159. assert (1 + x*y).subs({x: pi, y: 2}) == 1 + 2*pi
  160. c2, c3, q1p, q2p, c1, s1, s2, s3 = symbols('c2 c3 q1p q2p c1 s1 s2 s3')
  161. test = (c2**2*q2p*c3 + c1**2*s2**2*q2p*c3 + s1**2*s2**2*q2p*c3
  162. - c1**2*q1p*c2*s3 - s1**2*q1p*c2*s3)
  163. assert (test.subs({c1**2: 1 - s1**2, c2**2: 1 - s2**2, c3**3: 1 - s3**2})
  164. == c3*q2p*(1 - s2**2) + c3*q2p*s2**2*(1 - s1**2)
  165. - c2*q1p*s3*(1 - s1**2) + c3*q2p*s1**2*s2**2 - c2*q1p*s3*s1**2)
  166. def test_mul():
  167. x, y, z, a, b, c = symbols('x y z a b c')
  168. A, B, C = symbols('A B C', commutative=0)
  169. assert (x*y*z).subs(z*x, y) == y**2
  170. assert (z*x).subs(1/x, z) == 1
  171. assert (x*y/z).subs(1/z, a) == a*x*y
  172. assert (x*y/z).subs(x/z, a) == a*y
  173. assert (x*y/z).subs(y/z, a) == a*x
  174. assert (x*y/z).subs(x/z, 1/a) == y/a
  175. assert (x*y/z).subs(x, 1/a) == y/(z*a)
  176. assert (2*x*y).subs(5*x*y, z) != z*Rational(2, 5)
  177. assert (x*y*A).subs(x*y, a) == a*A
  178. assert (x**2*y**(x*Rational(3, 2))).subs(x*y**(x/2), 2) == 4*y**(x/2)
  179. assert (x*exp(x*2)).subs(x*exp(x), 2) == 2*exp(x)
  180. assert ((x**(2*y))**3).subs(x**y, 2) == 64
  181. assert (x*A*B).subs(x*A, y) == y*B
  182. assert (x*y*(1 + x)*(1 + x*y)).subs(x*y, 2) == 6*(1 + x)
  183. assert ((1 + A*B)*A*B).subs(A*B, x*A*B)
  184. assert (x*a/z).subs(x/z, A) == a*A
  185. assert (x**3*A).subs(x**2*A, a) == a*x
  186. assert (x**2*A*B).subs(x**2*B, a) == a*A
  187. assert (x**2*A*B).subs(x**2*A, a) == a*B
  188. assert (b*A**3/(a**3*c**3)).subs(a**4*c**3*A**3/b**4, z) == \
  189. b*A**3/(a**3*c**3)
  190. assert (6*x).subs(2*x, y) == 3*y
  191. assert (y*exp(x*Rational(3, 2))).subs(y*exp(x), 2) == 2*exp(x/2)
  192. assert (y*exp(x*Rational(3, 2))).subs(y*exp(x), 2) == 2*exp(x/2)
  193. assert (A**2*B*A**2*B*A**2).subs(A*B*A, C) == A*C**2*A
  194. assert (x*A**3).subs(x*A, y) == y*A**2
  195. assert (x**2*A**3).subs(x*A, y) == y**2*A
  196. assert (x*A**3).subs(x*A, B) == B*A**2
  197. assert (x*A*B*A*exp(x*A*B)).subs(x*A, B) == B**2*A*exp(B*B)
  198. assert (x**2*A*B*A*exp(x*A*B)).subs(x*A, B) == B**3*exp(B**2)
  199. assert (x**3*A*exp(x*A*B)*A*exp(x*A*B)).subs(x*A, B) == \
  200. x*B*exp(B**2)*B*exp(B**2)
  201. assert (x*A*B*C*A*B).subs(x*A*B, C) == C**2*A*B
  202. assert (-I*a*b).subs(a*b, 2) == -2*I
  203. # issue 6361
  204. assert (-8*I*a).subs(-2*a, 1) == 4*I
  205. assert (-I*a).subs(-a, 1) == I
  206. # issue 6441
  207. assert (4*x**2).subs(2*x, y) == y**2
  208. assert (2*4*x**2).subs(2*x, y) == 2*y**2
  209. assert (-x**3/9).subs(-x/3, z) == -z**2*x
  210. assert (-x**3/9).subs(x/3, z) == -z**2*x
  211. assert (-2*x**3/9).subs(x/3, z) == -2*x*z**2
  212. assert (-2*x**3/9).subs(-x/3, z) == -2*x*z**2
  213. assert (-2*x**3/9).subs(-2*x, z) == z*x**2/9
  214. assert (-2*x**3/9).subs(2*x, z) == -z*x**2/9
  215. assert (2*(3*x/5/7)**2).subs(3*x/5, z) == 2*(Rational(1, 7))**2*z**2
  216. assert (4*x).subs(-2*x, z) == 4*x # try keep subs literal
  217. def test_subs_simple():
  218. a = symbols('a', commutative=True)
  219. x = symbols('x', commutative=False)
  220. assert (2*a).subs(1, 3) == 2*a
  221. assert (2*a).subs(2, 3) == 3*a
  222. assert (2*a).subs(a, 3) == 6
  223. assert sin(2).subs(1, 3) == sin(2)
  224. assert sin(2).subs(2, 3) == sin(3)
  225. assert sin(a).subs(a, 3) == sin(3)
  226. assert (2*x).subs(1, 3) == 2*x
  227. assert (2*x).subs(2, 3) == 3*x
  228. assert (2*x).subs(x, 3) == 6
  229. assert sin(x).subs(x, 3) == sin(3)
  230. def test_subs_constants():
  231. a, b = symbols('a b', commutative=True)
  232. x, y = symbols('x y', commutative=False)
  233. assert (a*b).subs(2*a, 1) == a*b
  234. assert (1.5*a*b).subs(a, 1) == 1.5*b
  235. assert (2*a*b).subs(2*a, 1) == b
  236. assert (2*a*b).subs(4*a, 1) == 2*a*b
  237. assert (x*y).subs(2*x, 1) == x*y
  238. assert (1.5*x*y).subs(x, 1) == 1.5*y
  239. assert (2*x*y).subs(2*x, 1) == y
  240. assert (2*x*y).subs(4*x, 1) == 2*x*y
  241. def test_subs_commutative():
  242. a, b, c, d, K = symbols('a b c d K', commutative=True)
  243. assert (a*b).subs(a*b, K) == K
  244. assert (a*b*a*b).subs(a*b, K) == K**2
  245. assert (a*a*b*b).subs(a*b, K) == K**2
  246. assert (a*b*c*d).subs(a*b*c, K) == d*K
  247. assert (a*b**c).subs(a, K) == K*b**c
  248. assert (a*b**c).subs(b, K) == a*K**c
  249. assert (a*b**c).subs(c, K) == a*b**K
  250. assert (a*b*c*b*a).subs(a*b, K) == c*K**2
  251. assert (a**3*b**2*a).subs(a*b, K) == a**2*K**2
  252. def test_subs_noncommutative():
  253. w, x, y, z, L = symbols('w x y z L', commutative=False)
  254. alpha = symbols('alpha', commutative=True)
  255. someint = symbols('someint', commutative=True, integer=True)
  256. assert (x*y).subs(x*y, L) == L
  257. assert (w*y*x).subs(x*y, L) == w*y*x
  258. assert (w*x*y*z).subs(x*y, L) == w*L*z
  259. assert (x*y*x*y).subs(x*y, L) == L**2
  260. assert (x*x*y).subs(x*y, L) == x*L
  261. assert (x*x*y*y).subs(x*y, L) == x*L*y
  262. assert (w*x*y).subs(x*y*z, L) == w*x*y
  263. assert (x*y**z).subs(x, L) == L*y**z
  264. assert (x*y**z).subs(y, L) == x*L**z
  265. assert (x*y**z).subs(z, L) == x*y**L
  266. assert (w*x*y*z*x*y).subs(x*y*z, L) == w*L*x*y
  267. assert (w*x*y*y*w*x*x*y*x*y*y*x*y).subs(x*y, L) == w*L*y*w*x*L**2*y*L
  268. # Check fractional power substitutions. It should not do
  269. # substitutions that choose a value for noncommutative log,
  270. # or inverses that don't already appear in the expressions.
  271. assert (x*x*x).subs(x*x, L) == L*x
  272. assert (x*x*x*y*x*x*x*x).subs(x*x, L) == L*x*y*L**2
  273. for p in range(1, 5):
  274. for k in range(10):
  275. assert (y * x**k).subs(x**p, L) == y * L**(k//p) * x**(k % p)
  276. assert (x**Rational(3, 2)).subs(x**S.Half, L) == x**Rational(3, 2)
  277. assert (x**S.Half).subs(x**S.Half, L) == L
  278. assert (x**Rational(-1, 2)).subs(x**S.Half, L) == x**Rational(-1, 2)
  279. assert (x**Rational(-1, 2)).subs(x**Rational(-1, 2), L) == L
  280. assert (x**(2*someint)).subs(x**someint, L) == L**2
  281. assert (x**(2*someint + 3)).subs(x**someint, L) == L**2*x**3
  282. assert (x**(3*someint + 3)).subs(x**someint, L) == L**3*x**3
  283. assert (x**(3*someint)).subs(x**(2*someint), L) == L * x**someint
  284. assert (x**(4*someint)).subs(x**(2*someint), L) == L**2
  285. assert (x**(4*someint + 1)).subs(x**(2*someint), L) == L**2 * x
  286. assert (x**(4*someint)).subs(x**(3*someint), L) == L * x**someint
  287. assert (x**(4*someint + 1)).subs(x**(3*someint), L) == L * x**(someint + 1)
  288. assert (x**(2*alpha)).subs(x**alpha, L) == x**(2*alpha)
  289. assert (x**(2*alpha + 2)).subs(x**2, L) == x**(2*alpha + 2)
  290. assert ((2*z)**alpha).subs(z**alpha, y) == (2*z)**alpha
  291. assert (x**(2*someint*alpha)).subs(x**someint, L) == x**(2*someint*alpha)
  292. assert (x**(2*someint + alpha)).subs(x**someint, L) == x**(2*someint + alpha)
  293. # This could in principle be substituted, but is not currently
  294. # because it requires recognizing that someint**2 is divisible by
  295. # someint.
  296. assert (x**(someint**2 + 3)).subs(x**someint, L) == x**(someint**2 + 3)
  297. # alpha**z := exp(log(alpha) z) is usually well-defined
  298. assert (4**z).subs(2**z, y) == y**2
  299. # Negative powers
  300. assert (x**(-1)).subs(x**3, L) == x**(-1)
  301. assert (x**(-2)).subs(x**3, L) == x**(-2)
  302. assert (x**(-3)).subs(x**3, L) == L**(-1)
  303. assert (x**(-4)).subs(x**3, L) == L**(-1) * x**(-1)
  304. assert (x**(-5)).subs(x**3, L) == L**(-1) * x**(-2)
  305. assert (x**(-1)).subs(x**(-3), L) == x**(-1)
  306. assert (x**(-2)).subs(x**(-3), L) == x**(-2)
  307. assert (x**(-3)).subs(x**(-3), L) == L
  308. assert (x**(-4)).subs(x**(-3), L) == L * x**(-1)
  309. assert (x**(-5)).subs(x**(-3), L) == L * x**(-2)
  310. assert (x**1).subs(x**(-3), L) == x
  311. assert (x**2).subs(x**(-3), L) == x**2
  312. assert (x**3).subs(x**(-3), L) == L**(-1)
  313. assert (x**4).subs(x**(-3), L) == L**(-1) * x
  314. assert (x**5).subs(x**(-3), L) == L**(-1) * x**2
  315. def test_subs_basic_funcs():
  316. a, b, c, d, K = symbols('a b c d K', commutative=True)
  317. w, x, y, z, L = symbols('w x y z L', commutative=False)
  318. assert (x + y).subs(x + y, L) == L
  319. assert (x - y).subs(x - y, L) == L
  320. assert (x/y).subs(x, L) == L/y
  321. assert (x**y).subs(x, L) == L**y
  322. assert (x**y).subs(y, L) == x**L
  323. assert ((a - c)/b).subs(b, K) == (a - c)/K
  324. assert (exp(x*y - z)).subs(x*y, L) == exp(L - z)
  325. assert (a*exp(x*y - w*z) + b*exp(x*y + w*z)).subs(z, 0) == \
  326. a*exp(x*y) + b*exp(x*y)
  327. assert ((a - b)/(c*d - a*b)).subs(c*d - a*b, K) == (a - b)/K
  328. assert (w*exp(a*b - c)*x*y/4).subs(x*y, L) == w*exp(a*b - c)*L/4
  329. def test_subs_wild():
  330. R, S, T, U = symbols('R S T U', cls=Wild)
  331. assert (R*S).subs(R*S, T) == T
  332. assert (S*R).subs(R*S, T) == T
  333. assert (R + S).subs(R + S, T) == T
  334. assert (R**S).subs(R, T) == T**S
  335. assert (R**S).subs(S, T) == R**T
  336. assert (R*S**T).subs(R, U) == U*S**T
  337. assert (R*S**T).subs(S, U) == R*U**T
  338. assert (R*S**T).subs(T, U) == R*S**U
  339. def test_subs_mixed():
  340. a, b, c, d, K = symbols('a b c d K', commutative=True)
  341. w, x, y, z, L = symbols('w x y z L', commutative=False)
  342. R, S, T, U = symbols('R S T U', cls=Wild)
  343. assert (a*x*y).subs(x*y, L) == a*L
  344. assert (a*b*x*y*x).subs(x*y, L) == a*b*L*x
  345. assert (R*x*y*exp(x*y)).subs(x*y, L) == R*L*exp(L)
  346. assert (a*x*y*y*x - x*y*z*exp(a*b)).subs(x*y, L) == a*L*y*x - L*z*exp(a*b)
  347. e = c*y*x*y*x**(R*S - a*b) - T*(a*R*b*S)
  348. assert e.subs(x*y, L).subs(a*b, K).subs(R*S, U) == \
  349. c*y*L*x**(U - K) - T*(U*K)
  350. def test_division():
  351. a, b, c = symbols('a b c', commutative=True)
  352. x, y, z = symbols('x y z', commutative=True)
  353. assert (1/a).subs(a, c) == 1/c
  354. assert (1/a**2).subs(a, c) == 1/c**2
  355. assert (1/a**2).subs(a, -2) == Rational(1, 4)
  356. assert (-(1/a**2)).subs(a, -2) == Rational(-1, 4)
  357. assert (1/x).subs(x, z) == 1/z
  358. assert (1/x**2).subs(x, z) == 1/z**2
  359. assert (1/x**2).subs(x, -2) == Rational(1, 4)
  360. assert (-(1/x**2)).subs(x, -2) == Rational(-1, 4)
  361. #issue 5360
  362. assert (1/x).subs(x, 0) == 1/S.Zero
  363. def test_add():
  364. a, b, c, d, x, y, t = symbols('a b c d x y t')
  365. assert (a**2 - b - c).subs(a**2 - b, d) in [d - c, a**2 - b - c]
  366. assert (a**2 - c).subs(a**2 - c, d) == d
  367. assert (a**2 - b - c).subs(a**2 - c, d) in [d - b, a**2 - b - c]
  368. assert (a**2 - x - c).subs(a**2 - c, d) in [d - x, a**2 - x - c]
  369. assert (a**2 - b - sqrt(a)).subs(a**2 - sqrt(a), c) == c - b
  370. assert (a + b + exp(a + b)).subs(a + b, c) == c + exp(c)
  371. assert (c + b + exp(c + b)).subs(c + b, a) == a + exp(a)
  372. assert (a + b + c + d).subs(b + c, x) == a + d + x
  373. assert (a + b + c + d).subs(-b - c, x) == a + d - x
  374. assert ((x + 1)*y).subs(x + 1, t) == t*y
  375. assert ((-x - 1)*y).subs(x + 1, t) == -t*y
  376. assert ((x - 1)*y).subs(x + 1, t) == y*(t - 2)
  377. assert ((-x + 1)*y).subs(x + 1, t) == y*(-t + 2)
  378. # this should work every time:
  379. e = a**2 - b - c
  380. assert e.subs(Add(*e.args[:2]), d) == d + e.args[2]
  381. assert e.subs(a**2 - c, d) == d - b
  382. # the fallback should recognize when a change has
  383. # been made; while .1 == Rational(1, 10) they are not the same
  384. # and the change should be made
  385. assert (0.1 + a).subs(0.1, Rational(1, 10)) == Rational(1, 10) + a
  386. e = (-x*(-y + 1) - y*(y - 1))
  387. ans = (-x*(x) - y*(-x)).expand()
  388. assert e.subs(-y + 1, x) == ans
  389. #Test issue 18747
  390. assert (exp(x) + cos(x)).subs(x, oo) == oo
  391. assert Add(*[AccumBounds(-1, 1), oo]) == oo
  392. assert Add(*[oo, AccumBounds(-1, 1)]) == oo
  393. def test_subs_issue_4009():
  394. assert (I*Symbol('a')).subs(1, 2) == I*Symbol('a')
  395. def test_functions_subs():
  396. f, g = symbols('f g', cls=Function)
  397. l = Lambda((x, y), sin(x) + y)
  398. assert (g(y, x) + cos(x)).subs(g, l) == sin(y) + x + cos(x)
  399. assert (f(x)**2).subs(f, sin) == sin(x)**2
  400. assert (f(x, y)).subs(f, log) == log(x, y)
  401. assert (f(x, y)).subs(f, sin) == f(x, y)
  402. assert (sin(x) + atan2(x, y)).subs([[atan2, f], [sin, g]]) == \
  403. f(x, y) + g(x)
  404. assert (g(f(x + y, x))).subs([[f, l], [g, exp]]) == exp(x + sin(x + y))
  405. def test_derivative_subs():
  406. f = Function('f')
  407. g = Function('g')
  408. assert Derivative(f(x), x).subs(f(x), y) != 0
  409. # need xreplace to put the function back, see #13803
  410. assert Derivative(f(x), x).subs(f(x), y).xreplace({y: f(x)}) == \
  411. Derivative(f(x), x)
  412. # issues 5085, 5037
  413. assert cse(Derivative(f(x), x) + f(x))[1][0].has(Derivative)
  414. assert cse(Derivative(f(x, y), x) +
  415. Derivative(f(x, y), y))[1][0].has(Derivative)
  416. eq = Derivative(g(x), g(x))
  417. assert eq.subs(g, f) == Derivative(f(x), f(x))
  418. assert eq.subs(g(x), f(x)) == Derivative(f(x), f(x))
  419. assert eq.subs(g, cos) == Subs(Derivative(y, y), y, cos(x))
  420. def test_derivative_subs2():
  421. f_func, g_func = symbols('f g', cls=Function)
  422. f, g = f_func(x, y, z), g_func(x, y, z)
  423. assert Derivative(f, x, y).subs(Derivative(f, x, y), g) == g
  424. assert Derivative(f, y, x).subs(Derivative(f, x, y), g) == g
  425. assert Derivative(f, x, y).subs(Derivative(f, x), g) == Derivative(g, y)
  426. assert Derivative(f, x, y).subs(Derivative(f, y), g) == Derivative(g, x)
  427. assert (Derivative(f, x, y, z).subs(
  428. Derivative(f, x, z), g) == Derivative(g, y))
  429. assert (Derivative(f, x, y, z).subs(
  430. Derivative(f, z, y), g) == Derivative(g, x))
  431. assert (Derivative(f, x, y, z).subs(
  432. Derivative(f, z, y, x), g) == g)
  433. # Issue 9135
  434. assert (Derivative(f, x, x, y).subs(
  435. Derivative(f, y, y), g) == Derivative(f, x, x, y))
  436. assert (Derivative(f, x, y, y, z).subs(
  437. Derivative(f, x, y, y, y), g) == Derivative(f, x, y, y, z))
  438. assert Derivative(f, x, y).subs(Derivative(f_func(x), x, y), g) == Derivative(f, x, y)
  439. def test_derivative_subs3():
  440. dex = Derivative(exp(x), x)
  441. assert Derivative(dex, x).subs(dex, exp(x)) == dex
  442. assert dex.subs(exp(x), dex) == Derivative(exp(x), x, x)
  443. def test_issue_5284():
  444. A, B = symbols('A B', commutative=False)
  445. assert (x*A).subs(x**2*A, B) == x*A
  446. assert (A**2).subs(A**3, B) == A**2
  447. assert (A**6).subs(A**3, B) == B**2
  448. def test_subs_iter():
  449. assert x.subs(reversed([[x, y]])) == y
  450. it = iter([[x, y]])
  451. assert x.subs(it) == y
  452. assert x.subs(Tuple((x, y))) == y
  453. def test_subs_dict():
  454. a, b, c, d, e = symbols('a b c d e')
  455. assert (2*x + y + z).subs({"x": 1, "y": 2}) == 4 + z
  456. l = [(sin(x), 2), (x, 1)]
  457. assert (sin(x)).subs(l) == \
  458. (sin(x)).subs(dict(l)) == 2
  459. assert sin(x).subs(reversed(l)) == sin(1)
  460. expr = sin(2*x) + sqrt(sin(2*x))*cos(2*x)*sin(exp(x)*x)
  461. reps = {sin(2*x): c,
  462. sqrt(sin(2*x)): a,
  463. cos(2*x): b,
  464. exp(x): e,
  465. x: d,}
  466. assert expr.subs(reps) == c + a*b*sin(d*e)
  467. l = [(x, 3), (y, x**2)]
  468. assert (x + y).subs(l) == 3 + x**2
  469. assert (x + y).subs(reversed(l)) == 12
  470. # If changes are made to convert lists into dictionaries and do
  471. # a dictionary-lookup replacement, these tests will help to catch
  472. # some logical errors that might occur
  473. l = [(y, z + 2), (1 + z, 5), (z, 2)]
  474. assert (y - 1 + 3*x).subs(l) == 5 + 3*x
  475. l = [(y, z + 2), (z, 3)]
  476. assert (y - 2).subs(l) == 3
  477. def test_no_arith_subs_on_floats():
  478. assert (x + 3).subs(x + 3, a) == a
  479. assert (x + 3).subs(x + 2, a) == a + 1
  480. assert (x + y + 3).subs(x + 3, a) == a + y
  481. assert (x + y + 3).subs(x + 2, a) == a + y + 1
  482. assert (x + 3.0).subs(x + 3.0, a) == a
  483. assert (x + 3.0).subs(x + 2.0, a) == x + 3.0
  484. assert (x + y + 3.0).subs(x + 3.0, a) == a + y
  485. assert (x + y + 3.0).subs(x + 2.0, a) == x + y + 3.0
  486. def test_issue_5651():
  487. a, b, c, K = symbols('a b c K', commutative=True)
  488. assert (a/(b*c)).subs(b*c, K) == a/K
  489. assert (a/(b**2*c**3)).subs(b*c, K) == a/(c*K**2)
  490. assert (1/(x*y)).subs(x*y, 2) == S.Half
  491. assert ((1 + x*y)/(x*y)).subs(x*y, 1) == 2
  492. assert (x*y*z).subs(x*y, 2) == 2*z
  493. assert ((1 + x*y)/(x*y)/z).subs(x*y, 1) == 2/z
  494. def test_issue_6075():
  495. assert Tuple(1, True).subs(1, 2) == Tuple(2, True)
  496. def test_issue_6079():
  497. # since x + 2.0 == x + 2 we can't do a simple equality test
  498. assert _aresame((x + 2.0).subs(2, 3), x + 2.0)
  499. assert _aresame((x + 2.0).subs(2.0, 3), x + 3)
  500. assert not _aresame(x + 2, x + 2.0)
  501. assert not _aresame(Basic(cos(x), S(1)), Basic(cos(x), S(1.)))
  502. assert _aresame(cos, cos)
  503. assert not _aresame(1, S.One)
  504. assert not _aresame(x, symbols('x', positive=True))
  505. def test_issue_4680():
  506. N = Symbol('N')
  507. assert N.subs({"N": 3}) == 3
  508. def test_issue_6158():
  509. assert (x - 1).subs(1, y) == x - y
  510. assert (x - 1).subs(-1, y) == x + y
  511. assert (x - oo).subs(oo, y) == x - y
  512. assert (x - oo).subs(-oo, y) == x + y
  513. def test_Function_subs():
  514. f, g, h, i = symbols('f g h i', cls=Function)
  515. p = Piecewise((g(f(x, y)), x < -1), (g(x), x <= 1))
  516. assert p.subs(g, h) == Piecewise((h(f(x, y)), x < -1), (h(x), x <= 1))
  517. assert (f(y) + g(x)).subs({f: h, g: i}) == i(x) + h(y)
  518. def test_simultaneous_subs():
  519. reps = {x: 0, y: 0}
  520. assert (x/y).subs(reps) != (y/x).subs(reps)
  521. assert (x/y).subs(reps, simultaneous=True) == \
  522. (y/x).subs(reps, simultaneous=True)
  523. reps = reps.items()
  524. assert (x/y).subs(reps) != (y/x).subs(reps)
  525. assert (x/y).subs(reps, simultaneous=True) == \
  526. (y/x).subs(reps, simultaneous=True)
  527. assert Derivative(x, y, z).subs(reps, simultaneous=True) == \
  528. Subs(Derivative(0, y, z), y, 0)
  529. def test_issue_6419_6421():
  530. assert (1/(1 + x/y)).subs(x/y, x) == 1/(1 + x)
  531. assert (-2*I).subs(2*I, x) == -x
  532. assert (-I*x).subs(I*x, x) == -x
  533. assert (-3*I*y**4).subs(3*I*y**2, x) == -x*y**2
  534. def test_issue_6559():
  535. assert (-12*x + y).subs(-x, 1) == 12 + y
  536. # though this involves cse it generated a failure in Mul._eval_subs
  537. x0, x1 = symbols('x0 x1')
  538. e = -log(-12*sqrt(2) + 17)/24 - log(-2*sqrt(2) + 3)/12 + sqrt(2)/3
  539. # XXX modify cse so x1 is eliminated and x0 = -sqrt(2)?
  540. assert cse(e) == (
  541. [(x0, sqrt(2))], [x0/3 - log(-12*x0 + 17)/24 - log(-2*x0 + 3)/12])
  542. def test_issue_5261():
  543. x = symbols('x', real=True)
  544. e = I*x
  545. assert exp(e).subs(exp(x), y) == y**I
  546. assert (2**e).subs(2**x, y) == y**I
  547. eq = (-2)**e
  548. assert eq.subs((-2)**x, y) == eq
  549. def test_issue_6923():
  550. assert (-2*x*sqrt(2)).subs(2*x, y) == -sqrt(2)*y
  551. def test_2arg_hack():
  552. N = Symbol('N', commutative=False)
  553. ans = Mul(2, y + 1, evaluate=False)
  554. assert (2*x*(y + 1)).subs(x, 1, hack2=True) == ans
  555. assert (2*(y + 1 + N)).subs(N, 0, hack2=True) == ans
  556. @XFAIL
  557. def test_mul2():
  558. """When this fails, remove things labelled "2-arg hack"
  559. 1) remove special handling in the fallback of subs that
  560. was added in the same commit as this test
  561. 2) remove the special handling in Mul.flatten
  562. """
  563. assert (2*(x + 1)).is_Mul
  564. def test_noncommutative_subs():
  565. x,y = symbols('x,y', commutative=False)
  566. assert (x*y*x).subs([(x, x*y), (y, x)], simultaneous=True) == (x*y*x**2*y)
  567. def test_issue_2877():
  568. f = Float(2.0)
  569. assert (x + f).subs({f: 2}) == x + 2
  570. def r(a, b, c):
  571. return factor(a*x**2 + b*x + c)
  572. e = r(5.0/6, 10, 5)
  573. assert nsimplify(e) == 5*x**2/6 + 10*x + 5
  574. def test_issue_5910():
  575. t = Symbol('t')
  576. assert (1/(1 - t)).subs(t, 1) is zoo
  577. n = t
  578. d = t - 1
  579. assert (n/d).subs(t, 1) is zoo
  580. assert (-n/-d).subs(t, 1) is zoo
  581. def test_issue_5217():
  582. s = Symbol('s')
  583. z = (1 - 2*x*x)
  584. w = (1 + 2*x*x)
  585. q = 2*x*x*2*y*y
  586. sub = {2*x*x: s}
  587. assert w.subs(sub) == 1 + s
  588. assert z.subs(sub) == 1 - s
  589. assert q == 4*x**2*y**2
  590. assert q.subs(sub) == 2*y**2*s
  591. def test_issue_10829():
  592. assert (4**x).subs(2**x, y) == y**2
  593. assert (9**x).subs(3**x, y) == y**2
  594. def test_pow_eval_subs_no_cache():
  595. # Tests pull request 9376 is working
  596. from sympy.core.cache import clear_cache
  597. s = 1/sqrt(x**2)
  598. # This bug only appeared when the cache was turned off.
  599. # We need to approximate running this test without the cache.
  600. # This creates approximately the same situation.
  601. clear_cache()
  602. # This used to fail with a wrong result.
  603. # It incorrectly returned 1/sqrt(x**2) before this pull request.
  604. result = s.subs(sqrt(x**2), y)
  605. assert result == 1/y
  606. def test_RootOf_issue_10092():
  607. x = Symbol('x', real=True)
  608. eq = x**3 - 17*x**2 + 81*x - 118
  609. r = RootOf(eq, 0)
  610. assert (x < r).subs(x, r) is S.false
  611. def test_issue_8886():
  612. from sympy.physics.mechanics import ReferenceFrame as R
  613. # if something can't be sympified we assume that it
  614. # doesn't play well with SymPy and disallow the
  615. # substitution
  616. v = R('A').x
  617. raises(SympifyError, lambda: x.subs(x, v))
  618. raises(SympifyError, lambda: v.subs(v, x))
  619. assert v.__eq__(x) is False
  620. def test_issue_12657():
  621. # treat -oo like the atom that it is
  622. reps = [(-oo, 1), (oo, 2)]
  623. assert (x < -oo).subs(reps) == (x < 1)
  624. assert (x < -oo).subs(list(reversed(reps))) == (x < 1)
  625. reps = [(-oo, 2), (oo, 1)]
  626. assert (x < oo).subs(reps) == (x < 1)
  627. assert (x < oo).subs(list(reversed(reps))) == (x < 1)
  628. def test_recurse_Application_args():
  629. F = Lambda((x, y), exp(2*x + 3*y))
  630. f = Function('f')
  631. A = f(x, f(x, x))
  632. C = F(x, F(x, x))
  633. assert A.subs(f, F) == A.replace(f, F) == C
  634. def test_Subs_subs():
  635. assert Subs(x*y, x, x).subs(x, y) == Subs(x*y, x, y)
  636. assert Subs(x*y, x, x + 1).subs(x, y) == \
  637. Subs(x*y, x, y + 1)
  638. assert Subs(x*y, y, x + 1).subs(x, y) == \
  639. Subs(y**2, y, y + 1)
  640. a = Subs(x*y*z, (y, x, z), (x + 1, x + z, x))
  641. b = Subs(x*y*z, (y, x, z), (x + 1, y + z, y))
  642. assert a.subs(x, y) == b and \
  643. a.doit().subs(x, y) == a.subs(x, y).doit()
  644. f = Function('f')
  645. g = Function('g')
  646. assert Subs(2*f(x, y) + g(x), f(x, y), 1).subs(y, 2) == Subs(
  647. 2*f(x, y) + g(x), (f(x, y), y), (1, 2))
  648. def test_issue_13333():
  649. eq = 1/x
  650. assert eq.subs({"x": '1/2'}) == 2
  651. assert eq.subs({"x": '(1/2)'}) == 2
  652. def test_issue_15234():
  653. x, y = symbols('x y', real=True)
  654. p = 6*x**5 + x**4 - 4*x**3 + 4*x**2 - 2*x + 3
  655. p_subbed = 6*x**5 - 4*x**3 - 2*x + y**4 + 4*y**2 + 3
  656. assert p.subs([(x**i, y**i) for i in [2, 4]]) == p_subbed
  657. x, y = symbols('x y', complex=True)
  658. p = 6*x**5 + x**4 - 4*x**3 + 4*x**2 - 2*x + 3
  659. p_subbed = 6*x**5 - 4*x**3 - 2*x + y**4 + 4*y**2 + 3
  660. assert p.subs([(x**i, y**i) for i in [2, 4]]) == p_subbed
  661. def test_issue_6976():
  662. x, y = symbols('x y')
  663. assert (sqrt(x)**3 + sqrt(x) + x + x**2).subs(sqrt(x), y) == \
  664. y**4 + y**3 + y**2 + y
  665. assert (x**4 + x**3 + x**2 + x + sqrt(x)).subs(x**2, y) == \
  666. sqrt(x) + x**3 + x + y**2 + y
  667. assert x.subs(x**3, y) == x
  668. assert x.subs(x**Rational(1, 3), y) == y**3
  669. # More substitutions are possible with nonnegative symbols
  670. x, y = symbols('x y', nonnegative=True)
  671. assert (x**4 + x**3 + x**2 + x + sqrt(x)).subs(x**2, y) == \
  672. y**Rational(1, 4) + y**Rational(3, 2) + sqrt(y) + y**2 + y
  673. assert x.subs(x**3, y) == y**Rational(1, 3)
  674. def test_issue_11746():
  675. assert (1/x).subs(x**2, 1) == 1/x
  676. assert (1/(x**3)).subs(x**2, 1) == x**(-3)
  677. assert (1/(x**4)).subs(x**2, 1) == 1
  678. assert (1/(x**3)).subs(x**4, 1) == x**(-3)
  679. assert (1/(y**5)).subs(x**5, 1) == y**(-5)
  680. def test_issue_17823():
  681. from sympy.physics.mechanics import dynamicsymbols
  682. q1, q2 = dynamicsymbols('q1, q2')
  683. expr = q1.diff().diff()**2*q1 + q1.diff()*q2.diff()
  684. reps={q1: a, q1.diff(): a*x*y, q1.diff().diff(): z}
  685. assert expr.subs(reps) == a*x*y*Derivative(q2, t) + a*z**2
  686. def test_issue_19326():
  687. x, y = [i(t) for i in map(Function, 'xy')]
  688. assert (x*y).subs({x: 1 + x, y: x}) == (1 + x)*x
  689. def test_issue_19558():
  690. e = (7*x*cos(x) - 12*log(x)**3)*(-log(x)**4 + 2*sin(x) + 1)**2/ \
  691. (2*(x*cos(x) - 2*log(x)**3)*(3*log(x)**4 - 7*sin(x) + 3)**2)
  692. assert e.subs(x, oo) == AccumBounds(-oo, oo)
  693. assert (sin(x) + cos(x)).subs(x, oo) == AccumBounds(-2, 2)
  694. def test_issue_22033():
  695. xr = Symbol('xr', real=True)
  696. e = (1/xr)
  697. assert e.subs(xr**2, y) == e
  698. def test_guard_against_indeterminate_evaluation():
  699. eq = x**y
  700. assert eq.subs([(x, 1), (y, oo)]) == 1 # because 1**y == 1
  701. assert eq.subs([(y, oo), (x, 1)]) is S.NaN
  702. assert eq.subs({x: 1, y: oo}) is S.NaN
  703. assert eq.subs([(x, 1), (y, oo)], simultaneous=True) is S.NaN