function.py 113 KB


  1. """
  2. There are three types of functions implemented in SymPy:
  3. 1) defined functions (in the sense that they can be evaluated) like
  4. exp or sin; they have a name and a body:
  5. f = exp
  6. 2) undefined function which have a name but no body. Undefined
  7. functions can be defined using a Function class as follows:
  8. f = Function('f')
  9. (the result will be a Function instance)
  10. 3) anonymous function (or lambda function) which have a body (defined
  11. with dummy variables) but have no name:
  12. f = Lambda(x, exp(x)*x)
  13. f = Lambda((x, y), exp(x)*y)
  14. The fourth type of functions are composites, like (sin + cos)(x); these work in
  15. SymPy core, but are not yet part of SymPy.
  16. Examples
  17. ========
  18. >>> import sympy
  19. >>> f = sympy.Function("f")
  20. >>> from sympy.abc import x
  21. >>> f(x)
  22. f(x)
  23. >>> print(sympy.srepr(f(x).func))
  24. Function('f')
  25. >>> f(x).args
  26. (x,)
  27. """
  28. from __future__ import annotations
  29. from typing import Any
  30. from collections.abc import Iterable
  31. from .add import Add
  32. from .basic import Basic, _atomic
  33. from .cache import cacheit
  34. from .containers import Tuple, Dict
  35. from .decorators import _sympifyit
  36. from .evalf import pure_complex
  37. from .expr import Expr, AtomicExpr
  38. from .logic import fuzzy_and, fuzzy_or, fuzzy_not, FuzzyBool
  39. from .mul import Mul
  40. from .numbers import Rational, Float, Integer
  41. from .operations import LatticeOp
  42. from .parameters import global_parameters
  43. from .rules import Transform
  44. from .singleton import S
  45. from .sympify import sympify, _sympify
  46. from .sorting import default_sort_key, ordered
  47. from sympy.utilities.exceptions import (sympy_deprecation_warning,
  48. SymPyDeprecationWarning, ignore_warnings)
  49. from sympy.utilities.iterables import (has_dups, sift, iterable,
  50. is_sequence, uniq, topological_sort)
  51. from sympy.utilities.lambdify import MPMATH_TRANSLATIONS
  52. from sympy.utilities.misc import as_int, filldedent, func_name
  53. import mpmath
  54. from mpmath.libmp.libmpf import prec_to_dps
  55. import inspect
  56. from collections import Counter
  57. def _coeff_isneg(a):
  58. """Return True if the leading Number is negative.
  59. Examples
  60. ========
  61. >>> from sympy.core.function import _coeff_isneg
  62. >>> from sympy import S, Symbol, oo, pi
  63. >>> _coeff_isneg(-3*pi)
  64. True
  65. >>> _coeff_isneg(S(3))
  66. False
  67. >>> _coeff_isneg(-oo)
  68. True
  69. >>> _coeff_isneg(Symbol('n', negative=True)) # coeff is 1
  70. False
  71. For matrix expressions:
  72. >>> from sympy import MatrixSymbol, sqrt
  73. >>> A = MatrixSymbol("A", 3, 3)
  74. >>> _coeff_isneg(-sqrt(2)*A)
  75. True
  76. >>> _coeff_isneg(sqrt(2)*A)
  77. False
  78. """
  79. if a.is_MatMul:
  80. a = a.args[0]
  81. if a.is_Mul:
  82. a = a.args[0]
  83. return a.is_Number and a.is_extended_negative
  84. class PoleError(Exception):
  85. pass
  86. class ArgumentIndexError(ValueError):
  87. def __str__(self):
  88. return ("Invalid operation with argument number %s for Function %s" %
  89. (self.args[1], self.args[0]))
  90. class BadSignatureError(TypeError):
  91. '''Raised when a Lambda is created with an invalid signature'''
  92. pass
  93. class BadArgumentsError(TypeError):
  94. '''Raised when a Lambda is called with an incorrect number of arguments'''
  95. pass
  96. # Python 3 version that does not raise a Deprecation warning
  97. def arity(cls):
  98. """Return the arity of the function if it is known, else None.
  99. Explanation
  100. ===========
  101. When default values are specified for some arguments, they are
  102. optional and the arity is reported as a tuple of possible values.
  103. Examples
  104. ========
  105. >>> from sympy import arity, log
  106. >>> arity(lambda x: x)
  107. 1
  108. >>> arity(log)
  109. (1, 2)
  110. >>> arity(lambda *x: sum(x)) is None
  111. True
  112. """
  113. eval_ = getattr(cls, 'eval', cls)
  114. parameters = inspect.signature(eval_).parameters.items()
  115. if [p for _, p in parameters if p.kind == p.VAR_POSITIONAL]:
  116. return
  117. p_or_k = [p for _, p in parameters if p.kind == p.POSITIONAL_OR_KEYWORD]
  118. # how many have no default and how many have a default value
  119. no, yes = map(len, sift(p_or_k,
  120. lambda p:p.default == p.empty, binary=True))
  121. return no if not yes else tuple(range(no, no + yes + 1))
  122. class FunctionClass(type):
  123. """
  124. Base class for function classes. FunctionClass is a subclass of type.
  125. Use Function('<function name>' [ , signature ]) to create
  126. undefined function classes.
  127. """
  128. _new = type.__new__
  129. def __init__(cls, *args, **kwargs):
  130. # honor kwarg value or class-defined value before using
  131. # the number of arguments in the eval function (if present)
  132. nargs = kwargs.pop('nargs', cls.__dict__.get('nargs', arity(cls)))
  133. if nargs is None and 'nargs' not in cls.__dict__:
  134. for supcls in cls.__mro__:
  135. if hasattr(supcls, '_nargs'):
  136. nargs = supcls._nargs
  137. break
  138. else:
  139. continue
  140. # Canonicalize nargs here; change to set in nargs.
  141. if is_sequence(nargs):
  142. if not nargs:
  143. raise ValueError(filldedent('''
  144. Incorrectly specified nargs as %s:
  145. if there are no arguments, it should be
  146. `nargs = 0`;
  147. if there are any number of arguments,
  148. it should be
  149. `nargs = None`''' % str(nargs)))
  150. nargs = tuple(ordered(set(nargs)))
  151. elif nargs is not None:
  152. nargs = (as_int(nargs),)
  153. cls._nargs = nargs
  154. # When __init__ is called from UndefinedFunction it is called with
  155. # just one arg but when it is called from subclassing Function it is
  156. # called with the usual (name, bases, namespace) type() signature.
  157. if len(args) == 3:
  158. namespace = args[2]
  159. if 'eval' in namespace and not isinstance(namespace['eval'], classmethod):
  160. raise TypeError("eval on Function subclasses should be a class method (defined with @classmethod)")
  161. @property
  162. def __signature__(self):
  163. """
  164. Allow Python 3's inspect.signature to give a useful signature for
  165. Function subclasses.
  166. """
  167. # Python 3 only, but backports (like the one in IPython) still might
  168. # call this.
  169. try:
  170. from inspect import signature
  171. except ImportError:
  172. return None
  173. # TODO: Look at nargs
  174. return signature(self.eval)
  175. @property
  176. def free_symbols(self):
  177. return set()
  178. @property
  179. def xreplace(self):
  180. # Function needs args so we define a property that returns
  181. # a function that takes args...and then use that function
  182. # to return the right value
  183. return lambda rule, **_: rule.get(self, self)
  184. @property
  185. def nargs(self):
  186. """Return a set of the allowed number of arguments for the function.
  187. Examples
  188. ========
  189. >>> from sympy import Function
  190. >>> f = Function('f')
  191. If the function can take any number of arguments, the set of whole
  192. numbers is returned:
  193. >>> Function('f').nargs
  194. Naturals0
  195. If the function was initialized to accept one or more arguments, a
  196. corresponding set will be returned:
  197. >>> Function('f', nargs=1).nargs
  198. {1}
  199. >>> Function('f', nargs=(2, 1)).nargs
  200. {1, 2}
  201. The undefined function, after application, also has the nargs
  202. attribute; the actual number of arguments is always available by
  203. checking the ``args`` attribute:
  204. >>> f = Function('f')
  205. >>> f(1).nargs
  206. Naturals0
  207. >>> len(f(1).args)
  208. 1
  209. """
  210. from sympy.sets.sets import FiniteSet
  211. # XXX it would be nice to handle this in __init__ but there are import
  212. # problems with trying to import FiniteSet there
  213. return FiniteSet(*self._nargs) if self._nargs else S.Naturals0
  214. def _valid_nargs(self, n : int) -> bool:
  215. """ Return True if the specified integer is a valid number of arguments
  216. The number of arguments n is guaranteed to be an integer and positive
  217. """
  218. if self._nargs:
  219. return n in self._nargs
  220. nargs = self.nargs
  221. return nargs is S.Naturals0 or n in nargs
  222. def __repr__(cls):
  223. return cls.__name__
  224. class Application(Basic, metaclass=FunctionClass):
  225. """
  226. Base class for applied functions.
  227. Explanation
  228. ===========
  229. Instances of Application represent the result of applying an application of
  230. any type to any object.
  231. """
  232. is_Function = True
  233. @cacheit
  234. def __new__(cls, *args, **options):
  235. from sympy.sets.fancysets import Naturals0
  236. from sympy.sets.sets import FiniteSet
  237. args = list(map(sympify, args))
  238. evaluate = options.pop('evaluate', global_parameters.evaluate)
  239. # WildFunction (and anything else like it) may have nargs defined
  240. # and we throw that value away here
  241. options.pop('nargs', None)
  242. if options:
  243. raise ValueError("Unknown options: %s" % options)
  244. if evaluate:
  245. evaluated = cls.eval(*args)
  246. if evaluated is not None:
  247. return evaluated
  248. obj = super().__new__(cls, *args, **options)
  249. # make nargs uniform here
  250. sentinel = object()
  251. objnargs = getattr(obj, "nargs", sentinel)
  252. if objnargs is not sentinel:
  253. # things passing through here:
  254. # - functions subclassed from Function (e.g. myfunc(1).nargs)
  255. # - functions like cos(1).nargs
  256. # - AppliedUndef with given nargs like Function('f', nargs=1)(1).nargs
  257. # Canonicalize nargs here
  258. if is_sequence(objnargs):
  259. nargs = tuple(ordered(set(objnargs)))
  260. elif objnargs is not None:
  261. nargs = (as_int(objnargs),)
  262. else:
  263. nargs = None
  264. else:
  265. # things passing through here:
  266. # - WildFunction('f').nargs
  267. # - AppliedUndef with no nargs like Function('f')(1).nargs
  268. nargs = obj._nargs # note the underscore here
  269. # convert to FiniteSet
  270. obj.nargs = FiniteSet(*nargs) if nargs else Naturals0()
  271. return obj
  272. @classmethod
  273. def eval(cls, *args):
  274. """
  275. Returns a canonical form of cls applied to arguments args.
  276. Explanation
  277. ===========
  278. The ``eval()`` method is called when the class ``cls`` is about to be
  279. instantiated and it should return either some simplified instance
  280. (possible of some other class), or if the class ``cls`` should be
  281. unmodified, return None.
  282. Examples of ``eval()`` for the function "sign"
  283. .. code-block:: python
  284. @classmethod
  285. def eval(cls, arg):
  286. if arg is S.NaN:
  287. return S.NaN
  288. if arg.is_zero: return S.Zero
  289. if arg.is_positive: return S.One
  290. if arg.is_negative: return S.NegativeOne
  291. if isinstance(arg, Mul):
  292. coeff, terms = arg.as_coeff_Mul(rational=True)
  293. if coeff is not S.One:
  294. return cls(coeff) * cls(terms)
  295. """
  296. return
  297. @property
  298. def func(self):
  299. return self.__class__
  300. def _eval_subs(self, old, new):
  301. if (old.is_Function and new.is_Function and
  302. callable(old) and callable(new) and
  303. old == self.func and len(self.args) in new.nargs):
  304. return new(*[i._subs(old, new) for i in self.args])
  305. class Function(Application, Expr):
  306. r"""
  307. Base class for applied mathematical functions.
  308. It also serves as a constructor for undefined function classes.
  309. See the :ref:`custom-functions` guide for details on how to subclass
  310. ``Function`` and what methods can be defined.
  311. Examples
  312. ========
  313. **Undefined Functions**
  314. To create an undefined function, pass a string of the function name to
  315. ``Function``.
  316. >>> from sympy import Function, Symbol
  317. >>> x = Symbol('x')
  318. >>> f = Function('f')
  319. >>> g = Function('g')(x)
  320. >>> f
  321. f
  322. >>> f(x)
  323. f(x)
  324. >>> g
  325. g(x)
  326. >>> f(x).diff(x)
  327. Derivative(f(x), x)
  328. >>> g.diff(x)
  329. Derivative(g(x), x)
  330. Assumptions can be passed to ``Function`` the same as with a
  331. :class:`~.Symbol`. Alternatively, you can use a ``Symbol`` with
  332. assumptions for the function name and the function will inherit the name
  333. and assumptions associated with the ``Symbol``:
  334. >>> f_real = Function('f', real=True)
  335. >>> f_real(x).is_real
  336. True
  337. >>> f_real_inherit = Function(Symbol('f', real=True))
  338. >>> f_real_inherit(x).is_real
  339. True
  340. Note that assumptions on a function are unrelated to the assumptions on
  341. the variables it is called on. If you want to add a relationship, subclass
  342. ``Function`` and define custom assumptions handler methods. See the
  343. :ref:`custom-functions-assumptions` section of the :ref:`custom-functions`
  344. guide for more details.
  345. **Custom Function Subclasses**
  346. The :ref:`custom-functions` guide has several
  347. :ref:`custom-functions-complete-examples` of how to subclass ``Function``
  348. to create a custom function.
  349. """
  350. @property
  351. def _diff_wrt(self):
  352. return False
  353. @cacheit
  354. def __new__(cls, *args, **options):
  355. # Handle calls like Function('f')
  356. if cls is Function:
  357. return UndefinedFunction(*args, **options)
  358. n = len(args)
  359. if not cls._valid_nargs(n):
  360. # XXX: exception message must be in exactly this format to
  361. # make it work with NumPy's functions like vectorize(). See,
  362. # for example, https://github.com/numpy/numpy/issues/1697.
  363. # The ideal solution would be just to attach metadata to
  364. # the exception and change NumPy to take advantage of this.
  365. temp = ('%(name)s takes %(qual)s %(args)s '
  366. 'argument%(plural)s (%(given)s given)')
  367. raise TypeError(temp % {
  368. 'name': cls,
  369. 'qual': 'exactly' if len(cls.nargs) == 1 else 'at least',
  370. 'args': min(cls.nargs),
  371. 'plural': 's'*(min(cls.nargs) != 1),
  372. 'given': n})
  373. evaluate = options.get('evaluate', global_parameters.evaluate)
  374. result = super().__new__(cls, *args, **options)
  375. if evaluate and isinstance(result, cls) and result.args:
  376. _should_evalf = [cls._should_evalf(a) for a in result.args]
  377. pr2 = min(_should_evalf)
  378. if pr2 > 0:
  379. pr = max(_should_evalf)
  380. result = result.evalf(prec_to_dps(pr))
  381. return _sympify(result)
  382. @classmethod
  383. def _should_evalf(cls, arg):
  384. """
  385. Decide if the function should automatically evalf().
  386. Explanation
  387. ===========
  388. By default (in this implementation), this happens if (and only if) the
  389. ARG is a floating point number (including complex numbers).
  390. This function is used by __new__.
  391. Returns the precision to evalf to, or -1 if it should not evalf.
  392. """
  393. if arg.is_Float:
  394. return arg._prec
  395. if not arg.is_Add:
  396. return -1
  397. m = pure_complex(arg)
  398. if m is None:
  399. return -1
  400. # the elements of m are of type Number, so have a _prec
  401. return max(m[0]._prec, m[1]._prec)
  402. @classmethod
  403. def class_key(cls):
  404. from sympy.sets.fancysets import Naturals0
  405. funcs = {
  406. 'exp': 10,
  407. 'log': 11,
  408. 'sin': 20,
  409. 'cos': 21,
  410. 'tan': 22,
  411. 'cot': 23,
  412. 'sinh': 30,
  413. 'cosh': 31,
  414. 'tanh': 32,
  415. 'coth': 33,
  416. 'conjugate': 40,
  417. 're': 41,
  418. 'im': 42,
  419. 'arg': 43,
  420. }
  421. name = cls.__name__
  422. try:
  423. i = funcs[name]
  424. except KeyError:
  425. i = 0 if isinstance(cls.nargs, Naturals0) else 10000
  426. return 4, i, name
  427. def _eval_evalf(self, prec):
  428. def _get_mpmath_func(fname):
  429. """Lookup mpmath function based on name"""
  430. if isinstance(self, AppliedUndef):
  431. # Shouldn't lookup in mpmath but might have ._imp_
  432. return None
  433. if not hasattr(mpmath, fname):
  434. fname = MPMATH_TRANSLATIONS.get(fname, None)
  435. if fname is None:
  436. return None
  437. return getattr(mpmath, fname)
  438. _eval_mpmath = getattr(self, '_eval_mpmath', None)
  439. if _eval_mpmath is None:
  440. func = _get_mpmath_func(self.func.__name__)
  441. args = self.args
  442. else:
  443. func, args = _eval_mpmath()
  444. # Fall-back evaluation
  445. if func is None:
  446. imp = getattr(self, '_imp_', None)
  447. if imp is None:
  448. return None
  449. try:
  450. return Float(imp(*[i.evalf(prec) for i in self.args]), prec)
  451. except (TypeError, ValueError):
  452. return None
  453. # Convert all args to mpf or mpc
  454. # Convert the arguments to *higher* precision than requested for the
  455. # final result.
  456. # XXX + 5 is a guess, it is similar to what is used in evalf.py. Should
  457. # we be more intelligent about it?
  458. try:
  459. args = [arg._to_mpmath(prec + 5) for arg in args]
  460. def bad(m):
  461. from mpmath import mpf, mpc
  462. # the precision of an mpf value is the last element
  463. # if that is 1 (and m[1] is not 1 which would indicate a
  464. # power of 2), then the eval failed; so check that none of
  465. # the arguments failed to compute to a finite precision.
  466. # Note: An mpc value has two parts, the re and imag tuple;
  467. # check each of those parts, too. Anything else is allowed to
  468. # pass
  469. if isinstance(m, mpf):
  470. m = m._mpf_
  471. return m[1] !=1 and m[-1] == 1
  472. elif isinstance(m, mpc):
  473. m, n = m._mpc_
  474. return m[1] !=1 and m[-1] == 1 and \
  475. n[1] !=1 and n[-1] == 1
  476. else:
  477. return False
  478. if any(bad(a) for a in args):
  479. raise ValueError # one or more args failed to compute with significance
  480. except ValueError:
  481. return
  482. with mpmath.workprec(prec):
  483. v = func(*args)
  484. return Expr._from_mpmath(v, prec)
  485. def _eval_derivative(self, s):
  486. # f(x).diff(s) -> x.diff(s) * f.fdiff(1)(s)
  487. i = 0
  488. l = []
  489. for a in self.args:
  490. i += 1
  491. da = a.diff(s)
  492. if da.is_zero:
  493. continue
  494. try:
  495. df = self.fdiff(i)
  496. except ArgumentIndexError:
  497. df = Function.fdiff(self, i)
  498. l.append(df * da)
  499. return Add(*l)
  500. def _eval_is_commutative(self):
  501. return fuzzy_and(a.is_commutative for a in self.args)
  502. def _eval_is_meromorphic(self, x, a):
  503. if not self.args:
  504. return True
  505. if any(arg.has(x) for arg in self.args[1:]):
  506. return False
  507. arg = self.args[0]
  508. if not arg._eval_is_meromorphic(x, a):
  509. return None
  510. return fuzzy_not(type(self).is_singular(arg.subs(x, a)))
  511. _singularities: FuzzyBool | tuple[Expr, ...] = None
  512. @classmethod
  513. def is_singular(cls, a):
  514. """
  515. Tests whether the argument is an essential singularity
  516. or a branch point, or the functions is non-holomorphic.
  517. """
  518. ss = cls._singularities
  519. if ss in (True, None, False):
  520. return ss
  521. return fuzzy_or(a.is_infinite if s is S.ComplexInfinity
  522. else (a - s).is_zero for s in ss)
  523. def as_base_exp(self):
  524. """
  525. Returns the method as the 2-tuple (base, exponent).
  526. """
  527. return self, S.One
  528. def _eval_aseries(self, n, args0, x, logx):
  529. """
  530. Compute an asymptotic expansion around args0, in terms of self.args.
  531. This function is only used internally by _eval_nseries and should not
  532. be called directly; derived classes can overwrite this to implement
  533. asymptotic expansions.
  534. """
  535. raise PoleError(filldedent('''
  536. Asymptotic expansion of %s around %s is
  537. not implemented.''' % (type(self), args0)))
  538. def _eval_nseries(self, x, n, logx, cdir=0):
  539. """
  540. This function does compute series for multivariate functions,
  541. but the expansion is always in terms of *one* variable.
  542. Examples
  543. ========
  544. >>> from sympy import atan2
  545. >>> from sympy.abc import x, y
  546. >>> atan2(x, y).series(x, n=2)
  547. atan2(0, y) + x/y + O(x**2)
  548. >>> atan2(x, y).series(y, n=2)
  549. -y/x + atan2(x, 0) + O(y**2)
  550. This function also computes asymptotic expansions, if necessary
  551. and possible:
  552. >>> from sympy import loggamma
  553. >>> loggamma(1/x)._eval_nseries(x,0,None)
  554. -1/x - log(x)/x + log(x)/2 + O(1)
  555. """
  556. from .symbol import uniquely_named_symbol
  557. from sympy.series.order import Order
  558. from sympy.sets.sets import FiniteSet
  559. args = self.args
  560. args0 = [t.limit(x, 0) for t in args]
  561. if any(t.is_finite is False for t in args0):
  562. from .numbers import oo, zoo, nan
  563. a = [t.as_leading_term(x, logx=logx) for t in args]
  564. a0 = [t.limit(x, 0) for t in a]
  565. if any(t.has(oo, -oo, zoo, nan) for t in a0):
  566. return self._eval_aseries(n, args0, x, logx)
  567. # Careful: the argument goes to oo, but only logarithmically so. We
  568. # are supposed to do a power series expansion "around the
  569. # logarithmic term". e.g.
  570. # f(1+x+log(x))
  571. # -> f(1+logx) + x*f'(1+logx) + O(x**2)
  572. # where 'logx' is given in the argument
  573. a = [t._eval_nseries(x, n, logx) for t in args]
  574. z = [r - r0 for (r, r0) in zip(a, a0)]
  575. p = [Dummy() for _ in z]
  576. q = []
  577. v = None
  578. for ai, zi, pi in zip(a0, z, p):
  579. if zi.has(x):
  580. if v is not None:
  581. raise NotImplementedError
  582. q.append(ai + pi)
  583. v = pi
  584. else:
  585. q.append(ai)
  586. e1 = self.func(*q)
  587. if v is None:
  588. return e1
  589. s = e1._eval_nseries(v, n, logx)
  590. o = s.getO()
  591. s = s.removeO()
  592. s = s.subs(v, zi).expand() + Order(o.expr.subs(v, zi), x)
  593. return s
  594. if (self.func.nargs is S.Naturals0
  595. or (self.func.nargs == FiniteSet(1) and args0[0])
  596. or any(c > 1 for c in self.func.nargs)):
  597. e = self
  598. e1 = e.expand()
  599. if e == e1:
  600. #for example when e = sin(x+1) or e = sin(cos(x))
  601. #let's try the general algorithm
  602. if len(e.args) == 1:
  603. # issue 14411
  604. e = e.func(e.args[0].cancel())
  605. term = e.subs(x, S.Zero)
  606. if term.is_finite is False or term is S.NaN:
  607. raise PoleError("Cannot expand %s around 0" % (self))
  608. series = term
  609. fact = S.One
  610. _x = uniquely_named_symbol('xi', self)
  611. e = e.subs(x, _x)
  612. for i in range(1, n):
  613. fact *= Rational(i)
  614. e = e.diff(_x)
  615. subs = e.subs(_x, S.Zero)
  616. if subs is S.NaN:
  617. # try to evaluate a limit if we have to
  618. subs = e.limit(_x, S.Zero)
  619. if subs.is_finite is False:
  620. raise PoleError("Cannot expand %s around 0" % (self))
  621. term = subs*(x**i)/fact
  622. term = term.expand()
  623. series += term
  624. return series + Order(x**n, x)
  625. return e1.nseries(x, n=n, logx=logx)
  626. arg = self.args[0]
  627. l = []
  628. g = None
  629. # try to predict a number of terms needed
  630. nterms = n + 2
  631. cf = Order(arg.as_leading_term(x), x).getn()
  632. if cf != 0:
  633. nterms = (n/cf).ceiling()
  634. for i in range(nterms):
  635. g = self.taylor_term(i, arg, g)
  636. g = g.nseries(x, n=n, logx=logx)
  637. l.append(g)
  638. return Add(*l) + Order(x**n, x)
  639. def fdiff(self, argindex=1):
  640. """
  641. Returns the first derivative of the function.
  642. """
  643. if not (1 <= argindex <= len(self.args)):
  644. raise ArgumentIndexError(self, argindex)
  645. ix = argindex - 1
  646. A = self.args[ix]
  647. if A._diff_wrt:
  648. if len(self.args) == 1 or not A.is_Symbol:
  649. return _derivative_dispatch(self, A)
  650. for i, v in enumerate(self.args):
  651. if i != ix and A in v.free_symbols:
  652. # it can't be in any other argument's free symbols
  653. # issue 8510
  654. break
  655. else:
  656. return _derivative_dispatch(self, A)
  657. # See issue 4624 and issue 4719, 5600 and 8510
  658. D = Dummy('xi_%i' % argindex, dummy_index=hash(A))
  659. args = self.args[:ix] + (D,) + self.args[ix + 1:]
  660. return Subs(Derivative(self.func(*args), D), D, A)
  661. def _eval_as_leading_term(self, x, logx=None, cdir=0):
  662. """Stub that should be overridden by new Functions to return
  663. the first non-zero term in a series if ever an x-dependent
  664. argument whose leading term vanishes as x -> 0 might be encountered.
  665. See, for example, cos._eval_as_leading_term.
  666. """
  667. from sympy.series.order import Order
  668. args = [a.as_leading_term(x, logx=logx) for a in self.args]
  669. o = Order(1, x)
  670. if any(x in a.free_symbols and o.contains(a) for a in args):
  671. # Whereas x and any finite number are contained in O(1, x),
  672. # expressions like 1/x are not. If any arg simplified to a
  673. # vanishing expression as x -> 0 (like x or x**2, but not
  674. # 3, 1/x, etc...) then the _eval_as_leading_term is needed
  675. # to supply the first non-zero term of the series,
  676. #
  677. # e.g. expression leading term
  678. # ---------- ------------
  679. # cos(1/x) cos(1/x)
  680. # cos(cos(x)) cos(1)
  681. # cos(x) 1 <- _eval_as_leading_term needed
  682. # sin(x) x <- _eval_as_leading_term needed
  683. #
  684. raise NotImplementedError(
  685. '%s has no _eval_as_leading_term routine' % self.func)
  686. else:
  687. return self.func(*args)
  688. class AppliedUndef(Function):
  689. """
  690. Base class for expressions resulting from the application of an undefined
  691. function.
  692. """
  693. is_number = False
  694. def __new__(cls, *args, **options):
  695. args = list(map(sympify, args))
  696. u = [a.name for a in args if isinstance(a, UndefinedFunction)]
  697. if u:
  698. raise TypeError('Invalid argument: expecting an expression, not UndefinedFunction%s: %s' % (
  699. 's'*(len(u) > 1), ', '.join(u)))
  700. obj = super().__new__(cls, *args, **options)
  701. return obj
  702. def _eval_as_leading_term(self, x, logx=None, cdir=0):
  703. return self
  704. @property
  705. def _diff_wrt(self):
  706. """
  707. Allow derivatives wrt to undefined functions.
  708. Examples
  709. ========
  710. >>> from sympy import Function, Symbol
  711. >>> f = Function('f')
  712. >>> x = Symbol('x')
  713. >>> f(x)._diff_wrt
  714. True
  715. >>> f(x).diff(x)
  716. Derivative(f(x), x)
  717. """
  718. return True
  719. class UndefSageHelper:
  720. """
  721. Helper to facilitate Sage conversion.
  722. """
  723. def __get__(self, ins, typ):
  724. import sage.all as sage
  725. if ins is None:
  726. return lambda: sage.function(typ.__name__)
  727. else:
  728. args = [arg._sage_() for arg in ins.args]
  729. return lambda : sage.function(ins.__class__.__name__)(*args)
  730. _undef_sage_helper = UndefSageHelper()
  731. class UndefinedFunction(FunctionClass):
  732. """
  733. The (meta)class of undefined functions.
  734. """
  735. def __new__(mcl, name, bases=(AppliedUndef,), __dict__=None, **kwargs):
  736. from .symbol import _filter_assumptions
  737. # Allow Function('f', real=True)
  738. # and/or Function(Symbol('f', real=True))
  739. assumptions, kwargs = _filter_assumptions(kwargs)
  740. if isinstance(name, Symbol):
  741. assumptions = name._merge(assumptions)
  742. name = name.name
  743. elif not isinstance(name, str):
  744. raise TypeError('expecting string or Symbol for name')
  745. else:
  746. commutative = assumptions.get('commutative', None)
  747. assumptions = Symbol(name, **assumptions).assumptions0
  748. if commutative is None:
  749. assumptions.pop('commutative')
  750. __dict__ = __dict__ or {}
  751. # put the `is_*` for into __dict__
  752. __dict__.update({'is_%s' % k: v for k, v in assumptions.items()})
  753. # You can add other attributes, although they do have to be hashable
  754. # (but seriously, if you want to add anything other than assumptions,
  755. # just subclass Function)
  756. __dict__.update(kwargs)
  757. # add back the sanitized assumptions without the is_ prefix
  758. kwargs.update(assumptions)
  759. # Save these for __eq__
  760. __dict__.update({'_kwargs': kwargs})
  761. # do this for pickling
  762. __dict__['__module__'] = None
  763. obj = super().__new__(mcl, name, bases, __dict__)
  764. obj.name = name
  765. obj._sage_ = _undef_sage_helper
  766. return obj
  767. def __instancecheck__(cls, instance):
  768. return cls in type(instance).__mro__
  769. _kwargs: dict[str, bool | None] = {}
  770. def __hash__(self):
  771. return hash((self.class_key(), frozenset(self._kwargs.items())))
  772. def __eq__(self, other):
  773. return (isinstance(other, self.__class__) and
  774. self.class_key() == other.class_key() and
  775. self._kwargs == other._kwargs)
  776. def __ne__(self, other):
  777. return not self == other
  778. @property
  779. def _diff_wrt(self):
  780. return False
  781. # XXX: The type: ignore on WildFunction is because mypy complains:
  782. #
  783. # sympy/core/function.py:939: error: Cannot determine type of 'sort_key' in
  784. # base class 'Expr'
  785. #
  786. # Somehow this is because of the @cacheit decorator but it is not clear how to
  787. # fix it.
  788. class WildFunction(Function, AtomicExpr): # type: ignore
  789. """
  790. A WildFunction function matches any function (with its arguments).
  791. Examples
  792. ========
  793. >>> from sympy import WildFunction, Function, cos
  794. >>> from sympy.abc import x, y
  795. >>> F = WildFunction('F')
  796. >>> f = Function('f')
  797. >>> F.nargs
  798. Naturals0
  799. >>> x.match(F)
  800. >>> F.match(F)
  801. {F_: F_}
  802. >>> f(x).match(F)
  803. {F_: f(x)}
  804. >>> cos(x).match(F)
  805. {F_: cos(x)}
  806. >>> f(x, y).match(F)
  807. {F_: f(x, y)}
  808. To match functions with a given number of arguments, set ``nargs`` to the
  809. desired value at instantiation:
  810. >>> F = WildFunction('F', nargs=2)
  811. >>> F.nargs
  812. {2}
  813. >>> f(x).match(F)
  814. >>> f(x, y).match(F)
  815. {F_: f(x, y)}
  816. To match functions with a range of arguments, set ``nargs`` to a tuple
  817. containing the desired number of arguments, e.g. if ``nargs = (1, 2)``
  818. then functions with 1 or 2 arguments will be matched.
  819. >>> F = WildFunction('F', nargs=(1, 2))
  820. >>> F.nargs
  821. {1, 2}
  822. >>> f(x).match(F)
  823. {F_: f(x)}
  824. >>> f(x, y).match(F)
  825. {F_: f(x, y)}
  826. >>> f(x, y, 1).match(F)
  827. """
  828. # XXX: What is this class attribute used for?
  829. include: set[Any] = set()
  830. def __init__(cls, name, **assumptions):
  831. from sympy.sets.sets import Set, FiniteSet
  832. cls.name = name
  833. nargs = assumptions.pop('nargs', S.Naturals0)
  834. if not isinstance(nargs, Set):
  835. # Canonicalize nargs here. See also FunctionClass.
  836. if is_sequence(nargs):
  837. nargs = tuple(ordered(set(nargs)))
  838. elif nargs is not None:
  839. nargs = (as_int(nargs),)
  840. nargs = FiniteSet(*nargs)
  841. cls.nargs = nargs
  842. def matches(self, expr, repl_dict=None, old=False):
  843. if not isinstance(expr, (AppliedUndef, Function)):
  844. return None
  845. if len(expr.args) not in self.nargs:
  846. return None
  847. if repl_dict is None:
  848. repl_dict = {}
  849. else:
  850. repl_dict = repl_dict.copy()
  851. repl_dict[self] = expr
  852. return repl_dict
  853. class Derivative(Expr):
  854. """
  855. Carries out differentiation of the given expression with respect to symbols.
  856. Examples
  857. ========
  858. >>> from sympy import Derivative, Function, symbols, Subs
  859. >>> from sympy.abc import x, y
  860. >>> f, g = symbols('f g', cls=Function)
  861. >>> Derivative(x**2, x, evaluate=True)
  862. 2*x
  863. Denesting of derivatives retains the ordering of variables:
  864. >>> Derivative(Derivative(f(x, y), y), x)
  865. Derivative(f(x, y), y, x)
  866. Contiguously identical symbols are merged into a tuple giving
  867. the symbol and the count:
  868. >>> Derivative(f(x), x, x, y, x)
  869. Derivative(f(x), (x, 2), y, x)
  870. If the derivative cannot be performed, and evaluate is True, the
  871. order of the variables of differentiation will be made canonical:
  872. >>> Derivative(f(x, y), y, x, evaluate=True)
  873. Derivative(f(x, y), x, y)
  874. Derivatives with respect to undefined functions can be calculated:
  875. >>> Derivative(f(x)**2, f(x), evaluate=True)
  876. 2*f(x)
  877. Such derivatives will show up when the chain rule is used to
  878. evalulate a derivative:
  879. >>> f(g(x)).diff(x)
  880. Derivative(f(g(x)), g(x))*Derivative(g(x), x)
  881. Substitution is used to represent derivatives of functions with
  882. arguments that are not symbols or functions:
  883. >>> f(2*x + 3).diff(x) == 2*Subs(f(y).diff(y), y, 2*x + 3)
  884. True
  885. Notes
  886. =====
  887. Simplification of high-order derivatives:
  888. Because there can be a significant amount of simplification that can be
  889. done when multiple differentiations are performed, results will be
  890. automatically simplified in a fairly conservative fashion unless the
  891. keyword ``simplify`` is set to False.
  892. >>> from sympy import sqrt, diff, Function, symbols
  893. >>> from sympy.abc import x, y, z
  894. >>> f, g = symbols('f,g', cls=Function)
  895. >>> e = sqrt((x + 1)**2 + x)
  896. >>> diff(e, (x, 5), simplify=False).count_ops()
  897. 136
  898. >>> diff(e, (x, 5)).count_ops()
  899. 30
  900. Ordering of variables:
  901. If evaluate is set to True and the expression cannot be evaluated, the
  902. list of differentiation symbols will be sorted, that is, the expression is
  903. assumed to have continuous derivatives up to the order asked.
  904. Derivative wrt non-Symbols:
  905. For the most part, one may not differentiate wrt non-symbols.
  906. For example, we do not allow differentiation wrt `x*y` because
  907. there are multiple ways of structurally defining where x*y appears
  908. in an expression: a very strict definition would make
  909. (x*y*z).diff(x*y) == 0. Derivatives wrt defined functions (like
  910. cos(x)) are not allowed, either:
  911. >>> (x*y*z).diff(x*y)
  912. Traceback (most recent call last):
  913. ...
  914. ValueError: Can't calculate derivative wrt x*y.
  915. To make it easier to work with variational calculus, however,
  916. derivatives wrt AppliedUndef and Derivatives are allowed.
  917. For example, in the Euler-Lagrange method one may write
  918. F(t, u, v) where u = f(t) and v = f'(t). These variables can be
  919. written explicitly as functions of time::
  920. >>> from sympy.abc import t
  921. >>> F = Function('F')
  922. >>> U = f(t)
  923. >>> V = U.diff(t)
  924. The derivative wrt f(t) can be obtained directly:
  925. >>> direct = F(t, U, V).diff(U)
  926. When differentiation wrt a non-Symbol is attempted, the non-Symbol
  927. is temporarily converted to a Symbol while the differentiation
  928. is performed and the same answer is obtained:
  929. >>> indirect = F(t, U, V).subs(U, x).diff(x).subs(x, U)
  930. >>> assert direct == indirect
  931. The implication of this non-symbol replacement is that all
  932. functions are treated as independent of other functions and the
  933. symbols are independent of the functions that contain them::
  934. >>> x.diff(f(x))
  935. 0
  936. >>> g(x).diff(f(x))
  937. 0
  938. It also means that derivatives are assumed to depend only
  939. on the variables of differentiation, not on anything contained
  940. within the expression being differentiated::
  941. >>> F = f(x)
  942. >>> Fx = F.diff(x)
  943. >>> Fx.diff(F) # derivative depends on x, not F
  944. 0
  945. >>> Fxx = Fx.diff(x)
  946. >>> Fxx.diff(Fx) # derivative depends on x, not Fx
  947. 0
  948. The last example can be made explicit by showing the replacement
  949. of Fx in Fxx with y:
  950. >>> Fxx.subs(Fx, y)
  951. Derivative(y, x)
  952. Since that in itself will evaluate to zero, differentiating
  953. wrt Fx will also be zero:
  954. >>> _.doit()
  955. 0
  956. Replacing undefined functions with concrete expressions
  957. One must be careful to replace undefined functions with expressions
  958. that contain variables consistent with the function definition and
  959. the variables of differentiation or else insconsistent result will
  960. be obtained. Consider the following example:
  961. >>> eq = f(x)*g(y)
  962. >>> eq.subs(f(x), x*y).diff(x, y).doit()
  963. y*Derivative(g(y), y) + g(y)
  964. >>> eq.diff(x, y).subs(f(x), x*y).doit()
  965. y*Derivative(g(y), y)
  966. The results differ because `f(x)` was replaced with an expression
  967. that involved both variables of differentiation. In the abstract
  968. case, differentiation of `f(x)` by `y` is 0; in the concrete case,
  969. the presence of `y` made that derivative nonvanishing and produced
  970. the extra `g(y)` term.
  971. Defining differentiation for an object
  972. An object must define ._eval_derivative(symbol) method that returns
  973. the differentiation result. This function only needs to consider the
  974. non-trivial case where expr contains symbol and it should call the diff()
  975. method internally (not _eval_derivative); Derivative should be the only
  976. one to call _eval_derivative.
  977. Any class can allow derivatives to be taken with respect to
  978. itself (while indicating its scalar nature). See the
  979. docstring of Expr._diff_wrt.
  980. See Also
  981. ========
  982. _sort_variable_count
  983. """
  984. is_Derivative = True
  985. @property
  986. def _diff_wrt(self):
  987. """An expression may be differentiated wrt a Derivative if
  988. it is in elementary form.
  989. Examples
  990. ========
  991. >>> from sympy import Function, Derivative, cos
  992. >>> from sympy.abc import x
  993. >>> f = Function('f')
  994. >>> Derivative(f(x), x)._diff_wrt
  995. True
  996. >>> Derivative(cos(x), x)._diff_wrt
  997. False
  998. >>> Derivative(x + 1, x)._diff_wrt
  999. False
  1000. A Derivative might be an unevaluated form of what will not be
  1001. a valid variable of differentiation if evaluated. For example,
  1002. >>> Derivative(f(f(x)), x).doit()
  1003. Derivative(f(x), x)*Derivative(f(f(x)), f(x))
  1004. Such an expression will present the same ambiguities as arise
  1005. when dealing with any other product, like ``2*x``, so ``_diff_wrt``
  1006. is False:
  1007. >>> Derivative(f(f(x)), x)._diff_wrt
  1008. False
  1009. """
  1010. return self.expr._diff_wrt and isinstance(self.doit(), Derivative)
  1011. def __new__(cls, expr, *variables, **kwargs):
  1012. expr = sympify(expr)
  1013. symbols_or_none = getattr(expr, "free_symbols", None)
  1014. has_symbol_set = isinstance(symbols_or_none, set)
  1015. if not has_symbol_set:
  1016. raise ValueError(filldedent('''
  1017. Since there are no variables in the expression %s,
  1018. it cannot be differentiated.''' % expr))
  1019. # determine value for variables if it wasn't given
  1020. if not variables:
  1021. variables = expr.free_symbols
  1022. if len(variables) != 1:
  1023. if expr.is_number:
  1024. return S.Zero
  1025. if len(variables) == 0:
  1026. raise ValueError(filldedent('''
  1027. Since there are no variables in the expression,
  1028. the variable(s) of differentiation must be supplied
  1029. to differentiate %s''' % expr))
  1030. else:
  1031. raise ValueError(filldedent('''
  1032. Since there is more than one variable in the
  1033. expression, the variable(s) of differentiation
  1034. must be supplied to differentiate %s''' % expr))
  1035. # Split the list of variables into a list of the variables we are diff
  1036. # wrt, where each element of the list has the form (s, count) where
  1037. # s is the entity to diff wrt and count is the order of the
  1038. # derivative.
  1039. variable_count = []
  1040. array_likes = (tuple, list, Tuple)
  1041. from sympy.tensor.array import Array, NDimArray
  1042. for i, v in enumerate(variables):
  1043. if isinstance(v, UndefinedFunction):
  1044. raise TypeError(
  1045. "cannot differentiate wrt "
  1046. "UndefinedFunction: %s" % v)
  1047. if isinstance(v, array_likes):
  1048. if len(v) == 0:
  1049. # Ignore empty tuples: Derivative(expr, ... , (), ... )
  1050. continue
  1051. if isinstance(v[0], array_likes):
  1052. # Derive by array: Derivative(expr, ... , [[x, y, z]], ... )
  1053. if len(v) == 1:
  1054. v = Array(v[0])
  1055. count = 1
  1056. else:
  1057. v, count = v
  1058. v = Array(v)
  1059. else:
  1060. v, count = v
  1061. if count == 0:
  1062. continue
  1063. variable_count.append(Tuple(v, count))
  1064. continue
  1065. v = sympify(v)
  1066. if isinstance(v, Integer):
  1067. if i == 0:
  1068. raise ValueError("First variable cannot be a number: %i" % v)
  1069. count = v
  1070. prev, prevcount = variable_count[-1]
  1071. if prevcount != 1:
  1072. raise TypeError("tuple {} followed by number {}".format((prev, prevcount), v))
  1073. if count == 0:
  1074. variable_count.pop()
  1075. else:
  1076. variable_count[-1] = Tuple(prev, count)
  1077. else:
  1078. count = 1
  1079. variable_count.append(Tuple(v, count))
  1080. # light evaluation of contiguous, identical
  1081. # items: (x, 1), (x, 1) -> (x, 2)
  1082. merged = []
  1083. for t in variable_count:
  1084. v, c = t
  1085. if c.is_negative:
  1086. raise ValueError(
  1087. 'order of differentiation must be nonnegative')
  1088. if merged and merged[-1][0] == v:
  1089. c += merged[-1][1]
  1090. if not c:
  1091. merged.pop()
  1092. else:
  1093. merged[-1] = Tuple(v, c)
  1094. else:
  1095. merged.append(t)
  1096. variable_count = merged
  1097. # sanity check of variables of differentation; we waited
  1098. # until the counts were computed since some variables may
  1099. # have been removed because the count was 0
  1100. for v, c in variable_count:
  1101. # v must have _diff_wrt True
  1102. if not v._diff_wrt:
  1103. __ = '' # filler to make error message neater
  1104. raise ValueError(filldedent('''
  1105. Can't calculate derivative wrt %s.%s''' % (v,
  1106. __)))
  1107. # We make a special case for 0th derivative, because there is no
  1108. # good way to unambiguously print this.
  1109. if len(variable_count) == 0:
  1110. return expr
  1111. evaluate = kwargs.get('evaluate', False)
  1112. if evaluate:
  1113. if isinstance(expr, Derivative):
  1114. expr = expr.canonical
  1115. variable_count = [
  1116. (v.canonical if isinstance(v, Derivative) else v, c)
  1117. for v, c in variable_count]
  1118. # Look for a quick exit if there are symbols that don't appear in
  1119. # expression at all. Note, this cannot check non-symbols like
  1120. # Derivatives as those can be created by intermediate
  1121. # derivatives.
  1122. zero = False
  1123. free = expr.free_symbols
  1124. from sympy.matrices.expressions.matexpr import MatrixExpr
  1125. for v, c in variable_count:
  1126. vfree = v.free_symbols
  1127. if c.is_positive and vfree:
  1128. if isinstance(v, AppliedUndef):
  1129. # these match exactly since
  1130. # x.diff(f(x)) == g(x).diff(f(x)) == 0
  1131. # and are not created by differentiation
  1132. D = Dummy()
  1133. if not expr.xreplace({v: D}).has(D):
  1134. zero = True
  1135. break
  1136. elif isinstance(v, MatrixExpr):
  1137. zero = False
  1138. break
  1139. elif isinstance(v, Symbol) and v not in free:
  1140. zero = True
  1141. break
  1142. else:
  1143. if not free & vfree:
  1144. # e.g. v is IndexedBase or Matrix
  1145. zero = True
  1146. break
  1147. if zero:
  1148. return cls._get_zero_with_shape_like(expr)
  1149. # make the order of symbols canonical
  1150. #TODO: check if assumption of discontinuous derivatives exist
  1151. variable_count = cls._sort_variable_count(variable_count)
  1152. # denest
  1153. if isinstance(expr, Derivative):
  1154. variable_count = list(expr.variable_count) + variable_count
  1155. expr = expr.expr
  1156. return _derivative_dispatch(expr, *variable_count, **kwargs)
  1157. # we return here if evaluate is False or if there is no
  1158. # _eval_derivative method
  1159. if not evaluate or not hasattr(expr, '_eval_derivative'):
  1160. # return an unevaluated Derivative
  1161. if evaluate and variable_count == [(expr, 1)] and expr.is_scalar:
  1162. # special hack providing evaluation for classes
  1163. # that have defined is_scalar=True but have no
  1164. # _eval_derivative defined
  1165. return S.One
  1166. return Expr.__new__(cls, expr, *variable_count)
  1167. # evaluate the derivative by calling _eval_derivative method
  1168. # of expr for each variable
  1169. # -------------------------------------------------------------
  1170. nderivs = 0 # how many derivatives were performed
  1171. unhandled = []
  1172. from sympy.matrices.common import MatrixCommon
  1173. for i, (v, count) in enumerate(variable_count):
  1174. old_expr = expr
  1175. old_v = None
  1176. is_symbol = v.is_symbol or isinstance(v,
  1177. (Iterable, Tuple, MatrixCommon, NDimArray))
  1178. if not is_symbol:
  1179. old_v = v
  1180. v = Dummy('xi')
  1181. expr = expr.xreplace({old_v: v})
  1182. # Derivatives and UndefinedFunctions are independent
  1183. # of all others
  1184. clashing = not (isinstance(old_v, Derivative) or \
  1185. isinstance(old_v, AppliedUndef))
  1186. if v not in expr.free_symbols and not clashing:
  1187. return expr.diff(v) # expr's version of 0
  1188. if not old_v.is_scalar and not hasattr(
  1189. old_v, '_eval_derivative'):
  1190. # special hack providing evaluation for classes
  1191. # that have defined is_scalar=True but have no
  1192. # _eval_derivative defined
  1193. expr *= old_v.diff(old_v)
  1194. obj = cls._dispatch_eval_derivative_n_times(expr, v, count)
  1195. if obj is not None and obj.is_zero:
  1196. return obj
  1197. nderivs += count
  1198. if old_v is not None:
  1199. if obj is not None:
  1200. # remove the dummy that was used
  1201. obj = obj.subs(v, old_v)
  1202. # restore expr
  1203. expr = old_expr
  1204. if obj is None:
  1205. # we've already checked for quick-exit conditions
  1206. # that give 0 so the remaining variables
  1207. # are contained in the expression but the expression
  1208. # did not compute a derivative so we stop taking
  1209. # derivatives
  1210. unhandled = variable_count[i:]
  1211. break
  1212. expr = obj
  1213. # what we have so far can be made canonical
  1214. expr = expr.replace(
  1215. lambda x: isinstance(x, Derivative),
  1216. lambda x: x.canonical)
  1217. if unhandled:
  1218. if isinstance(expr, Derivative):
  1219. unhandled = list(expr.variable_count) + unhandled
  1220. expr = expr.expr
  1221. expr = Expr.__new__(cls, expr, *unhandled)
  1222. if (nderivs > 1) == True and kwargs.get('simplify', True):
  1223. from .exprtools import factor_terms
  1224. from sympy.simplify.simplify import signsimp
  1225. expr = factor_terms(signsimp(expr))
  1226. return expr
  1227. @property
  1228. def canonical(cls):
  1229. return cls.func(cls.expr,
  1230. *Derivative._sort_variable_count(cls.variable_count))
  1231. @classmethod
  1232. def _sort_variable_count(cls, vc):
  1233. """
  1234. Sort (variable, count) pairs into canonical order while
  1235. retaining order of variables that do not commute during
  1236. differentiation:
  1237. * symbols and functions commute with each other
  1238. * derivatives commute with each other
  1239. * a derivative does not commute with anything it contains
  1240. * any other object is not allowed to commute if it has
  1241. free symbols in common with another object
  1242. Examples
  1243. ========
  1244. >>> from sympy import Derivative, Function, symbols
  1245. >>> vsort = Derivative._sort_variable_count
  1246. >>> x, y, z = symbols('x y z')
  1247. >>> f, g, h = symbols('f g h', cls=Function)
  1248. Contiguous items are collapsed into one pair:
  1249. >>> vsort([(x, 1), (x, 1)])
  1250. [(x, 2)]
  1251. >>> vsort([(y, 1), (f(x), 1), (y, 1), (f(x), 1)])
  1252. [(y, 2), (f(x), 2)]
  1253. Ordering is canonical.
  1254. >>> def vsort0(*v):
  1255. ... # docstring helper to
  1256. ... # change vi -> (vi, 0), sort, and return vi vals
  1257. ... return [i[0] for i in vsort([(i, 0) for i in v])]
  1258. >>> vsort0(y, x)
  1259. [x, y]
  1260. >>> vsort0(g(y), g(x), f(y))
  1261. [f(y), g(x), g(y)]
  1262. Symbols are sorted as far to the left as possible but never
  1263. move to the left of a derivative having the same symbol in
  1264. its variables; the same applies to AppliedUndef which are
  1265. always sorted after Symbols:
  1266. >>> dfx = f(x).diff(x)
  1267. >>> assert vsort0(dfx, y) == [y, dfx]
  1268. >>> assert vsort0(dfx, x) == [dfx, x]
  1269. """
  1270. if not vc:
  1271. return []
  1272. vc = list(vc)
  1273. if len(vc) == 1:
  1274. return [Tuple(*vc[0])]
  1275. V = list(range(len(vc)))
  1276. E = []
  1277. v = lambda i: vc[i][0]
  1278. D = Dummy()
  1279. def _block(d, v, wrt=False):
  1280. # return True if v should not come before d else False
  1281. if d == v:
  1282. return wrt
  1283. if d.is_Symbol:
  1284. return False
  1285. if isinstance(d, Derivative):
  1286. # a derivative blocks if any of it's variables contain
  1287. # v; the wrt flag will return True for an exact match
  1288. # and will cause an AppliedUndef to block if v is in
  1289. # the arguments
  1290. if any(_block(k, v, wrt=True)
  1291. for k in d._wrt_variables):
  1292. return True
  1293. return False
  1294. if not wrt and isinstance(d, AppliedUndef):
  1295. return False
  1296. if v.is_Symbol:
  1297. return v in d.free_symbols
  1298. if isinstance(v, AppliedUndef):
  1299. return _block(d.xreplace({v: D}), D)
  1300. return d.free_symbols & v.free_symbols
  1301. for i in range(len(vc)):
  1302. for j in range(i):
  1303. if _block(v(j), v(i)):
  1304. E.append((j,i))
  1305. # this is the default ordering to use in case of ties
  1306. O = dict(zip(ordered(uniq([i for i, c in vc])), range(len(vc))))
  1307. ix = topological_sort((V, E), key=lambda i: O[v(i)])
  1308. # merge counts of contiguously identical items
  1309. merged = []
  1310. for v, c in [vc[i] for i in ix]:
  1311. if merged and merged[-1][0] == v:
  1312. merged[-1][1] += c
  1313. else:
  1314. merged.append([v, c])
  1315. return [Tuple(*i) for i in merged]
  1316. def _eval_is_commutative(self):
  1317. return self.expr.is_commutative
  1318. def _eval_derivative(self, v):
  1319. # If v (the variable of differentiation) is not in
  1320. # self.variables, we might be able to take the derivative.
  1321. if v not in self._wrt_variables:
  1322. dedv = self.expr.diff(v)
  1323. if isinstance(dedv, Derivative):
  1324. return dedv.func(dedv.expr, *(self.variable_count + dedv.variable_count))
  1325. # dedv (d(self.expr)/dv) could have simplified things such that the
  1326. # derivative wrt things in self.variables can now be done. Thus,
  1327. # we set evaluate=True to see if there are any other derivatives
  1328. # that can be done. The most common case is when dedv is a simple
  1329. # number so that the derivative wrt anything else will vanish.
  1330. return self.func(dedv, *self.variables, evaluate=True)
  1331. # In this case v was in self.variables so the derivative wrt v has
  1332. # already been attempted and was not computed, either because it
  1333. # couldn't be or evaluate=False originally.
  1334. variable_count = list(self.variable_count)
  1335. variable_count.append((v, 1))
  1336. return self.func(self.expr, *variable_count, evaluate=False)
  1337. def doit(self, **hints):
  1338. expr = self.expr
  1339. if hints.get('deep', True):
  1340. expr = expr.doit(**hints)
  1341. hints['evaluate'] = True
  1342. rv = self.func(expr, *self.variable_count, **hints)
  1343. if rv!= self and rv.has(Derivative):
  1344. rv = rv.doit(**hints)
  1345. return rv
  1346. @_sympifyit('z0', NotImplementedError)
  1347. def doit_numerically(self, z0):
  1348. """
  1349. Evaluate the derivative at z numerically.
  1350. When we can represent derivatives at a point, this should be folded
  1351. into the normal evalf. For now, we need a special method.
  1352. """
  1353. if len(self.free_symbols) != 1 or len(self.variables) != 1:
  1354. raise NotImplementedError('partials and higher order derivatives')
  1355. z = list(self.free_symbols)[0]
  1356. def eval(x):
  1357. f0 = self.expr.subs(z, Expr._from_mpmath(x, prec=mpmath.mp.prec))
  1358. f0 = f0.evalf(prec_to_dps(mpmath.mp.prec))
  1359. return f0._to_mpmath(mpmath.mp.prec)
  1360. return Expr._from_mpmath(mpmath.diff(eval,
  1361. z0._to_mpmath(mpmath.mp.prec)),
  1362. mpmath.mp.prec)
  1363. @property
  1364. def expr(self):
  1365. return self._args[0]
  1366. @property
  1367. def _wrt_variables(self):
  1368. # return the variables of differentiation without
  1369. # respect to the type of count (int or symbolic)
  1370. return [i[0] for i in self.variable_count]
  1371. @property
  1372. def variables(self):
  1373. # TODO: deprecate? YES, make this 'enumerated_variables' and
  1374. # name _wrt_variables as variables
  1375. # TODO: support for `d^n`?
  1376. rv = []
  1377. for v, count in self.variable_count:
  1378. if not count.is_Integer:
  1379. raise TypeError(filldedent('''
  1380. Cannot give expansion for symbolic count. If you just
  1381. want a list of all variables of differentiation, use
  1382. _wrt_variables.'''))
  1383. rv.extend([v]*count)
  1384. return tuple(rv)
  1385. @property
  1386. def variable_count(self):
  1387. return self._args[1:]
  1388. @property
  1389. def derivative_count(self):
  1390. return sum([count for _, count in self.variable_count], 0)
  1391. @property
  1392. def free_symbols(self):
  1393. ret = self.expr.free_symbols
  1394. # Add symbolic counts to free_symbols
  1395. for _, count in self.variable_count:
  1396. ret.update(count.free_symbols)
  1397. return ret
  1398. @property
  1399. def kind(self):
  1400. return self.args[0].kind
  1401. def _eval_subs(self, old, new):
  1402. # The substitution (old, new) cannot be done inside
  1403. # Derivative(expr, vars) for a variety of reasons
  1404. # as handled below.
  1405. if old in self._wrt_variables:
  1406. # first handle the counts
  1407. expr = self.func(self.expr, *[(v, c.subs(old, new))
  1408. for v, c in self.variable_count])
  1409. if expr != self:
  1410. return expr._eval_subs(old, new)
  1411. # quick exit case
  1412. if not getattr(new, '_diff_wrt', False):
  1413. # case (0): new is not a valid variable of
  1414. # differentiation
  1415. if isinstance(old, Symbol):
  1416. # don't introduce a new symbol if the old will do
  1417. return Subs(self, old, new)
  1418. else:
  1419. xi = Dummy('xi')
  1420. return Subs(self.xreplace({old: xi}), xi, new)
  1421. # If both are Derivatives with the same expr, check if old is
  1422. # equivalent to self or if old is a subderivative of self.
  1423. if old.is_Derivative and old.expr == self.expr:
  1424. if self.canonical == old.canonical:
  1425. return new
  1426. # collections.Counter doesn't have __le__
  1427. def _subset(a, b):
  1428. return all((a[i] <= b[i]) == True for i in a)
  1429. old_vars = Counter(dict(reversed(old.variable_count)))
  1430. self_vars = Counter(dict(reversed(self.variable_count)))
  1431. if _subset(old_vars, self_vars):
  1432. return _derivative_dispatch(new, *(self_vars - old_vars).items()).canonical
  1433. args = list(self.args)
  1434. newargs = [x._subs(old, new) for x in args]
  1435. if args[0] == old:
  1436. # complete replacement of self.expr
  1437. # we already checked that the new is valid so we know
  1438. # it won't be a problem should it appear in variables
  1439. return _derivative_dispatch(*newargs)
  1440. if newargs[0] != args[0]:
  1441. # case (1) can't change expr by introducing something that is in
  1442. # the _wrt_variables if it was already in the expr
  1443. # e.g.
  1444. # for Derivative(f(x, g(y)), y), x cannot be replaced with
  1445. # anything that has y in it; for f(g(x), g(y)).diff(g(y))
  1446. # g(x) cannot be replaced with anything that has g(y)
  1447. syms = {vi: Dummy() for vi in self._wrt_variables
  1448. if not vi.is_Symbol}
  1449. wrt = {syms.get(vi, vi) for vi in self._wrt_variables}
  1450. forbidden = args[0].xreplace(syms).free_symbols & wrt
  1451. nfree = new.xreplace(syms).free_symbols
  1452. ofree = old.xreplace(syms).free_symbols
  1453. if (nfree - ofree) & forbidden:
  1454. return Subs(self, old, new)
  1455. viter = ((i, j) for ((i, _), (j, _)) in zip(newargs[1:], args[1:]))
  1456. if any(i != j for i, j in viter): # a wrt-variable change
  1457. # case (2) can't change vars by introducing a variable
  1458. # that is contained in expr, e.g.
  1459. # for Derivative(f(z, g(h(x), y)), y), y cannot be changed to
  1460. # x, h(x), or g(h(x), y)
  1461. for a in _atomic(self.expr, recursive=True):
  1462. for i in range(1, len(newargs)):
  1463. vi, _ = newargs[i]
  1464. if a == vi and vi != args[i][0]:
  1465. return Subs(self, old, new)
  1466. # more arg-wise checks
  1467. vc = newargs[1:]
  1468. oldv = self._wrt_variables
  1469. newe = self.expr
  1470. subs = []
  1471. for i, (vi, ci) in enumerate(vc):
  1472. if not vi._diff_wrt:
  1473. # case (3) invalid differentiation expression so
  1474. # create a replacement dummy
  1475. xi = Dummy('xi_%i' % i)
  1476. # replace the old valid variable with the dummy
  1477. # in the expression
  1478. newe = newe.xreplace({oldv[i]: xi})
  1479. # and replace the bad variable with the dummy
  1480. vc[i] = (xi, ci)
  1481. # and record the dummy with the new (invalid)
  1482. # differentiation expression
  1483. subs.append((xi, vi))
  1484. if subs:
  1485. # handle any residual substitution in the expression
  1486. newe = newe._subs(old, new)
  1487. # return the Subs-wrapped derivative
  1488. return Subs(Derivative(newe, *vc), *zip(*subs))
  1489. # everything was ok
  1490. return _derivative_dispatch(*newargs)
  1491. def _eval_lseries(self, x, logx, cdir=0):
  1492. dx = self.variables
  1493. for term in self.expr.lseries(x, logx=logx, cdir=cdir):
  1494. yield self.func(term, *dx)
  1495. def _eval_nseries(self, x, n, logx, cdir=0):
  1496. arg = self.expr.nseries(x, n=n, logx=logx)
  1497. o = arg.getO()
  1498. dx = self.variables
  1499. rv = [self.func(a, *dx) for a in Add.make_args(arg.removeO())]
  1500. if o:
  1501. rv.append(o/x)
  1502. return Add(*rv)
  1503. def _eval_as_leading_term(self, x, logx=None, cdir=0):
  1504. series_gen = self.expr.lseries(x)
  1505. d = S.Zero
  1506. for leading_term in series_gen:
  1507. d = diff(leading_term, *self.variables)
  1508. if d != 0:
  1509. break
  1510. return d
  1511. def as_finite_difference(self, points=1, x0=None, wrt=None):
  1512. """ Expresses a Derivative instance as a finite difference.
  1513. Parameters
  1514. ==========
  1515. points : sequence or coefficient, optional
  1516. If sequence: discrete values (length >= order+1) of the
  1517. independent variable used for generating the finite
  1518. difference weights.
  1519. If it is a coefficient, it will be used as the step-size
  1520. for generating an equidistant sequence of length order+1
  1521. centered around ``x0``. Default: 1 (step-size 1)
  1522. x0 : number or Symbol, optional
  1523. the value of the independent variable (``wrt``) at which the
  1524. derivative is to be approximated. Default: same as ``wrt``.
  1525. wrt : Symbol, optional
  1526. "with respect to" the variable for which the (partial)
  1527. derivative is to be approximated for. If not provided it
  1528. is required that the derivative is ordinary. Default: ``None``.
  1529. Examples
  1530. ========
  1531. >>> from sympy import symbols, Function, exp, sqrt, Symbol
  1532. >>> x, h = symbols('x h')
  1533. >>> f = Function('f')
  1534. >>> f(x).diff(x).as_finite_difference()
  1535. -f(x - 1/2) + f(x + 1/2)
  1536. The default step size and number of points are 1 and
  1537. ``order + 1`` respectively. We can change the step size by
  1538. passing a symbol as a parameter:
  1539. >>> f(x).diff(x).as_finite_difference(h)
  1540. -f(-h/2 + x)/h + f(h/2 + x)/h
  1541. We can also specify the discretized values to be used in a
  1542. sequence:
  1543. >>> f(x).diff(x).as_finite_difference([x, x+h, x+2*h])
  1544. -3*f(x)/(2*h) + 2*f(h + x)/h - f(2*h + x)/(2*h)
  1545. The algorithm is not restricted to use equidistant spacing, nor
  1546. do we need to make the approximation around ``x0``, but we can get
  1547. an expression estimating the derivative at an offset:
  1548. >>> e, sq2 = exp(1), sqrt(2)
  1549. >>> xl = [x-h, x+h, x+e*h]
  1550. >>> f(x).diff(x, 1).as_finite_difference(xl, x+h*sq2) # doctest: +ELLIPSIS
  1551. 2*h*((h + sqrt(2)*h)/(2*h) - (-sqrt(2)*h + h)/(2*h))*f(E*h + x)/...
  1552. To approximate ``Derivative`` around ``x0`` using a non-equidistant
  1553. spacing step, the algorithm supports assignment of undefined
  1554. functions to ``points``:
  1555. >>> dx = Function('dx')
  1556. >>> f(x).diff(x).as_finite_difference(points=dx(x), x0=x-h)
  1557. -f(-h + x - dx(-h + x)/2)/dx(-h + x) + f(-h + x + dx(-h + x)/2)/dx(-h + x)
  1558. Partial derivatives are also supported:
  1559. >>> y = Symbol('y')
  1560. >>> d2fdxdy=f(x,y).diff(x,y)
  1561. >>> d2fdxdy.as_finite_difference(wrt=x)
  1562. -Derivative(f(x - 1/2, y), y) + Derivative(f(x + 1/2, y), y)
  1563. We can apply ``as_finite_difference`` to ``Derivative`` instances in
  1564. compound expressions using ``replace``:
  1565. >>> (1 + 42**f(x).diff(x)).replace(lambda arg: arg.is_Derivative,
  1566. ... lambda arg: arg.as_finite_difference())
  1567. 42**(-f(x - 1/2) + f(x + 1/2)) + 1
  1568. See also
  1569. ========
  1570. sympy.calculus.finite_diff.apply_finite_diff
  1571. sympy.calculus.finite_diff.differentiate_finite
  1572. sympy.calculus.finite_diff.finite_diff_weights
  1573. """
  1574. from sympy.calculus.finite_diff import _as_finite_diff
  1575. return _as_finite_diff(self, points, x0, wrt)
  1576. @classmethod
  1577. def _get_zero_with_shape_like(cls, expr):
  1578. return S.Zero
  1579. @classmethod
  1580. def _dispatch_eval_derivative_n_times(cls, expr, v, count):
  1581. # Evaluate the derivative `n` times. If
  1582. # `_eval_derivative_n_times` is not overridden by the current
  1583. # object, the default in `Basic` will call a loop over
  1584. # `_eval_derivative`:
  1585. return expr._eval_derivative_n_times(v, count)
  1586. def _derivative_dispatch(expr, *variables, **kwargs):
  1587. from sympy.matrices.common import MatrixCommon
  1588. from sympy.matrices.expressions.matexpr import MatrixExpr
  1589. from sympy.tensor.array import NDimArray
  1590. array_types = (MatrixCommon, MatrixExpr, NDimArray, list, tuple, Tuple)
  1591. if isinstance(expr, array_types) or any(isinstance(i[0], array_types) if isinstance(i, (tuple, list, Tuple)) else isinstance(i, array_types) for i in variables):
  1592. from sympy.tensor.array.array_derivatives import ArrayDerivative
  1593. return ArrayDerivative(expr, *variables, **kwargs)
  1594. return Derivative(expr, *variables, **kwargs)
  1595. class Lambda(Expr):
  1596. """
  1597. Lambda(x, expr) represents a lambda function similar to Python's
  1598. 'lambda x: expr'. A function of several variables is written as
  1599. Lambda((x, y, ...), expr).
  1600. Examples
  1601. ========
  1602. A simple example:
  1603. >>> from sympy import Lambda
  1604. >>> from sympy.abc import x
  1605. >>> f = Lambda(x, x**2)
  1606. >>> f(4)
  1607. 16
  1608. For multivariate functions, use:
  1609. >>> from sympy.abc import y, z, t
  1610. >>> f2 = Lambda((x, y, z, t), x + y**z + t**z)
  1611. >>> f2(1, 2, 3, 4)
  1612. 73
  1613. It is also possible to unpack tuple arguments:
  1614. >>> f = Lambda(((x, y), z), x + y + z)
  1615. >>> f((1, 2), 3)
  1616. 6
  1617. A handy shortcut for lots of arguments:
  1618. >>> p = x, y, z
  1619. >>> f = Lambda(p, x + y*z)
  1620. >>> f(*p)
  1621. x + y*z
  1622. """
  1623. is_Function = True
  1624. def __new__(cls, signature, expr):
  1625. if iterable(signature) and not isinstance(signature, (tuple, Tuple)):
  1626. sympy_deprecation_warning(
  1627. """
  1628. Using a non-tuple iterable as the first argument to Lambda
  1629. is deprecated. Use Lambda(tuple(args), expr) instead.
  1630. """,
  1631. deprecated_since_version="1.5",
  1632. active_deprecations_target="deprecated-non-tuple-lambda",
  1633. )
  1634. signature = tuple(signature)
  1635. sig = signature if iterable(signature) else (signature,)
  1636. sig = sympify(sig)
  1637. cls._check_signature(sig)
  1638. if len(sig) == 1 and sig[0] == expr:
  1639. return S.IdentityFunction
  1640. return Expr.__new__(cls, sig, sympify(expr))
  1641. @classmethod
  1642. def _check_signature(cls, sig):
  1643. syms = set()
  1644. def rcheck(args):
  1645. for a in args:
  1646. if a.is_symbol:
  1647. if a in syms:
  1648. raise BadSignatureError("Duplicate symbol %s" % a)
  1649. syms.add(a)
  1650. elif isinstance(a, Tuple):
  1651. rcheck(a)
  1652. else:
  1653. raise BadSignatureError("Lambda signature should be only tuples"
  1654. " and symbols, not %s" % a)
  1655. if not isinstance(sig, Tuple):
  1656. raise BadSignatureError("Lambda signature should be a tuple not %s" % sig)
  1657. # Recurse through the signature:
  1658. rcheck(sig)
  1659. @property
  1660. def signature(self):
  1661. """The expected form of the arguments to be unpacked into variables"""
  1662. return self._args[0]
  1663. @property
  1664. def expr(self):
  1665. """The return value of the function"""
  1666. return self._args[1]
  1667. @property
  1668. def variables(self):
  1669. """The variables used in the internal representation of the function"""
  1670. def _variables(args):
  1671. if isinstance(args, Tuple):
  1672. for arg in args:
  1673. yield from _variables(arg)
  1674. else:
  1675. yield args
  1676. return tuple(_variables(self.signature))
  1677. @property
  1678. def nargs(self):
  1679. from sympy.sets.sets import FiniteSet
  1680. return FiniteSet(len(self.signature))
  1681. bound_symbols = variables
  1682. @property
  1683. def free_symbols(self):
  1684. return self.expr.free_symbols - set(self.variables)
  1685. def __call__(self, *args):
  1686. n = len(args)
  1687. if n not in self.nargs: # Lambda only ever has 1 value in nargs
  1688. # XXX: exception message must be in exactly this format to
  1689. # make it work with NumPy's functions like vectorize(). See,
  1690. # for example, https://github.com/numpy/numpy/issues/1697.
  1691. # The ideal solution would be just to attach metadata to
  1692. # the exception and change NumPy to take advantage of this.
  1693. ## XXX does this apply to Lambda? If not, remove this comment.
  1694. temp = ('%(name)s takes exactly %(args)s '
  1695. 'argument%(plural)s (%(given)s given)')
  1696. raise BadArgumentsError(temp % {
  1697. 'name': self,
  1698. 'args': list(self.nargs)[0],
  1699. 'plural': 's'*(list(self.nargs)[0] != 1),
  1700. 'given': n})
  1701. d = self._match_signature(self.signature, args)
  1702. return self.expr.xreplace(d)
  1703. def _match_signature(self, sig, args):
  1704. symargmap = {}
  1705. def rmatch(pars, args):
  1706. for par, arg in zip(pars, args):
  1707. if par.is_symbol:
  1708. symargmap[par] = arg
  1709. elif isinstance(par, Tuple):
  1710. if not isinstance(arg, (tuple, Tuple)) or len(args) != len(pars):
  1711. raise BadArgumentsError("Can't match %s and %s" % (args, pars))
  1712. rmatch(par, arg)
  1713. rmatch(sig, args)
  1714. return symargmap
  1715. @property
  1716. def is_identity(self):
  1717. """Return ``True`` if this ``Lambda`` is an identity function. """
  1718. return self.signature == self.expr
  1719. def _eval_evalf(self, prec):
  1720. return self.func(self.args[0], self.args[1].evalf(n=prec_to_dps(prec)))
  1721. class Subs(Expr):
  1722. """
  1723. Represents unevaluated substitutions of an expression.
  1724. ``Subs(expr, x, x0)`` represents the expression resulting
  1725. from substituting x with x0 in expr.
  1726. Parameters
  1727. ==========
  1728. expr : Expr
  1729. An expression.
  1730. x : tuple, variable
  1731. A variable or list of distinct variables.
  1732. x0 : tuple or list of tuples
  1733. A point or list of evaluation points
  1734. corresponding to those variables.
  1735. Examples
  1736. ========
  1737. >>> from sympy import Subs, Function, sin, cos
  1738. >>> from sympy.abc import x, y, z
  1739. >>> f = Function('f')
  1740. Subs are created when a particular substitution cannot be made. The
  1741. x in the derivative cannot be replaced with 0 because 0 is not a
  1742. valid variables of differentiation:
  1743. >>> f(x).diff(x).subs(x, 0)
  1744. Subs(Derivative(f(x), x), x, 0)
  1745. Once f is known, the derivative and evaluation at 0 can be done:
  1746. >>> _.subs(f, sin).doit() == sin(x).diff(x).subs(x, 0) == cos(0)
  1747. True
  1748. Subs can also be created directly with one or more variables:
  1749. >>> Subs(f(x)*sin(y) + z, (x, y), (0, 1))
  1750. Subs(z + f(x)*sin(y), (x, y), (0, 1))
  1751. >>> _.doit()
  1752. z + f(0)*sin(1)
  1753. Notes
  1754. =====
  1755. ``Subs`` objects are generally useful to represent unevaluated derivatives
  1756. calculated at a point.
  1757. The variables may be expressions, but they are subjected to the limitations
  1758. of subs(), so it is usually a good practice to use only symbols for
  1759. variables, since in that case there can be no ambiguity.
  1760. There's no automatic expansion - use the method .doit() to effect all
  1761. possible substitutions of the object and also of objects inside the
  1762. expression.
  1763. When evaluating derivatives at a point that is not a symbol, a Subs object
  1764. is returned. One is also able to calculate derivatives of Subs objects - in
  1765. this case the expression is always expanded (for the unevaluated form, use
  1766. Derivative()).
  1767. In order to allow expressions to combine before doit is done, a
  1768. representation of the Subs expression is used internally to make
  1769. expressions that are superficially different compare the same:
  1770. >>> a, b = Subs(x, x, 0), Subs(y, y, 0)
  1771. >>> a + b
  1772. 2*Subs(x, x, 0)
  1773. This can lead to unexpected consequences when using methods
  1774. like `has` that are cached:
  1775. >>> s = Subs(x, x, 0)
  1776. >>> s.has(x), s.has(y)
  1777. (True, False)
  1778. >>> ss = s.subs(x, y)
  1779. >>> ss.has(x), ss.has(y)
  1780. (True, False)
  1781. >>> s, ss
  1782. (Subs(x, x, 0), Subs(y, y, 0))
  1783. """
  1784. def __new__(cls, expr, variables, point, **assumptions):
  1785. if not is_sequence(variables, Tuple):
  1786. variables = [variables]
  1787. variables = Tuple(*variables)
  1788. if has_dups(variables):
  1789. repeated = [str(v) for v, i in Counter(variables).items() if i > 1]
  1790. __ = ', '.join(repeated)
  1791. raise ValueError(filldedent('''
  1792. The following expressions appear more than once: %s
  1793. ''' % __))
  1794. point = Tuple(*(point if is_sequence(point, Tuple) else [point]))
  1795. if len(point) != len(variables):
  1796. raise ValueError('Number of point values must be the same as '
  1797. 'the number of variables.')
  1798. if not point:
  1799. return sympify(expr)
  1800. # denest
  1801. if isinstance(expr, Subs):
  1802. variables = expr.variables + variables
  1803. point = expr.point + point
  1804. expr = expr.expr
  1805. else:
  1806. expr = sympify(expr)
  1807. # use symbols with names equal to the point value (with prepended _)
  1808. # to give a variable-independent expression
  1809. pre = "_"
  1810. pts = sorted(set(point), key=default_sort_key)
  1811. from sympy.printing.str import StrPrinter
  1812. class CustomStrPrinter(StrPrinter):
  1813. def _print_Dummy(self, expr):
  1814. return str(expr) + str(expr.dummy_index)
  1815. def mystr(expr, **settings):
  1816. p = CustomStrPrinter(settings)
  1817. return p.doprint(expr)
  1818. while 1:
  1819. s_pts = {p: Symbol(pre + mystr(p)) for p in pts}
  1820. reps = [(v, s_pts[p])
  1821. for v, p in zip(variables, point)]
  1822. # if any underscore-prepended symbol is already a free symbol
  1823. # and is a variable with a different point value, then there
  1824. # is a clash, e.g. _0 clashes in Subs(_0 + _1, (_0, _1), (1, 0))
  1825. # because the new symbol that would be created is _1 but _1
  1826. # is already mapped to 0 so __0 and __1 are used for the new
  1827. # symbols
  1828. if any(r in expr.free_symbols and
  1829. r in variables and
  1830. Symbol(pre + mystr(point[variables.index(r)])) != r
  1831. for _, r in reps):
  1832. pre += "_"
  1833. continue
  1834. break
  1835. obj = Expr.__new__(cls, expr, Tuple(*variables), point)
  1836. obj._expr = expr.xreplace(dict(reps))
  1837. return obj
  1838. def _eval_is_commutative(self):
  1839. return self.expr.is_commutative
  1840. def doit(self, **hints):
  1841. e, v, p = self.args
  1842. # remove self mappings
  1843. for i, (vi, pi) in enumerate(zip(v, p)):
  1844. if vi == pi:
  1845. v = v[:i] + v[i + 1:]
  1846. p = p[:i] + p[i + 1:]
  1847. if not v:
  1848. return self.expr
  1849. if isinstance(e, Derivative):
  1850. # apply functions first, e.g. f -> cos
  1851. undone = []
  1852. for i, vi in enumerate(v):
  1853. if isinstance(vi, FunctionClass):
  1854. e = e.subs(vi, p[i])
  1855. else:
  1856. undone.append((vi, p[i]))
  1857. if not isinstance(e, Derivative):
  1858. e = e.doit()
  1859. if isinstance(e, Derivative):
  1860. # do Subs that aren't related to differentiation
  1861. undone2 = []
  1862. D = Dummy()
  1863. arg = e.args[0]
  1864. for vi, pi in undone:
  1865. if D not in e.xreplace({vi: D}).free_symbols:
  1866. if arg.has(vi):
  1867. e = e.subs(vi, pi)
  1868. else:
  1869. undone2.append((vi, pi))
  1870. undone = undone2
  1871. # differentiate wrt variables that are present
  1872. wrt = []
  1873. D = Dummy()
  1874. expr = e.expr
  1875. free = expr.free_symbols
  1876. for vi, ci in e.variable_count:
  1877. if isinstance(vi, Symbol) and vi in free:
  1878. expr = expr.diff((vi, ci))
  1879. elif D in expr.subs(vi, D).free_symbols:
  1880. expr = expr.diff((vi, ci))
  1881. else:
  1882. wrt.append((vi, ci))
  1883. # inject remaining subs
  1884. rv = expr.subs(undone)
  1885. # do remaining differentiation *in order given*
  1886. for vc in wrt:
  1887. rv = rv.diff(vc)
  1888. else:
  1889. # inject remaining subs
  1890. rv = e.subs(undone)
  1891. else:
  1892. rv = e.doit(**hints).subs(list(zip(v, p)))
  1893. if hints.get('deep', True) and rv != self:
  1894. rv = rv.doit(**hints)
  1895. return rv
  1896. def evalf(self, prec=None, **options):
  1897. return self.doit().evalf(prec, **options)
  1898. n = evalf # type:ignore
  1899. @property
  1900. def variables(self):
  1901. """The variables to be evaluated"""
  1902. return self._args[1]
  1903. bound_symbols = variables
  1904. @property
  1905. def expr(self):
  1906. """The expression on which the substitution operates"""
  1907. return self._args[0]
  1908. @property
  1909. def point(self):
  1910. """The values for which the variables are to be substituted"""
  1911. return self._args[2]
  1912. @property
  1913. def free_symbols(self):
  1914. return (self.expr.free_symbols - set(self.variables) |
  1915. set(self.point.free_symbols))
  1916. @property
  1917. def expr_free_symbols(self):
  1918. sympy_deprecation_warning("""
  1919. The expr_free_symbols property is deprecated. Use free_symbols to get
  1920. the free symbols of an expression.
  1921. """,
  1922. deprecated_since_version="1.9",
  1923. active_deprecations_target="deprecated-expr-free-symbols")
  1924. # Don't show the warning twice from the recursive call
  1925. with ignore_warnings(SymPyDeprecationWarning):
  1926. return (self.expr.expr_free_symbols - set(self.variables) |
  1927. set(self.point.expr_free_symbols))
  1928. def __eq__(self, other):
  1929. if not isinstance(other, Subs):
  1930. return False
  1931. return self._hashable_content() == other._hashable_content()
  1932. def __ne__(self, other):
  1933. return not(self == other)
  1934. def __hash__(self):
  1935. return super().__hash__()
  1936. def _hashable_content(self):
  1937. return (self._expr.xreplace(self.canonical_variables),
  1938. ) + tuple(ordered([(v, p) for v, p in
  1939. zip(self.variables, self.point) if not self.expr.has(v)]))
  1940. def _eval_subs(self, old, new):
  1941. # Subs doit will do the variables in order; the semantics
  1942. # of subs for Subs is have the following invariant for
  1943. # Subs object foo:
  1944. # foo.doit().subs(reps) == foo.subs(reps).doit()
  1945. pt = list(self.point)
  1946. if old in self.variables:
  1947. if _atomic(new) == {new} and not any(
  1948. i.has(new) for i in self.args):
  1949. # the substitution is neutral
  1950. return self.xreplace({old: new})
  1951. # any occurrence of old before this point will get
  1952. # handled by replacements from here on
  1953. i = self.variables.index(old)
  1954. for j in range(i, len(self.variables)):
  1955. pt[j] = pt[j]._subs(old, new)
  1956. return self.func(self.expr, self.variables, pt)
  1957. v = [i._subs(old, new) for i in self.variables]
  1958. if v != list(self.variables):
  1959. return self.func(self.expr, self.variables + (old,), pt + [new])
  1960. expr = self.expr._subs(old, new)
  1961. pt = [i._subs(old, new) for i in self.point]
  1962. return self.func(expr, v, pt)
  1963. def _eval_derivative(self, s):
  1964. # Apply the chain rule of the derivative on the substitution variables:
  1965. f = self.expr
  1966. vp = V, P = self.variables, self.point
  1967. val = Add.fromiter(p.diff(s)*Subs(f.diff(v), *vp).doit()
  1968. for v, p in zip(V, P))
  1969. # these are all the free symbols in the expr
  1970. efree = f.free_symbols
  1971. # some symbols like IndexedBase include themselves and args
  1972. # as free symbols
  1973. compound = {i for i in efree if len(i.free_symbols) > 1}
  1974. # hide them and see what independent free symbols remain
  1975. dums = {Dummy() for i in compound}
  1976. masked = f.xreplace(dict(zip(compound, dums)))
  1977. ifree = masked.free_symbols - dums
  1978. # include the compound symbols
  1979. free = ifree | compound
  1980. # remove the variables already handled
  1981. free -= set(V)
  1982. # add back any free symbols of remaining compound symbols
  1983. free |= {i for j in free & compound for i in j.free_symbols}
  1984. # if symbols of s are in free then there is more to do
  1985. if free & s.free_symbols:
  1986. val += Subs(f.diff(s), self.variables, self.point).doit()
  1987. return val
  1988. def _eval_nseries(self, x, n, logx, cdir=0):
  1989. if x in self.point:
  1990. # x is the variable being substituted into
  1991. apos = self.point.index(x)
  1992. other = self.variables[apos]
  1993. else:
  1994. other = x
  1995. arg = self.expr.nseries(other, n=n, logx=logx)
  1996. o = arg.getO()
  1997. terms = Add.make_args(arg.removeO())
  1998. rv = Add(*[self.func(a, *self.args[1:]) for a in terms])
  1999. if o:
  2000. rv += o.subs(other, x)
  2001. return rv
  2002. def _eval_as_leading_term(self, x, logx=None, cdir=0):
  2003. if x in self.point:
  2004. ipos = self.point.index(x)
  2005. xvar = self.variables[ipos]
  2006. return self.expr.as_leading_term(xvar)
  2007. if x in self.variables:
  2008. # if `x` is a dummy variable, it means it won't exist after the
  2009. # substitution has been performed:
  2010. return self
  2011. # The variable is independent of the substitution:
  2012. return self.expr.as_leading_term(x)
  2013. def diff(f, *symbols, **kwargs):
  2014. """
  2015. Differentiate f with respect to symbols.
  2016. Explanation
  2017. ===========
  2018. This is just a wrapper to unify .diff() and the Derivative class; its
  2019. interface is similar to that of integrate(). You can use the same
  2020. shortcuts for multiple variables as with Derivative. For example,
  2021. diff(f(x), x, x, x) and diff(f(x), x, 3) both return the third derivative
  2022. of f(x).
  2023. You can pass evaluate=False to get an unevaluated Derivative class. Note
  2024. that if there are 0 symbols (such as diff(f(x), x, 0), then the result will
  2025. be the function (the zeroth derivative), even if evaluate=False.
  2026. Examples
  2027. ========
  2028. >>> from sympy import sin, cos, Function, diff
  2029. >>> from sympy.abc import x, y
  2030. >>> f = Function('f')
  2031. >>> diff(sin(x), x)
  2032. cos(x)
  2033. >>> diff(f(x), x, x, x)
  2034. Derivative(f(x), (x, 3))
  2035. >>> diff(f(x), x, 3)
  2036. Derivative(f(x), (x, 3))
  2037. >>> diff(sin(x)*cos(y), x, 2, y, 2)
  2038. sin(x)*cos(y)
  2039. >>> type(diff(sin(x), x))
  2040. cos
  2041. >>> type(diff(sin(x), x, evaluate=False))
  2042. <class 'sympy.core.function.Derivative'>
  2043. >>> type(diff(sin(x), x, 0))
  2044. sin
  2045. >>> type(diff(sin(x), x, 0, evaluate=False))
  2046. sin
  2047. >>> diff(sin(x))
  2048. cos(x)
  2049. >>> diff(sin(x*y))
  2050. Traceback (most recent call last):
  2051. ...
  2052. ValueError: specify differentiation variables to differentiate sin(x*y)
  2053. Note that ``diff(sin(x))`` syntax is meant only for convenience
  2054. in interactive sessions and should be avoided in library code.
  2055. References
  2056. ==========
  2057. .. [1] https://reference.wolfram.com/legacy/v5_2/Built-inFunctions/AlgebraicComputation/Calculus/D.html
  2058. See Also
  2059. ========
  2060. Derivative
  2061. idiff: computes the derivative implicitly
  2062. """
  2063. if hasattr(f, 'diff'):
  2064. return f.diff(*symbols, **kwargs)
  2065. kwargs.setdefault('evaluate', True)
  2066. return _derivative_dispatch(f, *symbols, **kwargs)
  2067. def expand(e, deep=True, modulus=None, power_base=True, power_exp=True,
  2068. mul=True, log=True, multinomial=True, basic=True, **hints):
  2069. r"""
  2070. Expand an expression using methods given as hints.
  2071. Explanation
  2072. ===========
  2073. Hints evaluated unless explicitly set to False are: ``basic``, ``log``,
  2074. ``multinomial``, ``mul``, ``power_base``, and ``power_exp`` The following
  2075. hints are supported but not applied unless set to True: ``complex``,
  2076. ``func``, and ``trig``. In addition, the following meta-hints are
  2077. supported by some or all of the other hints: ``frac``, ``numer``,
  2078. ``denom``, ``modulus``, and ``force``. ``deep`` is supported by all
  2079. hints. Additionally, subclasses of Expr may define their own hints or
  2080. meta-hints.
  2081. The ``basic`` hint is used for any special rewriting of an object that
  2082. should be done automatically (along with the other hints like ``mul``)
  2083. when expand is called. This is a catch-all hint to handle any sort of
  2084. expansion that may not be described by the existing hint names. To use
  2085. this hint an object should override the ``_eval_expand_basic`` method.
  2086. Objects may also define their own expand methods, which are not run by
  2087. default. See the API section below.
  2088. If ``deep`` is set to ``True`` (the default), things like arguments of
  2089. functions are recursively expanded. Use ``deep=False`` to only expand on
  2090. the top level.
  2091. If the ``force`` hint is used, assumptions about variables will be ignored
  2092. in making the expansion.
  2093. Hints
  2094. =====
  2095. These hints are run by default
  2096. mul
  2097. ---
  2098. Distributes multiplication over addition:
  2099. >>> from sympy import cos, exp, sin
  2100. >>> from sympy.abc import x, y, z
  2101. >>> (y*(x + z)).expand(mul=True)
  2102. x*y + y*z
  2103. multinomial
  2104. -----------
  2105. Expand (x + y + ...)**n where n is a positive integer.
  2106. >>> ((x + y + z)**2).expand(multinomial=True)
  2107. x**2 + 2*x*y + 2*x*z + y**2 + 2*y*z + z**2
  2108. power_exp
  2109. ---------
  2110. Expand addition in exponents into multiplied bases.
  2111. >>> exp(x + y).expand(power_exp=True)
  2112. exp(x)*exp(y)
  2113. >>> (2**(x + y)).expand(power_exp=True)
  2114. 2**x*2**y
  2115. power_base
  2116. ----------
  2117. Split powers of multiplied bases.
  2118. This only happens by default if assumptions allow, or if the
  2119. ``force`` meta-hint is used:
  2120. >>> ((x*y)**z).expand(power_base=True)
  2121. (x*y)**z
  2122. >>> ((x*y)**z).expand(power_base=True, force=True)
  2123. x**z*y**z
  2124. >>> ((2*y)**z).expand(power_base=True)
  2125. 2**z*y**z
  2126. Note that in some cases where this expansion always holds, SymPy performs
  2127. it automatically:
  2128. >>> (x*y)**2
  2129. x**2*y**2
  2130. log
  2131. ---
  2132. Pull out power of an argument as a coefficient and split logs products
  2133. into sums of logs.
  2134. Note that these only work if the arguments of the log function have the
  2135. proper assumptions--the arguments must be positive and the exponents must
  2136. be real--or else the ``force`` hint must be True:
  2137. >>> from sympy import log, symbols
  2138. >>> log(x**2*y).expand(log=True)
  2139. log(x**2*y)
  2140. >>> log(x**2*y).expand(log=True, force=True)
  2141. 2*log(x) + log(y)
  2142. >>> x, y = symbols('x,y', positive=True)
  2143. >>> log(x**2*y).expand(log=True)
  2144. 2*log(x) + log(y)
  2145. basic
  2146. -----
  2147. This hint is intended primarily as a way for custom subclasses to enable
  2148. expansion by default.
  2149. These hints are not run by default:
  2150. complex
  2151. -------
  2152. Split an expression into real and imaginary parts.
  2153. >>> x, y = symbols('x,y')
  2154. >>> (x + y).expand(complex=True)
  2155. re(x) + re(y) + I*im(x) + I*im(y)
  2156. >>> cos(x).expand(complex=True)
  2157. -I*sin(re(x))*sinh(im(x)) + cos(re(x))*cosh(im(x))
  2158. Note that this is just a wrapper around ``as_real_imag()``. Most objects
  2159. that wish to redefine ``_eval_expand_complex()`` should consider
  2160. redefining ``as_real_imag()`` instead.
  2161. func
  2162. ----
  2163. Expand other functions.
  2164. >>> from sympy import gamma
  2165. >>> gamma(x + 1).expand(func=True)
  2166. x*gamma(x)
  2167. trig
  2168. ----
  2169. Do trigonometric expansions.
  2170. >>> cos(x + y).expand(trig=True)
  2171. -sin(x)*sin(y) + cos(x)*cos(y)
  2172. >>> sin(2*x).expand(trig=True)
  2173. 2*sin(x)*cos(x)
  2174. Note that the forms of ``sin(n*x)`` and ``cos(n*x)`` in terms of ``sin(x)``
  2175. and ``cos(x)`` are not unique, due to the identity `\sin^2(x) + \cos^2(x)
  2176. = 1`. The current implementation uses the form obtained from Chebyshev
  2177. polynomials, but this may change. See `this MathWorld article
  2178. <https://mathworld.wolfram.com/Multiple-AngleFormulas.html>`_ for more
  2179. information.
  2180. Notes
  2181. =====
  2182. - You can shut off unwanted methods::
  2183. >>> (exp(x + y)*(x + y)).expand()
  2184. x*exp(x)*exp(y) + y*exp(x)*exp(y)
  2185. >>> (exp(x + y)*(x + y)).expand(power_exp=False)
  2186. x*exp(x + y) + y*exp(x + y)
  2187. >>> (exp(x + y)*(x + y)).expand(mul=False)
  2188. (x + y)*exp(x)*exp(y)
  2189. - Use deep=False to only expand on the top level::
  2190. >>> exp(x + exp(x + y)).expand()
  2191. exp(x)*exp(exp(x)*exp(y))
  2192. >>> exp(x + exp(x + y)).expand(deep=False)
  2193. exp(x)*exp(exp(x + y))
  2194. - Hints are applied in an arbitrary, but consistent order (in the current
  2195. implementation, they are applied in alphabetical order, except
  2196. multinomial comes before mul, but this may change). Because of this,
  2197. some hints may prevent expansion by other hints if they are applied
  2198. first. For example, ``mul`` may distribute multiplications and prevent
  2199. ``log`` and ``power_base`` from expanding them. Also, if ``mul`` is
  2200. applied before ``multinomial`, the expression might not be fully
  2201. distributed. The solution is to use the various ``expand_hint`` helper
  2202. functions or to use ``hint=False`` to this function to finely control
  2203. which hints are applied. Here are some examples::
  2204. >>> from sympy import expand, expand_mul, expand_power_base
  2205. >>> x, y, z = symbols('x,y,z', positive=True)
  2206. >>> expand(log(x*(y + z)))
  2207. log(x) + log(y + z)
  2208. Here, we see that ``log`` was applied before ``mul``. To get the mul
  2209. expanded form, either of the following will work::
  2210. >>> expand_mul(log(x*(y + z)))
  2211. log(x*y + x*z)
  2212. >>> expand(log(x*(y + z)), log=False)
  2213. log(x*y + x*z)
  2214. A similar thing can happen with the ``power_base`` hint::
  2215. >>> expand((x*(y + z))**x)
  2216. (x*y + x*z)**x
  2217. To get the ``power_base`` expanded form, either of the following will
  2218. work::
  2219. >>> expand((x*(y + z))**x, mul=False)
  2220. x**x*(y + z)**x
  2221. >>> expand_power_base((x*(y + z))**x)
  2222. x**x*(y + z)**x
  2223. >>> expand((x + y)*y/x)
  2224. y + y**2/x
  2225. The parts of a rational expression can be targeted::
  2226. >>> expand((x + y)*y/x/(x + 1), frac=True)
  2227. (x*y + y**2)/(x**2 + x)
  2228. >>> expand((x + y)*y/x/(x + 1), numer=True)
  2229. (x*y + y**2)/(x*(x + 1))
  2230. >>> expand((x + y)*y/x/(x + 1), denom=True)
  2231. y*(x + y)/(x**2 + x)
  2232. - The ``modulus`` meta-hint can be used to reduce the coefficients of an
  2233. expression post-expansion::
  2234. >>> expand((3*x + 1)**2)
  2235. 9*x**2 + 6*x + 1
  2236. >>> expand((3*x + 1)**2, modulus=5)
  2237. 4*x**2 + x + 1
  2238. - Either ``expand()`` the function or ``.expand()`` the method can be
  2239. used. Both are equivalent::
  2240. >>> expand((x + 1)**2)
  2241. x**2 + 2*x + 1
  2242. >>> ((x + 1)**2).expand()
  2243. x**2 + 2*x + 1
  2244. API
  2245. ===
  2246. Objects can define their own expand hints by defining
  2247. ``_eval_expand_hint()``. The function should take the form::
  2248. def _eval_expand_hint(self, **hints):
  2249. # Only apply the method to the top-level expression
  2250. ...
  2251. See also the example below. Objects should define ``_eval_expand_hint()``
  2252. methods only if ``hint`` applies to that specific object. The generic
  2253. ``_eval_expand_hint()`` method defined in Expr will handle the no-op case.
  2254. Each hint should be responsible for expanding that hint only.
  2255. Furthermore, the expansion should be applied to the top-level expression
  2256. only. ``expand()`` takes care of the recursion that happens when
  2257. ``deep=True``.
  2258. You should only call ``_eval_expand_hint()`` methods directly if you are
  2259. 100% sure that the object has the method, as otherwise you are liable to
  2260. get unexpected ``AttributeError``s. Note, again, that you do not need to
  2261. recursively apply the hint to args of your object: this is handled
  2262. automatically by ``expand()``. ``_eval_expand_hint()`` should
  2263. generally not be used at all outside of an ``_eval_expand_hint()`` method.
  2264. If you want to apply a specific expansion from within another method, use
  2265. the public ``expand()`` function, method, or ``expand_hint()`` functions.
  2266. In order for expand to work, objects must be rebuildable by their args,
  2267. i.e., ``obj.func(*obj.args) == obj`` must hold.
  2268. Expand methods are passed ``**hints`` so that expand hints may use
  2269. 'metahints'--hints that control how different expand methods are applied.
  2270. For example, the ``force=True`` hint described above that causes
  2271. ``expand(log=True)`` to ignore assumptions is such a metahint. The
  2272. ``deep`` meta-hint is handled exclusively by ``expand()`` and is not
  2273. passed to ``_eval_expand_hint()`` methods.
  2274. Note that expansion hints should generally be methods that perform some
  2275. kind of 'expansion'. For hints that simply rewrite an expression, use the
  2276. .rewrite() API.
  2277. Examples
  2278. ========
  2279. >>> from sympy import Expr, sympify
  2280. >>> class MyClass(Expr):
  2281. ... def __new__(cls, *args):
  2282. ... args = sympify(args)
  2283. ... return Expr.__new__(cls, *args)
  2284. ...
  2285. ... def _eval_expand_double(self, *, force=False, **hints):
  2286. ... '''
  2287. ... Doubles the args of MyClass.
  2288. ...
  2289. ... If there more than four args, doubling is not performed,
  2290. ... unless force=True is also used (False by default).
  2291. ... '''
  2292. ... if not force and len(self.args) > 4:
  2293. ... return self
  2294. ... return self.func(*(self.args + self.args))
  2295. ...
  2296. >>> a = MyClass(1, 2, MyClass(3, 4))
  2297. >>> a
  2298. MyClass(1, 2, MyClass(3, 4))
  2299. >>> a.expand(double=True)
  2300. MyClass(1, 2, MyClass(3, 4, 3, 4), 1, 2, MyClass(3, 4, 3, 4))
  2301. >>> a.expand(double=True, deep=False)
  2302. MyClass(1, 2, MyClass(3, 4), 1, 2, MyClass(3, 4))
  2303. >>> b = MyClass(1, 2, 3, 4, 5)
  2304. >>> b.expand(double=True)
  2305. MyClass(1, 2, 3, 4, 5)
  2306. >>> b.expand(double=True, force=True)
  2307. MyClass(1, 2, 3, 4, 5, 1, 2, 3, 4, 5)
  2308. See Also
  2309. ========
  2310. expand_log, expand_mul, expand_multinomial, expand_complex, expand_trig,
  2311. expand_power_base, expand_power_exp, expand_func, sympy.simplify.hyperexpand.hyperexpand
  2312. """
  2313. # don't modify this; modify the Expr.expand method
  2314. hints['power_base'] = power_base
  2315. hints['power_exp'] = power_exp
  2316. hints['mul'] = mul
  2317. hints['log'] = log
  2318. hints['multinomial'] = multinomial
  2319. hints['basic'] = basic
  2320. return sympify(e).expand(deep=deep, modulus=modulus, **hints)
  2321. # This is a special application of two hints
  2322. def _mexpand(expr, recursive=False):
  2323. # expand multinomials and then expand products; this may not always
  2324. # be sufficient to give a fully expanded expression (see
  2325. # test_issue_8247_8354 in test_arit)
  2326. if expr is None:
  2327. return
  2328. was = None
  2329. while was != expr:
  2330. was, expr = expr, expand_mul(expand_multinomial(expr))
  2331. if not recursive:
  2332. break
  2333. return expr
  2334. # These are simple wrappers around single hints.
  2335. def expand_mul(expr, deep=True):
  2336. """
  2337. Wrapper around expand that only uses the mul hint. See the expand
  2338. docstring for more information.
  2339. Examples
  2340. ========
  2341. >>> from sympy import symbols, expand_mul, exp, log
  2342. >>> x, y = symbols('x,y', positive=True)
  2343. >>> expand_mul(exp(x+y)*(x+y)*log(x*y**2))
  2344. x*exp(x + y)*log(x*y**2) + y*exp(x + y)*log(x*y**2)
  2345. """
  2346. return sympify(expr).expand(deep=deep, mul=True, power_exp=False,
  2347. power_base=False, basic=False, multinomial=False, log=False)
  2348. def expand_multinomial(expr, deep=True):
  2349. """
  2350. Wrapper around expand that only uses the multinomial hint. See the expand
  2351. docstring for more information.
  2352. Examples
  2353. ========
  2354. >>> from sympy import symbols, expand_multinomial, exp
  2355. >>> x, y = symbols('x y', positive=True)
  2356. >>> expand_multinomial((x + exp(x + 1))**2)
  2357. x**2 + 2*x*exp(x + 1) + exp(2*x + 2)
  2358. """
  2359. return sympify(expr).expand(deep=deep, mul=False, power_exp=False,
  2360. power_base=False, basic=False, multinomial=True, log=False)
  2361. def expand_log(expr, deep=True, force=False, factor=False):
  2362. """
  2363. Wrapper around expand that only uses the log hint. See the expand
  2364. docstring for more information.
  2365. Examples
  2366. ========
  2367. >>> from sympy import symbols, expand_log, exp, log
  2368. >>> x, y = symbols('x,y', positive=True)
  2369. >>> expand_log(exp(x+y)*(x+y)*log(x*y**2))
  2370. (x + y)*(log(x) + 2*log(y))*exp(x + y)
  2371. """
  2372. from sympy.functions.elementary.exponential import log
  2373. if factor is False:
  2374. def _handle(x):
  2375. x1 = expand_mul(expand_log(x, deep=deep, force=force, factor=True))
  2376. if x1.count(log) <= x.count(log):
  2377. return x1
  2378. return x
  2379. expr = expr.replace(
  2380. lambda x: x.is_Mul and all(any(isinstance(i, log) and i.args[0].is_Rational
  2381. for i in Mul.make_args(j)) for j in x.as_numer_denom()),
  2382. _handle)
  2383. return sympify(expr).expand(deep=deep, log=True, mul=False,
  2384. power_exp=False, power_base=False, multinomial=False,
  2385. basic=False, force=force, factor=factor)
  2386. def expand_func(expr, deep=True):
  2387. """
  2388. Wrapper around expand that only uses the func hint. See the expand
  2389. docstring for more information.
  2390. Examples
  2391. ========
  2392. >>> from sympy import expand_func, gamma
  2393. >>> from sympy.abc import x
  2394. >>> expand_func(gamma(x + 2))
  2395. x*(x + 1)*gamma(x)
  2396. """
  2397. return sympify(expr).expand(deep=deep, func=True, basic=False,
  2398. log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
  2399. def expand_trig(expr, deep=True):
  2400. """
  2401. Wrapper around expand that only uses the trig hint. See the expand
  2402. docstring for more information.
  2403. Examples
  2404. ========
  2405. >>> from sympy import expand_trig, sin
  2406. >>> from sympy.abc import x, y
  2407. >>> expand_trig(sin(x+y)*(x+y))
  2408. (x + y)*(sin(x)*cos(y) + sin(y)*cos(x))
  2409. """
  2410. return sympify(expr).expand(deep=deep, trig=True, basic=False,
  2411. log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
  2412. def expand_complex(expr, deep=True):
  2413. """
  2414. Wrapper around expand that only uses the complex hint. See the expand
  2415. docstring for more information.
  2416. Examples
  2417. ========
  2418. >>> from sympy import expand_complex, exp, sqrt, I
  2419. >>> from sympy.abc import z
  2420. >>> expand_complex(exp(z))
  2421. I*exp(re(z))*sin(im(z)) + exp(re(z))*cos(im(z))
  2422. >>> expand_complex(sqrt(I))
  2423. sqrt(2)/2 + sqrt(2)*I/2
  2424. See Also
  2425. ========
  2426. sympy.core.expr.Expr.as_real_imag
  2427. """
  2428. return sympify(expr).expand(deep=deep, complex=True, basic=False,
  2429. log=False, mul=False, power_exp=False, power_base=False, multinomial=False)
  2430. def expand_power_base(expr, deep=True, force=False):
  2431. """
  2432. Wrapper around expand that only uses the power_base hint.
  2433. A wrapper to expand(power_base=True) which separates a power with a base
  2434. that is a Mul into a product of powers, without performing any other
  2435. expansions, provided that assumptions about the power's base and exponent
  2436. allow.
  2437. deep=False (default is True) will only apply to the top-level expression.
  2438. force=True (default is False) will cause the expansion to ignore
  2439. assumptions about the base and exponent. When False, the expansion will
  2440. only happen if the base is non-negative or the exponent is an integer.
  2441. >>> from sympy.abc import x, y, z
  2442. >>> from sympy import expand_power_base, sin, cos, exp, Symbol
  2443. >>> (x*y)**2
  2444. x**2*y**2
  2445. >>> (2*x)**y
  2446. (2*x)**y
  2447. >>> expand_power_base(_)
  2448. 2**y*x**y
  2449. >>> expand_power_base((x*y)**z)
  2450. (x*y)**z
  2451. >>> expand_power_base((x*y)**z, force=True)
  2452. x**z*y**z
  2453. >>> expand_power_base(sin((x*y)**z), deep=False)
  2454. sin((x*y)**z)
  2455. >>> expand_power_base(sin((x*y)**z), force=True)
  2456. sin(x**z*y**z)
  2457. >>> expand_power_base((2*sin(x))**y + (2*cos(x))**y)
  2458. 2**y*sin(x)**y + 2**y*cos(x)**y
  2459. >>> expand_power_base((2*exp(y))**x)
  2460. 2**x*exp(y)**x
  2461. >>> expand_power_base((2*cos(x))**y)
  2462. 2**y*cos(x)**y
  2463. Notice that sums are left untouched. If this is not the desired behavior,
  2464. apply full ``expand()`` to the expression:
  2465. >>> expand_power_base(((x+y)*z)**2)
  2466. z**2*(x + y)**2
  2467. >>> (((x+y)*z)**2).expand()
  2468. x**2*z**2 + 2*x*y*z**2 + y**2*z**2
  2469. >>> expand_power_base((2*y)**(1+z))
  2470. 2**(z + 1)*y**(z + 1)
  2471. >>> ((2*y)**(1+z)).expand()
  2472. 2*2**z*y**(z + 1)
  2473. The power that is unexpanded can be expanded safely when
  2474. ``y != 0``, otherwise different values might be obtained for the expression:
  2475. >>> prev = _
  2476. If we indicate that ``y`` is positive but then replace it with
  2477. a value of 0 after expansion, the expression becomes 0:
  2478. >>> p = Symbol('p', positive=True)
  2479. >>> prev.subs(y, p).expand().subs(p, 0)
  2480. 0
  2481. But if ``z = -1`` the expression would not be zero:
  2482. >>> prev.subs(y, 0).subs(z, -1)
  2483. 1
  2484. See Also
  2485. ========
  2486. expand
  2487. """
  2488. return sympify(expr).expand(deep=deep, log=False, mul=False,
  2489. power_exp=False, power_base=True, multinomial=False,
  2490. basic=False, force=force)
  2491. def expand_power_exp(expr, deep=True):
  2492. """
  2493. Wrapper around expand that only uses the power_exp hint.
  2494. See the expand docstring for more information.
  2495. Examples
  2496. ========
  2497. >>> from sympy import expand_power_exp, Symbol
  2498. >>> from sympy.abc import x, y
  2499. >>> expand_power_exp(3**(y + 2))
  2500. 9*3**y
  2501. >>> expand_power_exp(x**(y + 2))
  2502. x**(y + 2)
  2503. If ``x = 0`` the value of the expression depends on the
  2504. value of ``y``; if the expression were expanded the result
  2505. would be 0. So expansion is only done if ``x != 0``:
  2506. >>> expand_power_exp(Symbol('x', zero=False)**(y + 2))
  2507. x**2*x**y
  2508. """
  2509. return sympify(expr).expand(deep=deep, complex=False, basic=False,
  2510. log=False, mul=False, power_exp=True, power_base=False, multinomial=False)
  2511. def count_ops(expr, visual=False):
  2512. """
  2513. Return a representation (integer or expression) of the operations in expr.
  2514. Parameters
  2515. ==========
  2516. expr : Expr
  2517. If expr is an iterable, the sum of the op counts of the
  2518. items will be returned.
  2519. visual : bool, optional
  2520. If ``False`` (default) then the sum of the coefficients of the
  2521. visual expression will be returned.
  2522. If ``True`` then the number of each type of operation is shown
  2523. with the core class types (or their virtual equivalent) multiplied by the
  2524. number of times they occur.
  2525. Examples
  2526. ========
  2527. >>> from sympy.abc import a, b, x, y
  2528. >>> from sympy import sin, count_ops
  2529. Although there is not a SUB object, minus signs are interpreted as
  2530. either negations or subtractions:
  2531. >>> (x - y).count_ops(visual=True)
  2532. SUB
  2533. >>> (-x).count_ops(visual=True)
  2534. NEG
  2535. Here, there are two Adds and a Pow:
  2536. >>> (1 + a + b**2).count_ops(visual=True)
  2537. 2*ADD + POW
  2538. In the following, an Add, Mul, Pow and two functions:
  2539. >>> (sin(x)*x + sin(x)**2).count_ops(visual=True)
  2540. ADD + MUL + POW + 2*SIN
  2541. for a total of 5:
  2542. >>> (sin(x)*x + sin(x)**2).count_ops(visual=False)
  2543. 5
  2544. Note that "what you type" is not always what you get. The expression
  2545. 1/x/y is translated by sympy into 1/(x*y) so it gives a DIV and MUL rather
  2546. than two DIVs:
  2547. >>> (1/x/y).count_ops(visual=True)
  2548. DIV + MUL
  2549. The visual option can be used to demonstrate the difference in
  2550. operations for expressions in different forms. Here, the Horner
  2551. representation is compared with the expanded form of a polynomial:
  2552. >>> eq=x*(1 + x*(2 + x*(3 + x)))
  2553. >>> count_ops(eq.expand(), visual=True) - count_ops(eq, visual=True)
  2554. -MUL + 3*POW
  2555. The count_ops function also handles iterables:
  2556. >>> count_ops([x, sin(x), None, True, x + 2], visual=False)
  2557. 2
  2558. >>> count_ops([x, sin(x), None, True, x + 2], visual=True)
  2559. ADD + SIN
  2560. >>> count_ops({x: sin(x), x + 2: y + 1}, visual=True)
  2561. 2*ADD + SIN
  2562. """
  2563. from .relational import Relational
  2564. from sympy.concrete.summations import Sum
  2565. from sympy.integrals.integrals import Integral
  2566. from sympy.logic.boolalg import BooleanFunction
  2567. from sympy.simplify.radsimp import fraction
  2568. expr = sympify(expr)
  2569. if isinstance(expr, Expr) and not expr.is_Relational:
  2570. ops = []
  2571. args = [expr]
  2572. NEG = Symbol('NEG')
  2573. DIV = Symbol('DIV')
  2574. SUB = Symbol('SUB')
  2575. ADD = Symbol('ADD')
  2576. EXP = Symbol('EXP')
  2577. while args:
  2578. a = args.pop()
  2579. # if the following fails because the object is
  2580. # not Basic type, then the object should be fixed
  2581. # since it is the intention that all args of Basic
  2582. # should themselves be Basic
  2583. if a.is_Rational:
  2584. #-1/3 = NEG + DIV
  2585. if a is not S.One:
  2586. if a.p < 0:
  2587. ops.append(NEG)
  2588. if a.q != 1:
  2589. ops.append(DIV)
  2590. continue
  2591. elif a.is_Mul or a.is_MatMul:
  2592. if _coeff_isneg(a):
  2593. ops.append(NEG)
  2594. if a.args[0] is S.NegativeOne:
  2595. a = a.as_two_terms()[1]
  2596. else:
  2597. a = -a
  2598. n, d = fraction(a)
  2599. if n.is_Integer:
  2600. ops.append(DIV)
  2601. if n < 0:
  2602. ops.append(NEG)
  2603. args.append(d)
  2604. continue # won't be -Mul but could be Add
  2605. elif d is not S.One:
  2606. if not d.is_Integer:
  2607. args.append(d)
  2608. ops.append(DIV)
  2609. args.append(n)
  2610. continue # could be -Mul
  2611. elif a.is_Add or a.is_MatAdd:
  2612. aargs = list(a.args)
  2613. negs = 0
  2614. for i, ai in enumerate(aargs):
  2615. if _coeff_isneg(ai):
  2616. negs += 1
  2617. args.append(-ai)
  2618. if i > 0:
  2619. ops.append(SUB)
  2620. else:
  2621. args.append(ai)
  2622. if i > 0:
  2623. ops.append(ADD)
  2624. if negs == len(aargs): # -x - y = NEG + SUB
  2625. ops.append(NEG)
  2626. elif _coeff_isneg(aargs[0]): # -x + y = SUB, but already recorded ADD
  2627. ops.append(SUB - ADD)
  2628. continue
  2629. if a.is_Pow and a.exp is S.NegativeOne:
  2630. ops.append(DIV)
  2631. args.append(a.base) # won't be -Mul but could be Add
  2632. continue
  2633. if a == S.Exp1:
  2634. ops.append(EXP)
  2635. continue
  2636. if a.is_Pow and a.base == S.Exp1:
  2637. ops.append(EXP)
  2638. args.append(a.exp)
  2639. continue
  2640. if a.is_Mul or isinstance(a, LatticeOp):
  2641. o = Symbol(a.func.__name__.upper())
  2642. # count the args
  2643. ops.append(o*(len(a.args) - 1))
  2644. elif a.args and (
  2645. a.is_Pow or
  2646. a.is_Function or
  2647. isinstance(a, Derivative) or
  2648. isinstance(a, Integral) or
  2649. isinstance(a, Sum)):
  2650. # if it's not in the list above we don't
  2651. # consider a.func something to count, e.g.
  2652. # Tuple, MatrixSymbol, etc...
  2653. if isinstance(a.func, UndefinedFunction):
  2654. o = Symbol("FUNC_" + a.func.__name__.upper())
  2655. else:
  2656. o = Symbol(a.func.__name__.upper())
  2657. ops.append(o)
  2658. if not a.is_Symbol:
  2659. args.extend(a.args)
  2660. elif isinstance(expr, Dict):
  2661. ops = [count_ops(k, visual=visual) +
  2662. count_ops(v, visual=visual) for k, v in expr.items()]
  2663. elif iterable(expr):
  2664. ops = [count_ops(i, visual=visual) for i in expr]
  2665. elif isinstance(expr, (Relational, BooleanFunction)):
  2666. ops = []
  2667. for arg in expr.args:
  2668. ops.append(count_ops(arg, visual=True))
  2669. o = Symbol(func_name(expr, short=True).upper())
  2670. ops.append(o)
  2671. elif not isinstance(expr, Basic):
  2672. ops = []
  2673. else: # it's Basic not isinstance(expr, Expr):
  2674. if not isinstance(expr, Basic):
  2675. raise TypeError("Invalid type of expr")
  2676. else:
  2677. ops = []
  2678. args = [expr]
  2679. while args:
  2680. a = args.pop()
  2681. if a.args:
  2682. o = Symbol(type(a).__name__.upper())
  2683. if a.is_Boolean:
  2684. ops.append(o*(len(a.args)-1))
  2685. else:
  2686. ops.append(o)
  2687. args.extend(a.args)
  2688. if not ops:
  2689. if visual:
  2690. return S.Zero
  2691. return 0
  2692. ops = Add(*ops)
  2693. if visual:
  2694. return ops
  2695. if ops.is_Number:
  2696. return int(ops)
  2697. return sum(int((a.args or [1])[0]) for a in Add.make_args(ops))
  2698. def nfloat(expr, n=15, exponent=False, dkeys=False):
  2699. """Make all Rationals in expr Floats except those in exponents
  2700. (unless the exponents flag is set to True) and those in undefined
  2701. functions. When processing dictionaries, do not modify the keys
  2702. unless ``dkeys=True``.
  2703. Examples
  2704. ========
  2705. >>> from sympy import nfloat, cos, pi, sqrt
  2706. >>> from sympy.abc import x, y
  2707. >>> nfloat(x**4 + x/2 + cos(pi/3) + 1 + sqrt(y))
  2708. x**4 + 0.5*x + sqrt(y) + 1.5
  2709. >>> nfloat(x**4 + sqrt(y), exponent=True)
  2710. x**4.0 + y**0.5
  2711. Container types are not modified:
  2712. >>> type(nfloat((1, 2))) is tuple
  2713. True
  2714. """
  2715. from sympy.matrices.matrices import MatrixBase
  2716. kw = {"n": n, "exponent": exponent, "dkeys": dkeys}
  2717. if isinstance(expr, MatrixBase):
  2718. return expr.applyfunc(lambda e: nfloat(e, **kw))
  2719. # handling of iterable containers
  2720. if iterable(expr, exclude=str):
  2721. if isinstance(expr, (dict, Dict)):
  2722. if dkeys:
  2723. args = [tuple((nfloat(i, **kw) for i in a))
  2724. for a in expr.items()]
  2725. else:
  2726. args = [(k, nfloat(v, **kw)) for k, v in expr.items()]
  2727. if isinstance(expr, dict):
  2728. return type(expr)(args)
  2729. else:
  2730. return expr.func(*args)
  2731. elif isinstance(expr, Basic):
  2732. return expr.func(*[nfloat(a, **kw) for a in expr.args])
  2733. return type(expr)([nfloat(a, **kw) for a in expr])
  2734. rv = sympify(expr)
  2735. if rv.is_Number:
  2736. return Float(rv, n)
  2737. elif rv.is_number:
  2738. # evalf doesn't always set the precision
  2739. rv = rv.n(n)
  2740. if rv.is_Number:
  2741. rv = Float(rv.n(n), n)
  2742. else:
  2743. pass # pure_complex(rv) is likely True
  2744. return rv
  2745. elif rv.is_Atom:
  2746. return rv
  2747. elif rv.is_Relational:
  2748. args_nfloat = (nfloat(arg, **kw) for arg in rv.args)
  2749. return rv.func(*args_nfloat)
  2750. # watch out for RootOf instances that don't like to have
  2751. # their exponents replaced with Dummies and also sometimes have
  2752. # problems with evaluating at low precision (issue 6393)
  2753. from sympy.polys.rootoftools import RootOf
  2754. rv = rv.xreplace({ro: ro.n(n) for ro in rv.atoms(RootOf)})
  2755. from .power import Pow
  2756. if not exponent:
  2757. reps = [(p, Pow(p.base, Dummy())) for p in rv.atoms(Pow)]
  2758. rv = rv.xreplace(dict(reps))
  2759. rv = rv.n(n)
  2760. if not exponent:
  2761. rv = rv.xreplace({d.exp: p.exp for p, d in reps})
  2762. else:
  2763. # Pow._eval_evalf special cases Integer exponents so if
  2764. # exponent is suppose to be handled we have to do so here
  2765. rv = rv.xreplace(Transform(
  2766. lambda x: Pow(x.base, Float(x.exp, n)),
  2767. lambda x: x.is_Pow and x.exp.is_Integer))
  2768. return rv.xreplace(Transform(
  2769. lambda x: x.func(*nfloat(x.args, n, exponent)),
  2770. lambda x: isinstance(x, Function) and not isinstance(x, AppliedUndef)))
  2771. from .symbol import Dummy, Symbol