apply_optimizer_in_backward.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from typing import Any, Dict, Iterable, List, no_type_check, Type
  2. import torch
  3. __all__: List[str] = []
  4. # WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
  5. # without changing it's life-time.
  6. # NOTE: Alternative is to add the meta-data as an attribute to the tensor,
  7. # but that will serialize the meta-data if Tensor is serialized.
  8. param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary()
  9. param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary()
  10. @no_type_check
  11. def _apply_optimizer_in_backward(
  12. optimizer_class: Type[torch.optim.Optimizer],
  13. params: Iterable[torch.nn.Parameter],
  14. optimizer_kwargs: Dict[str, Any],
  15. ) -> None:
  16. """
  17. Upon ``backward()``, parameters will fire the corresponding optimizer.
  18. Note - gradients for these parameters will be set to None after ``backward()``.
  19. This means that any other (non applied) optimizer over this parameter will be
  20. a no-op.
  21. Args:
  22. optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter
  23. params: (Iterator[nn.Parameter]): parameters to apply optimizer state to
  24. optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor
  25. Example::
  26. params_generator = model.parameters()
  27. param_1 = next(params_generator)
  28. remainder_params = list(params_generator)
  29. apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": .02})
  30. apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": .04})
  31. model(...).sum().backward() # after backward, parameters will already
  32. # have their registered optimizer applied.
  33. """
  34. @no_type_check
  35. def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
  36. # view_as creates a node in autograd graph that allows us access to the
  37. # parameter's AccumulateGrad autograd function object. We register a
  38. # hook on this object to fire the optimizer when the gradient for
  39. # this parameter is ready (has been accumulated into .grad field)
  40. # Don't create a new acc_grad if we already have one
  41. # i.e. for shared parameters or attaching multiple optimizers to a param.
  42. if param not in param_to_acc_grad_map:
  43. param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[0][0]
  44. optimizer = optimizer_class([param], **optimizer_kwargs)
  45. if not hasattr(param, "_in_backward_optimizers"):
  46. param._in_backward_optimizers = [] # type: ignore[attr-defined]
  47. # TODO: investigate whether we really need these attributes.
  48. param._optimizer_classes = [] # type: ignore[attr-defined]
  49. param._optimizer_kwargs = [] # type: ignore[attr-defined]
  50. param._in_backward_optimizers.append(optimizer) # type: ignore[attr-defined]
  51. param._optimizer_classes.append(optimizer_class) # type: ignore[attr-defined]
  52. param._optimizer_kwargs.append(optimizer_kwargs) # type: ignore[attr-defined]
  53. def optimizer_hook(*_unused) -> None:
  54. for opt in param._in_backward_optimizers: # type: ignore[attr-defined]
  55. opt.step()
  56. param.grad = None
  57. handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined]
  58. if param not in param_to_optim_hook_handle_map:
  59. param_to_optim_hook_handle_map[param] = []
  60. param_to_optim_hook_handle_map[param].append(handle)
  61. for param in params:
  62. _apply_optimizer_in_backward_to_param(param)