adagrad.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. import torch
  2. from torch import Tensor
  3. from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value,
  4. _default_to_fused_or_foreach, _differentiable_doc, _foreach_doc, _maximize_doc)
  5. from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
  6. from typing import List, Optional
  7. __all__ = ["Adagrad", "adagrad"]
  8. class Adagrad(Optimizer):
  9. def __init__(
  10. self,
  11. params,
  12. lr=1e-2,
  13. lr_decay=0,
  14. weight_decay=0,
  15. initial_accumulator_value=0,
  16. eps=1e-10,
  17. foreach: Optional[bool] = None,
  18. *,
  19. maximize: bool = False,
  20. differentiable: bool = False,
  21. ):
  22. if not 0.0 <= lr:
  23. raise ValueError("Invalid learning rate: {}".format(lr))
  24. if not 0.0 <= lr_decay:
  25. raise ValueError("Invalid lr_decay value: {}".format(lr_decay))
  26. if not 0.0 <= weight_decay:
  27. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  28. if not 0.0 <= initial_accumulator_value:
  29. raise ValueError(
  30. "Invalid initial_accumulator_value value: {}".format(
  31. initial_accumulator_value
  32. )
  33. )
  34. if not 0.0 <= eps:
  35. raise ValueError("Invalid epsilon value: {}".format(eps))
  36. defaults = dict(
  37. lr=lr,
  38. lr_decay=lr_decay,
  39. eps=eps,
  40. weight_decay=weight_decay,
  41. initial_accumulator_value=initial_accumulator_value,
  42. foreach=foreach,
  43. maximize=maximize,
  44. differentiable=differentiable,
  45. )
  46. super().__init__(params, defaults)
  47. for group in self.param_groups:
  48. for p in group["params"]:
  49. state = self.state[p]
  50. state["step"] = torch.tensor(0.0)
  51. init_value = (
  52. complex(initial_accumulator_value, initial_accumulator_value)
  53. if torch.is_complex(p)
  54. else initial_accumulator_value
  55. )
  56. state["sum"] = torch.full_like(
  57. p, init_value, memory_format=torch.preserve_format
  58. )
  59. def __setstate__(self, state):
  60. super().__setstate__(state)
  61. for group in self.param_groups:
  62. group.setdefault("foreach", None)
  63. group.setdefault("maximize", False)
  64. group.setdefault("differentiable", False)
  65. state_values = list(self.state.values())
  66. step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
  67. state_values[0]["step"]
  68. )
  69. if not step_is_tensor:
  70. for s in state_values:
  71. s["step"] = torch.tensor(float(s["step"]))
  72. def share_memory(self):
  73. for group in self.param_groups:
  74. for p in group["params"]:
  75. state = self.state[p]
  76. state["sum"].share_memory_()
  77. def _init_group(self, group, params_with_grad, grads, state_sums, state_steps):
  78. has_sparse_grad = False
  79. for p in group["params"]:
  80. if p.grad is not None:
  81. if p.grad.is_sparse:
  82. has_sparse_grad = True
  83. params_with_grad.append(p)
  84. grads.append(p.grad)
  85. state = self.state[p]
  86. state_sums.append(state["sum"])
  87. state_steps.append(state["step"])
  88. return has_sparse_grad
  89. @_use_grad_for_differentiable
  90. def step(self, closure=None):
  91. """Performs a single optimization step.
  92. Args:
  93. closure (Callable, optional): A closure that reevaluates the model
  94. and returns the loss.
  95. """
  96. loss = None
  97. if closure is not None:
  98. with torch.enable_grad():
  99. loss = closure()
  100. for group in self.param_groups:
  101. params_with_grad = []
  102. grads = []
  103. state_sums = []
  104. state_steps = []
  105. has_sparse_grad = self._init_group(group, params_with_grad, grads, state_sums, state_steps)
  106. adagrad(
  107. params_with_grad,
  108. grads,
  109. state_sums,
  110. state_steps,
  111. lr=group["lr"],
  112. weight_decay=group["weight_decay"],
  113. lr_decay=group["lr_decay"],
  114. eps=group["eps"],
  115. has_sparse_grad=has_sparse_grad,
  116. foreach=group["foreach"],
  117. maximize=group["maximize"],
  118. differentiable=group["differentiable"],
  119. )
  120. return loss
  121. Adagrad.__doc__ = r"""Implements Adagrad algorithm.
  122. .. math::
  123. \begin{aligned}
  124. &\rule{110mm}{0.4pt} \\
  125. &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
  126. \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
  127. &\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\
  128. &\textbf{initialize} : state\_sum_0 \leftarrow 0 \\[-1.ex]
  129. &\rule{110mm}{0.4pt} \\
  130. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  131. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  132. &\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\
  133. &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
  134. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  135. &\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\
  136. &\hspace{5mm}\theta_t \leftarrow
  137. \theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\
  138. &\rule{110mm}{0.4pt} \\[-1.ex]
  139. &\bf{return} \: \theta_t \\[-1.ex]
  140. &\rule{110mm}{0.4pt} \\[-1.ex]
  141. \end{aligned}
  142. For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning
  143. and Stochastic Optimization`_.
  144. """ + r"""
  145. Args:
  146. params (iterable): iterable of parameters to optimize or dicts defining
  147. parameter groups
  148. lr (float, optional): learning rate (default: 1e-2)
  149. lr_decay (float, optional): learning rate decay (default: 0)
  150. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  151. eps (float, optional): term added to the denominator to improve
  152. numerical stability (default: 1e-10)
  153. {foreach}
  154. {maximize}
  155. {differentiable}
  156. .. _Adaptive Subgradient Methods for Online Learning and Stochastic
  157. Optimization: http://jmlr.org/papers/v12/duchi11a.html
  158. """.format(foreach=_foreach_doc, maximize=_maximize_doc, differentiable=_differentiable_doc)
  159. def adagrad(
  160. params: List[Tensor],
  161. grads: List[Tensor],
  162. state_sums: List[Tensor],
  163. state_steps: List[Tensor],
  164. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  165. # setting these as kwargs for now as functional API is compiled by torch/distributed/optim
  166. has_sparse_grad: bool = None,
  167. foreach: Optional[bool] = None,
  168. differentiable: bool = False,
  169. *,
  170. lr: float,
  171. weight_decay: float,
  172. lr_decay: float,
  173. eps: float,
  174. maximize: bool,
  175. ):
  176. r"""Functional API that performs Adagrad algorithm computation.
  177. See :class:`~torch.optim.Adagrad` for details.
  178. """
  179. if not all(isinstance(t, torch.Tensor) for t in state_steps):
  180. raise RuntimeError(
  181. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  182. )
  183. if foreach is None:
  184. _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
  185. if foreach and torch.jit.is_scripting():
  186. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  187. if foreach and not torch.jit.is_scripting():
  188. func = _multi_tensor_adagrad
  189. else:
  190. func = _single_tensor_adagrad
  191. func(
  192. params,
  193. grads,
  194. state_sums,
  195. state_steps,
  196. lr=lr,
  197. weight_decay=weight_decay,
  198. lr_decay=lr_decay,
  199. eps=eps,
  200. has_sparse_grad=has_sparse_grad,
  201. maximize=maximize,
  202. differentiable=differentiable,
  203. )
  204. def _make_sparse(grad, grad_indices, values):
  205. size = grad.size()
  206. if grad_indices.numel() == 0 or values.numel() == 0:
  207. return torch.empty_like(grad)
  208. return torch.sparse_coo_tensor(grad_indices, values, size)
  209. def _single_tensor_adagrad(
  210. params: List[Tensor],
  211. grads: List[Tensor],
  212. state_sums: List[Tensor],
  213. state_steps: List[Tensor],
  214. *,
  215. lr: float,
  216. weight_decay: float,
  217. lr_decay: float,
  218. eps: float,
  219. has_sparse_grad: bool,
  220. maximize: bool,
  221. differentiable: bool,
  222. ):
  223. for (param, grad, state_sum, step_t) in zip(params, grads, state_sums, state_steps):
  224. # update step
  225. step_t += 1
  226. step = _get_value(step_t)
  227. grad = grad if not maximize else -grad
  228. if weight_decay != 0:
  229. if grad.is_sparse:
  230. raise RuntimeError(
  231. "weight_decay option is not compatible with sparse gradients"
  232. )
  233. grad = grad.add(param, alpha=weight_decay)
  234. clr = lr / (1 + (step - 1) * lr_decay)
  235. if grad.is_sparse:
  236. grad = grad.coalesce() # the update is non-linear so indices must be unique
  237. grad_indices = grad._indices()
  238. grad_values = grad._values()
  239. state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
  240. std = state_sum.sparse_mask(grad)
  241. std_values = std._values().sqrt_().add_(eps)
  242. param.add_(
  243. _make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr
  244. )
  245. else:
  246. is_complex = torch.is_complex(param)
  247. if is_complex:
  248. grad = torch.view_as_real(grad)
  249. state_sum = torch.view_as_real(state_sum)
  250. param = torch.view_as_real(param)
  251. state_sum.addcmul_(grad, grad, value=1)
  252. if differentiable:
  253. std = state_sum.sqrt() + eps
  254. else:
  255. std = state_sum.sqrt().add_(eps)
  256. param.addcdiv_(grad, std, value=-clr)
  257. if is_complex:
  258. param = torch.view_as_complex(param)
  259. state_sum = torch.view_as_complex(state_sum)
  260. def _multi_tensor_adagrad(
  261. params: List[Tensor],
  262. grads: List[Tensor],
  263. state_sums: List[Tensor],
  264. state_steps: List[Tensor],
  265. *,
  266. lr: float,
  267. weight_decay: float,
  268. lr_decay: float,
  269. eps: float,
  270. has_sparse_grad: bool,
  271. maximize: bool,
  272. differentiable: bool,
  273. ):
  274. assert not differentiable, "_foreach ops don't support autograd"
  275. # Foreach functions will throw errors if given empty lists
  276. if len(params) == 0:
  277. return
  278. grouped_tensorlists = _group_tensors_by_device_and_dtype([params, grads, state_sums, state_steps])
  279. for device_params, device_grads, device_state_sums, device_state_steps in grouped_tensorlists.values():
  280. if maximize:
  281. device_grads = torch._foreach_neg(device_grads)
  282. device_has_sparse_grad = any(grad.is_sparse for grad in device_grads)
  283. if device_has_sparse_grad:
  284. return _single_tensor_adagrad(
  285. device_params,
  286. device_grads,
  287. device_state_sums,
  288. device_state_steps,
  289. lr=lr,
  290. weight_decay=weight_decay,
  291. lr_decay=lr_decay,
  292. eps=eps,
  293. has_sparse_grad=True,
  294. maximize=False,
  295. differentiable=differentiable,
  296. )
  297. # Update steps
  298. torch._foreach_add_(device_state_steps, 1)
  299. if weight_decay != 0:
  300. device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
  301. minus_clr = [-lr / (1 + (step - 1) * lr_decay) for step in device_state_steps]
  302. device_grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_grads]
  303. device_state_sums = [
  304. torch.view_as_real(x) if torch.is_complex(x) else x for x in device_state_sums
  305. ]
  306. torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1)
  307. std = torch._foreach_add(torch._foreach_sqrt(device_state_sums), eps)
  308. toAdd = torch._foreach_div(torch._foreach_mul(device_grads, minus_clr), std)
  309. toAdd = [
  310. torch.view_as_complex(x) if torch.is_complex(device_params[i]) else x
  311. for i, x in enumerate(toAdd)
  312. ]
  313. torch._foreach_add_(device_params, toAdd)
  314. device_state_sums = [
  315. torch.view_as_complex(x) if torch.is_complex(device_params[i]) else x
  316. for i, x in enumerate(device_state_sums)
  317. ]