sgd.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import torch
  2. from torch import Tensor
  3. from .optimizer import (Optimizer, required, _use_grad_for_differentiable, _default_to_fused_or_foreach,
  4. _differentiable_doc, _foreach_doc, _maximize_doc)
  5. from typing import List, Optional
  6. from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
  7. __all__ = ['SGD', 'sgd']
  8. class SGD(Optimizer):
  9. def __init__(self, params, lr=required, momentum=0, dampening=0,
  10. weight_decay=0, nesterov=False, *, maximize: bool = False, foreach: Optional[bool] = None,
  11. differentiable: bool = False):
  12. if lr is not required and lr < 0.0:
  13. raise ValueError("Invalid learning rate: {}".format(lr))
  14. if momentum < 0.0:
  15. raise ValueError("Invalid momentum value: {}".format(momentum))
  16. if weight_decay < 0.0:
  17. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  18. defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
  19. weight_decay=weight_decay, nesterov=nesterov,
  20. maximize=maximize, foreach=foreach,
  21. differentiable=differentiable)
  22. if nesterov and (momentum <= 0 or dampening != 0):
  23. raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  24. super().__init__(params, defaults)
  25. def __setstate__(self, state):
  26. super().__setstate__(state)
  27. for group in self.param_groups:
  28. group.setdefault('nesterov', False)
  29. group.setdefault('maximize', False)
  30. group.setdefault('foreach', None)
  31. group.setdefault('differentiable', False)
  32. def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
  33. has_sparse_grad = False
  34. for p in group['params']:
  35. if p.grad is not None:
  36. params_with_grad.append(p)
  37. d_p_list.append(p.grad)
  38. if p.grad.is_sparse:
  39. has_sparse_grad = True
  40. state = self.state[p]
  41. if 'momentum_buffer' not in state:
  42. momentum_buffer_list.append(None)
  43. else:
  44. momentum_buffer_list.append(state['momentum_buffer'])
  45. return has_sparse_grad
  46. @_use_grad_for_differentiable
  47. def step(self, closure=None):
  48. """Performs a single optimization step.
  49. Args:
  50. closure (Callable, optional): A closure that reevaluates the model
  51. and returns the loss.
  52. """
  53. loss = None
  54. if closure is not None:
  55. with torch.enable_grad():
  56. loss = closure()
  57. for group in self.param_groups:
  58. params_with_grad = []
  59. d_p_list = []
  60. momentum_buffer_list = []
  61. has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list)
  62. sgd(params_with_grad,
  63. d_p_list,
  64. momentum_buffer_list,
  65. weight_decay=group['weight_decay'],
  66. momentum=group['momentum'],
  67. lr=group['lr'],
  68. dampening=group['dampening'],
  69. nesterov=group['nesterov'],
  70. maximize=group['maximize'],
  71. has_sparse_grad=has_sparse_grad,
  72. foreach=group['foreach'])
  73. # update momentum_buffers in state
  74. for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
  75. state = self.state[p]
  76. state['momentum_buffer'] = momentum_buffer
  77. return loss
  78. SGD.__doc__ = r"""\
  79. Implements stochastic gradient descent (optionally with momentum).
  80. .. math::
  81. \begin{aligned}
  82. &\rule{110mm}{0.4pt} \\
  83. &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
  84. \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
  85. &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
  86. \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
  87. &\rule{110mm}{0.4pt} \\
  88. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  89. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  90. &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
  91. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  92. &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
  93. &\hspace{10mm}\textbf{if} \: t > 1 \\
  94. &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
  95. &\hspace{10mm}\textbf{else} \\
  96. &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
  97. &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
  98. &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
  99. &\hspace{10mm}\textbf{else} \\[-1.ex]
  100. &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
  101. &\hspace{5mm}\textbf{if} \: \textit{maximize} \\
  102. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex]
  103. &\hspace{5mm}\textbf{else} \\[-1.ex]
  104. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
  105. &\rule{110mm}{0.4pt} \\[-1.ex]
  106. &\bf{return} \: \theta_t \\[-1.ex]
  107. &\rule{110mm}{0.4pt} \\[-1.ex]
  108. \end{aligned}
  109. Nesterov momentum is based on the formula from
  110. `On the importance of initialization and momentum in deep learning`__.
  111. """ + r"""
  112. Args:
  113. params (iterable): iterable of parameters to optimize or dicts defining
  114. parameter groups
  115. lr (float): learning rate
  116. momentum (float, optional): momentum factor (default: 0)
  117. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  118. dampening (float, optional): dampening for momentum (default: 0)
  119. nesterov (bool, optional): enables Nesterov momentum (default: False)
  120. {maximize}
  121. {foreach}
  122. {differentiable}
  123. """.format(maximize=_maximize_doc, foreach=_foreach_doc, differentiable=_differentiable_doc) + r"""
  124. Example:
  125. >>> # xdoctest: +SKIP
  126. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  127. >>> optimizer.zero_grad()
  128. >>> loss_fn(model(input), target).backward()
  129. >>> optimizer.step()
  130. __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
  131. .. note::
  132. The implementation of SGD with Momentum/Nesterov subtly differs from
  133. Sutskever et. al. and implementations in some other frameworks.
  134. Considering the specific case of Momentum, the update can be written as
  135. .. math::
  136. \begin{aligned}
  137. v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
  138. p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
  139. \end{aligned}
  140. where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
  141. parameters, gradient, velocity, and momentum respectively.
  142. This is in contrast to Sutskever et. al. and
  143. other frameworks which employ an update of the form
  144. .. math::
  145. \begin{aligned}
  146. v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
  147. p_{t+1} & = p_{t} - v_{t+1}.
  148. \end{aligned}
  149. The Nesterov version is analogously modified.
  150. Moreover, the initial value of the momentum buffer is set to the
  151. gradient value at the first step. This is in contrast to some other
  152. frameworks that initialize it to all zeros.
  153. """
  154. def sgd(params: List[Tensor],
  155. d_p_list: List[Tensor],
  156. momentum_buffer_list: List[Optional[Tensor]],
  157. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  158. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  159. has_sparse_grad: bool = None,
  160. foreach: Optional[bool] = None,
  161. *,
  162. weight_decay: float,
  163. momentum: float,
  164. lr: float,
  165. dampening: float,
  166. nesterov: bool,
  167. maximize: bool):
  168. r"""Functional API that performs SGD algorithm computation.
  169. See :class:`~torch.optim.SGD` for details.
  170. """
  171. if foreach is None:
  172. # why must we be explicit about an if statement for torch.jit.is_scripting here?
  173. # because JIT can't handle Optionals nor fancy conditionals when scripting
  174. if not torch.jit.is_scripting():
  175. _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
  176. else:
  177. foreach = False
  178. if foreach and torch.jit.is_scripting():
  179. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  180. if foreach and not torch.jit.is_scripting():
  181. func = _multi_tensor_sgd
  182. else:
  183. func = _single_tensor_sgd
  184. func(params,
  185. d_p_list,
  186. momentum_buffer_list,
  187. weight_decay=weight_decay,
  188. momentum=momentum,
  189. lr=lr,
  190. dampening=dampening,
  191. nesterov=nesterov,
  192. has_sparse_grad=has_sparse_grad,
  193. maximize=maximize)
  194. def _single_tensor_sgd(params: List[Tensor],
  195. d_p_list: List[Tensor],
  196. momentum_buffer_list: List[Optional[Tensor]],
  197. *,
  198. weight_decay: float,
  199. momentum: float,
  200. lr: float,
  201. dampening: float,
  202. nesterov: bool,
  203. maximize: bool,
  204. has_sparse_grad: bool):
  205. for i, param in enumerate(params):
  206. d_p = d_p_list[i] if not maximize else -d_p_list[i]
  207. if weight_decay != 0:
  208. d_p = d_p.add(param, alpha=weight_decay)
  209. if momentum != 0:
  210. buf = momentum_buffer_list[i]
  211. if buf is None:
  212. buf = torch.clone(d_p).detach()
  213. momentum_buffer_list[i] = buf
  214. else:
  215. buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
  216. if nesterov:
  217. d_p = d_p.add(buf, alpha=momentum)
  218. else:
  219. d_p = buf
  220. param.add_(d_p, alpha=-lr)
  221. def _multi_tensor_sgd(params: List[Tensor],
  222. grads: List[Tensor],
  223. momentum_buffer_list: List[Optional[Tensor]],
  224. *,
  225. weight_decay: float,
  226. momentum: float,
  227. lr: float,
  228. dampening: float,
  229. nesterov: bool,
  230. maximize: bool,
  231. has_sparse_grad: bool):
  232. if len(params) == 0:
  233. return
  234. grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, momentum_buffer_list], with_indices=True)
  235. for device_params, device_grads, device_momentum_buffer_list, indices in grouped_tensors.values():
  236. device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)
  237. if maximize:
  238. device_grads = torch._foreach_neg(tuple(device_grads)) # type: ignore[assignment]
  239. if weight_decay != 0:
  240. device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
  241. if momentum != 0:
  242. bufs = []
  243. all_states_with_momentum_buffer = True
  244. for i in range(len(device_momentum_buffer_list)):
  245. if device_momentum_buffer_list[i] is None:
  246. all_states_with_momentum_buffer = False
  247. break
  248. else:
  249. bufs.append(device_momentum_buffer_list[i])
  250. if all_states_with_momentum_buffer:
  251. torch._foreach_mul_(bufs, momentum)
  252. torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
  253. else:
  254. bufs = []
  255. for i in range(len(device_momentum_buffer_list)):
  256. if device_momentum_buffer_list[i] is None:
  257. buf = device_momentum_buffer_list[i] = momentum_buffer_list[indices[i]] = \
  258. torch.clone(device_grads[i]).detach()
  259. else:
  260. buf = device_momentum_buffer_list[i]
  261. buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
  262. bufs.append(buf)
  263. if nesterov:
  264. torch._foreach_add_(device_grads, bufs, alpha=momentum)
  265. else:
  266. device_grads = bufs
  267. if not device_has_sparse_grad:
  268. torch._foreach_add_(device_params, device_grads, alpha=-lr)
  269. else:
  270. # foreach APIs don't support sparse
  271. for i in range(len(device_params)):
  272. device_params[i].add_(device_grads[i], alpha=-lr)