_root_scalar.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. """
  2. Unified interfaces to root finding algorithms for real or complex
  3. scalar functions.
  4. Functions
  5. ---------
  6. - root : find a root of a scalar function.
  7. """
  8. import numpy as np
  9. from . import _zeros_py as optzeros
  10. __all__ = ['root_scalar']
  11. ROOT_SCALAR_METHODS = ['bisect', 'brentq', 'brenth', 'ridder', 'toms748',
  12. 'newton', 'secant', 'halley']
  13. class MemoizeDer:
  14. """Decorator that caches the value and derivative(s) of function each
  15. time it is called.
  16. This is a simplistic memoizer that calls and caches a single value
  17. of `f(x, *args)`.
  18. It assumes that `args` does not change between invocations.
  19. It supports the use case of a root-finder where `args` is fixed,
  20. `x` changes, and only rarely, if at all, does x assume the same value
  21. more than once."""
  22. def __init__(self, fun):
  23. self.fun = fun
  24. self.vals = None
  25. self.x = None
  26. self.n_calls = 0
  27. def __call__(self, x, *args):
  28. r"""Calculate f or use cached value if available"""
  29. # Derivative may be requested before the function itself, always check
  30. if self.vals is None or x != self.x:
  31. fg = self.fun(x, *args)
  32. self.x = x
  33. self.n_calls += 1
  34. self.vals = fg[:]
  35. return self.vals[0]
  36. def fprime(self, x, *args):
  37. r"""Calculate f' or use a cached value if available"""
  38. if self.vals is None or x != self.x:
  39. self(x, *args)
  40. return self.vals[1]
  41. def fprime2(self, x, *args):
  42. r"""Calculate f'' or use a cached value if available"""
  43. if self.vals is None or x != self.x:
  44. self(x, *args)
  45. return self.vals[2]
  46. def ncalls(self):
  47. return self.n_calls
  48. def root_scalar(f, args=(), method=None, bracket=None,
  49. fprime=None, fprime2=None,
  50. x0=None, x1=None,
  51. xtol=None, rtol=None, maxiter=None,
  52. options=None):
  53. """
  54. Find a root of a scalar function.
  55. Parameters
  56. ----------
  57. f : callable
  58. A function to find a root of.
  59. args : tuple, optional
  60. Extra arguments passed to the objective function and its derivative(s).
  61. method : str, optional
  62. Type of solver. Should be one of
  63. - 'bisect' :ref:`(see here) <optimize.root_scalar-bisect>`
  64. - 'brentq' :ref:`(see here) <optimize.root_scalar-brentq>`
  65. - 'brenth' :ref:`(see here) <optimize.root_scalar-brenth>`
  66. - 'ridder' :ref:`(see here) <optimize.root_scalar-ridder>`
  67. - 'toms748' :ref:`(see here) <optimize.root_scalar-toms748>`
  68. - 'newton' :ref:`(see here) <optimize.root_scalar-newton>`
  69. - 'secant' :ref:`(see here) <optimize.root_scalar-secant>`
  70. - 'halley' :ref:`(see here) <optimize.root_scalar-halley>`
  71. bracket: A sequence of 2 floats, optional
  72. An interval bracketing a root. `f(x, *args)` must have different
  73. signs at the two endpoints.
  74. x0 : float, optional
  75. Initial guess.
  76. x1 : float, optional
  77. A second guess.
  78. fprime : bool or callable, optional
  79. If `fprime` is a boolean and is True, `f` is assumed to return the
  80. value of the objective function and of the derivative.
  81. `fprime` can also be a callable returning the derivative of `f`. In
  82. this case, it must accept the same arguments as `f`.
  83. fprime2 : bool or callable, optional
  84. If `fprime2` is a boolean and is True, `f` is assumed to return the
  85. value of the objective function and of the
  86. first and second derivatives.
  87. `fprime2` can also be a callable returning the second derivative of `f`.
  88. In this case, it must accept the same arguments as `f`.
  89. xtol : float, optional
  90. Tolerance (absolute) for termination.
  91. rtol : float, optional
  92. Tolerance (relative) for termination.
  93. maxiter : int, optional
  94. Maximum number of iterations.
  95. options : dict, optional
  96. A dictionary of solver options. E.g., ``k``, see
  97. :obj:`show_options()` for details.
  98. Returns
  99. -------
  100. sol : RootResults
  101. The solution represented as a ``RootResults`` object.
  102. Important attributes are: ``root`` the solution , ``converged`` a
  103. boolean flag indicating if the algorithm exited successfully and
  104. ``flag`` which describes the cause of the termination. See
  105. `RootResults` for a description of other attributes.
  106. See also
  107. --------
  108. show_options : Additional options accepted by the solvers
  109. root : Find a root of a vector function.
  110. Notes
  111. -----
  112. This section describes the available solvers that can be selected by the
  113. 'method' parameter.
  114. The default is to use the best method available for the situation
  115. presented.
  116. If a bracket is provided, it may use one of the bracketing methods.
  117. If a derivative and an initial value are specified, it may
  118. select one of the derivative-based methods.
  119. If no method is judged applicable, it will raise an Exception.
  120. Arguments for each method are as follows (x=required, o=optional).
  121. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  122. | method | f | args | bracket | x0 | x1 | fprime | fprime2 | xtol | rtol | maxiter | options |
  123. +===============================================+===+======+=========+====+====+========+=========+======+======+=========+=========+
  124. | :ref:`bisect <optimize.root_scalar-bisect>` | x | o | x | | | | | o | o | o | o |
  125. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  126. | :ref:`brentq <optimize.root_scalar-brentq>` | x | o | x | | | | | o | o | o | o |
  127. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  128. | :ref:`brenth <optimize.root_scalar-brenth>` | x | o | x | | | | | o | o | o | o |
  129. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  130. | :ref:`ridder <optimize.root_scalar-ridder>` | x | o | x | | | | | o | o | o | o |
  131. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  132. | :ref:`toms748 <optimize.root_scalar-toms748>` | x | o | x | | | | | o | o | o | o |
  133. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  134. | :ref:`newton <optimize.root_scalar-newton>` | x | o | | x | | x | | o | o | o | o |
  135. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  136. | :ref:`secant <optimize.root_scalar-secant>` | x | o | | x | x | | | o | o | o | o |
  137. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  138. | :ref:`halley <optimize.root_scalar-halley>` | x | o | | x | | x | x | o | o | o | o |
  139. +-----------------------------------------------+---+------+---------+----+----+--------+---------+------+------+---------+---------+
  140. Examples
  141. --------
  142. Find the root of a simple cubic
  143. >>> from scipy import optimize
  144. >>> def f(x):
  145. ... return (x**3 - 1) # only one real root at x = 1
  146. >>> def fprime(x):
  147. ... return 3*x**2
  148. The `brentq` method takes as input a bracket
  149. >>> sol = optimize.root_scalar(f, bracket=[0, 3], method='brentq')
  150. >>> sol.root, sol.iterations, sol.function_calls
  151. (1.0, 10, 11)
  152. The `newton` method takes as input a single point and uses the
  153. derivative(s).
  154. >>> sol = optimize.root_scalar(f, x0=0.2, fprime=fprime, method='newton')
  155. >>> sol.root, sol.iterations, sol.function_calls
  156. (1.0, 11, 22)
  157. The function can provide the value and derivative(s) in a single call.
  158. >>> def f_p_pp(x):
  159. ... return (x**3 - 1), 3*x**2, 6*x
  160. >>> sol = optimize.root_scalar(
  161. ... f_p_pp, x0=0.2, fprime=True, method='newton'
  162. ... )
  163. >>> sol.root, sol.iterations, sol.function_calls
  164. (1.0, 11, 11)
  165. >>> sol = optimize.root_scalar(
  166. ... f_p_pp, x0=0.2, fprime=True, fprime2=True, method='halley'
  167. ... )
  168. >>> sol.root, sol.iterations, sol.function_calls
  169. (1.0, 7, 8)
  170. """
  171. if not isinstance(args, tuple):
  172. args = (args,)
  173. if options is None:
  174. options = {}
  175. # fun also returns the derivative(s)
  176. is_memoized = False
  177. if fprime2 is not None and not callable(fprime2):
  178. if bool(fprime2):
  179. f = MemoizeDer(f)
  180. is_memoized = True
  181. fprime2 = f.fprime2
  182. fprime = f.fprime
  183. else:
  184. fprime2 = None
  185. if fprime is not None and not callable(fprime):
  186. if bool(fprime):
  187. f = MemoizeDer(f)
  188. is_memoized = True
  189. fprime = f.fprime
  190. else:
  191. fprime = None
  192. # respect solver-specific default tolerances - only pass in if actually set
  193. kwargs = {}
  194. for k in ['xtol', 'rtol', 'maxiter']:
  195. v = locals().get(k)
  196. if v is not None:
  197. kwargs[k] = v
  198. # Set any solver-specific options
  199. if options:
  200. kwargs.update(options)
  201. # Always request full_output from the underlying method as _root_scalar
  202. # always returns a RootResults object
  203. kwargs.update(full_output=True, disp=False)
  204. # Pick a method if not specified.
  205. # Use the "best" method available for the situation.
  206. if not method:
  207. if bracket:
  208. method = 'brentq'
  209. elif x0 is not None:
  210. if fprime:
  211. if fprime2:
  212. method = 'halley'
  213. else:
  214. method = 'newton'
  215. else:
  216. method = 'secant'
  217. if not method:
  218. raise ValueError('Unable to select a solver as neither bracket '
  219. 'nor starting point provided.')
  220. meth = method.lower()
  221. map2underlying = {'halley': 'newton', 'secant': 'newton'}
  222. try:
  223. methodc = getattr(optzeros, map2underlying.get(meth, meth))
  224. except AttributeError as e:
  225. raise ValueError('Unknown solver %s' % meth) from e
  226. if meth in ['bisect', 'ridder', 'brentq', 'brenth', 'toms748']:
  227. if not isinstance(bracket, (list, tuple, np.ndarray)):
  228. raise ValueError('Bracket needed for %s' % method)
  229. a, b = bracket[:2]
  230. r, sol = methodc(f, a, b, args=args, **kwargs)
  231. elif meth in ['secant']:
  232. if x0 is None:
  233. raise ValueError('x0 must not be None for %s' % method)
  234. if x1 is None:
  235. raise ValueError('x1 must not be None for %s' % method)
  236. if 'xtol' in kwargs:
  237. kwargs['tol'] = kwargs.pop('xtol')
  238. r, sol = methodc(f, x0, args=args, fprime=None, fprime2=None,
  239. x1=x1, **kwargs)
  240. elif meth in ['newton']:
  241. if x0 is None:
  242. raise ValueError('x0 must not be None for %s' % method)
  243. if not fprime:
  244. raise ValueError('fprime must be specified for %s' % method)
  245. if 'xtol' in kwargs:
  246. kwargs['tol'] = kwargs.pop('xtol')
  247. r, sol = methodc(f, x0, args=args, fprime=fprime, fprime2=None,
  248. **kwargs)
  249. elif meth in ['halley']:
  250. if x0 is None:
  251. raise ValueError('x0 must not be None for %s' % method)
  252. if not fprime:
  253. raise ValueError('fprime must be specified for %s' % method)
  254. if not fprime2:
  255. raise ValueError('fprime2 must be specified for %s' % method)
  256. if 'xtol' in kwargs:
  257. kwargs['tol'] = kwargs.pop('xtol')
  258. r, sol = methodc(f, x0, args=args, fprime=fprime, fprime2=fprime2, **kwargs)
  259. else:
  260. raise ValueError('Unknown solver %s' % method)
  261. if is_memoized:
  262. # Replace the function_calls count with the memoized count.
  263. # Avoids double and triple-counting.
  264. n_calls = f.n_calls
  265. sol.function_calls = n_calls
  266. return sol
  267. def _root_scalar_brentq_doc():
  268. r"""
  269. Options
  270. -------
  271. args : tuple, optional
  272. Extra arguments passed to the objective function.
  273. bracket: A sequence of 2 floats, optional
  274. An interval bracketing a root. `f(x, *args)` must have different
  275. signs at the two endpoints.
  276. xtol : float, optional
  277. Tolerance (absolute) for termination.
  278. rtol : float, optional
  279. Tolerance (relative) for termination.
  280. maxiter : int, optional
  281. Maximum number of iterations.
  282. options: dict, optional
  283. Specifies any method-specific options not covered above
  284. """
  285. pass
  286. def _root_scalar_brenth_doc():
  287. r"""
  288. Options
  289. -------
  290. args : tuple, optional
  291. Extra arguments passed to the objective function.
  292. bracket: A sequence of 2 floats, optional
  293. An interval bracketing a root. `f(x, *args)` must have different
  294. signs at the two endpoints.
  295. xtol : float, optional
  296. Tolerance (absolute) for termination.
  297. rtol : float, optional
  298. Tolerance (relative) for termination.
  299. maxiter : int, optional
  300. Maximum number of iterations.
  301. options: dict, optional
  302. Specifies any method-specific options not covered above.
  303. """
  304. pass
  305. def _root_scalar_toms748_doc():
  306. r"""
  307. Options
  308. -------
  309. args : tuple, optional
  310. Extra arguments passed to the objective function.
  311. bracket: A sequence of 2 floats, optional
  312. An interval bracketing a root. `f(x, *args)` must have different
  313. signs at the two endpoints.
  314. xtol : float, optional
  315. Tolerance (absolute) for termination.
  316. rtol : float, optional
  317. Tolerance (relative) for termination.
  318. maxiter : int, optional
  319. Maximum number of iterations.
  320. options: dict, optional
  321. Specifies any method-specific options not covered above.
  322. """
  323. pass
  324. def _root_scalar_secant_doc():
  325. r"""
  326. Options
  327. -------
  328. args : tuple, optional
  329. Extra arguments passed to the objective function.
  330. xtol : float, optional
  331. Tolerance (absolute) for termination.
  332. rtol : float, optional
  333. Tolerance (relative) for termination.
  334. maxiter : int, optional
  335. Maximum number of iterations.
  336. x0 : float, required
  337. Initial guess.
  338. x1 : float, required
  339. A second guess.
  340. options: dict, optional
  341. Specifies any method-specific options not covered above.
  342. """
  343. pass
  344. def _root_scalar_newton_doc():
  345. r"""
  346. Options
  347. -------
  348. args : tuple, optional
  349. Extra arguments passed to the objective function and its derivative.
  350. xtol : float, optional
  351. Tolerance (absolute) for termination.
  352. rtol : float, optional
  353. Tolerance (relative) for termination.
  354. maxiter : int, optional
  355. Maximum number of iterations.
  356. x0 : float, required
  357. Initial guess.
  358. fprime : bool or callable, optional
  359. If `fprime` is a boolean and is True, `f` is assumed to return the
  360. value of derivative along with the objective function.
  361. `fprime` can also be a callable returning the derivative of `f`. In
  362. this case, it must accept the same arguments as `f`.
  363. options: dict, optional
  364. Specifies any method-specific options not covered above.
  365. """
  366. pass
  367. def _root_scalar_halley_doc():
  368. r"""
  369. Options
  370. -------
  371. args : tuple, optional
  372. Extra arguments passed to the objective function and its derivatives.
  373. xtol : float, optional
  374. Tolerance (absolute) for termination.
  375. rtol : float, optional
  376. Tolerance (relative) for termination.
  377. maxiter : int, optional
  378. Maximum number of iterations.
  379. x0 : float, required
  380. Initial guess.
  381. fprime : bool or callable, required
  382. If `fprime` is a boolean and is True, `f` is assumed to return the
  383. value of derivative along with the objective function.
  384. `fprime` can also be a callable returning the derivative of `f`. In
  385. this case, it must accept the same arguments as `f`.
  386. fprime2 : bool or callable, required
  387. If `fprime2` is a boolean and is True, `f` is assumed to return the
  388. value of 1st and 2nd derivatives along with the objective function.
  389. `fprime2` can also be a callable returning the 2nd derivative of `f`.
  390. In this case, it must accept the same arguments as `f`.
  391. options: dict, optional
  392. Specifies any method-specific options not covered above.
  393. """
  394. pass
  395. def _root_scalar_ridder_doc():
  396. r"""
  397. Options
  398. -------
  399. args : tuple, optional
  400. Extra arguments passed to the objective function.
  401. bracket: A sequence of 2 floats, optional
  402. An interval bracketing a root. `f(x, *args)` must have different
  403. signs at the two endpoints.
  404. xtol : float, optional
  405. Tolerance (absolute) for termination.
  406. rtol : float, optional
  407. Tolerance (relative) for termination.
  408. maxiter : int, optional
  409. Maximum number of iterations.
  410. options: dict, optional
  411. Specifies any method-specific options not covered above.
  412. """
  413. pass
  414. def _root_scalar_bisect_doc():
  415. r"""
  416. Options
  417. -------
  418. args : tuple, optional
  419. Extra arguments passed to the objective function.
  420. bracket: A sequence of 2 floats, optional
  421. An interval bracketing a root. `f(x, *args)` must have different
  422. signs at the two endpoints.
  423. xtol : float, optional
  424. Tolerance (absolute) for termination.
  425. rtol : float, optional
  426. Tolerance (relative) for termination.
  427. maxiter : int, optional
  428. Maximum number of iterations.
  429. options: dict, optional
  430. Specifies any method-specific options not covered above.
  431. """
  432. pass