test_sqrtdenest.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from sympy.core.mul import Mul
  2. from sympy.core.numbers import (I, Integer, Rational)
  3. from sympy.core.symbol import Symbol
  4. from sympy.functions.elementary.miscellaneous import (root, sqrt)
  5. from sympy.functions.elementary.trigonometric import cos
  6. from sympy.integrals.integrals import Integral
  7. from sympy.simplify.sqrtdenest import sqrtdenest
  8. from sympy.simplify.sqrtdenest import (
  9. _subsets as subsets, _sqrt_numeric_denest)
  10. r2, r3, r5, r6, r7, r10, r15, r29 = [sqrt(x) for x in (2, 3, 5, 6, 7, 10,
  11. 15, 29)]
  12. def test_sqrtdenest():
  13. d = {sqrt(5 + 2 * r6): r2 + r3,
  14. sqrt(5. + 2 * r6): sqrt(5. + 2 * r6),
  15. sqrt(5. + 4*sqrt(5 + 2 * r6)): sqrt(5.0 + 4*r2 + 4*r3),
  16. sqrt(r2): sqrt(r2),
  17. sqrt(5 + r7): sqrt(5 + r7),
  18. sqrt(3 + sqrt(5 + 2*r7)):
  19. 3*r2*(5 + 2*r7)**Rational(1, 4)/(2*sqrt(6 + 3*r7)) +
  20. r2*sqrt(6 + 3*r7)/(2*(5 + 2*r7)**Rational(1, 4)),
  21. sqrt(3 + 2*r3): 3**Rational(3, 4)*(r6/2 + 3*r2/2)/3}
  22. for i in d:
  23. assert sqrtdenest(i) == d[i], i
  24. def test_sqrtdenest2():
  25. assert sqrtdenest(sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))) == \
  26. r5 + sqrt(11 - 2*r29)
  27. e = sqrt(-r5 + sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16))
  28. assert sqrtdenest(e) == root(-2*r29 + 11, 4)
  29. r = sqrt(1 + r7)
  30. assert sqrtdenest(sqrt(1 + r)) == sqrt(1 + r)
  31. e = sqrt(((1 + sqrt(1 + 2*sqrt(3 + r2 + r5)))**2).expand())
  32. assert sqrtdenest(e) == 1 + sqrt(1 + 2*sqrt(r2 + r5 + 3))
  33. assert sqrtdenest(sqrt(5*r3 + 6*r2)) == \
  34. sqrt(2)*root(3, 4) + root(3, 4)**3
  35. assert sqrtdenest(sqrt(((1 + r5 + sqrt(1 + r3))**2).expand())) == \
  36. 1 + r5 + sqrt(1 + r3)
  37. assert sqrtdenest(sqrt(((1 + r5 + r7 + sqrt(1 + r3))**2).expand())) == \
  38. 1 + sqrt(1 + r3) + r5 + r7
  39. e = sqrt(((1 + cos(2) + cos(3) + sqrt(1 + r3))**2).expand())
  40. assert sqrtdenest(e) == cos(3) + cos(2) + 1 + sqrt(1 + r3)
  41. e = sqrt(-2*r10 + 2*r2*sqrt(-2*r10 + 11) + 14)
  42. assert sqrtdenest(e) == sqrt(-2*r10 - 2*r2 + 4*r5 + 14)
  43. # check that the result is not more complicated than the input
  44. z = sqrt(-2*r29 + cos(2) + 2*sqrt(-10*r29 + 55) + 16)
  45. assert sqrtdenest(z) == z
  46. assert sqrtdenest(sqrt(r6 + sqrt(15))) == sqrt(r6 + sqrt(15))
  47. z = sqrt(15 - 2*sqrt(31) + 2*sqrt(55 - 10*r29))
  48. assert sqrtdenest(z) == z
  49. def test_sqrtdenest_rec():
  50. assert sqrtdenest(sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 33)) == \
  51. -r2 + r3 + 2*r7
  52. assert sqrtdenest(sqrt(-28*r7 - 14*r5 + 4*sqrt(35) + 82)) == \
  53. -7 + r5 + 2*r7
  54. assert sqrtdenest(sqrt(6*r2/11 + 2*sqrt(22)/11 + 6*sqrt(11)/11 + 2)) == \
  55. sqrt(11)*(r2 + 3 + sqrt(11))/11
  56. assert sqrtdenest(sqrt(468*r3 + 3024*r2 + 2912*r6 + 19735)) == \
  57. 9*r3 + 26 + 56*r6
  58. z = sqrt(-490*r3 - 98*sqrt(115) - 98*sqrt(345) - 2107)
  59. assert sqrtdenest(z) == sqrt(-1)*(7*r5 + 7*r15 + 7*sqrt(23))
  60. z = sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 34)
  61. assert sqrtdenest(z) == z
  62. assert sqrtdenest(sqrt(-8*r2 - 2*r5 + 18)) == -r10 + 1 + r2 + r5
  63. assert sqrtdenest(sqrt(8*r2 + 2*r5 - 18)) == \
  64. sqrt(-1)*(-r10 + 1 + r2 + r5)
  65. assert sqrtdenest(sqrt(8*r2/3 + 14*r5/3 + Rational(154, 9))) == \
  66. -r10/3 + r2 + r5 + 3
  67. assert sqrtdenest(sqrt(sqrt(2*r6 + 5) + sqrt(2*r7 + 8))) == \
  68. sqrt(1 + r2 + r3 + r7)
  69. assert sqrtdenest(sqrt(4*r15 + 8*r5 + 12*r3 + 24)) == 1 + r3 + r5 + r15
  70. w = 1 + r2 + r3 + r5 + r7
  71. assert sqrtdenest(sqrt((w**2).expand())) == w
  72. z = sqrt((w**2).expand() + 1)
  73. assert sqrtdenest(z) == z
  74. z = sqrt(2*r10 + 6*r2 + 4*r5 + 12 + 10*r15 + 30*r3)
  75. assert sqrtdenest(z) == z
  76. def test_issue_6241():
  77. z = sqrt( -320 + 32*sqrt(5) + 64*r15)
  78. assert sqrtdenest(z) == z
  79. def test_sqrtdenest3():
  80. z = sqrt(13 - 2*r10 + 2*r2*sqrt(-2*r10 + 11))
  81. assert sqrtdenest(z) == -1 + r2 + r10
  82. assert sqrtdenest(z, max_iter=1) == -1 + sqrt(2) + sqrt(10)
  83. z = sqrt(sqrt(r2 + 2) + 2)
  84. assert sqrtdenest(z) == z
  85. assert sqrtdenest(sqrt(-2*r10 + 4*r2*sqrt(-2*r10 + 11) + 20)) == \
  86. sqrt(-2*r10 - 4*r2 + 8*r5 + 20)
  87. assert sqrtdenest(sqrt((112 + 70*r2) + (46 + 34*r2)*r5)) == \
  88. r10 + 5 + 4*r2 + 3*r5
  89. z = sqrt(5 + sqrt(2*r6 + 5)*sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16))
  90. r = sqrt(-2*r29 + 11)
  91. assert sqrtdenest(z) == sqrt(r2*r + r3*r + r10 + r15 + 5)
  92. n = sqrt(2*r6/7 + 2*r7/7 + 2*sqrt(42)/7 + 2)
  93. d = sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))
  94. assert sqrtdenest(n/d) == r7*(1 + r6 + r7)/(Mul(7, (sqrt(-2*r29 + 11) + r5),
  95. evaluate=False))
  96. def test_sqrtdenest4():
  97. # see Denest_en.pdf in https://github.com/sympy/sympy/issues/3192
  98. z = sqrt(8 - r2*sqrt(5 - r5) - sqrt(3)*(1 + r5))
  99. z1 = sqrtdenest(z)
  100. c = sqrt(-r5 + 5)
  101. z1 = ((-r15*c - r3*c + c + r5*c - r6 - r2 + r10 + sqrt(30))/4).expand()
  102. assert sqrtdenest(z) == z1
  103. z = sqrt(2*r2*sqrt(r2 + 2) + 5*r2 + 4*sqrt(r2 + 2) + 8)
  104. assert sqrtdenest(z) == r2 + sqrt(r2 + 2) + 2
  105. w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3)
  106. z = sqrt((w**2).expand())
  107. assert sqrtdenest(z) == w.expand()
  108. def test_sqrt_symbolic_denest():
  109. x = Symbol('x')
  110. z = sqrt(((1 + sqrt(sqrt(2 + x) + 3))**2).expand())
  111. assert sqrtdenest(z) == sqrt((1 + sqrt(sqrt(2 + x) + 3))**2)
  112. z = sqrt(((1 + sqrt(sqrt(2 + cos(1)) + 3))**2).expand())
  113. assert sqrtdenest(z) == 1 + sqrt(sqrt(2 + cos(1)) + 3)
  114. z = ((1 + cos(2))**4 + 1).expand()
  115. assert sqrtdenest(z) == z
  116. z = sqrt(((1 + sqrt(sqrt(2 + cos(3*x)) + 3))**2 + 1).expand())
  117. assert sqrtdenest(z) == z
  118. c = cos(3)
  119. c2 = c**2
  120. assert sqrtdenest(sqrt(2*sqrt(1 + r3)*c + c2 + 1 + r3*c2)) == \
  121. -1 - sqrt(1 + r3)*c
  122. ra = sqrt(1 + r3)
  123. z = sqrt(20*ra*sqrt(3 + 3*r3) + 12*r3*ra*sqrt(3 + 3*r3) + 64*r3 + 112)
  124. assert sqrtdenest(z) == z
  125. def test_issue_5857():
  126. from sympy.abc import x, y
  127. z = sqrt(1/(4*r3 + 7) + 1)
  128. ans = (r2 + r6)/(r3 + 2)
  129. assert sqrtdenest(z) == ans
  130. assert sqrtdenest(1 + z) == 1 + ans
  131. assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \
  132. Integral(1 + ans, (x, 1, 2))
  133. assert sqrtdenest(x + sqrt(y)) == x + sqrt(y)
  134. ans = (r2 + r6)/(r3 + 2)
  135. assert sqrtdenest(z) == ans
  136. assert sqrtdenest(1 + z) == 1 + ans
  137. assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \
  138. Integral(1 + ans, (x, 1, 2))
  139. assert sqrtdenest(x + sqrt(y)) == x + sqrt(y)
  140. def test_subsets():
  141. assert subsets(1) == [[1]]
  142. assert subsets(4) == [
  143. [1, 0, 0, 0], [0, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [1, 0, 1, 0],
  144. [0, 1, 1, 0], [1, 1, 1, 0], [0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1],
  145. [1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]]
  146. def test_issue_5653():
  147. assert sqrtdenest(
  148. sqrt(2 + sqrt(2 + sqrt(2)))) == sqrt(2 + sqrt(2 + sqrt(2)))
  149. def test_issue_12420():
  150. assert sqrtdenest((3 - sqrt(2)*sqrt(4 + 3*I) + 3*I)/2) == I
  151. e = 3 - sqrt(2)*sqrt(4 + I) + 3*I
  152. assert sqrtdenest(e) == e
  153. def test_sqrt_ratcomb():
  154. assert sqrtdenest(sqrt(1 + r3) + sqrt(3 + 3*r3) - sqrt(10 + 6*r3)) == 0
  155. def test_issue_18041():
  156. e = -sqrt(-2 + 2*sqrt(3)*I)
  157. assert sqrtdenest(e) == -1 - sqrt(3)*I
  158. def test_issue_19914():
  159. a = Integer(-8)
  160. b = Integer(-1)
  161. r = Integer(63)
  162. d2 = a*a - b*b*r
  163. assert _sqrt_numeric_denest(a, b, r, d2) == \
  164. sqrt(14)*I/2 + 3*sqrt(2)*I/2
  165. assert sqrtdenest(sqrt(-8-sqrt(63))) == sqrt(14)*I/2 + 3*sqrt(2)*I/2