adamw.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. import torch
  2. from torch import Tensor
  3. from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt,
  4. _stack_if_compiling, _capturable_doc, _differentiable_doc, _foreach_doc,
  5. _fused_doc, _maximize_doc, _default_to_fused_or_foreach)
  6. from typing import List, Optional
  7. from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
  8. __all__ = ["AdamW", "adamw"]
  9. class AdamW(Optimizer):
  10. def __init__(
  11. self,
  12. params,
  13. lr=1e-3,
  14. betas=(0.9, 0.999),
  15. eps=1e-8,
  16. weight_decay=1e-2,
  17. amsgrad=False,
  18. *,
  19. maximize: bool = False,
  20. foreach: Optional[bool] = None,
  21. capturable: bool = False,
  22. differentiable: bool = False,
  23. fused: Optional[bool] = None,
  24. ):
  25. if not 0.0 <= lr:
  26. raise ValueError("Invalid learning rate: {}".format(lr))
  27. if not 0.0 <= eps:
  28. raise ValueError("Invalid epsilon value: {}".format(eps))
  29. if not 0.0 <= betas[0] < 1.0:
  30. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  31. if not 0.0 <= betas[1] < 1.0:
  32. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  33. if not 0.0 <= weight_decay:
  34. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  35. defaults = dict(
  36. lr=lr,
  37. betas=betas,
  38. eps=eps,
  39. weight_decay=weight_decay,
  40. amsgrad=amsgrad,
  41. foreach=foreach,
  42. maximize=maximize,
  43. capturable=capturable,
  44. differentiable=differentiable,
  45. fused=fused,
  46. )
  47. super().__init__(params, defaults)
  48. if fused:
  49. if differentiable:
  50. raise RuntimeError("`fused` does not support `differentiable`")
  51. self._step_supports_amp_scaling = True
  52. # TODO(crcrpar): [low prec params & their higher prec copy]
  53. # Suppor AMP with FP16/BF16 model params which would need
  54. # higher prec copy of params to do update math in higher prec to
  55. # alleviate the loss of information.
  56. if not all(
  57. p.is_cuda and torch.is_floating_point(p)
  58. for pg in self.param_groups for p in pg['params']
  59. ):
  60. raise RuntimeError("`fused=True` requires all the params to be CUDA, floating point Tensor")
  61. if foreach:
  62. raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
  63. def __setstate__(self, state):
  64. super().__setstate__(state)
  65. for group in self.param_groups:
  66. group.setdefault("amsgrad", False)
  67. group.setdefault("maximize", False)
  68. group.setdefault("foreach", None)
  69. group.setdefault("capturable", False)
  70. group.setdefault("differentiable", False)
  71. group.setdefault("fused", None)
  72. state_values = list(self.state.values())
  73. step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
  74. state_values[0]["step"]
  75. )
  76. if not step_is_tensor:
  77. for s in state_values:
  78. s["step"] = torch.tensor(float(s["step"]))
  79. def _init_group(
  80. self,
  81. group,
  82. params_with_grad,
  83. grads,
  84. amsgrad,
  85. exp_avgs,
  86. exp_avg_sqs,
  87. max_exp_avg_sqs,
  88. state_steps,
  89. ):
  90. for p in group["params"]:
  91. if p.grad is None:
  92. continue
  93. params_with_grad.append(p)
  94. if p.grad.is_sparse:
  95. raise RuntimeError("AdamW does not support sparse gradients")
  96. grads.append(p.grad)
  97. state = self.state[p]
  98. # State initialization
  99. if len(state) == 0:
  100. state["step"] = (
  101. torch.zeros((1,), dtype=torch.float, device=p.device)
  102. if group["capturable"] or group["fused"]
  103. else torch.tensor(0.0)
  104. )
  105. # Exponential moving average of gradient values
  106. state["exp_avg"] = torch.zeros_like(
  107. p, memory_format=torch.preserve_format
  108. )
  109. # Exponential moving average of squared gradient values
  110. state["exp_avg_sq"] = torch.zeros_like(
  111. p, memory_format=torch.preserve_format
  112. )
  113. if amsgrad:
  114. # Maintains max of all exp. moving avg. of sq. grad. values
  115. state["max_exp_avg_sq"] = torch.zeros_like(
  116. p, memory_format=torch.preserve_format
  117. )
  118. exp_avgs.append(state["exp_avg"])
  119. exp_avg_sqs.append(state["exp_avg_sq"])
  120. if amsgrad:
  121. max_exp_avg_sqs.append(state["max_exp_avg_sq"])
  122. state_steps.append(state["step"])
  123. @_use_grad_for_differentiable
  124. def step(self, closure=None):
  125. """Performs a single optimization step.
  126. Args:
  127. closure (Callable, optional): A closure that reevaluates the model
  128. and returns the loss.
  129. """
  130. self._cuda_graph_capture_health_check()
  131. loss = None
  132. if closure is not None:
  133. with torch.enable_grad():
  134. loss = closure()
  135. for group in self.param_groups:
  136. params_with_grad = []
  137. grads = []
  138. exp_avgs = []
  139. exp_avg_sqs = []
  140. max_exp_avg_sqs = []
  141. state_steps = []
  142. amsgrad = group["amsgrad"]
  143. beta1, beta2 = group["betas"]
  144. self._init_group(
  145. group,
  146. params_with_grad,
  147. grads,
  148. amsgrad,
  149. exp_avgs,
  150. exp_avg_sqs,
  151. max_exp_avg_sqs,
  152. state_steps,
  153. )
  154. adamw(
  155. params_with_grad,
  156. grads,
  157. exp_avgs,
  158. exp_avg_sqs,
  159. max_exp_avg_sqs,
  160. state_steps,
  161. amsgrad=amsgrad,
  162. beta1=beta1,
  163. beta2=beta2,
  164. lr=group["lr"],
  165. weight_decay=group["weight_decay"],
  166. eps=group["eps"],
  167. maximize=group["maximize"],
  168. foreach=group["foreach"],
  169. capturable=group["capturable"],
  170. differentiable=group["differentiable"],
  171. fused=group["fused"],
  172. grad_scale=getattr(self, "grad_scale", None),
  173. found_inf=getattr(self, "found_inf", None),
  174. )
  175. return loss
  176. AdamW.__doc__ = r"""Implements AdamW algorithm.
  177. .. math::
  178. \begin{aligned}
  179. &\rule{110mm}{0.4pt} \\
  180. &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
  181. \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
  182. \: \epsilon \text{ (epsilon)} \\
  183. &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
  184. \: \textit{maximize} \\
  185. &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
  186. \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
  187. &\rule{110mm}{0.4pt} \\
  188. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  189. &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
  190. &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  191. &\hspace{5mm}\textbf{else} \\
  192. &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  193. &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
  194. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  195. &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  196. &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
  197. &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
  198. &\hspace{5mm}\textbf{if} \: amsgrad \\
  199. &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
  200. \widehat{v_t}) \\
  201. &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
  202. \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
  203. &\hspace{5mm}\textbf{else} \\
  204. &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
  205. \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
  206. &\rule{110mm}{0.4pt} \\[-1.ex]
  207. &\bf{return} \: \theta_t \\[-1.ex]
  208. &\rule{110mm}{0.4pt} \\[-1.ex]
  209. \end{aligned}
  210. For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_.
  211. """ + r"""
  212. Args:
  213. params (iterable): iterable of parameters to optimize or dicts defining
  214. parameter groups
  215. lr (float, optional): learning rate (default: 1e-3)
  216. betas (Tuple[float, float], optional): coefficients used for computing
  217. running averages of gradient and its square (default: (0.9, 0.999))
  218. eps (float, optional): term added to the denominator to improve
  219. numerical stability (default: 1e-8)
  220. weight_decay (float, optional): weight decay coefficient (default: 1e-2)
  221. amsgrad (bool, optional): whether to use the AMSGrad variant of this
  222. algorithm from the paper `On the Convergence of Adam and Beyond`_
  223. (default: False)
  224. {maximize}
  225. {foreach}
  226. {capturable}
  227. {differentiable}
  228. {fused}
  229. .. _Decoupled Weight Decay Regularization:
  230. https://arxiv.org/abs/1711.05101
  231. .. _On the Convergence of Adam and Beyond:
  232. https://openreview.net/forum?id=ryQu7f-RZ
  233. """.format(maximize=_maximize_doc,
  234. foreach=_foreach_doc,
  235. fused=_fused_doc,
  236. capturable=_capturable_doc,
  237. differentiable=_differentiable_doc)
  238. def adamw(
  239. params: List[Tensor],
  240. grads: List[Tensor],
  241. exp_avgs: List[Tensor],
  242. exp_avg_sqs: List[Tensor],
  243. max_exp_avg_sqs: List[Tensor],
  244. state_steps: List[Tensor],
  245. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  246. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  247. foreach: Optional[bool] = None,
  248. capturable: bool = False,
  249. differentiable: bool = False,
  250. fused: Optional[bool] = None,
  251. grad_scale: Optional[Tensor] = None,
  252. found_inf: Optional[Tensor] = None,
  253. *,
  254. amsgrad: bool,
  255. beta1: float,
  256. beta2: float,
  257. lr: float,
  258. weight_decay: float,
  259. eps: float,
  260. maximize: bool,
  261. ):
  262. r"""Functional API that performs AdamW algorithm computation.
  263. See :class:`~torch.optim.AdamW` for details.
  264. """
  265. if not all(isinstance(t, torch.Tensor) for t in state_steps):
  266. raise RuntimeError(
  267. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  268. )
  269. # Respect when the user inputs False/True for foreach or fused. We only want to change
  270. # the default when neither have been user-specified. Note that we default to foreach
  271. # and pass False to use_fused. This is not a mistake--we want to give the fused impl
  272. # bake-in time before making it the default, even if it is typically faster.
  273. if fused is None and foreach is None:
  274. _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
  275. if fused is None:
  276. fused = False
  277. if foreach is None:
  278. foreach = False
  279. if foreach and torch.jit.is_scripting():
  280. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  281. if fused and torch.jit.is_scripting():
  282. raise RuntimeError("torch.jit.script not supported with fused optimizers")
  283. if fused and not torch.jit.is_scripting():
  284. func = _fused_adamw
  285. elif foreach and not torch.jit.is_scripting():
  286. func = _multi_tensor_adamw
  287. else:
  288. func = _single_tensor_adamw
  289. func(
  290. params,
  291. grads,
  292. exp_avgs,
  293. exp_avg_sqs,
  294. max_exp_avg_sqs,
  295. state_steps,
  296. amsgrad=amsgrad,
  297. beta1=beta1,
  298. beta2=beta2,
  299. lr=lr,
  300. weight_decay=weight_decay,
  301. eps=eps,
  302. maximize=maximize,
  303. capturable=capturable,
  304. differentiable=differentiable,
  305. grad_scale=grad_scale,
  306. found_inf=found_inf,
  307. )
  308. def _single_tensor_adamw(
  309. params: List[Tensor],
  310. grads: List[Tensor],
  311. exp_avgs: List[Tensor],
  312. exp_avg_sqs: List[Tensor],
  313. max_exp_avg_sqs: List[Tensor],
  314. state_steps: List[Tensor],
  315. grad_scale: Optional[Tensor],
  316. found_inf: Optional[Tensor],
  317. *,
  318. amsgrad: bool,
  319. beta1: float,
  320. beta2: float,
  321. lr: float,
  322. weight_decay: float,
  323. eps: float,
  324. maximize: bool,
  325. capturable: bool,
  326. differentiable: bool,
  327. ):
  328. assert grad_scale is None and found_inf is None
  329. for i, param in enumerate(params):
  330. grad = grads[i] if not maximize else -grads[i]
  331. exp_avg = exp_avgs[i]
  332. exp_avg_sq = exp_avg_sqs[i]
  333. step_t = state_steps[i]
  334. if capturable:
  335. assert (
  336. param.is_cuda and step_t.is_cuda
  337. ), "If capturable=True, params and state_steps must be CUDA tensors."
  338. if torch.is_complex(param):
  339. grad = torch.view_as_real(grad)
  340. exp_avg = torch.view_as_real(exp_avg)
  341. exp_avg_sq = torch.view_as_real(exp_avg_sq)
  342. param = torch.view_as_real(param)
  343. # update step
  344. step_t += 1
  345. # Perform stepweight decay
  346. param.mul_(1 - lr * weight_decay)
  347. # Decay the first and second moment running average coefficient
  348. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  349. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
  350. if capturable or differentiable:
  351. step = step_t
  352. # 1 - beta1 ** step can't be captured in a CUDA graph, even if step is a CUDA tensor
  353. # (incurs "RuntimeError: CUDA error: operation not permitted when stream is capturing")
  354. bias_correction1 = 1 - torch.pow(beta1, step)
  355. bias_correction2 = 1 - torch.pow(beta2, step)
  356. step_size = lr / bias_correction1
  357. step_size_neg = step_size.neg()
  358. bias_correction2_sqrt = bias_correction2.sqrt()
  359. if amsgrad:
  360. # Maintains the maximum of all 2nd moment running avg. till now
  361. if differentiable:
  362. max_exp_avg_sqs_i = max_exp_avg_sqs[i].clone()
  363. else:
  364. max_exp_avg_sqs_i = max_exp_avg_sqs[i]
  365. max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sqs_i, exp_avg_sq))
  366. # Uses the max. for normalizing running avg. of gradient
  367. # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
  368. # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
  369. denom = (
  370. max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
  371. ).add_(eps / step_size_neg)
  372. else:
  373. denom = (
  374. exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
  375. ).add_(eps / step_size_neg)
  376. param.addcdiv_(exp_avg, denom)
  377. else:
  378. step = _get_value(step_t)
  379. bias_correction1 = 1 - beta1 ** step
  380. bias_correction2 = 1 - beta2 ** step
  381. step_size = lr / bias_correction1
  382. bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
  383. if amsgrad:
  384. # Maintains the maximum of all 2nd moment running avg. till now
  385. torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
  386. # Use the max. for normalizing running avg. of gradient
  387. denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
  388. else:
  389. denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
  390. param.addcdiv_(exp_avg, denom, value=-step_size)
  391. def _multi_tensor_adamw(
  392. params: List[Tensor],
  393. grads: List[Tensor],
  394. exp_avgs: List[Tensor],
  395. exp_avg_sqs: List[Tensor],
  396. max_exp_avg_sqs: List[Tensor],
  397. state_steps: List[Tensor],
  398. grad_scale: Optional[Tensor],
  399. found_inf: Optional[Tensor],
  400. *,
  401. amsgrad: bool,
  402. beta1: float,
  403. beta2: float,
  404. lr: float,
  405. weight_decay: float,
  406. eps: float,
  407. maximize: bool,
  408. capturable: bool,
  409. differentiable: bool,
  410. ):
  411. if len(params) == 0:
  412. return
  413. if capturable:
  414. assert all(
  415. p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)
  416. ), "If capturable=True, params and state_steps must be CUDA tensors."
  417. assert not differentiable, "_foreach ops don't support autograd"
  418. assert grad_scale is None and found_inf is None
  419. grouped_tensors = _group_tensors_by_device_and_dtype([
  420. params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
  421. for (device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,
  422. device_max_exp_avg_sqs, device_state_steps) in grouped_tensors.values():
  423. if maximize:
  424. device_grads = torch._foreach_neg(tuple(device_grads)) # type: ignore[assignment]
  425. device_grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_grads]
  426. device_exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avgs]
  427. device_exp_avg_sqs = [
  428. torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avg_sqs
  429. ]
  430. device_params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_params]
  431. # update steps
  432. torch._foreach_add_(device_state_steps, 1)
  433. # Perform stepweight decay
  434. torch._foreach_mul_(device_params, 1 - lr * weight_decay)
  435. # Decay the first and second moment running average coefficient
  436. torch._foreach_mul_(device_exp_avgs, beta1)
  437. torch._foreach_add_(device_exp_avgs, device_grads, alpha=1 - beta1)
  438. torch._foreach_mul_(device_exp_avg_sqs, beta2)
  439. torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)
  440. if capturable:
  441. # TODO: use foreach_pow if/when foreach_pow is added
  442. bias_correction1 = [torch.pow(beta1, step) for step in device_state_steps]
  443. bias_correction2 = [torch.pow(beta2, step) for step in device_state_steps]
  444. # foreach_sub doesn't allow a scalar as the first arg
  445. torch._foreach_sub_(bias_correction1, 1)
  446. torch._foreach_sub_(bias_correction2, 1)
  447. torch._foreach_neg_(bias_correction1)
  448. torch._foreach_neg_(bias_correction2)
  449. # foreach_div doesn't allow a scalar as the first arg
  450. step_size = torch._foreach_div(bias_correction1, lr)
  451. torch._foreach_reciprocal_(step_size)
  452. torch._foreach_neg_(step_size)
  453. bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)
  454. if amsgrad:
  455. # Maintains the maximum of all 2nd moment running avg. till now
  456. torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
  457. # Use the max. for normalizing running avg. of gradient
  458. max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
  459. # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
  460. # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
  461. torch._foreach_div_(
  462. max_exp_avg_sq_sqrt,
  463. torch._foreach_mul(bias_correction2_sqrt, step_size),
  464. )
  465. eps_over_step_size = torch._foreach_div(step_size, eps)
  466. torch._foreach_reciprocal_(eps_over_step_size)
  467. denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps_over_step_size)
  468. else:
  469. exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
  470. torch._foreach_div_(
  471. exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
  472. )
  473. eps_over_step_size = torch._foreach_div(step_size, eps)
  474. torch._foreach_reciprocal_(eps_over_step_size)
  475. denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)
  476. torch._foreach_addcdiv_(device_params, device_exp_avgs, denom)
  477. else:
  478. bias_correction1 = [1 - beta1 ** _get_value(step) for step in device_state_steps]
  479. bias_correction2 = [1 - beta2 ** _get_value(step) for step in device_state_steps]
  480. step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
  481. bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
  482. if amsgrad:
  483. # Maintains the maximum of all 2nd moment running avg. till now
  484. torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
  485. # Use the max. for normalizing running avg. of gradient
  486. max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
  487. torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction2_sqrt)
  488. denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps)
  489. else:
  490. exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
  491. torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
  492. denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
  493. torch._foreach_addcdiv_(device_params, device_exp_avgs, denom, step_size)
  494. def _fused_adamw(
  495. params: List[Tensor],
  496. grads: List[Tensor],
  497. exp_avgs: List[Tensor],
  498. exp_avg_sqs: List[Tensor],
  499. max_exp_avg_sqs: List[Tensor],
  500. state_steps: List[Tensor],
  501. grad_scale: Optional[Tensor],
  502. found_inf: Optional[Tensor],
  503. *,
  504. amsgrad: bool,
  505. beta1: float,
  506. beta2: float,
  507. lr: float,
  508. weight_decay: float,
  509. eps: float,
  510. maximize: bool,
  511. capturable: bool, # Needed for consistency.
  512. differentiable: bool,
  513. ) -> None:
  514. if differentiable:
  515. raise RuntimeError("_fused_adamw is not differentiable")
  516. grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
  517. found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
  518. grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
  519. for (device, dtype) in grouped_tensors:
  520. (
  521. device_params,
  522. device_grads,
  523. device_exp_avgs,
  524. device_exp_avg_sqs,
  525. device_max_exp_avg_sqs,
  526. device_state_steps,
  527. ) = grouped_tensors[(device, dtype)]
  528. if grad_scale is not None and found_inf is not None:
  529. if device not in grad_scale_dict:
  530. grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)
  531. if found_inf not in found_inf_dict:
  532. found_inf_dict[device] = found_inf.to(device, non_blocking=True)
  533. device_grad_scale = grad_scale_dict[device]
  534. device_found_inf = found_inf_dict[device]
  535. else:
  536. device_grad_scale = None
  537. device_found_inf = None
  538. torch._foreach_add_(device_state_steps, 1)
  539. torch._fused_adamw_(
  540. device_params,
  541. device_grads,
  542. device_exp_avgs,
  543. device_exp_avg_sqs,
  544. device_max_exp_avg_sqs,
  545. device_state_steps,
  546. amsgrad=amsgrad,
  547. lr=lr,
  548. beta1=beta1,
  549. beta2=beta2,
  550. weight_decay=weight_decay,
  551. eps=eps,
  552. maximize=maximize,
  553. grad_scale=device_grad_scale,
  554. found_inf=device_found_inf,
  555. )
  556. if device_found_inf is not None:
  557. torch._foreach_sub_(device_state_steps, [device_found_inf] * len(device_state_steps))