123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537 |
- from collections import OrderedDict, defaultdict, abc as container_abcs
- import torch
- from copy import deepcopy
- from itertools import chain
- import warnings
- import functools
- import math
- from typing import Callable, Dict, List, Tuple
- import torch.utils.hooks as hooks
- from torch.utils.hooks import RemovableHandle
- from torch._utils import is_compiling
- __all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook']
- _global_optimizer_pre_hooks: Dict[int, Callable] = OrderedDict()
- _global_optimizer_post_hooks: Dict[int, Callable] = OrderedDict()
- _foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
- class _RequiredParameter:
- """Singleton class representing a required parameter for an Optimizer."""
- def __repr__(self):
- return "<required parameter>"
- required = _RequiredParameter()
- def _use_grad_for_differentiable(func):
- def _use_grad(self, *args, **kwargs):
- prev_grad = torch.is_grad_enabled()
- try:
- torch.set_grad_enabled(self.defaults['differentiable'])
- ret = func(self, *args, **kwargs)
- finally:
- torch.set_grad_enabled(prev_grad)
- return ret
- return _use_grad
- def _get_value(x):
- # item is significantly faster than a cpu tensor in eager mode
- if not torch.jit.is_scripting() and is_compiling():
- return x
- else:
- return x.item()
- def _stack_if_compiling(x):
- if not torch.jit.is_scripting() and is_compiling():
- return torch.stack(x)
- else:
- return x
- def _dispatch_sqrt(x: float): # float annotation is needed because of torchscript type inference
- if not torch.jit.is_scripting() and isinstance(x, torch.Tensor):
- return x.sqrt()
- else:
- return math.sqrt(x)
- # For any optimizer with a faster implementation, we attempt to default to the
- # fastest + stablest whenever possible. For foreach, the requirements are to have
- # native params all on CUDA. For fused, there's currently the additional requirement
- # that the tensors' dtypes must be floating point. Neither alternative supports
- # torch.jit.script nor differentiable, so we fall back to the single tensor
- # implementation in those cases.
- def _default_to_fused_or_foreach(params: List[torch.Tensor],
- differentiable: bool,
- use_fused: bool = False) -> Tuple[bool, bool]:
- if torch.jit.is_scripting() or differentiable:
- return False, False
- fused = use_fused and all(
- p is None or (type(p) in _foreach_supported_types and p.is_cuda and torch.is_floating_point(p)) for p in params
- )
- foreach = not fused and all(
- p is None or (type(p) in _foreach_supported_types and p.is_cuda) for p in params
- )
- return fused, foreach
- # Common doc strings among optimizers
- _foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer
- is used. If unspecified by the user (so foreach is None), we will try to use
- foreach over the for-loop implementation on CUDA, since it is usually
- significantly more performant. (default: None)"""
- _fused_doc = r"""fused (bool, optional): whether the fused implementation (CUDA only) is used.
- Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
- are supported. (default: None)
- .. note:: The foreach and fused implementations are typically faster than the for-loop,
- single-tensor implementation. Thus, if the user has not specified BOTH flags
- (i.e., when foreach = fused = None), we will attempt defaulting to the foreach
- implementation when the tensors are all on CUDA. For example, if the user specifies
- True for fused but nothing for foreach, we will run the fused implementation. If
- the user specifies False for foreach but nothing for fused (or False for fused but
- nothing for foreach), we will run the for-loop implementation. If the user specifies
- True for both foreach and fused, we will prioritize fused over foreach, as it is
- typically faster. We attempt to use the fastest, so the hierarchy goes fused ->
- foreach -> for-loop. HOWEVER, since the fused implementation is relatively new,
- we want to give it sufficient bake-in time, so we default to foreach and NOT
- fused when the user has not specified either flag."""
- _capturable_doc = r"""capturable (bool, optional): whether this instance is safe to
- capture in a CUDA graph. Passing True can impair ungraphed performance,
- so if you don't intend to graph capture this instance, leave it False
- (default: False)"""
- _differentiable_doc = r"""differentiable (bool, optional): whether autograd should
- occur through the optimizer step in training. Otherwise, the step()
- function runs in a torch.no_grad() context. Setting to True can impair
- performance, so leave it False if you don't intend to run autograd
- through this instance (default: False)"""
- _maximize_doc = r"""maximize (bool, optional): maximize the params based on the
- objective, instead of minimizing (default: False)"""
- def register_optimizer_step_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
- r"""Register a pre hook common to all optimizers. The hook should have the following
- signature::
- hook(optimizer, args, kwargs) -> None or modified args and kwargs
- Args:
- hook (Callable): A user defined hook which is registered on all optimizers.
- Returns:
- :class:`torch.utils.hooks.RemoveableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- handle = hooks.RemovableHandle(_global_optimizer_pre_hooks)
- _global_optimizer_pre_hooks[handle.id] = hook
- return handle
- def register_optimizer_step_post_hook(hook: Callable[..., None]) -> RemovableHandle:
- r"""Register a post hook common to all optimizers. The hook should have the following
- signature::
- hook(optimizer, args, kwargs) -> None
- Args:
- hook (Callable): A user defined hook which is registered on all optimizers.
- Returns:
- :class:`torch.utils.hooks.RemoveableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- handle = hooks.RemovableHandle(_global_optimizer_post_hooks)
- _global_optimizer_post_hooks[handle.id] = hook
- return handle
- class Optimizer:
- r"""Base class for all optimizers.
- .. warning::
- Parameters need to be specified as collections that have a deterministic
- ordering that is consistent between runs. Examples of objects that don't
- satisfy those properties are sets and iterators over values of dictionaries.
- Args:
- params (iterable): an iterable of :class:`torch.Tensor` s or
- :class:`dict` s. Specifies what Tensors should be optimized.
- defaults: (dict): a dict containing default values of optimization
- options (used when a parameter group doesn't specify them).
- """
- def __init__(self, params, defaults):
- torch._C._log_api_usage_once("python.optimizer")
- self.defaults = defaults
- self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
- self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
- self._patch_step_function()
- if isinstance(params, torch.Tensor):
- raise TypeError("params argument given to the optimizer should be "
- "an iterable of Tensors or dicts, but got " +
- torch.typename(params))
- self.state = defaultdict(dict)
- self.param_groups = []
- param_groups = list(params)
- if len(param_groups) == 0:
- raise ValueError("optimizer got an empty parameter list")
- if not isinstance(param_groups[0], dict):
- param_groups = [{'params': param_groups}]
- for param_group in param_groups:
- self.add_param_group(param_group)
- # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python,
- # which I don't think exists
- # https://github.com/pytorch/pytorch/issues/72948
- self._warned_capturable_if_run_uncaptured = True
- def __getstate__(self):
- return {
- 'defaults': self.defaults,
- 'state': self.state,
- 'param_groups': self.param_groups,
- }
- def __setstate__(self, state):
- self.__dict__.update(state)
- if '_optimizer_step_pre_hooks' not in self.__dict__:
- self._optimizer_step_pre_hooks = OrderedDict()
- if '_optimizer_step_post_hooks' not in self.__dict__:
- self._optimizer_step_post_hooks = OrderedDict()
- self._patch_step_function() # To support multiprocessing pickle/unpickle
- self.defaults.setdefault('differentiable', False)
- def __repr__(self):
- format_string = self.__class__.__name__ + ' ('
- for i, group in enumerate(self.param_groups):
- format_string += '\n'
- format_string += 'Parameter Group {0}\n'.format(i)
- for key in sorted(group.keys()):
- if key != 'params':
- format_string += ' {0}: {1}\n'.format(key, group[key])
- format_string += ')'
- return format_string
- # Currently needed by Adam and AdamW
- def _cuda_graph_capture_health_check(self):
- if torch.has_cuda and torch.cuda.is_available():
- capturing = torch.cuda.is_current_stream_capturing()
- if capturing and not all(group['capturable'] for group in self.param_groups):
- raise RuntimeError("Attempting CUDA graph capture of step() for an instance of " +
- self.__class__.__name__ +
- " but param_groups' capturable is False.")
- if (
- (not getattr(self, "_warned_capturable_if_run_uncaptured", False))
- and all(group['capturable'] for group in self.param_groups)
- and (not capturing)
- ):
- warnings.warn(
- "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, "
- "but step() is running without CUDA graph capture. If you never intend to graph-capture this "
- "instance, capturable=True can impair performance, and you should set capturable=False."
- )
- self._warned_capturable_if_run_uncaptured = True
- def _optimizer_step_code(self):
- """Entry point for `torch.profile.profiler`.
- When python tracing is enabled the profiler will hook into this
- function at the CPython level to inspect the optimizer's parameters and
- param groups. It is called it after `step()` since many optimizers
- lazily initialize state.
- This is a workaround due to lack of a proper step hook on the optimizer,
- and will be removed if it exists.
- """
- pass
- @staticmethod
- def profile_hook_step(func):
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- self, *_ = args
- profile_name = "Optimizer.step#{}.step".format(self.__class__.__name__)
- with torch.autograd.profiler.record_function(profile_name):
- # call optimizer step pre hooks
- for pre_hook in chain(_global_optimizer_pre_hooks.values(), self._optimizer_step_pre_hooks.values()):
- result = pre_hook(self, args, kwargs)
- if result is not None:
- if isinstance(result, tuple) and len(result) == 2:
- args, kwargs = result
- else:
- raise RuntimeError(f"{func} must return None or a tuple of (new_args, new_kwargs),"
- f"but got {result}.")
- out = func(*args, **kwargs)
- self._optimizer_step_code()
- # call optimizer step post hooks
- for post_hook in chain(self._optimizer_step_post_hooks.values(), _global_optimizer_post_hooks.values()):
- post_hook(self, args, kwargs)
- return out
- return wrapper
- def _patch_step_function(self):
- self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__)
- hooked = getattr(self.__class__.step, "hooked", None)
- if not hooked:
- self.__class__.step = self.profile_hook_step(self.__class__.step)
- self.__class__.step.hooked = True
- def register_step_pre_hook(self, hook: Callable[..., None]) -> RemovableHandle:
- r"""Register an optimizer step pre hook which will be called before
- optimizer step. It should have the following signature::
- hook(optimizer, args, kwargs) -> None or modified args and kwargs
- The ``optimizer`` argument is the optimizer instance being used. If
- args and kwargs are modified by the pre-hook, then the transformed
- values are returned as a tuple containing the new_args and new_kwargs.
- Args:
- hook (Callable): The user defined hook to be registered.
- Returns:
- :class:`torch.utils.hooks.RemoveableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks)
- self._optimizer_step_pre_hooks[handle.id] = hook
- return handle
- def register_step_post_hook(self, hook: Callable[..., None]) -> RemovableHandle:
- r"""Register an optimizer step post hook which will be called after optimizer step.
- It should have the following signature::
- hook(optimizer, args, kwargs) -> None
- The ``optimizer`` argument is the optimizer instance being used.
- Args:
- hook (Callable): The user defined hook to be registered.
- Returns:
- :class:`torch.utils.hooks.RemoveableHandle`:
- a handle that can be used to remove the added hook by calling
- ``handle.remove()``
- """
- handle = hooks.RemovableHandle(self._optimizer_step_post_hooks)
- self._optimizer_step_post_hooks[handle.id] = hook
- return handle
- def state_dict(self):
- r"""Returns the state of the optimizer as a :class:`dict`.
- It contains two entries:
- * state - a dict holding current optimization state. Its content
- differs between optimizer classes.
- * param_groups - a list containing all parameter groups where each
- parameter group is a dict
- """
- # Save order indices instead of Tensors
- param_mappings = {}
- start_index = 0
- def pack_group(group):
- nonlocal start_index
- packed = {k: v for k, v in group.items() if k != 'params'}
- param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
- if id(p) not in param_mappings})
- packed['params'] = [param_mappings[id(p)] for p in group['params']]
- start_index += len(packed['params'])
- return packed
- param_groups = [pack_group(g) for g in self.param_groups]
- # Remap state to use order indices as keys
- packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
- for k, v in self.state.items()}
- return {
- 'state': packed_state,
- 'param_groups': param_groups,
- }
- def load_state_dict(self, state_dict):
- r"""Loads the optimizer state.
- Args:
- state_dict (dict): optimizer state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- # deepcopy, to be consistent with module API
- state_dict = deepcopy(state_dict)
- # Validate the state_dict
- groups = self.param_groups
- saved_groups = state_dict['param_groups']
- if len(groups) != len(saved_groups):
- raise ValueError("loaded state dict has a different number of "
- "parameter groups")
- param_lens = (len(g['params']) for g in groups)
- saved_lens = (len(g['params']) for g in saved_groups)
- if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
- raise ValueError("loaded state dict contains a parameter group "
- "that doesn't match the size of optimizer's group")
- # Update the state
- id_map = {old_id: p for old_id, p in
- zip(chain.from_iterable((g['params'] for g in saved_groups)),
- chain.from_iterable((g['params'] for g in groups)))}
- def cast(param, value, key=None):
- r"""Make a deep copy of value, casting all tensors to device of param."""
- if isinstance(value, torch.Tensor):
- # Floating-point types are a bit special here. They are the only ones
- # that are assumed to always match the type of params.
- # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
- if (key != "step"):
- if param.is_floating_point():
- value = value.to(param.dtype)
- value = value.to(param.device)
- return value
- elif isinstance(value, dict):
- return {k: cast(param, v, key=k) for k, v in value.items()}
- elif isinstance(value, container_abcs.Iterable):
- return type(value)(cast(param, v) for v in value)
- else:
- return value
- # Copy state assigned to params (and cast tensors to appropriate types).
- # State that is not assigned to params is copied as is (needed for
- # backward compatibility).
- state = defaultdict(dict)
- for k, v in state_dict['state'].items():
- if k in id_map:
- param = id_map[k]
- state[param] = cast(param, v)
- else:
- state[k] = v
- # Update parameter groups, setting their 'params' value
- def update_group(group, new_group):
- new_group['params'] = group['params']
- return new_group
- param_groups = [
- update_group(g, ng) for g, ng in zip(groups, saved_groups)]
- self.__setstate__({'state': state, 'param_groups': param_groups})
- def zero_grad(self, set_to_none: bool = True):
- r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
- Args:
- set_to_none (bool): instead of setting to zero, set the grads to None.
- This will in general have lower memory footprint, and can modestly improve performance.
- However, it changes certain behaviors. For example:
- 1. When the user tries to access a gradient and perform manual ops on it,
- a None attribute or a Tensor full of 0s will behave differently.
- 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
- are guaranteed to be None for params that did not receive a gradient.
- 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
- (in one case it does the step with a gradient of 0 and in the other it skips
- the step altogether).
- """
- foreach = self.defaults.get('foreach', False)
- if not hasattr(self, "_zero_grad_profile_name"):
- self._patch_step_function()
- if foreach:
- per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
- with torch.autograd.profiler.record_function(self._zero_grad_profile_name):
- for group in self.param_groups:
- for p in group['params']:
- if p.grad is not None:
- if set_to_none:
- p.grad = None
- else:
- if p.grad.grad_fn is not None:
- p.grad.detach_()
- else:
- p.grad.requires_grad_(False)
- if (not foreach or p.grad.is_sparse):
- p.grad.zero_()
- else:
- per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad)
- if foreach:
- for _, per_dtype_grads in per_device_and_dtype_grads.items():
- for grads in per_dtype_grads.values():
- torch._foreach_zero_(grads)
- def step(self, closure):
- r"""Performs a single optimization step (parameter update).
- Args:
- closure (Callable): A closure that reevaluates the model and
- returns the loss. Optional for most optimizers.
- .. note::
- Unless otherwise specified, this function should not modify the
- ``.grad`` field of the parameters.
- """
- raise NotImplementedError
- def add_param_group(self, param_group):
- r"""Add a param group to the :class:`Optimizer` s `param_groups`.
- This can be useful when fine tuning a pre-trained network as frozen layers can be made
- trainable and added to the :class:`Optimizer` as training progresses.
- Args:
- param_group (dict): Specifies what Tensors should be optimized along with group
- specific optimization options.
- """
- assert isinstance(param_group, dict), "param group must be a dict"
- params = param_group['params']
- if isinstance(params, torch.Tensor):
- param_group['params'] = [params]
- elif isinstance(params, set):
- raise TypeError('optimizer parameters need to be organized in ordered collections, but '
- 'the ordering of tensors in sets will change between runs. Please use a list instead.')
- else:
- param_group['params'] = list(params)
- for param in param_group['params']:
- if not isinstance(param, torch.Tensor):
- raise TypeError("optimizer can only optimize Tensors, "
- "but one of the params is " + torch.typename(param))
- if not self.defaults.get('differentiable', None) and not (param.is_leaf or param.retains_grad):
- raise ValueError("can't optimize a non-leaf Tensor")
- for name, default in self.defaults.items():
- if default is required and name not in param_group:
- raise ValueError("parameter group didn't specify a value of required optimization parameter " +
- name)
- else:
- param_group.setdefault(name, default)
- params = param_group['params']
- if len(params) != len(set(params)):
- warnings.warn("optimizer contains a parameter group with duplicate parameters; "
- "in future, this will cause an error; "
- "see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3)
- param_set = set()
- for group in self.param_groups:
- param_set.update(set(group['params']))
- if not param_set.isdisjoint(set(param_group['params'])):
- raise ValueError("some parameters appear in more than one parameter group")
- self.param_groups.append(param_group)
|