optimizer.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. from collections import OrderedDict, defaultdict, abc as container_abcs
  2. import torch
  3. from copy import deepcopy
  4. from itertools import chain
  5. import warnings
  6. import functools
  7. import math
  8. from typing import Callable, Dict, List, Tuple
  9. import torch.utils.hooks as hooks
  10. from torch.utils.hooks import RemovableHandle
  11. from torch._utils import is_compiling
  12. __all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook']
  13. _global_optimizer_pre_hooks: Dict[int, Callable] = OrderedDict()
  14. _global_optimizer_post_hooks: Dict[int, Callable] = OrderedDict()
  15. _foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
  16. class _RequiredParameter:
  17. """Singleton class representing a required parameter for an Optimizer."""
  18. def __repr__(self):
  19. return "<required parameter>"
  20. required = _RequiredParameter()
  21. def _use_grad_for_differentiable(func):
  22. def _use_grad(self, *args, **kwargs):
  23. prev_grad = torch.is_grad_enabled()
  24. try:
  25. torch.set_grad_enabled(self.defaults['differentiable'])
  26. ret = func(self, *args, **kwargs)
  27. finally:
  28. torch.set_grad_enabled(prev_grad)
  29. return ret
  30. return _use_grad
  31. def _get_value(x):
  32. # item is significantly faster than a cpu tensor in eager mode
  33. if not torch.jit.is_scripting() and is_compiling():
  34. return x
  35. else:
  36. return x.item()
  37. def _stack_if_compiling(x):
  38. if not torch.jit.is_scripting() and is_compiling():
  39. return torch.stack(x)
  40. else:
  41. return x
  42. def _dispatch_sqrt(x: float): # float annotation is needed because of torchscript type inference
  43. if not torch.jit.is_scripting() and isinstance(x, torch.Tensor):
  44. return x.sqrt()
  45. else:
  46. return math.sqrt(x)
  47. # For any optimizer with a faster implementation, we attempt to default to the
  48. # fastest + stablest whenever possible. For foreach, the requirements are to have
  49. # native params all on CUDA. For fused, there's currently the additional requirement
  50. # that the tensors' dtypes must be floating point. Neither alternative supports
  51. # torch.jit.script nor differentiable, so we fall back to the single tensor
  52. # implementation in those cases.
  53. def _default_to_fused_or_foreach(params: List[torch.Tensor],
  54. differentiable: bool,
  55. use_fused: bool = False) -> Tuple[bool, bool]:
  56. if torch.jit.is_scripting() or differentiable:
  57. return False, False
  58. fused = use_fused and all(
  59. p is None or (type(p) in _foreach_supported_types and p.is_cuda and torch.is_floating_point(p)) for p in params
  60. )
  61. foreach = not fused and all(
  62. p is None or (type(p) in _foreach_supported_types and p.is_cuda) for p in params
  63. )
  64. return fused, foreach
  65. # Common doc strings among optimizers
  66. _foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer
  67. is used. If unspecified by the user (so foreach is None), we will try to use
  68. foreach over the for-loop implementation on CUDA, since it is usually
  69. significantly more performant. (default: None)"""
  70. _fused_doc = r"""fused (bool, optional): whether the fused implementation (CUDA only) is used.
  71. Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
  72. are supported. (default: None)
  73. .. note:: The foreach and fused implementations are typically faster than the for-loop,
  74. single-tensor implementation. Thus, if the user has not specified BOTH flags
  75. (i.e., when foreach = fused = None), we will attempt defaulting to the foreach
  76. implementation when the tensors are all on CUDA. For example, if the user specifies
  77. True for fused but nothing for foreach, we will run the fused implementation. If
  78. the user specifies False for foreach but nothing for fused (or False for fused but
  79. nothing for foreach), we will run the for-loop implementation. If the user specifies
  80. True for both foreach and fused, we will prioritize fused over foreach, as it is
  81. typically faster. We attempt to use the fastest, so the hierarchy goes fused ->
  82. foreach -> for-loop. HOWEVER, since the fused implementation is relatively new,
  83. we want to give it sufficient bake-in time, so we default to foreach and NOT
  84. fused when the user has not specified either flag."""
  85. _capturable_doc = r"""capturable (bool, optional): whether this instance is safe to
  86. capture in a CUDA graph. Passing True can impair ungraphed performance,
  87. so if you don't intend to graph capture this instance, leave it False
  88. (default: False)"""
  89. _differentiable_doc = r"""differentiable (bool, optional): whether autograd should
  90. occur through the optimizer step in training. Otherwise, the step()
  91. function runs in a torch.no_grad() context. Setting to True can impair
  92. performance, so leave it False if you don't intend to run autograd
  93. through this instance (default: False)"""
  94. _maximize_doc = r"""maximize (bool, optional): maximize the params based on the
  95. objective, instead of minimizing (default: False)"""
  96. def register_optimizer_step_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
  97. r"""Register a pre hook common to all optimizers. The hook should have the following
  98. signature::
  99. hook(optimizer, args, kwargs) -> None or modified args and kwargs
  100. Args:
  101. hook (Callable): A user defined hook which is registered on all optimizers.
  102. Returns:
  103. :class:`torch.utils.hooks.RemoveableHandle`:
  104. a handle that can be used to remove the added hook by calling
  105. ``handle.remove()``
  106. """
  107. handle = hooks.RemovableHandle(_global_optimizer_pre_hooks)
  108. _global_optimizer_pre_hooks[handle.id] = hook
  109. return handle
  110. def register_optimizer_step_post_hook(hook: Callable[..., None]) -> RemovableHandle:
  111. r"""Register a post hook common to all optimizers. The hook should have the following
  112. signature::
  113. hook(optimizer, args, kwargs) -> None
  114. Args:
  115. hook (Callable): A user defined hook which is registered on all optimizers.
  116. Returns:
  117. :class:`torch.utils.hooks.RemoveableHandle`:
  118. a handle that can be used to remove the added hook by calling
  119. ``handle.remove()``
  120. """
  121. handle = hooks.RemovableHandle(_global_optimizer_post_hooks)
  122. _global_optimizer_post_hooks[handle.id] = hook
  123. return handle
  124. class Optimizer:
  125. r"""Base class for all optimizers.
  126. .. warning::
  127. Parameters need to be specified as collections that have a deterministic
  128. ordering that is consistent between runs. Examples of objects that don't
  129. satisfy those properties are sets and iterators over values of dictionaries.
  130. Args:
  131. params (iterable): an iterable of :class:`torch.Tensor` s or
  132. :class:`dict` s. Specifies what Tensors should be optimized.
  133. defaults: (dict): a dict containing default values of optimization
  134. options (used when a parameter group doesn't specify them).
  135. """
  136. def __init__(self, params, defaults):
  137. torch._C._log_api_usage_once("python.optimizer")
  138. self.defaults = defaults
  139. self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
  140. self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
  141. self._patch_step_function()
  142. if isinstance(params, torch.Tensor):
  143. raise TypeError("params argument given to the optimizer should be "
  144. "an iterable of Tensors or dicts, but got " +
  145. torch.typename(params))
  146. self.state = defaultdict(dict)
  147. self.param_groups = []
  148. param_groups = list(params)
  149. if len(param_groups) == 0:
  150. raise ValueError("optimizer got an empty parameter list")
  151. if not isinstance(param_groups[0], dict):
  152. param_groups = [{'params': param_groups}]
  153. for param_group in param_groups:
  154. self.add_param_group(param_group)
  155. # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python,
  156. # which I don't think exists
  157. # https://github.com/pytorch/pytorch/issues/72948
  158. self._warned_capturable_if_run_uncaptured = True
  159. def __getstate__(self):
  160. return {
  161. 'defaults': self.defaults,
  162. 'state': self.state,
  163. 'param_groups': self.param_groups,
  164. }
  165. def __setstate__(self, state):
  166. self.__dict__.update(state)
  167. if '_optimizer_step_pre_hooks' not in self.__dict__:
  168. self._optimizer_step_pre_hooks = OrderedDict()
  169. if '_optimizer_step_post_hooks' not in self.__dict__:
  170. self._optimizer_step_post_hooks = OrderedDict()
  171. self._patch_step_function() # To support multiprocessing pickle/unpickle
  172. self.defaults.setdefault('differentiable', False)
  173. def __repr__(self):
  174. format_string = self.__class__.__name__ + ' ('
  175. for i, group in enumerate(self.param_groups):
  176. format_string += '\n'
  177. format_string += 'Parameter Group {0}\n'.format(i)
  178. for key in sorted(group.keys()):
  179. if key != 'params':
  180. format_string += ' {0}: {1}\n'.format(key, group[key])
  181. format_string += ')'
  182. return format_string
  183. # Currently needed by Adam and AdamW
  184. def _cuda_graph_capture_health_check(self):
  185. if torch.has_cuda and torch.cuda.is_available():
  186. capturing = torch.cuda.is_current_stream_capturing()
  187. if capturing and not all(group['capturable'] for group in self.param_groups):
  188. raise RuntimeError("Attempting CUDA graph capture of step() for an instance of " +
  189. self.__class__.__name__ +
  190. " but param_groups' capturable is False.")
  191. if (
  192. (not getattr(self, "_warned_capturable_if_run_uncaptured", False))
  193. and all(group['capturable'] for group in self.param_groups)
  194. and (not capturing)
  195. ):
  196. warnings.warn(
  197. "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, "
  198. "but step() is running without CUDA graph capture. If you never intend to graph-capture this "
  199. "instance, capturable=True can impair performance, and you should set capturable=False."
  200. )
  201. self._warned_capturable_if_run_uncaptured = True
  202. def _optimizer_step_code(self):
  203. """Entry point for `torch.profile.profiler`.
  204. When python tracing is enabled the profiler will hook into this
  205. function at the CPython level to inspect the optimizer's parameters and
  206. param groups. It is called it after `step()` since many optimizers
  207. lazily initialize state.
  208. This is a workaround due to lack of a proper step hook on the optimizer,
  209. and will be removed if it exists.
  210. """
  211. pass
  212. @staticmethod
  213. def profile_hook_step(func):
  214. @functools.wraps(func)
  215. def wrapper(*args, **kwargs):
  216. self, *_ = args
  217. profile_name = "Optimizer.step#{}.step".format(self.__class__.__name__)
  218. with torch.autograd.profiler.record_function(profile_name):
  219. # call optimizer step pre hooks
  220. for pre_hook in chain(_global_optimizer_pre_hooks.values(), self._optimizer_step_pre_hooks.values()):
  221. result = pre_hook(self, args, kwargs)
  222. if result is not None:
  223. if isinstance(result, tuple) and len(result) == 2:
  224. args, kwargs = result
  225. else:
  226. raise RuntimeError(f"{func} must return None or a tuple of (new_args, new_kwargs),"
  227. f"but got {result}.")
  228. out = func(*args, **kwargs)
  229. self._optimizer_step_code()
  230. # call optimizer step post hooks
  231. for post_hook in chain(self._optimizer_step_post_hooks.values(), _global_optimizer_post_hooks.values()):
  232. post_hook(self, args, kwargs)
  233. return out
  234. return wrapper
  235. def _patch_step_function(self):
  236. self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__)
  237. hooked = getattr(self.__class__.step, "hooked", None)
  238. if not hooked:
  239. self.__class__.step = self.profile_hook_step(self.__class__.step)
  240. self.__class__.step.hooked = True
  241. def register_step_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
  242. r"""Register an optimizer step pre hook which will be called before
  243. optimizer step. It should have the following signature::
  244. hook(optimizer, args, kwargs) -> None or modified args and kwargs
  245. The ``optimizer`` argument is the optimizer instance being used. If
  246. args and kwargs are modified by the pre-hook, then the transformed
  247. values are returned as a tuple containing the new_args and new_kwargs.
  248. Args:
  249. hook (Callable): The user defined hook to be registered.
  250. Returns:
  251. :class:`torch.utils.hooks.RemoveableHandle`:
  252. a handle that can be used to remove the added hook by calling
  253. ``handle.remove()``
  254. """
  255. handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks)
  256. self._optimizer_step_pre_hooks[handle.id] = hook
  257. return handle
  258. def register_step_post_hook(self, hook: Callable[..., None]) -> RemovableHandle:
  259. r"""Register an optimizer step post hook which will be called after optimizer step.
  260. It should have the following signature::
  261. hook(optimizer, args, kwargs) -> None
  262. The ``optimizer`` argument is the optimizer instance being used.
  263. Args:
  264. hook (Callable): The user defined hook to be registered.
  265. Returns:
  266. :class:`torch.utils.hooks.RemoveableHandle`:
  267. a handle that can be used to remove the added hook by calling
  268. ``handle.remove()``
  269. """
  270. handle = hooks.RemovableHandle(self._optimizer_step_post_hooks)
  271. self._optimizer_step_post_hooks[handle.id] = hook
  272. return handle
  273. def state_dict(self):
  274. r"""Returns the state of the optimizer as a :class:`dict`.
  275. It contains two entries:
  276. * state - a dict holding current optimization state. Its content
  277. differs between optimizer classes.
  278. * param_groups - a list containing all parameter groups where each
  279. parameter group is a dict
  280. """
  281. # Save order indices instead of Tensors
  282. param_mappings = {}
  283. start_index = 0
  284. def pack_group(group):
  285. nonlocal start_index
  286. packed = {k: v for k, v in group.items() if k != 'params'}
  287. param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
  288. if id(p) not in param_mappings})
  289. packed['params'] = [param_mappings[id(p)] for p in group['params']]
  290. start_index += len(packed['params'])
  291. return packed
  292. param_groups = [pack_group(g) for g in self.param_groups]
  293. # Remap state to use order indices as keys
  294. packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
  295. for k, v in self.state.items()}
  296. return {
  297. 'state': packed_state,
  298. 'param_groups': param_groups,
  299. }
  300. def load_state_dict(self, state_dict):
  301. r"""Loads the optimizer state.
  302. Args:
  303. state_dict (dict): optimizer state. Should be an object returned
  304. from a call to :meth:`state_dict`.
  305. """
  306. # deepcopy, to be consistent with module API
  307. state_dict = deepcopy(state_dict)
  308. # Validate the state_dict
  309. groups = self.param_groups
  310. saved_groups = state_dict['param_groups']
  311. if len(groups) != len(saved_groups):
  312. raise ValueError("loaded state dict has a different number of "
  313. "parameter groups")
  314. param_lens = (len(g['params']) for g in groups)
  315. saved_lens = (len(g['params']) for g in saved_groups)
  316. if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
  317. raise ValueError("loaded state dict contains a parameter group "
  318. "that doesn't match the size of optimizer's group")
  319. # Update the state
  320. id_map = {old_id: p for old_id, p in
  321. zip(chain.from_iterable((g['params'] for g in saved_groups)),
  322. chain.from_iterable((g['params'] for g in groups)))}
  323. def cast(param, value, key=None):
  324. r"""Make a deep copy of value, casting all tensors to device of param."""
  325. if isinstance(value, torch.Tensor):
  326. # Floating-point types are a bit special here. They are the only ones
  327. # that are assumed to always match the type of params.
  328. # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
  329. if (key != "step"):
  330. if param.is_floating_point():
  331. value = value.to(param.dtype)
  332. value = value.to(param.device)
  333. return value
  334. elif isinstance(value, dict):
  335. return {k: cast(param, v, key=k) for k, v in value.items()}
  336. elif isinstance(value, container_abcs.Iterable):
  337. return type(value)(cast(param, v) for v in value)
  338. else:
  339. return value
  340. # Copy state assigned to params (and cast tensors to appropriate types).
  341. # State that is not assigned to params is copied as is (needed for
  342. # backward compatibility).
  343. state = defaultdict(dict)
  344. for k, v in state_dict['state'].items():
  345. if k in id_map:
  346. param = id_map[k]
  347. state[param] = cast(param, v)
  348. else:
  349. state[k] = v
  350. # Update parameter groups, setting their 'params' value
  351. def update_group(group, new_group):
  352. new_group['params'] = group['params']
  353. return new_group
  354. param_groups = [
  355. update_group(g, ng) for g, ng in zip(groups, saved_groups)]
  356. self.__setstate__({'state': state, 'param_groups': param_groups})
  357. def zero_grad(self, set_to_none: bool = True):
  358. r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
  359. Args:
  360. set_to_none (bool): instead of setting to zero, set the grads to None.
  361. This will in general have lower memory footprint, and can modestly improve performance.
  362. However, it changes certain behaviors. For example:
  363. 1. When the user tries to access a gradient and perform manual ops on it,
  364. a None attribute or a Tensor full of 0s will behave differently.
  365. 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
  366. are guaranteed to be None for params that did not receive a gradient.
  367. 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
  368. (in one case it does the step with a gradient of 0 and in the other it skips
  369. the step altogether).
  370. """
  371. foreach = self.defaults.get('foreach', False)
  372. if not hasattr(self, "_zero_grad_profile_name"):
  373. self._patch_step_function()
  374. if foreach:
  375. per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
  376. with torch.autograd.profiler.record_function(self._zero_grad_profile_name):
  377. for group in self.param_groups:
  378. for p in group['params']:
  379. if p.grad is not None:
  380. if set_to_none:
  381. p.grad = None
  382. else:
  383. if p.grad.grad_fn is not None:
  384. p.grad.detach_()
  385. else:
  386. p.grad.requires_grad_(False)
  387. if (not foreach or p.grad.is_sparse):
  388. p.grad.zero_()
  389. else:
  390. per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad)
  391. if foreach:
  392. for _, per_dtype_grads in per_device_and_dtype_grads.items():
  393. for grads in per_dtype_grads.values():
  394. torch._foreach_zero_(grads)
  395. def step(self, closure):
  396. r"""Performs a single optimization step (parameter update).
  397. Args:
  398. closure (Callable): A closure that reevaluates the model and
  399. returns the loss. Optional for most optimizers.
  400. .. note::
  401. Unless otherwise specified, this function should not modify the
  402. ``.grad`` field of the parameters.
  403. """
  404. raise NotImplementedError
  405. def add_param_group(self, param_group):
  406. r"""Add a param group to the :class:`Optimizer` s `param_groups`.
  407. This can be useful when fine tuning a pre-trained network as frozen layers can be made
  408. trainable and added to the :class:`Optimizer` as training progresses.
  409. Args:
  410. param_group (dict): Specifies what Tensors should be optimized along with group
  411. specific optimization options.
  412. """
  413. assert isinstance(param_group, dict), "param group must be a dict"
  414. params = param_group['params']
  415. if isinstance(params, torch.Tensor):
  416. param_group['params'] = [params]
  417. elif isinstance(params, set):
  418. raise TypeError('optimizer parameters need to be organized in ordered collections, but '
  419. 'the ordering of tensors in sets will change between runs. Please use a list instead.')
  420. else:
  421. param_group['params'] = list(params)
  422. for param in param_group['params']:
  423. if not isinstance(param, torch.Tensor):
  424. raise TypeError("optimizer can only optimize Tensors, "
  425. "but one of the params is " + torch.typename(param))
  426. if not self.defaults.get('differentiable', None) and not (param.is_leaf or param.retains_grad):
  427. raise ValueError("can't optimize a non-leaf Tensor")
  428. for name, default in self.defaults.items():
  429. if default is required and name not in param_group:
  430. raise ValueError("parameter group didn't specify a value of required optimization parameter " +
  431. name)
  432. else:
  433. param_group.setdefault(name, default)
  434. params = param_group['params']
  435. if len(params) != len(set(params)):
  436. warnings.warn("optimizer contains a parameter group with duplicate parameters; "
  437. "in future, this will cause an error; "
  438. "see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3)
  439. param_set = set()
  440. for group in self.param_groups:
  441. param_set.update(set(group['params']))
  442. if not param_set.isdisjoint(set(param_group['params'])):
  443. raise ValueError("some parameters appear in more than one parameter group")
  444. self.param_groups.append(param_group)