refine.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. from __future__ import annotations
  2. from typing import Callable
  3. from sympy.core import S, Add, Expr, Basic, Mul, Pow, Rational
  4. from sympy.core.logic import fuzzy_not
  5. from sympy.logic.boolalg import Boolean
  6. from sympy.assumptions import ask, Q # type: ignore
  7. def refine(expr, assumptions=True):
  8. """
  9. Simplify an expression using assumptions.
  10. Explanation
  11. ===========
  12. Unlike :func:`~.simplify()` which performs structural simplification
  13. without any assumption, this function transforms the expression into
  14. the form which is only valid under certain assumptions. Note that
  15. ``simplify()`` is generally not done in refining process.
  16. Refining boolean expression involves reducing it to ``S.true`` or
  17. ``S.false``. Unlike :func:`~.ask()`, the expression will not be reduced
  18. if the truth value cannot be determined.
  19. Examples
  20. ========
  21. >>> from sympy import refine, sqrt, Q
  22. >>> from sympy.abc import x
  23. >>> refine(sqrt(x**2), Q.real(x))
  24. Abs(x)
  25. >>> refine(sqrt(x**2), Q.positive(x))
  26. x
  27. >>> refine(Q.real(x), Q.positive(x))
  28. True
  29. >>> refine(Q.positive(x), Q.real(x))
  30. Q.positive(x)
  31. See Also
  32. ========
  33. sympy.simplify.simplify.simplify : Structural simplification without assumptions.
  34. sympy.assumptions.ask.ask : Query for boolean expressions using assumptions.
  35. """
  36. if not isinstance(expr, Basic):
  37. return expr
  38. if not expr.is_Atom:
  39. args = [refine(arg, assumptions) for arg in expr.args]
  40. # TODO: this will probably not work with Integral or Polynomial
  41. expr = expr.func(*args)
  42. if hasattr(expr, '_eval_refine'):
  43. ref_expr = expr._eval_refine(assumptions)
  44. if ref_expr is not None:
  45. return ref_expr
  46. name = expr.__class__.__name__
  47. handler = handlers_dict.get(name, None)
  48. if handler is None:
  49. return expr
  50. new_expr = handler(expr, assumptions)
  51. if (new_expr is None) or (expr == new_expr):
  52. return expr
  53. if not isinstance(new_expr, Expr):
  54. return new_expr
  55. return refine(new_expr, assumptions)
  56. def refine_abs(expr, assumptions):
  57. """
  58. Handler for the absolute value.
  59. Examples
  60. ========
  61. >>> from sympy import Q, Abs
  62. >>> from sympy.assumptions.refine import refine_abs
  63. >>> from sympy.abc import x
  64. >>> refine_abs(Abs(x), Q.real(x))
  65. >>> refine_abs(Abs(x), Q.positive(x))
  66. x
  67. >>> refine_abs(Abs(x), Q.negative(x))
  68. -x
  69. """
  70. from sympy.functions.elementary.complexes import Abs
  71. arg = expr.args[0]
  72. if ask(Q.real(arg), assumptions) and \
  73. fuzzy_not(ask(Q.negative(arg), assumptions)):
  74. # if it's nonnegative
  75. return arg
  76. if ask(Q.negative(arg), assumptions):
  77. return -arg
  78. # arg is Mul
  79. if isinstance(arg, Mul):
  80. r = [refine(abs(a), assumptions) for a in arg.args]
  81. non_abs = []
  82. in_abs = []
  83. for i in r:
  84. if isinstance(i, Abs):
  85. in_abs.append(i.args[0])
  86. else:
  87. non_abs.append(i)
  88. return Mul(*non_abs) * Abs(Mul(*in_abs))
  89. def refine_Pow(expr, assumptions):
  90. """
  91. Handler for instances of Pow.
  92. Examples
  93. ========
  94. >>> from sympy import Q
  95. >>> from sympy.assumptions.refine import refine_Pow
  96. >>> from sympy.abc import x,y,z
  97. >>> refine_Pow((-1)**x, Q.real(x))
  98. >>> refine_Pow((-1)**x, Q.even(x))
  99. 1
  100. >>> refine_Pow((-1)**x, Q.odd(x))
  101. -1
  102. For powers of -1, even parts of the exponent can be simplified:
  103. >>> refine_Pow((-1)**(x+y), Q.even(x))
  104. (-1)**y
  105. >>> refine_Pow((-1)**(x+y+z), Q.odd(x) & Q.odd(z))
  106. (-1)**y
  107. >>> refine_Pow((-1)**(x+y+2), Q.odd(x))
  108. (-1)**(y + 1)
  109. >>> refine_Pow((-1)**(x+3), True)
  110. (-1)**(x + 1)
  111. """
  112. from sympy.functions.elementary.complexes import Abs
  113. from sympy.functions import sign
  114. if isinstance(expr.base, Abs):
  115. if ask(Q.real(expr.base.args[0]), assumptions) and \
  116. ask(Q.even(expr.exp), assumptions):
  117. return expr.base.args[0] ** expr.exp
  118. if ask(Q.real(expr.base), assumptions):
  119. if expr.base.is_number:
  120. if ask(Q.even(expr.exp), assumptions):
  121. return abs(expr.base) ** expr.exp
  122. if ask(Q.odd(expr.exp), assumptions):
  123. return sign(expr.base) * abs(expr.base) ** expr.exp
  124. if isinstance(expr.exp, Rational):
  125. if isinstance(expr.base, Pow):
  126. return abs(expr.base.base) ** (expr.base.exp * expr.exp)
  127. if expr.base is S.NegativeOne:
  128. if expr.exp.is_Add:
  129. old = expr
  130. # For powers of (-1) we can remove
  131. # - even terms
  132. # - pairs of odd terms
  133. # - a single odd term + 1
  134. # - A numerical constant N can be replaced with mod(N,2)
  135. coeff, terms = expr.exp.as_coeff_add()
  136. terms = set(terms)
  137. even_terms = set()
  138. odd_terms = set()
  139. initial_number_of_terms = len(terms)
  140. for t in terms:
  141. if ask(Q.even(t), assumptions):
  142. even_terms.add(t)
  143. elif ask(Q.odd(t), assumptions):
  144. odd_terms.add(t)
  145. terms -= even_terms
  146. if len(odd_terms) % 2:
  147. terms -= odd_terms
  148. new_coeff = (coeff + S.One) % 2
  149. else:
  150. terms -= odd_terms
  151. new_coeff = coeff % 2
  152. if new_coeff != coeff or len(terms) < initial_number_of_terms:
  153. terms.add(new_coeff)
  154. expr = expr.base**(Add(*terms))
  155. # Handle (-1)**((-1)**n/2 + m/2)
  156. e2 = 2*expr.exp
  157. if ask(Q.even(e2), assumptions):
  158. if e2.could_extract_minus_sign():
  159. e2 *= expr.base
  160. if e2.is_Add:
  161. i, p = e2.as_two_terms()
  162. if p.is_Pow and p.base is S.NegativeOne:
  163. if ask(Q.integer(p.exp), assumptions):
  164. i = (i + 1)/2
  165. if ask(Q.even(i), assumptions):
  166. return expr.base**p.exp
  167. elif ask(Q.odd(i), assumptions):
  168. return expr.base**(p.exp + 1)
  169. else:
  170. return expr.base**(p.exp + i)
  171. if old != expr:
  172. return expr
  173. def refine_atan2(expr, assumptions):
  174. """
  175. Handler for the atan2 function.
  176. Examples
  177. ========
  178. >>> from sympy import Q, atan2
  179. >>> from sympy.assumptions.refine import refine_atan2
  180. >>> from sympy.abc import x, y
  181. >>> refine_atan2(atan2(y,x), Q.real(y) & Q.positive(x))
  182. atan(y/x)
  183. >>> refine_atan2(atan2(y,x), Q.negative(y) & Q.negative(x))
  184. atan(y/x) - pi
  185. >>> refine_atan2(atan2(y,x), Q.positive(y) & Q.negative(x))
  186. atan(y/x) + pi
  187. >>> refine_atan2(atan2(y,x), Q.zero(y) & Q.negative(x))
  188. pi
  189. >>> refine_atan2(atan2(y,x), Q.positive(y) & Q.zero(x))
  190. pi/2
  191. >>> refine_atan2(atan2(y,x), Q.negative(y) & Q.zero(x))
  192. -pi/2
  193. >>> refine_atan2(atan2(y,x), Q.zero(y) & Q.zero(x))
  194. nan
  195. """
  196. from sympy.functions.elementary.trigonometric import atan
  197. y, x = expr.args
  198. if ask(Q.real(y) & Q.positive(x), assumptions):
  199. return atan(y / x)
  200. elif ask(Q.negative(y) & Q.negative(x), assumptions):
  201. return atan(y / x) - S.Pi
  202. elif ask(Q.positive(y) & Q.negative(x), assumptions):
  203. return atan(y / x) + S.Pi
  204. elif ask(Q.zero(y) & Q.negative(x), assumptions):
  205. return S.Pi
  206. elif ask(Q.positive(y) & Q.zero(x), assumptions):
  207. return S.Pi/2
  208. elif ask(Q.negative(y) & Q.zero(x), assumptions):
  209. return -S.Pi/2
  210. elif ask(Q.zero(y) & Q.zero(x), assumptions):
  211. return S.NaN
  212. else:
  213. return expr
  214. def refine_re(expr, assumptions):
  215. """
  216. Handler for real part.
  217. Examples
  218. ========
  219. >>> from sympy.assumptions.refine import refine_re
  220. >>> from sympy import Q, re
  221. >>> from sympy.abc import x
  222. >>> refine_re(re(x), Q.real(x))
  223. x
  224. >>> refine_re(re(x), Q.imaginary(x))
  225. 0
  226. """
  227. arg = expr.args[0]
  228. if ask(Q.real(arg), assumptions):
  229. return arg
  230. if ask(Q.imaginary(arg), assumptions):
  231. return S.Zero
  232. return _refine_reim(expr, assumptions)
  233. def refine_im(expr, assumptions):
  234. """
  235. Handler for imaginary part.
  236. Explanation
  237. ===========
  238. >>> from sympy.assumptions.refine import refine_im
  239. >>> from sympy import Q, im
  240. >>> from sympy.abc import x
  241. >>> refine_im(im(x), Q.real(x))
  242. 0
  243. >>> refine_im(im(x), Q.imaginary(x))
  244. -I*x
  245. """
  246. arg = expr.args[0]
  247. if ask(Q.real(arg), assumptions):
  248. return S.Zero
  249. if ask(Q.imaginary(arg), assumptions):
  250. return - S.ImaginaryUnit * arg
  251. return _refine_reim(expr, assumptions)
  252. def refine_arg(expr, assumptions):
  253. """
  254. Handler for complex argument
  255. Explanation
  256. ===========
  257. >>> from sympy.assumptions.refine import refine_arg
  258. >>> from sympy import Q, arg
  259. >>> from sympy.abc import x
  260. >>> refine_arg(arg(x), Q.positive(x))
  261. 0
  262. >>> refine_arg(arg(x), Q.negative(x))
  263. pi
  264. """
  265. rg = expr.args[0]
  266. if ask(Q.positive(rg), assumptions):
  267. return S.Zero
  268. if ask(Q.negative(rg), assumptions):
  269. return S.Pi
  270. return None
  271. def _refine_reim(expr, assumptions):
  272. # Helper function for refine_re & refine_im
  273. expanded = expr.expand(complex = True)
  274. if expanded != expr:
  275. refined = refine(expanded, assumptions)
  276. if refined != expanded:
  277. return refined
  278. # Best to leave the expression as is
  279. return None
  280. def refine_sign(expr, assumptions):
  281. """
  282. Handler for sign.
  283. Examples
  284. ========
  285. >>> from sympy.assumptions.refine import refine_sign
  286. >>> from sympy import Symbol, Q, sign, im
  287. >>> x = Symbol('x', real = True)
  288. >>> expr = sign(x)
  289. >>> refine_sign(expr, Q.positive(x) & Q.nonzero(x))
  290. 1
  291. >>> refine_sign(expr, Q.negative(x) & Q.nonzero(x))
  292. -1
  293. >>> refine_sign(expr, Q.zero(x))
  294. 0
  295. >>> y = Symbol('y', imaginary = True)
  296. >>> expr = sign(y)
  297. >>> refine_sign(expr, Q.positive(im(y)))
  298. I
  299. >>> refine_sign(expr, Q.negative(im(y)))
  300. -I
  301. """
  302. arg = expr.args[0]
  303. if ask(Q.zero(arg), assumptions):
  304. return S.Zero
  305. if ask(Q.real(arg)):
  306. if ask(Q.positive(arg), assumptions):
  307. return S.One
  308. if ask(Q.negative(arg), assumptions):
  309. return S.NegativeOne
  310. if ask(Q.imaginary(arg)):
  311. arg_re, arg_im = arg.as_real_imag()
  312. if ask(Q.positive(arg_im), assumptions):
  313. return S.ImaginaryUnit
  314. if ask(Q.negative(arg_im), assumptions):
  315. return -S.ImaginaryUnit
  316. return expr
  317. def refine_matrixelement(expr, assumptions):
  318. """
  319. Handler for symmetric part.
  320. Examples
  321. ========
  322. >>> from sympy.assumptions.refine import refine_matrixelement
  323. >>> from sympy import MatrixSymbol, Q
  324. >>> X = MatrixSymbol('X', 3, 3)
  325. >>> refine_matrixelement(X[0, 1], Q.symmetric(X))
  326. X[0, 1]
  327. >>> refine_matrixelement(X[1, 0], Q.symmetric(X))
  328. X[0, 1]
  329. """
  330. from sympy.matrices.expressions.matexpr import MatrixElement
  331. matrix, i, j = expr.args
  332. if ask(Q.symmetric(matrix), assumptions):
  333. if (i - j).could_extract_minus_sign():
  334. return expr
  335. return MatrixElement(matrix, j, i)
  336. handlers_dict: dict[str, Callable[[Expr, Boolean], Expr]] = {
  337. 'Abs': refine_abs,
  338. 'Pow': refine_Pow,
  339. 'atan2': refine_atan2,
  340. 're': refine_re,
  341. 'im': refine_im,
  342. 'arg': refine_arg,
  343. 'sign': refine_sign,
  344. 'MatrixElement': refine_matrixelement
  345. }