_direct_py.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. from __future__ import annotations
  2. from typing import (
  3. Any, Callable, Iterable, Optional, Tuple, TYPE_CHECKING, Union
  4. )
  5. import numpy as np
  6. from scipy.optimize import OptimizeResult
  7. from ._constraints import old_bound_to_new, Bounds
  8. from ._direct import direct as _direct # type: ignore
  9. if TYPE_CHECKING:
  10. import numpy.typing as npt
  11. from _typeshed import NoneType
  12. __all__ = ['direct']
  13. ERROR_MESSAGES = (
  14. "Number of function evaluations done is larger than maxfun={}",
  15. "Number of iterations is larger than maxiter={}",
  16. "u[i] < l[i] for some i",
  17. "maxfun is too large",
  18. "Initialization failed",
  19. "There was an error in the creation of the sample points",
  20. "An error occured while the function was sampled",
  21. "Maximum number of levels has been reached.",
  22. "Forced stop",
  23. "Invalid arguments",
  24. "Out of memory",
  25. )
  26. SUCCESS_MESSAGES = (
  27. ("The best function value found is within a relative error={} "
  28. "of the (known) global optimum f_min"),
  29. ("The volume of the hyperrectangle containing the lowest function value "
  30. "found is below vol_tol={}"),
  31. ("The side length measure of the hyperrectangle containing the lowest "
  32. "function value found is below len_tol={}"),
  33. )
  34. def direct(
  35. func: Callable[[npt.ArrayLike, Tuple[Any]], float],
  36. bounds: Union[Iterable, Bounds],
  37. *,
  38. args: tuple = (),
  39. eps: float = 1e-4,
  40. maxfun: Union[int, None] = None,
  41. maxiter: int = 1000,
  42. locally_biased: bool = True,
  43. f_min: float = -np.inf,
  44. f_min_rtol: float = 1e-4,
  45. vol_tol: float = 1e-16,
  46. len_tol: float = 1e-6,
  47. callback: Optional[Callable[[npt.ArrayLike], NoneType]] = None
  48. ) -> OptimizeResult:
  49. """
  50. Finds the global minimum of a function using the
  51. DIRECT algorithm.
  52. Parameters
  53. ----------
  54. func : callable
  55. The objective function to be minimized.
  56. ``func(x, *args) -> float``
  57. where ``x`` is an 1-D array with shape (n,) and ``args`` is a tuple of
  58. the fixed parameters needed to completely specify the function.
  59. bounds : sequence or `Bounds`
  60. Bounds for variables. There are two ways to specify the bounds:
  61. 1. Instance of `Bounds` class.
  62. 2. ``(min, max)`` pairs for each element in ``x``.
  63. args : tuple, optional
  64. Any additional fixed parameters needed to
  65. completely specify the objective function.
  66. eps : float, optional
  67. Minimal required difference of the objective function values
  68. between the current best hyperrectangle and the next potentially
  69. optimal hyperrectangle to be divided. In consequence, `eps` serves as a
  70. tradeoff between local and global search: the smaller, the more local
  71. the search becomes. Default is 1e-4.
  72. maxfun : int or None, optional
  73. Approximate upper bound on objective function evaluations.
  74. If `None`, will be automatically set to ``1000 * N`` where ``N``
  75. represents the number of dimensions. Will be capped if necessary to
  76. limit DIRECT's RAM usage to app. 1GiB. This will only occur for very
  77. high dimensional problems and excessive `max_fun`. Default is `None`.
  78. maxiter : int, optional
  79. Maximum number of iterations. Default is 1000.
  80. locally_biased : bool, optional
  81. If `True` (default), use the locally biased variant of the
  82. algorithm known as DIRECT_L. If `False`, use the original unbiased
  83. DIRECT algorithm. For hard problems with many local minima,
  84. `False` is recommended.
  85. f_min : float, optional
  86. Function value of the global optimum. Set this value only if the
  87. global optimum is known. Default is ``-np.inf``, so that this
  88. termination criterion is deactivated.
  89. f_min_rtol : float, optional
  90. Terminate the optimization once the relative error between the
  91. current best minimum `f` and the supplied global minimum `f_min`
  92. is smaller than `f_min_rtol`. This parameter is only used if
  93. `f_min` is also set. Must lie between 0 and 1. Default is 1e-4.
  94. vol_tol : float, optional
  95. Terminate the optimization once the volume of the hyperrectangle
  96. containing the lowest function value is smaller than `vol_tol`
  97. of the complete search space. Must lie between 0 and 1.
  98. Default is 1e-16.
  99. len_tol : float, optional
  100. If `locally_biased=True`, terminate the optimization once half of
  101. the normalized maximal side length of the hyperrectangle containing
  102. the lowest function value is smaller than `len_tol`.
  103. If `locally_biased=False`, terminate the optimization once half of
  104. the normalized diagonal of the hyperrectangle containing the lowest
  105. function value is smaller than `len_tol`. Must lie between 0 and 1.
  106. Default is 1e-6.
  107. callback : callable, optional
  108. A callback function with signature ``callback(xk)`` where ``xk``
  109. represents the best function value found so far.
  110. Returns
  111. -------
  112. res : OptimizeResult
  113. The optimization result represented as a ``OptimizeResult`` object.
  114. Important attributes are: ``x`` the solution array, ``success`` a
  115. Boolean flag indicating if the optimizer exited successfully and
  116. ``message`` which describes the cause of the termination. See
  117. `OptimizeResult` for a description of other attributes.
  118. Notes
  119. -----
  120. DIviding RECTangles (DIRECT) is a deterministic global
  121. optimization algorithm capable of minimizing a black box function with
  122. its variables subject to lower and upper bound constraints by sampling
  123. potential solutions in the search space [1]_. The algorithm starts by
  124. normalising the search space to an n-dimensional unit hypercube.
  125. It samples the function at the center of this hypercube and at 2n
  126. (n is the number of variables) more points, 2 in each coordinate
  127. direction. Using these function values, DIRECT then divides the
  128. domain into hyperrectangles, each having exactly one of the sampling
  129. points as its center. In each iteration, DIRECT chooses, using the `eps`
  130. parameter which defaults to 1e-4, some of the existing hyperrectangles
  131. to be further divided. This division process continues until either the
  132. maximum number of iterations or maximum function evaluations allowed
  133. are exceeded, or the hyperrectangle containing the minimal value found
  134. so far becomes small enough. If `f_min` is specified, the optimization
  135. will stop once this function value is reached within a relative tolerance.
  136. The locally biased variant of DIRECT (originally called DIRECT_L) [2]_ is
  137. used by default. It makes the search more locally biased and more
  138. efficient for cases with only a few local minima.
  139. A note about termination criteria: `vol_tol` refers to the volume of the
  140. hyperrectangle containing the lowest function value found so far. This
  141. volume decreases exponentially with increasing dimensionality of the
  142. problem. Therefore `vol_tol` should be decreased to avoid premature
  143. termination of the algorithm for higher dimensions. This does not hold
  144. for `len_tol`: it refers either to half of the maximal side length
  145. (for ``locally_biased=True``) or half of the diagonal of the
  146. hyperrectangle (for ``locally_biased=False``).
  147. This code is based on the DIRECT 2.0.4 Fortran code by Gablonsky et al. at
  148. https://ctk.math.ncsu.edu/SOFTWARE/DIRECTv204.tar.gz .
  149. This original version was initially converted via f2c and then cleaned up
  150. and reorganized by Steven G. Johnson, August 2007, for the NLopt project.
  151. The `direct` function wraps the C implementation.
  152. .. versionadded:: 1.9.0
  153. References
  154. ----------
  155. .. [1] Jones, D.R., Perttunen, C.D. & Stuckman, B.E. Lipschitzian
  156. optimization without the Lipschitz constant. J Optim Theory Appl
  157. 79, 157-181 (1993).
  158. .. [2] Gablonsky, J., Kelley, C. A Locally-Biased form of the DIRECT
  159. Algorithm. Journal of Global Optimization 21, 27-37 (2001).
  160. Examples
  161. --------
  162. The following example is a 2-D problem with four local minima: minimizing
  163. the Styblinski-Tang function
  164. (https://en.wikipedia.org/wiki/Test_functions_for_optimization).
  165. >>> from scipy.optimize import direct, Bounds
  166. >>> def styblinski_tang(pos):
  167. ... x, y = pos
  168. ... return 0.5 * (x**4 - 16*x**2 + 5*x + y**4 - 16*y**2 + 5*y)
  169. >>> bounds = Bounds([-4., -4.], [4., 4.])
  170. >>> result = direct(styblinski_tang, bounds)
  171. >>> result.x, result.fun, result.nfev
  172. array([-2.90321597, -2.90321597]), -78.3323279095383, 2011
  173. The correct global minimum was found but with a huge number of function
  174. evaluations (2011). Loosening the termination tolerances `vol_tol` and
  175. `len_tol` can be used to stop DIRECT earlier.
  176. >>> result = direct(styblinski_tang, bounds, len_tol=1e-3)
  177. >>> result.x, result.fun, result.nfev
  178. array([-2.9044353, -2.9044353]), -78.33230330754142, 207
  179. """
  180. # convert bounds to new Bounds class if necessary
  181. if not isinstance(bounds, Bounds):
  182. if isinstance(bounds, list) or isinstance(bounds, tuple):
  183. lb, ub = old_bound_to_new(bounds)
  184. bounds = Bounds(lb, ub)
  185. else:
  186. message = ("bounds must be a sequence or "
  187. "instance of Bounds class")
  188. raise ValueError(message)
  189. lb = np.ascontiguousarray(bounds.lb, dtype=np.float64)
  190. ub = np.ascontiguousarray(bounds.ub, dtype=np.float64)
  191. # validate bounds
  192. # check that lower bounds are smaller than upper bounds
  193. if not np.all(lb < ub):
  194. raise ValueError('Bounds are not consistent min < max')
  195. # check for infs
  196. if (np.any(np.isinf(lb)) or np.any(np.isinf(ub))):
  197. raise ValueError("Bounds must not be inf.")
  198. # validate tolerances
  199. if (vol_tol < 0 or vol_tol > 1):
  200. raise ValueError("vol_tol must be between 0 and 1.")
  201. if (len_tol < 0 or len_tol > 1):
  202. raise ValueError("len_tol must be between 0 and 1.")
  203. if (f_min_rtol < 0 or f_min_rtol > 1):
  204. raise ValueError("f_min_rtol must be between 0 and 1.")
  205. # validate maxfun and maxiter
  206. if maxfun is None:
  207. maxfun = 1000 * lb.shape[0]
  208. if not isinstance(maxfun, int):
  209. raise ValueError("maxfun must be of type int.")
  210. if maxfun < 0:
  211. raise ValueError("maxfun must be > 0.")
  212. if not isinstance(maxiter, int):
  213. raise ValueError("maxiter must be of type int.")
  214. if maxiter < 0:
  215. raise ValueError("maxiter must be > 0.")
  216. # validate boolean parameters
  217. if not isinstance(locally_biased, bool):
  218. raise ValueError("locally_biased must be True or False.")
  219. def _func_wrap(x, args=None):
  220. x = np.asarray(x)
  221. if args is None:
  222. f = func(x)
  223. else:
  224. f = func(x, *args)
  225. # always return a float
  226. return np.asarray(f).item()
  227. # TODO: fix disp argument
  228. x, fun, ret_code, nfev, nit = _direct(
  229. _func_wrap,
  230. np.asarray(lb), np.asarray(ub),
  231. args,
  232. False, eps, maxfun, maxiter,
  233. locally_biased,
  234. f_min, f_min_rtol,
  235. vol_tol, len_tol, callback
  236. )
  237. format_val = (maxfun, maxiter, f_min_rtol, vol_tol, len_tol)
  238. if ret_code > 2:
  239. message = SUCCESS_MESSAGES[ret_code - 3].format(
  240. format_val[ret_code - 1])
  241. elif 0 < ret_code <= 2:
  242. message = ERROR_MESSAGES[ret_code - 1].format(format_val[ret_code - 1])
  243. elif 0 > ret_code > -100:
  244. message = ERROR_MESSAGES[abs(ret_code) + 1]
  245. else:
  246. message = ERROR_MESSAGES[ret_code + 99]
  247. return OptimizeResult(x=np.asarray(x), fun=fun, status=ret_code,
  248. success=ret_code > 2, message=message,
  249. nfev=nfev, nit=nit)