lbfgs.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. import torch
  2. from functools import reduce
  3. from .optimizer import Optimizer
  4. __all__ = ['LBFGS']
  5. def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
  6. # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
  7. # Compute bounds of interpolation area
  8. if bounds is not None:
  9. xmin_bound, xmax_bound = bounds
  10. else:
  11. xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
  12. # Code for most common case: cubic interpolation of 2 points
  13. # w/ function and derivative values for both
  14. # Solution in this case (where x2 is the farthest point):
  15. # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
  16. # d2 = sqrt(d1^2 - g1*g2);
  17. # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
  18. # t_new = min(max(min_pos,xmin_bound),xmax_bound);
  19. d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
  20. d2_square = d1**2 - g1 * g2
  21. if d2_square >= 0:
  22. d2 = d2_square.sqrt()
  23. if x1 <= x2:
  24. min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
  25. else:
  26. min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
  27. return min(max(min_pos, xmin_bound), xmax_bound)
  28. else:
  29. return (xmin_bound + xmax_bound) / 2.
  30. def _strong_wolfe(obj_func,
  31. x,
  32. t,
  33. d,
  34. f,
  35. g,
  36. gtd,
  37. c1=1e-4,
  38. c2=0.9,
  39. tolerance_change=1e-9,
  40. max_ls=25):
  41. # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
  42. d_norm = d.abs().max()
  43. g = g.clone(memory_format=torch.contiguous_format)
  44. # evaluate objective and gradient using initial step
  45. f_new, g_new = obj_func(x, t, d)
  46. ls_func_evals = 1
  47. gtd_new = g_new.dot(d)
  48. # bracket an interval containing a point satisfying the Wolfe criteria
  49. t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
  50. done = False
  51. ls_iter = 0
  52. while ls_iter < max_ls:
  53. # check conditions
  54. if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
  55. bracket = [t_prev, t]
  56. bracket_f = [f_prev, f_new]
  57. bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
  58. bracket_gtd = [gtd_prev, gtd_new]
  59. break
  60. if abs(gtd_new) <= -c2 * gtd:
  61. bracket = [t]
  62. bracket_f = [f_new]
  63. bracket_g = [g_new]
  64. done = True
  65. break
  66. if gtd_new >= 0:
  67. bracket = [t_prev, t]
  68. bracket_f = [f_prev, f_new]
  69. bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
  70. bracket_gtd = [gtd_prev, gtd_new]
  71. break
  72. # interpolate
  73. min_step = t + 0.01 * (t - t_prev)
  74. max_step = t * 10
  75. tmp = t
  76. t = _cubic_interpolate(
  77. t_prev,
  78. f_prev,
  79. gtd_prev,
  80. t,
  81. f_new,
  82. gtd_new,
  83. bounds=(min_step, max_step))
  84. # next step
  85. t_prev = tmp
  86. f_prev = f_new
  87. g_prev = g_new.clone(memory_format=torch.contiguous_format)
  88. gtd_prev = gtd_new
  89. f_new, g_new = obj_func(x, t, d)
  90. ls_func_evals += 1
  91. gtd_new = g_new.dot(d)
  92. ls_iter += 1
  93. # reached max number of iterations?
  94. if ls_iter == max_ls:
  95. bracket = [0, t]
  96. bracket_f = [f, f_new]
  97. bracket_g = [g, g_new]
  98. # zoom phase: we now have a point satisfying the criteria, or
  99. # a bracket around it. We refine the bracket until we find the
  100. # exact point satisfying the criteria
  101. insuf_progress = False
  102. # find high and low points in bracket
  103. low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
  104. while not done and ls_iter < max_ls:
  105. # line-search bracket is so small
  106. if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
  107. break
  108. # compute new trial value
  109. t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
  110. bracket[1], bracket_f[1], bracket_gtd[1])
  111. # test that we are making sufficient progress:
  112. # in case `t` is so close to boundary, we mark that we are making
  113. # insufficient progress, and if
  114. # + we have made insufficient progress in the last step, or
  115. # + `t` is at one of the boundary,
  116. # we will move `t` to a position which is `0.1 * len(bracket)`
  117. # away from the nearest boundary point.
  118. eps = 0.1 * (max(bracket) - min(bracket))
  119. if min(max(bracket) - t, t - min(bracket)) < eps:
  120. # interpolation close to boundary
  121. if insuf_progress or t >= max(bracket) or t <= min(bracket):
  122. # evaluate at 0.1 away from boundary
  123. if abs(t - max(bracket)) < abs(t - min(bracket)):
  124. t = max(bracket) - eps
  125. else:
  126. t = min(bracket) + eps
  127. insuf_progress = False
  128. else:
  129. insuf_progress = True
  130. else:
  131. insuf_progress = False
  132. # Evaluate new point
  133. f_new, g_new = obj_func(x, t, d)
  134. ls_func_evals += 1
  135. gtd_new = g_new.dot(d)
  136. ls_iter += 1
  137. if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
  138. # Armijo condition not satisfied or not lower than lowest point
  139. bracket[high_pos] = t
  140. bracket_f[high_pos] = f_new
  141. bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)
  142. bracket_gtd[high_pos] = gtd_new
  143. low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
  144. else:
  145. if abs(gtd_new) <= -c2 * gtd:
  146. # Wolfe conditions satisfied
  147. done = True
  148. elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
  149. # old high becomes new low
  150. bracket[high_pos] = bracket[low_pos]
  151. bracket_f[high_pos] = bracket_f[low_pos]
  152. bracket_g[high_pos] = bracket_g[low_pos]
  153. bracket_gtd[high_pos] = bracket_gtd[low_pos]
  154. # new point becomes new low
  155. bracket[low_pos] = t
  156. bracket_f[low_pos] = f_new
  157. bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)
  158. bracket_gtd[low_pos] = gtd_new
  159. # return stuff
  160. t = bracket[low_pos]
  161. f_new = bracket_f[low_pos]
  162. g_new = bracket_g[low_pos]
  163. return f_new, g_new, t, ls_func_evals
  164. class LBFGS(Optimizer):
  165. """Implements L-BFGS algorithm, heavily inspired by `minFunc
  166. <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.
  167. .. warning::
  168. This optimizer doesn't support per-parameter options and parameter
  169. groups (there can be only one).
  170. .. warning::
  171. Right now all parameters have to be on a single device. This will be
  172. improved in the future.
  173. .. note::
  174. This is a very memory intensive optimizer (it requires additional
  175. ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
  176. try reducing the history size, or use a different algorithm.
  177. Args:
  178. lr (float): learning rate (default: 1)
  179. max_iter (int): maximal number of iterations per optimization step
  180. (default: 20)
  181. max_eval (int): maximal number of function evaluations per optimization
  182. step (default: max_iter * 1.25).
  183. tolerance_grad (float): termination tolerance on first order optimality
  184. (default: 1e-5).
  185. tolerance_change (float): termination tolerance on function
  186. value/parameter changes (default: 1e-9).
  187. history_size (int): update history size (default: 100).
  188. line_search_fn (str): either 'strong_wolfe' or None (default: None).
  189. """
  190. def __init__(self,
  191. params,
  192. lr=1,
  193. max_iter=20,
  194. max_eval=None,
  195. tolerance_grad=1e-7,
  196. tolerance_change=1e-9,
  197. history_size=100,
  198. line_search_fn=None):
  199. if max_eval is None:
  200. max_eval = max_iter * 5 // 4
  201. defaults = dict(
  202. lr=lr,
  203. max_iter=max_iter,
  204. max_eval=max_eval,
  205. tolerance_grad=tolerance_grad,
  206. tolerance_change=tolerance_change,
  207. history_size=history_size,
  208. line_search_fn=line_search_fn)
  209. super().__init__(params, defaults)
  210. if len(self.param_groups) != 1:
  211. raise ValueError("LBFGS doesn't support per-parameter options "
  212. "(parameter groups)")
  213. self._params = self.param_groups[0]['params']
  214. self._numel_cache = None
  215. def _numel(self):
  216. if self._numel_cache is None:
  217. self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
  218. return self._numel_cache
  219. def _gather_flat_grad(self):
  220. views = []
  221. for p in self._params:
  222. if p.grad is None:
  223. view = p.new(p.numel()).zero_()
  224. elif p.grad.is_sparse:
  225. view = p.grad.to_dense().view(-1)
  226. else:
  227. view = p.grad.view(-1)
  228. views.append(view)
  229. return torch.cat(views, 0)
  230. def _add_grad(self, step_size, update):
  231. offset = 0
  232. for p in self._params:
  233. numel = p.numel()
  234. # view as to avoid deprecated pointwise semantics
  235. p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
  236. offset += numel
  237. assert offset == self._numel()
  238. def _clone_param(self):
  239. return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
  240. def _set_param(self, params_data):
  241. for p, pdata in zip(self._params, params_data):
  242. p.copy_(pdata)
  243. def _directional_evaluate(self, closure, x, t, d):
  244. self._add_grad(t, d)
  245. loss = float(closure())
  246. flat_grad = self._gather_flat_grad()
  247. self._set_param(x)
  248. return loss, flat_grad
  249. @torch.no_grad()
  250. def step(self, closure):
  251. """Performs a single optimization step.
  252. Args:
  253. closure (Callable): A closure that reevaluates the model
  254. and returns the loss.
  255. """
  256. assert len(self.param_groups) == 1
  257. # Make sure the closure is always called with grad enabled
  258. closure = torch.enable_grad()(closure)
  259. group = self.param_groups[0]
  260. lr = group['lr']
  261. max_iter = group['max_iter']
  262. max_eval = group['max_eval']
  263. tolerance_grad = group['tolerance_grad']
  264. tolerance_change = group['tolerance_change']
  265. line_search_fn = group['line_search_fn']
  266. history_size = group['history_size']
  267. # NOTE: LBFGS has only global state, but we register it as state for
  268. # the first param, because this helps with casting in load_state_dict
  269. state = self.state[self._params[0]]
  270. state.setdefault('func_evals', 0)
  271. state.setdefault('n_iter', 0)
  272. # evaluate initial f(x) and df/dx
  273. orig_loss = closure()
  274. loss = float(orig_loss)
  275. current_evals = 1
  276. state['func_evals'] += 1
  277. flat_grad = self._gather_flat_grad()
  278. opt_cond = flat_grad.abs().max() <= tolerance_grad
  279. # optimal condition
  280. if opt_cond:
  281. return orig_loss
  282. # tensors cached in state (for tracing)
  283. d = state.get('d')
  284. t = state.get('t')
  285. old_dirs = state.get('old_dirs')
  286. old_stps = state.get('old_stps')
  287. ro = state.get('ro')
  288. H_diag = state.get('H_diag')
  289. prev_flat_grad = state.get('prev_flat_grad')
  290. prev_loss = state.get('prev_loss')
  291. n_iter = 0
  292. # optimize for a max of max_iter iterations
  293. while n_iter < max_iter:
  294. # keep track of nb of iterations
  295. n_iter += 1
  296. state['n_iter'] += 1
  297. ############################################################
  298. # compute gradient descent direction
  299. ############################################################
  300. if state['n_iter'] == 1:
  301. d = flat_grad.neg()
  302. old_dirs = []
  303. old_stps = []
  304. ro = []
  305. H_diag = 1
  306. else:
  307. # do lbfgs update (update memory)
  308. y = flat_grad.sub(prev_flat_grad)
  309. s = d.mul(t)
  310. ys = y.dot(s) # y*s
  311. if ys > 1e-10:
  312. # updating memory
  313. if len(old_dirs) == history_size:
  314. # shift history by one (limited-memory)
  315. old_dirs.pop(0)
  316. old_stps.pop(0)
  317. ro.pop(0)
  318. # store new direction/step
  319. old_dirs.append(y)
  320. old_stps.append(s)
  321. ro.append(1. / ys)
  322. # update scale of initial Hessian approximation
  323. H_diag = ys / y.dot(y) # (y*y)
  324. # compute the approximate (L-BFGS) inverse Hessian
  325. # multiplied by the gradient
  326. num_old = len(old_dirs)
  327. if 'al' not in state:
  328. state['al'] = [None] * history_size
  329. al = state['al']
  330. # iteration in L-BFGS loop collapsed to use just one buffer
  331. q = flat_grad.neg()
  332. for i in range(num_old - 1, -1, -1):
  333. al[i] = old_stps[i].dot(q) * ro[i]
  334. q.add_(old_dirs[i], alpha=-al[i])
  335. # multiply by initial Hessian
  336. # r/d is the final direction
  337. d = r = torch.mul(q, H_diag)
  338. for i in range(num_old):
  339. be_i = old_dirs[i].dot(r) * ro[i]
  340. r.add_(old_stps[i], alpha=al[i] - be_i)
  341. if prev_flat_grad is None:
  342. prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
  343. else:
  344. prev_flat_grad.copy_(flat_grad)
  345. prev_loss = loss
  346. ############################################################
  347. # compute step length
  348. ############################################################
  349. # reset initial guess for step size
  350. if state['n_iter'] == 1:
  351. t = min(1., 1. / flat_grad.abs().sum()) * lr
  352. else:
  353. t = lr
  354. # directional derivative
  355. gtd = flat_grad.dot(d) # g * d
  356. # directional derivative is below tolerance
  357. if gtd > -tolerance_change:
  358. break
  359. # optional line search: user function
  360. ls_func_evals = 0
  361. if line_search_fn is not None:
  362. # perform line search, using user function
  363. if line_search_fn != "strong_wolfe":
  364. raise RuntimeError("only 'strong_wolfe' is supported")
  365. else:
  366. x_init = self._clone_param()
  367. def obj_func(x, t, d):
  368. return self._directional_evaluate(closure, x, t, d)
  369. loss, flat_grad, t, ls_func_evals = _strong_wolfe(
  370. obj_func, x_init, t, d, loss, flat_grad, gtd)
  371. self._add_grad(t, d)
  372. opt_cond = flat_grad.abs().max() <= tolerance_grad
  373. else:
  374. # no line search, simply move with fixed-step
  375. self._add_grad(t, d)
  376. if n_iter != max_iter:
  377. # re-evaluate function only if not in last iteration
  378. # the reason we do this: in a stochastic setting,
  379. # no use to re-evaluate that function here
  380. with torch.enable_grad():
  381. loss = float(closure())
  382. flat_grad = self._gather_flat_grad()
  383. opt_cond = flat_grad.abs().max() <= tolerance_grad
  384. ls_func_evals = 1
  385. # update func eval
  386. current_evals += ls_func_evals
  387. state['func_evals'] += ls_func_evals
  388. ############################################################
  389. # check conditions
  390. ############################################################
  391. if n_iter == max_iter:
  392. break
  393. if current_evals >= max_eval:
  394. break
  395. # optimal condition
  396. if opt_cond:
  397. break
  398. # lack of progress
  399. if d.mul(t).abs().max() <= tolerance_change:
  400. break
  401. if abs(loss - prev_loss) < tolerance_change:
  402. break
  403. state['d'] = d
  404. state['t'] = t
  405. state['old_dirs'] = old_dirs
  406. state['old_stps'] = old_stps
  407. state['ro'] = ro
  408. state['H_diag'] = H_diag
  409. state['prev_flat_grad'] = prev_flat_grad
  410. state['prev_loss'] = prev_loss
  411. return orig_loss