test_expand.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. from sympy.core.expr import unchanged
  2. from sympy.core.mul import Mul
  3. from sympy.core.numbers import (I, Rational as R, pi)
  4. from sympy.core.power import Pow
  5. from sympy.core.singleton import S
  6. from sympy.core.symbol import Symbol
  7. from sympy.functions.elementary.exponential import (exp, log)
  8. from sympy.functions.elementary.miscellaneous import sqrt
  9. from sympy.functions.elementary.trigonometric import (cos, sin)
  10. from sympy.series.order import O
  11. from sympy.simplify.radsimp import expand_numer
  12. from sympy.core.function import expand, expand_multinomial, expand_power_base
  13. from sympy.testing.pytest import raises
  14. from sympy.core.random import verify_numerically
  15. from sympy.abc import x, y, z
  16. def test_expand_no_log():
  17. assert (
  18. (1 + log(x**4))**2).expand(log=False) == 1 + 2*log(x**4) + log(x**4)**2
  19. assert ((1 + log(x**4))*(1 + log(x**3))).expand(
  20. log=False) == 1 + log(x**4) + log(x**3) + log(x**4)*log(x**3)
  21. def test_expand_no_multinomial():
  22. assert ((1 + x)*(1 + (1 + x)**4)).expand(multinomial=False) == \
  23. 1 + x + (1 + x)**4 + x*(1 + x)**4
  24. def test_expand_negative_integer_powers():
  25. expr = (x + y)**(-2)
  26. assert expr.expand() == 1 / (2*x*y + x**2 + y**2)
  27. assert expr.expand(multinomial=False) == (x + y)**(-2)
  28. expr = (x + y)**(-3)
  29. assert expr.expand() == 1 / (3*x*x*y + 3*x*y*y + x**3 + y**3)
  30. assert expr.expand(multinomial=False) == (x + y)**(-3)
  31. expr = (x + y)**(2) * (x + y)**(-4)
  32. assert expr.expand() == 1 / (2*x*y + x**2 + y**2)
  33. assert expr.expand(multinomial=False) == (x + y)**(-2)
  34. def test_expand_non_commutative():
  35. A = Symbol('A', commutative=False)
  36. B = Symbol('B', commutative=False)
  37. C = Symbol('C', commutative=False)
  38. a = Symbol('a')
  39. b = Symbol('b')
  40. i = Symbol('i', integer=True)
  41. n = Symbol('n', negative=True)
  42. m = Symbol('m', negative=True)
  43. p = Symbol('p', polar=True)
  44. np = Symbol('p', polar=False)
  45. assert (C*(A + B)).expand() == C*A + C*B
  46. assert (C*(A + B)).expand() != A*C + B*C
  47. assert ((A + B)**2).expand() == A**2 + A*B + B*A + B**2
  48. assert ((A + B)**3).expand() == (A**2*B + B**2*A + A*B**2 + B*A**2 +
  49. A**3 + B**3 + A*B*A + B*A*B)
  50. # issue 6219
  51. assert ((a*A*B*A**-1)**2).expand() == a**2*A*B**2/A
  52. # Note that (a*A*B*A**-1)**2 is automatically converted to a**2*(A*B*A**-1)**2
  53. assert ((a*A*B*A**-1)**2).expand(deep=False) == a**2*(A*B*A**-1)**2
  54. assert ((a*A*B*A**-1)**2).expand() == a**2*(A*B**2*A**-1)
  55. assert ((a*A*B*A**-1)**2).expand(force=True) == a**2*A*B**2*A**(-1)
  56. assert ((a*A*B)**2).expand() == a**2*A*B*A*B
  57. assert ((a*A)**2).expand() == a**2*A**2
  58. assert ((a*A*B)**i).expand() == a**i*(A*B)**i
  59. assert ((a*A*(B*(A*B/A)**2))**i).expand() == a**i*(A*B*A*B**2/A)**i
  60. # issue 6558
  61. assert (A*B*(A*B)**-1).expand() == 1
  62. assert ((a*A)**i).expand() == a**i*A**i
  63. assert ((a*A*B*A**-1)**3).expand() == a**3*A*B**3/A
  64. assert ((a*A*B*A*B/A)**3).expand() == \
  65. a**3*A*B*(A*B**2)*(A*B**2)*A*B*A**(-1)
  66. assert ((a*A*B*A*B/A)**-2).expand() == \
  67. A*B**-1*A**-1*B**-2*A**-1*B**-1*A**-1/a**2
  68. assert ((a*b*A*B*A**-1)**i).expand() == a**i*b**i*(A*B/A)**i
  69. assert ((a*(a*b)**i)**i).expand() == a**i*a**(i**2)*b**(i**2)
  70. e = Pow(Mul(a, 1/a, A, B, evaluate=False), S(2), evaluate=False)
  71. assert e.expand() == A*B*A*B
  72. assert sqrt(a*(A*b)**i).expand() == sqrt(a*b**i*A**i)
  73. assert (sqrt(-a)**a).expand() == sqrt(-a)**a
  74. assert expand((-2*n)**(i/3)) == 2**(i/3)*(-n)**(i/3)
  75. assert expand((-2*n*m)**(i/a)) == (-2)**(i/a)*(-n)**(i/a)*(-m)**(i/a)
  76. assert expand((-2*a*p)**b) == 2**b*p**b*(-a)**b
  77. assert expand((-2*a*np)**b) == 2**b*(-a*np)**b
  78. assert expand(sqrt(A*B)) == sqrt(A*B)
  79. assert expand(sqrt(-2*a*b)) == sqrt(2)*sqrt(-a*b)
  80. def test_expand_radicals():
  81. a = (x + y)**R(1, 2)
  82. assert (a**1).expand() == a
  83. assert (a**3).expand() == x*a + y*a
  84. assert (a**5).expand() == x**2*a + 2*x*y*a + y**2*a
  85. assert (1/a**1).expand() == 1/a
  86. assert (1/a**3).expand() == 1/(x*a + y*a)
  87. assert (1/a**5).expand() == 1/(x**2*a + 2*x*y*a + y**2*a)
  88. a = (x + y)**R(1, 3)
  89. assert (a**1).expand() == a
  90. assert (a**2).expand() == a**2
  91. assert (a**4).expand() == x*a + y*a
  92. assert (a**5).expand() == x*a**2 + y*a**2
  93. assert (a**7).expand() == x**2*a + 2*x*y*a + y**2*a
  94. def test_expand_modulus():
  95. assert ((x + y)**11).expand(modulus=11) == x**11 + y**11
  96. assert ((x + sqrt(2)*y)**11).expand(modulus=11) == x**11 + 10*sqrt(2)*y**11
  97. assert (x + y/2).expand(modulus=1) == y/2
  98. raises(ValueError, lambda: ((x + y)**11).expand(modulus=0))
  99. raises(ValueError, lambda: ((x + y)**11).expand(modulus=x))
  100. def test_issue_5743():
  101. assert (x*sqrt(
  102. x + y)*(1 + sqrt(x + y))).expand() == x**2 + x*y + x*sqrt(x + y)
  103. assert (x*sqrt(
  104. x + y)*(1 + x*sqrt(x + y))).expand() == x**3 + x**2*y + x*sqrt(x + y)
  105. def test_expand_frac():
  106. assert expand((x + y)*y/x/(x + 1), frac=True) == \
  107. (x*y + y**2)/(x**2 + x)
  108. assert expand((x + y)*y/x/(x + 1), numer=True) == \
  109. (x*y + y**2)/(x*(x + 1))
  110. assert expand((x + y)*y/x/(x + 1), denom=True) == \
  111. y*(x + y)/(x**2 + x)
  112. eq = (x + 1)**2/y
  113. assert expand_numer(eq, multinomial=False) == eq
  114. def test_issue_6121():
  115. eq = -I*exp(-3*I*pi/4)/(4*pi**(S(3)/2)*sqrt(x))
  116. assert eq.expand(complex=True) # does not give oo recursion
  117. eq = -I*exp(-3*I*pi/4)/(4*pi**(R(3, 2))*sqrt(x))
  118. assert eq.expand(complex=True) # does not give oo recursion
  119. def test_expand_power_base():
  120. assert expand_power_base((x*y*z)**4) == x**4*y**4*z**4
  121. assert expand_power_base((x*y*z)**x).is_Pow
  122. assert expand_power_base((x*y*z)**x, force=True) == x**x*y**x*z**x
  123. assert expand_power_base((x*(y*z)**2)**3) == x**3*y**6*z**6
  124. assert expand_power_base((sin((x*y)**2)*y)**z).is_Pow
  125. assert expand_power_base(
  126. (sin((x*y)**2)*y)**z, force=True) == sin((x*y)**2)**z*y**z
  127. assert expand_power_base(
  128. (sin((x*y)**2)*y)**z, deep=True) == (sin(x**2*y**2)*y)**z
  129. assert expand_power_base(exp(x)**2) == exp(2*x)
  130. assert expand_power_base((exp(x)*exp(y))**2) == exp(2*x)*exp(2*y)
  131. assert expand_power_base(
  132. (exp((x*y)**z)*exp(y))**2) == exp(2*(x*y)**z)*exp(2*y)
  133. assert expand_power_base((exp((x*y)**z)*exp(
  134. y))**2, deep=True, force=True) == exp(2*x**z*y**z)*exp(2*y)
  135. assert expand_power_base((exp(x)*exp(y))**z).is_Pow
  136. assert expand_power_base(
  137. (exp(x)*exp(y))**z, force=True) == exp(x)**z*exp(y)**z
  138. def test_expand_arit():
  139. a = Symbol("a")
  140. b = Symbol("b", positive=True)
  141. c = Symbol("c")
  142. p = R(5)
  143. e = (a + b)*c
  144. assert e == c*(a + b)
  145. assert (e.expand() - a*c - b*c) == R(0)
  146. e = (a + b)*(a + b)
  147. assert e == (a + b)**2
  148. assert e.expand() == 2*a*b + a**2 + b**2
  149. e = (a + b)*(a + b)**R(2)
  150. assert e == (a + b)**3
  151. assert e.expand() == 3*b*a**2 + 3*a*b**2 + a**3 + b**3
  152. assert e.expand() == 3*b*a**2 + 3*a*b**2 + a**3 + b**3
  153. e = (a + b)*(a + c)*(b + c)
  154. assert e == (a + c)*(a + b)*(b + c)
  155. assert e.expand() == 2*a*b*c + b*a**2 + c*a**2 + b*c**2 + a*c**2 + c*b**2 + a*b**2
  156. e = (a + R(1))**p
  157. assert e == (1 + a)**5
  158. assert e.expand() == 1 + 5*a + 10*a**2 + 10*a**3 + 5*a**4 + a**5
  159. e = (a + b + c)*(a + c + p)
  160. assert e == (5 + a + c)*(a + b + c)
  161. assert e.expand() == 5*a + 5*b + 5*c + 2*a*c + b*c + a*b + a**2 + c**2
  162. x = Symbol("x")
  163. s = exp(x*x) - 1
  164. e = s.nseries(x, 0, 6)/x**2
  165. assert e.expand() == 1 + x**2/2 + O(x**4)
  166. e = (x*(y + z))**(x*(y + z))*(x + y)
  167. assert e.expand(power_exp=False, power_base=False) == x*(x*y + x*
  168. z)**(x*y + x*z) + y*(x*y + x*z)**(x*y + x*z)
  169. assert e.expand(power_exp=False, power_base=False, deep=False) == x* \
  170. (x*(y + z))**(x*(y + z)) + y*(x*(y + z))**(x*(y + z))
  171. e = x * (x + (y + 1)**2)
  172. assert e.expand(deep=False) == x**2 + x*(y + 1)**2
  173. e = (x*(y + z))**z
  174. assert e.expand(power_base=True, mul=True, deep=True) in [x**z*(y +
  175. z)**z, (x*y + x*z)**z]
  176. assert ((2*y)**z).expand() == 2**z*y**z
  177. p = Symbol('p', positive=True)
  178. assert sqrt(-x).expand().is_Pow
  179. assert sqrt(-x).expand(force=True) == I*sqrt(x)
  180. assert ((2*y*p)**z).expand() == 2**z*p**z*y**z
  181. assert ((2*y*p*x)**z).expand() == 2**z*p**z*(x*y)**z
  182. assert ((2*y*p*x)**z).expand(force=True) == 2**z*p**z*x**z*y**z
  183. assert ((2*y*p*-pi)**z).expand() == 2**z*pi**z*p**z*(-y)**z
  184. assert ((2*y*p*-pi*x)**z).expand() == 2**z*pi**z*p**z*(-x*y)**z
  185. n = Symbol('n', negative=True)
  186. m = Symbol('m', negative=True)
  187. assert ((-2*x*y*n)**z).expand() == 2**z*(-n)**z*(x*y)**z
  188. assert ((-2*x*y*n*m)**z).expand() == 2**z*(-m)**z*(-n)**z*(-x*y)**z
  189. # issue 5482
  190. assert sqrt(-2*x*n) == sqrt(2)*sqrt(-n)*sqrt(x)
  191. # issue 5605 (2)
  192. assert (cos(x + y)**2).expand(trig=True) in [
  193. (-sin(x)*sin(y) + cos(x)*cos(y))**2,
  194. sin(x)**2*sin(y)**2 - 2*sin(x)*sin(y)*cos(x)*cos(y) + cos(x)**2*cos(y)**2
  195. ]
  196. # Check that this isn't too slow
  197. x = Symbol('x')
  198. W = 1
  199. for i in range(1, 21):
  200. W = W * (x - i)
  201. W = W.expand()
  202. assert W.has(-1672280820*x**15)
  203. def test_expand_mul():
  204. # part of issue 20597
  205. e = Mul(2, 3, evaluate=False)
  206. assert e.expand() == 6
  207. e = Mul(2, 3, 1/x, evaluate = False)
  208. assert e.expand() == 6/x
  209. e = Mul(2, R(1, 3), evaluate=False)
  210. assert e.expand() == R(2, 3)
  211. def test_power_expand():
  212. """Test for Pow.expand()"""
  213. a = Symbol('a')
  214. b = Symbol('b')
  215. p = (a + b)**2
  216. assert p.expand() == a**2 + b**2 + 2*a*b
  217. p = (1 + 2*(1 + a))**2
  218. assert p.expand() == 9 + 4*(a**2) + 12*a
  219. p = 2**(a + b)
  220. assert p.expand() == 2**a*2**b
  221. A = Symbol('A', commutative=False)
  222. B = Symbol('B', commutative=False)
  223. assert (2**(A + B)).expand() == 2**(A + B)
  224. assert (A**(a + b)).expand() != A**(a + b)
  225. def test_issues_5919_6830():
  226. # issue 5919
  227. n = -1 + 1/x
  228. z = n/x/(-n)**2 - 1/n/x
  229. assert expand(z) == 1/(x**2 - 2*x + 1) - 1/(x - 2 + 1/x) - 1/(-x + 1)
  230. # issue 6830
  231. p = (1 + x)**2
  232. assert expand_multinomial((1 + x*p)**2) == (
  233. x**2*(x**4 + 4*x**3 + 6*x**2 + 4*x + 1) + 2*x*(x**2 + 2*x + 1) + 1)
  234. assert expand_multinomial((1 + (y + x)*p)**2) == (
  235. 2*((x + y)*(x**2 + 2*x + 1)) + (x**2 + 2*x*y + y**2)*
  236. (x**4 + 4*x**3 + 6*x**2 + 4*x + 1) + 1)
  237. A = Symbol('A', commutative=False)
  238. p = (1 + A)**2
  239. assert expand_multinomial((1 + x*p)**2) == (
  240. x**2*(1 + 4*A + 6*A**2 + 4*A**3 + A**4) + 2*x*(1 + 2*A + A**2) + 1)
  241. assert expand_multinomial((1 + (y + x)*p)**2) == (
  242. (x + y)*(1 + 2*A + A**2)*2 + (x**2 + 2*x*y + y**2)*
  243. (1 + 4*A + 6*A**2 + 4*A**3 + A**4) + 1)
  244. assert expand_multinomial((1 + (y + x)*p)**3) == (
  245. (x + y)*(1 + 2*A + A**2)*3 + (x**2 + 2*x*y + y**2)*(1 + 4*A +
  246. 6*A**2 + 4*A**3 + A**4)*3 + (x**3 + 3*x**2*y + 3*x*y**2 + y**3)*(1 + 6*A
  247. + 15*A**2 + 20*A**3 + 15*A**4 + 6*A**5 + A**6) + 1)
  248. # unevaluate powers
  249. eq = (Pow((x + 1)*((A + 1)**2), 2, evaluate=False))
  250. # - in this case the base is not an Add so no further
  251. # expansion is done
  252. assert expand_multinomial(eq) == \
  253. (x**2 + 2*x + 1)*(1 + 4*A + 6*A**2 + 4*A**3 + A**4)
  254. # - but here, the expanded base *is* an Add so it gets expanded
  255. eq = (Pow(((A + 1)**2), 2, evaluate=False))
  256. assert expand_multinomial(eq) == 1 + 4*A + 6*A**2 + 4*A**3 + A**4
  257. # coverage
  258. def ok(a, b, n):
  259. e = (a + I*b)**n
  260. return verify_numerically(e, expand_multinomial(e))
  261. for a in [2, S.Half]:
  262. for b in [3, R(1, 3)]:
  263. for n in range(2, 6):
  264. assert ok(a, b, n)
  265. assert expand_multinomial((x + 1 + O(z))**2) == \
  266. 1 + 2*x + x**2 + O(z)
  267. assert expand_multinomial((x + 1 + O(z))**3) == \
  268. 1 + 3*x + 3*x**2 + x**3 + O(z)
  269. assert expand_multinomial(3**(x + y + 3)) == 27*3**(x + y)
  270. def test_expand_log():
  271. t = Symbol('t', positive=True)
  272. # after first expansion, -2*log(2) + log(4); then 0 after second
  273. assert expand(log(t**2) - log(t**2/4) - 2*log(2)) == 0
  274. def test_issue_23952():
  275. assert (x**(y + z)).expand(force=True) == x**y*x**z
  276. one = Symbol('1', integer=True, prime=True, odd=True, positive=True)
  277. two = Symbol('2', integer=True, prime=True, even=True)
  278. e = two - one
  279. for b in (0, x):
  280. # 0**e = 0, 0**-e = zoo; but if expanded then nan
  281. assert unchanged(Pow, b, e) # power_exp
  282. assert unchanged(Pow, b, -e) # power_exp
  283. assert unchanged(Pow, b, y - x) # power_exp
  284. assert unchanged(Pow, b, 3 - x) # multinomial
  285. assert (b**e).expand().is_Pow # power_exp
  286. assert (b**-e).expand().is_Pow # power_exp
  287. assert (b**(y - x)).expand().is_Pow # power_exp
  288. assert (b**(3 - x)).expand().is_Pow # multinomial
  289. nn1 = Symbol('nn1', nonnegative=True)
  290. nn2 = Symbol('nn2', nonnegative=True)
  291. nn3 = Symbol('nn3', nonnegative=True)
  292. assert (x**(nn1 + nn2)).expand() == x**nn1*x**nn2
  293. assert (x**(-nn1 - nn2)).expand() == x**-nn1*x**-nn2
  294. assert unchanged(Pow, x, nn1 + nn2 - nn3)
  295. assert unchanged(Pow, x, 1 + nn2 - nn3)
  296. assert unchanged(Pow, x, nn1 - nn2)
  297. assert unchanged(Pow, x, 1 - nn2)
  298. assert unchanged(Pow, x, -1 + nn2)