123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476 |
- import torch
- from functools import reduce
- from .optimizer import Optimizer
- __all__ = ['LBFGS']
- def _cubic_interpolate(x1, f1, g1, x2, f2, g2, bounds=None):
- # ported from https://github.com/torch/optim/blob/master/polyinterp.lua
- # Compute bounds of interpolation area
- if bounds is not None:
- xmin_bound, xmax_bound = bounds
- else:
- xmin_bound, xmax_bound = (x1, x2) if x1 <= x2 else (x2, x1)
- # Code for most common case: cubic interpolation of 2 points
- # w/ function and derivative values for both
- # Solution in this case (where x2 is the farthest point):
- # d1 = g1 + g2 - 3*(f1-f2)/(x1-x2);
- # d2 = sqrt(d1^2 - g1*g2);
- # min_pos = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2));
- # t_new = min(max(min_pos,xmin_bound),xmax_bound);
- d1 = g1 + g2 - 3 * (f1 - f2) / (x1 - x2)
- d2_square = d1**2 - g1 * g2
- if d2_square >= 0:
- d2 = d2_square.sqrt()
- if x1 <= x2:
- min_pos = x2 - (x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))
- else:
- min_pos = x1 - (x1 - x2) * ((g1 + d2 - d1) / (g1 - g2 + 2 * d2))
- return min(max(min_pos, xmin_bound), xmax_bound)
- else:
- return (xmin_bound + xmax_bound) / 2.
- def _strong_wolfe(obj_func,
- x,
- t,
- d,
- f,
- g,
- gtd,
- c1=1e-4,
- c2=0.9,
- tolerance_change=1e-9,
- max_ls=25):
- # ported from https://github.com/torch/optim/blob/master/lswolfe.lua
- d_norm = d.abs().max()
- g = g.clone(memory_format=torch.contiguous_format)
- # evaluate objective and gradient using initial step
- f_new, g_new = obj_func(x, t, d)
- ls_func_evals = 1
- gtd_new = g_new.dot(d)
- # bracket an interval containing a point satisfying the Wolfe criteria
- t_prev, f_prev, g_prev, gtd_prev = 0, f, g, gtd
- done = False
- ls_iter = 0
- while ls_iter < max_ls:
- # check conditions
- if f_new > (f + c1 * t * gtd) or (ls_iter > 1 and f_new >= f_prev):
- bracket = [t_prev, t]
- bracket_f = [f_prev, f_new]
- bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
- bracket_gtd = [gtd_prev, gtd_new]
- break
- if abs(gtd_new) <= -c2 * gtd:
- bracket = [t]
- bracket_f = [f_new]
- bracket_g = [g_new]
- done = True
- break
- if gtd_new >= 0:
- bracket = [t_prev, t]
- bracket_f = [f_prev, f_new]
- bracket_g = [g_prev, g_new.clone(memory_format=torch.contiguous_format)]
- bracket_gtd = [gtd_prev, gtd_new]
- break
- # interpolate
- min_step = t + 0.01 * (t - t_prev)
- max_step = t * 10
- tmp = t
- t = _cubic_interpolate(
- t_prev,
- f_prev,
- gtd_prev,
- t,
- f_new,
- gtd_new,
- bounds=(min_step, max_step))
- # next step
- t_prev = tmp
- f_prev = f_new
- g_prev = g_new.clone(memory_format=torch.contiguous_format)
- gtd_prev = gtd_new
- f_new, g_new = obj_func(x, t, d)
- ls_func_evals += 1
- gtd_new = g_new.dot(d)
- ls_iter += 1
- # reached max number of iterations?
- if ls_iter == max_ls:
- bracket = [0, t]
- bracket_f = [f, f_new]
- bracket_g = [g, g_new]
- # zoom phase: we now have a point satisfying the criteria, or
- # a bracket around it. We refine the bracket until we find the
- # exact point satisfying the criteria
- insuf_progress = False
- # find high and low points in bracket
- low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[-1] else (1, 0)
- while not done and ls_iter < max_ls:
- # line-search bracket is so small
- if abs(bracket[1] - bracket[0]) * d_norm < tolerance_change:
- break
- # compute new trial value
- t = _cubic_interpolate(bracket[0], bracket_f[0], bracket_gtd[0],
- bracket[1], bracket_f[1], bracket_gtd[1])
- # test that we are making sufficient progress:
- # in case `t` is so close to boundary, we mark that we are making
- # insufficient progress, and if
- # + we have made insufficient progress in the last step, or
- # + `t` is at one of the boundary,
- # we will move `t` to a position which is `0.1 * len(bracket)`
- # away from the nearest boundary point.
- eps = 0.1 * (max(bracket) - min(bracket))
- if min(max(bracket) - t, t - min(bracket)) < eps:
- # interpolation close to boundary
- if insuf_progress or t >= max(bracket) or t <= min(bracket):
- # evaluate at 0.1 away from boundary
- if abs(t - max(bracket)) < abs(t - min(bracket)):
- t = max(bracket) - eps
- else:
- t = min(bracket) + eps
- insuf_progress = False
- else:
- insuf_progress = True
- else:
- insuf_progress = False
- # Evaluate new point
- f_new, g_new = obj_func(x, t, d)
- ls_func_evals += 1
- gtd_new = g_new.dot(d)
- ls_iter += 1
- if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
- # Armijo condition not satisfied or not lower than lowest point
- bracket[high_pos] = t
- bracket_f[high_pos] = f_new
- bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format)
- bracket_gtd[high_pos] = gtd_new
- low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
- else:
- if abs(gtd_new) <= -c2 * gtd:
- # Wolfe conditions satisfied
- done = True
- elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
- # old high becomes new low
- bracket[high_pos] = bracket[low_pos]
- bracket_f[high_pos] = bracket_f[low_pos]
- bracket_g[high_pos] = bracket_g[low_pos]
- bracket_gtd[high_pos] = bracket_gtd[low_pos]
- # new point becomes new low
- bracket[low_pos] = t
- bracket_f[low_pos] = f_new
- bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format)
- bracket_gtd[low_pos] = gtd_new
- # return stuff
- t = bracket[low_pos]
- f_new = bracket_f[low_pos]
- g_new = bracket_g[low_pos]
- return f_new, g_new, t, ls_func_evals
- class LBFGS(Optimizer):
- """Implements L-BFGS algorithm, heavily inspired by `minFunc
- <https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`_.
- .. warning::
- This optimizer doesn't support per-parameter options and parameter
- groups (there can be only one).
- .. warning::
- Right now all parameters have to be on a single device. This will be
- improved in the future.
- .. note::
- This is a very memory intensive optimizer (it requires additional
- ``param_bytes * (history_size + 1)`` bytes). If it doesn't fit in memory
- try reducing the history size, or use a different algorithm.
- Args:
- lr (float): learning rate (default: 1)
- max_iter (int): maximal number of iterations per optimization step
- (default: 20)
- max_eval (int): maximal number of function evaluations per optimization
- step (default: max_iter * 1.25).
- tolerance_grad (float): termination tolerance on first order optimality
- (default: 1e-5).
- tolerance_change (float): termination tolerance on function
- value/parameter changes (default: 1e-9).
- history_size (int): update history size (default: 100).
- line_search_fn (str): either 'strong_wolfe' or None (default: None).
- """
- def __init__(self,
- params,
- lr=1,
- max_iter=20,
- max_eval=None,
- tolerance_grad=1e-7,
- tolerance_change=1e-9,
- history_size=100,
- line_search_fn=None):
- if max_eval is None:
- max_eval = max_iter * 5 // 4
- defaults = dict(
- lr=lr,
- max_iter=max_iter,
- max_eval=max_eval,
- tolerance_grad=tolerance_grad,
- tolerance_change=tolerance_change,
- history_size=history_size,
- line_search_fn=line_search_fn)
- super().__init__(params, defaults)
- if len(self.param_groups) != 1:
- raise ValueError("LBFGS doesn't support per-parameter options "
- "(parameter groups)")
- self._params = self.param_groups[0]['params']
- self._numel_cache = None
- def _numel(self):
- if self._numel_cache is None:
- self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
- return self._numel_cache
- def _gather_flat_grad(self):
- views = []
- for p in self._params:
- if p.grad is None:
- view = p.new(p.numel()).zero_()
- elif p.grad.is_sparse:
- view = p.grad.to_dense().view(-1)
- else:
- view = p.grad.view(-1)
- views.append(view)
- return torch.cat(views, 0)
- def _add_grad(self, step_size, update):
- offset = 0
- for p in self._params:
- numel = p.numel()
- # view as to avoid deprecated pointwise semantics
- p.add_(update[offset:offset + numel].view_as(p), alpha=step_size)
- offset += numel
- assert offset == self._numel()
- def _clone_param(self):
- return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
- def _set_param(self, params_data):
- for p, pdata in zip(self._params, params_data):
- p.copy_(pdata)
- def _directional_evaluate(self, closure, x, t, d):
- self._add_grad(t, d)
- loss = float(closure())
- flat_grad = self._gather_flat_grad()
- self._set_param(x)
- return loss, flat_grad
- @torch.no_grad()
- def step(self, closure):
- """Performs a single optimization step.
- Args:
- closure (Callable): A closure that reevaluates the model
- and returns the loss.
- """
- assert len(self.param_groups) == 1
- # Make sure the closure is always called with grad enabled
- closure = torch.enable_grad()(closure)
- group = self.param_groups[0]
- lr = group['lr']
- max_iter = group['max_iter']
- max_eval = group['max_eval']
- tolerance_grad = group['tolerance_grad']
- tolerance_change = group['tolerance_change']
- line_search_fn = group['line_search_fn']
- history_size = group['history_size']
- # NOTE: LBFGS has only global state, but we register it as state for
- # the first param, because this helps with casting in load_state_dict
- state = self.state[self._params[0]]
- state.setdefault('func_evals', 0)
- state.setdefault('n_iter', 0)
- # evaluate initial f(x) and df/dx
- orig_loss = closure()
- loss = float(orig_loss)
- current_evals = 1
- state['func_evals'] += 1
- flat_grad = self._gather_flat_grad()
- opt_cond = flat_grad.abs().max() <= tolerance_grad
- # optimal condition
- if opt_cond:
- return orig_loss
- # tensors cached in state (for tracing)
- d = state.get('d')
- t = state.get('t')
- old_dirs = state.get('old_dirs')
- old_stps = state.get('old_stps')
- ro = state.get('ro')
- H_diag = state.get('H_diag')
- prev_flat_grad = state.get('prev_flat_grad')
- prev_loss = state.get('prev_loss')
- n_iter = 0
- # optimize for a max of max_iter iterations
- while n_iter < max_iter:
- # keep track of nb of iterations
- n_iter += 1
- state['n_iter'] += 1
- ############################################################
- # compute gradient descent direction
- ############################################################
- if state['n_iter'] == 1:
- d = flat_grad.neg()
- old_dirs = []
- old_stps = []
- ro = []
- H_diag = 1
- else:
- # do lbfgs update (update memory)
- y = flat_grad.sub(prev_flat_grad)
- s = d.mul(t)
- ys = y.dot(s) # y*s
- if ys > 1e-10:
- # updating memory
- if len(old_dirs) == history_size:
- # shift history by one (limited-memory)
- old_dirs.pop(0)
- old_stps.pop(0)
- ro.pop(0)
- # store new direction/step
- old_dirs.append(y)
- old_stps.append(s)
- ro.append(1. / ys)
- # update scale of initial Hessian approximation
- H_diag = ys / y.dot(y) # (y*y)
- # compute the approximate (L-BFGS) inverse Hessian
- # multiplied by the gradient
- num_old = len(old_dirs)
- if 'al' not in state:
- state['al'] = [None] * history_size
- al = state['al']
- # iteration in L-BFGS loop collapsed to use just one buffer
- q = flat_grad.neg()
- for i in range(num_old - 1, -1, -1):
- al[i] = old_stps[i].dot(q) * ro[i]
- q.add_(old_dirs[i], alpha=-al[i])
- # multiply by initial Hessian
- # r/d is the final direction
- d = r = torch.mul(q, H_diag)
- for i in range(num_old):
- be_i = old_dirs[i].dot(r) * ro[i]
- r.add_(old_stps[i], alpha=al[i] - be_i)
- if prev_flat_grad is None:
- prev_flat_grad = flat_grad.clone(memory_format=torch.contiguous_format)
- else:
- prev_flat_grad.copy_(flat_grad)
- prev_loss = loss
- ############################################################
- # compute step length
- ############################################################
- # reset initial guess for step size
- if state['n_iter'] == 1:
- t = min(1., 1. / flat_grad.abs().sum()) * lr
- else:
- t = lr
- # directional derivative
- gtd = flat_grad.dot(d) # g * d
- # directional derivative is below tolerance
- if gtd > -tolerance_change:
- break
- # optional line search: user function
- ls_func_evals = 0
- if line_search_fn is not None:
- # perform line search, using user function
- if line_search_fn != "strong_wolfe":
- raise RuntimeError("only 'strong_wolfe' is supported")
- else:
- x_init = self._clone_param()
- def obj_func(x, t, d):
- return self._directional_evaluate(closure, x, t, d)
- loss, flat_grad, t, ls_func_evals = _strong_wolfe(
- obj_func, x_init, t, d, loss, flat_grad, gtd)
- self._add_grad(t, d)
- opt_cond = flat_grad.abs().max() <= tolerance_grad
- else:
- # no line search, simply move with fixed-step
- self._add_grad(t, d)
- if n_iter != max_iter:
- # re-evaluate function only if not in last iteration
- # the reason we do this: in a stochastic setting,
- # no use to re-evaluate that function here
- with torch.enable_grad():
- loss = float(closure())
- flat_grad = self._gather_flat_grad()
- opt_cond = flat_grad.abs().max() <= tolerance_grad
- ls_func_evals = 1
- # update func eval
- current_evals += ls_func_evals
- state['func_evals'] += ls_func_evals
- ############################################################
- # check conditions
- ############################################################
- if n_iter == max_iter:
- break
- if current_evals >= max_eval:
- break
- # optimal condition
- if opt_cond:
- break
- # lack of progress
- if d.mul(t).abs().max() <= tolerance_change:
- break
- if abs(loss - prev_loss) < tolerance_change:
- break
- state['d'] = d
- state['t'] = t
- state['old_dirs'] = old_dirs
- state['old_stps'] = old_stps
- state['ro'] = ro
- state['H_diag'] = H_diag
- state['prev_flat_grad'] = prev_flat_grad
- state['prev_loss'] = prev_loss
- return orig_loss
|