12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- r"""Functional interface"""
- import math
- from torch import Tensor
- from typing import List
- from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401
- from .adagrad import adagrad, _make_sparse # type: ignore[attr-defined] # noqa: F401
- from .adam import adam # type: ignore[attr-defined] # noqa: F401
- from .adamw import adamw # type: ignore[attr-defined] # noqa: F401
- from .adamax import adamax # type: ignore[attr-defined] # noqa: F401
- from .asgd import asgd # type: ignore[attr-defined] # noqa: F401
- from .nadam import nadam # type: ignore[attr-defined] # noqa: F401
- from .radam import radam # type: ignore[attr-defined] # noqa: F401
- from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401
- from .rprop import rprop # type: ignore[attr-defined] # noqa: F401
- from .sgd import sgd # type: ignore[attr-defined] # noqa: F401
- # TODO: use foreach API in optim._functional to do all the computation
- def sparse_adam(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[int],
- *,
- eps: float,
- beta1: float,
- beta2: float,
- lr: float,
- maximize: bool):
- r"""Functional API that performs Sparse Adam algorithm computation.
- See :class:`~torch.optim.SparseAdam` for details.
- """
- for i, param in enumerate(params):
- grad = grads[i]
- grad = grad if not maximize else -grad
- grad = grad.coalesce() # the update is non-linear so indices must be unique
- grad_indices = grad._indices()
- grad_values = grad._values()
- if grad_values.numel() == 0:
- # Skip update for empty grad
- continue
- size = grad.size()
- exp_avg = exp_avgs[i]
- exp_avg_sq = exp_avg_sqs[i]
- step = state_steps[i]
- def make_sparse(values):
- constructor = grad.new
- if grad_indices.dim() == 0 or values.dim() == 0:
- return constructor().resize_as_(grad)
- return constructor(grad_indices, values, size)
- # Decay the first and second moment running average coefficient
- # old <- b * old + (1 - b) * new
- # <==> old += (1 - b) * (new - old)
- old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
- exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
- exp_avg.add_(make_sparse(exp_avg_update_values))
- old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
- exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
- exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
- # Dense addition again is intended, avoiding another sparse_mask
- numer = exp_avg_update_values.add_(old_exp_avg_values)
- exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
- denom = exp_avg_sq_update_values.sqrt_().add_(eps)
- del exp_avg_update_values, exp_avg_sq_update_values
- bias_correction1 = 1 - beta1 ** step
- bias_correction2 = 1 - beta2 ** step
- step_size = lr * math.sqrt(bias_correction2) / bias_correction1
- param.add_(make_sparse(-step_size * numer.div_(denom)))
|