_functional.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. r"""Functional interface"""
  2. import math
  3. from torch import Tensor
  4. from typing import List
  5. from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401
  6. from .adagrad import adagrad, _make_sparse # type: ignore[attr-defined] # noqa: F401
  7. from .adam import adam # type: ignore[attr-defined] # noqa: F401
  8. from .adamw import adamw # type: ignore[attr-defined] # noqa: F401
  9. from .adamax import adamax # type: ignore[attr-defined] # noqa: F401
  10. from .asgd import asgd # type: ignore[attr-defined] # noqa: F401
  11. from .nadam import nadam # type: ignore[attr-defined] # noqa: F401
  12. from .radam import radam # type: ignore[attr-defined] # noqa: F401
  13. from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401
  14. from .rprop import rprop # type: ignore[attr-defined] # noqa: F401
  15. from .sgd import sgd # type: ignore[attr-defined] # noqa: F401
  16. # TODO: use foreach API in optim._functional to do all the computation
  17. def sparse_adam(params: List[Tensor],
  18. grads: List[Tensor],
  19. exp_avgs: List[Tensor],
  20. exp_avg_sqs: List[Tensor],
  21. state_steps: List[int],
  22. *,
  23. eps: float,
  24. beta1: float,
  25. beta2: float,
  26. lr: float,
  27. maximize: bool):
  28. r"""Functional API that performs Sparse Adam algorithm computation.
  29. See :class:`~torch.optim.SparseAdam` for details.
  30. """
  31. for i, param in enumerate(params):
  32. grad = grads[i]
  33. grad = grad if not maximize else -grad
  34. grad = grad.coalesce() # the update is non-linear so indices must be unique
  35. grad_indices = grad._indices()
  36. grad_values = grad._values()
  37. if grad_values.numel() == 0:
  38. # Skip update for empty grad
  39. continue
  40. size = grad.size()
  41. exp_avg = exp_avgs[i]
  42. exp_avg_sq = exp_avg_sqs[i]
  43. step = state_steps[i]
  44. def make_sparse(values):
  45. constructor = grad.new
  46. if grad_indices.dim() == 0 or values.dim() == 0:
  47. return constructor().resize_as_(grad)
  48. return constructor(grad_indices, values, size)
  49. # Decay the first and second moment running average coefficient
  50. # old <- b * old + (1 - b) * new
  51. # <==> old += (1 - b) * (new - old)
  52. old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
  53. exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
  54. exp_avg.add_(make_sparse(exp_avg_update_values))
  55. old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
  56. exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
  57. exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
  58. # Dense addition again is intended, avoiding another sparse_mask
  59. numer = exp_avg_update_values.add_(old_exp_avg_values)
  60. exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
  61. denom = exp_avg_sq_update_values.sqrt_().add_(eps)
  62. del exp_avg_update_values, exp_avg_sq_update_values
  63. bias_correction1 = 1 - beta1 ** step
  64. bias_correction2 = 1 - beta2 ** step
  65. step_size = lr * math.sqrt(bias_correction2) / bias_correction1
  66. param.add_(make_sparse(-step_size * numer.div_(denom)))