_trustregion.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. """Trust-region optimization."""
  2. import math
  3. import warnings
  4. import numpy as np
  5. import scipy.linalg
  6. from ._optimize import (_check_unknown_options, _status_message,
  7. OptimizeResult, _prepare_scalar_function)
  8. from scipy.optimize._hessian_update_strategy import HessianUpdateStrategy
  9. from scipy.optimize._differentiable_functions import FD_METHODS
  10. __all__ = []
  11. def _wrap_function(function, args):
  12. # wraps a minimizer function to count number of evaluations
  13. # and to easily provide an args kwd.
  14. ncalls = [0]
  15. if function is None:
  16. return ncalls, None
  17. def function_wrapper(x, *wrapper_args):
  18. ncalls[0] += 1
  19. # A copy of x is sent to the user function (gh13740)
  20. return function(np.copy(x), *(wrapper_args + args))
  21. return ncalls, function_wrapper
  22. class BaseQuadraticSubproblem:
  23. """
  24. Base/abstract class defining the quadratic model for trust-region
  25. minimization. Child classes must implement the ``solve`` method.
  26. Values of the objective function, Jacobian and Hessian (if provided) at
  27. the current iterate ``x`` are evaluated on demand and then stored as
  28. attributes ``fun``, ``jac``, ``hess``.
  29. """
  30. def __init__(self, x, fun, jac, hess=None, hessp=None):
  31. self._x = x
  32. self._f = None
  33. self._g = None
  34. self._h = None
  35. self._g_mag = None
  36. self._cauchy_point = None
  37. self._newton_point = None
  38. self._fun = fun
  39. self._jac = jac
  40. self._hess = hess
  41. self._hessp = hessp
  42. def __call__(self, p):
  43. return self.fun + np.dot(self.jac, p) + 0.5 * np.dot(p, self.hessp(p))
  44. @property
  45. def fun(self):
  46. """Value of objective function at current iteration."""
  47. if self._f is None:
  48. self._f = self._fun(self._x)
  49. return self._f
  50. @property
  51. def jac(self):
  52. """Value of Jacobian of objective function at current iteration."""
  53. if self._g is None:
  54. self._g = self._jac(self._x)
  55. return self._g
  56. @property
  57. def hess(self):
  58. """Value of Hessian of objective function at current iteration."""
  59. if self._h is None:
  60. self._h = self._hess(self._x)
  61. return self._h
  62. def hessp(self, p):
  63. if self._hessp is not None:
  64. return self._hessp(self._x, p)
  65. else:
  66. return np.dot(self.hess, p)
  67. @property
  68. def jac_mag(self):
  69. """Magnitude of jacobian of objective function at current iteration."""
  70. if self._g_mag is None:
  71. self._g_mag = scipy.linalg.norm(self.jac)
  72. return self._g_mag
  73. def get_boundaries_intersections(self, z, d, trust_radius):
  74. """
  75. Solve the scalar quadratic equation ||z + t d|| == trust_radius.
  76. This is like a line-sphere intersection.
  77. Return the two values of t, sorted from low to high.
  78. """
  79. a = np.dot(d, d)
  80. b = 2 * np.dot(z, d)
  81. c = np.dot(z, z) - trust_radius**2
  82. sqrt_discriminant = math.sqrt(b*b - 4*a*c)
  83. # The following calculation is mathematically
  84. # equivalent to:
  85. # ta = (-b - sqrt_discriminant) / (2*a)
  86. # tb = (-b + sqrt_discriminant) / (2*a)
  87. # but produce smaller round off errors.
  88. # Look at Matrix Computation p.97
  89. # for a better justification.
  90. aux = b + math.copysign(sqrt_discriminant, b)
  91. ta = -aux / (2*a)
  92. tb = -2*c / aux
  93. return sorted([ta, tb])
  94. def solve(self, trust_radius):
  95. raise NotImplementedError('The solve method should be implemented by '
  96. 'the child class')
  97. def _minimize_trust_region(fun, x0, args=(), jac=None, hess=None, hessp=None,
  98. subproblem=None, initial_trust_radius=1.0,
  99. max_trust_radius=1000.0, eta=0.15, gtol=1e-4,
  100. maxiter=None, disp=False, return_all=False,
  101. callback=None, inexact=True, **unknown_options):
  102. """
  103. Minimization of scalar function of one or more variables using a
  104. trust-region algorithm.
  105. Options for the trust-region algorithm are:
  106. initial_trust_radius : float
  107. Initial trust radius.
  108. max_trust_radius : float
  109. Never propose steps that are longer than this value.
  110. eta : float
  111. Trust region related acceptance stringency for proposed steps.
  112. gtol : float
  113. Gradient norm must be less than `gtol`
  114. before successful termination.
  115. maxiter : int
  116. Maximum number of iterations to perform.
  117. disp : bool
  118. If True, print convergence message.
  119. inexact : bool
  120. Accuracy to solve subproblems. If True requires less nonlinear
  121. iterations, but more vector products. Only effective for method
  122. trust-krylov.
  123. This function is called by the `minimize` function.
  124. It is not supposed to be called directly.
  125. """
  126. _check_unknown_options(unknown_options)
  127. if jac is None:
  128. raise ValueError('Jacobian is currently required for trust-region '
  129. 'methods')
  130. if hess is None and hessp is None:
  131. raise ValueError('Either the Hessian or the Hessian-vector product '
  132. 'is currently required for trust-region methods')
  133. if subproblem is None:
  134. raise ValueError('A subproblem solving strategy is required for '
  135. 'trust-region methods')
  136. if not (0 <= eta < 0.25):
  137. raise Exception('invalid acceptance stringency')
  138. if max_trust_radius <= 0:
  139. raise Exception('the max trust radius must be positive')
  140. if initial_trust_radius <= 0:
  141. raise ValueError('the initial trust radius must be positive')
  142. if initial_trust_radius >= max_trust_radius:
  143. raise ValueError('the initial trust radius must be less than the '
  144. 'max trust radius')
  145. # force the initial guess into a nice format
  146. x0 = np.asarray(x0).flatten()
  147. # A ScalarFunction representing the problem. This caches calls to fun, jac,
  148. # hess.
  149. sf = _prepare_scalar_function(fun, x0, jac=jac, hess=hess, args=args)
  150. fun = sf.fun
  151. jac = sf.grad
  152. if callable(hess):
  153. hess = sf.hess
  154. elif callable(hessp):
  155. # this elif statement must come before examining whether hess
  156. # is estimated by FD methods or a HessianUpdateStrategy
  157. pass
  158. elif (hess in FD_METHODS or isinstance(hess, HessianUpdateStrategy)):
  159. # If the Hessian is being estimated by finite differences or a
  160. # Hessian update strategy then ScalarFunction.hess returns a
  161. # LinearOperator or a HessianUpdateStrategy. This enables the
  162. # calculation/creation of a hessp. BUT you only want to do this
  163. # if the user *hasn't* provided a callable(hessp) function.
  164. hess = None
  165. def hessp(x, p, *args):
  166. return sf.hess(x).dot(p)
  167. else:
  168. raise ValueError('Either the Hessian or the Hessian-vector product '
  169. 'is currently required for trust-region methods')
  170. # ScalarFunction doesn't represent hessp
  171. nhessp, hessp = _wrap_function(hessp, args)
  172. # limit the number of iterations
  173. if maxiter is None:
  174. maxiter = len(x0)*200
  175. # init the search status
  176. warnflag = 0
  177. # initialize the search
  178. trust_radius = initial_trust_radius
  179. x = x0
  180. if return_all:
  181. allvecs = [x]
  182. m = subproblem(x, fun, jac, hess, hessp)
  183. k = 0
  184. # search for the function min
  185. # do not even start if the gradient is small enough
  186. while m.jac_mag >= gtol:
  187. # Solve the sub-problem.
  188. # This gives us the proposed step relative to the current position
  189. # and it tells us whether the proposed step
  190. # has reached the trust region boundary or not.
  191. try:
  192. p, hits_boundary = m.solve(trust_radius)
  193. except np.linalg.LinAlgError:
  194. warnflag = 3
  195. break
  196. # calculate the predicted value at the proposed point
  197. predicted_value = m(p)
  198. # define the local approximation at the proposed point
  199. x_proposed = x + p
  200. m_proposed = subproblem(x_proposed, fun, jac, hess, hessp)
  201. # evaluate the ratio defined in equation (4.4)
  202. actual_reduction = m.fun - m_proposed.fun
  203. predicted_reduction = m.fun - predicted_value
  204. if predicted_reduction <= 0:
  205. warnflag = 2
  206. break
  207. rho = actual_reduction / predicted_reduction
  208. # update the trust radius according to the actual/predicted ratio
  209. if rho < 0.25:
  210. trust_radius *= 0.25
  211. elif rho > 0.75 and hits_boundary:
  212. trust_radius = min(2*trust_radius, max_trust_radius)
  213. # if the ratio is high enough then accept the proposed step
  214. if rho > eta:
  215. x = x_proposed
  216. m = m_proposed
  217. # append the best guess, call back, increment the iteration count
  218. if return_all:
  219. allvecs.append(np.copy(x))
  220. if callback is not None:
  221. callback(np.copy(x))
  222. k += 1
  223. # check if the gradient is small enough to stop
  224. if m.jac_mag < gtol:
  225. warnflag = 0
  226. break
  227. # check if we have looked at enough iterations
  228. if k >= maxiter:
  229. warnflag = 1
  230. break
  231. # print some stuff if requested
  232. status_messages = (
  233. _status_message['success'],
  234. _status_message['maxiter'],
  235. 'A bad approximation caused failure to predict improvement.',
  236. 'A linalg error occurred, such as a non-psd Hessian.',
  237. )
  238. if disp:
  239. if warnflag == 0:
  240. print(status_messages[warnflag])
  241. else:
  242. warnings.warn(status_messages[warnflag], RuntimeWarning, 3)
  243. print(" Current function value: %f" % m.fun)
  244. print(" Iterations: %d" % k)
  245. print(" Function evaluations: %d" % sf.nfev)
  246. print(" Gradient evaluations: %d" % sf.ngev)
  247. print(" Hessian evaluations: %d" % (sf.nhev + nhessp[0]))
  248. result = OptimizeResult(x=x, success=(warnflag == 0), status=warnflag,
  249. fun=m.fun, jac=m.jac, nfev=sf.nfev, njev=sf.ngev,
  250. nhev=sf.nhev + nhessp[0], nit=k,
  251. message=status_messages[warnflag])
  252. if hess is not None:
  253. result['hess'] = m.hess
  254. if return_all:
  255. result['allvecs'] = allvecs
  256. return result