basic.py 71 KB


  1. """Base class for all the objects in SymPy"""
  2. from __future__ import annotations
  3. from collections import defaultdict
  4. from collections.abc import Mapping
  5. from itertools import chain, zip_longest
  6. from .assumptions import _prepare_class_assumptions
  7. from .cache import cacheit
  8. from .core import ordering_of_classes
  9. from .sympify import _sympify, sympify, SympifyError, _external_converter
  10. from .sorting import ordered
  11. from .kind import Kind, UndefinedKind
  12. from ._print_helpers import Printable
  13. from sympy.utilities.decorator import deprecated
  14. from sympy.utilities.exceptions import sympy_deprecation_warning
  15. from sympy.utilities.iterables import iterable, numbered_symbols
  16. from sympy.utilities.misc import filldedent, func_name
  17. from inspect import getmro
  18. def as_Basic(expr):
  19. """Return expr as a Basic instance using strict sympify
  20. or raise a TypeError; this is just a wrapper to _sympify,
  21. raising a TypeError instead of a SympifyError."""
  22. try:
  23. return _sympify(expr)
  24. except SympifyError:
  25. raise TypeError(
  26. 'Argument must be a Basic object, not `%s`' % func_name(
  27. expr))
  28. def _old_compare(x: type, y: type) -> int:
  29. # If the other object is not a Basic subclass, then we are not equal to it.
  30. if not issubclass(y, Basic):
  31. return -1
  32. n1 = x.__name__
  33. n2 = y.__name__
  34. if n1 == n2:
  35. return 0
  36. UNKNOWN = len(ordering_of_classes) + 1
  37. try:
  38. i1 = ordering_of_classes.index(n1)
  39. except ValueError:
  40. i1 = UNKNOWN
  41. try:
  42. i2 = ordering_of_classes.index(n2)
  43. except ValueError:
  44. i2 = UNKNOWN
  45. if i1 == UNKNOWN and i2 == UNKNOWN:
  46. return (n1 > n2) - (n1 < n2)
  47. return (i1 > i2) - (i1 < i2)
  48. class Basic(Printable):
  49. """
  50. Base class for all SymPy objects.
  51. Notes and conventions
  52. =====================
  53. 1) Always use ``.args``, when accessing parameters of some instance:
  54. >>> from sympy import cot
  55. >>> from sympy.abc import x, y
  56. >>> cot(x).args
  57. (x,)
  58. >>> cot(x).args[0]
  59. x
  60. >>> (x*y).args
  61. (x, y)
  62. >>> (x*y).args[1]
  63. y
  64. 2) Never use internal methods or variables (the ones prefixed with ``_``):
  65. >>> cot(x)._args # do not use this, use cot(x).args instead
  66. (x,)
  67. 3) By "SymPy object" we mean something that can be returned by
  68. ``sympify``. But not all objects one encounters using SymPy are
  69. subclasses of Basic. For example, mutable objects are not:
  70. >>> from sympy import Basic, Matrix, sympify
  71. >>> A = Matrix([[1, 2], [3, 4]]).as_mutable()
  72. >>> isinstance(A, Basic)
  73. False
  74. >>> B = sympify(A)
  75. >>> isinstance(B, Basic)
  76. True
  77. """
  78. __slots__ = ('_mhash', # hash value
  79. '_args', # arguments
  80. '_assumptions'
  81. )
  82. _args: tuple[Basic, ...]
  83. _mhash: int | None
  84. @property
  85. def __sympy__(self):
  86. return True
  87. def __init_subclass__(cls):
  88. # Initialize the default_assumptions FactKB and also any assumptions
  89. # property methods. This method will only be called for subclasses of
  90. # Basic but not for Basic itself so we call
  91. # _prepare_class_assumptions(Basic) below the class definition.
  92. _prepare_class_assumptions(cls)
  93. # To be overridden with True in the appropriate subclasses
  94. is_number = False
  95. is_Atom = False
  96. is_Symbol = False
  97. is_symbol = False
  98. is_Indexed = False
  99. is_Dummy = False
  100. is_Wild = False
  101. is_Function = False
  102. is_Add = False
  103. is_Mul = False
  104. is_Pow = False
  105. is_Number = False
  106. is_Float = False
  107. is_Rational = False
  108. is_Integer = False
  109. is_NumberSymbol = False
  110. is_Order = False
  111. is_Derivative = False
  112. is_Piecewise = False
  113. is_Poly = False
  114. is_AlgebraicNumber = False
  115. is_Relational = False
  116. is_Equality = False
  117. is_Boolean = False
  118. is_Not = False
  119. is_Matrix = False
  120. is_Vector = False
  121. is_Point = False
  122. is_MatAdd = False
  123. is_MatMul = False
  124. is_real: bool | None
  125. is_extended_real: bool | None
  126. is_zero: bool | None
  127. is_negative: bool | None
  128. is_commutative: bool | None
  129. kind: Kind = UndefinedKind
  130. def __new__(cls, *args):
  131. obj = object.__new__(cls)
  132. obj._assumptions = cls.default_assumptions
  133. obj._mhash = None # will be set by __hash__ method.
  134. obj._args = args # all items in args must be Basic objects
  135. return obj
  136. def copy(self):
  137. return self.func(*self.args)
  138. def __getnewargs__(self):
  139. return self.args
  140. def __getstate__(self):
  141. return None
  142. def __setstate__(self, state):
  143. for name, value in state.items():
  144. setattr(self, name, value)
  145. def __reduce_ex__(self, protocol):
  146. if protocol < 2:
  147. msg = "Only pickle protocol 2 or higher is supported by SymPy"
  148. raise NotImplementedError(msg)
  149. return super().__reduce_ex__(protocol)
  150. def __hash__(self) -> int:
  151. # hash cannot be cached using cache_it because infinite recurrence
  152. # occurs as hash is needed for setting cache dictionary keys
  153. h = self._mhash
  154. if h is None:
  155. h = hash((type(self).__name__,) + self._hashable_content())
  156. self._mhash = h
  157. return h
  158. def _hashable_content(self):
  159. """Return a tuple of information about self that can be used to
  160. compute the hash. If a class defines additional attributes,
  161. like ``name`` in Symbol, then this method should be updated
  162. accordingly to return such relevant attributes.
  163. Defining more than _hashable_content is necessary if __eq__ has
  164. been defined by a class. See note about this in Basic.__eq__."""
  165. return self._args
  166. @property
  167. def assumptions0(self):
  168. """
  169. Return object `type` assumptions.
  170. For example:
  171. Symbol('x', real=True)
  172. Symbol('x', integer=True)
  173. are different objects. In other words, besides Python type (Symbol in
  174. this case), the initial assumptions are also forming their typeinfo.
  175. Examples
  176. ========
  177. >>> from sympy import Symbol
  178. >>> from sympy.abc import x
  179. >>> x.assumptions0
  180. {'commutative': True}
  181. >>> x = Symbol("x", positive=True)
  182. >>> x.assumptions0
  183. {'commutative': True, 'complex': True, 'extended_negative': False,
  184. 'extended_nonnegative': True, 'extended_nonpositive': False,
  185. 'extended_nonzero': True, 'extended_positive': True, 'extended_real':
  186. True, 'finite': True, 'hermitian': True, 'imaginary': False,
  187. 'infinite': False, 'negative': False, 'nonnegative': True,
  188. 'nonpositive': False, 'nonzero': True, 'positive': True, 'real':
  189. True, 'zero': False}
  190. """
  191. return {}
  192. def compare(self, other):
  193. """
  194. Return -1, 0, 1 if the object is smaller, equal, or greater than other.
  195. Not in the mathematical sense. If the object is of a different type
  196. from the "other" then their classes are ordered according to
  197. the sorted_classes list.
  198. Examples
  199. ========
  200. >>> from sympy.abc import x, y
  201. >>> x.compare(y)
  202. -1
  203. >>> x.compare(x)
  204. 0
  205. >>> y.compare(x)
  206. 1
  207. """
  208. # all redefinitions of __cmp__ method should start with the
  209. # following lines:
  210. if self is other:
  211. return 0
  212. n1 = self.__class__
  213. n2 = other.__class__
  214. c = _old_compare(n1, n2)
  215. if c:
  216. return c
  217. #
  218. st = self._hashable_content()
  219. ot = other._hashable_content()
  220. c = (len(st) > len(ot)) - (len(st) < len(ot))
  221. if c:
  222. return c
  223. for l, r in zip(st, ot):
  224. l = Basic(*l) if isinstance(l, frozenset) else l
  225. r = Basic(*r) if isinstance(r, frozenset) else r
  226. if isinstance(l, Basic):
  227. c = l.compare(r)
  228. else:
  229. c = (l > r) - (l < r)
  230. if c:
  231. return c
  232. return 0
  233. @staticmethod
  234. def _compare_pretty(a, b):
  235. from sympy.series.order import Order
  236. if isinstance(a, Order) and not isinstance(b, Order):
  237. return 1
  238. if not isinstance(a, Order) and isinstance(b, Order):
  239. return -1
  240. if a.is_Rational and b.is_Rational:
  241. l = a.p * b.q
  242. r = b.p * a.q
  243. return (l > r) - (l < r)
  244. else:
  245. from .symbol import Wild
  246. p1, p2, p3 = Wild("p1"), Wild("p2"), Wild("p3")
  247. r_a = a.match(p1 * p2**p3)
  248. if r_a and p3 in r_a:
  249. a3 = r_a[p3]
  250. r_b = b.match(p1 * p2**p3)
  251. if r_b and p3 in r_b:
  252. b3 = r_b[p3]
  253. c = Basic.compare(a3, b3)
  254. if c != 0:
  255. return c
  256. return Basic.compare(a, b)
  257. @classmethod
  258. def fromiter(cls, args, **assumptions):
  259. """
  260. Create a new object from an iterable.
  261. This is a convenience function that allows one to create objects from
  262. any iterable, without having to convert to a list or tuple first.
  263. Examples
  264. ========
  265. >>> from sympy import Tuple
  266. >>> Tuple.fromiter(i for i in range(5))
  267. (0, 1, 2, 3, 4)
  268. """
  269. return cls(*tuple(args), **assumptions)
  270. @classmethod
  271. def class_key(cls):
  272. """Nice order of classes."""
  273. return 5, 0, cls.__name__
  274. @cacheit
  275. def sort_key(self, order=None):
  276. """
  277. Return a sort key.
  278. Examples
  279. ========
  280. >>> from sympy import S, I
  281. >>> sorted([S(1)/2, I, -I], key=lambda x: x.sort_key())
  282. [1/2, -I, I]
  283. >>> S("[x, 1/x, 1/x**2, x**2, x**(1/2), x**(1/4), x**(3/2)]")
  284. [x, 1/x, x**(-2), x**2, sqrt(x), x**(1/4), x**(3/2)]
  285. >>> sorted(_, key=lambda x: x.sort_key())
  286. [x**(-2), 1/x, x**(1/4), sqrt(x), x, x**(3/2), x**2]
  287. """
  288. # XXX: remove this when issue 5169 is fixed
  289. def inner_key(arg):
  290. if isinstance(arg, Basic):
  291. return arg.sort_key(order)
  292. else:
  293. return arg
  294. args = self._sorted_args
  295. args = len(args), tuple([inner_key(arg) for arg in args])
  296. return self.class_key(), args, S.One.sort_key(), S.One
  297. def _do_eq_sympify(self, other):
  298. """Returns a boolean indicating whether a == b when either a
  299. or b is not a Basic. This is only done for types that were either
  300. added to `converter` by a 3rd party or when the object has `_sympy_`
  301. defined. This essentially reuses the code in `_sympify` that is
  302. specific for this use case. Non-user defined types that are meant
  303. to work with SymPy should be handled directly in the __eq__ methods
  304. of the `Basic` classes it could equate to and not be converted. Note
  305. that after conversion, `==` is used again since it is not
  306. necessarily clear whether `self` or `other`'s __eq__ method needs
  307. to be used."""
  308. for superclass in type(other).__mro__:
  309. conv = _external_converter.get(superclass)
  310. if conv is not None:
  311. return self == conv(other)
  312. if hasattr(other, '_sympy_'):
  313. return self == other._sympy_()
  314. return NotImplemented
  315. def __eq__(self, other):
  316. """Return a boolean indicating whether a == b on the basis of
  317. their symbolic trees.
  318. This is the same as a.compare(b) == 0 but faster.
  319. Notes
  320. =====
  321. If a class that overrides __eq__() needs to retain the
  322. implementation of __hash__() from a parent class, the
  323. interpreter must be told this explicitly by setting
  324. __hash__ : Callable[[object], int] = <ParentClass>.__hash__.
  325. Otherwise the inheritance of __hash__() will be blocked,
  326. just as if __hash__ had been explicitly set to None.
  327. References
  328. ==========
  329. from https://docs.python.org/dev/reference/datamodel.html#object.__hash__
  330. """
  331. if self is other:
  332. return True
  333. if not isinstance(other, Basic):
  334. return self._do_eq_sympify(other)
  335. # check for pure number expr
  336. if not (self.is_Number and other.is_Number) and (
  337. type(self) != type(other)):
  338. return False
  339. a, b = self._hashable_content(), other._hashable_content()
  340. if a != b:
  341. return False
  342. # check number *in* an expression
  343. for a, b in zip(a, b):
  344. if not isinstance(a, Basic):
  345. continue
  346. if a.is_Number and type(a) != type(b):
  347. return False
  348. return True
  349. def __ne__(self, other):
  350. """``a != b`` -> Compare two symbolic trees and see whether they are different
  351. this is the same as:
  352. ``a.compare(b) != 0``
  353. but faster
  354. """
  355. return not self == other
  356. def dummy_eq(self, other, symbol=None):
  357. """
  358. Compare two expressions and handle dummy symbols.
  359. Examples
  360. ========
  361. >>> from sympy import Dummy
  362. >>> from sympy.abc import x, y
  363. >>> u = Dummy('u')
  364. >>> (u**2 + 1).dummy_eq(x**2 + 1)
  365. True
  366. >>> (u**2 + 1) == (x**2 + 1)
  367. False
  368. >>> (u**2 + y).dummy_eq(x**2 + y, x)
  369. True
  370. >>> (u**2 + y).dummy_eq(x**2 + y, y)
  371. False
  372. """
  373. s = self.as_dummy()
  374. o = _sympify(other)
  375. o = o.as_dummy()
  376. dummy_symbols = [i for i in s.free_symbols if i.is_Dummy]
  377. if len(dummy_symbols) == 1:
  378. dummy = dummy_symbols.pop()
  379. else:
  380. return s == o
  381. if symbol is None:
  382. symbols = o.free_symbols
  383. if len(symbols) == 1:
  384. symbol = symbols.pop()
  385. else:
  386. return s == o
  387. tmp = dummy.__class__()
  388. return s.xreplace({dummy: tmp}) == o.xreplace({symbol: tmp})
  389. def atoms(self, *types):
  390. """Returns the atoms that form the current object.
  391. By default, only objects that are truly atomic and cannot
  392. be divided into smaller pieces are returned: symbols, numbers,
  393. and number symbols like I and pi. It is possible to request
  394. atoms of any type, however, as demonstrated below.
  395. Examples
  396. ========
  397. >>> from sympy import I, pi, sin
  398. >>> from sympy.abc import x, y
  399. >>> (1 + x + 2*sin(y + I*pi)).atoms()
  400. {1, 2, I, pi, x, y}
  401. If one or more types are given, the results will contain only
  402. those types of atoms.
  403. >>> from sympy import Number, NumberSymbol, Symbol
  404. >>> (1 + x + 2*sin(y + I*pi)).atoms(Symbol)
  405. {x, y}
  406. >>> (1 + x + 2*sin(y + I*pi)).atoms(Number)
  407. {1, 2}
  408. >>> (1 + x + 2*sin(y + I*pi)).atoms(Number, NumberSymbol)
  409. {1, 2, pi}
  410. >>> (1 + x + 2*sin(y + I*pi)).atoms(Number, NumberSymbol, I)
  411. {1, 2, I, pi}
  412. Note that I (imaginary unit) and zoo (complex infinity) are special
  413. types of number symbols and are not part of the NumberSymbol class.
  414. The type can be given implicitly, too:
  415. >>> (1 + x + 2*sin(y + I*pi)).atoms(x) # x is a Symbol
  416. {x, y}
  417. Be careful to check your assumptions when using the implicit option
  418. since ``S(1).is_Integer = True`` but ``type(S(1))`` is ``One``, a special type
  419. of SymPy atom, while ``type(S(2))`` is type ``Integer`` and will find all
  420. integers in an expression:
  421. >>> from sympy import S
  422. >>> (1 + x + 2*sin(y + I*pi)).atoms(S(1))
  423. {1}
  424. >>> (1 + x + 2*sin(y + I*pi)).atoms(S(2))
  425. {1, 2}
  426. Finally, arguments to atoms() can select more than atomic atoms: any
  427. SymPy type (loaded in core/__init__.py) can be listed as an argument
  428. and those types of "atoms" as found in scanning the arguments of the
  429. expression recursively:
  430. >>> from sympy import Function, Mul
  431. >>> from sympy.core.function import AppliedUndef
  432. >>> f = Function('f')
  433. >>> (1 + f(x) + 2*sin(y + I*pi)).atoms(Function)
  434. {f(x), sin(y + I*pi)}
  435. >>> (1 + f(x) + 2*sin(y + I*pi)).atoms(AppliedUndef)
  436. {f(x)}
  437. >>> (1 + x + 2*sin(y + I*pi)).atoms(Mul)
  438. {I*pi, 2*sin(y + I*pi)}
  439. """
  440. if types:
  441. types = tuple(
  442. [t if isinstance(t, type) else type(t) for t in types])
  443. nodes = _preorder_traversal(self)
  444. if types:
  445. result = {node for node in nodes if isinstance(node, types)}
  446. else:
  447. result = {node for node in nodes if not node.args}
  448. return result
  449. @property
  450. def free_symbols(self) -> set[Basic]:
  451. """Return from the atoms of self those which are free symbols.
  452. Not all free symbols are ``Symbol``. Eg: IndexedBase('I')[0].free_symbols
  453. For most expressions, all symbols are free symbols. For some classes
  454. this is not true. e.g. Integrals use Symbols for the dummy variables
  455. which are bound variables, so Integral has a method to return all
  456. symbols except those. Derivative keeps track of symbols with respect
  457. to which it will perform a derivative; those are
  458. bound variables, too, so it has its own free_symbols method.
  459. Any other method that uses bound variables should implement a
  460. free_symbols method."""
  461. empty: set[Basic] = set()
  462. return empty.union(*(a.free_symbols for a in self.args))
  463. @property
  464. def expr_free_symbols(self):
  465. sympy_deprecation_warning("""
  466. The expr_free_symbols property is deprecated. Use free_symbols to get
  467. the free symbols of an expression.
  468. """,
  469. deprecated_since_version="1.9",
  470. active_deprecations_target="deprecated-expr-free-symbols")
  471. return set()
  472. def as_dummy(self):
  473. """Return the expression with any objects having structurally
  474. bound symbols replaced with unique, canonical symbols within
  475. the object in which they appear and having only the default
  476. assumption for commutativity being True. When applied to a
  477. symbol a new symbol having only the same commutativity will be
  478. returned.
  479. Examples
  480. ========
  481. >>> from sympy import Integral, Symbol
  482. >>> from sympy.abc import x
  483. >>> r = Symbol('r', real=True)
  484. >>> Integral(r, (r, x)).as_dummy()
  485. Integral(_0, (_0, x))
  486. >>> _.variables[0].is_real is None
  487. True
  488. >>> r.as_dummy()
  489. _r
  490. Notes
  491. =====
  492. Any object that has structurally bound variables should have
  493. a property, `bound_symbols` that returns those symbols
  494. appearing in the object.
  495. """
  496. from .symbol import Dummy, Symbol
  497. def can(x):
  498. # mask free that shadow bound
  499. free = x.free_symbols
  500. bound = set(x.bound_symbols)
  501. d = {i: Dummy() for i in bound & free}
  502. x = x.subs(d)
  503. # replace bound with canonical names
  504. x = x.xreplace(x.canonical_variables)
  505. # return after undoing masking
  506. return x.xreplace({v: k for k, v in d.items()})
  507. if not self.has(Symbol):
  508. return self
  509. return self.replace(
  510. lambda x: hasattr(x, 'bound_symbols'),
  511. can,
  512. simultaneous=False)
  513. @property
  514. def canonical_variables(self):
  515. """Return a dictionary mapping any variable defined in
  516. ``self.bound_symbols`` to Symbols that do not clash
  517. with any free symbols in the expression.
  518. Examples
  519. ========
  520. >>> from sympy import Lambda
  521. >>> from sympy.abc import x
  522. >>> Lambda(x, 2*x).canonical_variables
  523. {x: _0}
  524. """
  525. if not hasattr(self, 'bound_symbols'):
  526. return {}
  527. dums = numbered_symbols('_')
  528. reps = {}
  529. # watch out for free symbol that are not in bound symbols;
  530. # those that are in bound symbols are about to get changed
  531. bound = self.bound_symbols
  532. names = {i.name for i in self.free_symbols - set(bound)}
  533. for b in bound:
  534. d = next(dums)
  535. if b.is_Symbol:
  536. while d.name in names:
  537. d = next(dums)
  538. reps[b] = d
  539. return reps
  540. def rcall(self, *args):
  541. """Apply on the argument recursively through the expression tree.
  542. This method is used to simulate a common abuse of notation for
  543. operators. For instance, in SymPy the following will not work:
  544. ``(x+Lambda(y, 2*y))(z) == x+2*z``,
  545. however, you can use:
  546. >>> from sympy import Lambda
  547. >>> from sympy.abc import x, y, z
  548. >>> (x + Lambda(y, 2*y)).rcall(z)
  549. x + 2*z
  550. """
  551. return Basic._recursive_call(self, args)
  552. @staticmethod
  553. def _recursive_call(expr_to_call, on_args):
  554. """Helper for rcall method."""
  555. from .symbol import Symbol
  556. def the_call_method_is_overridden(expr):
  557. for cls in getmro(type(expr)):
  558. if '__call__' in cls.__dict__:
  559. return cls != Basic
  560. if callable(expr_to_call) and the_call_method_is_overridden(expr_to_call):
  561. if isinstance(expr_to_call, Symbol): # XXX When you call a Symbol it is
  562. return expr_to_call # transformed into an UndefFunction
  563. else:
  564. return expr_to_call(*on_args)
  565. elif expr_to_call.args:
  566. args = [Basic._recursive_call(
  567. sub, on_args) for sub in expr_to_call.args]
  568. return type(expr_to_call)(*args)
  569. else:
  570. return expr_to_call
  571. def is_hypergeometric(self, k):
  572. from sympy.simplify.simplify import hypersimp
  573. from sympy.functions.elementary.piecewise import Piecewise
  574. if self.has(Piecewise):
  575. return None
  576. return hypersimp(self, k) is not None
  577. @property
  578. def is_comparable(self):
  579. """Return True if self can be computed to a real number
  580. (or already is a real number) with precision, else False.
  581. Examples
  582. ========
  583. >>> from sympy import exp_polar, pi, I
  584. >>> (I*exp_polar(I*pi/2)).is_comparable
  585. True
  586. >>> (I*exp_polar(I*pi*2)).is_comparable
  587. False
  588. A False result does not mean that `self` cannot be rewritten
  589. into a form that would be comparable. For example, the
  590. difference computed below is zero but without simplification
  591. it does not evaluate to a zero with precision:
  592. >>> e = 2**pi*(1 + 2**pi)
  593. >>> dif = e - e.expand()
  594. >>> dif.is_comparable
  595. False
  596. >>> dif.n(2)._prec
  597. 1
  598. """
  599. is_extended_real = self.is_extended_real
  600. if is_extended_real is False:
  601. return False
  602. if not self.is_number:
  603. return False
  604. # don't re-eval numbers that are already evaluated since
  605. # this will create spurious precision
  606. n, i = [p.evalf(2) if not p.is_Number else p
  607. for p in self.as_real_imag()]
  608. if not (i.is_Number and n.is_Number):
  609. return False
  610. if i:
  611. # if _prec = 1 we can't decide and if not,
  612. # the answer is False because numbers with
  613. # imaginary parts can't be compared
  614. # so return False
  615. return False
  616. else:
  617. return n._prec != 1
  618. @property
  619. def func(self):
  620. """
  621. The top-level function in an expression.
  622. The following should hold for all objects::
  623. >> x == x.func(*x.args)
  624. Examples
  625. ========
  626. >>> from sympy.abc import x
  627. >>> a = 2*x
  628. >>> a.func
  629. <class 'sympy.core.mul.Mul'>
  630. >>> a.args
  631. (2, x)
  632. >>> a.func(*a.args)
  633. 2*x
  634. >>> a == a.func(*a.args)
  635. True
  636. """
  637. return self.__class__
  638. @property
  639. def args(self) -> tuple[Basic, ...]:
  640. """Returns a tuple of arguments of 'self'.
  641. Examples
  642. ========
  643. >>> from sympy import cot
  644. >>> from sympy.abc import x, y
  645. >>> cot(x).args
  646. (x,)
  647. >>> cot(x).args[0]
  648. x
  649. >>> (x*y).args
  650. (x, y)
  651. >>> (x*y).args[1]
  652. y
  653. Notes
  654. =====
  655. Never use self._args, always use self.args.
  656. Only use _args in __new__ when creating a new function.
  657. Do not override .args() from Basic (so that it is easy to
  658. change the interface in the future if needed).
  659. """
  660. return self._args
  661. @property
  662. def _sorted_args(self):
  663. """
  664. The same as ``args``. Derived classes which do not fix an
  665. order on their arguments should override this method to
  666. produce the sorted representation.
  667. """
  668. return self.args
  669. def as_content_primitive(self, radical=False, clear=True):
  670. """A stub to allow Basic args (like Tuple) to be skipped when computing
  671. the content and primitive components of an expression.
  672. See Also
  673. ========
  674. sympy.core.expr.Expr.as_content_primitive
  675. """
  676. return S.One, self
  677. def subs(self, *args, **kwargs):
  678. """
  679. Substitutes old for new in an expression after sympifying args.
  680. `args` is either:
  681. - two arguments, e.g. foo.subs(old, new)
  682. - one iterable argument, e.g. foo.subs(iterable). The iterable may be
  683. o an iterable container with (old, new) pairs. In this case the
  684. replacements are processed in the order given with successive
  685. patterns possibly affecting replacements already made.
  686. o a dict or set whose key/value items correspond to old/new pairs.
  687. In this case the old/new pairs will be sorted by op count and in
  688. case of a tie, by number of args and the default_sort_key. The
  689. resulting sorted list is then processed as an iterable container
  690. (see previous).
  691. If the keyword ``simultaneous`` is True, the subexpressions will not be
  692. evaluated until all the substitutions have been made.
  693. Examples
  694. ========
  695. >>> from sympy import pi, exp, limit, oo
  696. >>> from sympy.abc import x, y
  697. >>> (1 + x*y).subs(x, pi)
  698. pi*y + 1
  699. >>> (1 + x*y).subs({x:pi, y:2})
  700. 1 + 2*pi
  701. >>> (1 + x*y).subs([(x, pi), (y, 2)])
  702. 1 + 2*pi
  703. >>> reps = [(y, x**2), (x, 2)]
  704. >>> (x + y).subs(reps)
  705. 6
  706. >>> (x + y).subs(reversed(reps))
  707. x**2 + 2
  708. >>> (x**2 + x**4).subs(x**2, y)
  709. y**2 + y
  710. To replace only the x**2 but not the x**4, use xreplace:
  711. >>> (x**2 + x**4).xreplace({x**2: y})
  712. x**4 + y
  713. To delay evaluation until all substitutions have been made,
  714. set the keyword ``simultaneous`` to True:
  715. >>> (x/y).subs([(x, 0), (y, 0)])
  716. 0
  717. >>> (x/y).subs([(x, 0), (y, 0)], simultaneous=True)
  718. nan
  719. This has the added feature of not allowing subsequent substitutions
  720. to affect those already made:
  721. >>> ((x + y)/y).subs({x + y: y, y: x + y})
  722. 1
  723. >>> ((x + y)/y).subs({x + y: y, y: x + y}, simultaneous=True)
  724. y/(x + y)
  725. In order to obtain a canonical result, unordered iterables are
  726. sorted by count_op length, number of arguments and by the
  727. default_sort_key to break any ties. All other iterables are left
  728. unsorted.
  729. >>> from sympy import sqrt, sin, cos
  730. >>> from sympy.abc import a, b, c, d, e
  731. >>> A = (sqrt(sin(2*x)), a)
  732. >>> B = (sin(2*x), b)
  733. >>> C = (cos(2*x), c)
  734. >>> D = (x, d)
  735. >>> E = (exp(x), e)
  736. >>> expr = sqrt(sin(2*x))*sin(exp(x)*x)*cos(2*x) + sin(2*x)
  737. >>> expr.subs(dict([A, B, C, D, E]))
  738. a*c*sin(d*e) + b
  739. The resulting expression represents a literal replacement of the
  740. old arguments with the new arguments. This may not reflect the
  741. limiting behavior of the expression:
  742. >>> (x**3 - 3*x).subs({x: oo})
  743. nan
  744. >>> limit(x**3 - 3*x, x, oo)
  745. oo
  746. If the substitution will be followed by numerical
  747. evaluation, it is better to pass the substitution to
  748. evalf as
  749. >>> (1/x).evalf(subs={x: 3.0}, n=21)
  750. 0.333333333333333333333
  751. rather than
  752. >>> (1/x).subs({x: 3.0}).evalf(21)
  753. 0.333333333333333314830
  754. as the former will ensure that the desired level of precision is
  755. obtained.
  756. See Also
  757. ========
  758. replace: replacement capable of doing wildcard-like matching,
  759. parsing of match, and conditional replacements
  760. xreplace: exact node replacement in expr tree; also capable of
  761. using matching rules
  762. sympy.core.evalf.EvalfMixin.evalf: calculates the given formula to a desired level of precision
  763. """
  764. from .containers import Dict
  765. from .symbol import Dummy, Symbol
  766. from .numbers import _illegal
  767. unordered = False
  768. if len(args) == 1:
  769. sequence = args[0]
  770. if isinstance(sequence, set):
  771. unordered = True
  772. elif isinstance(sequence, (Dict, Mapping)):
  773. unordered = True
  774. sequence = sequence.items()
  775. elif not iterable(sequence):
  776. raise ValueError(filldedent("""
  777. When a single argument is passed to subs
  778. it should be a dictionary of old: new pairs or an iterable
  779. of (old, new) tuples."""))
  780. elif len(args) == 2:
  781. sequence = [args]
  782. else:
  783. raise ValueError("subs accepts either 1 or 2 arguments")
  784. def sympify_old(old):
  785. if isinstance(old, str):
  786. # Use Symbol rather than parse_expr for old
  787. return Symbol(old)
  788. elif isinstance(old, type):
  789. # Allow a type e.g. Function('f') or sin
  790. return sympify(old, strict=False)
  791. else:
  792. return sympify(old, strict=True)
  793. def sympify_new(new):
  794. if isinstance(new, (str, type)):
  795. # Allow a type or parse a string input
  796. return sympify(new, strict=False)
  797. else:
  798. return sympify(new, strict=True)
  799. sequence = [(sympify_old(s1), sympify_new(s2)) for s1, s2 in sequence]
  800. # skip if there is no change
  801. sequence = [(s1, s2) for s1, s2 in sequence if not _aresame(s1, s2)]
  802. simultaneous = kwargs.pop('simultaneous', False)
  803. if unordered:
  804. from .sorting import _nodes, default_sort_key
  805. sequence = dict(sequence)
  806. # order so more complex items are first and items
  807. # of identical complexity are ordered so
  808. # f(x) < f(y) < x < y
  809. # \___ 2 __/ \_1_/ <- number of nodes
  810. #
  811. # For more complex ordering use an unordered sequence.
  812. k = list(ordered(sequence, default=False, keys=(
  813. lambda x: -_nodes(x),
  814. default_sort_key,
  815. )))
  816. sequence = [(k, sequence[k]) for k in k]
  817. # do infinities first
  818. if not simultaneous:
  819. redo = [i for i, seq in enumerate(sequence) if seq[1] in _illegal]
  820. for i in reversed(redo):
  821. sequence.insert(0, sequence.pop(i))
  822. if simultaneous: # XXX should this be the default for dict subs?
  823. reps = {}
  824. rv = self
  825. kwargs['hack2'] = True
  826. m = Dummy('subs_m')
  827. for old, new in sequence:
  828. com = new.is_commutative
  829. if com is None:
  830. com = True
  831. d = Dummy('subs_d', commutative=com)
  832. # using d*m so Subs will be used on dummy variables
  833. # in things like Derivative(f(x, y), x) in which x
  834. # is both free and bound
  835. rv = rv._subs(old, d*m, **kwargs)
  836. if not isinstance(rv, Basic):
  837. break
  838. reps[d] = new
  839. reps[m] = S.One # get rid of m
  840. return rv.xreplace(reps)
  841. else:
  842. rv = self
  843. for old, new in sequence:
  844. rv = rv._subs(old, new, **kwargs)
  845. if not isinstance(rv, Basic):
  846. break
  847. return rv
  848. @cacheit
  849. def _subs(self, old, new, **hints):
  850. """Substitutes an expression old -> new.
  851. If self is not equal to old then _eval_subs is called.
  852. If _eval_subs does not want to make any special replacement
  853. then a None is received which indicates that the fallback
  854. should be applied wherein a search for replacements is made
  855. amongst the arguments of self.
  856. >>> from sympy import Add
  857. >>> from sympy.abc import x, y, z
  858. Examples
  859. ========
  860. Add's _eval_subs knows how to target x + y in the following
  861. so it makes the change:
  862. >>> (x + y + z).subs(x + y, 1)
  863. z + 1
  864. Add's _eval_subs does not need to know how to find x + y in
  865. the following:
  866. >>> Add._eval_subs(z*(x + y) + 3, x + y, 1) is None
  867. True
  868. The returned None will cause the fallback routine to traverse the args and
  869. pass the z*(x + y) arg to Mul where the change will take place and the
  870. substitution will succeed:
  871. >>> (z*(x + y) + 3).subs(x + y, 1)
  872. z + 3
  873. ** Developers Notes **
  874. An _eval_subs routine for a class should be written if:
  875. 1) any arguments are not instances of Basic (e.g. bool, tuple);
  876. 2) some arguments should not be targeted (as in integration
  877. variables);
  878. 3) if there is something other than a literal replacement
  879. that should be attempted (as in Piecewise where the condition
  880. may be updated without doing a replacement).
  881. If it is overridden, here are some special cases that might arise:
  882. 1) If it turns out that no special change was made and all
  883. the original sub-arguments should be checked for
  884. replacements then None should be returned.
  885. 2) If it is necessary to do substitutions on a portion of
  886. the expression then _subs should be called. _subs will
  887. handle the case of any sub-expression being equal to old
  888. (which usually would not be the case) while its fallback
  889. will handle the recursion into the sub-arguments. For
  890. example, after Add's _eval_subs removes some matching terms
  891. it must process the remaining terms so it calls _subs
  892. on each of the un-matched terms and then adds them
  893. onto the terms previously obtained.
  894. 3) If the initial expression should remain unchanged then
  895. the original expression should be returned. (Whenever an
  896. expression is returned, modified or not, no further
  897. substitution of old -> new is attempted.) Sum's _eval_subs
  898. routine uses this strategy when a substitution is attempted
  899. on any of its summation variables.
  900. """
  901. def fallback(self, old, new):
  902. """
  903. Try to replace old with new in any of self's arguments.
  904. """
  905. hit = False
  906. args = list(self.args)
  907. for i, arg in enumerate(args):
  908. if not hasattr(arg, '_eval_subs'):
  909. continue
  910. arg = arg._subs(old, new, **hints)
  911. if not _aresame(arg, args[i]):
  912. hit = True
  913. args[i] = arg
  914. if hit:
  915. rv = self.func(*args)
  916. hack2 = hints.get('hack2', False)
  917. if hack2 and self.is_Mul and not rv.is_Mul: # 2-arg hack
  918. coeff = S.One
  919. nonnumber = []
  920. for i in args:
  921. if i.is_Number:
  922. coeff *= i
  923. else:
  924. nonnumber.append(i)
  925. nonnumber = self.func(*nonnumber)
  926. if coeff is S.One:
  927. return nonnumber
  928. else:
  929. return self.func(coeff, nonnumber, evaluate=False)
  930. return rv
  931. return self
  932. if _aresame(self, old):
  933. return new
  934. rv = self._eval_subs(old, new)
  935. if rv is None:
  936. rv = fallback(self, old, new)
  937. return rv
  938. def _eval_subs(self, old, new):
  939. """Override this stub if you want to do anything more than
  940. attempt a replacement of old with new in the arguments of self.
  941. See also
  942. ========
  943. _subs
  944. """
  945. return None
  946. def xreplace(self, rule):
  947. """
  948. Replace occurrences of objects within the expression.
  949. Parameters
  950. ==========
  951. rule : dict-like
  952. Expresses a replacement rule
  953. Returns
  954. =======
  955. xreplace : the result of the replacement
  956. Examples
  957. ========
  958. >>> from sympy import symbols, pi, exp
  959. >>> x, y, z = symbols('x y z')
  960. >>> (1 + x*y).xreplace({x: pi})
  961. pi*y + 1
  962. >>> (1 + x*y).xreplace({x: pi, y: 2})
  963. 1 + 2*pi
  964. Replacements occur only if an entire node in the expression tree is
  965. matched:
  966. >>> (x*y + z).xreplace({x*y: pi})
  967. z + pi
  968. >>> (x*y*z).xreplace({x*y: pi})
  969. x*y*z
  970. >>> (2*x).xreplace({2*x: y, x: z})
  971. y
  972. >>> (2*2*x).xreplace({2*x: y, x: z})
  973. 4*z
  974. >>> (x + y + 2).xreplace({x + y: 2})
  975. x + y + 2
  976. >>> (x + 2 + exp(x + 2)).xreplace({x + 2: y})
  977. x + exp(y) + 2
  978. xreplace does not differentiate between free and bound symbols. In the
  979. following, subs(x, y) would not change x since it is a bound symbol,
  980. but xreplace does:
  981. >>> from sympy import Integral
  982. >>> Integral(x, (x, 1, 2*x)).xreplace({x: y})
  983. Integral(y, (y, 1, 2*y))
  984. Trying to replace x with an expression raises an error:
  985. >>> Integral(x, (x, 1, 2*x)).xreplace({x: 2*y}) # doctest: +SKIP
  986. ValueError: Invalid limits given: ((2*y, 1, 4*y),)
  987. See Also
  988. ========
  989. replace: replacement capable of doing wildcard-like matching,
  990. parsing of match, and conditional replacements
  991. subs: substitution of subexpressions as defined by the objects
  992. themselves.
  993. """
  994. value, _ = self._xreplace(rule)
  995. return value
  996. def _xreplace(self, rule):
  997. """
  998. Helper for xreplace. Tracks whether a replacement actually occurred.
  999. """
  1000. if self in rule:
  1001. return rule[self], True
  1002. elif rule:
  1003. args = []
  1004. changed = False
  1005. for a in self.args:
  1006. _xreplace = getattr(a, '_xreplace', None)
  1007. if _xreplace is not None:
  1008. a_xr = _xreplace(rule)
  1009. args.append(a_xr[0])
  1010. changed |= a_xr[1]
  1011. else:
  1012. args.append(a)
  1013. args = tuple(args)
  1014. if changed:
  1015. return self.func(*args), True
  1016. return self, False
  1017. @cacheit
  1018. def has(self, *patterns):
  1019. """
  1020. Test whether any subexpression matches any of the patterns.
  1021. Examples
  1022. ========
  1023. >>> from sympy import sin
  1024. >>> from sympy.abc import x, y, z
  1025. >>> (x**2 + sin(x*y)).has(z)
  1026. False
  1027. >>> (x**2 + sin(x*y)).has(x, y, z)
  1028. True
  1029. >>> x.has(x)
  1030. True
  1031. Note ``has`` is a structural algorithm with no knowledge of
  1032. mathematics. Consider the following half-open interval:
  1033. >>> from sympy import Interval
  1034. >>> i = Interval.Lopen(0, 5); i
  1035. Interval.Lopen(0, 5)
  1036. >>> i.args
  1037. (0, 5, True, False)
  1038. >>> i.has(4) # there is no "4" in the arguments
  1039. False
  1040. >>> i.has(0) # there *is* a "0" in the arguments
  1041. True
  1042. Instead, use ``contains`` to determine whether a number is in the
  1043. interval or not:
  1044. >>> i.contains(4)
  1045. True
  1046. >>> i.contains(0)
  1047. False
  1048. Note that ``expr.has(*patterns)`` is exactly equivalent to
  1049. ``any(expr.has(p) for p in patterns)``. In particular, ``False`` is
  1050. returned when the list of patterns is empty.
  1051. >>> x.has()
  1052. False
  1053. """
  1054. return self._has(iterargs, *patterns)
  1055. def has_xfree(self, s: set[Basic]):
  1056. """Return True if self has any of the patterns in s as a
  1057. free argument, else False. This is like `Basic.has_free`
  1058. but this will only report exact argument matches.
  1059. Examples
  1060. ========
  1061. >>> from sympy import Function
  1062. >>> from sympy.abc import x, y
  1063. >>> f = Function('f')
  1064. >>> f(x).has_xfree({f})
  1065. False
  1066. >>> f(x).has_xfree({f(x)})
  1067. True
  1068. >>> f(x + 1).has_xfree({x})
  1069. True
  1070. >>> f(x + 1).has_xfree({x + 1})
  1071. True
  1072. >>> f(x + y + 1).has_xfree({x + 1})
  1073. False
  1074. """
  1075. # protect O(1) containment check by requiring:
  1076. if type(s) is not set:
  1077. raise TypeError('expecting set argument')
  1078. return any(a in s for a in iterfreeargs(self))
  1079. @cacheit
  1080. def has_free(self, *patterns):
  1081. """Return True if self has object(s) ``x`` as a free expression
  1082. else False.
  1083. Examples
  1084. ========
  1085. >>> from sympy import Integral, Function
  1086. >>> from sympy.abc import x, y
  1087. >>> f = Function('f')
  1088. >>> g = Function('g')
  1089. >>> expr = Integral(f(x), (f(x), 1, g(y)))
  1090. >>> expr.free_symbols
  1091. {y}
  1092. >>> expr.has_free(g(y))
  1093. True
  1094. >>> expr.has_free(*(x, f(x)))
  1095. False
  1096. This works for subexpressions and types, too:
  1097. >>> expr.has_free(g)
  1098. True
  1099. >>> (x + y + 1).has_free(y + 1)
  1100. True
  1101. """
  1102. if not patterns:
  1103. return False
  1104. p0 = patterns[0]
  1105. if len(patterns) == 1 and iterable(p0) and not isinstance(p0, Basic):
  1106. # Basic can contain iterables (though not non-Basic, ideally)
  1107. # but don't encourage mixed passing patterns
  1108. raise TypeError(filldedent('''
  1109. Expecting 1 or more Basic args, not a single
  1110. non-Basic iterable. Don't forget to unpack
  1111. iterables: `eq.has_free(*patterns)`'''))
  1112. # try quick test first
  1113. s = set(patterns)
  1114. rv = self.has_xfree(s)
  1115. if rv:
  1116. return rv
  1117. # now try matching through slower _has
  1118. return self._has(iterfreeargs, *patterns)
  1119. def _has(self, iterargs, *patterns):
  1120. # separate out types and unhashable objects
  1121. type_set = set() # only types
  1122. p_set = set() # hashable non-types
  1123. for p in patterns:
  1124. if isinstance(p, type) and issubclass(p, Basic):
  1125. type_set.add(p)
  1126. continue
  1127. if not isinstance(p, Basic):
  1128. try:
  1129. p = _sympify(p)
  1130. except SympifyError:
  1131. continue # Basic won't have this in it
  1132. p_set.add(p) # fails if object defines __eq__ but
  1133. # doesn't define __hash__
  1134. types = tuple(type_set) #
  1135. for i in iterargs(self): #
  1136. if i in p_set: # <--- here, too
  1137. return True
  1138. if isinstance(i, types):
  1139. return True
  1140. # use matcher if defined, e.g. operations defines
  1141. # matcher that checks for exact subset containment,
  1142. # (x + y + 1).has(x + 1) -> True
  1143. for i in p_set - type_set: # types don't have matchers
  1144. if not hasattr(i, '_has_matcher'):
  1145. continue
  1146. match = i._has_matcher()
  1147. if any(match(arg) for arg in iterargs(self)):
  1148. return True
  1149. # no success
  1150. return False
  1151. def replace(self, query, value, map=False, simultaneous=True, exact=None):
  1152. """
  1153. Replace matching subexpressions of ``self`` with ``value``.
  1154. If ``map = True`` then also return the mapping {old: new} where ``old``
  1155. was a sub-expression found with query and ``new`` is the replacement
  1156. value for it. If the expression itself does not match the query, then
  1157. the returned value will be ``self.xreplace(map)`` otherwise it should
  1158. be ``self.subs(ordered(map.items()))``.
  1159. Traverses an expression tree and performs replacement of matching
  1160. subexpressions from the bottom to the top of the tree. The default
  1161. approach is to do the replacement in a simultaneous fashion so
  1162. changes made are targeted only once. If this is not desired or causes
  1163. problems, ``simultaneous`` can be set to False.
  1164. In addition, if an expression containing more than one Wild symbol
  1165. is being used to match subexpressions and the ``exact`` flag is None
  1166. it will be set to True so the match will only succeed if all non-zero
  1167. values are received for each Wild that appears in the match pattern.
  1168. Setting this to False accepts a match of 0; while setting it True
  1169. accepts all matches that have a 0 in them. See example below for
  1170. cautions.
  1171. The list of possible combinations of queries and replacement values
  1172. is listed below:
  1173. Examples
  1174. ========
  1175. Initial setup
  1176. >>> from sympy import log, sin, cos, tan, Wild, Mul, Add
  1177. >>> from sympy.abc import x, y
  1178. >>> f = log(sin(x)) + tan(sin(x**2))
  1179. 1.1. type -> type
  1180. obj.replace(type, newtype)
  1181. When object of type ``type`` is found, replace it with the
  1182. result of passing its argument(s) to ``newtype``.
  1183. >>> f.replace(sin, cos)
  1184. log(cos(x)) + tan(cos(x**2))
  1185. >>> sin(x).replace(sin, cos, map=True)
  1186. (cos(x), {sin(x): cos(x)})
  1187. >>> (x*y).replace(Mul, Add)
  1188. x + y
  1189. 1.2. type -> func
  1190. obj.replace(type, func)
  1191. When object of type ``type`` is found, apply ``func`` to its
  1192. argument(s). ``func`` must be written to handle the number
  1193. of arguments of ``type``.
  1194. >>> f.replace(sin, lambda arg: sin(2*arg))
  1195. log(sin(2*x)) + tan(sin(2*x**2))
  1196. >>> (x*y).replace(Mul, lambda *args: sin(2*Mul(*args)))
  1197. sin(2*x*y)
  1198. 2.1. pattern -> expr
  1199. obj.replace(pattern(wild), expr(wild))
  1200. Replace subexpressions matching ``pattern`` with the expression
  1201. written in terms of the Wild symbols in ``pattern``.
  1202. >>> a, b = map(Wild, 'ab')
  1203. >>> f.replace(sin(a), tan(a))
  1204. log(tan(x)) + tan(tan(x**2))
  1205. >>> f.replace(sin(a), tan(a/2))
  1206. log(tan(x/2)) + tan(tan(x**2/2))
  1207. >>> f.replace(sin(a), a)
  1208. log(x) + tan(x**2)
  1209. >>> (x*y).replace(a*x, a)
  1210. y
  1211. Matching is exact by default when more than one Wild symbol
  1212. is used: matching fails unless the match gives non-zero
  1213. values for all Wild symbols:
  1214. >>> (2*x + y).replace(a*x + b, b - a)
  1215. y - 2
  1216. >>> (2*x).replace(a*x + b, b - a)
  1217. 2*x
  1218. When set to False, the results may be non-intuitive:
  1219. >>> (2*x).replace(a*x + b, b - a, exact=False)
  1220. 2/x
  1221. 2.2. pattern -> func
  1222. obj.replace(pattern(wild), lambda wild: expr(wild))
  1223. All behavior is the same as in 2.1 but now a function in terms of
  1224. pattern variables is used rather than an expression:
  1225. >>> f.replace(sin(a), lambda a: sin(2*a))
  1226. log(sin(2*x)) + tan(sin(2*x**2))
  1227. 3.1. func -> func
  1228. obj.replace(filter, func)
  1229. Replace subexpression ``e`` with ``func(e)`` if ``filter(e)``
  1230. is True.
  1231. >>> g = 2*sin(x**3)
  1232. >>> g.replace(lambda expr: expr.is_Number, lambda expr: expr**2)
  1233. 4*sin(x**9)
  1234. The expression itself is also targeted by the query but is done in
  1235. such a fashion that changes are not made twice.
  1236. >>> e = x*(x*y + 1)
  1237. >>> e.replace(lambda x: x.is_Mul, lambda x: 2*x)
  1238. 2*x*(2*x*y + 1)
  1239. When matching a single symbol, `exact` will default to True, but
  1240. this may or may not be the behavior that is desired:
  1241. Here, we want `exact=False`:
  1242. >>> from sympy import Function
  1243. >>> f = Function('f')
  1244. >>> e = f(1) + f(0)
  1245. >>> q = f(a), lambda a: f(a + 1)
  1246. >>> e.replace(*q, exact=False)
  1247. f(1) + f(2)
  1248. >>> e.replace(*q, exact=True)
  1249. f(0) + f(2)
  1250. But here, the nature of matching makes selecting
  1251. the right setting tricky:
  1252. >>> e = x**(1 + y)
  1253. >>> (x**(1 + y)).replace(x**(1 + a), lambda a: x**-a, exact=False)
  1254. x
  1255. >>> (x**(1 + y)).replace(x**(1 + a), lambda a: x**-a, exact=True)
  1256. x**(-x - y + 1)
  1257. >>> (x**y).replace(x**(1 + a), lambda a: x**-a, exact=False)
  1258. x
  1259. >>> (x**y).replace(x**(1 + a), lambda a: x**-a, exact=True)
  1260. x**(1 - y)
  1261. It is probably better to use a different form of the query
  1262. that describes the target expression more precisely:
  1263. >>> (1 + x**(1 + y)).replace(
  1264. ... lambda x: x.is_Pow and x.exp.is_Add and x.exp.args[0] == 1,
  1265. ... lambda x: x.base**(1 - (x.exp - 1)))
  1266. ...
  1267. x**(1 - y) + 1
  1268. See Also
  1269. ========
  1270. subs: substitution of subexpressions as defined by the objects
  1271. themselves.
  1272. xreplace: exact node replacement in expr tree; also capable of
  1273. using matching rules
  1274. """
  1275. try:
  1276. query = _sympify(query)
  1277. except SympifyError:
  1278. pass
  1279. try:
  1280. value = _sympify(value)
  1281. except SympifyError:
  1282. pass
  1283. if isinstance(query, type):
  1284. _query = lambda expr: isinstance(expr, query)
  1285. if isinstance(value, type):
  1286. _value = lambda expr, result: value(*expr.args)
  1287. elif callable(value):
  1288. _value = lambda expr, result: value(*expr.args)
  1289. else:
  1290. raise TypeError(
  1291. "given a type, replace() expects another "
  1292. "type or a callable")
  1293. elif isinstance(query, Basic):
  1294. _query = lambda expr: expr.match(query)
  1295. if exact is None:
  1296. from .symbol import Wild
  1297. exact = (len(query.atoms(Wild)) > 1)
  1298. if isinstance(value, Basic):
  1299. if exact:
  1300. _value = lambda expr, result: (value.subs(result)
  1301. if all(result.values()) else expr)
  1302. else:
  1303. _value = lambda expr, result: value.subs(result)
  1304. elif callable(value):
  1305. # match dictionary keys get the trailing underscore stripped
  1306. # from them and are then passed as keywords to the callable;
  1307. # if ``exact`` is True, only accept match if there are no null
  1308. # values amongst those matched.
  1309. if exact:
  1310. _value = lambda expr, result: (value(**
  1311. {str(k)[:-1]: v for k, v in result.items()})
  1312. if all(val for val in result.values()) else expr)
  1313. else:
  1314. _value = lambda expr, result: value(**
  1315. {str(k)[:-1]: v for k, v in result.items()})
  1316. else:
  1317. raise TypeError(
  1318. "given an expression, replace() expects "
  1319. "another expression or a callable")
  1320. elif callable(query):
  1321. _query = query
  1322. if callable(value):
  1323. _value = lambda expr, result: value(expr)
  1324. else:
  1325. raise TypeError(
  1326. "given a callable, replace() expects "
  1327. "another callable")
  1328. else:
  1329. raise TypeError(
  1330. "first argument to replace() must be a "
  1331. "type, an expression or a callable")
  1332. def walk(rv, F):
  1333. """Apply ``F`` to args and then to result.
  1334. """
  1335. args = getattr(rv, 'args', None)
  1336. if args is not None:
  1337. if args:
  1338. newargs = tuple([walk(a, F) for a in args])
  1339. if args != newargs:
  1340. rv = rv.func(*newargs)
  1341. if simultaneous:
  1342. # if rv is something that was already
  1343. # matched (that was changed) then skip
  1344. # applying F again
  1345. for i, e in enumerate(args):
  1346. if rv == e and e != newargs[i]:
  1347. return rv
  1348. rv = F(rv)
  1349. return rv
  1350. mapping = {} # changes that took place
  1351. def rec_replace(expr):
  1352. result = _query(expr)
  1353. if result or result == {}:
  1354. v = _value(expr, result)
  1355. if v is not None and v != expr:
  1356. if map:
  1357. mapping[expr] = v
  1358. expr = v
  1359. return expr
  1360. rv = walk(self, rec_replace)
  1361. return (rv, mapping) if map else rv
  1362. def find(self, query, group=False):
  1363. """Find all subexpressions matching a query."""
  1364. query = _make_find_query(query)
  1365. results = list(filter(query, _preorder_traversal(self)))
  1366. if not group:
  1367. return set(results)
  1368. else:
  1369. groups = {}
  1370. for result in results:
  1371. if result in groups:
  1372. groups[result] += 1
  1373. else:
  1374. groups[result] = 1
  1375. return groups
  1376. def count(self, query):
  1377. """Count the number of matching subexpressions."""
  1378. query = _make_find_query(query)
  1379. return sum(bool(query(sub)) for sub in _preorder_traversal(self))
  1380. def matches(self, expr, repl_dict=None, old=False):
  1381. """
  1382. Helper method for match() that looks for a match between Wild symbols
  1383. in self and expressions in expr.
  1384. Examples
  1385. ========
  1386. >>> from sympy import symbols, Wild, Basic
  1387. >>> a, b, c = symbols('a b c')
  1388. >>> x = Wild('x')
  1389. >>> Basic(a + x, x).matches(Basic(a + b, c)) is None
  1390. True
  1391. >>> Basic(a + x, x).matches(Basic(a + b + c, b + c))
  1392. {x_: b + c}
  1393. """
  1394. expr = sympify(expr)
  1395. if not isinstance(expr, self.__class__):
  1396. return None
  1397. if repl_dict is None:
  1398. repl_dict = {}
  1399. else:
  1400. repl_dict = repl_dict.copy()
  1401. if self == expr:
  1402. return repl_dict
  1403. if len(self.args) != len(expr.args):
  1404. return None
  1405. d = repl_dict # already a copy
  1406. for arg, other_arg in zip(self.args, expr.args):
  1407. if arg == other_arg:
  1408. continue
  1409. if arg.is_Relational:
  1410. try:
  1411. d = arg.xreplace(d).matches(other_arg, d, old=old)
  1412. except TypeError: # Should be InvalidComparisonError when introduced
  1413. d = None
  1414. else:
  1415. d = arg.xreplace(d).matches(other_arg, d, old=old)
  1416. if d is None:
  1417. return None
  1418. return d
  1419. def match(self, pattern, old=False):
  1420. """
  1421. Pattern matching.
  1422. Wild symbols match all.
  1423. Return ``None`` when expression (self) does not match
  1424. with pattern. Otherwise return a dictionary such that::
  1425. pattern.xreplace(self.match(pattern)) == self
  1426. Examples
  1427. ========
  1428. >>> from sympy import Wild, Sum
  1429. >>> from sympy.abc import x, y
  1430. >>> p = Wild("p")
  1431. >>> q = Wild("q")
  1432. >>> r = Wild("r")
  1433. >>> e = (x+y)**(x+y)
  1434. >>> e.match(p**p)
  1435. {p_: x + y}
  1436. >>> e.match(p**q)
  1437. {p_: x + y, q_: x + y}
  1438. >>> e = (2*x)**2
  1439. >>> e.match(p*q**r)
  1440. {p_: 4, q_: x, r_: 2}
  1441. >>> (p*q**r).xreplace(e.match(p*q**r))
  1442. 4*x**2
  1443. Structurally bound symbols are ignored during matching:
  1444. >>> Sum(x, (x, 1, 2)).match(Sum(y, (y, 1, p)))
  1445. {p_: 2}
  1446. But they can be identified if desired:
  1447. >>> Sum(x, (x, 1, 2)).match(Sum(q, (q, 1, p)))
  1448. {p_: 2, q_: x}
  1449. The ``old`` flag will give the old-style pattern matching where
  1450. expressions and patterns are essentially solved to give the
  1451. match. Both of the following give None unless ``old=True``:
  1452. >>> (x - 2).match(p - x, old=True)
  1453. {p_: 2*x - 2}
  1454. >>> (2/x).match(p*x, old=True)
  1455. {p_: 2/x**2}
  1456. """
  1457. pattern = sympify(pattern)
  1458. # match non-bound symbols
  1459. canonical = lambda x: x if x.is_Symbol else x.as_dummy()
  1460. m = canonical(pattern).matches(canonical(self), old=old)
  1461. if m is None:
  1462. return m
  1463. from .symbol import Wild
  1464. from .function import WildFunction
  1465. from ..tensor.tensor import WildTensor, WildTensorIndex, WildTensorHead
  1466. wild = pattern.atoms(Wild, WildFunction, WildTensor, WildTensorIndex, WildTensorHead)
  1467. # sanity check
  1468. if set(m) - wild:
  1469. raise ValueError(filldedent('''
  1470. Some `matches` routine did not use a copy of repl_dict
  1471. and injected unexpected symbols. Report this as an
  1472. error at https://github.com/sympy/sympy/issues'''))
  1473. # now see if bound symbols were requested
  1474. bwild = wild - set(m)
  1475. if not bwild:
  1476. return m
  1477. # replace free-Wild symbols in pattern with match result
  1478. # so they will match but not be in the next match
  1479. wpat = pattern.xreplace(m)
  1480. # identify remaining bound wild
  1481. w = wpat.matches(self, old=old)
  1482. # add them to m
  1483. if w:
  1484. m.update(w)
  1485. # done
  1486. return m
  1487. def count_ops(self, visual=None):
  1488. """Wrapper for count_ops that returns the operation count."""
  1489. from .function import count_ops
  1490. return count_ops(self, visual)
  1491. def doit(self, **hints):
  1492. """Evaluate objects that are not evaluated by default like limits,
  1493. integrals, sums and products. All objects of this kind will be
  1494. evaluated recursively, unless some species were excluded via 'hints'
  1495. or unless the 'deep' hint was set to 'False'.
  1496. >>> from sympy import Integral
  1497. >>> from sympy.abc import x
  1498. >>> 2*Integral(x, x)
  1499. 2*Integral(x, x)
  1500. >>> (2*Integral(x, x)).doit()
  1501. x**2
  1502. >>> (2*Integral(x, x)).doit(deep=False)
  1503. 2*Integral(x, x)
  1504. """
  1505. if hints.get('deep', True):
  1506. terms = [term.doit(**hints) if isinstance(term, Basic) else term
  1507. for term in self.args]
  1508. return self.func(*terms)
  1509. else:
  1510. return self
  1511. def simplify(self, **kwargs):
  1512. """See the simplify function in sympy.simplify"""
  1513. from sympy.simplify.simplify import simplify
  1514. return simplify(self, **kwargs)
  1515. def refine(self, assumption=True):
  1516. """See the refine function in sympy.assumptions"""
  1517. from sympy.assumptions.refine import refine
  1518. return refine(self, assumption)
  1519. def _eval_derivative_n_times(self, s, n):
  1520. # This is the default evaluator for derivatives (as called by `diff`
  1521. # and `Derivative`), it will attempt a loop to derive the expression
  1522. # `n` times by calling the corresponding `_eval_derivative` method,
  1523. # while leaving the derivative unevaluated if `n` is symbolic. This
  1524. # method should be overridden if the object has a closed form for its
  1525. # symbolic n-th derivative.
  1526. from .numbers import Integer
  1527. if isinstance(n, (int, Integer)):
  1528. obj = self
  1529. for i in range(n):
  1530. obj2 = obj._eval_derivative(s)
  1531. if obj == obj2 or obj2 is None:
  1532. break
  1533. obj = obj2
  1534. return obj2
  1535. else:
  1536. return None
  1537. def rewrite(self, *args, deep=True, **hints):
  1538. """
  1539. Rewrite *self* using a defined rule.
  1540. Rewriting transforms an expression to another, which is mathematically
  1541. equivalent but structurally different. For example you can rewrite
  1542. trigonometric functions as complex exponentials or combinatorial
  1543. functions as gamma function.
  1544. This method takes a *pattern* and a *rule* as positional arguments.
  1545. *pattern* is optional parameter which defines the types of expressions
  1546. that will be transformed. If it is not passed, all possible expressions
  1547. will be rewritten. *rule* defines how the expression will be rewritten.
  1548. Parameters
  1549. ==========
  1550. args : Expr
  1551. A *rule*, or *pattern* and *rule*.
  1552. - *pattern* is a type or an iterable of types.
  1553. - *rule* can be any object.
  1554. deep : bool, optional
  1555. If ``True``, subexpressions are recursively transformed. Default is
  1556. ``True``.
  1557. Examples
  1558. ========
  1559. If *pattern* is unspecified, all possible expressions are transformed.
  1560. >>> from sympy import cos, sin, exp, I
  1561. >>> from sympy.abc import x
  1562. >>> expr = cos(x) + I*sin(x)
  1563. >>> expr.rewrite(exp)
  1564. exp(I*x)
  1565. Pattern can be a type or an iterable of types.
  1566. >>> expr.rewrite(sin, exp)
  1567. exp(I*x)/2 + cos(x) - exp(-I*x)/2
  1568. >>> expr.rewrite([cos,], exp)
  1569. exp(I*x)/2 + I*sin(x) + exp(-I*x)/2
  1570. >>> expr.rewrite([cos, sin], exp)
  1571. exp(I*x)
  1572. Rewriting behavior can be implemented by defining ``_eval_rewrite()``
  1573. method.
  1574. >>> from sympy import Expr, sqrt, pi
  1575. >>> class MySin(Expr):
  1576. ... def _eval_rewrite(self, rule, args, **hints):
  1577. ... x, = args
  1578. ... if rule == cos:
  1579. ... return cos(pi/2 - x, evaluate=False)
  1580. ... if rule == sqrt:
  1581. ... return sqrt(1 - cos(x)**2)
  1582. >>> MySin(MySin(x)).rewrite(cos)
  1583. cos(-cos(-x + pi/2) + pi/2)
  1584. >>> MySin(x).rewrite(sqrt)
  1585. sqrt(1 - cos(x)**2)
  1586. Defining ``_eval_rewrite_as_[...]()`` method is supported for backwards
  1587. compatibility reason. This may be removed in the future and using it is
  1588. discouraged.
  1589. >>> class MySin(Expr):
  1590. ... def _eval_rewrite_as_cos(self, *args, **hints):
  1591. ... x, = args
  1592. ... return cos(pi/2 - x, evaluate=False)
  1593. >>> MySin(x).rewrite(cos)
  1594. cos(-x + pi/2)
  1595. """
  1596. if not args:
  1597. return self
  1598. hints.update(deep=deep)
  1599. pattern = args[:-1]
  1600. rule = args[-1]
  1601. # support old design by _eval_rewrite_as_[...] method
  1602. if isinstance(rule, str):
  1603. method = "_eval_rewrite_as_%s" % rule
  1604. elif hasattr(rule, "__name__"):
  1605. # rule is class or function
  1606. clsname = rule.__name__
  1607. method = "_eval_rewrite_as_%s" % clsname
  1608. else:
  1609. # rule is instance
  1610. clsname = rule.__class__.__name__
  1611. method = "_eval_rewrite_as_%s" % clsname
  1612. if pattern:
  1613. if iterable(pattern[0]):
  1614. pattern = pattern[0]
  1615. pattern = tuple(p for p in pattern if self.has(p))
  1616. if not pattern:
  1617. return self
  1618. # hereafter, empty pattern is interpreted as all pattern.
  1619. return self._rewrite(pattern, rule, method, **hints)
  1620. def _rewrite(self, pattern, rule, method, **hints):
  1621. deep = hints.pop('deep', True)
  1622. if deep:
  1623. args = [a._rewrite(pattern, rule, method, **hints)
  1624. for a in self.args]
  1625. else:
  1626. args = self.args
  1627. if not pattern or any(isinstance(self, p) for p in pattern):
  1628. meth = getattr(self, method, None)
  1629. if meth is not None:
  1630. rewritten = meth(*args, **hints)
  1631. else:
  1632. rewritten = self._eval_rewrite(rule, args, **hints)
  1633. if rewritten is not None:
  1634. return rewritten
  1635. if not args:
  1636. return self
  1637. return self.func(*args)
  1638. def _eval_rewrite(self, rule, args, **hints):
  1639. return None
  1640. _constructor_postprocessor_mapping = {} # type: ignore
  1641. @classmethod
  1642. def _exec_constructor_postprocessors(cls, obj):
  1643. # WARNING: This API is experimental.
  1644. # This is an experimental API that introduces constructor
  1645. # postprosessors for SymPy Core elements. If an argument of a SymPy
  1646. # expression has a `_constructor_postprocessor_mapping` attribute, it will
  1647. # be interpreted as a dictionary containing lists of postprocessing
  1648. # functions for matching expression node names.
  1649. clsname = obj.__class__.__name__
  1650. postprocessors = defaultdict(list)
  1651. for i in obj.args:
  1652. try:
  1653. postprocessor_mappings = (
  1654. Basic._constructor_postprocessor_mapping[cls].items()
  1655. for cls in type(i).mro()
  1656. if cls in Basic._constructor_postprocessor_mapping
  1657. )
  1658. for k, v in chain.from_iterable(postprocessor_mappings):
  1659. postprocessors[k].extend([j for j in v if j not in postprocessors[k]])
  1660. except TypeError:
  1661. pass
  1662. for f in postprocessors.get(clsname, []):
  1663. obj = f(obj)
  1664. return obj
  1665. def _sage_(self):
  1666. """
  1667. Convert *self* to a symbolic expression of SageMath.
  1668. This version of the method is merely a placeholder.
  1669. """
  1670. old_method = self._sage_
  1671. from sage.interfaces.sympy import sympy_init
  1672. sympy_init() # may monkey-patch _sage_ method into self's class or superclasses
  1673. if old_method == self._sage_:
  1674. raise NotImplementedError('conversion to SageMath is not implemented')
  1675. else:
  1676. # call the freshly monkey-patched method
  1677. return self._sage_()
  1678. def could_extract_minus_sign(self):
  1679. return False # see Expr.could_extract_minus_sign
  1680. # For all Basic subclasses _prepare_class_assumptions is called by
  1681. # Basic.__init_subclass__ but that method is not called for Basic itself so we
  1682. # call the function here instead.
  1683. _prepare_class_assumptions(Basic)
  1684. class Atom(Basic):
  1685. """
  1686. A parent class for atomic things. An atom is an expression with no subexpressions.
  1687. Examples
  1688. ========
  1689. Symbol, Number, Rational, Integer, ...
  1690. But not: Add, Mul, Pow, ...
  1691. """
  1692. is_Atom = True
  1693. __slots__ = ()
  1694. def matches(self, expr, repl_dict=None, old=False):
  1695. if self == expr:
  1696. if repl_dict is None:
  1697. return {}
  1698. return repl_dict.copy()
  1699. def xreplace(self, rule, hack2=False):
  1700. return rule.get(self, self)
  1701. def doit(self, **hints):
  1702. return self
  1703. @classmethod
  1704. def class_key(cls):
  1705. return 2, 0, cls.__name__
  1706. @cacheit
  1707. def sort_key(self, order=None):
  1708. return self.class_key(), (1, (str(self),)), S.One.sort_key(), S.One
  1709. def _eval_simplify(self, **kwargs):
  1710. return self
  1711. @property
  1712. def _sorted_args(self):
  1713. # this is here as a safeguard against accidentally using _sorted_args
  1714. # on Atoms -- they cannot be rebuilt as atom.func(*atom._sorted_args)
  1715. # since there are no args. So the calling routine should be checking
  1716. # to see that this property is not called for Atoms.
  1717. raise AttributeError('Atoms have no args. It might be necessary'
  1718. ' to make a check for Atoms in the calling code.')
  1719. def _aresame(a, b):
  1720. """Return True if a and b are structurally the same, else False.
  1721. Examples
  1722. ========
  1723. In SymPy (as in Python) two numbers compare the same if they
  1724. have the same underlying base-2 representation even though
  1725. they may not be the same type:
  1726. >>> from sympy import S
  1727. >>> 2.0 == S(2)
  1728. True
  1729. >>> 0.5 == S.Half
  1730. True
  1731. This routine was written to provide a query for such cases that
  1732. would give false when the types do not match:
  1733. >>> from sympy.core.basic import _aresame
  1734. >>> _aresame(S(2.0), S(2))
  1735. False
  1736. """
  1737. from .numbers import Number
  1738. from .function import AppliedUndef, UndefinedFunction as UndefFunc
  1739. if isinstance(a, Number) and isinstance(b, Number):
  1740. return a == b and a.__class__ == b.__class__
  1741. for i, j in zip_longest(_preorder_traversal(a), _preorder_traversal(b)):
  1742. if i != j or type(i) != type(j):
  1743. if ((isinstance(i, UndefFunc) and isinstance(j, UndefFunc)) or
  1744. (isinstance(i, AppliedUndef) and isinstance(j, AppliedUndef))):
  1745. if i.class_key() != j.class_key():
  1746. return False
  1747. else:
  1748. return False
  1749. return True
  1750. def _ne(a, b):
  1751. # use this as a second test after `a != b` if you want to make
  1752. # sure that things are truly equal, e.g.
  1753. # a, b = 0.5, S.Half
  1754. # a !=b or _ne(a, b) -> True
  1755. from .numbers import Number
  1756. # 0.5 == S.Half
  1757. if isinstance(a, Number) and isinstance(b, Number):
  1758. return a.__class__ != b.__class__
  1759. def _atomic(e, recursive=False):
  1760. """Return atom-like quantities as far as substitution is
  1761. concerned: Derivatives, Functions and Symbols. Do not
  1762. return any 'atoms' that are inside such quantities unless
  1763. they also appear outside, too, unless `recursive` is True.
  1764. Examples
  1765. ========
  1766. >>> from sympy import Derivative, Function, cos
  1767. >>> from sympy.abc import x, y
  1768. >>> from sympy.core.basic import _atomic
  1769. >>> f = Function('f')
  1770. >>> _atomic(x + y)
  1771. {x, y}
  1772. >>> _atomic(x + f(y))
  1773. {x, f(y)}
  1774. >>> _atomic(Derivative(f(x), x) + cos(x) + y)
  1775. {y, cos(x), Derivative(f(x), x)}
  1776. """
  1777. pot = _preorder_traversal(e)
  1778. seen = set()
  1779. if isinstance(e, Basic):
  1780. free = getattr(e, "free_symbols", None)
  1781. if free is None:
  1782. return {e}
  1783. else:
  1784. return set()
  1785. from .symbol import Symbol
  1786. from .function import Derivative, Function
  1787. atoms = set()
  1788. for p in pot:
  1789. if p in seen:
  1790. pot.skip()
  1791. continue
  1792. seen.add(p)
  1793. if isinstance(p, Symbol) and p in free:
  1794. atoms.add(p)
  1795. elif isinstance(p, (Derivative, Function)):
  1796. if not recursive:
  1797. pot.skip()
  1798. atoms.add(p)
  1799. return atoms
  1800. def _make_find_query(query):
  1801. """Convert the argument of Basic.find() into a callable"""
  1802. try:
  1803. query = _sympify(query)
  1804. except SympifyError:
  1805. pass
  1806. if isinstance(query, type):
  1807. return lambda expr: isinstance(expr, query)
  1808. elif isinstance(query, Basic):
  1809. return lambda expr: expr.match(query) is not None
  1810. return query
  1811. # Delayed to avoid cyclic import
  1812. from .singleton import S
  1813. from .traversal import (preorder_traversal as _preorder_traversal,
  1814. iterargs, iterfreeargs)
  1815. preorder_traversal = deprecated(
  1816. """
  1817. Using preorder_traversal from the sympy.core.basic submodule is
  1818. deprecated.
  1819. Instead, use preorder_traversal from the top-level sympy namespace, like
  1820. sympy.preorder_traversal
  1821. """,
  1822. deprecated_since_version="1.10",
  1823. active_deprecations_target="deprecated-traversal-functions-moved",
  1824. )(_preorder_traversal)