from typing import Any, Dict, Iterable, List, no_type_check, Type import torch __all__: List[str] = [] # WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter # without changing it's life-time. # NOTE: Alternative is to add the meta-data as an attribute to the tensor, # but that will serialize the meta-data if Tensor is serialized. param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary() param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary() @no_type_check def _apply_optimizer_in_backward( optimizer_class: Type[torch.optim.Optimizer], params: Iterable[torch.nn.Parameter], optimizer_kwargs: Dict[str, Any], ) -> None: """ Upon ``backward()``, parameters will fire the corresponding optimizer. Note - gradients for these parameters will be set to None after ``backward()``. This means that any other (non applied) optimizer over this parameter will be a no-op. Args: optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter params: (Iterator[nn.Parameter]): parameters to apply optimizer state to optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor Example:: params_generator = model.parameters() param_1 = next(params_generator) remainder_params = list(params_generator) apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02}) apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04}) model(...).sum().backward() # after backward, parameters will already # have their registered optimizer applied. """ @no_type_check def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: # view_as creates a node in autograd graph that allows us access to the # parameter's AccumulateGrad autograd function object. We register a # hook on this object to fire the optimizer when the gradient for # this parameter is ready (has been accumulated into .grad field) # Don't create a new acc_grad if we already have one # i.e. for shared parameters or attaching multiple optimizers to a param. if param not in param_to_acc_grad_map: param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0] optimizer = optimizer_class([param], **optimizer_kwargs) if not hasattr(param, "_in_backward_optimizers"): param._in_backward_optimizers = [] # type: ignore[attr-defined] # TODO: investigate whether we really need these attributes. param._optimizer_classes = [] # type: ignore[attr-defined] param._optimizer_kwargs = [] # type: ignore[attr-defined] param._in_backward_optimizers.append(optimizer) # type: ignore[attr-defined] param._optimizer_classes.append(optimizer_class) # type: ignore[attr-defined] param._optimizer_kwargs.append(optimizer_kwargs) # type: ignore[attr-defined] def optimizer_hook(*_unused) -> None: for opt in param._in_backward_optimizers: # type: ignore[attr-defined] opt.step() param.grad = None handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined] if param not in param_to_optim_hook_handle_map: param_to_optim_hook_handle_map[param] = [] param_to_optim_hook_handle_map[param].append(handle) for param in params: _apply_optimizer_in_backward_to_param(param)