rprop.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import torch
  2. from torch import Tensor
  3. from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach,
  4. _differentiable_doc, _foreach_doc, _maximize_doc)
  5. from typing import List, Optional
  6. from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
  7. __all__ = ["Rprop", "rprop"]
  8. class Rprop(Optimizer):
  9. def __init__(
  10. self,
  11. params,
  12. lr=1e-2,
  13. etas=(0.5, 1.2),
  14. step_sizes=(1e-6, 50),
  15. *,
  16. foreach: Optional[bool] = None,
  17. maximize: bool = False,
  18. differentiable: bool = False,
  19. ):
  20. if not 0.0 <= lr:
  21. raise ValueError("Invalid learning rate: {}".format(lr))
  22. if not 0.0 < etas[0] < 1.0 < etas[1]:
  23. raise ValueError("Invalid eta values: {}, {}".format(etas[0], etas[1]))
  24. defaults = dict(
  25. lr=lr,
  26. etas=etas,
  27. step_sizes=step_sizes,
  28. foreach=foreach,
  29. maximize=maximize,
  30. differentiable=differentiable,
  31. )
  32. super().__init__(params, defaults)
  33. def __setstate__(self, state):
  34. super().__setstate__(state)
  35. for group in self.param_groups:
  36. group.setdefault("foreach", None)
  37. group.setdefault("maximize", False)
  38. group.setdefault("differentiable", False)
  39. def _init_group(self, group, params, grads, prevs, step_sizes):
  40. for p in group["params"]:
  41. if p.grad is None:
  42. continue
  43. params.append(p)
  44. grad = p.grad
  45. if grad.is_sparse:
  46. raise RuntimeError("Rprop does not support sparse gradients")
  47. grads.append(grad)
  48. state = self.state[p]
  49. # State initialization
  50. if len(state) == 0:
  51. state["step"] = 0
  52. state["prev"] = torch.zeros_like(
  53. p, memory_format=torch.preserve_format
  54. )
  55. if p.dtype.is_complex:
  56. # Complex Number should be as if they are two independent real numbers.
  57. # Hence the step_size shouldn't be zero for imaginary part.
  58. state["step_size"] = (
  59. grad.new()
  60. .resize_as_(grad)
  61. .fill_(complex(group["lr"], group["lr"]))
  62. )
  63. else:
  64. state["step_size"] = (
  65. grad.new().resize_as_(grad).fill_(group["lr"])
  66. )
  67. prevs.append(state["prev"])
  68. step_sizes.append(state["step_size"])
  69. state["step"] += 1
  70. @_use_grad_for_differentiable
  71. def step(self, closure=None):
  72. """Performs a single optimization step.
  73. Args:
  74. closure (Callable, optional): A closure that reevaluates the model
  75. and returns the loss.
  76. """
  77. loss = None
  78. if closure is not None:
  79. with torch.enable_grad():
  80. loss = closure()
  81. for group in self.param_groups:
  82. params = []
  83. grads = []
  84. prevs = []
  85. step_sizes = []
  86. etaminus, etaplus = group["etas"]
  87. step_size_min, step_size_max = group["step_sizes"]
  88. foreach = group["foreach"]
  89. maximize = group["maximize"]
  90. self._init_group(group, params, grads, prevs, step_sizes)
  91. rprop(
  92. params,
  93. grads,
  94. prevs,
  95. step_sizes,
  96. step_size_min=step_size_min,
  97. step_size_max=step_size_max,
  98. etaminus=etaminus,
  99. etaplus=etaplus,
  100. foreach=foreach,
  101. maximize=maximize,
  102. differentiable=group["differentiable"],
  103. )
  104. return loss
  105. Rprop.__doc__ = r"""Implements the resilient backpropagation algorithm.
  106. .. math::
  107. \begin{aligned}
  108. &\rule{110mm}{0.4pt} \\
  109. &\textbf{input} : \theta_0 \in \mathbf{R}^d \text{ (params)},f(\theta)
  110. \text{ (objective)}, \\
  111. &\hspace{13mm} \eta_{+/-} \text{ (etaplus, etaminus)}, \Gamma_{max/min}
  112. \text{ (step sizes)} \\
  113. &\textbf{initialize} : g^0_{prev} \leftarrow 0,
  114. \: \eta_0 \leftarrow \text{lr (learning rate)} \\
  115. &\rule{110mm}{0.4pt} \\
  116. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  117. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  118. &\hspace{5mm} \textbf{for} \text{ } i = 0, 1, \ldots, d-1 \: \mathbf{do} \\
  119. &\hspace{10mm} \textbf{if} \: g^i_{prev} g^i_t > 0 \\
  120. &\hspace{15mm} \eta^i_t \leftarrow \mathrm{min}(\eta^i_{t-1} \eta_{+},
  121. \Gamma_{max}) \\
  122. &\hspace{10mm} \textbf{else if} \: g^i_{prev} g^i_t < 0 \\
  123. &\hspace{15mm} \eta^i_t \leftarrow \mathrm{max}(\eta^i_{t-1} \eta_{-},
  124. \Gamma_{min}) \\
  125. &\hspace{15mm} g^i_t \leftarrow 0 \\
  126. &\hspace{10mm} \textbf{else} \: \\
  127. &\hspace{15mm} \eta^i_t \leftarrow \eta^i_{t-1} \\
  128. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1}- \eta_t \mathrm{sign}(g_t) \\
  129. &\hspace{5mm}g_{prev} \leftarrow g_t \\
  130. &\rule{110mm}{0.4pt} \\[-1.ex]
  131. &\bf{return} \: \theta_t \\[-1.ex]
  132. &\rule{110mm}{0.4pt} \\[-1.ex]
  133. \end{aligned}
  134. For further details regarding the algorithm we refer to the paper
  135. `A Direct Adaptive Method for Faster Backpropagation Learning: The RPROP Algorithm
  136. <http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.21.1417>`_.
  137. """ + r"""
  138. Args:
  139. params (iterable): iterable of parameters to optimize or dicts defining
  140. parameter groups
  141. lr (float, optional): learning rate (default: 1e-2)
  142. etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that
  143. are multiplicative increase and decrease factors
  144. (default: (0.5, 1.2))
  145. step_sizes (Tuple[float, float], optional): a pair of minimal and
  146. maximal allowed step sizes (default: (1e-6, 50))
  147. {foreach}
  148. {maximize}
  149. {differentiable}
  150. """.format(foreach=_foreach_doc, maximize=_maximize_doc, differentiable=_differentiable_doc)
  151. def rprop(
  152. params: List[Tensor],
  153. grads: List[Tensor],
  154. prevs: List[Tensor],
  155. step_sizes: List[Tensor],
  156. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  157. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  158. foreach: Optional[bool] = None,
  159. maximize: bool = False,
  160. differentiable: bool = False,
  161. *,
  162. step_size_min: float,
  163. step_size_max: float,
  164. etaminus: float,
  165. etaplus: float,
  166. ):
  167. r"""Functional API that performs rprop algorithm computation.
  168. See :class:`~torch.optim.Rprop` for details.
  169. """
  170. if foreach is None:
  171. _, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
  172. if foreach and torch.jit.is_scripting():
  173. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  174. if foreach and not torch.jit.is_scripting():
  175. func = _multi_tensor_rprop
  176. else:
  177. func = _single_tensor_rprop
  178. func(
  179. params,
  180. grads,
  181. prevs,
  182. step_sizes,
  183. step_size_min=step_size_min,
  184. step_size_max=step_size_max,
  185. etaminus=etaminus,
  186. etaplus=etaplus,
  187. maximize=maximize,
  188. differentiable=differentiable,
  189. )
  190. def _single_tensor_rprop(
  191. params: List[Tensor],
  192. grads: List[Tensor],
  193. prevs: List[Tensor],
  194. step_sizes: List[Tensor],
  195. *,
  196. step_size_min: float,
  197. step_size_max: float,
  198. etaminus: float,
  199. etaplus: float,
  200. maximize: bool,
  201. differentiable: bool,
  202. ):
  203. for i, param in enumerate(params):
  204. grad = grads[i]
  205. grad = grad if not maximize else -grad
  206. prev = prevs[i]
  207. step_size = step_sizes[i]
  208. if torch.is_complex(param):
  209. grad = torch.view_as_real(grad)
  210. prev = torch.view_as_real(prev)
  211. param = torch.view_as_real(param)
  212. step_size = torch.view_as_real(step_size)
  213. if differentiable:
  214. sign = grad.mul(prev.clone()).sign()
  215. else:
  216. sign = grad.mul(prev).sign()
  217. sign[sign.gt(0)] = etaplus
  218. sign[sign.lt(0)] = etaminus
  219. sign[sign.eq(0)] = 1
  220. # update stepsizes with step size updates
  221. step_size.mul_(sign).clamp_(step_size_min, step_size_max)
  222. # for dir<0, dfdx=0
  223. # for dir>=0 dfdx=dfdx
  224. grad = grad.clone(memory_format=torch.preserve_format)
  225. grad[sign.eq(etaminus)] = 0
  226. # update parameters
  227. param.addcmul_(grad.sign(), step_size, value=-1)
  228. prev.copy_(grad)
  229. def _multi_tensor_rprop(
  230. params: List[Tensor],
  231. grads: List[Tensor],
  232. prevs: List[Tensor],
  233. step_sizes: List[Tensor],
  234. *,
  235. step_size_min: float,
  236. step_size_max: float,
  237. etaminus: float,
  238. etaplus: float,
  239. maximize: bool,
  240. differentiable: bool,
  241. ):
  242. if len(params) == 0:
  243. return
  244. assert not differentiable, "_foreach ops don't support autograd"
  245. grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, prevs, step_sizes])
  246. for grouped_params, grouped_grads, grouped_prevs, grouped_step_sizes in grouped_tensors.values():
  247. # Handle complex params
  248. def _view_complex_as_real(tensor_list):
  249. return [
  250. torch.view_as_real(t) if torch.is_complex(t) else t for t in tensor_list
  251. ]
  252. grouped_grads = _view_complex_as_real(grouped_grads)
  253. grouped_prevs = _view_complex_as_real(grouped_prevs)
  254. grouped_params = _view_complex_as_real(grouped_params)
  255. grouped_step_sizes = _view_complex_as_real(grouped_step_sizes)
  256. if maximize:
  257. grouped_grads = torch._foreach_neg(grouped_grads)
  258. signs = torch._foreach_mul(grouped_grads, grouped_prevs)
  259. signs = [s.sign() for s in signs]
  260. for sign in signs:
  261. sign[sign.gt(0)] = etaplus
  262. sign[sign.lt(0)] = etaminus
  263. sign[sign.eq(0)] = 1
  264. # update stepsizes with step size updates
  265. torch._foreach_mul_(grouped_step_sizes, signs)
  266. for step_size in grouped_step_sizes:
  267. step_size.clamp_(step_size_min, step_size_max)
  268. # for dir<0, dfdx=0
  269. # for dir>=0 dfdx=dfdx
  270. grouped_grads = list(grouped_grads)
  271. for i in range(len(grouped_grads)):
  272. grouped_grads[i] = grouped_grads[i].clone(memory_format=torch.preserve_format)
  273. grouped_grads[i][signs[i].eq(etaminus)] = 0
  274. # update parameters
  275. grad_signs = [grad.sign() for grad in grouped_grads]
  276. torch._foreach_addcmul_(grouped_params, grad_signs, grouped_step_sizes, value=-1)
  277. for i in range(len(grouped_prevs)):
  278. grouped_prevs[i].copy_(grouped_grads[i])