123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- from typing import Type
- from torch import optim
- from .functional_adadelta import _FunctionalAdadelta
- from .functional_adagrad import _FunctionalAdagrad
- from .functional_adam import _FunctionalAdam
- from .functional_adamax import _FunctionalAdamax
- from .functional_adamw import _FunctionalAdamW
- from .functional_rmsprop import _FunctionalRMSprop
- from .functional_rprop import _FunctionalRprop
- from .functional_sgd import _FunctionalSGD
- # dict to map a user passed in optimizer_class to a functional
- # optimizer class if we have already defined inside the
- # distributed.optim package, this is so that we hide the
- # functional optimizer to user and still provide the same API.
- functional_optim_map = {
- optim.Adagrad: _FunctionalAdagrad,
- optim.Adam: _FunctionalAdam,
- optim.AdamW: _FunctionalAdamW,
- optim.SGD: _FunctionalSGD,
- optim.Adadelta: _FunctionalAdadelta,
- optim.RMSprop: _FunctionalRMSprop,
- optim.Rprop: _FunctionalRprop,
- optim.Adamax: _FunctionalAdamax,
- }
- def register_functional_optim(key, optim):
- """
- Interface to insert a new functional optimizer to functional_optim_map
- ``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key
- need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers)
- Example::
- >>> # import the new functional optimizer
- >>> # xdoctest: +SKIP
- >>> from xyz import fn_optimizer
- >>> from torch.distributed.optim.utils import register_functional_optim
- >>> fn_optim_key = "XYZ_optim"
- >>> register_functional_optim(fn_optim_key, fn_optimizer)
- """
- if key not in functional_optim_map:
- functional_optim_map[key] = optim
- def as_functional_optim(optim_cls: Type, *args, **kwargs):
- try:
- functional_cls = functional_optim_map[optim_cls]
- except KeyError as e:
- raise ValueError(
- f"Optimizer {optim_cls} does not have a functional " f"counterpart!"
- ) from e
- return _create_functional_optim(functional_cls, *args, **kwargs)
- def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
- return functional_optim_cls(
- [],
- *args,
- **kwargs,
- _allow_empty_param_list=True,
- )
|