rootoftools.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242
  1. """Implementation of RootOf class and related tools. """
  2. from sympy.core.basic import Basic
  3. from sympy.core import (S, Expr, Integer, Float, I, oo, Add, Lambda,
  4. symbols, sympify, Rational, Dummy)
  5. from sympy.core.cache import cacheit
  6. from sympy.core.relational import is_le
  7. from sympy.core.sorting import ordered
  8. from sympy.polys.domains import QQ
  9. from sympy.polys.polyerrors import (
  10. MultivariatePolynomialError,
  11. GeneratorsNeeded,
  12. PolynomialError,
  13. DomainError)
  14. from sympy.polys.polyfuncs import symmetrize, viete
  15. from sympy.polys.polyroots import (
  16. roots_linear, roots_quadratic, roots_binomial,
  17. preprocess_roots, roots)
  18. from sympy.polys.polytools import Poly, PurePoly, factor
  19. from sympy.polys.rationaltools import together
  20. from sympy.polys.rootisolation import (
  21. dup_isolate_complex_roots_sqf,
  22. dup_isolate_real_roots_sqf)
  23. from sympy.utilities import lambdify, public, sift, numbered_symbols
  24. from mpmath import mpf, mpc, findroot, workprec
  25. from mpmath.libmp.libmpf import dps_to_prec, prec_to_dps
  26. from sympy.multipledispatch import dispatch
  27. from itertools import chain
  28. __all__ = ['CRootOf']
  29. class _pure_key_dict:
  30. """A minimal dictionary that makes sure that the key is a
  31. univariate PurePoly instance.
  32. Examples
  33. ========
  34. Only the following actions are guaranteed:
  35. >>> from sympy.polys.rootoftools import _pure_key_dict
  36. >>> from sympy import PurePoly
  37. >>> from sympy.abc import x, y
  38. 1) creation
  39. >>> P = _pure_key_dict()
  40. 2) assignment for a PurePoly or univariate polynomial
  41. >>> P[x] = 1
  42. >>> P[PurePoly(x - y, x)] = 2
  43. 3) retrieval based on PurePoly key comparison (use this
  44. instead of the get method)
  45. >>> P[y]
  46. 1
  47. 4) KeyError when trying to retrieve a nonexisting key
  48. >>> P[y + 1]
  49. Traceback (most recent call last):
  50. ...
  51. KeyError: PurePoly(y + 1, y, domain='ZZ')
  52. 5) ability to query with ``in``
  53. >>> x + 1 in P
  54. False
  55. NOTE: this is a *not* a dictionary. It is a very basic object
  56. for internal use that makes sure to always address its cache
  57. via PurePoly instances. It does not, for example, implement
  58. ``get`` or ``setdefault``.
  59. """
  60. def __init__(self):
  61. self._dict = {}
  62. def __getitem__(self, k):
  63. if not isinstance(k, PurePoly):
  64. if not (isinstance(k, Expr) and len(k.free_symbols) == 1):
  65. raise KeyError
  66. k = PurePoly(k, expand=False)
  67. return self._dict[k]
  68. def __setitem__(self, k, v):
  69. if not isinstance(k, PurePoly):
  70. if not (isinstance(k, Expr) and len(k.free_symbols) == 1):
  71. raise ValueError('expecting univariate expression')
  72. k = PurePoly(k, expand=False)
  73. self._dict[k] = v
  74. def __contains__(self, k):
  75. try:
  76. self[k]
  77. return True
  78. except KeyError:
  79. return False
  80. _reals_cache = _pure_key_dict()
  81. _complexes_cache = _pure_key_dict()
  82. def _pure_factors(poly):
  83. _, factors = poly.factor_list()
  84. return [(PurePoly(f, expand=False), m) for f, m in factors]
  85. def _imag_count_of_factor(f):
  86. """Return the number of imaginary roots for irreducible
  87. univariate polynomial ``f``.
  88. """
  89. terms = [(i, j) for (i,), j in f.terms()]
  90. if any(i % 2 for i, j in terms):
  91. return 0
  92. # update signs
  93. even = [(i, I**i*j) for i, j in terms]
  94. even = Poly.from_dict(dict(even), Dummy('x'))
  95. return int(even.count_roots(-oo, oo))
  96. @public
  97. def rootof(f, x, index=None, radicals=True, expand=True):
  98. """An indexed root of a univariate polynomial.
  99. Returns either a :obj:`ComplexRootOf` object or an explicit
  100. expression involving radicals.
  101. Parameters
  102. ==========
  103. f : Expr
  104. Univariate polynomial.
  105. x : Symbol, optional
  106. Generator for ``f``.
  107. index : int or Integer
  108. radicals : bool
  109. Return a radical expression if possible.
  110. expand : bool
  111. Expand ``f``.
  112. """
  113. return CRootOf(f, x, index=index, radicals=radicals, expand=expand)
  114. @public
  115. class RootOf(Expr):
  116. """Represents a root of a univariate polynomial.
  117. Base class for roots of different kinds of polynomials.
  118. Only complex roots are currently supported.
  119. """
  120. __slots__ = ('poly',)
  121. def __new__(cls, f, x, index=None, radicals=True, expand=True):
  122. """Construct a new ``CRootOf`` object for ``k``-th root of ``f``."""
  123. return rootof(f, x, index=index, radicals=radicals, expand=expand)
  124. @public
  125. class ComplexRootOf(RootOf):
  126. """Represents an indexed complex root of a polynomial.
  127. Roots of a univariate polynomial separated into disjoint
  128. real or complex intervals and indexed in a fixed order:
  129. * real roots come first and are sorted in increasing order;
  130. * complex roots come next and are sorted primarily by increasing
  131. real part, secondarily by increasing imaginary part.
  132. Currently only rational coefficients are allowed.
  133. Can be imported as ``CRootOf``. To avoid confusion, the
  134. generator must be a Symbol.
  135. Examples
  136. ========
  137. >>> from sympy import CRootOf, rootof
  138. >>> from sympy.abc import x
  139. CRootOf is a way to reference a particular root of a
  140. polynomial. If there is a rational root, it will be returned:
  141. >>> CRootOf.clear_cache() # for doctest reproducibility
  142. >>> CRootOf(x**2 - 4, 0)
  143. -2
  144. Whether roots involving radicals are returned or not
  145. depends on whether the ``radicals`` flag is true (which is
  146. set to True with rootof):
  147. >>> CRootOf(x**2 - 3, 0)
  148. CRootOf(x**2 - 3, 0)
  149. >>> CRootOf(x**2 - 3, 0, radicals=True)
  150. -sqrt(3)
  151. >>> rootof(x**2 - 3, 0)
  152. -sqrt(3)
  153. The following cannot be expressed in terms of radicals:
  154. >>> r = rootof(4*x**5 + 16*x**3 + 12*x**2 + 7, 0); r
  155. CRootOf(4*x**5 + 16*x**3 + 12*x**2 + 7, 0)
  156. The root bounds can be seen, however, and they are used by the
  157. evaluation methods to get numerical approximations for the root.
  158. >>> interval = r._get_interval(); interval
  159. (-1, 0)
  160. >>> r.evalf(2)
  161. -0.98
  162. The evalf method refines the width of the root bounds until it
  163. guarantees that any decimal approximation within those bounds
  164. will satisfy the desired precision. It then stores the refined
  165. interval so subsequent requests at or below the requested
  166. precision will not have to recompute the root bounds and will
  167. return very quickly.
  168. Before evaluation above, the interval was
  169. >>> interval
  170. (-1, 0)
  171. After evaluation it is now
  172. >>> r._get_interval() # doctest: +SKIP
  173. (-165/169, -206/211)
  174. To reset all intervals for a given polynomial, the :meth:`_reset` method
  175. can be called from any CRootOf instance of the polynomial:
  176. >>> r._reset()
  177. >>> r._get_interval()
  178. (-1, 0)
  179. The :meth:`eval_approx` method will also find the root to a given
  180. precision but the interval is not modified unless the search
  181. for the root fails to converge within the root bounds. And
  182. the secant method is used to find the root. (The ``evalf``
  183. method uses bisection and will always update the interval.)
  184. >>> r.eval_approx(2)
  185. -0.98
  186. The interval needed to be slightly updated to find that root:
  187. >>> r._get_interval()
  188. (-1, -1/2)
  189. The ``evalf_rational`` will compute a rational approximation
  190. of the root to the desired accuracy or precision.
  191. >>> r.eval_rational(n=2)
  192. -69629/71318
  193. >>> t = CRootOf(x**3 + 10*x + 1, 1)
  194. >>> t.eval_rational(1e-1)
  195. 15/256 - 805*I/256
  196. >>> t.eval_rational(1e-1, 1e-4)
  197. 3275/65536 - 414645*I/131072
  198. >>> t.eval_rational(1e-4, 1e-4)
  199. 6545/131072 - 414645*I/131072
  200. >>> t.eval_rational(n=2)
  201. 104755/2097152 - 6634255*I/2097152
  202. Notes
  203. =====
  204. Although a PurePoly can be constructed from a non-symbol generator
  205. RootOf instances of non-symbols are disallowed to avoid confusion
  206. over what root is being represented.
  207. >>> from sympy import exp, PurePoly
  208. >>> PurePoly(x) == PurePoly(exp(x))
  209. True
  210. >>> CRootOf(x - 1, 0)
  211. 1
  212. >>> CRootOf(exp(x) - 1, 0) # would correspond to x == 0
  213. Traceback (most recent call last):
  214. ...
  215. sympy.polys.polyerrors.PolynomialError: generator must be a Symbol
  216. See Also
  217. ========
  218. eval_approx
  219. eval_rational
  220. """
  221. __slots__ = ('index',)
  222. is_complex = True
  223. is_number = True
  224. is_finite = True
  225. def __new__(cls, f, x, index=None, radicals=False, expand=True):
  226. """ Construct an indexed complex root of a polynomial.
  227. See ``rootof`` for the parameters.
  228. The default value of ``radicals`` is ``False`` to satisfy
  229. ``eval(srepr(expr) == expr``.
  230. """
  231. x = sympify(x)
  232. if index is None and x.is_Integer:
  233. x, index = None, x
  234. else:
  235. index = sympify(index)
  236. if index is not None and index.is_Integer:
  237. index = int(index)
  238. else:
  239. raise ValueError("expected an integer root index, got %s" % index)
  240. poly = PurePoly(f, x, greedy=False, expand=expand)
  241. if not poly.is_univariate:
  242. raise PolynomialError("only univariate polynomials are allowed")
  243. if not poly.gen.is_Symbol:
  244. # PurePoly(sin(x) + 1) == PurePoly(x + 1) but the roots of
  245. # x for each are not the same: issue 8617
  246. raise PolynomialError("generator must be a Symbol")
  247. degree = poly.degree()
  248. if degree <= 0:
  249. raise PolynomialError("Cannot construct CRootOf object for %s" % f)
  250. if index < -degree or index >= degree:
  251. raise IndexError("root index out of [%d, %d] range, got %d" %
  252. (-degree, degree - 1, index))
  253. elif index < 0:
  254. index += degree
  255. dom = poly.get_domain()
  256. if not dom.is_Exact:
  257. poly = poly.to_exact()
  258. roots = cls._roots_trivial(poly, radicals)
  259. if roots is not None:
  260. return roots[index]
  261. coeff, poly = preprocess_roots(poly)
  262. dom = poly.get_domain()
  263. if not dom.is_ZZ:
  264. raise NotImplementedError("CRootOf is not supported over %s" % dom)
  265. root = cls._indexed_root(poly, index, lazy=True)
  266. return coeff * cls._postprocess_root(root, radicals)
  267. @classmethod
  268. def _new(cls, poly, index):
  269. """Construct new ``CRootOf`` object from raw data. """
  270. obj = Expr.__new__(cls)
  271. obj.poly = PurePoly(poly)
  272. obj.index = index
  273. try:
  274. _reals_cache[obj.poly] = _reals_cache[poly]
  275. _complexes_cache[obj.poly] = _complexes_cache[poly]
  276. except KeyError:
  277. pass
  278. return obj
  279. def _hashable_content(self):
  280. return (self.poly, self.index)
  281. @property
  282. def expr(self):
  283. return self.poly.as_expr()
  284. @property
  285. def args(self):
  286. return (self.expr, Integer(self.index))
  287. @property
  288. def free_symbols(self):
  289. # CRootOf currently only works with univariate expressions
  290. # whose poly attribute should be a PurePoly with no free
  291. # symbols
  292. return set()
  293. def _eval_is_real(self):
  294. """Return ``True`` if the root is real. """
  295. self._ensure_reals_init()
  296. return self.index < len(_reals_cache[self.poly])
  297. def _eval_is_imaginary(self):
  298. """Return ``True`` if the root is imaginary. """
  299. self._ensure_reals_init()
  300. if self.index >= len(_reals_cache[self.poly]):
  301. ivl = self._get_interval()
  302. return ivl.ax*ivl.bx <= 0 # all others are on one side or the other
  303. return False # XXX is this necessary?
  304. @classmethod
  305. def real_roots(cls, poly, radicals=True):
  306. """Get real roots of a polynomial. """
  307. return cls._get_roots("_real_roots", poly, radicals)
  308. @classmethod
  309. def all_roots(cls, poly, radicals=True):
  310. """Get real and complex roots of a polynomial. """
  311. return cls._get_roots("_all_roots", poly, radicals)
  312. @classmethod
  313. def _get_reals_sqf(cls, currentfactor, use_cache=True):
  314. """Get real root isolating intervals for a square-free factor."""
  315. if use_cache and currentfactor in _reals_cache:
  316. real_part = _reals_cache[currentfactor]
  317. else:
  318. _reals_cache[currentfactor] = real_part = \
  319. dup_isolate_real_roots_sqf(
  320. currentfactor.rep.rep, currentfactor.rep.dom, blackbox=True)
  321. return real_part
  322. @classmethod
  323. def _get_complexes_sqf(cls, currentfactor, use_cache=True):
  324. """Get complex root isolating intervals for a square-free factor."""
  325. if use_cache and currentfactor in _complexes_cache:
  326. complex_part = _complexes_cache[currentfactor]
  327. else:
  328. _complexes_cache[currentfactor] = complex_part = \
  329. dup_isolate_complex_roots_sqf(
  330. currentfactor.rep.rep, currentfactor.rep.dom, blackbox=True)
  331. return complex_part
  332. @classmethod
  333. def _get_reals(cls, factors, use_cache=True):
  334. """Compute real root isolating intervals for a list of factors. """
  335. reals = []
  336. for currentfactor, k in factors:
  337. try:
  338. if not use_cache:
  339. raise KeyError
  340. r = _reals_cache[currentfactor]
  341. reals.extend([(i, currentfactor, k) for i in r])
  342. except KeyError:
  343. real_part = cls._get_reals_sqf(currentfactor, use_cache)
  344. new = [(root, currentfactor, k) for root in real_part]
  345. reals.extend(new)
  346. reals = cls._reals_sorted(reals)
  347. return reals
  348. @classmethod
  349. def _get_complexes(cls, factors, use_cache=True):
  350. """Compute complex root isolating intervals for a list of factors. """
  351. complexes = []
  352. for currentfactor, k in ordered(factors):
  353. try:
  354. if not use_cache:
  355. raise KeyError
  356. c = _complexes_cache[currentfactor]
  357. complexes.extend([(i, currentfactor, k) for i in c])
  358. except KeyError:
  359. complex_part = cls._get_complexes_sqf(currentfactor, use_cache)
  360. new = [(root, currentfactor, k) for root in complex_part]
  361. complexes.extend(new)
  362. complexes = cls._complexes_sorted(complexes)
  363. return complexes
  364. @classmethod
  365. def _reals_sorted(cls, reals):
  366. """Make real isolating intervals disjoint and sort roots. """
  367. cache = {}
  368. for i, (u, f, k) in enumerate(reals):
  369. for j, (v, g, m) in enumerate(reals[i + 1:]):
  370. u, v = u.refine_disjoint(v)
  371. reals[i + j + 1] = (v, g, m)
  372. reals[i] = (u, f, k)
  373. reals = sorted(reals, key=lambda r: r[0].a)
  374. for root, currentfactor, _ in reals:
  375. if currentfactor in cache:
  376. cache[currentfactor].append(root)
  377. else:
  378. cache[currentfactor] = [root]
  379. for currentfactor, root in cache.items():
  380. _reals_cache[currentfactor] = root
  381. return reals
  382. @classmethod
  383. def _refine_imaginary(cls, complexes):
  384. sifted = sift(complexes, lambda c: c[1])
  385. complexes = []
  386. for f in ordered(sifted):
  387. nimag = _imag_count_of_factor(f)
  388. if nimag == 0:
  389. # refine until xbounds are neg or pos
  390. for u, f, k in sifted[f]:
  391. while u.ax*u.bx <= 0:
  392. u = u._inner_refine()
  393. complexes.append((u, f, k))
  394. else:
  395. # refine until all but nimag xbounds are neg or pos
  396. potential_imag = list(range(len(sifted[f])))
  397. while True:
  398. assert len(potential_imag) > 1
  399. for i in list(potential_imag):
  400. u, f, k = sifted[f][i]
  401. if u.ax*u.bx > 0:
  402. potential_imag.remove(i)
  403. elif u.ax != u.bx:
  404. u = u._inner_refine()
  405. sifted[f][i] = u, f, k
  406. if len(potential_imag) == nimag:
  407. break
  408. complexes.extend(sifted[f])
  409. return complexes
  410. @classmethod
  411. def _refine_complexes(cls, complexes):
  412. """return complexes such that no bounding rectangles of non-conjugate
  413. roots would intersect. In addition, assure that neither ay nor by is
  414. 0 to guarantee that non-real roots are distinct from real roots in
  415. terms of the y-bounds.
  416. """
  417. # get the intervals pairwise-disjoint.
  418. # If rectangles were drawn around the coordinates of the bounding
  419. # rectangles, no rectangles would intersect after this procedure.
  420. for i, (u, f, k) in enumerate(complexes):
  421. for j, (v, g, m) in enumerate(complexes[i + 1:]):
  422. u, v = u.refine_disjoint(v)
  423. complexes[i + j + 1] = (v, g, m)
  424. complexes[i] = (u, f, k)
  425. # refine until the x-bounds are unambiguously positive or negative
  426. # for non-imaginary roots
  427. complexes = cls._refine_imaginary(complexes)
  428. # make sure that all y bounds are off the real axis
  429. # and on the same side of the axis
  430. for i, (u, f, k) in enumerate(complexes):
  431. while u.ay*u.by <= 0:
  432. u = u.refine()
  433. complexes[i] = u, f, k
  434. return complexes
  435. @classmethod
  436. def _complexes_sorted(cls, complexes):
  437. """Make complex isolating intervals disjoint and sort roots. """
  438. complexes = cls._refine_complexes(complexes)
  439. # XXX don't sort until you are sure that it is compatible
  440. # with the indexing method but assert that the desired state
  441. # is not broken
  442. C, F = 0, 1 # location of ComplexInterval and factor
  443. fs = {i[F] for i in complexes}
  444. for i in range(1, len(complexes)):
  445. if complexes[i][F] != complexes[i - 1][F]:
  446. # if this fails the factors of a root were not
  447. # contiguous because a discontinuity should only
  448. # happen once
  449. fs.remove(complexes[i - 1][F])
  450. for i, cmplx in enumerate(complexes):
  451. # negative im part (conj=True) comes before
  452. # positive im part (conj=False)
  453. assert cmplx[C].conj is (i % 2 == 0)
  454. # update cache
  455. cache = {}
  456. # -- collate
  457. for root, currentfactor, _ in complexes:
  458. cache.setdefault(currentfactor, []).append(root)
  459. # -- store
  460. for currentfactor, root in cache.items():
  461. _complexes_cache[currentfactor] = root
  462. return complexes
  463. @classmethod
  464. def _reals_index(cls, reals, index):
  465. """
  466. Map initial real root index to an index in a factor where
  467. the root belongs.
  468. """
  469. i = 0
  470. for j, (_, currentfactor, k) in enumerate(reals):
  471. if index < i + k:
  472. poly, index = currentfactor, 0
  473. for _, currentfactor, _ in reals[:j]:
  474. if currentfactor == poly:
  475. index += 1
  476. return poly, index
  477. else:
  478. i += k
  479. @classmethod
  480. def _complexes_index(cls, complexes, index):
  481. """
  482. Map initial complex root index to an index in a factor where
  483. the root belongs.
  484. """
  485. i = 0
  486. for j, (_, currentfactor, k) in enumerate(complexes):
  487. if index < i + k:
  488. poly, index = currentfactor, 0
  489. for _, currentfactor, _ in complexes[:j]:
  490. if currentfactor == poly:
  491. index += 1
  492. index += len(_reals_cache[poly])
  493. return poly, index
  494. else:
  495. i += k
  496. @classmethod
  497. def _count_roots(cls, roots):
  498. """Count the number of real or complex roots with multiplicities."""
  499. return sum([k for _, _, k in roots])
  500. @classmethod
  501. def _indexed_root(cls, poly, index, lazy=False):
  502. """Get a root of a composite polynomial by index. """
  503. factors = _pure_factors(poly)
  504. # If the given poly is already irreducible, then the index does not
  505. # need to be adjusted, and we can postpone the heavy lifting of
  506. # computing and refining isolating intervals until that is needed.
  507. if lazy and len(factors) == 1 and factors[0][1] == 1:
  508. return poly, index
  509. reals = cls._get_reals(factors)
  510. reals_count = cls._count_roots(reals)
  511. if index < reals_count:
  512. return cls._reals_index(reals, index)
  513. else:
  514. complexes = cls._get_complexes(factors)
  515. return cls._complexes_index(complexes, index - reals_count)
  516. def _ensure_reals_init(self):
  517. """Ensure that our poly has entries in the reals cache. """
  518. if self.poly not in _reals_cache:
  519. self._indexed_root(self.poly, self.index)
  520. def _ensure_complexes_init(self):
  521. """Ensure that our poly has entries in the complexes cache. """
  522. if self.poly not in _complexes_cache:
  523. self._indexed_root(self.poly, self.index)
  524. @classmethod
  525. def _real_roots(cls, poly):
  526. """Get real roots of a composite polynomial. """
  527. factors = _pure_factors(poly)
  528. reals = cls._get_reals(factors)
  529. reals_count = cls._count_roots(reals)
  530. roots = []
  531. for index in range(0, reals_count):
  532. roots.append(cls._reals_index(reals, index))
  533. return roots
  534. def _reset(self):
  535. """
  536. Reset all intervals
  537. """
  538. self._all_roots(self.poly, use_cache=False)
  539. @classmethod
  540. def _all_roots(cls, poly, use_cache=True):
  541. """Get real and complex roots of a composite polynomial. """
  542. factors = _pure_factors(poly)
  543. reals = cls._get_reals(factors, use_cache=use_cache)
  544. reals_count = cls._count_roots(reals)
  545. roots = []
  546. for index in range(0, reals_count):
  547. roots.append(cls._reals_index(reals, index))
  548. complexes = cls._get_complexes(factors, use_cache=use_cache)
  549. complexes_count = cls._count_roots(complexes)
  550. for index in range(0, complexes_count):
  551. roots.append(cls._complexes_index(complexes, index))
  552. return roots
  553. @classmethod
  554. @cacheit
  555. def _roots_trivial(cls, poly, radicals):
  556. """Compute roots in linear, quadratic and binomial cases. """
  557. if poly.degree() == 1:
  558. return roots_linear(poly)
  559. if not radicals:
  560. return None
  561. if poly.degree() == 2:
  562. return roots_quadratic(poly)
  563. elif poly.length() == 2 and poly.TC():
  564. return roots_binomial(poly)
  565. else:
  566. return None
  567. @classmethod
  568. def _preprocess_roots(cls, poly):
  569. """Take heroic measures to make ``poly`` compatible with ``CRootOf``."""
  570. dom = poly.get_domain()
  571. if not dom.is_Exact:
  572. poly = poly.to_exact()
  573. coeff, poly = preprocess_roots(poly)
  574. dom = poly.get_domain()
  575. if not dom.is_ZZ:
  576. raise NotImplementedError(
  577. "sorted roots not supported over %s" % dom)
  578. return coeff, poly
  579. @classmethod
  580. def _postprocess_root(cls, root, radicals):
  581. """Return the root if it is trivial or a ``CRootOf`` object. """
  582. poly, index = root
  583. roots = cls._roots_trivial(poly, radicals)
  584. if roots is not None:
  585. return roots[index]
  586. else:
  587. return cls._new(poly, index)
  588. @classmethod
  589. def _get_roots(cls, method, poly, radicals):
  590. """Return postprocessed roots of specified kind. """
  591. if not poly.is_univariate:
  592. raise PolynomialError("only univariate polynomials are allowed")
  593. # get rid of gen and it's free symbol
  594. d = Dummy()
  595. poly = poly.subs(poly.gen, d)
  596. x = symbols('x')
  597. # see what others are left and select x or a numbered x
  598. # that doesn't clash
  599. free_names = {str(i) for i in poly.free_symbols}
  600. for x in chain((symbols('x'),), numbered_symbols('x')):
  601. if x.name not in free_names:
  602. poly = poly.xreplace({d: x})
  603. break
  604. coeff, poly = cls._preprocess_roots(poly)
  605. roots = []
  606. for root in getattr(cls, method)(poly):
  607. roots.append(coeff*cls._postprocess_root(root, radicals))
  608. return roots
  609. @classmethod
  610. def clear_cache(cls):
  611. """Reset cache for reals and complexes.
  612. The intervals used to approximate a root instance are updated
  613. as needed. When a request is made to see the intervals, the
  614. most current values are shown. `clear_cache` will reset all
  615. CRootOf instances back to their original state.
  616. See Also
  617. ========
  618. _reset
  619. """
  620. global _reals_cache, _complexes_cache
  621. _reals_cache = _pure_key_dict()
  622. _complexes_cache = _pure_key_dict()
  623. def _get_interval(self):
  624. """Internal function for retrieving isolation interval from cache. """
  625. self._ensure_reals_init()
  626. if self.is_real:
  627. return _reals_cache[self.poly][self.index]
  628. else:
  629. reals_count = len(_reals_cache[self.poly])
  630. self._ensure_complexes_init()
  631. return _complexes_cache[self.poly][self.index - reals_count]
  632. def _set_interval(self, interval):
  633. """Internal function for updating isolation interval in cache. """
  634. self._ensure_reals_init()
  635. if self.is_real:
  636. _reals_cache[self.poly][self.index] = interval
  637. else:
  638. reals_count = len(_reals_cache[self.poly])
  639. self._ensure_complexes_init()
  640. _complexes_cache[self.poly][self.index - reals_count] = interval
  641. def _eval_subs(self, old, new):
  642. # don't allow subs to change anything
  643. return self
  644. def _eval_conjugate(self):
  645. if self.is_real:
  646. return self
  647. expr, i = self.args
  648. return self.func(expr, i + (1 if self._get_interval().conj else -1))
  649. def eval_approx(self, n, return_mpmath=False):
  650. """Evaluate this complex root to the given precision.
  651. This uses secant method and root bounds are used to both
  652. generate an initial guess and to check that the root
  653. returned is valid. If ever the method converges outside the
  654. root bounds, the bounds will be made smaller and updated.
  655. """
  656. prec = dps_to_prec(n)
  657. with workprec(prec):
  658. g = self.poly.gen
  659. if not g.is_Symbol:
  660. d = Dummy('x')
  661. if self.is_imaginary:
  662. d *= I
  663. func = lambdify(d, self.expr.subs(g, d))
  664. else:
  665. expr = self.expr
  666. if self.is_imaginary:
  667. expr = self.expr.subs(g, I*g)
  668. func = lambdify(g, expr)
  669. interval = self._get_interval()
  670. while True:
  671. if self.is_real:
  672. a = mpf(str(interval.a))
  673. b = mpf(str(interval.b))
  674. if a == b:
  675. root = a
  676. break
  677. x0 = mpf(str(interval.center))
  678. x1 = x0 + mpf(str(interval.dx))/4
  679. elif self.is_imaginary:
  680. a = mpf(str(interval.ay))
  681. b = mpf(str(interval.by))
  682. if a == b:
  683. root = mpc(mpf('0'), a)
  684. break
  685. x0 = mpf(str(interval.center[1]))
  686. x1 = x0 + mpf(str(interval.dy))/4
  687. else:
  688. ax = mpf(str(interval.ax))
  689. bx = mpf(str(interval.bx))
  690. ay = mpf(str(interval.ay))
  691. by = mpf(str(interval.by))
  692. if ax == bx and ay == by:
  693. root = mpc(ax, ay)
  694. break
  695. x0 = mpc(*map(str, interval.center))
  696. x1 = x0 + mpc(*map(str, (interval.dx, interval.dy)))/4
  697. try:
  698. # without a tolerance, this will return when (to within
  699. # the given precision) x_i == x_{i-1}
  700. root = findroot(func, (x0, x1))
  701. # If the (real or complex) root is not in the 'interval',
  702. # then keep refining the interval. This happens if findroot
  703. # accidentally finds a different root outside of this
  704. # interval because our initial estimate 'x0' was not close
  705. # enough. It is also possible that the secant method will
  706. # get trapped by a max/min in the interval; the root
  707. # verification by findroot will raise a ValueError in this
  708. # case and the interval will then be tightened -- and
  709. # eventually the root will be found.
  710. #
  711. # It is also possible that findroot will not have any
  712. # successful iterations to process (in which case it
  713. # will fail to initialize a variable that is tested
  714. # after the iterations and raise an UnboundLocalError).
  715. if self.is_real or self.is_imaginary:
  716. if not bool(root.imag) == self.is_real and (
  717. a <= root <= b):
  718. if self.is_imaginary:
  719. root = mpc(mpf('0'), root.real)
  720. break
  721. elif (ax <= root.real <= bx and ay <= root.imag <= by):
  722. break
  723. except (UnboundLocalError, ValueError):
  724. pass
  725. interval = interval.refine()
  726. # update the interval so we at least (for this precision or
  727. # less) don't have much work to do to recompute the root
  728. self._set_interval(interval)
  729. if return_mpmath:
  730. return root
  731. return (Float._new(root.real._mpf_, prec) +
  732. I*Float._new(root.imag._mpf_, prec))
  733. def _eval_evalf(self, prec, **kwargs):
  734. """Evaluate this complex root to the given precision."""
  735. # all kwargs are ignored
  736. return self.eval_rational(n=prec_to_dps(prec))._evalf(prec)
  737. def eval_rational(self, dx=None, dy=None, n=15):
  738. """
  739. Return a Rational approximation of ``self`` that has real
  740. and imaginary component approximations that are within ``dx``
  741. and ``dy`` of the true values, respectively. Alternatively,
  742. ``n`` digits of precision can be specified.
  743. The interval is refined with bisection and is sure to
  744. converge. The root bounds are updated when the refinement
  745. is complete so recalculation at the same or lesser precision
  746. will not have to repeat the refinement and should be much
  747. faster.
  748. The following example first obtains Rational approximation to
  749. 1e-8 accuracy for all roots of the 4-th order Legendre
  750. polynomial. Since the roots are all less than 1, this will
  751. ensure the decimal representation of the approximation will be
  752. correct (including rounding) to 6 digits:
  753. >>> from sympy import legendre_poly, Symbol
  754. >>> x = Symbol("x")
  755. >>> p = legendre_poly(4, x, polys=True)
  756. >>> r = p.real_roots()[-1]
  757. >>> r.eval_rational(10**-8).n(6)
  758. 0.861136
  759. It is not necessary to a two-step calculation, however: the
  760. decimal representation can be computed directly:
  761. >>> r.evalf(17)
  762. 0.86113631159405258
  763. """
  764. dy = dy or dx
  765. if dx:
  766. rtol = None
  767. dx = dx if isinstance(dx, Rational) else Rational(str(dx))
  768. dy = dy if isinstance(dy, Rational) else Rational(str(dy))
  769. else:
  770. # 5 binary (or 2 decimal) digits are needed to ensure that
  771. # a given digit is correctly rounded
  772. # prec_to_dps(dps_to_prec(n) + 5) - n <= 2 (tested for
  773. # n in range(1000000)
  774. rtol = S(10)**-(n + 2) # +2 for guard digits
  775. interval = self._get_interval()
  776. while True:
  777. if self.is_real:
  778. if rtol:
  779. dx = abs(interval.center*rtol)
  780. interval = interval.refine_size(dx=dx)
  781. c = interval.center
  782. real = Rational(c)
  783. imag = S.Zero
  784. if not rtol or interval.dx < abs(c*rtol):
  785. break
  786. elif self.is_imaginary:
  787. if rtol:
  788. dy = abs(interval.center[1]*rtol)
  789. dx = 1
  790. interval = interval.refine_size(dx=dx, dy=dy)
  791. c = interval.center[1]
  792. imag = Rational(c)
  793. real = S.Zero
  794. if not rtol or interval.dy < abs(c*rtol):
  795. break
  796. else:
  797. if rtol:
  798. dx = abs(interval.center[0]*rtol)
  799. dy = abs(interval.center[1]*rtol)
  800. interval = interval.refine_size(dx, dy)
  801. c = interval.center
  802. real, imag = map(Rational, c)
  803. if not rtol or (
  804. interval.dx < abs(c[0]*rtol) and
  805. interval.dy < abs(c[1]*rtol)):
  806. break
  807. # update the interval so we at least (for this precision or
  808. # less) don't have much work to do to recompute the root
  809. self._set_interval(interval)
  810. return real + I*imag
  811. CRootOf = ComplexRootOf
  812. @dispatch(ComplexRootOf, ComplexRootOf)
  813. def _eval_is_eq(lhs, rhs): # noqa:F811
  814. # if we use is_eq to check here, we get infinite recurion
  815. return lhs == rhs
  816. @dispatch(ComplexRootOf, Basic) # type:ignore
  817. def _eval_is_eq(lhs, rhs): # noqa:F811
  818. # CRootOf represents a Root, so if rhs is that root, it should set
  819. # the expression to zero *and* it should be in the interval of the
  820. # CRootOf instance. It must also be a number that agrees with the
  821. # is_real value of the CRootOf instance.
  822. if not rhs.is_number:
  823. return None
  824. if not rhs.is_finite:
  825. return False
  826. z = lhs.expr.subs(lhs.expr.free_symbols.pop(), rhs).is_zero
  827. if z is False: # all roots will make z True but we don't know
  828. # whether this is the right root if z is True
  829. return False
  830. o = rhs.is_real, rhs.is_imaginary
  831. s = lhs.is_real, lhs.is_imaginary
  832. assert None not in s # this is part of initial refinement
  833. if o != s and None not in o:
  834. return False
  835. re, im = rhs.as_real_imag()
  836. if lhs.is_real:
  837. if im:
  838. return False
  839. i = lhs._get_interval()
  840. a, b = [Rational(str(_)) for _ in (i.a, i.b)]
  841. return sympify(a <= rhs and rhs <= b)
  842. i = lhs._get_interval()
  843. r1, r2, i1, i2 = [Rational(str(j)) for j in (
  844. i.ax, i.bx, i.ay, i.by)]
  845. return is_le(r1, re) and is_le(re,r2) and is_le(i1,im) and is_le(im,i2)
  846. @public
  847. class RootSum(Expr):
  848. """Represents a sum of all roots of a univariate polynomial. """
  849. __slots__ = ('poly', 'fun', 'auto')
  850. def __new__(cls, expr, func=None, x=None, auto=True, quadratic=False):
  851. """Construct a new ``RootSum`` instance of roots of a polynomial."""
  852. coeff, poly = cls._transform(expr, x)
  853. if not poly.is_univariate:
  854. raise MultivariatePolynomialError(
  855. "only univariate polynomials are allowed")
  856. if func is None:
  857. func = Lambda(poly.gen, poly.gen)
  858. else:
  859. is_func = getattr(func, 'is_Function', False)
  860. if is_func and 1 in func.nargs:
  861. if not isinstance(func, Lambda):
  862. func = Lambda(poly.gen, func(poly.gen))
  863. else:
  864. raise ValueError(
  865. "expected a univariate function, got %s" % func)
  866. var, expr = func.variables[0], func.expr
  867. if coeff is not S.One:
  868. expr = expr.subs(var, coeff*var)
  869. deg = poly.degree()
  870. if not expr.has(var):
  871. return deg*expr
  872. if expr.is_Add:
  873. add_const, expr = expr.as_independent(var)
  874. else:
  875. add_const = S.Zero
  876. if expr.is_Mul:
  877. mul_const, expr = expr.as_independent(var)
  878. else:
  879. mul_const = S.One
  880. func = Lambda(var, expr)
  881. rational = cls._is_func_rational(poly, func)
  882. factors, terms = _pure_factors(poly), []
  883. for poly, k in factors:
  884. if poly.is_linear:
  885. term = func(roots_linear(poly)[0])
  886. elif quadratic and poly.is_quadratic:
  887. term = sum(map(func, roots_quadratic(poly)))
  888. else:
  889. if not rational or not auto:
  890. term = cls._new(poly, func, auto)
  891. else:
  892. term = cls._rational_case(poly, func)
  893. terms.append(k*term)
  894. return mul_const*Add(*terms) + deg*add_const
  895. @classmethod
  896. def _new(cls, poly, func, auto=True):
  897. """Construct new raw ``RootSum`` instance. """
  898. obj = Expr.__new__(cls)
  899. obj.poly = poly
  900. obj.fun = func
  901. obj.auto = auto
  902. return obj
  903. @classmethod
  904. def new(cls, poly, func, auto=True):
  905. """Construct new ``RootSum`` instance. """
  906. if not func.expr.has(*func.variables):
  907. return func.expr
  908. rational = cls._is_func_rational(poly, func)
  909. if not rational or not auto:
  910. return cls._new(poly, func, auto)
  911. else:
  912. return cls._rational_case(poly, func)
  913. @classmethod
  914. def _transform(cls, expr, x):
  915. """Transform an expression to a polynomial. """
  916. poly = PurePoly(expr, x, greedy=False)
  917. return preprocess_roots(poly)
  918. @classmethod
  919. def _is_func_rational(cls, poly, func):
  920. """Check if a lambda is a rational function. """
  921. var, expr = func.variables[0], func.expr
  922. return expr.is_rational_function(var)
  923. @classmethod
  924. def _rational_case(cls, poly, func):
  925. """Handle the rational function case. """
  926. roots = symbols('r:%d' % poly.degree())
  927. var, expr = func.variables[0], func.expr
  928. f = sum(expr.subs(var, r) for r in roots)
  929. p, q = together(f).as_numer_denom()
  930. domain = QQ[roots]
  931. p = p.expand()
  932. q = q.expand()
  933. try:
  934. p = Poly(p, domain=domain, expand=False)
  935. except GeneratorsNeeded:
  936. p, p_coeff = None, (p,)
  937. else:
  938. p_monom, p_coeff = zip(*p.terms())
  939. try:
  940. q = Poly(q, domain=domain, expand=False)
  941. except GeneratorsNeeded:
  942. q, q_coeff = None, (q,)
  943. else:
  944. q_monom, q_coeff = zip(*q.terms())
  945. coeffs, mapping = symmetrize(p_coeff + q_coeff, formal=True)
  946. formulas, values = viete(poly, roots), []
  947. for (sym, _), (_, val) in zip(mapping, formulas):
  948. values.append((sym, val))
  949. for i, (coeff, _) in enumerate(coeffs):
  950. coeffs[i] = coeff.subs(values)
  951. n = len(p_coeff)
  952. p_coeff = coeffs[:n]
  953. q_coeff = coeffs[n:]
  954. if p is not None:
  955. p = Poly(dict(zip(p_monom, p_coeff)), *p.gens).as_expr()
  956. else:
  957. (p,) = p_coeff
  958. if q is not None:
  959. q = Poly(dict(zip(q_monom, q_coeff)), *q.gens).as_expr()
  960. else:
  961. (q,) = q_coeff
  962. return factor(p/q)
  963. def _hashable_content(self):
  964. return (self.poly, self.fun)
  965. @property
  966. def expr(self):
  967. return self.poly.as_expr()
  968. @property
  969. def args(self):
  970. return (self.expr, self.fun, self.poly.gen)
  971. @property
  972. def free_symbols(self):
  973. return self.poly.free_symbols | self.fun.free_symbols
  974. @property
  975. def is_commutative(self):
  976. return True
  977. def doit(self, **hints):
  978. if not hints.get('roots', True):
  979. return self
  980. _roots = roots(self.poly, multiple=True)
  981. if len(_roots) < self.poly.degree():
  982. return self
  983. else:
  984. return Add(*[self.fun(r) for r in _roots])
  985. def _eval_evalf(self, prec):
  986. try:
  987. _roots = self.poly.nroots(n=prec_to_dps(prec))
  988. except (DomainError, PolynomialError):
  989. return self
  990. else:
  991. return Add(*[self.fun(r) for r in _roots])
  992. def _eval_derivative(self, x):
  993. var, expr = self.fun.args
  994. func = Lambda(var, expr.diff(x))
  995. return self.new(self.poly, func, self.auto)