nadam.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import torch
  2. from torch import Tensor
  3. from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
  4. _differentiable_doc, _foreach_doc, _default_to_fused_or_foreach)
  5. from typing import List, Optional
  6. from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
  7. __all__ = ['NAdam', 'nadam']
  8. class NAdam(Optimizer):
  9. def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
  10. weight_decay=0, momentum_decay=4e-3, *, foreach: Optional[bool] = None,
  11. differentiable: bool = False):
  12. if not 0.0 <= lr:
  13. raise ValueError("Invalid learning rate: {}".format(lr))
  14. if not 0.0 <= eps:
  15. raise ValueError("Invalid epsilon value: {}".format(eps))
  16. if not 0.0 <= betas[0] < 1.0:
  17. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  18. if not 0.0 <= betas[1] < 1.0:
  19. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  20. if not 0.0 <= weight_decay:
  21. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  22. if not 0.0 <= momentum_decay:
  23. raise ValueError("Invalid momentum_decay value: {}".format(momentum_decay))
  24. defaults = dict(lr=lr, betas=betas, eps=eps,
  25. weight_decay=weight_decay, momentum_decay=momentum_decay,
  26. foreach=foreach, differentiable=differentiable)
  27. super().__init__(params, defaults)
  28. def __setstate__(self, state):
  29. super().__setstate__(state)
  30. for group in self.param_groups:
  31. group.setdefault('foreach', None)
  32. group.setdefault('differentiable', False)
  33. state_values = list(self.state.values())
  34. step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
  35. if not step_is_tensor:
  36. for s in state_values:
  37. s['step'] = torch.tensor(float(s['step']))
  38. mu_product_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['mu_product'])
  39. if not mu_product_is_tensor:
  40. for s in state_values:
  41. s['mu_product'] = torch.tensor(s['mu_product'])
  42. def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps):
  43. for p in group['params']:
  44. if p.grad is not None:
  45. params_with_grad.append(p)
  46. if p.grad.is_sparse:
  47. raise RuntimeError('NAdam does not support sparse gradients')
  48. grads.append(p.grad)
  49. state = self.state[p]
  50. # Lazy state initialization
  51. if len(state) == 0:
  52. state['step'] = torch.tensor(0.)
  53. state['mu_product'] = torch.tensor(1.)
  54. # Exponential moving average of gradient values
  55. state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  56. # Exponential moving average of squared gradient values
  57. state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  58. exp_avgs.append(state['exp_avg'])
  59. exp_avg_sqs.append(state['exp_avg_sq'])
  60. mu_products.append(state['mu_product'])
  61. state_steps.append(state['step'])
  62. @_use_grad_for_differentiable
  63. def step(self, closure=None):
  64. """Performs a single optimization step.
  65. Args:
  66. closure (Callable, optional): A closure that reevaluates the model
  67. and returns the loss.
  68. """
  69. loss = None
  70. if closure is not None:
  71. with torch.enable_grad():
  72. loss = closure()
  73. for group in self.param_groups:
  74. params_with_grad = []
  75. grads = []
  76. exp_avgs = []
  77. exp_avg_sqs = []
  78. mu_products = []
  79. state_steps = []
  80. beta1, beta2 = group['betas']
  81. self._init_group(group, params_with_grad, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps)
  82. nadam(params_with_grad,
  83. grads,
  84. exp_avgs,
  85. exp_avg_sqs,
  86. mu_products,
  87. state_steps,
  88. beta1=beta1,
  89. beta2=beta2,
  90. lr=group['lr'],
  91. weight_decay=group['weight_decay'],
  92. momentum_decay=group['momentum_decay'],
  93. eps=group['eps'],
  94. foreach=group['foreach'],
  95. differentiable=group['differentiable'])
  96. return loss
  97. NAdam.__doc__ = r"""Implements NAdam algorithm.
  98. .. math::
  99. \begin{aligned}
  100. &\rule{110mm}{0.4pt} \\
  101. &\textbf{input} : \gamma_t \text{ (lr)}, \: \beta_1,\beta_2 \text{ (betas)},
  102. \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
  103. &\hspace{13mm} \: \lambda \text{ (weight decay)}, \:\psi \text{ (momentum decay)} \\
  104. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  105. v_0 \leftarrow 0 \text{ ( second moment)} \\[-1.ex]
  106. &\rule{110mm}{0.4pt} \\
  107. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  108. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  109. &\hspace{5mm}if \: \lambda \neq 0 \\
  110. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  111. &\hspace{5mm} \mu_t \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{t \psi} \big) \\
  112. &\hspace{5mm} \mu_{t+1} \leftarrow \beta_1 \big(1 - \frac{1}{2} 0.96^{(t+1)\psi}\big)\\
  113. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  114. &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  115. &\hspace{5mm}\widehat{m_t} \leftarrow \mu_{t+1} m_t/(1-\prod_{i=1}^{t+1}\mu_i)\\[-1.ex]
  116. & \hspace{11mm} + (1-\mu_t) g_t /(1-\prod_{i=1}^{t} \mu_{i}) \\
  117. &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
  118. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
  119. \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
  120. &\rule{110mm}{0.4pt} \\[-1.ex]
  121. &\bf{return} \: \theta_t \\[-1.ex]
  122. &\rule{110mm}{0.4pt} \\[-1.ex]
  123. \end{aligned}
  124. For further details regarding the algorithm we refer to `Incorporating Nesterov Momentum into Adam`_.
  125. """ + r"""
  126. Args:
  127. params (iterable): iterable of parameters to optimize or dicts defining
  128. parameter groups
  129. lr (float, optional): learning rate (default: 2e-3)
  130. betas (Tuple[float, float], optional): coefficients used for computing
  131. running averages of gradient and its square (default: (0.9, 0.999))
  132. eps (float, optional): term added to the denominator to improve
  133. numerical stability (default: 1e-8)
  134. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  135. momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
  136. {foreach}
  137. {differentiable}
  138. .. _Incorporating Nesterov Momentum into Adam:
  139. https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
  140. """.format(foreach=_foreach_doc, differentiable=_differentiable_doc)
  141. def nadam(params: List[Tensor],
  142. grads: List[Tensor],
  143. exp_avgs: List[Tensor],
  144. exp_avg_sqs: List[Tensor],
  145. mu_products: List[Tensor],
  146. state_steps: List[Tensor],
  147. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  148. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  149. foreach: Optional[bool] = None,
  150. differentiable: bool = False,
  151. *,
  152. beta1: float,
  153. beta2: float,
  154. lr: float,
  155. weight_decay: float,
  156. momentum_decay: float,
  157. eps: float):
  158. r"""Functional API that performs NAdam algorithm computation.
  159. See :class:`~torch.optim.NAdam` for details.
  160. """
  161. if not all(isinstance(t, torch.Tensor) for t in state_steps):
  162. raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
  163. if not all(isinstance(t, torch.Tensor) for t in mu_products):
  164. raise RuntimeError("API has changed, `mu_products` argument must contain a list of singleton tensors")
  165. if foreach is None:
  166. _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
  167. if foreach and torch.jit.is_scripting():
  168. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  169. if foreach and not torch.jit.is_scripting():
  170. func = _multi_tensor_nadam
  171. else:
  172. func = _single_tensor_nadam
  173. func(params,
  174. grads,
  175. exp_avgs,
  176. exp_avg_sqs,
  177. mu_products,
  178. state_steps,
  179. beta1=beta1,
  180. beta2=beta2,
  181. lr=lr,
  182. weight_decay=weight_decay,
  183. momentum_decay=momentum_decay,
  184. eps=eps,
  185. differentiable=differentiable)
  186. def _single_tensor_nadam(params: List[Tensor],
  187. grads: List[Tensor],
  188. exp_avgs: List[Tensor],
  189. exp_avg_sqs: List[Tensor],
  190. mu_products: List[Tensor],
  191. state_steps: List[Tensor],
  192. *,
  193. beta1: float,
  194. beta2: float,
  195. lr: float,
  196. weight_decay: float,
  197. momentum_decay: float,
  198. eps: float,
  199. differentiable: bool):
  200. for i, param in enumerate(params):
  201. grad = grads[i]
  202. exp_avg = exp_avgs[i]
  203. exp_avg_sq = exp_avg_sqs[i]
  204. mu_product = mu_products[i]
  205. step_t = state_steps[i]
  206. # update step
  207. step_t += 1
  208. step = _get_value(step_t)
  209. bias_correction2 = 1 - beta2 ** step
  210. if weight_decay != 0:
  211. grad = grad.add(param, alpha=weight_decay)
  212. # calculate the momentum cache \mu^{t} and \mu^{t+1}
  213. mu = beta1 * (1. - 0.5 * (0.96 ** (step * momentum_decay)))
  214. mu_next = beta1 * (1. - 0.5 * (0.96 ** ((step + 1) * momentum_decay)))
  215. # update mu_product
  216. mu_product *= mu
  217. # decay the first and second moment running average coefficient
  218. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  219. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
  220. denom = exp_avg_sq.div(bias_correction2).sqrt()
  221. if differentiable:
  222. denom = denom.add(eps)
  223. # Make autograd track the operations
  224. # by updating the grad and exp_avg directly and not using the
  225. # scalar "value" argument of addcdiv.
  226. mu_product_next = mu_product * mu_next
  227. grad = grad * (-lr * (1. - mu) / (1. - mu_product))
  228. exp_avg = grad * (-lr * (1. - mu_next) / (1. - mu_product_next))
  229. param.addcdiv_(grad, denom)
  230. param.addcdiv_(exp_avg, denom)
  231. else:
  232. mu_product_next = _get_value(mu_product) * mu_next
  233. denom.add_(eps)
  234. param.addcdiv_(grad, denom, value=(-lr * (1. - mu) / (1. - _get_value(mu_product))))
  235. param.addcdiv_(exp_avg, denom, value=(-lr * mu_next) / (1. - mu_product_next))
  236. def _multi_tensor_nadam(params: List[Tensor],
  237. grads: List[Tensor],
  238. exp_avgs: List[Tensor],
  239. exp_avg_sqs: List[Tensor],
  240. mu_products: List[Tensor],
  241. state_steps: List[Tensor],
  242. *,
  243. beta1: float,
  244. beta2: float,
  245. lr: float,
  246. weight_decay: float,
  247. momentum_decay: float,
  248. eps: float,
  249. differentiable: bool):
  250. if len(params) == 0:
  251. return
  252. assert not differentiable, "_foreach ops don't support autograd"
  253. grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs,
  254. mu_products, state_steps])
  255. for (grouped_params, grouped_grads, grouped_exp_avgs,
  256. grouped_exp_avg_sqs, grouped_mu_products, grouped_state_steps) in grouped_tensors.values():
  257. # update steps
  258. torch._foreach_add_(grouped_state_steps, 1)
  259. bias_correction2 = [1 - beta2 ** _get_value(step) for step in grouped_state_steps]
  260. mus = [beta1 * (1. - 0.5 * (0.96 ** (_get_value(step) * momentum_decay))) for step in grouped_state_steps]
  261. mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay)))
  262. for step in grouped_state_steps]
  263. # update mu_products
  264. torch._foreach_mul_(grouped_mu_products, mus)
  265. if weight_decay != 0:
  266. grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
  267. # Decay the first and second moment running average coefficient
  268. torch._foreach_mul_(grouped_exp_avgs, beta1)
  269. torch._foreach_add_(grouped_exp_avgs, grouped_grads, alpha=1 - beta1)
  270. torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
  271. torch._foreach_addcmul_(grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2)
  272. exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs)
  273. bias_correction_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
  274. torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt)
  275. denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
  276. step_size_grads = _stack_if_compiling([(lr * (1. - mu) / (1. - _get_value(mu_product))) * -1
  277. for mu_product, mu in zip(grouped_mu_products, mus)])
  278. step_size_expavg = _stack_if_compiling([(lr * mu_next / (1. - _get_value(mu_product) * mu_next)) * -1
  279. for mu_product, mu_next in zip(grouped_mu_products, mu_nexts)])
  280. torch._foreach_addcdiv_(grouped_params, grouped_grads, denom, step_size_grads)
  281. torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom, step_size_expavg)