order.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. from sympy.core import S, sympify, Expr, Dummy, Add, Mul
  2. from sympy.core.cache import cacheit
  3. from sympy.core.containers import Tuple
  4. from sympy.core.function import Function, PoleError, expand_power_base, expand_log
  5. from sympy.core.sorting import default_sort_key
  6. from sympy.functions.elementary.exponential import exp, log
  7. from sympy.sets.sets import Complement
  8. from sympy.utilities.iterables import uniq, is_sequence
  9. class Order(Expr):
  10. r""" Represents the limiting behavior of some function.
  11. Explanation
  12. ===========
  13. The order of a function characterizes the function based on the limiting
  14. behavior of the function as it goes to some limit. Only taking the limit
  15. point to be a number is currently supported. This is expressed in
  16. big O notation [1]_.
  17. The formal definition for the order of a function `g(x)` about a point `a`
  18. is such that `g(x) = O(f(x))` as `x \rightarrow a` if and only if there
  19. exists a `\delta > 0` and an `M > 0` such that `|g(x)| \leq M|f(x)|` for
  20. `|x-a| < \delta`. This is equivalent to `\limsup_{x \rightarrow a}
  21. |g(x)/f(x)| < \infty`.
  22. Let's illustrate it on the following example by taking the expansion of
  23. `\sin(x)` about 0:
  24. .. math ::
  25. \sin(x) = x - x^3/3! + O(x^5)
  26. where in this case `O(x^5) = x^5/5! - x^7/7! + \cdots`. By the definition
  27. of `O`, there is a `\delta > 0` and an `M` such that:
  28. .. math ::
  29. |x^5/5! - x^7/7! + ....| <= M|x^5| \text{ for } |x| < \delta
  30. or by the alternate definition:
  31. .. math ::
  32. \lim_{x \rightarrow 0} | (x^5/5! - x^7/7! + ....) / x^5| < \infty
  33. which surely is true, because
  34. .. math ::
  35. \lim_{x \rightarrow 0} | (x^5/5! - x^7/7! + ....) / x^5| = 1/5!
  36. As it is usually used, the order of a function can be intuitively thought
  37. of representing all terms of powers greater than the one specified. For
  38. example, `O(x^3)` corresponds to any terms proportional to `x^3,
  39. x^4,\ldots` and any higher power. For a polynomial, this leaves terms
  40. proportional to `x^2`, `x` and constants.
  41. Examples
  42. ========
  43. >>> from sympy import O, oo, cos, pi
  44. >>> from sympy.abc import x, y
  45. >>> O(x + x**2)
  46. O(x)
  47. >>> O(x + x**2, (x, 0))
  48. O(x)
  49. >>> O(x + x**2, (x, oo))
  50. O(x**2, (x, oo))
  51. >>> O(1 + x*y)
  52. O(1, x, y)
  53. >>> O(1 + x*y, (x, 0), (y, 0))
  54. O(1, x, y)
  55. >>> O(1 + x*y, (x, oo), (y, oo))
  56. O(x*y, (x, oo), (y, oo))
  57. >>> O(1) in O(1, x)
  58. True
  59. >>> O(1, x) in O(1)
  60. False
  61. >>> O(x) in O(1, x)
  62. True
  63. >>> O(x**2) in O(x)
  64. True
  65. >>> O(x)*x
  66. O(x**2)
  67. >>> O(x) - O(x)
  68. O(x)
  69. >>> O(cos(x))
  70. O(1)
  71. >>> O(cos(x), (x, pi/2))
  72. O(x - pi/2, (x, pi/2))
  73. References
  74. ==========
  75. .. [1] `Big O notation <https://en.wikipedia.org/wiki/Big_O_notation>`_
  76. Notes
  77. =====
  78. In ``O(f(x), x)`` the expression ``f(x)`` is assumed to have a leading
  79. term. ``O(f(x), x)`` is automatically transformed to
  80. ``O(f(x).as_leading_term(x),x)``.
  81. ``O(expr*f(x), x)`` is ``O(f(x), x)``
  82. ``O(expr, x)`` is ``O(1)``
  83. ``O(0, x)`` is 0.
  84. Multivariate O is also supported:
  85. ``O(f(x, y), x, y)`` is transformed to
  86. ``O(f(x, y).as_leading_term(x,y).as_leading_term(y), x, y)``
  87. In the multivariate case, it is assumed the limits w.r.t. the various
  88. symbols commute.
  89. If no symbols are passed then all symbols in the expression are used
  90. and the limit point is assumed to be zero.
  91. """
  92. is_Order = True
  93. __slots__ = ()
  94. @cacheit
  95. def __new__(cls, expr, *args, **kwargs):
  96. expr = sympify(expr)
  97. if not args:
  98. if expr.is_Order:
  99. variables = expr.variables
  100. point = expr.point
  101. else:
  102. variables = list(expr.free_symbols)
  103. point = [S.Zero]*len(variables)
  104. else:
  105. args = list(args if is_sequence(args) else [args])
  106. variables, point = [], []
  107. if is_sequence(args[0]):
  108. for a in args:
  109. v, p = list(map(sympify, a))
  110. variables.append(v)
  111. point.append(p)
  112. else:
  113. variables = list(map(sympify, args))
  114. point = [S.Zero]*len(variables)
  115. if not all(v.is_symbol for v in variables):
  116. raise TypeError('Variables are not symbols, got %s' % variables)
  117. if len(list(uniq(variables))) != len(variables):
  118. raise ValueError('Variables are supposed to be unique symbols, got %s' % variables)
  119. if expr.is_Order:
  120. expr_vp = dict(expr.args[1:])
  121. new_vp = dict(expr_vp)
  122. vp = dict(zip(variables, point))
  123. for v, p in vp.items():
  124. if v in new_vp.keys():
  125. if p != new_vp[v]:
  126. raise NotImplementedError(
  127. "Mixing Order at different points is not supported.")
  128. else:
  129. new_vp[v] = p
  130. if set(expr_vp.keys()) == set(new_vp.keys()):
  131. return expr
  132. else:
  133. variables = list(new_vp.keys())
  134. point = [new_vp[v] for v in variables]
  135. if expr is S.NaN:
  136. return S.NaN
  137. if any(x in p.free_symbols for x in variables for p in point):
  138. raise ValueError('Got %s as a point.' % point)
  139. if variables:
  140. if any(p != point[0] for p in point):
  141. raise NotImplementedError(
  142. "Multivariable orders at different points are not supported.")
  143. if point[0] in (S.Infinity, S.Infinity*S.ImaginaryUnit):
  144. s = {k: 1/Dummy() for k in variables}
  145. rs = {1/v: 1/k for k, v in s.items()}
  146. ps = [S.Zero for p in point]
  147. elif point[0] in (S.NegativeInfinity, S.NegativeInfinity*S.ImaginaryUnit):
  148. s = {k: -1/Dummy() for k in variables}
  149. rs = {-1/v: -1/k for k, v in s.items()}
  150. ps = [S.Zero for p in point]
  151. elif point[0] is not S.Zero:
  152. s = {k: Dummy() + point[0] for k in variables}
  153. rs = {(v - point[0]).together(): k - point[0] for k, v in s.items()}
  154. ps = [S.Zero for p in point]
  155. else:
  156. s = ()
  157. rs = ()
  158. ps = list(point)
  159. expr = expr.subs(s)
  160. if expr.is_Add:
  161. expr = expr.factor()
  162. if s:
  163. args = tuple([r[0] for r in rs.items()])
  164. else:
  165. args = tuple(variables)
  166. if len(variables) > 1:
  167. # XXX: better way? We need this expand() to
  168. # workaround e.g: expr = x*(x + y).
  169. # (x*(x + y)).as_leading_term(x, y) currently returns
  170. # x*y (wrong order term!). That's why we want to deal with
  171. # expand()'ed expr (handled in "if expr.is_Add" branch below).
  172. expr = expr.expand()
  173. old_expr = None
  174. while old_expr != expr:
  175. old_expr = expr
  176. if expr.is_Add:
  177. lst = expr.extract_leading_order(args)
  178. expr = Add(*[f.expr for (e, f) in lst])
  179. elif expr:
  180. try:
  181. expr = expr.as_leading_term(*args)
  182. except PoleError:
  183. if isinstance(expr, Function) or\
  184. all(isinstance(arg, Function) for arg in expr.args):
  185. # It is not possible to simplify an expression
  186. # containing only functions (which raise error on
  187. # call to leading term) further
  188. pass
  189. else:
  190. orders = []
  191. pts = tuple(zip(args, ps))
  192. for arg in expr.args:
  193. try:
  194. lt = arg.as_leading_term(*args)
  195. except PoleError:
  196. lt = arg
  197. if lt not in args:
  198. order = Order(lt)
  199. else:
  200. order = Order(lt, *pts)
  201. orders.append(order)
  202. if expr.is_Add:
  203. new_expr = Order(Add(*orders), *pts)
  204. if new_expr.is_Add:
  205. new_expr = Order(Add(*[a.expr for a in new_expr.args]), *pts)
  206. expr = new_expr.expr
  207. elif expr.is_Mul:
  208. expr = Mul(*[a.expr for a in orders])
  209. elif expr.is_Pow:
  210. e = expr.exp
  211. b = expr.base
  212. expr = exp(e * log(b))
  213. # It would probably be better to handle this somewhere
  214. # else. This is needed for a testcase in which there is a
  215. # symbol with the assumptions zero=True.
  216. if expr.is_zero:
  217. expr = S.Zero
  218. else:
  219. expr = expr.as_independent(*args, as_Add=False)[1]
  220. expr = expand_power_base(expr)
  221. expr = expand_log(expr)
  222. if len(args) == 1:
  223. # The definition of O(f(x)) symbol explicitly stated that
  224. # the argument of f(x) is irrelevant. That's why we can
  225. # combine some power exponents (only "on top" of the
  226. # expression tree for f(x)), e.g.:
  227. # x**p * (-x)**q -> x**(p+q) for real p, q.
  228. x = args[0]
  229. margs = list(Mul.make_args(
  230. expr.as_independent(x, as_Add=False)[1]))
  231. for i, t in enumerate(margs):
  232. if t.is_Pow:
  233. b, q = t.args
  234. if b in (x, -x) and q.is_real and not q.has(x):
  235. margs[i] = x**q
  236. elif b.is_Pow and not b.exp.has(x):
  237. b, r = b.args
  238. if b in (x, -x) and r.is_real:
  239. margs[i] = x**(r*q)
  240. elif b.is_Mul and b.args[0] is S.NegativeOne:
  241. b = -b
  242. if b.is_Pow and not b.exp.has(x):
  243. b, r = b.args
  244. if b in (x, -x) and r.is_real:
  245. margs[i] = x**(r*q)
  246. expr = Mul(*margs)
  247. expr = expr.subs(rs)
  248. if expr.is_Order:
  249. expr = expr.expr
  250. if not expr.has(*variables) and not expr.is_zero:
  251. expr = S.One
  252. # create Order instance:
  253. vp = dict(zip(variables, point))
  254. variables.sort(key=default_sort_key)
  255. point = [vp[v] for v in variables]
  256. args = (expr,) + Tuple(*zip(variables, point))
  257. obj = Expr.__new__(cls, *args)
  258. return obj
  259. def _eval_nseries(self, x, n, logx, cdir=0):
  260. return self
  261. @property
  262. def expr(self):
  263. return self.args[0]
  264. @property
  265. def variables(self):
  266. if self.args[1:]:
  267. return tuple(x[0] for x in self.args[1:])
  268. else:
  269. return ()
  270. @property
  271. def point(self):
  272. if self.args[1:]:
  273. return tuple(x[1] for x in self.args[1:])
  274. else:
  275. return ()
  276. @property
  277. def free_symbols(self):
  278. return self.expr.free_symbols | set(self.variables)
  279. def _eval_power(b, e):
  280. if e.is_Number and e.is_nonnegative:
  281. return b.func(b.expr ** e, *b.args[1:])
  282. if e == O(1):
  283. return b
  284. return
  285. def as_expr_variables(self, order_symbols):
  286. if order_symbols is None:
  287. order_symbols = self.args[1:]
  288. else:
  289. if (not all(o[1] == order_symbols[0][1] for o in order_symbols) and
  290. not all(p == self.point[0] for p in self.point)): # pragma: no cover
  291. raise NotImplementedError('Order at points other than 0 '
  292. 'or oo not supported, got %s as a point.' % self.point)
  293. if order_symbols and order_symbols[0][1] != self.point[0]:
  294. raise NotImplementedError(
  295. "Multiplying Order at different points is not supported.")
  296. order_symbols = dict(order_symbols)
  297. for s, p in dict(self.args[1:]).items():
  298. if s not in order_symbols.keys():
  299. order_symbols[s] = p
  300. order_symbols = sorted(order_symbols.items(), key=lambda x: default_sort_key(x[0]))
  301. return self.expr, tuple(order_symbols)
  302. def removeO(self):
  303. return S.Zero
  304. def getO(self):
  305. return self
  306. @cacheit
  307. def contains(self, expr):
  308. r"""
  309. Return True if expr belongs to Order(self.expr, \*self.variables).
  310. Return False if self belongs to expr.
  311. Return None if the inclusion relation cannot be determined
  312. (e.g. when self and expr have different symbols).
  313. """
  314. expr = sympify(expr)
  315. if expr.is_zero:
  316. return True
  317. if expr is S.NaN:
  318. return False
  319. point = self.point[0] if self.point else S.Zero
  320. if expr.is_Order:
  321. if (any(p != point for p in expr.point) or
  322. any(p != point for p in self.point)):
  323. return None
  324. if expr.expr == self.expr:
  325. # O(1) + O(1), O(1) + O(1, x), etc.
  326. return all(x in self.args[1:] for x in expr.args[1:])
  327. if expr.expr.is_Add:
  328. return all(self.contains(x) for x in expr.expr.args)
  329. if self.expr.is_Add and point.is_zero:
  330. return any(self.func(x, *self.args[1:]).contains(expr)
  331. for x in self.expr.args)
  332. if self.variables and expr.variables:
  333. common_symbols = tuple(
  334. [s for s in self.variables if s in expr.variables])
  335. elif self.variables:
  336. common_symbols = self.variables
  337. else:
  338. common_symbols = expr.variables
  339. if not common_symbols:
  340. return None
  341. if (self.expr.is_Pow and len(self.variables) == 1
  342. and self.variables == expr.variables):
  343. symbol = self.variables[0]
  344. other = expr.expr.as_independent(symbol, as_Add=False)[1]
  345. if (other.is_Pow and other.base == symbol and
  346. self.expr.base == symbol):
  347. if point.is_zero:
  348. rv = (self.expr.exp - other.exp).is_nonpositive
  349. if point.is_infinite:
  350. rv = (self.expr.exp - other.exp).is_nonnegative
  351. if rv is not None:
  352. return rv
  353. from sympy.simplify.powsimp import powsimp
  354. r = None
  355. ratio = self.expr/expr.expr
  356. ratio = powsimp(ratio, deep=True, combine='exp')
  357. for s in common_symbols:
  358. from sympy.series.limits import Limit
  359. l = Limit(ratio, s, point).doit(heuristics=False)
  360. if not isinstance(l, Limit):
  361. l = l != 0
  362. else:
  363. l = None
  364. if r is None:
  365. r = l
  366. else:
  367. if r != l:
  368. return
  369. return r
  370. if self.expr.is_Pow and len(self.variables) == 1:
  371. symbol = self.variables[0]
  372. other = expr.as_independent(symbol, as_Add=False)[1]
  373. if (other.is_Pow and other.base == symbol and
  374. self.expr.base == symbol):
  375. if point.is_zero:
  376. rv = (self.expr.exp - other.exp).is_nonpositive
  377. if point.is_infinite:
  378. rv = (self.expr.exp - other.exp).is_nonnegative
  379. if rv is not None:
  380. return rv
  381. obj = self.func(expr, *self.args[1:])
  382. return self.contains(obj)
  383. def __contains__(self, other):
  384. result = self.contains(other)
  385. if result is None:
  386. raise TypeError('contains did not evaluate to a bool')
  387. return result
  388. def _eval_subs(self, old, new):
  389. if old in self.variables:
  390. newexpr = self.expr.subs(old, new)
  391. i = self.variables.index(old)
  392. newvars = list(self.variables)
  393. newpt = list(self.point)
  394. if new.is_symbol:
  395. newvars[i] = new
  396. else:
  397. syms = new.free_symbols
  398. if len(syms) == 1 or old in syms:
  399. if old in syms:
  400. var = self.variables[i]
  401. else:
  402. var = syms.pop()
  403. # First, try to substitute self.point in the "new"
  404. # expr to see if this is a fixed point.
  405. # E.g. O(y).subs(y, sin(x))
  406. point = new.subs(var, self.point[i])
  407. if point != self.point[i]:
  408. from sympy.solvers.solveset import solveset
  409. d = Dummy()
  410. sol = solveset(old - new.subs(var, d), d)
  411. if isinstance(sol, Complement):
  412. e1 = sol.args[0]
  413. e2 = sol.args[1]
  414. sol = set(e1) - set(e2)
  415. res = [dict(zip((d, ), sol))]
  416. point = d.subs(res[0]).limit(old, self.point[i])
  417. newvars[i] = var
  418. newpt[i] = point
  419. elif old not in syms:
  420. del newvars[i], newpt[i]
  421. if not syms and new == self.point[i]:
  422. newvars.extend(syms)
  423. newpt.extend([S.Zero]*len(syms))
  424. else:
  425. return
  426. return Order(newexpr, *zip(newvars, newpt))
  427. def _eval_conjugate(self):
  428. expr = self.expr._eval_conjugate()
  429. if expr is not None:
  430. return self.func(expr, *self.args[1:])
  431. def _eval_derivative(self, x):
  432. return self.func(self.expr.diff(x), *self.args[1:]) or self
  433. def _eval_transpose(self):
  434. expr = self.expr._eval_transpose()
  435. if expr is not None:
  436. return self.func(expr, *self.args[1:])
  437. def __neg__(self):
  438. return self
  439. O = Order