asgd.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. import torch
  2. from torch import Tensor
  3. from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _default_to_fused_or_foreach,
  4. _differentiable_doc, _foreach_doc, _maximize_doc)
  5. from torch._utils import is_compiling
  6. from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
  7. from typing import List, Optional
  8. __all__ = ["ASGD", "asgd"]
  9. def _to_tensor(x):
  10. if not isinstance(x, torch.Tensor):
  11. return torch.tensor(x)
  12. return x
  13. class ASGD(Optimizer):
  14. def __init__(
  15. self,
  16. params,
  17. lr=1e-2,
  18. lambd=1e-4,
  19. alpha=0.75,
  20. t0=1e6,
  21. weight_decay=0,
  22. foreach: Optional[bool] = None,
  23. maximize: bool = False,
  24. differentiable: bool = False,
  25. ):
  26. if not 0.0 <= lr:
  27. raise ValueError("Invalid learning rate: {}".format(lr))
  28. if not 0.0 <= weight_decay:
  29. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  30. defaults = dict(
  31. lr=lr,
  32. lambd=lambd,
  33. alpha=alpha,
  34. t0=t0,
  35. weight_decay=weight_decay,
  36. foreach=foreach,
  37. maximize=maximize,
  38. differentiable=differentiable,
  39. )
  40. super().__init__(params, defaults)
  41. def __setstate__(self, state):
  42. super().__setstate__(state)
  43. for group in self.param_groups:
  44. group.setdefault("foreach", None)
  45. group.setdefault("maximize", False)
  46. group.setdefault("differentiable", False)
  47. state_values = list(self.state.values())
  48. step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
  49. state_values[0]["step"]
  50. )
  51. if not step_is_tensor:
  52. for s in state_values:
  53. s["step"] = torch.tensor(float(s["step"]))
  54. eta_is_tensor = (len(state_values) != 0) and torch.is_tensor(
  55. state_values[0]["eta"]
  56. )
  57. if not eta_is_tensor:
  58. for s in state_values:
  59. s["eta"] = torch.tensor(s["eta"])
  60. mu_is_tensor = (len(state_values) != 0) and torch.is_tensor(
  61. state_values[0]["mu"]
  62. )
  63. if not mu_is_tensor:
  64. for s in state_values:
  65. s["mu"] = torch.tensor(float(s["mu"]))
  66. def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
  67. for p in group["params"]:
  68. if p.grad is not None:
  69. params_with_grad.append(p)
  70. if p.grad.is_sparse:
  71. raise RuntimeError("ASGD does not support sparse gradients")
  72. grads.append(p.grad)
  73. state = self.state[p]
  74. # State initialization
  75. if len(state) == 0:
  76. state["step"] = torch.tensor(0.0)
  77. state["eta"] = torch.tensor(group["lr"])
  78. state["mu"] = torch.tensor(1.0)
  79. state["ax"] = torch.zeros_like(
  80. p, memory_format=torch.preserve_format
  81. )
  82. mus.append(state["mu"])
  83. axs.append(state["ax"])
  84. etas.append(state["eta"])
  85. state_steps.append(state["step"])
  86. @_use_grad_for_differentiable
  87. def step(self, closure=None):
  88. """Performs a single optimization step.
  89. Args:
  90. closure (Callable, optional): A closure that reevaluates the model
  91. and returns the loss.
  92. """
  93. loss = None
  94. if closure is not None:
  95. with torch.enable_grad():
  96. loss = closure()
  97. for group in self.param_groups:
  98. params_with_grad = []
  99. grads = []
  100. mus = []
  101. axs = []
  102. etas = []
  103. state_steps = []
  104. self._init_group(group, params_with_grad, grads, mus, axs, etas, state_steps)
  105. asgd(
  106. params_with_grad,
  107. grads,
  108. axs,
  109. mus,
  110. etas,
  111. state_steps,
  112. lambd=group["lambd"],
  113. lr=group["lr"],
  114. t0=group["t0"],
  115. alpha=group["alpha"],
  116. weight_decay=group["weight_decay"],
  117. foreach=group["foreach"],
  118. maximize=group["maximize"],
  119. differentiable=group["differentiable"],
  120. )
  121. return loss
  122. ASGD.__doc__ = r"""Implements Averaged Stochastic Gradient Descent.
  123. It has been proposed in `Acceleration of stochastic approximation by
  124. averaging`_.
  125. Args:
  126. params (iterable): iterable of parameters to optimize or dicts defining
  127. parameter groups
  128. lr (float, optional): learning rate (default: 1e-2)
  129. lambd (float, optional): decay term (default: 1e-4)
  130. alpha (float, optional): power for eta update (default: 0.75)
  131. t0 (float, optional): point at which to start averaging (default: 1e6)
  132. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  133. {foreach}
  134. {maximize}
  135. {differentiable}
  136. .. _Acceleration of stochastic approximation by averaging:
  137. https://dl.acm.org/citation.cfm?id=131098
  138. """.format(foreach=_foreach_doc, maximize=_maximize_doc, differentiable=_differentiable_doc)
  139. def asgd(
  140. params: List[Tensor],
  141. grads: List[Tensor],
  142. axs: List[Tensor],
  143. mus: List[Tensor],
  144. etas: List[Tensor],
  145. state_steps: List[Tensor],
  146. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  147. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  148. foreach: Optional[bool] = None,
  149. maximize: bool = False,
  150. differentiable: bool = False,
  151. *,
  152. lambd: float,
  153. lr: float,
  154. t0: float,
  155. alpha: float,
  156. weight_decay: float,
  157. ):
  158. r"""Functional API that performs asgd algorithm computation.
  159. See :class:`~torch.optim.ASGD` for details.
  160. """
  161. if foreach is None:
  162. _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
  163. if foreach and torch.jit.is_scripting():
  164. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  165. if foreach and not torch.jit.is_scripting():
  166. func = _multi_tensor_asgd
  167. else:
  168. func = _single_tensor_asgd
  169. func(
  170. params,
  171. grads,
  172. axs,
  173. mus,
  174. etas,
  175. state_steps,
  176. lambd=lambd,
  177. lr=lr,
  178. t0=t0,
  179. alpha=alpha,
  180. weight_decay=weight_decay,
  181. maximize=maximize,
  182. differentiable=differentiable,
  183. )
  184. def _single_tensor_asgd(
  185. params: List[Tensor],
  186. grads: List[Tensor],
  187. axs: List[Tensor],
  188. mus: List[Tensor],
  189. etas: List[Tensor],
  190. state_steps: List[Tensor],
  191. *,
  192. lambd: float,
  193. lr: float,
  194. t0: float,
  195. alpha: float,
  196. weight_decay: float,
  197. maximize: bool,
  198. differentiable: bool,
  199. ):
  200. def _to_tensor(x):
  201. if not isinstance(x, torch.Tensor):
  202. return torch.tensor(x)
  203. return x
  204. for i, param in enumerate(params):
  205. grad = grads[i]
  206. grad = grad if not maximize else -grad
  207. mu = mus[i]
  208. ax = axs[i]
  209. eta = etas[i]
  210. step_t = state_steps[i]
  211. if torch.is_complex(param):
  212. grad = torch.view_as_real(grad)
  213. param = torch.view_as_real(param)
  214. ax = torch.view_as_real(ax)
  215. # update step
  216. step_t += 1
  217. step = _get_value(step_t)
  218. if weight_decay != 0:
  219. grad = grad.add(param, alpha=weight_decay)
  220. eta_value = _get_value(eta)
  221. # decay term
  222. param.mul_(1 - lambd * eta_value)
  223. # update parameter
  224. param.add_(grad, alpha=-eta_value)
  225. # averaging
  226. if is_compiling() or mu.item() != 1:
  227. ax.add_(param.sub(ax).mul(mu))
  228. else:
  229. ax.copy_(param)
  230. new_eta = _to_tensor(lr / ((1 + lambd * lr * step) ** alpha))
  231. eta.copy_(new_eta)
  232. new_mu = _to_tensor(1 / max(1, step - t0))
  233. mu.copy_(new_mu)
  234. def _multi_tensor_asgd(
  235. params: List[Tensor],
  236. grads: List[Tensor],
  237. axs: List[Tensor],
  238. mus: List[Tensor],
  239. etas: List[Tensor],
  240. state_steps: List[Tensor],
  241. *,
  242. lambd: float,
  243. lr: float,
  244. t0: float,
  245. alpha: float,
  246. weight_decay: float,
  247. maximize: bool,
  248. differentiable: bool,
  249. ):
  250. if len(params) == 0:
  251. return
  252. assert not differentiable, "_foreach ops don't support autograd"
  253. grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, axs, mus, etas, state_steps])
  254. for (grouped_params, grouped_grads, grouped_axs, grouped_mus,
  255. grouped_etas, grouped_state_steps) in grouped_tensors.values():
  256. if maximize:
  257. grouped_grads = torch._foreach_neg(grouped_grads)
  258. def _view_complex_as_real(tensor_list):
  259. return [
  260. torch.view_as_real(t) if torch.is_complex(t) else t for t in tensor_list
  261. ]
  262. grouped_grads = _view_complex_as_real(grouped_grads)
  263. grouped_params = _view_complex_as_real(grouped_params)
  264. grouped_axs = _view_complex_as_real(grouped_axs)
  265. # update step
  266. torch._foreach_add_(grouped_state_steps, 1)
  267. if weight_decay != 0:
  268. grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
  269. # decay term
  270. eta = _get_value(grouped_etas[0])
  271. torch._foreach_mul_(grouped_params, 1 - lambd * eta)
  272. # update parameter
  273. torch._foreach_add_(grouped_params, grouped_grads, alpha=-eta)
  274. # averaging
  275. for i in range(len(grouped_axs)):
  276. if is_compiling() or grouped_mus[i].item() != 1:
  277. grouped_axs[i].add_(grouped_params[i].sub(grouped_axs[i]).mul(grouped_mus[i]))
  278. else:
  279. grouped_axs[i].copy_(grouped_params[i])
  280. # update eta and mu
  281. for i in range(len(grouped_mus)):
  282. new_eta = _to_tensor(
  283. lr / (1 + lambd * lr * _get_value(grouped_state_steps[i]) ** alpha)
  284. )
  285. grouped_etas[i].copy_(new_eta)
  286. new_mu = _to_tensor(1 / max(1, _get_value(grouped_state_steps[i]) - t0))
  287. grouped_mus[i].copy_(new_mu)