123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- from typing import Any, Dict, Iterable, List, no_type_check, Type
- import torch
- __all__: List[str] = []
- # WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
- # without changing it's life-time.
- # NOTE: Alternative is to add the meta-data as an attribute to the tensor,
- # but that will serialize the meta-data if Tensor is serialized.
- param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary()
- param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary()
- @no_type_check
- def _apply_optimizer_in_backward(
- optimizer_class: Type[torch.optim.Optimizer],
- params: Iterable[torch.nn.Parameter],
- optimizer_kwargs: Dict[str, Any],
- ) -> None:
- """
- Upon ``backward()``, parameters will fire the corresponding optimizer.
- Note - gradients for these parameters will be set to None after ``backward()``.
- This means that any other (non applied) optimizer over this parameter will be
- a no-op.
- Args:
- optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter
- params: (Iterator[nn.Parameter]): parameters to apply optimizer state to
- optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor
- Example::
- params_generator = model.parameters()
- param_1 = next(params_generator)
- remainder_params = list(params_generator)
- apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02})
- apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04})
- model(...).sum().backward() # after backward, parameters will already
- # have their registered optimizer applied.
- """
- @no_type_check
- def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
- # view_as creates a node in autograd graph that allows us access to the
- # parameter's AccumulateGrad autograd function object. We register a
- # hook on this object to fire the optimizer when the gradient for
- # this parameter is ready (has been accumulated into .grad field)
- # Don't create a new acc_grad if we already have one
- # i.e. for shared parameters or attaching multiple optimizers to a param.
- if param not in param_to_acc_grad_map:
- param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0]
- optimizer = optimizer_class([param], **optimizer_kwargs)
- if not hasattr(param, "_in_backward_optimizers"):
- param._in_backward_optimizers = [] # type: ignore[attr-defined]
- # TODO: investigate whether we really need these attributes.
- param._optimizer_classes = [] # type: ignore[attr-defined]
- param._optimizer_kwargs = [] # type: ignore[attr-defined]
- param._in_backward_optimizers.append(optimizer) # type: ignore[attr-defined]
- param._optimizer_classes.append(optimizer_class) # type: ignore[attr-defined]
- param._optimizer_kwargs.append(optimizer_kwargs) # type: ignore[attr-defined]
- def optimizer_hook(*_unused) -> None:
- for opt in param._in_backward_optimizers: # type: ignore[attr-defined]
- opt.step()
- param.grad = None
- handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined]
- if param not in param_to_optim_hook_handle_map:
- param_to_optim_hook_handle_map[param] = []
- param_to_optim_hook_handle_map[param].append(handle)
- for param in params:
- _apply_optimizer_in_backward_to_param(param)
|