utils.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from typing import Type
  2. from torch import optim
  3. from .functional_adadelta import _FunctionalAdadelta
  4. from .functional_adagrad import _FunctionalAdagrad
  5. from .functional_adam import _FunctionalAdam
  6. from .functional_adamax import _FunctionalAdamax
  7. from .functional_adamw import _FunctionalAdamW
  8. from .functional_rmsprop import _FunctionalRMSprop
  9. from .functional_rprop import _FunctionalRprop
  10. from .functional_sgd import _FunctionalSGD
  11. # dict to map a user passed in optimizer_class to a functional
  12. # optimizer class if we have already defined inside the
  13. # distributed.optim package, this is so that we hide the
  14. # functional optimizer to user and still provide the same API.
  15. functional_optim_map = {
  16. optim.Adagrad: _FunctionalAdagrad,
  17. optim.Adam: _FunctionalAdam,
  18. optim.AdamW: _FunctionalAdamW,
  19. optim.SGD: _FunctionalSGD,
  20. optim.Adadelta: _FunctionalAdadelta,
  21. optim.RMSprop: _FunctionalRMSprop,
  22. optim.Rprop: _FunctionalRprop,
  23. optim.Adamax: _FunctionalAdamax,
  24. }
  25. def register_functional_optim(key, optim):
  26. """
  27. Interface to insert a new functional optimizer to functional_optim_map
  28. ``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key
  29. need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers)
  30. Example::
  31. >>> # import the new functional optimizer
  32. >>> # xdoctest: +SKIP
  33. >>> from xyz import fn_optimizer
  34. >>> from torch.distributed.optim.utils import register_functional_optim
  35. >>> fn_optim_key = "XYZ_optim"
  36. >>> register_functional_optim(fn_optim_key, fn_optimizer)
  37. """
  38. if key not in functional_optim_map:
  39. functional_optim_map[key] = optim
  40. def as_functional_optim(optim_cls: Type, *args, **kwargs):
  41. try:
  42. functional_cls = functional_optim_map[optim_cls]
  43. except KeyError as e:
  44. raise ValueError(
  45. f"Optimizer {optim_cls} does not have a functional " f"counterpart!"
  46. ) from e
  47. return _create_functional_optim(functional_cls, *args, **kwargs)
  48. def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
  49. return functional_optim_cls(
  50. [],
  51. *args,
  52. **kwargs,
  53. _allow_empty_param_list=True,
  54. )