123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370 |
- import torch
- from torch import Tensor
- from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
- _default_to_fused_or_foreach, _differentiable_doc, _foreach_doc, _maximize_doc)
- from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
- from typing import List, Optional
- __all__ = ["Adagrad", "adagrad"]
- class Adagrad(Optimizer):
- def __init__(
- self,
- params,
- lr=1e-2,
- lr_decay=0,
- weight_decay=0,
- initial_accumulator_value=0,
- eps=1e-10,
- foreach: Optional[bool] = None,
- *,
- maximize: bool = False,
- differentiable: bool = False,
- ):
- if not 0.0 <= lr:
- raise ValueError("Invalid learning rate: {}".format(lr))
- if not 0.0 <= lr_decay:
- raise ValueError("Invalid lr_decay value: {}".format(lr_decay))
- if not 0.0 <= weight_decay:
- raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
- if not 0.0 <= initial_accumulator_value:
- raise ValueError(
- "Invalid initial_accumulator_value value: {}".format(
- initial_accumulator_value
- )
- )
- if not 0.0 <= eps:
- raise ValueError("Invalid epsilon value: {}".format(eps))
- defaults = dict(
- lr=lr,
- lr_decay=lr_decay,
- eps=eps,
- weight_decay=weight_decay,
- initial_accumulator_value=initial_accumulator_value,
- foreach=foreach,
- maximize=maximize,
- differentiable=differentiable,
- )
- super().__init__(params, defaults)
- for group in self.param_groups:
- for p in group["params"]:
- state = self.state[p]
- state["step"] = torch.tensor(0.0)
- init_value = (
- complex(initial_accumulator_value, initial_accumulator_value)
- if torch.is_complex(p)
- else initial_accumulator_value
- )
- state["sum"] = torch.full_like(
- p, init_value, memory_format=torch.preserve_format
- )
- def __setstate__(self, state):
- super().__setstate__(state)
- for group in self.param_groups:
- group.setdefault("foreach", None)
- group.setdefault("maximize", False)
- group.setdefault("differentiable", False)
- state_values = list(self.state.values())
- step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]["step"]
- )
- if not step_is_tensor:
- for s in state_values:
- s["step"] = torch.tensor(float(s["step"]))
- def share_memory(self):
- for group in self.param_groups:
- for p in group["params"]:
- state = self.state[p]
- state["sum"].share_memory_()
- def _init_group(self, group, params_with_grad, grads, state_sums, state_steps):
- has_sparse_grad = False
- for p in group["params"]:
- if p.grad is not None:
- if p.grad.is_sparse:
- has_sparse_grad = True
- params_with_grad.append(p)
- grads.append(p.grad)
- state = self.state[p]
- state_sums.append(state["sum"])
- state_steps.append(state["step"])
- return has_sparse_grad
- @_use_grad_for_differentiable
- def step(self, closure=None):
- """Performs a single optimization step.
- Args:
- closure (Callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
- for group in self.param_groups:
- params_with_grad = []
- grads = []
- state_sums = []
- state_steps = []
- has_sparse_grad = self._init_group(group, params_with_grad, grads, state_sums, state_steps)
- adagrad(
- params_with_grad,
- grads,
- state_sums,
- state_steps,
- lr=group["lr"],
- weight_decay=group["weight_decay"],
- lr_decay=group["lr_decay"],
- eps=group["eps"],
- has_sparse_grad=has_sparse_grad,
- foreach=group["foreach"],
- maximize=group["maximize"],
- differentiable=group["differentiable"],
- )
- return loss
- Adagrad.__doc__ = r"""Implements Adagrad algorithm.
- .. math::
- \begin{aligned}
- &\rule{110mm}{0.4pt} \\
- &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
- \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
- &\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\
- &\textbf{initialize} : state\_sum_0 \leftarrow 0 \\[-1.ex]
- &\rule{110mm}{0.4pt} \\
- &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
- &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
- &\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\
- &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
- &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
- &\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\
- &\hspace{5mm}\theta_t \leftarrow
- \theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\
- &\rule{110mm}{0.4pt} \\[-1.ex]
- &\bf{return} \: \theta_t \\[-1.ex]
- &\rule{110mm}{0.4pt} \\[-1.ex]
- \end{aligned}
- For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning
- and Stochastic Optimization`_.
- """ + r"""
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-2)
- lr_decay (float, optional): learning rate decay (default: 0)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-10)
- {foreach}
- {maximize}
- {differentiable}
- .. _Adaptive Subgradient Methods for Online Learning and Stochastic
- Optimization: http://jmlr.org/papers/v12/duchi11a.html
- """.format(foreach=_foreach_doc, maximize=_maximize_doc, differentiable=_differentiable_doc)
- def adagrad(
- params: List[Tensor],
- grads: List[Tensor],
- state_sums: List[Tensor],
- state_steps: List[Tensor],
- # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
- # setting these as kwargs for now as functional API is compiled by torch/distributed/optim
- has_sparse_grad: bool = None,
- foreach: Optional[bool] = None,
- differentiable: bool = False,
- *,
- lr: float,
- weight_decay: float,
- lr_decay: float,
- eps: float,
- maximize: bool,
- ):
- r"""Functional API that performs Adagrad algorithm computation.
- See :class:`~torch.optim.Adagrad` for details.
- """
- if not all(isinstance(t, torch.Tensor) for t in state_steps):
- raise RuntimeError(
- "API has changed, `state_steps` argument must contain a list of singleton tensors"
- )
- if foreach is None:
- _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
- if foreach and torch.jit.is_scripting():
- raise RuntimeError("torch.jit.script not supported with foreach optimizers")
- if foreach and not torch.jit.is_scripting():
- func = _multi_tensor_adagrad
- else:
- func = _single_tensor_adagrad
- func(
- params,
- grads,
- state_sums,
- state_steps,
- lr=lr,
- weight_decay=weight_decay,
- lr_decay=lr_decay,
- eps=eps,
- has_sparse_grad=has_sparse_grad,
- maximize=maximize,
- differentiable=differentiable,
- )
- def _make_sparse(grad, grad_indices, values):
- size = grad.size()
- if grad_indices.numel() == 0 or values.numel() == 0:
- return torch.empty_like(grad)
- return torch.sparse_coo_tensor(grad_indices, values, size)
- def _single_tensor_adagrad(
- params: List[Tensor],
- grads: List[Tensor],
- state_sums: List[Tensor],
- state_steps: List[Tensor],
- *,
- lr: float,
- weight_decay: float,
- lr_decay: float,
- eps: float,
- has_sparse_grad: bool,
- maximize: bool,
- differentiable: bool,
- ):
- for (param, grad, state_sum, step_t) in zip(params, grads, state_sums, state_steps):
- # update step
- step_t += 1
- step = _get_value(step_t)
- grad = grad if not maximize else -grad
- if weight_decay != 0:
- if grad.is_sparse:
- raise RuntimeError(
- "weight_decay option is not compatible with sparse gradients"
- )
- grad = grad.add(param, alpha=weight_decay)
- clr = lr / (1 + (step - 1) * lr_decay)
- if grad.is_sparse:
- grad = grad.coalesce() # the update is non-linear so indices must be unique
- grad_indices = grad._indices()
- grad_values = grad._values()
- state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
- std = state_sum.sparse_mask(grad)
- std_values = std._values().sqrt_().add_(eps)
- param.add_(
- _make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr
- )
- else:
- is_complex = torch.is_complex(param)
- if is_complex:
- grad = torch.view_as_real(grad)
- state_sum = torch.view_as_real(state_sum)
- param = torch.view_as_real(param)
- state_sum.addcmul_(grad, grad, value=1)
- if differentiable:
- std = state_sum.sqrt() + eps
- else:
- std = state_sum.sqrt().add_(eps)
- param.addcdiv_(grad, std, value=-clr)
- if is_complex:
- param = torch.view_as_complex(param)
- state_sum = torch.view_as_complex(state_sum)
- def _multi_tensor_adagrad(
- params: List[Tensor],
- grads: List[Tensor],
- state_sums: List[Tensor],
- state_steps: List[Tensor],
- *,
- lr: float,
- weight_decay: float,
- lr_decay: float,
- eps: float,
- has_sparse_grad: bool,
- maximize: bool,
- differentiable: bool,
- ):
- assert not differentiable, "_foreach ops don't support autograd"
- # Foreach functions will throw errors if given empty lists
- if len(params) == 0:
- return
- grouped_tensorlists = _group_tensors_by_device_and_dtype([params, grads, state_sums, state_steps])
- for device_params, device_grads, device_state_sums, device_state_steps in grouped_tensorlists.values():
- if maximize:
- device_grads = torch._foreach_neg(device_grads)
- device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)
- if device_has_sparse_grad:
- return _single_tensor_adagrad(
- device_params,
- device_grads,
- device_state_sums,
- device_state_steps,
- lr=lr,
- weight_decay=weight_decay,
- lr_decay=lr_decay,
- eps=eps,
- has_sparse_grad=True,
- maximize=False,
- differentiable=differentiable,
- )
- # Update steps
- torch._foreach_add_(device_state_steps, 1)
- if weight_decay != 0:
- device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
- minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in device_state_steps]
- device_grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_grads]
- device_state_sums = [
- torch.view_as_real(x) if torch.is_complex(x) else x for x in device_state_sums
- ]
- torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1)
- std = torch._foreach_add(torch._foreach_sqrt(device_state_sums), eps)
- toAdd = torch._foreach_div(torch._foreach_mul(device_grads, minus_clr), std)
- toAdd = [
- torch.view_as_complex(x) if torch.is_complex(device_params[i]) else x
- for i, x in enumerate(toAdd)
- ]
- torch._foreach_add_(device_params, toAdd)
- device_state_sums = [
- torch.view_as_complex(x) if torch.is_complex(device_params[i]) else x
- for i, x in enumerate(device_state_sums)
- ]
|