radam.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import math
  2. import torch
  3. from torch import Tensor
  4. from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
  5. _default_to_fused_or_foreach, _differentiable_doc, _foreach_doc)
  6. from typing import List, Optional
  7. from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
  8. __all__ = ["RAdam", "radam"]
  9. class RAdam(Optimizer):
  10. def __init__(
  11. self,
  12. params,
  13. lr=1e-3,
  14. betas=(0.9, 0.999),
  15. eps=1e-8,
  16. weight_decay=0,
  17. *,
  18. foreach: Optional[bool] = None,
  19. differentiable: bool = False,
  20. ):
  21. if not 0.0 <= lr:
  22. raise ValueError("Invalid learning rate: {}".format(lr))
  23. if not 0.0 <= eps:
  24. raise ValueError("Invalid epsilon value: {}".format(eps))
  25. if not 0.0 <= betas[0] < 1.0:
  26. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  27. if not 0.0 <= betas[1] < 1.0:
  28. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  29. if not 0.0 <= weight_decay:
  30. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  31. defaults = dict(
  32. lr=lr,
  33. betas=betas,
  34. eps=eps,
  35. weight_decay=weight_decay,
  36. foreach=foreach,
  37. differentiable=differentiable,
  38. )
  39. super().__init__(params, defaults)
  40. def __setstate__(self, state):
  41. super().__setstate__(state)
  42. for group in self.param_groups:
  43. group.setdefault("foreach", None)
  44. group.setdefault("differentiable", False)
  45. state_values = list(self.state.values())
  46. step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
  47. state_values[0]["step"]
  48. )
  49. if not step_is_tensor:
  50. for s in state_values:
  51. s["step"] = torch.tensor(float(s["step"]))
  52. def _init_group(self, group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps):
  53. for p in group["params"]:
  54. if p.grad is not None:
  55. params_with_grad.append(p)
  56. if p.grad.is_sparse:
  57. raise RuntimeError("RAdam does not support sparse gradients")
  58. grads.append(p.grad)
  59. state = self.state[p]
  60. # Lazy state initialization
  61. if len(state) == 0:
  62. state["step"] = torch.tensor(0.0)
  63. # Exponential moving average of gradient values
  64. state["exp_avg"] = torch.zeros_like(
  65. p, memory_format=torch.preserve_format
  66. )
  67. # Exponential moving average of squared gradient values
  68. state["exp_avg_sq"] = torch.zeros_like(
  69. p, memory_format=torch.preserve_format
  70. )
  71. exp_avgs.append(state["exp_avg"])
  72. exp_avg_sqs.append(state["exp_avg_sq"])
  73. state_steps.append(state["step"])
  74. @_use_grad_for_differentiable
  75. def step(self, closure=None):
  76. """Performs a single optimization step.
  77. Args:
  78. closure (Callable, optional): A closure that reevaluates the model
  79. and returns the loss.
  80. """
  81. loss = None
  82. if closure is not None:
  83. with torch.enable_grad():
  84. loss = closure()
  85. for group in self.param_groups:
  86. params_with_grad = []
  87. grads = []
  88. exp_avgs = []
  89. exp_avg_sqs = []
  90. state_steps = []
  91. beta1, beta2 = group["betas"]
  92. self._init_group(group, params_with_grad, grads, exp_avgs, exp_avg_sqs, state_steps)
  93. radam(
  94. params_with_grad,
  95. grads,
  96. exp_avgs,
  97. exp_avg_sqs,
  98. state_steps,
  99. beta1=beta1,
  100. beta2=beta2,
  101. lr=group["lr"],
  102. weight_decay=group["weight_decay"],
  103. eps=group["eps"],
  104. foreach=group["foreach"],
  105. differentiable=group["differentiable"],
  106. )
  107. return loss
  108. RAdam.__doc__ = r"""Implements RAdam algorithm.
  109. .. math::
  110. \begin{aligned}
  111. &\rule{110mm}{0.4pt} \\
  112. &\textbf{input} : \gamma \text{ (lr)}, \: \beta_1, \beta_2
  113. \text{ (betas)}, \: \theta_0 \text{ (params)}, \:f(\theta) \text{ (objective)}, \:
  114. \lambda \text{ (weightdecay)}, \\
  115. &\hspace{13mm} \epsilon \text{ (epsilon)} \\
  116. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  117. v_0 \leftarrow 0 \text{ ( second moment)}, \\
  118. &\hspace{18mm} \rho_{\infty} \leftarrow 2/(1-\beta_2) -1 \\[-1.ex]
  119. &\rule{110mm}{0.4pt} \\
  120. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  121. &\hspace{6mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  122. &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
  123. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  124. &\hspace{6mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  125. &\hspace{6mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  126. &\hspace{6mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
  127. &\hspace{6mm}\rho_t \leftarrow \rho_{\infty} -
  128. 2 t \beta^t_2 /\big(1-\beta_2^t \big) \\[0.1.ex]
  129. &\hspace{6mm}\textbf{if} \: \rho_t > 5 \\
  130. &\hspace{12mm} l_t \leftarrow \frac{\sqrt{ (1-\beta^t_2) }}{ \sqrt{v_t} +\epsilon } \\
  131. &\hspace{12mm} r_t \leftarrow
  132. \sqrt{\frac{(\rho_t-4)(\rho_t-2)\rho_{\infty}}{(\rho_{\infty}-4)(\rho_{\infty}-2) \rho_t}} \\
  133. &\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} r_t l_t \\
  134. &\hspace{6mm}\textbf{else} \\
  135. &\hspace{12mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t} \\
  136. &\rule{110mm}{0.4pt} \\[-1.ex]
  137. &\bf{return} \: \theta_t \\[-1.ex]
  138. &\rule{110mm}{0.4pt} \\[-1.ex]
  139. \end{aligned}
  140. For further details regarding the algorithm we refer to `On the variance of the adaptive learning rate and beyond`_.
  141. This implementation uses the same weight_decay implementation as Adam (were the weight_decay is applied
  142. to the gradient) and not the one from AdamW (were weight_decay is applied to the update). This
  143. is different from the `author's implementation`_.
  144. """ + r"""
  145. Args:
  146. params (iterable): iterable of parameters to optimize or dicts defining
  147. parameter groups
  148. lr (float, optional): learning rate (default: 1e-3)
  149. betas (Tuple[float, float], optional): coefficients used for computing
  150. running averages of gradient and its square (default: (0.9, 0.999))
  151. eps (float, optional): term added to the denominator to improve
  152. numerical stability (default: 1e-8)
  153. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  154. {foreach}
  155. {differentiable}
  156. .. _On the variance of the adaptive learning rate and beyond:
  157. https://arxiv.org/abs/1908.03265
  158. .. _author's implementation:
  159. https://github.com/LiyuanLucasLiu/RAdam
  160. """.format(foreach=_foreach_doc, differentiable=_differentiable_doc)
  161. def radam(
  162. params: List[Tensor],
  163. grads: List[Tensor],
  164. exp_avgs: List[Tensor],
  165. exp_avg_sqs: List[Tensor],
  166. state_steps: List[Tensor],
  167. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  168. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  169. foreach: Optional[bool] = None,
  170. differentiable: bool = False,
  171. *,
  172. beta1: float,
  173. beta2: float,
  174. lr: float,
  175. weight_decay: float,
  176. eps: float,
  177. ):
  178. r"""Functional API that performs RAdam algorithm computation.
  179. See :class:`~torch.optim.RAdam` for details.
  180. """
  181. if not all(isinstance(t, torch.Tensor) for t in state_steps):
  182. raise RuntimeError(
  183. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  184. )
  185. if foreach is None:
  186. _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
  187. if foreach and torch.jit.is_scripting():
  188. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  189. if foreach and not torch.jit.is_scripting():
  190. func = _multi_tensor_radam
  191. else:
  192. func = _single_tensor_radam
  193. func(
  194. params,
  195. grads,
  196. exp_avgs,
  197. exp_avg_sqs,
  198. state_steps,
  199. beta1=beta1,
  200. beta2=beta2,
  201. lr=lr,
  202. weight_decay=weight_decay,
  203. eps=eps,
  204. differentiable=differentiable,
  205. )
  206. def _single_tensor_radam(
  207. params: List[Tensor],
  208. grads: List[Tensor],
  209. exp_avgs: List[Tensor],
  210. exp_avg_sqs: List[Tensor],
  211. state_steps: List[Tensor],
  212. *,
  213. beta1: float,
  214. beta2: float,
  215. lr: float,
  216. weight_decay: float,
  217. eps: float,
  218. differentiable: bool,
  219. ):
  220. for i, param in enumerate(params):
  221. grad = grads[i]
  222. exp_avg = exp_avgs[i]
  223. exp_avg_sq = exp_avg_sqs[i]
  224. step_t = state_steps[i]
  225. # update step
  226. step_t += 1
  227. step = _get_value(step_t)
  228. bias_correction1 = 1 - beta1 ** step
  229. bias_correction2 = 1 - beta2 ** step
  230. if weight_decay != 0:
  231. grad = grad.add(param, alpha=weight_decay)
  232. # Decay the first and second moment running average coefficient
  233. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  234. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
  235. # correcting bias for the first moving moment
  236. bias_corrected_exp_avg = exp_avg / bias_correction1
  237. # maximum length of the approximated SMA
  238. rho_inf = 2 / (1 - beta2) - 1
  239. # compute the length of the approximated SMA
  240. rho_t = rho_inf - 2 * step * (beta2 ** step) / bias_correction2
  241. if rho_t > 5.0:
  242. # Compute the variance rectification term and update parameters accordingly
  243. rect = math.sqrt(
  244. (rho_t - 4)
  245. * (rho_t - 2)
  246. * rho_inf
  247. / ((rho_inf - 4) * (rho_inf - 2) * rho_t)
  248. )
  249. exp_avg_sq_sqrt = exp_avg_sq.sqrt()
  250. if differentiable:
  251. exp_avg_sq_sqrt = exp_avg_sq_sqrt.add(eps)
  252. else:
  253. exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
  254. adaptive_lr = math.sqrt(bias_correction2) / exp_avg_sq_sqrt
  255. param.add_(bias_corrected_exp_avg * lr * adaptive_lr * rect, alpha=-1.0)
  256. else:
  257. param.add_(bias_corrected_exp_avg * lr, alpha=-1.0)
  258. def _multi_tensor_radam(
  259. params: List[Tensor],
  260. grads: List[Tensor],
  261. exp_avgs: List[Tensor],
  262. exp_avg_sqs: List[Tensor],
  263. state_steps: List[Tensor],
  264. *,
  265. beta1: float,
  266. beta2: float,
  267. lr: float,
  268. weight_decay: float,
  269. eps: float,
  270. differentiable: bool,
  271. ):
  272. if len(params) == 0:
  273. return
  274. assert not differentiable, "_foreach ops don't support autograd"
  275. grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, state_steps])
  276. for grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_avg_sqs, grouped_state_steps in grouped_tensors.values():
  277. # Update steps
  278. torch._foreach_add_(grouped_state_steps, 1)
  279. # maximum length of the approximated SMA
  280. rho_inf = 2 / (1 - beta2) - 1
  281. # compute the length of the approximated SMA
  282. rho_t_list = [rho_inf - 2 * _get_value(step) * (beta2 ** _get_value(step)) /
  283. (1 - beta2 ** _get_value(step)) for step in grouped_state_steps]
  284. bias_correction1 = [1 - beta1 ** _get_value(step) for step in grouped_state_steps]
  285. bias_correction2 = [1 - beta2 ** _get_value(step) for step in grouped_state_steps]
  286. if weight_decay != 0:
  287. grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
  288. # Decay the first and second moment running average coefficient
  289. torch._foreach_mul_(grouped_exp_avgs, beta1)
  290. torch._foreach_add_(grouped_exp_avgs, grouped_grads, alpha=1 - beta1)
  291. torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
  292. torch._foreach_addcmul_(grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2)
  293. rect = [
  294. _dispatch_sqrt(
  295. (rho_t - 4)
  296. * (rho_t - 2)
  297. * rho_inf
  298. / ((rho_inf - 4) * (rho_inf - 2) * rho_t)
  299. )
  300. if rho_t > 5
  301. else 0
  302. for rho_t in rho_t_list
  303. ]
  304. unrectified = [0 if rect > 0 else 1.0 for rect in rect]
  305. exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs)
  306. torch._foreach_add_(exp_avg_sq_sqrt, eps)
  307. bias_correction_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
  308. denom = torch._foreach_div(exp_avg_sq_sqrt, bias_correction_sqrt)
  309. step_size = _stack_if_compiling([(lr * rect / bc) * -1 for rect, bc in zip(rect, bias_correction1)])
  310. torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom, step_size)
  311. denom = [torch.ones_like(exp_av, memory_format=torch.preserve_format) for exp_av in grouped_exp_avgs]
  312. step_size = _stack_if_compiling([(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)])
  313. torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom, step_size)