test_refine.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. from sympy.assumptions.ask import Q
  2. from sympy.assumptions.refine import refine
  3. from sympy.core.expr import Expr
  4. from sympy.core.numbers import (I, Rational, nan, pi)
  5. from sympy.core.singleton import S
  6. from sympy.core.symbol import Symbol
  7. from sympy.functions.elementary.complexes import (Abs, arg, im, re, sign)
  8. from sympy.functions.elementary.exponential import exp
  9. from sympy.functions.elementary.miscellaneous import sqrt
  10. from sympy.functions.elementary.trigonometric import (atan, atan2)
  11. from sympy.abc import w, x, y, z
  12. from sympy.core.relational import Eq, Ne
  13. from sympy.functions.elementary.piecewise import Piecewise
  14. from sympy.matrices.expressions.matexpr import MatrixSymbol
  15. def test_Abs():
  16. assert refine(Abs(x), Q.positive(x)) == x
  17. assert refine(1 + Abs(x), Q.positive(x)) == 1 + x
  18. assert refine(Abs(x), Q.negative(x)) == -x
  19. assert refine(1 + Abs(x), Q.negative(x)) == 1 - x
  20. assert refine(Abs(x**2)) != x**2
  21. assert refine(Abs(x**2), Q.real(x)) == x**2
  22. def test_pow1():
  23. assert refine((-1)**x, Q.even(x)) == 1
  24. assert refine((-1)**x, Q.odd(x)) == -1
  25. assert refine((-2)**x, Q.even(x)) == 2**x
  26. # nested powers
  27. assert refine(sqrt(x**2)) != Abs(x)
  28. assert refine(sqrt(x**2), Q.complex(x)) != Abs(x)
  29. assert refine(sqrt(x**2), Q.real(x)) == Abs(x)
  30. assert refine(sqrt(x**2), Q.positive(x)) == x
  31. assert refine((x**3)**Rational(1, 3)) != x
  32. assert refine((x**3)**Rational(1, 3), Q.real(x)) != x
  33. assert refine((x**3)**Rational(1, 3), Q.positive(x)) == x
  34. assert refine(sqrt(1/x), Q.real(x)) != 1/sqrt(x)
  35. assert refine(sqrt(1/x), Q.positive(x)) == 1/sqrt(x)
  36. # powers of (-1)
  37. assert refine((-1)**(x + y), Q.even(x)) == (-1)**y
  38. assert refine((-1)**(x + y + z), Q.odd(x) & Q.odd(z)) == (-1)**y
  39. assert refine((-1)**(x + y + 1), Q.odd(x)) == (-1)**y
  40. assert refine((-1)**(x + y + 2), Q.odd(x)) == (-1)**(y + 1)
  41. assert refine((-1)**(x + 3)) == (-1)**(x + 1)
  42. # continuation
  43. assert refine((-1)**((-1)**x/2 - S.Half), Q.integer(x)) == (-1)**x
  44. assert refine((-1)**((-1)**x/2 + S.Half), Q.integer(x)) == (-1)**(x + 1)
  45. assert refine((-1)**((-1)**x/2 + 5*S.Half), Q.integer(x)) == (-1)**(x + 1)
  46. def test_pow2():
  47. assert refine((-1)**((-1)**x/2 - 7*S.Half), Q.integer(x)) == (-1)**(x + 1)
  48. assert refine((-1)**((-1)**x/2 - 9*S.Half), Q.integer(x)) == (-1)**x
  49. # powers of Abs
  50. assert refine(Abs(x)**2, Q.real(x)) == x**2
  51. assert refine(Abs(x)**3, Q.real(x)) == Abs(x)**3
  52. assert refine(Abs(x)**2) == Abs(x)**2
  53. def test_exp():
  54. x = Symbol('x', integer=True)
  55. assert refine(exp(pi*I*2*x)) == 1
  56. assert refine(exp(pi*I*2*(x + S.Half))) == -1
  57. assert refine(exp(pi*I*2*(x + Rational(1, 4)))) == I
  58. assert refine(exp(pi*I*2*(x + Rational(3, 4)))) == -I
  59. def test_Piecewise():
  60. assert refine(Piecewise((1, x < 0), (3, True)), (x < 0)) == 1
  61. assert refine(Piecewise((1, x < 0), (3, True)), ~(x < 0)) == 3
  62. assert refine(Piecewise((1, x < 0), (3, True)), (y < 0)) == \
  63. Piecewise((1, x < 0), (3, True))
  64. assert refine(Piecewise((1, x > 0), (3, True)), (x > 0)) == 1
  65. assert refine(Piecewise((1, x > 0), (3, True)), ~(x > 0)) == 3
  66. assert refine(Piecewise((1, x > 0), (3, True)), (y > 0)) == \
  67. Piecewise((1, x > 0), (3, True))
  68. assert refine(Piecewise((1, x <= 0), (3, True)), (x <= 0)) == 1
  69. assert refine(Piecewise((1, x <= 0), (3, True)), ~(x <= 0)) == 3
  70. assert refine(Piecewise((1, x <= 0), (3, True)), (y <= 0)) == \
  71. Piecewise((1, x <= 0), (3, True))
  72. assert refine(Piecewise((1, x >= 0), (3, True)), (x >= 0)) == 1
  73. assert refine(Piecewise((1, x >= 0), (3, True)), ~(x >= 0)) == 3
  74. assert refine(Piecewise((1, x >= 0), (3, True)), (y >= 0)) == \
  75. Piecewise((1, x >= 0), (3, True))
  76. assert refine(Piecewise((1, Eq(x, 0)), (3, True)), (Eq(x, 0)))\
  77. == 1
  78. assert refine(Piecewise((1, Eq(x, 0)), (3, True)), (Eq(0, x)))\
  79. == 1
  80. assert refine(Piecewise((1, Eq(x, 0)), (3, True)), ~(Eq(x, 0)))\
  81. == 3
  82. assert refine(Piecewise((1, Eq(x, 0)), (3, True)), ~(Eq(0, x)))\
  83. == 3
  84. assert refine(Piecewise((1, Eq(x, 0)), (3, True)), (Eq(y, 0)))\
  85. == Piecewise((1, Eq(x, 0)), (3, True))
  86. assert refine(Piecewise((1, Ne(x, 0)), (3, True)), (Ne(x, 0)))\
  87. == 1
  88. assert refine(Piecewise((1, Ne(x, 0)), (3, True)), ~(Ne(x, 0)))\
  89. == 3
  90. assert refine(Piecewise((1, Ne(x, 0)), (3, True)), (Ne(y, 0)))\
  91. == Piecewise((1, Ne(x, 0)), (3, True))
  92. def test_atan2():
  93. assert refine(atan2(y, x), Q.real(y) & Q.positive(x)) == atan(y/x)
  94. assert refine(atan2(y, x), Q.negative(y) & Q.positive(x)) == atan(y/x)
  95. assert refine(atan2(y, x), Q.negative(y) & Q.negative(x)) == atan(y/x) - pi
  96. assert refine(atan2(y, x), Q.positive(y) & Q.negative(x)) == atan(y/x) + pi
  97. assert refine(atan2(y, x), Q.zero(y) & Q.negative(x)) == pi
  98. assert refine(atan2(y, x), Q.positive(y) & Q.zero(x)) == pi/2
  99. assert refine(atan2(y, x), Q.negative(y) & Q.zero(x)) == -pi/2
  100. assert refine(atan2(y, x), Q.zero(y) & Q.zero(x)) is nan
  101. def test_re():
  102. assert refine(re(x), Q.real(x)) == x
  103. assert refine(re(x), Q.imaginary(x)) is S.Zero
  104. assert refine(re(x+y), Q.real(x) & Q.real(y)) == x + y
  105. assert refine(re(x+y), Q.real(x) & Q.imaginary(y)) == x
  106. assert refine(re(x*y), Q.real(x) & Q.real(y)) == x * y
  107. assert refine(re(x*y), Q.real(x) & Q.imaginary(y)) == 0
  108. assert refine(re(x*y*z), Q.real(x) & Q.real(y) & Q.real(z)) == x * y * z
  109. def test_im():
  110. assert refine(im(x), Q.imaginary(x)) == -I*x
  111. assert refine(im(x), Q.real(x)) is S.Zero
  112. assert refine(im(x+y), Q.imaginary(x) & Q.imaginary(y)) == -I*x - I*y
  113. assert refine(im(x+y), Q.real(x) & Q.imaginary(y)) == -I*y
  114. assert refine(im(x*y), Q.imaginary(x) & Q.real(y)) == -I*x*y
  115. assert refine(im(x*y), Q.imaginary(x) & Q.imaginary(y)) == 0
  116. assert refine(im(1/x), Q.imaginary(x)) == -I/x
  117. assert refine(im(x*y*z), Q.imaginary(x) & Q.imaginary(y)
  118. & Q.imaginary(z)) == -I*x*y*z
  119. def test_complex():
  120. assert refine(re(1/(x + I*y)), Q.real(x) & Q.real(y)) == \
  121. x/(x**2 + y**2)
  122. assert refine(im(1/(x + I*y)), Q.real(x) & Q.real(y)) == \
  123. -y/(x**2 + y**2)
  124. assert refine(re((w + I*x) * (y + I*z)), Q.real(w) & Q.real(x) & Q.real(y)
  125. & Q.real(z)) == w*y - x*z
  126. assert refine(im((w + I*x) * (y + I*z)), Q.real(w) & Q.real(x) & Q.real(y)
  127. & Q.real(z)) == w*z + x*y
  128. def test_sign():
  129. x = Symbol('x', real = True)
  130. assert refine(sign(x), Q.positive(x)) == 1
  131. assert refine(sign(x), Q.negative(x)) == -1
  132. assert refine(sign(x), Q.zero(x)) == 0
  133. assert refine(sign(x), True) == sign(x)
  134. assert refine(sign(Abs(x)), Q.nonzero(x)) == 1
  135. x = Symbol('x', imaginary=True)
  136. assert refine(sign(x), Q.positive(im(x))) == S.ImaginaryUnit
  137. assert refine(sign(x), Q.negative(im(x))) == -S.ImaginaryUnit
  138. assert refine(sign(x), True) == sign(x)
  139. x = Symbol('x', complex=True)
  140. assert refine(sign(x), Q.zero(x)) == 0
  141. def test_arg():
  142. x = Symbol('x', complex = True)
  143. assert refine(arg(x), Q.positive(x)) == 0
  144. assert refine(arg(x), Q.negative(x)) == pi
  145. def test_func_args():
  146. class MyClass(Expr):
  147. # A class with nontrivial .func
  148. def __init__(self, *args):
  149. self.my_member = ""
  150. @property
  151. def func(self):
  152. def my_func(*args):
  153. obj = MyClass(*args)
  154. obj.my_member = self.my_member
  155. return obj
  156. return my_func
  157. x = MyClass()
  158. x.my_member = "A very important value"
  159. assert x.my_member == refine(x).my_member
  160. def test_issue_refine_9384():
  161. assert refine(Piecewise((1, x < 0), (0, True)), Q.positive(x)) == 0
  162. assert refine(Piecewise((1, x < 0), (0, True)), Q.negative(x)) == 1
  163. assert refine(Piecewise((1, x > 0), (0, True)), Q.positive(x)) == 1
  164. assert refine(Piecewise((1, x > 0), (0, True)), Q.negative(x)) == 0
  165. def test_eval_refine():
  166. class MockExpr(Expr):
  167. def _eval_refine(self, assumptions):
  168. return True
  169. mock_obj = MockExpr()
  170. assert refine(mock_obj)
  171. def test_refine_issue_12724():
  172. expr1 = refine(Abs(x * y), Q.positive(x))
  173. expr2 = refine(Abs(x * y * z), Q.positive(x))
  174. assert expr1 == x * Abs(y)
  175. assert expr2 == x * Abs(y * z)
  176. y1 = Symbol('y1', real = True)
  177. expr3 = refine(Abs(x * y1**2 * z), Q.positive(x))
  178. assert expr3 == x * y1**2 * Abs(z)
  179. def test_matrixelement():
  180. x = MatrixSymbol('x', 3, 3)
  181. i = Symbol('i', positive = True)
  182. j = Symbol('j', positive = True)
  183. assert refine(x[0, 1], Q.symmetric(x)) == x[0, 1]
  184. assert refine(x[1, 0], Q.symmetric(x)) == x[0, 1]
  185. assert refine(x[i, j], Q.symmetric(x)) == x[j, i]
  186. assert refine(x[j, i], Q.symmetric(x)) == x[j, i]