test_zeros.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. import pytest
  2. from math import sqrt, exp, sin, cos
  3. from functools import lru_cache
  4. from numpy.testing import (assert_warns, assert_,
  5. assert_allclose,
  6. assert_equal,
  7. assert_array_equal,
  8. suppress_warnings)
  9. import numpy as np
  10. from numpy import finfo, power, nan, isclose
  11. from scipy.optimize import _zeros_py as zeros, newton, root_scalar
  12. from scipy._lib._util import getfullargspec_no_self as _getfullargspec
  13. # Import testing parameters
  14. from scipy.optimize._tstutils import get_tests, functions as tstutils_functions, fstrings as tstutils_fstrings
  15. TOL = 4*np.finfo(float).eps # tolerance
  16. _FLOAT_EPS = finfo(float).eps
  17. # A few test functions used frequently:
  18. # # A simple quadratic, (x-1)^2 - 1
  19. def f1(x):
  20. return x ** 2 - 2 * x - 1
  21. def f1_1(x):
  22. return 2 * x - 2
  23. def f1_2(x):
  24. return 2.0 + 0 * x
  25. def f1_and_p_and_pp(x):
  26. return f1(x), f1_1(x), f1_2(x)
  27. # Simple transcendental function
  28. def f2(x):
  29. return exp(x) - cos(x)
  30. def f2_1(x):
  31. return exp(x) + sin(x)
  32. def f2_2(x):
  33. return exp(x) + cos(x)
  34. # lru cached function
  35. @lru_cache()
  36. def f_lrucached(x):
  37. return x
  38. class TestBasic:
  39. def run_check_by_name(self, name, smoothness=0, **kwargs):
  40. a = .5
  41. b = sqrt(3)
  42. xtol = 4*np.finfo(float).eps
  43. rtol = 4*np.finfo(float).eps
  44. for function, fname in zip(tstutils_functions, tstutils_fstrings):
  45. if smoothness > 0 and fname in ['f4', 'f5', 'f6']:
  46. continue
  47. r = root_scalar(function, method=name, bracket=[a, b], x0=a,
  48. xtol=xtol, rtol=rtol, **kwargs)
  49. zero = r.root
  50. assert_(r.converged)
  51. assert_allclose(zero, 1.0, atol=xtol, rtol=rtol,
  52. err_msg='method %s, function %s' % (name, fname))
  53. def run_check(self, method, name):
  54. a = .5
  55. b = sqrt(3)
  56. xtol = 4 * _FLOAT_EPS
  57. rtol = 4 * _FLOAT_EPS
  58. for function, fname in zip(tstutils_functions, tstutils_fstrings):
  59. zero, r = method(function, a, b, xtol=xtol, rtol=rtol,
  60. full_output=True)
  61. assert_(r.converged)
  62. assert_allclose(zero, 1.0, atol=xtol, rtol=rtol,
  63. err_msg='method %s, function %s' % (name, fname))
  64. def run_check_lru_cached(self, method, name):
  65. # check that https://github.com/scipy/scipy/issues/10846 is fixed
  66. a = -1
  67. b = 1
  68. zero, r = method(f_lrucached, a, b, full_output=True)
  69. assert_(r.converged)
  70. assert_allclose(zero, 0,
  71. err_msg='method %s, function %s' % (name, 'f_lrucached'))
  72. def _run_one_test(self, tc, method, sig_args_keys=None,
  73. sig_kwargs_keys=None, **kwargs):
  74. method_args = []
  75. for k in sig_args_keys or []:
  76. if k not in tc:
  77. # If a,b not present use x0, x1. Similarly for f and func
  78. k = {'a': 'x0', 'b': 'x1', 'func': 'f'}.get(k, k)
  79. method_args.append(tc[k])
  80. method_kwargs = dict(**kwargs)
  81. method_kwargs.update({'full_output': True, 'disp': False})
  82. for k in sig_kwargs_keys or []:
  83. method_kwargs[k] = tc[k]
  84. root = tc.get('root')
  85. func_args = tc.get('args', ())
  86. try:
  87. r, rr = method(*method_args, args=func_args, **method_kwargs)
  88. return root, rr, tc
  89. except Exception:
  90. return root, zeros.RootResults(nan, -1, -1, zeros._EVALUEERR), tc
  91. def run_tests(self, tests, method, name,
  92. xtol=4 * _FLOAT_EPS, rtol=4 * _FLOAT_EPS,
  93. known_fail=None, **kwargs):
  94. r"""Run test-cases using the specified method and the supplied signature.
  95. Extract the arguments for the method call from the test case
  96. dictionary using the supplied keys for the method's signature."""
  97. # The methods have one of two base signatures:
  98. # (f, a, b, **kwargs) # newton
  99. # (func, x0, **kwargs) # bisect/brentq/...
  100. sig = _getfullargspec(method) # FullArgSpec with args, varargs, varkw, defaults, ...
  101. assert_(not sig.kwonlyargs)
  102. nDefaults = len(sig.defaults)
  103. nRequired = len(sig.args) - nDefaults
  104. sig_args_keys = sig.args[:nRequired]
  105. sig_kwargs_keys = []
  106. if name in ['secant', 'newton', 'halley']:
  107. if name in ['newton', 'halley']:
  108. sig_kwargs_keys.append('fprime')
  109. if name in ['halley']:
  110. sig_kwargs_keys.append('fprime2')
  111. kwargs['tol'] = xtol
  112. else:
  113. kwargs['xtol'] = xtol
  114. kwargs['rtol'] = rtol
  115. results = [list(self._run_one_test(
  116. tc, method, sig_args_keys=sig_args_keys,
  117. sig_kwargs_keys=sig_kwargs_keys, **kwargs)) for tc in tests]
  118. # results= [[true root, full output, tc], ...]
  119. known_fail = known_fail or []
  120. notcvgd = [elt for elt in results if not elt[1].converged]
  121. notcvgd = [elt for elt in notcvgd if elt[-1]['ID'] not in known_fail]
  122. notcvged_IDS = [elt[-1]['ID'] for elt in notcvgd]
  123. assert_equal([len(notcvged_IDS), notcvged_IDS], [0, []])
  124. # The usable xtol and rtol depend on the test
  125. tols = {'xtol': 4 * _FLOAT_EPS, 'rtol': 4 * _FLOAT_EPS}
  126. tols.update(**kwargs)
  127. rtol = tols['rtol']
  128. atol = tols.get('tol', tols['xtol'])
  129. cvgd = [elt for elt in results if elt[1].converged]
  130. approx = [elt[1].root for elt in cvgd]
  131. correct = [elt[0] for elt in cvgd]
  132. notclose = [[a] + elt for a, c, elt in zip(approx, correct, cvgd) if
  133. not isclose(a, c, rtol=rtol, atol=atol)
  134. and elt[-1]['ID'] not in known_fail]
  135. # Evaluate the function and see if is 0 at the purported root
  136. fvs = [tc['f'](aroot, *(tc['args'])) for aroot, c, fullout, tc in notclose]
  137. notclose = [[fv] + elt for fv, elt in zip(fvs, notclose) if fv != 0]
  138. assert_equal([notclose, len(notclose)], [[], 0])
  139. def run_collection(self, collection, method, name, smoothness=None,
  140. known_fail=None,
  141. xtol=4 * _FLOAT_EPS, rtol=4 * _FLOAT_EPS,
  142. **kwargs):
  143. r"""Run a collection of tests using the specified method.
  144. The name is used to determine some optional arguments."""
  145. tests = get_tests(collection, smoothness=smoothness)
  146. self.run_tests(tests, method, name, xtol=xtol, rtol=rtol,
  147. known_fail=known_fail, **kwargs)
  148. def test_bisect(self):
  149. self.run_check(zeros.bisect, 'bisect')
  150. self.run_check_lru_cached(zeros.bisect, 'bisect')
  151. self.run_check_by_name('bisect')
  152. self.run_collection('aps', zeros.bisect, 'bisect', smoothness=1)
  153. def test_ridder(self):
  154. self.run_check(zeros.ridder, 'ridder')
  155. self.run_check_lru_cached(zeros.ridder, 'ridder')
  156. self.run_check_by_name('ridder')
  157. self.run_collection('aps', zeros.ridder, 'ridder', smoothness=1)
  158. def test_brentq(self):
  159. self.run_check(zeros.brentq, 'brentq')
  160. self.run_check_lru_cached(zeros.brentq, 'brentq')
  161. self.run_check_by_name('brentq')
  162. # Brentq/h needs a lower tolerance to be specified
  163. self.run_collection('aps', zeros.brentq, 'brentq', smoothness=1,
  164. xtol=1e-14, rtol=1e-14)
  165. def test_brenth(self):
  166. self.run_check(zeros.brenth, 'brenth')
  167. self.run_check_lru_cached(zeros.brenth, 'brenth')
  168. self.run_check_by_name('brenth')
  169. self.run_collection('aps', zeros.brenth, 'brenth', smoothness=1,
  170. xtol=1e-14, rtol=1e-14)
  171. def test_toms748(self):
  172. self.run_check(zeros.toms748, 'toms748')
  173. self.run_check_lru_cached(zeros.toms748, 'toms748')
  174. self.run_check_by_name('toms748')
  175. self.run_collection('aps', zeros.toms748, 'toms748', smoothness=1)
  176. def test_newton_collections(self):
  177. known_fail = ['aps.13.00']
  178. known_fail += ['aps.12.05', 'aps.12.17'] # fails under Windows Py27
  179. for collection in ['aps', 'complex']:
  180. self.run_collection(collection, zeros.newton, 'newton',
  181. smoothness=2, known_fail=known_fail)
  182. def test_halley_collections(self):
  183. known_fail = ['aps.12.06', 'aps.12.07', 'aps.12.08', 'aps.12.09',
  184. 'aps.12.10', 'aps.12.11', 'aps.12.12', 'aps.12.13',
  185. 'aps.12.14', 'aps.12.15', 'aps.12.16', 'aps.12.17',
  186. 'aps.12.18', 'aps.13.00']
  187. for collection in ['aps', 'complex']:
  188. self.run_collection(collection, zeros.newton, 'halley',
  189. smoothness=2, known_fail=known_fail)
  190. @staticmethod
  191. def f1(x):
  192. return x**2 - 2*x - 1 # == (x-1)**2 - 2
  193. @staticmethod
  194. def f1_1(x):
  195. return 2*x - 2
  196. @staticmethod
  197. def f1_2(x):
  198. return 2.0 + 0*x
  199. @staticmethod
  200. def f2(x):
  201. return exp(x) - cos(x)
  202. @staticmethod
  203. def f2_1(x):
  204. return exp(x) + sin(x)
  205. @staticmethod
  206. def f2_2(x):
  207. return exp(x) + cos(x)
  208. def test_newton(self):
  209. for f, f_1, f_2 in [(self.f1, self.f1_1, self.f1_2),
  210. (self.f2, self.f2_1, self.f2_2)]:
  211. x = zeros.newton(f, 3, tol=1e-6)
  212. assert_allclose(f(x), 0, atol=1e-6)
  213. x = zeros.newton(f, 3, x1=5, tol=1e-6) # secant, x0 and x1
  214. assert_allclose(f(x), 0, atol=1e-6)
  215. x = zeros.newton(f, 3, fprime=f_1, tol=1e-6) # newton
  216. assert_allclose(f(x), 0, atol=1e-6)
  217. x = zeros.newton(f, 3, fprime=f_1, fprime2=f_2, tol=1e-6) # halley
  218. assert_allclose(f(x), 0, atol=1e-6)
  219. def test_newton_by_name(self):
  220. r"""Invoke newton through root_scalar()"""
  221. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  222. r = root_scalar(f, method='newton', x0=3, fprime=f_1, xtol=1e-6)
  223. assert_allclose(f(r.root), 0, atol=1e-6)
  224. def test_secant_by_name(self):
  225. r"""Invoke secant through root_scalar()"""
  226. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  227. r = root_scalar(f, method='secant', x0=3, x1=2, xtol=1e-6)
  228. assert_allclose(f(r.root), 0, atol=1e-6)
  229. r = root_scalar(f, method='secant', x0=3, x1=5, xtol=1e-6)
  230. assert_allclose(f(r.root), 0, atol=1e-6)
  231. def test_halley_by_name(self):
  232. r"""Invoke halley through root_scalar()"""
  233. for f, f_1, f_2 in [(f1, f1_1, f1_2), (f2, f2_1, f2_2)]:
  234. r = root_scalar(f, method='halley', x0=3,
  235. fprime=f_1, fprime2=f_2, xtol=1e-6)
  236. assert_allclose(f(r.root), 0, atol=1e-6)
  237. def test_root_scalar_fail(self):
  238. with pytest.raises(ValueError):
  239. root_scalar(f1, method='secant', x0=3, xtol=1e-6) # no x1
  240. with pytest.raises(ValueError):
  241. root_scalar(f1, method='newton', x0=3, xtol=1e-6) # no fprime
  242. with pytest.raises(ValueError):
  243. root_scalar(f1, method='halley', fprime=f1_1, x0=3, xtol=1e-6) # no fprime2
  244. with pytest.raises(ValueError):
  245. root_scalar(f1, method='halley', fprime2=f1_2, x0=3, xtol=1e-6) # no fprime
  246. def test_array_newton(self):
  247. """test newton with array"""
  248. def f1(x, *a):
  249. b = a[0] + x * a[3]
  250. return a[1] - a[2] * (np.exp(b / a[5]) - 1.0) - b / a[4] - x
  251. def f1_1(x, *a):
  252. b = a[3] / a[5]
  253. return -a[2] * np.exp(a[0] / a[5] + x * b) * b - a[3] / a[4] - 1
  254. def f1_2(x, *a):
  255. b = a[3] / a[5]
  256. return -a[2] * np.exp(a[0] / a[5] + x * b) * b**2
  257. a0 = np.array([
  258. 5.32725221, 5.48673747, 5.49539973,
  259. 5.36387202, 4.80237316, 1.43764452,
  260. 5.23063958, 5.46094772, 5.50512718,
  261. 5.42046290
  262. ])
  263. a1 = (np.sin(range(10)) + 1.0) * 7.0
  264. args = (a0, a1, 1e-09, 0.004, 10, 0.27456)
  265. x0 = [7.0] * 10
  266. x = zeros.newton(f1, x0, f1_1, args)
  267. x_expected = (
  268. 6.17264965, 11.7702805, 12.2219954,
  269. 7.11017681, 1.18151293, 0.143707955,
  270. 4.31928228, 10.5419107, 12.7552490,
  271. 8.91225749
  272. )
  273. assert_allclose(x, x_expected)
  274. # test halley's
  275. x = zeros.newton(f1, x0, f1_1, args, fprime2=f1_2)
  276. assert_allclose(x, x_expected)
  277. # test secant
  278. x = zeros.newton(f1, x0, args=args)
  279. assert_allclose(x, x_expected)
  280. def test_array_newton_complex(self):
  281. def f(x):
  282. return x + 1+1j
  283. def fprime(x):
  284. return 1.0
  285. t = np.full(4, 1j)
  286. x = zeros.newton(f, t, fprime=fprime)
  287. assert_allclose(f(x), 0.)
  288. # should work even if x0 is not complex
  289. t = np.ones(4)
  290. x = zeros.newton(f, t, fprime=fprime)
  291. assert_allclose(f(x), 0.)
  292. x = zeros.newton(f, t)
  293. assert_allclose(f(x), 0.)
  294. def test_array_secant_active_zero_der(self):
  295. """test secant doesn't continue to iterate zero derivatives"""
  296. x = zeros.newton(lambda x, *a: x*x - a[0], x0=[4.123, 5],
  297. args=[np.array([17, 25])])
  298. assert_allclose(x, (4.123105625617661, 5.0))
  299. def test_array_newton_integers(self):
  300. # test secant with float
  301. x = zeros.newton(lambda y, z: z - y ** 2, [4.0] * 2,
  302. args=([15.0, 17.0],))
  303. assert_allclose(x, (3.872983346207417, 4.123105625617661))
  304. # test integer becomes float
  305. x = zeros.newton(lambda y, z: z - y ** 2, [4] * 2, args=([15, 17],))
  306. assert_allclose(x, (3.872983346207417, 4.123105625617661))
  307. def test_array_newton_zero_der_failures(self):
  308. # test derivative zero warning
  309. assert_warns(RuntimeWarning, zeros.newton,
  310. lambda y: y**2 - 2, [0., 0.], lambda y: 2 * y)
  311. # test failures and zero_der
  312. with pytest.warns(RuntimeWarning):
  313. results = zeros.newton(lambda y: y**2 - 2, [0., 0.],
  314. lambda y: 2*y, full_output=True)
  315. assert_allclose(results.root, 0)
  316. assert results.zero_der.all()
  317. assert not results.converged.any()
  318. def test_newton_combined(self):
  319. f1 = lambda x: x**2 - 2*x - 1
  320. f1_1 = lambda x: 2*x - 2
  321. f1_2 = lambda x: 2.0 + 0*x
  322. def f1_and_p_and_pp(x):
  323. return x**2 - 2*x-1, 2*x-2, 2.0
  324. sol0 = root_scalar(f1, method='newton', x0=3, fprime=f1_1)
  325. sol = root_scalar(f1_and_p_and_pp, method='newton', x0=3, fprime=True)
  326. assert_allclose(sol0.root, sol.root, atol=1e-8)
  327. assert_equal(2*sol.function_calls, sol0.function_calls)
  328. sol0 = root_scalar(f1, method='halley', x0=3, fprime=f1_1, fprime2=f1_2)
  329. sol = root_scalar(f1_and_p_and_pp, method='halley', x0=3, fprime2=True)
  330. assert_allclose(sol0.root, sol.root, atol=1e-8)
  331. assert_equal(3*sol.function_calls, sol0.function_calls)
  332. def test_newton_full_output(self):
  333. # Test the full_output capability, both when converging and not.
  334. # Use simple polynomials, to avoid hitting platform dependencies
  335. # (e.g., exp & trig) in number of iterations
  336. x0 = 3
  337. expected_counts = [(6, 7), (5, 10), (3, 9)]
  338. for derivs in range(3):
  339. kwargs = {'tol': 1e-6, 'full_output': True, }
  340. for k, v in [['fprime', self.f1_1], ['fprime2', self.f1_2]][:derivs]:
  341. kwargs[k] = v
  342. x, r = zeros.newton(self.f1, x0, disp=False, **kwargs)
  343. assert_(r.converged)
  344. assert_equal(x, r.root)
  345. assert_equal((r.iterations, r.function_calls), expected_counts[derivs])
  346. if derivs == 0:
  347. assert r.function_calls <= r.iterations + 1
  348. else:
  349. assert_equal(r.function_calls, (derivs + 1) * r.iterations)
  350. # Now repeat, allowing one fewer iteration to force convergence failure
  351. iters = r.iterations - 1
  352. x, r = zeros.newton(self.f1, x0, maxiter=iters, disp=False, **kwargs)
  353. assert_(not r.converged)
  354. assert_equal(x, r.root)
  355. assert_equal(r.iterations, iters)
  356. if derivs == 1:
  357. # Check that the correct Exception is raised and
  358. # validate the start of the message.
  359. with pytest.raises(
  360. RuntimeError,
  361. match='Failed to converge after %d iterations, value is .*' % (iters)):
  362. x, r = zeros.newton(self.f1, x0, maxiter=iters, disp=True, **kwargs)
  363. def test_deriv_zero_warning(self):
  364. func = lambda x: x**2 - 2.0
  365. dfunc = lambda x: 2*x
  366. assert_warns(RuntimeWarning, zeros.newton, func, 0.0, dfunc, disp=False)
  367. with pytest.raises(RuntimeError, match='Derivative was zero'):
  368. zeros.newton(func, 0.0, dfunc)
  369. def test_newton_does_not_modify_x0(self):
  370. # https://github.com/scipy/scipy/issues/9964
  371. x0 = np.array([0.1, 3])
  372. x0_copy = x0.copy() # Copy to test for equality.
  373. newton(np.sin, x0, np.cos)
  374. assert_array_equal(x0, x0_copy)
  375. def test_maxiter_int_check(self):
  376. for method in [zeros.bisect, zeros.newton, zeros.ridder, zeros.brentq,
  377. zeros.brenth, zeros.toms748]:
  378. with pytest.raises(TypeError,
  379. match="'float' object cannot be interpreted as an integer"):
  380. method(f1, 0.0, 1.0, maxiter=72.45)
  381. def test_gh_5555():
  382. root = 0.1
  383. def f(x):
  384. return x - root
  385. methods = [zeros.bisect, zeros.ridder]
  386. xtol = rtol = TOL
  387. for method in methods:
  388. res = method(f, -1e8, 1e7, xtol=xtol, rtol=rtol)
  389. assert_allclose(root, res, atol=xtol, rtol=rtol,
  390. err_msg='method %s' % method.__name__)
  391. def test_gh_5557():
  392. # Show that without the changes in 5557 brentq and brenth might
  393. # only achieve a tolerance of 2*(xtol + rtol*|res|).
  394. # f linearly interpolates (0, -0.1), (0.5, -0.1), and (1,
  395. # 0.4). The important parts are that |f(0)| < |f(1)| (so that
  396. # brent takes 0 as the initial guess), |f(0)| < atol (so that
  397. # brent accepts 0 as the root), and that the exact root of f lies
  398. # more than atol away from 0 (so that brent doesn't achieve the
  399. # desired tolerance).
  400. def f(x):
  401. if x < 0.5:
  402. return -0.1
  403. else:
  404. return x - 0.6
  405. atol = 0.51
  406. rtol = 4 * _FLOAT_EPS
  407. methods = [zeros.brentq, zeros.brenth]
  408. for method in methods:
  409. res = method(f, 0, 1, xtol=atol, rtol=rtol)
  410. assert_allclose(0.6, res, atol=atol, rtol=rtol)
  411. def test_brent_underflow_in_root_bracketing():
  412. # Tetsing if an interval [a,b] brackets a zero of a function
  413. # by checking f(a)*f(b) < 0 is not reliable when the product
  414. # underflows/overflows. (reported in issue# 13737)
  415. underflow_scenario = (-450.0, -350.0, -400.0)
  416. overflow_scenario = (350.0, 450.0, 400.0)
  417. for a, b, root in [underflow_scenario, overflow_scenario]:
  418. c = np.exp(root)
  419. for method in [zeros.brenth, zeros.brentq]:
  420. res = method(lambda x: np.exp(x)-c, a, b)
  421. assert_allclose(root, res)
  422. class TestRootResults:
  423. def test_repr(self):
  424. r = zeros.RootResults(root=1.0,
  425. iterations=44,
  426. function_calls=46,
  427. flag=0)
  428. expected_repr = (" converged: True\n flag: 'converged'"
  429. "\n function_calls: 46\n iterations: 44\n"
  430. " root: 1.0")
  431. assert_equal(repr(r), expected_repr)
  432. def test_complex_halley():
  433. """Test Halley's works with complex roots"""
  434. def f(x, *a):
  435. return a[0] * x**2 + a[1] * x + a[2]
  436. def f_1(x, *a):
  437. return 2 * a[0] * x + a[1]
  438. def f_2(x, *a):
  439. retval = 2 * a[0]
  440. try:
  441. size = len(x)
  442. except TypeError:
  443. return retval
  444. else:
  445. return [retval] * size
  446. z = complex(1.0, 2.0)
  447. coeffs = (2.0, 3.0, 4.0)
  448. y = zeros.newton(f, z, args=coeffs, fprime=f_1, fprime2=f_2, tol=1e-6)
  449. # (-0.75000000000000078+1.1989578808281789j)
  450. assert_allclose(f(y, *coeffs), 0, atol=1e-6)
  451. z = [z] * 10
  452. coeffs = (2.0, 3.0, 4.0)
  453. y = zeros.newton(f, z, args=coeffs, fprime=f_1, fprime2=f_2, tol=1e-6)
  454. assert_allclose(f(y, *coeffs), 0, atol=1e-6)
  455. def test_zero_der_nz_dp():
  456. """Test secant method with a non-zero dp, but an infinite newton step"""
  457. # pick a symmetrical functions and choose a point on the side that with dx
  458. # makes a secant that is a flat line with zero slope, EG: f = (x - 100)**2,
  459. # which has a root at x = 100 and is symmetrical around the line x = 100
  460. # we have to pick a really big number so that it is consistently true
  461. # now find a point on each side so that the secant has a zero slope
  462. dx = np.finfo(float).eps ** 0.33
  463. # 100 - p0 = p1 - 100 = p0 * (1 + dx) + dx - 100
  464. # -> 200 = p0 * (2 + dx) + dx
  465. p0 = (200.0 - dx) / (2.0 + dx)
  466. with suppress_warnings() as sup:
  467. sup.filter(RuntimeWarning, "RMS of")
  468. x = zeros.newton(lambda y: (y - 100.0)**2, x0=[p0] * 10)
  469. assert_allclose(x, [100] * 10)
  470. # test scalar cases too
  471. p0 = (2.0 - 1e-4) / (2.0 + 1e-4)
  472. with suppress_warnings() as sup:
  473. sup.filter(RuntimeWarning, "Tolerance of")
  474. x = zeros.newton(lambda y: (y - 1.0) ** 2, x0=p0, disp=False)
  475. assert_allclose(x, 1)
  476. with pytest.raises(RuntimeError, match='Tolerance of'):
  477. x = zeros.newton(lambda y: (y - 1.0) ** 2, x0=p0, disp=True)
  478. p0 = (-2.0 + 1e-4) / (2.0 + 1e-4)
  479. with suppress_warnings() as sup:
  480. sup.filter(RuntimeWarning, "Tolerance of")
  481. x = zeros.newton(lambda y: (y + 1.0) ** 2, x0=p0, disp=False)
  482. assert_allclose(x, -1)
  483. with pytest.raises(RuntimeError, match='Tolerance of'):
  484. x = zeros.newton(lambda y: (y + 1.0) ** 2, x0=p0, disp=True)
  485. def test_array_newton_failures():
  486. """Test that array newton fails as expected"""
  487. # p = 0.68 # [MPa]
  488. # dp = -0.068 * 1e6 # [Pa]
  489. # T = 323 # [K]
  490. diameter = 0.10 # [m]
  491. # L = 100 # [m]
  492. roughness = 0.00015 # [m]
  493. rho = 988.1 # [kg/m**3]
  494. mu = 5.4790e-04 # [Pa*s]
  495. u = 2.488 # [m/s]
  496. reynolds_number = rho * u * diameter / mu # Reynolds number
  497. def colebrook_eqn(darcy_friction, re, dia):
  498. return (1 / np.sqrt(darcy_friction) +
  499. 2 * np.log10(roughness / 3.7 / dia +
  500. 2.51 / re / np.sqrt(darcy_friction)))
  501. # only some failures
  502. with pytest.warns(RuntimeWarning):
  503. result = zeros.newton(
  504. colebrook_eqn, x0=[0.01, 0.2, 0.02223, 0.3], maxiter=2,
  505. args=[reynolds_number, diameter], full_output=True
  506. )
  507. assert not result.converged.all()
  508. # they all fail
  509. with pytest.raises(RuntimeError):
  510. result = zeros.newton(
  511. colebrook_eqn, x0=[0.01] * 2, maxiter=2,
  512. args=[reynolds_number, diameter], full_output=True
  513. )
  514. # this test should **not** raise a RuntimeWarning
  515. def test_gh8904_zeroder_at_root_fails():
  516. """Test that Newton or Halley don't warn if zero derivative at root"""
  517. # a function that has a zero derivative at it's root
  518. def f_zeroder_root(x):
  519. return x**3 - x**2
  520. # should work with secant
  521. r = zeros.newton(f_zeroder_root, x0=0)
  522. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  523. # test again with array
  524. r = zeros.newton(f_zeroder_root, x0=[0]*10)
  525. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  526. # 1st derivative
  527. def fder(x):
  528. return 3 * x**2 - 2 * x
  529. # 2nd derivative
  530. def fder2(x):
  531. return 6*x - 2
  532. # should work with newton and halley
  533. r = zeros.newton(f_zeroder_root, x0=0, fprime=fder)
  534. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  535. r = zeros.newton(f_zeroder_root, x0=0, fprime=fder,
  536. fprime2=fder2)
  537. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  538. # test again with array
  539. r = zeros.newton(f_zeroder_root, x0=[0]*10, fprime=fder)
  540. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  541. r = zeros.newton(f_zeroder_root, x0=[0]*10, fprime=fder,
  542. fprime2=fder2)
  543. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  544. # also test that if a root is found we do not raise RuntimeWarning even if
  545. # the derivative is zero, EG: at x = 0.5, then fval = -0.125 and
  546. # fder = -0.25 so the next guess is 0.5 - (-0.125/-0.5) = 0 which is the
  547. # root, but if the solver continued with that guess, then it will calculate
  548. # a zero derivative, so it should return the root w/o RuntimeWarning
  549. r = zeros.newton(f_zeroder_root, x0=0.5, fprime=fder)
  550. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  551. # test again with array
  552. r = zeros.newton(f_zeroder_root, x0=[0.5]*10, fprime=fder)
  553. assert_allclose(r, 0, atol=zeros._xtol, rtol=zeros._rtol)
  554. # doesn't apply to halley
  555. def test_gh_8881():
  556. r"""Test that Halley's method realizes that the 2nd order adjustment
  557. is too big and drops off to the 1st order adjustment."""
  558. n = 9
  559. def f(x):
  560. return power(x, 1.0/n) - power(n, 1.0/n)
  561. def fp(x):
  562. return power(x, (1.0-n)/n)/n
  563. def fpp(x):
  564. return power(x, (1.0-2*n)/n) * (1.0/n) * (1.0-n)/n
  565. x0 = 0.1
  566. # The root is at x=9.
  567. # The function has positive slope, x0 < root.
  568. # Newton succeeds in 8 iterations
  569. rt, r = newton(f, x0, fprime=fp, full_output=True)
  570. assert r.converged
  571. # Before the Issue 8881/PR 8882, halley would send x in the wrong direction.
  572. # Check that it now succeeds.
  573. rt, r = newton(f, x0, fprime=fp, fprime2=fpp, full_output=True)
  574. assert r.converged
  575. def test_gh_9608_preserve_array_shape():
  576. """
  577. Test that shape is preserved for array inputs even if fprime or fprime2 is
  578. scalar
  579. """
  580. def f(x):
  581. return x**2
  582. def fp(x):
  583. return 2 * x
  584. def fpp(x):
  585. return 2
  586. x0 = np.array([-2], dtype=np.float32)
  587. rt, r = newton(f, x0, fprime=fp, fprime2=fpp, full_output=True)
  588. assert r.converged
  589. x0_array = np.array([-2, -3], dtype=np.float32)
  590. # This next invocation should fail
  591. with pytest.raises(IndexError):
  592. result = zeros.newton(
  593. f, x0_array, fprime=fp, fprime2=fpp, full_output=True
  594. )
  595. def fpp_array(x):
  596. return np.full(np.shape(x), 2, dtype=np.float32)
  597. result = zeros.newton(
  598. f, x0_array, fprime=fp, fprime2=fpp_array, full_output=True
  599. )
  600. assert result.converged.all()
  601. @pytest.mark.parametrize(
  602. "maximum_iterations,flag_expected",
  603. [(10, zeros.CONVERR), (100, zeros.CONVERGED)])
  604. def test_gh9254_flag_if_maxiter_exceeded(maximum_iterations, flag_expected):
  605. """
  606. Test that if the maximum iterations is exceeded that the flag is not
  607. converged.
  608. """
  609. result = zeros.brentq(
  610. lambda x: ((1.2*x - 2.3)*x + 3.4)*x - 4.5,
  611. -30, 30, (), 1e-6, 1e-6, maximum_iterations,
  612. full_output=True, disp=False)
  613. assert result[1].flag == flag_expected
  614. if flag_expected == zeros.CONVERR:
  615. # didn't converge because exceeded maximum iterations
  616. assert result[1].iterations == maximum_iterations
  617. elif flag_expected == zeros.CONVERGED:
  618. # converged before maximum iterations
  619. assert result[1].iterations < maximum_iterations
  620. def test_gh9551_raise_error_if_disp_true():
  621. """Test that if disp is true then zero derivative raises RuntimeError"""
  622. def f(x):
  623. return x*x + 1
  624. def f_p(x):
  625. return 2*x
  626. assert_warns(RuntimeWarning, zeros.newton, f, 1.0, f_p, disp=False)
  627. with pytest.raises(
  628. RuntimeError,
  629. match=r'^Derivative was zero\. Failed to converge after \d+ iterations, value is [+-]?\d*\.\d+\.$'):
  630. zeros.newton(f, 1.0, f_p)
  631. root = zeros.newton(f, complex(10.0, 10.0), f_p)
  632. assert_allclose(root, complex(0.0, 1.0))