test_linesearch.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. """
  2. Tests for line search routines
  3. """
  4. from numpy.testing import (assert_equal, assert_array_almost_equal,
  5. assert_array_almost_equal_nulp, assert_warns,
  6. suppress_warnings)
  7. import scipy.optimize._linesearch as ls
  8. from scipy.optimize._linesearch import LineSearchWarning
  9. import numpy as np
  10. def assert_wolfe(s, phi, derphi, c1=1e-4, c2=0.9, err_msg=""):
  11. """
  12. Check that strong Wolfe conditions apply
  13. """
  14. phi1 = phi(s)
  15. phi0 = phi(0)
  16. derphi0 = derphi(0)
  17. derphi1 = derphi(s)
  18. msg = "s = %s; phi(0) = %s; phi(s) = %s; phi'(0) = %s; phi'(s) = %s; %s" % (
  19. s, phi0, phi1, derphi0, derphi1, err_msg)
  20. assert phi1 <= phi0 + c1*s*derphi0, "Wolfe 1 failed: " + msg
  21. assert abs(derphi1) <= abs(c2*derphi0), "Wolfe 2 failed: " + msg
  22. def assert_armijo(s, phi, c1=1e-4, err_msg=""):
  23. """
  24. Check that Armijo condition applies
  25. """
  26. phi1 = phi(s)
  27. phi0 = phi(0)
  28. msg = "s = %s; phi(0) = %s; phi(s) = %s; %s" % (s, phi0, phi1, err_msg)
  29. assert phi1 <= (1 - c1*s)*phi0, msg
  30. def assert_line_wolfe(x, p, s, f, fprime, **kw):
  31. assert_wolfe(s, phi=lambda sp: f(x + p*sp),
  32. derphi=lambda sp: np.dot(fprime(x + p*sp), p), **kw)
  33. def assert_line_armijo(x, p, s, f, **kw):
  34. assert_armijo(s, phi=lambda sp: f(x + p*sp), **kw)
  35. def assert_fp_equal(x, y, err_msg="", nulp=50):
  36. """Assert two arrays are equal, up to some floating-point rounding error"""
  37. try:
  38. assert_array_almost_equal_nulp(x, y, nulp)
  39. except AssertionError as e:
  40. raise AssertionError("%s\n%s" % (e, err_msg)) from e
  41. class TestLineSearch:
  42. # -- scalar functions; must have dphi(0.) < 0
  43. def _scalar_func_1(self, s):
  44. self.fcount += 1
  45. p = -s - s**3 + s**4
  46. dp = -1 - 3*s**2 + 4*s**3
  47. return p, dp
  48. def _scalar_func_2(self, s):
  49. self.fcount += 1
  50. p = np.exp(-4*s) + s**2
  51. dp = -4*np.exp(-4*s) + 2*s
  52. return p, dp
  53. def _scalar_func_3(self, s):
  54. self.fcount += 1
  55. p = -np.sin(10*s)
  56. dp = -10*np.cos(10*s)
  57. return p, dp
  58. # -- n-d functions
  59. def _line_func_1(self, x):
  60. self.fcount += 1
  61. f = np.dot(x, x)
  62. df = 2*x
  63. return f, df
  64. def _line_func_2(self, x):
  65. self.fcount += 1
  66. f = np.dot(x, np.dot(self.A, x)) + 1
  67. df = np.dot(self.A + self.A.T, x)
  68. return f, df
  69. # --
  70. def setup_method(self):
  71. self.scalar_funcs = []
  72. self.line_funcs = []
  73. self.N = 20
  74. self.fcount = 0
  75. def bind_index(func, idx):
  76. # Remember Python's closure semantics!
  77. return lambda *a, **kw: func(*a, **kw)[idx]
  78. for name in sorted(dir(self)):
  79. if name.startswith('_scalar_func_'):
  80. value = getattr(self, name)
  81. self.scalar_funcs.append(
  82. (name, bind_index(value, 0), bind_index(value, 1)))
  83. elif name.startswith('_line_func_'):
  84. value = getattr(self, name)
  85. self.line_funcs.append(
  86. (name, bind_index(value, 0), bind_index(value, 1)))
  87. np.random.seed(1234)
  88. self.A = np.random.randn(self.N, self.N)
  89. def scalar_iter(self):
  90. for name, phi, derphi in self.scalar_funcs:
  91. for old_phi0 in np.random.randn(3):
  92. yield name, phi, derphi, old_phi0
  93. def line_iter(self):
  94. for name, f, fprime in self.line_funcs:
  95. k = 0
  96. while k < 9:
  97. x = np.random.randn(self.N)
  98. p = np.random.randn(self.N)
  99. if np.dot(p, fprime(x)) >= 0:
  100. # always pick a descent direction
  101. continue
  102. k += 1
  103. old_fv = float(np.random.randn())
  104. yield name, f, fprime, x, p, old_fv
  105. # -- Generic scalar searches
  106. def test_scalar_search_wolfe1(self):
  107. c = 0
  108. for name, phi, derphi, old_phi0 in self.scalar_iter():
  109. c += 1
  110. s, phi1, phi0 = ls.scalar_search_wolfe1(phi, derphi, phi(0),
  111. old_phi0, derphi(0))
  112. assert_fp_equal(phi0, phi(0), name)
  113. assert_fp_equal(phi1, phi(s), name)
  114. assert_wolfe(s, phi, derphi, err_msg=name)
  115. assert c > 3 # check that the iterator really works...
  116. def test_scalar_search_wolfe2(self):
  117. for name, phi, derphi, old_phi0 in self.scalar_iter():
  118. s, phi1, phi0, derphi1 = ls.scalar_search_wolfe2(
  119. phi, derphi, phi(0), old_phi0, derphi(0))
  120. assert_fp_equal(phi0, phi(0), name)
  121. assert_fp_equal(phi1, phi(s), name)
  122. if derphi1 is not None:
  123. assert_fp_equal(derphi1, derphi(s), name)
  124. assert_wolfe(s, phi, derphi, err_msg="%s %g" % (name, old_phi0))
  125. def test_scalar_search_wolfe2_with_low_amax(self):
  126. def phi(alpha):
  127. return (alpha - 5) ** 2
  128. def derphi(alpha):
  129. return 2 * (alpha - 5)
  130. s, _, _, _ = assert_warns(LineSearchWarning,
  131. ls.scalar_search_wolfe2, phi, derphi, amax=0.001)
  132. assert s is None
  133. def test_scalar_search_wolfe2_regression(self):
  134. # Regression test for gh-12157
  135. # This phi has its minimum at alpha=4/3 ~ 1.333.
  136. def phi(alpha):
  137. if alpha < 1:
  138. return - 3*np.pi/2 * (alpha - 1)
  139. else:
  140. return np.cos(3*np.pi/2 * alpha - np.pi)
  141. def derphi(alpha):
  142. if alpha < 1:
  143. return - 3*np.pi/2
  144. else:
  145. return - 3*np.pi/2 * np.sin(3*np.pi/2 * alpha - np.pi)
  146. s, _, _, _ = ls.scalar_search_wolfe2(phi, derphi)
  147. # Without the fix in gh-13073, the scalar_search_wolfe2
  148. # returned s=2.0 instead.
  149. assert s < 1.5
  150. def test_scalar_search_armijo(self):
  151. for name, phi, derphi, old_phi0 in self.scalar_iter():
  152. s, phi1 = ls.scalar_search_armijo(phi, phi(0), derphi(0))
  153. assert_fp_equal(phi1, phi(s), name)
  154. assert_armijo(s, phi, err_msg="%s %g" % (name, old_phi0))
  155. # -- Generic line searches
  156. def test_line_search_wolfe1(self):
  157. c = 0
  158. smax = 100
  159. for name, f, fprime, x, p, old_f in self.line_iter():
  160. f0 = f(x)
  161. g0 = fprime(x)
  162. self.fcount = 0
  163. s, fc, gc, fv, ofv, gv = ls.line_search_wolfe1(f, fprime, x, p,
  164. g0, f0, old_f,
  165. amax=smax)
  166. assert_equal(self.fcount, fc+gc)
  167. assert_fp_equal(ofv, f(x))
  168. if s is None:
  169. continue
  170. assert_fp_equal(fv, f(x + s*p))
  171. assert_array_almost_equal(gv, fprime(x + s*p), decimal=14)
  172. if s < smax:
  173. c += 1
  174. assert_line_wolfe(x, p, s, f, fprime, err_msg=name)
  175. assert c > 3 # check that the iterator really works...
  176. def test_line_search_wolfe2(self):
  177. c = 0
  178. smax = 512
  179. for name, f, fprime, x, p, old_f in self.line_iter():
  180. f0 = f(x)
  181. g0 = fprime(x)
  182. self.fcount = 0
  183. with suppress_warnings() as sup:
  184. sup.filter(LineSearchWarning,
  185. "The line search algorithm could not find a solution")
  186. sup.filter(LineSearchWarning,
  187. "The line search algorithm did not converge")
  188. s, fc, gc, fv, ofv, gv = ls.line_search_wolfe2(f, fprime, x, p,
  189. g0, f0, old_f,
  190. amax=smax)
  191. assert_equal(self.fcount, fc+gc)
  192. assert_fp_equal(ofv, f(x))
  193. assert_fp_equal(fv, f(x + s*p))
  194. if gv is not None:
  195. assert_array_almost_equal(gv, fprime(x + s*p), decimal=14)
  196. if s < smax:
  197. c += 1
  198. assert_line_wolfe(x, p, s, f, fprime, err_msg=name)
  199. assert c > 3 # check that the iterator really works...
  200. def test_line_search_wolfe2_bounds(self):
  201. # See gh-7475
  202. # For this f and p, starting at a point on axis 0, the strong Wolfe
  203. # condition 2 is met if and only if the step length s satisfies
  204. # |x + s| <= c2 * |x|
  205. f = lambda x: np.dot(x, x)
  206. fp = lambda x: 2 * x
  207. p = np.array([1, 0])
  208. # Smallest s satisfying strong Wolfe conditions for these arguments is 30
  209. x = -60 * p
  210. c2 = 0.5
  211. s, _, _, _, _, _ = ls.line_search_wolfe2(f, fp, x, p, amax=30, c2=c2)
  212. assert_line_wolfe(x, p, s, f, fp)
  213. s, _, _, _, _, _ = assert_warns(LineSearchWarning,
  214. ls.line_search_wolfe2, f, fp, x, p,
  215. amax=29, c2=c2)
  216. assert s is None
  217. # s=30 will only be tried on the 6th iteration, so this won't converge
  218. assert_warns(LineSearchWarning, ls.line_search_wolfe2, f, fp, x, p,
  219. c2=c2, maxiter=5)
  220. def test_line_search_armijo(self):
  221. c = 0
  222. for name, f, fprime, x, p, old_f in self.line_iter():
  223. f0 = f(x)
  224. g0 = fprime(x)
  225. self.fcount = 0
  226. s, fc, fv = ls.line_search_armijo(f, x, p, g0, f0)
  227. c += 1
  228. assert_equal(self.fcount, fc)
  229. assert_fp_equal(fv, f(x + s*p))
  230. assert_line_armijo(x, p, s, f, err_msg=name)
  231. assert c >= 9
  232. # -- More specific tests
  233. def test_armijo_terminate_1(self):
  234. # Armijo should evaluate the function only once if the trial step
  235. # is already suitable
  236. count = [0]
  237. def phi(s):
  238. count[0] += 1
  239. return -s + 0.01*s**2
  240. s, phi1 = ls.scalar_search_armijo(phi, phi(0), -1, alpha0=1)
  241. assert_equal(s, 1)
  242. assert_equal(count[0], 2)
  243. assert_armijo(s, phi)
  244. def test_wolfe_terminate(self):
  245. # wolfe1 and wolfe2 should also evaluate the function only a few
  246. # times if the trial step is already suitable
  247. def phi(s):
  248. count[0] += 1
  249. return -s + 0.05*s**2
  250. def derphi(s):
  251. count[0] += 1
  252. return -1 + 0.05*2*s
  253. for func in [ls.scalar_search_wolfe1, ls.scalar_search_wolfe2]:
  254. count = [0]
  255. r = func(phi, derphi, phi(0), None, derphi(0))
  256. assert r[0] is not None, (r, func)
  257. assert count[0] <= 2 + 2, (count, func)
  258. assert_wolfe(r[0], phi, derphi, err_msg=str(func))