functional_adam.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from typing import Dict, List, Optional, Tuple
  2. import torch
  3. import torch.optim._functional as F
  4. from torch import Tensor
  5. __all__: List[str] = []
  6. # Define a TorchScript compatible Functional Adam Optimizer
  7. # where we use these optimizer in a functional way.
  8. # Instead of using the `param.grad` when updating parameters,
  9. # we explicitly allow the distributed optimizer pass gradients to
  10. # the `step` function. In this way, we could separate the gradients
  11. # and parameters and allow multithreaded trainer to update the
  12. # parameters without data traces on accumulating to the same .grad.
  13. # NOTE: This should be only used by distributed optimizer internals
  14. # and not meant to expose to the user.
  15. @torch.jit.script
  16. class _FunctionalAdam:
  17. def __init__(
  18. self,
  19. params: List[Tensor],
  20. lr: float = 1e-3,
  21. betas: Tuple[float, float] = (0.9, 0.999),
  22. eps: float = 1e-8,
  23. weight_decay: float = 0.0,
  24. amsgrad: bool = False,
  25. maximize: bool = False,
  26. foreach: bool = False,
  27. fused: bool = False,
  28. _allow_empty_param_list: bool = False,
  29. ):
  30. if not 0.0 <= lr:
  31. raise ValueError("Invalid learning rate: {}".format(lr))
  32. if not 0.0 <= eps:
  33. raise ValueError("Invalid epsilon value: {}".format(eps))
  34. if not 0.0 <= betas[0] < 1.0:
  35. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  36. if not 0.0 <= betas[1] < 1.0:
  37. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  38. if not 0.0 <= weight_decay:
  39. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  40. self.defaults = {
  41. "lr": lr,
  42. "eps": eps,
  43. "beta1": betas[0],
  44. "beta2": betas[1],
  45. "weight_decay": weight_decay,
  46. }
  47. self.amsgrad = amsgrad
  48. self.maximize = maximize
  49. self.foreach = foreach
  50. self.fused = fused
  51. self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})
  52. if len(params) == 0 and not _allow_empty_param_list:
  53. raise ValueError("optimizer got an empty parameter list")
  54. # NOTE: we only have one param_group and don't allow user to add additional
  55. # param group as it's not a common use case.
  56. self.param_group = {"params": params}
  57. def step_param(self, param: Tensor, grad: Optional[Tensor]):
  58. """
  59. Similar to step, but operates on a single parameter and optionally a
  60. gradient tensor.
  61. """
  62. params_with_grad = []
  63. grads = []
  64. exp_avgs = []
  65. exp_avg_sqs = []
  66. max_exp_avg_sqs = []
  67. state_steps: List[Tensor] = []
  68. if grad is not None:
  69. params_with_grad.append(param)
  70. grads.append(grad)
  71. if param not in self.state:
  72. self.state[param] = {}
  73. state = self.state[param]
  74. state["step"] = torch.tensor(0.0)
  75. state["exp_avg"] = torch.zeros_like(
  76. param, memory_format=torch.preserve_format
  77. )
  78. state["exp_avg_sq"] = torch.zeros_like(
  79. param, memory_format=torch.preserve_format
  80. )
  81. if self.amsgrad:
  82. state["max_exp_avg_sq"] = torch.zeros_like(
  83. param, memory_format=torch.preserve_format
  84. )
  85. state = self.state[param]
  86. exp_avgs.append(state["exp_avg"])
  87. exp_avg_sqs.append(state["exp_avg_sq"])
  88. if self.amsgrad:
  89. max_exp_avg_sqs.append(state["max_exp_avg_sq"])
  90. state_steps.append(state["step"])
  91. with torch.no_grad():
  92. F.adam(
  93. params_with_grad,
  94. grads,
  95. exp_avgs,
  96. exp_avg_sqs,
  97. max_exp_avg_sqs,
  98. state_steps,
  99. amsgrad=self.amsgrad,
  100. maximize=self.maximize,
  101. beta1=self.defaults["beta1"],
  102. beta2=self.defaults["beta2"],
  103. lr=self.defaults["lr"],
  104. weight_decay=self.defaults["weight_decay"],
  105. eps=self.defaults["eps"],
  106. foreach=self.foreach,
  107. fused=self.fused,
  108. grad_scale=None,
  109. found_inf=None,
  110. )
  111. def step(self, gradients: List[Optional[Tensor]]):
  112. params = self.param_group["params"]
  113. params_with_grad = []
  114. grads = []
  115. exp_avgs = []
  116. exp_avg_sqs = []
  117. max_exp_avg_sqs = []
  118. state_steps: List[Tensor] = []
  119. if len(params) != len(gradients):
  120. raise ValueError(
  121. "the gradients passed in does not equal to the size of the parameters!"
  122. + f"Params length: {len(params)}. "
  123. + f"Gradients length: {len(gradients)}"
  124. )
  125. for param, gradient in zip(self.param_group["params"], gradients):
  126. if gradient is not None:
  127. params_with_grad.append(param)
  128. grads.append(gradient)
  129. # Lazy state initialization
  130. if param not in self.state:
  131. self.state[param] = {}
  132. state = self.state[param]
  133. state["step"] = torch.tensor(0.0)
  134. # Exponential moving average of gradient values
  135. state["exp_avg"] = torch.zeros_like(
  136. param, memory_format=torch.preserve_format
  137. )
  138. # Exponential moving average of squared gradient values
  139. state["exp_avg_sq"] = torch.zeros_like(
  140. param, memory_format=torch.preserve_format
  141. )
  142. if self.amsgrad:
  143. # Maintains max of all exp. moving avg. of sq. grad. values
  144. state["max_exp_avg_sq"] = torch.zeros_like(
  145. param, memory_format=torch.preserve_format
  146. )
  147. state = self.state[param]
  148. exp_avgs.append(state["exp_avg"])
  149. exp_avg_sqs.append(state["exp_avg_sq"])
  150. if self.amsgrad:
  151. max_exp_avg_sqs.append(state["max_exp_avg_sq"])
  152. state_steps.append(state["step"])
  153. with torch.no_grad():
  154. F.adam(
  155. params_with_grad,
  156. grads,
  157. exp_avgs,
  158. exp_avg_sqs,
  159. max_exp_avg_sqs,
  160. state_steps,
  161. amsgrad=self.amsgrad,
  162. maximize=self.maximize,
  163. beta1=self.defaults["beta1"],
  164. beta2=self.defaults["beta2"],
  165. lr=self.defaults["lr"],
  166. weight_decay=self.defaults["weight_decay"],
  167. eps=self.defaults["eps"],
  168. foreach=self.foreach,
  169. fused=self.fused,
  170. grad_scale=None,
  171. found_inf=None,
  172. )