test_radsimp.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. from sympy.core.add import Add
  2. from sympy.core.function import (Derivative, Function, diff)
  3. from sympy.core.mul import Mul
  4. from sympy.core.numbers import (I, Rational)
  5. from sympy.core.power import Pow
  6. from sympy.core.singleton import S
  7. from sympy.core.symbol import (Symbol, Wild, symbols)
  8. from sympy.functions.elementary.complexes import Abs
  9. from sympy.functions.elementary.exponential import (exp, log)
  10. from sympy.functions.elementary.miscellaneous import (root, sqrt)
  11. from sympy.functions.elementary.trigonometric import (cos, sin)
  12. from sympy.polys.polytools import factor
  13. from sympy.series.order import O
  14. from sympy.simplify.radsimp import (collect, collect_const, fraction, radsimp, rcollect)
  15. from sympy.core.expr import unchanged
  16. from sympy.core.mul import _unevaluated_Mul as umul
  17. from sympy.simplify.radsimp import (_unevaluated_Add,
  18. collect_sqrt, fraction_expand, collect_abs)
  19. from sympy.testing.pytest import raises
  20. from sympy.abc import x, y, z, a, b, c, d
  21. def test_radsimp():
  22. r2 = sqrt(2)
  23. r3 = sqrt(3)
  24. r5 = sqrt(5)
  25. r7 = sqrt(7)
  26. assert fraction(radsimp(1/r2)) == (sqrt(2), 2)
  27. assert radsimp(1/(1 + r2)) == \
  28. -1 + sqrt(2)
  29. assert radsimp(1/(r2 + r3)) == \
  30. -sqrt(2) + sqrt(3)
  31. assert fraction(radsimp(1/(1 + r2 + r3))) == \
  32. (-sqrt(6) + sqrt(2) + 2, 4)
  33. assert fraction(radsimp(1/(r2 + r3 + r5))) == \
  34. (-sqrt(30) + 2*sqrt(3) + 3*sqrt(2), 12)
  35. assert fraction(radsimp(1/(1 + r2 + r3 + r5))) == (
  36. (-34*sqrt(10) - 26*sqrt(15) - 55*sqrt(3) - 61*sqrt(2) + 14*sqrt(30) +
  37. 93 + 46*sqrt(6) + 53*sqrt(5), 71))
  38. assert fraction(radsimp(1/(r2 + r3 + r5 + r7))) == (
  39. (-50*sqrt(42) - 133*sqrt(5) - 34*sqrt(70) - 145*sqrt(3) + 22*sqrt(105)
  40. + 185*sqrt(2) + 62*sqrt(30) + 135*sqrt(7), 215))
  41. z = radsimp(1/(1 + r2/3 + r3/5 + r5 + r7))
  42. assert len((3616791619821680643598*z).args) == 16
  43. assert radsimp(1/z) == 1/z
  44. assert radsimp(1/z, max_terms=20).expand() == 1 + r2/3 + r3/5 + r5 + r7
  45. assert radsimp(1/(r2*3)) == \
  46. sqrt(2)/6
  47. assert radsimp(1/(r2*a + r3 + r5 + r7)) == (
  48. (8*sqrt(2)*a**7 - 8*sqrt(7)*a**6 - 8*sqrt(5)*a**6 - 8*sqrt(3)*a**6 -
  49. 180*sqrt(2)*a**5 + 8*sqrt(30)*a**5 + 8*sqrt(42)*a**5 + 8*sqrt(70)*a**5
  50. - 24*sqrt(105)*a**4 + 84*sqrt(3)*a**4 + 100*sqrt(5)*a**4 +
  51. 116*sqrt(7)*a**4 - 72*sqrt(70)*a**3 - 40*sqrt(42)*a**3 -
  52. 8*sqrt(30)*a**3 + 782*sqrt(2)*a**3 - 462*sqrt(3)*a**2 -
  53. 302*sqrt(7)*a**2 - 254*sqrt(5)*a**2 + 120*sqrt(105)*a**2 -
  54. 795*sqrt(2)*a - 62*sqrt(30)*a + 82*sqrt(42)*a + 98*sqrt(70)*a -
  55. 118*sqrt(105) + 59*sqrt(7) + 295*sqrt(5) + 531*sqrt(3))/(16*a**8 -
  56. 480*a**6 + 3128*a**4 - 6360*a**2 + 3481))
  57. assert radsimp(1/(r2*a + r2*b + r3 + r7)) == (
  58. (sqrt(2)*a*(a + b)**2 - 5*sqrt(2)*a + sqrt(42)*a + sqrt(2)*b*(a +
  59. b)**2 - 5*sqrt(2)*b + sqrt(42)*b - sqrt(7)*(a + b)**2 - sqrt(3)*(a +
  60. b)**2 - 2*sqrt(3) + 2*sqrt(7))/(2*a**4 + 8*a**3*b + 12*a**2*b**2 -
  61. 20*a**2 + 8*a*b**3 - 40*a*b + 2*b**4 - 20*b**2 + 8))
  62. assert radsimp(1/(r2*a + r2*b + r2*c + r2*d)) == \
  63. sqrt(2)/(2*a + 2*b + 2*c + 2*d)
  64. assert radsimp(1/(1 + r2*a + r2*b + r2*c + r2*d)) == (
  65. (sqrt(2)*a + sqrt(2)*b + sqrt(2)*c + sqrt(2)*d - 1)/(2*a**2 + 4*a*b +
  66. 4*a*c + 4*a*d + 2*b**2 + 4*b*c + 4*b*d + 2*c**2 + 4*c*d + 2*d**2 - 1))
  67. assert radsimp((y**2 - x)/(y - sqrt(x))) == \
  68. sqrt(x) + y
  69. assert radsimp(-(y**2 - x)/(y - sqrt(x))) == \
  70. -(sqrt(x) + y)
  71. assert radsimp(1/(1 - I + a*I)) == \
  72. (-I*a + 1 + I)/(a**2 - 2*a + 2)
  73. assert radsimp(1/((-x + y)*(x - sqrt(y)))) == \
  74. (-x - sqrt(y))/((x - y)*(x**2 - y))
  75. e = (3 + 3*sqrt(2))*x*(3*x - 3*sqrt(y))
  76. assert radsimp(e) == x*(3 + 3*sqrt(2))*(3*x - 3*sqrt(y))
  77. assert radsimp(1/e) == (
  78. (-9*x + 9*sqrt(2)*x - 9*sqrt(y) + 9*sqrt(2)*sqrt(y))/(9*x*(9*x**2 -
  79. 9*y)))
  80. assert radsimp(1 + 1/(1 + sqrt(3))) == \
  81. Mul(S.Half, -1 + sqrt(3), evaluate=False) + 1
  82. A = symbols("A", commutative=False)
  83. assert radsimp(x**2 + sqrt(2)*x**2 - sqrt(2)*x*A) == \
  84. x**2 + sqrt(2)*x**2 - sqrt(2)*x*A
  85. assert radsimp(1/sqrt(5 + 2 * sqrt(6))) == -sqrt(2) + sqrt(3)
  86. assert radsimp(1/sqrt(5 + 2 * sqrt(6))**3) == -(-sqrt(3) + sqrt(2))**3
  87. # issue 6532
  88. assert fraction(radsimp(1/sqrt(x))) == (sqrt(x), x)
  89. assert fraction(radsimp(1/sqrt(2*x + 3))) == (sqrt(2*x + 3), 2*x + 3)
  90. assert fraction(radsimp(1/sqrt(2*(x + 3)))) == (sqrt(2*x + 6), 2*x + 6)
  91. # issue 5994
  92. e = S('-(2 + 2*sqrt(2) + 4*2**(1/4))/'
  93. '(1 + 2**(3/4) + 3*2**(1/4) + 3*sqrt(2))')
  94. assert radsimp(e).expand() == -2*2**Rational(3, 4) - 2*2**Rational(1, 4) + 2 + 2*sqrt(2)
  95. # issue 5986 (modifications to radimp didn't initially recognize this so
  96. # the test is included here)
  97. assert radsimp(1/(-sqrt(5)/2 - S.Half + (-sqrt(5)/2 - S.Half)**2)) == 1
  98. # from issue 5934
  99. eq = (
  100. (-240*sqrt(2)*sqrt(sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) -
  101. 360*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) -
  102. 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) +
  103. 120*sqrt(2)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) +
  104. 120*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5) +
  105. 120*sqrt(10)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) +
  106. 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5))/(-36000 -
  107. 7200*sqrt(5) + (12*sqrt(10)*sqrt(sqrt(5) + 5) +
  108. 24*sqrt(10)*sqrt(-sqrt(5) + 5))**2))
  109. assert radsimp(eq) is S.NaN # it's 0/0
  110. # work with normal form
  111. e = 1/sqrt(sqrt(7)/7 + 2*sqrt(2) + 3*sqrt(3) + 5*sqrt(5)) + 3
  112. assert radsimp(e) == (
  113. -sqrt(sqrt(7) + 14*sqrt(2) + 21*sqrt(3) +
  114. 35*sqrt(5))*(-11654899*sqrt(35) - 1577436*sqrt(210) - 1278438*sqrt(15)
  115. - 1346996*sqrt(10) + 1635060*sqrt(6) + 5709765 + 7539830*sqrt(14) +
  116. 8291415*sqrt(21))/1300423175 + 3)
  117. # obey power rules
  118. base = sqrt(3) - sqrt(2)
  119. assert radsimp(1/base**3) == (sqrt(3) + sqrt(2))**3
  120. assert radsimp(1/(-base)**3) == -(sqrt(2) + sqrt(3))**3
  121. assert radsimp(1/(-base)**x) == (-base)**(-x)
  122. assert radsimp(1/base**x) == (sqrt(2) + sqrt(3))**x
  123. assert radsimp(root(1/(-1 - sqrt(2)), -x)) == (-1)**(-1/x)*(1 + sqrt(2))**(1/x)
  124. # recurse
  125. e = cos(1/(1 + sqrt(2)))
  126. assert radsimp(e) == cos(-sqrt(2) + 1)
  127. assert radsimp(e/2) == cos(-sqrt(2) + 1)/2
  128. assert radsimp(1/e) == 1/cos(-sqrt(2) + 1)
  129. assert radsimp(2/e) == 2/cos(-sqrt(2) + 1)
  130. assert fraction(radsimp(e/sqrt(x))) == (sqrt(x)*cos(-sqrt(2)+1), x)
  131. # test that symbolic denominators are not processed
  132. r = 1 + sqrt(2)
  133. assert radsimp(x/r, symbolic=False) == -x*(-sqrt(2) + 1)
  134. assert radsimp(x/(y + r), symbolic=False) == x/(y + 1 + sqrt(2))
  135. assert radsimp(x/(y + r)/r, symbolic=False) == \
  136. -x*(-sqrt(2) + 1)/(y + 1 + sqrt(2))
  137. # issue 7408
  138. eq = sqrt(x)/sqrt(y)
  139. assert radsimp(eq) == umul(sqrt(x), sqrt(y), 1/y)
  140. assert radsimp(eq, symbolic=False) == eq
  141. # issue 7498
  142. assert radsimp(sqrt(x)/sqrt(y)**3) == umul(sqrt(x), sqrt(y**3), 1/y**3)
  143. # for coverage
  144. eq = sqrt(x)/y**2
  145. assert radsimp(eq) == eq
  146. def test_radsimp_issue_3214():
  147. c, p = symbols('c p', positive=True)
  148. s = sqrt(c**2 - p**2)
  149. b = (c + I*p - s)/(c + I*p + s)
  150. assert radsimp(b) == -I*(c + I*p - sqrt(c**2 - p**2))**2/(2*c*p)
  151. def test_collect_1():
  152. """Collect with respect to Symbol"""
  153. x, y, z, n = symbols('x,y,z,n')
  154. assert collect(1, x) == 1
  155. assert collect( x + y*x, x ) == x * (1 + y)
  156. assert collect( x + x**2, x ) == x + x**2
  157. assert collect( x**2 + y*x**2, x ) == (x**2)*(1 + y)
  158. assert collect( x**2 + y*x, x ) == x*y + x**2
  159. assert collect( 2*x**2 + y*x**2 + 3*x*y, [x] ) == x**2*(2 + y) + 3*x*y
  160. assert collect( 2*x**2 + y*x**2 + 3*x*y, [y] ) == 2*x**2 + y*(x**2 + 3*x)
  161. assert collect( ((1 + y + x)**4).expand(), x) == ((1 + y)**4).expand() + \
  162. x*(4*(1 + y)**3).expand() + x**2*(6*(1 + y)**2).expand() + \
  163. x**3*(4*(1 + y)).expand() + x**4
  164. # symbols can be given as any iterable
  165. expr = x + y
  166. assert collect(expr, expr.free_symbols) == expr
  167. assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None
  168. ) == x*exp(x) + 3*x + (y + 2)*sin(x)
  169. assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x + y*x +
  170. y*x*exp(x), x, exact=None
  171. ) == x*exp(x)*(y + 1) + (3 + y)*x + (y + 2)*sin(x)
  172. def test_collect_2():
  173. """Collect with respect to a sum"""
  174. a, b, x = symbols('a,b,x')
  175. assert collect(a*(cos(x) + sin(x)) + b*(cos(x) + sin(x)),
  176. sin(x) + cos(x)) == (a + b)*(cos(x) + sin(x))
  177. def test_collect_3():
  178. """Collect with respect to a product"""
  179. a, b, c = symbols('a,b,c')
  180. f = Function('f')
  181. x, y, z, n = symbols('x,y,z,n')
  182. assert collect(-x/8 + x*y, -x) == x*(y - Rational(1, 8))
  183. assert collect( 1 + x*(y**2), x*y ) == 1 + x*(y**2)
  184. assert collect( x*y + a*x*y, x*y) == x*y*(1 + a)
  185. assert collect( 1 + x*y + a*x*y, x*y) == 1 + x*y*(1 + a)
  186. assert collect(a*x*f(x) + b*(x*f(x)), x*f(x)) == x*(a + b)*f(x)
  187. assert collect(a*x*log(x) + b*(x*log(x)), x*log(x)) == x*(a + b)*log(x)
  188. assert collect(a*x**2*log(x)**2 + b*(x*log(x))**2, x*log(x)) == \
  189. x**2*log(x)**2*(a + b)
  190. # with respect to a product of three symbols
  191. assert collect(y*x*z + a*x*y*z, x*y*z) == (1 + a)*x*y*z
  192. def test_collect_4():
  193. """Collect with respect to a power"""
  194. a, b, c, x = symbols('a,b,c,x')
  195. assert collect(a*x**c + b*x**c, x**c) == x**c*(a + b)
  196. # issue 6096: 2 stays with c (unless c is integer or x is positive0
  197. assert collect(a*x**(2*c) + b*x**(2*c), x**c) == x**(2*c)*(a + b)
  198. def test_collect_5():
  199. """Collect with respect to a tuple"""
  200. a, x, y, z, n = symbols('a,x,y,z,n')
  201. assert collect(x**2*y**4 + z*(x*y**2)**2 + z + a*z, [x*y**2, z]) in [
  202. z*(1 + a + x**2*y**4) + x**2*y**4,
  203. z*(1 + a) + x**2*y**4*(1 + z) ]
  204. assert collect((1 + (x + y) + (x + y)**2).expand(),
  205. [x, y]) == 1 + y + x*(1 + 2*y) + x**2 + y**2
  206. def test_collect_pr19431():
  207. """Unevaluated collect with respect to a product"""
  208. a = symbols('a')
  209. assert collect(a**2*(a**2 + 1), a**2, evaluate=False)[a**2] == (a**2 + 1)
  210. def test_collect_D():
  211. D = Derivative
  212. f = Function('f')
  213. x, a, b = symbols('x,a,b')
  214. fx = D(f(x), x)
  215. fxx = D(f(x), x, x)
  216. assert collect(a*fx + b*fx, fx) == (a + b)*fx
  217. assert collect(a*D(fx, x) + b*D(fx, x), fx) == (a + b)*D(fx, x)
  218. assert collect(a*fxx + b*fxx, fx) == (a + b)*D(fx, x)
  219. # issue 4784
  220. assert collect(5*f(x) + 3*fx, fx) == 5*f(x) + 3*fx
  221. assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x)) == \
  222. (x*f(x) + f(x))*D(f(x), x) + f(x)
  223. assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x), exact=True) == \
  224. (x*f(x) + f(x))*D(f(x), x) + f(x)
  225. assert collect(1/f(x) + 1/f(x)*diff(f(x), x) + x*diff(f(x), x)/f(x), f(x).diff(x), exact=True) == \
  226. (1/f(x) + x/f(x))*D(f(x), x) + 1/f(x)
  227. e = (1 + x*fx + fx)/f(x)
  228. assert collect(e.expand(), fx) == fx*(x/f(x) + 1/f(x)) + 1/f(x)
  229. def test_collect_func():
  230. f = ((x + a + 1)**3).expand()
  231. assert collect(f, x) == a**3 + 3*a**2 + 3*a + x**3 + x**2*(3*a + 3) + \
  232. x*(3*a**2 + 6*a + 3) + 1
  233. assert collect(f, x, factor) == x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + \
  234. (a + 1)**3
  235. assert collect(f, x, evaluate=False) == {
  236. S.One: a**3 + 3*a**2 + 3*a + 1,
  237. x: 3*a**2 + 6*a + 3, x**2: 3*a + 3,
  238. x**3: 1
  239. }
  240. assert collect(f, x, factor, evaluate=False) == {
  241. S.One: (a + 1)**3, x: 3*(a + 1)**2,
  242. x**2: umul(S(3), a + 1), x**3: 1}
  243. def test_collect_order():
  244. a, b, x, t = symbols('a,b,x,t')
  245. assert collect(t + t*x + t*x**2 + O(x**3), t) == t*(1 + x + x**2 + O(x**3))
  246. assert collect(t + t*x + x**2 + O(x**3), t) == \
  247. t*(1 + x + O(x**3)) + x**2 + O(x**3)
  248. f = a*x + b*x + c*x**2 + d*x**2 + O(x**3)
  249. g = x*(a + b) + x**2*(c + d) + O(x**3)
  250. assert collect(f, x) == g
  251. assert collect(f, x, distribute_order_term=False) == g
  252. f = sin(a + b).series(b, 0, 10)
  253. assert collect(f, [sin(a), cos(a)]) == \
  254. sin(a)*cos(b).series(b, 0, 10) + cos(a)*sin(b).series(b, 0, 10)
  255. assert collect(f, [sin(a), cos(a)], distribute_order_term=False) == \
  256. sin(a)*cos(b).series(b, 0, 10).removeO() + \
  257. cos(a)*sin(b).series(b, 0, 10).removeO() + O(b**10)
  258. def test_rcollect():
  259. assert rcollect((x**2*y + x*y + x + y)/(x + y), y) == \
  260. (x + y*(1 + x + x**2))/(x + y)
  261. assert rcollect(sqrt(-((x + 1)*(y + 1))), z) == sqrt(-((x + 1)*(y + 1)))
  262. def test_collect_D_0():
  263. D = Derivative
  264. f = Function('f')
  265. x, a, b = symbols('x,a,b')
  266. fxx = D(f(x), x, x)
  267. assert collect(a*fxx + b*fxx, fxx) == (a + b)*fxx
  268. def test_collect_Wild():
  269. """Collect with respect to functions with Wild argument"""
  270. a, b, x, y = symbols('a b x y')
  271. f = Function('f')
  272. w1 = Wild('.1')
  273. w2 = Wild('.2')
  274. assert collect(f(x) + a*f(x), f(w1)) == (1 + a)*f(x)
  275. assert collect(f(x, y) + a*f(x, y), f(w1)) == f(x, y) + a*f(x, y)
  276. assert collect(f(x, y) + a*f(x, y), f(w1, w2)) == (1 + a)*f(x, y)
  277. assert collect(f(x, y) + a*f(x, y), f(w1, w1)) == f(x, y) + a*f(x, y)
  278. assert collect(f(x, x) + a*f(x, x), f(w1, w1)) == (1 + a)*f(x, x)
  279. assert collect(a*(x + 1)**y + (x + 1)**y, w1**y) == (1 + a)*(x + 1)**y
  280. assert collect(a*(x + 1)**y + (x + 1)**y, w1**b) == \
  281. a*(x + 1)**y + (x + 1)**y
  282. assert collect(a*(x + 1)**y + (x + 1)**y, (x + 1)**w2) == \
  283. (1 + a)*(x + 1)**y
  284. assert collect(a*(x + 1)**y + (x + 1)**y, w1**w2) == (1 + a)*(x + 1)**y
  285. def test_collect_const():
  286. # coverage not provided by above tests
  287. assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == \
  288. 2*(2*sqrt(5)*a + sqrt(3)) # let the primitive reabsorb
  289. assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == \
  290. 2*sqrt(3) + 4*a*sqrt(5)
  291. assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \
  292. sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3)
  293. # issue 5290
  294. assert collect_const(2*x + 2*y + 1, 2) == \
  295. collect_const(2*x + 2*y + 1) == \
  296. Add(S.One, Mul(2, x + y, evaluate=False), evaluate=False)
  297. assert collect_const(-y - z) == Mul(-1, y + z, evaluate=False)
  298. assert collect_const(2*x - 2*y - 2*z, 2) == \
  299. Mul(2, x - y - z, evaluate=False)
  300. assert collect_const(2*x - 2*y - 2*z, -2) == \
  301. _unevaluated_Add(2*x, Mul(-2, y + z, evaluate=False))
  302. # this is why the content_primitive is used
  303. eq = (sqrt(15 + 5*sqrt(2))*x + sqrt(3 + sqrt(2))*y)*2
  304. assert collect_sqrt(eq + 2) == \
  305. 2*sqrt(sqrt(2) + 3)*(sqrt(5)*x + y) + 2
  306. # issue 16296
  307. assert collect_const(a + b + x/2 + y/2) == a + b + Mul(S.Half, x + y, evaluate=False)
  308. def test_issue_13143():
  309. f = Function('f')
  310. fx = f(x).diff(x)
  311. e = f(x) + fx + f(x)*fx
  312. # collect function before derivative
  313. assert collect(e, Wild('w')) == f(x)*(fx + 1) + fx
  314. e = f(x) + f(x)*fx + x*fx*f(x)
  315. assert collect(e, fx) == (x*f(x) + f(x))*fx + f(x)
  316. assert collect(e, f(x)) == (x*fx + fx + 1)*f(x)
  317. e = f(x) + fx + f(x)*fx
  318. assert collect(e, [f(x), fx]) == f(x)*(1 + fx) + fx
  319. assert collect(e, [fx, f(x)]) == fx*(1 + f(x)) + f(x)
  320. def test_issue_6097():
  321. assert collect(a*y**(2.0*x) + b*y**(2.0*x), y**x) == (a + b)*(y**x)**2.0
  322. assert collect(a*2**(2.0*x) + b*2**(2.0*x), 2**x) == (a + b)*(2**x)**2.0
  323. def test_fraction_expand():
  324. eq = (x + y)*y/x
  325. assert eq.expand(frac=True) == fraction_expand(eq) == (x*y + y**2)/x
  326. assert eq.expand() == y + y**2/x
  327. def test_fraction():
  328. x, y, z = map(Symbol, 'xyz')
  329. A = Symbol('A', commutative=False)
  330. assert fraction(S.Half) == (1, 2)
  331. assert fraction(x) == (x, 1)
  332. assert fraction(1/x) == (1, x)
  333. assert fraction(x/y) == (x, y)
  334. assert fraction(x/2) == (x, 2)
  335. assert fraction(x*y/z) == (x*y, z)
  336. assert fraction(x/(y*z)) == (x, y*z)
  337. assert fraction(1/y**2) == (1, y**2)
  338. assert fraction(x/y**2) == (x, y**2)
  339. assert fraction((x**2 + 1)/y) == (x**2 + 1, y)
  340. assert fraction(x*(y + 1)/y**7) == (x*(y + 1), y**7)
  341. assert fraction(exp(-x), exact=True) == (exp(-x), 1)
  342. assert fraction((1/(x + y))/2, exact=True) == (1, Mul(2,(x + y), evaluate=False))
  343. assert fraction(x*A/y) == (x*A, y)
  344. assert fraction(x*A**-1/y) == (x*A**-1, y)
  345. n = symbols('n', negative=True)
  346. assert fraction(exp(n)) == (1, exp(-n))
  347. assert fraction(exp(-n)) == (exp(-n), 1)
  348. p = symbols('p', positive=True)
  349. assert fraction(exp(-p)*log(p), exact=True) == (exp(-p)*log(p), 1)
  350. m = Mul(1, 1, S.Half, evaluate=False)
  351. assert fraction(m) == (1, 2)
  352. assert fraction(m, exact=True) == (Mul(1, 1, evaluate=False), 2)
  353. m = Mul(1, 1, S.Half, S.Half, Pow(1, -1, evaluate=False), evaluate=False)
  354. assert fraction(m) == (1, 4)
  355. assert fraction(m, exact=True) == \
  356. (Mul(1, 1, evaluate=False), Mul(2, 2, 1, evaluate=False))
  357. def test_issue_5615():
  358. aA, Re, a, b, D = symbols('aA Re a b D')
  359. e = ((D**3*a + b*aA**3)/Re).expand()
  360. assert collect(e, [aA**3/Re, a]) == e
  361. def test_issue_5933():
  362. from sympy.geometry.polygon import (Polygon, RegularPolygon)
  363. from sympy.simplify.radsimp import denom
  364. x = Polygon(*RegularPolygon((0, 0), 1, 5).vertices).centroid.x
  365. assert abs(denom(x).n()) > 1e-12
  366. assert abs(denom(radsimp(x))) > 1e-12 # in case simplify didn't handle it
  367. def test_issue_14608():
  368. a, b = symbols('a b', commutative=False)
  369. x, y = symbols('x y')
  370. raises(AttributeError, lambda: collect(a*b + b*a, a))
  371. assert collect(x*y + y*(x+1), a) == x*y + y*(x+1)
  372. assert collect(x*y + y*(x+1) + a*b + b*a, y) == y*(2*x + 1) + a*b + b*a
  373. def test_collect_abs():
  374. s = abs(x) + abs(y)
  375. assert collect_abs(s) == s
  376. assert unchanged(Mul, abs(x), abs(y))
  377. ans = Abs(x*y)
  378. assert isinstance(ans, Abs)
  379. assert collect_abs(abs(x)*abs(y)) == ans
  380. assert collect_abs(1 + exp(abs(x)*abs(y))) == 1 + exp(ans)
  381. # See https://github.com/sympy/sympy/issues/12910
  382. p = Symbol('p', positive=True)
  383. assert collect_abs(p/abs(1-p)).is_commutative is True
  384. def test_issue_19149():
  385. eq = exp(3*x/4)
  386. assert collect(eq, exp(x)) == eq
  387. def test_issue_19719():
  388. a, b = symbols('a, b')
  389. expr = a**2 * (b + 1) + (7 + 1/b)/a
  390. collected = collect(expr, (a**2, 1/a), evaluate=False)
  391. # Would return {_Dummy_20**(-2): b + 1, 1/a: 7 + 1/b} without xreplace
  392. assert collected == {a**2: b + 1, 1/a: 7 + 1/b}
  393. def test_issue_21355():
  394. assert radsimp(1/(x + sqrt(x**2))) == 1/(x + sqrt(x**2))
  395. assert radsimp(1/(x - sqrt(x**2))) == 1/(x - sqrt(x**2))