123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634 |
- import torch
- from torch import Tensor
- from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt,
- _stack_if_compiling, _capturable_doc, _differentiable_doc, _foreach_doc,
- _fused_doc, _maximize_doc, _default_to_fused_or_foreach)
- from typing import List, Optional
- from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
- __all__ = ["AdamW", "adamw"]
- class AdamW(Optimizer):
- def __init__(
- self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2,
- amsgrad=False,
- *,
- maximize: bool = False,
- foreach: Optional[bool] = None,
- capturable: bool = False,
- differentiable: bool = False,
- fused: Optional[bool] = None,
- ):
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= eps:
- raise ValueError("Invalid epsilon value: {}".format(eps))
- if not 0.0 <= betas[0] < 1.0:
- raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
- if not 0.0 <= betas[1] < 1.0:
- raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
- if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- defaults = dict(
- lr=lr,
- betas=betas,
- eps=eps,
- weight_decay=weight_decay,
- amsgrad=amsgrad,
- foreach=foreach,
- maximize=maximize,
- capturable=capturable,
- differentiable=differentiable,
- fused=fused,
- )
- super().__init__(params, defaults)
- if fused:
- if differentiable:
- raise RuntimeError("`fused` does not support `differentiable`")
- self._step_supports_amp_scaling = True
- # TODO(crcrpar): [low prec params & their higher prec copy]
- # Suppor AMP with FP16/BF16 model params which would need
- # higher prec copy of params to do update math in higher prec to
- # alleviate the loss of information.
- if not all(
- p.is_cuda and torch.is_floating_point(p)
- for pg in self.param_groups for p in pg['params']
- ):
- raise RuntimeError("`fused=True` requires all the params to be CUDA, floating point Tensor")
- if foreach:
- raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
- def __setstate__(self, state):
- super().__setstate__(state)
- for group in self.param_groups:
- group.setdefault("amsgrad", False)
- group.setdefault("maximize", False)
- group.setdefault("foreach", None)
- group.setdefault("capturable", False)
- group.setdefault("differentiable", False)
- group.setdefault("fused", None)
- state_values = list(self.state.values())
- step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]["step"]
- )
- if not step_is_tensor:
- for s in state_values:
- s["step"] = torch.tensor(float(s["step"]))
- def _init_group(
- self,
- group,
- params_with_grad,
- grads,
- amsgrad,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- ):
- for p in group["params"]:
- if p.grad is None:
- continue
- params_with_grad.append(p)
- if p.grad.is_sparse:
- raise RuntimeError("AdamW does not support sparse gradients")
- grads.append(p.grad)
- state = self.state[p]
- # State initialization
- if len(state) == 0:
- state["step"] = (
- torch.zeros((1,), dtype=torch.float, device=p.device)
- if group["capturable"] or group["fused"]
- else torch.tensor(0.0)
- )
- # Exponential moving average of gradient values
- state["exp_avg"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
- # Exponential moving average of squared gradient values
- state["exp_avg_sq"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
- if amsgrad:
- # Maintains max of all exp. moving avg. of sq. grad. values
- state["max_exp_avg_sq"] = torch.zeros_like(
- p, memory_format=torch.preserve_format
- )
- exp_avgs.append(state["exp_avg"])
- exp_avg_sqs.append(state["exp_avg_sq"])
- if amsgrad:
- max_exp_avg_sqs.append(state["max_exp_avg_sq"])
- state_steps.append(state["step"])
- @_use_grad_for_differentiable
- def step(self, closure=None):
- """Performs a single optimization step.
- Args:
- closure (Callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- self._cuda_graph_capture_health_check()
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
- for group in self.param_groups:
- params_with_grad = []
- grads = []
- exp_avgs = []
- exp_avg_sqs = []
- max_exp_avg_sqs = []
- state_steps = []
- amsgrad = group["amsgrad"]
- beta1, beta2 = group["betas"]
- self._init_group(
- group,
- params_with_grad,
- grads,
- amsgrad,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- )
- adamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- amsgrad=amsgrad,
- beta1=beta1,
- beta2=beta2,
- lr=group["lr"],
- weight_decay=group["weight_decay"],
- eps=group["eps"],
- maximize=group["maximize"],
- foreach=group["foreach"],
- capturable=group["capturable"],
- differentiable=group["differentiable"],
- fused=group["fused"],
- grad_scale=getattr(self, "grad_scale", None),
- found_inf=getattr(self, "found_inf", None),
- )
- return loss
- AdamW.__doc__ = r"""Implements AdamW algorithm.
- .. math::
- \begin{aligned}
- &\rule{110mm}{0.4pt} \\
- &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
- \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
- \: \epsilon \text{ (epsilon)} \\
- &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
- \: \textit{maximize} \\
- &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
- \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex]
- &\rule{110mm}{0.4pt} \\
- &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
- &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
- &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
- &\hspace{5mm}\textbf{else} \\
- &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
- &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
- &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
- &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
- &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
- &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
- &\hspace{5mm}\textbf{if} \: amsgrad \\
- &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max},
- \widehat{v_t}) \\
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
- \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\
- &\hspace{5mm}\textbf{else} \\
- &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
- \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
- &\rule{110mm}{0.4pt} \\[-1.ex]
- &\bf{return} \: \theta_t \\[-1.ex]
- &\rule{110mm}{0.4pt} \\[-1.ex]
- \end{aligned}
- For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_.
- """ + r"""
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- amsgrad (bool, optional): whether to use the AMSGrad variant of this
- algorithm from the paper `On the Convergence of Adam and Beyond`_
- (default: False)
- {maximize}
- {foreach}
- {capturable}
- {differentiable}
- {fused}
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
- """.format(maximize=_maximize_doc,
- foreach=_foreach_doc,
- fused=_fused_doc,
- capturable=_capturable_doc,
- differentiable=_differentiable_doc)
- def adamw(
- params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- max_exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
- # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
- foreach: Optional[bool] = None,
- capturable: bool = False,
- differentiable: bool = False,
- fused: Optional[bool] = None,
- grad_scale: Optional[Tensor] = None,
- found_inf: Optional[Tensor] = None,
- *,
- amsgrad: bool,
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float,
- maximize: bool,
- ):
- r"""Functional API that performs AdamW algorithm computation.
- See :class:`~torch.optim.AdamW` for details.
- """
- if not all(isinstance(t, torch.Tensor) for t in state_steps):
- raise RuntimeError(
- "API has changed, `state_steps` argument must contain a list of singleton tensors"
- )
- # Respect when the user inputs False/True for foreach or fused. We only want to change
- # the default when neither have been user-specified. Note that we default to foreach
- # and pass False to use_fused. This is not a mistake--we want to give the fused impl
- # bake-in time before making it the default, even if it is typically faster.
- if fused is None and foreach is None:
- _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
- if fused is None:
- fused = False
- if foreach is None:
- foreach = False
- if foreach and torch.jit.is_scripting():
- raise RuntimeError("torch.jit.script not supported with foreach optimizers")
- if fused and torch.jit.is_scripting():
- raise RuntimeError("torch.jit.script not supported with fused optimizers")
- if fused and not torch.jit.is_scripting():
- func = _fused_adamw
- elif foreach and not torch.jit.is_scripting():
- func = _multi_tensor_adamw
- else:
- func = _single_tensor_adamw
- func(
- params,
- grads,
- exp_avgs,
- exp_avg_sqs,
- max_exp_avg_sqs,
- state_steps,
- amsgrad=amsgrad,
- beta1=beta1,
- beta2=beta2,
- lr=lr,
- weight_decay=weight_decay,
- eps=eps,
- maximize=maximize,
- capturable=capturable,
- differentiable=differentiable,
- grad_scale=grad_scale,
- found_inf=found_inf,
- )
- def _single_tensor_adamw(
- params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- max_exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- grad_scale: Optional[Tensor],
- found_inf: Optional[Tensor],
- *,
- amsgrad: bool,
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float,
- maximize: bool,
- capturable: bool,
- differentiable: bool,
- ):
- assert grad_scale is None and found_inf is None
- for i, param in enumerate(params):
- grad = grads[i] if not maximize else -grads[i]
- exp_avg = exp_avgs[i]
- exp_avg_sq = exp_avg_sqs[i]
- step_t = state_steps[i]
- if capturable:
- assert (
- param.is_cuda and step_t.is_cuda
- ), "If capturable=True, params and state_steps must be CUDA tensors."
- if torch.is_complex(param):
- grad = torch.view_as_real(grad)
- exp_avg = torch.view_as_real(exp_avg)
- exp_avg_sq = torch.view_as_real(exp_avg_sq)
- param = torch.view_as_real(param)
- # update step
- step_t += 1
- # Perform stepweight decay
- param.mul_(1 - lr * weight_decay)
- # Decay the first and second moment running average coefficient
- exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
- if capturable or differentiable:
- step = step_t
- # 1 - beta1 ** step can't be captured in a CUDA graph, even if step is a CUDA tensor
- # (incurs "RuntimeError: CUDA error: operation not permitted when stream is capturing")
- bias_correction1 = 1 - torch.pow(beta1, step)
- bias_correction2 = 1 - torch.pow(beta2, step)
- step_size = lr / bias_correction1
- step_size_neg = step_size.neg()
- bias_correction2_sqrt = bias_correction2.sqrt()
- if amsgrad:
- # Maintains the maximum of all 2nd moment running avg. till now
- if differentiable:
- max_exp_avg_sqs_i = max_exp_avg_sqs[i].clone()
- else:
- max_exp_avg_sqs_i = max_exp_avg_sqs[i]
- max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sqs_i, exp_avg_sq))
- # Uses the max. for normalizing running avg. of gradient
- # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
- # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
- denom = (
- max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
- ).add_(eps / step_size_neg)
- else:
- denom = (
- exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
- ).add_(eps / step_size_neg)
- param.addcdiv_(exp_avg, denom)
- else:
- step = _get_value(step_t)
- bias_correction1 = 1 - beta1 ** step
- bias_correction2 = 1 - beta2 ** step
- step_size = lr / bias_correction1
- bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)
- if amsgrad:
- # Maintains the maximum of all 2nd moment running avg. till now
- torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
- # Use the max. for normalizing running avg. of gradient
- denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
- else:
- denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
- param.addcdiv_(exp_avg, denom, value=-step_size)
- def _multi_tensor_adamw(
- params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- max_exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- grad_scale: Optional[Tensor],
- found_inf: Optional[Tensor],
- *,
- amsgrad: bool,
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float,
- maximize: bool,
- capturable: bool,
- differentiable: bool,
- ):
- if len(params) == 0:
- return
- if capturable:
- assert all(
- p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)
- ), "If capturable=True, params and state_steps must be CUDA tensors."
- assert not differentiable, "_foreach ops don't support autograd"
- assert grad_scale is None and found_inf is None
- grouped_tensors = _group_tensors_by_device_and_dtype([
- params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
- for (device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,
- device_max_exp_avg_sqs, device_state_steps) in grouped_tensors.values():
- if maximize:
- device_grads = torch._foreach_neg(tuple(device_grads)) # type: ignore[assignment]
- device_grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_grads]
- device_exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avgs]
- device_exp_avg_sqs = [
- torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avg_sqs
- ]
- device_params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_params]
- # update steps
- torch._foreach_add_(device_state_steps, 1)
- # Perform stepweight decay
- torch._foreach_mul_(device_params, 1 - lr * weight_decay)
- # Decay the first and second moment running average coefficient
- torch._foreach_mul_(device_exp_avgs, beta1)
- torch._foreach_add_(device_exp_avgs, device_grads, alpha=1 - beta1)
- torch._foreach_mul_(device_exp_avg_sqs, beta2)
- torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)
- if capturable:
- # TODO: use foreach_pow if/when foreach_pow is added
- bias_correction1 = [torch.pow(beta1, step) for step in device_state_steps]
- bias_correction2 = [torch.pow(beta2, step) for step in device_state_steps]
- # foreach_sub doesn't allow a scalar as the first arg
- torch._foreach_sub_(bias_correction1, 1)
- torch._foreach_sub_(bias_correction2, 1)
- torch._foreach_neg_(bias_correction1)
- torch._foreach_neg_(bias_correction2)
- # foreach_div doesn't allow a scalar as the first arg
- step_size = torch._foreach_div(bias_correction1, lr)
- torch._foreach_reciprocal_(step_size)
- torch._foreach_neg_(step_size)
- bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)
- if amsgrad:
- # Maintains the maximum of all 2nd moment running avg. till now
- torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
- # Use the max. for normalizing running avg. of gradient
- max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
- # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
- # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
- torch._foreach_div_(
- max_exp_avg_sq_sqrt,
- torch._foreach_mul(bias_correction2_sqrt, step_size),
- )
- eps_over_step_size = torch._foreach_div(step_size, eps)
- torch._foreach_reciprocal_(eps_over_step_size)
- denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps_over_step_size)
- else:
- exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
- torch._foreach_div_(
- exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
- )
- eps_over_step_size = torch._foreach_div(step_size, eps)
- torch._foreach_reciprocal_(eps_over_step_size)
- denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)
- torch._foreach_addcdiv_(device_params, device_exp_avgs, denom)
- else:
- bias_correction1 = [1 - beta1 ** _get_value(step) for step in device_state_steps]
- bias_correction2 = [1 - beta2 ** _get_value(step) for step in device_state_steps]
- step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
- bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
- if amsgrad:
- # Maintains the maximum of all 2nd moment running avg. till now
- torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
- # Use the max. for normalizing running avg. of gradient
- max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
- torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction2_sqrt)
- denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps)
- else:
- exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
- torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
- denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
- torch._foreach_addcdiv_(device_params, device_exp_avgs, denom, step_size)
- def _fused_adamw(
- params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- max_exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- grad_scale: Optional[Tensor],
- found_inf: Optional[Tensor],
- *,
- amsgrad: bool,
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float,
- maximize: bool,
- capturable: bool, # Needed for consistency.
- differentiable: bool,
- ) -> None:
- if differentiable:
- raise RuntimeError("_fused_adamw is not differentiable")
- grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None
- found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
- grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
- for (device, dtype) in grouped_tensors:
- (
- device_params,
- device_grads,
- device_exp_avgs,
- device_exp_avg_sqs,
- device_max_exp_avg_sqs,
- device_state_steps,
- ) = grouped_tensors[(device, dtype)]
- if grad_scale is not None and found_inf is not None:
- if device not in grad_scale_dict:
- grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)
- if found_inf not in found_inf_dict:
- found_inf_dict[device] = found_inf.to(device, non_blocking=True)
- device_grad_scale = grad_scale_dict[device]
- device_found_inf = found_inf_dict[device]
- else:
- device_grad_scale = None
- device_found_inf = None
- torch._foreach_add_(device_state_steps, 1)
- torch._fused_adamw_(
- device_params,
- device_grads,
- device_exp_avgs,
- device_exp_avg_sqs,
- device_max_exp_avg_sqs,
- device_state_steps,
- amsgrad=amsgrad,
- lr=lr,
- beta1=beta1,
- beta2=beta2,
- weight_decay=weight_decay,
- eps=eps,
- maximize=maximize,
- grad_scale=device_grad_scale,
- found_inf=device_found_inf,
- )
- if device_found_inf is not None:
- torch._foreach_sub_(device_state_steps, [device_found_inf] * len(device_state_steps))
|