adam.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. from typing import List, Optional
  2. import torch
  3. from torch import Tensor
  4. from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
  5. _dispatch_sqrt, _default_to_fused_or_foreach, _capturable_doc,
  6. _differentiable_doc, _foreach_doc, _fused_doc, _maximize_doc)
  7. from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
  8. __all__ = ['Adam', 'adam']
  9. class Adam(Optimizer):
  10. def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
  11. weight_decay=0, amsgrad=False, *, foreach: Optional[bool] = None,
  12. maximize: bool = False, capturable: bool = False,
  13. differentiable: bool = False, fused: Optional[bool] = None):
  14. if not 0.0 <= lr:
  15. raise ValueError("Invalid learning rate: {}".format(lr))
  16. if not 0.0 <= eps:
  17. raise ValueError("Invalid epsilon value: {}".format(eps))
  18. if not 0.0 <= betas[0] < 1.0:
  19. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  20. if not 0.0 <= betas[1] < 1.0:
  21. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  22. if not 0.0 <= weight_decay:
  23. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  24. defaults = dict(lr=lr, betas=betas, eps=eps,
  25. weight_decay=weight_decay, amsgrad=amsgrad,
  26. maximize=maximize, foreach=foreach, capturable=capturable,
  27. differentiable=differentiable, fused=fused)
  28. super().__init__(params, defaults)
  29. if fused:
  30. if differentiable:
  31. raise RuntimeError("`fused` does not support `differentiable`")
  32. self._step_supports_amp_scaling = True
  33. # TODO(crcrpar): [low prec params & their higher prec copy]
  34. # Suppor AMP with FP16/BF16 model params which would need
  35. # higher prec copy of params to do update math in higher prec to
  36. # alleviate the loss of information.
  37. if not all(
  38. p.is_cuda and torch.is_floating_point(p)
  39. for pg in self.param_groups for p in pg['params']
  40. ):
  41. raise RuntimeError("`fused=True` requires all the params to be CUDA, floating point Tensor")
  42. if foreach:
  43. raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
  44. def __setstate__(self, state):
  45. super().__setstate__(state)
  46. for group in self.param_groups:
  47. group.setdefault('amsgrad', False)
  48. group.setdefault('maximize', False)
  49. group.setdefault('foreach', None)
  50. group.setdefault('capturable', False)
  51. group.setdefault('differentiable', False)
  52. group.setdefault('fused', None)
  53. state_values = list(self.state.values())
  54. step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step'])
  55. if not step_is_tensor:
  56. for s in state_values:
  57. s['step'] = torch.tensor(float(s['step']))
  58. def _init_group(
  59. self,
  60. group,
  61. params_with_grad,
  62. grads,
  63. exp_avgs,
  64. exp_avg_sqs,
  65. max_exp_avg_sqs,
  66. state_steps
  67. ):
  68. for p in group['params']:
  69. if p.grad is not None:
  70. params_with_grad.append(p)
  71. if p.grad.is_sparse:
  72. raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
  73. grads.append(p.grad)
  74. state = self.state[p]
  75. # Lazy state initialization
  76. if len(state) == 0:
  77. state['step'] = (
  78. torch.zeros((1,), dtype=torch.float, device=p.device)
  79. if group['capturable'] or group['fused']
  80. else torch.tensor(0.)
  81. )
  82. # Exponential moving average of gradient values
  83. state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  84. # Exponential moving average of squared gradient values
  85. state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  86. if group['amsgrad']:
  87. # Maintains max of all exp. moving avg. of sq. grad. values
  88. state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
  89. exp_avgs.append(state['exp_avg'])
  90. exp_avg_sqs.append(state['exp_avg_sq'])
  91. if group['amsgrad']:
  92. max_exp_avg_sqs.append(state['max_exp_avg_sq'])
  93. if group['differentiable'] and state['step'].requires_grad:
  94. raise RuntimeError('`requires_grad` is not supported for `step` in differentiable mode')
  95. state_steps.append(state['step'])
  96. @_use_grad_for_differentiable
  97. def step(self, closure=None):
  98. """Performs a single optimization step.
  99. Args:
  100. closure (Callable, optional): A closure that reevaluates the model
  101. and returns the loss.
  102. """
  103. self._cuda_graph_capture_health_check()
  104. loss = None
  105. if closure is not None:
  106. with torch.enable_grad():
  107. loss = closure()
  108. for group in self.param_groups:
  109. params_with_grad = []
  110. grads = []
  111. exp_avgs = []
  112. exp_avg_sqs = []
  113. max_exp_avg_sqs = []
  114. state_steps = []
  115. beta1, beta2 = group['betas']
  116. self._init_group(
  117. group,
  118. params_with_grad,
  119. grads,
  120. exp_avgs,
  121. exp_avg_sqs,
  122. max_exp_avg_sqs,
  123. state_steps)
  124. adam(
  125. params_with_grad,
  126. grads,
  127. exp_avgs,
  128. exp_avg_sqs,
  129. max_exp_avg_sqs,
  130. state_steps,
  131. amsgrad=group['amsgrad'],
  132. beta1=beta1,
  133. beta2=beta2,
  134. lr=group['lr'],
  135. weight_decay=group['weight_decay'],
  136. eps=group['eps'],
  137. maximize=group['maximize'],
  138. foreach=group['foreach'],
  139. capturable=group['capturable'],
  140. differentiable=group['differentiable'],
  141. fused=group['fused'],
  142. grad_scale=getattr(self, "grad_scale", None),
  143. found_inf=getattr(self, "found_inf", None),
  144. )
  145. return loss
  146. Adam.__doc__ = r"""Implements Adam algorithm.
  147. .. math::
  148. \begin{aligned}
  149. &\rule{110mm}{0.4pt} \\
  150. &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
  151. \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
  152. &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
  153. \:\textit{maximize} \\
  154. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  155. v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex]
  156. &\rule{110mm}{0.4pt} \\
  157. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  158. &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
  159. &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  160. &\hspace{5mm}\textbf{else} \\
  161. &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  162. &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
  163. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  164. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  165. &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  166. &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
  167. &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
  168. &\hspace{5mm}\textbf{if} \: amsgrad \\
  169. &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
  170. \widehat{v_t}) \\
  171. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
  172. \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
  173. &\hspace{5mm}\textbf{else} \\
  174. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
  175. \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
  176. &\rule{110mm}{0.4pt} \\[-1.ex]
  177. &\bf{return} \: \theta_t \\[-1.ex]
  178. &\rule{110mm}{0.4pt} \\[-1.ex]
  179. \end{aligned}
  180. For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
  181. """ + r"""
  182. Args:
  183. params (iterable): iterable of parameters to optimize or dicts defining
  184. parameter groups
  185. lr (float, optional): learning rate (default: 1e-3)
  186. betas (Tuple[float, float], optional): coefficients used for computing
  187. running averages of gradient and its square (default: (0.9, 0.999))
  188. eps (float, optional): term added to the denominator to improve
  189. numerical stability (default: 1e-8)
  190. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  191. amsgrad (bool, optional): whether to use the AMSGrad variant of this
  192. algorithm from the paper `On the Convergence of Adam and Beyond`_
  193. (default: False)
  194. {foreach}
  195. {maximize}
  196. {capturable}
  197. {differentiable}
  198. {fused}
  199. .. _Adam\: A Method for Stochastic Optimization:
  200. https://arxiv.org/abs/1412.6980
  201. .. _On the Convergence of Adam and Beyond:
  202. https://openreview.net/forum?id=ryQu7f-RZ
  203. """.format(foreach=_foreach_doc, maximize=_maximize_doc, capturable=_capturable_doc,
  204. differentiable=_differentiable_doc, fused=_fused_doc)
  205. def adam(params: List[Tensor],
  206. grads: List[Tensor],
  207. exp_avgs: List[Tensor],
  208. exp_avg_sqs: List[Tensor],
  209. max_exp_avg_sqs: List[Tensor],
  210. state_steps: List[Tensor],
  211. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  212. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  213. foreach: Optional[bool] = None,
  214. capturable: bool = False,
  215. differentiable: bool = False,
  216. fused: Optional[bool] = None,
  217. grad_scale: Optional[Tensor] = None,
  218. found_inf: Optional[Tensor] = None,
  219. *,
  220. amsgrad: bool,
  221. beta1: float,
  222. beta2: float,
  223. lr: float,
  224. weight_decay: float,
  225. eps: float,
  226. maximize: bool):
  227. r"""Functional API that performs Adam algorithm computation.
  228. See :class:`~torch.optim.Adam` for details.
  229. """
  230. # Respect when the user inputs False/True for foreach or fused. We only want to change
  231. # the default when neither have been user-specified. Note that we default to foreach
  232. # and pass False to use_fused. This is not a mistake--we want to give the fused impl
  233. # bake-in time before making it the default, even if it is typically faster.
  234. if fused is None and foreach is None:
  235. _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
  236. if fused is None:
  237. fused = False
  238. if foreach is None:
  239. foreach = False
  240. if not all(isinstance(t, torch.Tensor) for t in state_steps):
  241. raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
  242. if foreach and torch.jit.is_scripting():
  243. raise RuntimeError('torch.jit.script not supported with foreach optimizers')
  244. if fused and not torch.jit.is_scripting():
  245. func = _fused_adam
  246. elif foreach and not torch.jit.is_scripting():
  247. func = _multi_tensor_adam
  248. else:
  249. func = _single_tensor_adam
  250. func(params,
  251. grads,
  252. exp_avgs,
  253. exp_avg_sqs,
  254. max_exp_avg_sqs,
  255. state_steps,
  256. amsgrad=amsgrad,
  257. beta1=beta1,
  258. beta2=beta2,
  259. lr=lr,
  260. weight_decay=weight_decay,
  261. eps=eps,
  262. maximize=maximize,
  263. capturable=capturable,
  264. differentiable=differentiable,
  265. grad_scale=grad_scale,
  266. found_inf=found_inf)
  267. def _single_tensor_adam(params: List[Tensor],
  268. grads: List[Tensor],
  269. exp_avgs: List[Tensor],
  270. exp_avg_sqs: List[Tensor],
  271. max_exp_avg_sqs: List[Tensor],
  272. state_steps: List[Tensor],
  273. grad_scale: Optional[Tensor],
  274. found_inf: Optional[Tensor],
  275. *,
  276. amsgrad: bool,
  277. beta1: float,
  278. beta2: float,
  279. lr: float,
  280. weight_decay: float,
  281. eps: float,
  282. maximize: bool,
  283. capturable: bool,
  284. differentiable: bool):
  285. assert grad_scale is None and found_inf is None
  286. for i, param in enumerate(params):
  287. grad = grads[i] if not maximize else -grads[i]
  288. exp_avg = exp_avgs[i]
  289. exp_avg_sq = exp_avg_sqs[i]
  290. step_t = state_steps[i]
  291. if capturable:
  292. assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors."
  293. # update step
  294. step_t += 1
  295. if weight_decay != 0:
  296. grad = grad.add(param, alpha=weight_decay)
  297. if torch.is_complex(param):
  298. grad = torch.view_as_real(grad)
  299. exp_avg = torch.view_as_real(exp_avg)
  300. exp_avg_sq = torch.view_as_real(exp_avg_sq)
  301. param = torch.view_as_real(param)
  302. # Decay the first and second moment running average coefficient
  303. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  304. exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
  305. if capturable or differentiable:
  306. step = step_t
  307. # 1 - beta1 ** step can't be captured in a CUDA graph, even if step is a CUDA tensor
  308. # (incurs "RuntimeError: CUDA error: operation not permitted when stream is capturing")
  309. bias_correction1 = 1 - torch.pow(beta1, step)
  310. bias_correction2 = 1 - torch.pow(beta2, step)
  311. step_size = lr / bias_correction1
  312. step_size_neg = step_size.neg()
  313. bias_correction2_sqrt = bias_correction2.sqrt()
  314. if amsgrad:
  315. # Maintains the maximum of all 2nd moment running avg. till now
  316. if differentiable:
  317. max_exp_avg_sqs_i = max_exp_avg_sqs[i].clone()
  318. else:
  319. max_exp_avg_sqs_i = max_exp_avg_sqs[i]
  320. max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sqs_i, exp_avg_sq))
  321. # Uses the max. for normalizing running avg. of gradient
  322. # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
  323. # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
  324. denom = (max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg)
  325. else:
  326. denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg)
  327. param.addcdiv_(exp_avg, denom)
  328. else:
  329. step = _get_value(step_t)
  330. bias_correction1 = 1 - beta1 ** step
  331. bias_correction2 = 1 - beta2 ** step
  332. step_size = lr / bias_correction1
  333. bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
  334. if amsgrad:
  335. # Maintains the maximum of all 2nd moment running avg. till now
  336. torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
  337. # Use the max. for normalizing running avg. of gradient
  338. denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
  339. else:
  340. denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
  341. param.addcdiv_(exp_avg, denom, value=-step_size)
  342. def _multi_tensor_adam(params: List[Tensor],
  343. grads: List[Tensor],
  344. exp_avgs: List[Tensor],
  345. exp_avg_sqs: List[Tensor],
  346. max_exp_avg_sqs: List[Tensor],
  347. state_steps: List[Tensor],
  348. grad_scale: Optional[Tensor],
  349. found_inf: Optional[Tensor],
  350. *,
  351. amsgrad: bool,
  352. beta1: float,
  353. beta2: float,
  354. lr: float,
  355. weight_decay: float,
  356. eps: float,
  357. maximize: bool,
  358. capturable: bool,
  359. differentiable: bool):
  360. if len(params) == 0:
  361. return
  362. if capturable:
  363. assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \
  364. "If capturable=True, params and state_steps must be CUDA tensors."
  365. assert grad_scale is None and found_inf is None
  366. assert not differentiable, "_foreach ops don't support autograd"
  367. grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
  368. for (device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,
  369. device_max_exp_avg_sqs, device_state_steps) in grouped_tensors.values():
  370. if maximize:
  371. device_grads = torch._foreach_neg(tuple(device_grads)) # type: ignore[assignment]
  372. # Handle complex parameters
  373. device_grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_grads]
  374. device_exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avgs]
  375. device_exp_avg_sqs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avg_sqs]
  376. params_ = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_params]
  377. # update steps
  378. torch._foreach_add_(device_state_steps, 1)
  379. if weight_decay != 0:
  380. device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
  381. # Decay the first and second moment running average coefficient
  382. torch._foreach_mul_(device_exp_avgs, beta1)
  383. torch._foreach_add_(device_exp_avgs, device_grads, alpha=1 - beta1)
  384. torch._foreach_mul_(device_exp_avg_sqs, beta2)
  385. torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)
  386. if capturable:
  387. # TODO: use foreach_pow if/when foreach_pow is added
  388. bias_correction1 = [torch.pow(beta1, step) for step in device_state_steps]
  389. bias_correction2 = [torch.pow(beta2, step) for step in device_state_steps]
  390. # foreach_sub doesn't allow a scalar as the first arg
  391. torch._foreach_sub_(bias_correction1, 1)
  392. torch._foreach_sub_(bias_correction2, 1)
  393. torch._foreach_neg_(bias_correction1)
  394. torch._foreach_neg_(bias_correction2)
  395. # foreach_div doesn't allow a scalar as the first arg
  396. step_size = torch._foreach_div(bias_correction1, lr)
  397. torch._foreach_reciprocal_(step_size)
  398. torch._foreach_neg_(step_size)
  399. bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)
  400. if amsgrad:
  401. # Maintains the maximum of all 2nd moment running avg. till now
  402. torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment]
  403. # Use the max. for normalizing running avg. of gradient
  404. max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
  405. # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
  406. # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
  407. torch._foreach_div_(max_exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size))
  408. eps_over_step_size = torch._foreach_div(step_size, eps)
  409. torch._foreach_reciprocal_(eps_over_step_size)
  410. denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps_over_step_size)
  411. else:
  412. exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
  413. torch._foreach_div_(exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size))
  414. eps_over_step_size = torch._foreach_div(step_size, eps)
  415. torch._foreach_reciprocal_(eps_over_step_size)
  416. denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)
  417. torch._foreach_addcdiv_(params_, device_exp_avgs, denom)
  418. else:
  419. bias_correction1 = [1 - beta1 ** _get_value(step) for step in device_state_steps]
  420. bias_correction2 = [1 - beta2 ** _get_value(step) for step in device_state_steps]
  421. step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
  422. bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
  423. if amsgrad:
  424. # Maintains the maximum of all 2nd moment running avg. till now
  425. torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
  426. # Use the max. for normalizing running avg. of gradient
  427. max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
  428. torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction2_sqrt)
  429. denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps)
  430. else:
  431. exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
  432. torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
  433. denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
  434. torch._foreach_addcdiv_(params_, device_exp_avgs, denom, step_size)
  435. def _fused_adam(
  436. params: List[Tensor],
  437. grads: List[Tensor],
  438. exp_avgs: List[Tensor],
  439. exp_avg_sqs: List[Tensor],
  440. max_exp_avg_sqs: List[Tensor],
  441. state_steps: List[Tensor],
  442. grad_scale: Optional[Tensor],
  443. found_inf: Optional[Tensor],
  444. *,
  445. amsgrad: bool,
  446. beta1: float,
  447. beta2: float,
  448. lr: float,
  449. weight_decay: float,
  450. eps: float,
  451. maximize: bool,
  452. capturable: bool, # Needed for consistency.
  453. differentiable: bool,
  454. ) -> None:
  455. grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
  456. grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
  457. found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
  458. grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
  459. for (device, dtype) in grouped_tensors:
  460. (
  461. device_params,
  462. device_grads,
  463. device_exp_avgs,
  464. device_exp_avg_sqs,
  465. device_max_exp_avg_sqs,
  466. device_state_steps,
  467. ) = grouped_tensors[(device, dtype)]
  468. if grad_scale is not None and found_inf is not None:
  469. if device not in grad_scale_dict:
  470. grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)
  471. if found_inf not in found_inf_dict:
  472. found_inf_dict[device] = found_inf.to(device, non_blocking=True)
  473. device_grad_scale = grad_scale_dict[device]
  474. device_found_inf = found_inf_dict[device]
  475. else:
  476. device_grad_scale = None
  477. device_found_inf = None
  478. torch._foreach_add_(device_state_steps, 1)
  479. torch._fused_adam_(
  480. device_params,
  481. device_grads,
  482. device_exp_avgs,
  483. device_exp_avg_sqs,
  484. device_max_exp_avg_sqs,
  485. device_state_steps,
  486. amsgrad=amsgrad,
  487. lr=lr,
  488. beta1=beta1,
  489. beta2=beta2,
  490. weight_decay=weight_decay,
  491. eps=eps,
  492. maximize=maximize,
  493. grad_scale=device_grad_scale,
  494. found_inf=device_found_inf,
  495. )
  496. if device_found_inf is not None:
  497. torch._foreach_sub_(device_state_steps, [device_found_inf] * len(device_state_steps))