123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- import logging
- import warnings
- from copy import deepcopy
- from typing import Any, Collection, Dict, List, Mapping, Union
- import torch
- import torch.nn as nn
- from torch import optim
- from torch.distributed._shard.sharded_tensor import ShardedTensor
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
- __all__: List[str] = []
- logger = logging.getLogger(__name__)
- class _NamedOptimizer(optim.Optimizer):
- """
- ``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by
- parameter key. We replace the original key (number) in an optim to the
- fully qualifed name (FQN) string. User can initialize the optim as they
- initialize a PyTorch optim, the only difference is that they also need to
- pass in the FQN of each parameters.
- Args:
- named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]):
- Mapping from FQN to parameter.
- optimizer_class (optim.Optimizer):
- The class of optimizer to instantiate.
- param_groups (Collection[Mapping[str, Any]]):
- `param_groups` to pass to optimizer if specified.
- The key of the inner map needs to be FQNs.
- Default: None
- module (nn.Module): the module whose parameters to updated
- by the optimizer.
- args: arguments to pass to the optimizer constructor.
- kwargs: arguments to pass to the optimizer constructor.
- Example::
- >>> # xdoctest: +SKIP("distributed")
- >>> from torch import optim
- >>> from torch.distributed.optim import _NamedOptimizer
- >>>
- >>> # Define the named optimizer.
- >>> m = Model(...)
- >>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD)
- >>> # Forward pass + backward pass.
- >>> named_optim.step()
- >>> ...
- >>> # Call state_dict for the named optimizer returns a FQN state_dict.
- >>> named_optim.state_dict()
- Warning: This API is still in development and subject to change.
- TODO: Add tutorial for _NamedOptimizer.
- TODO: Add documentation in the docstring for the public attributes
- like self.param_groups and self.named_parameters.
- """
- def __init__(
- self,
- named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]],
- optimizer_class: optim.Optimizer,
- param_groups: Collection[Mapping[str, Any]] = None,
- module: nn.Module = None,
- *args,
- **kwargs,
- ) -> None:
- torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer")
- self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment]
- self._param_groups_check()
- self.named_parameters = dict(named_parameters)
- params_for_optimizer = (
- self.named_parameters.values() if param_groups is None else param_groups
- )
- self._optimizer = optimizer_class( # type: ignore[operator]
- params_for_optimizer,
- *args,
- **kwargs,
- )
- self.module = module
- if param_groups is None:
- self.ordered_param_keys = list(self.named_parameters.keys())
- else:
- warnings.warn(
- "Since we pass in param_groups, we will use param_groups to "
- "initialize the optimizer, not all parameters of the module."
- )
- param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type]
- ordered_param_keys = []
- for group in param_groups:
- for param in group["params"]:
- if param not in param_to_key:
- raise ValueError(
- f"Expect param name {param} found in param group but is missing."
- )
- ordered_param_keys.append(param_to_key[param])
- self.ordered_param_keys = ordered_param_keys
- # Update param_groups from optimizer.
- self.param_groups = self._optimizer.param_groups
- def _param_groups_check(self):
- if self.param_groups is not None:
- for param_group in self.param_groups:
- assert isinstance(param_group, dict), "param group must be a dict"
- assert "params" in param_group, "param group must contain key params"
- params = param_group["params"]
- if isinstance(params, torch.Tensor):
- params = [params]
- params = list(params)
- for param in params:
- if not isinstance(param, torch.Tensor):
- raise TypeError(
- "optimizer can only optimize Tensors, "
- "but one of the params is " + torch.typename(param)
- )
- param_group["params"] = params
- def state_dict(self) -> Dict[str, Any]:
- """
- Return the ``state_dict`` of the optimzer. Instead of using number to index
- parameters, we will use module fully qualifed name (FQN) as the key.
- """
- state_dict = self._optimizer.state_dict()
- param_groups = state_dict["param_groups"]
- ret_state = {
- self.ordered_param_keys[st_key]: state_val
- for st_key, state_val in state_dict["state"].items()
- }
- ret_groups = []
- for group in param_groups:
- param_keys = []
- for param in group["params"]:
- param_keys.append(self.ordered_param_keys[param])
- ret_group = {"params": sorted(param_keys)}
- for k, v in group.items():
- if k != "params":
- ret_group[k] = deepcopy(v)
- ret_groups.append(ret_group)
- return self._post_state_dict({"state": ret_state, "param_groups": ret_groups})
- def step(self, closure: Any = None) -> None:
- """
- Performs a single optimization step.
- This will call :meth:`torch.optim.Optimizer.step` on the wrapped
- optimizer.
- """
- self._optimizer.step(closure=closure)
- def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
- """
- This public function defines the default behavior to load a state_dict
- for ``_NamedOptimizer``.
- Sample Code
- ```
- my_model = MyModule()
- optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad)
- ...
- optim_state_dict = optimizer.state_dict()
- ...
- ...
- optimizer.load_state_dict(optim_state_dict)
- ...
- ```
- Args:
- state_dict (Dict[str, Any]) : A ``state_dict`` to load into the optimizer.
- Note that this state dict update is performed in place.
- .. note:: PyTorch is using lazy init to initialize the optim states.
- So it is possible that there is no optim state when user call
- ``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter
- that users can only call ``load_state_dict`` after the state is initialized.
- By doing this, we can validate the optim ``state_dict`` to be loaded.
- """
- new_state_dict = self._optimizer.state_dict()
- state_dict = self._pre_load_state_dict(state_dict)
- state = state_dict["state"]
- new_state = new_state_dict["state"]
- if len(new_state) == 0:
- raise ValueError(
- "Expects the optim to be initialized before load but found not initialized."
- )
- for idx, param_key in enumerate(self.ordered_param_keys):
- # When the conditional training is performed, not all parameters are updated in the optim.
- if param_key not in state.keys():
- continue
- if len(state[param_key]) != len(new_state[idx]):
- raise ValueError(
- f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}"
- )
- # Iterate through all optimizer states.
- for state_key, state_val in new_state[idx].items():
- if state_key not in state[param_key]:
- raise ValueError(
- f"Expects state {state_key} for parameter {param_key} but not found."
- )
- src_state_val = state[param_key][state_key]
- if isinstance(state_val, ShardedTensor):
- assert isinstance(src_state_val, ShardedTensor)
- num_shards = len(state_val.local_shards())
- num_new_shards = len(src_state_val.local_shards())
- if num_shards != num_new_shards:
- raise ValueError(
- f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}"
- )
- for shard, src_shard in zip(
- state_val.local_shards(), src_state_val.local_shards()
- ):
- shard.tensor.detach().copy_(src_shard.tensor)
- elif isinstance(state_val, torch.Tensor):
- assert isinstance(src_state_val, torch.Tensor)
- state_val.detach().copy_(src_state_val)
- else:
- new_state[idx][state_key] = deepcopy(src_state_val)
- # Load param_groups of state_dict
- src_param_groups = state_dict["param_groups"]
- new_param_groups = new_state_dict["param_groups"]
- src_group_map = {}
- for group in src_param_groups:
- param_keys = []
- for param_key in group["params"]:
- param_keys.append(param_key)
- src_group_map[_gen_param_group_key(param_keys)] = group
- new_group_map = {}
- for new_group in new_param_groups:
- param_keys = []
- for param_key in new_group["params"]:
- param_keys.append(self.ordered_param_keys[param_key]) # type: ignore[call-overload]
- new_group_map[_gen_param_group_key(param_keys)] = new_group
- for group_key, new_group in new_group_map.items():
- # When not all parameters are used in training or receive gradient, aka., not all parameters
- # would be in the param_group. Thus we skip the group_key here.
- if group_key not in src_group_map:
- continue
- src_group = src_group_map[group_key]
- if len(src_group) != len(new_group):
- raise ValueError(
- f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}."
- )
- for k in src_group:
- if k not in new_group:
- raise ValueError(
- f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing."
- )
- if k != "params":
- new_group[k] = deepcopy(src_group[k])
- self._optimizer.load_state_dict(new_state_dict)
- def add_param_group(self, param_group: Mapping[str, Any]) -> None:
- """
- Add a param group to the :class:`_NamedOptimizer` s `param_groups`.
- Warning: This API is still in development and subject to change.
- """
- assert isinstance(param_group, dict), "param group must be a dict"
- params = param_group["params"]
- if isinstance(params, torch.Tensor):
- param_group["params"] = [params]
- else:
- param_group["params"] = list(params)
- param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type]
- for param in param_group["params"]:
- if param not in param_to_key:
- raise ValueError("some parameters are not in the module")
- self.ordered_param_keys.append(param_to_key[param])
- self._optimizer.add_param_group(param_group)
- # Update param_groups from optimizer.
- self.param_groups = self._optimizer.param_groups
- def init_state(self) -> None:
- """
- Runs a dummy optimizer step, which allows to initialize optimizer state
- because we do lazy init for most optimizers.
- This allows doing in-place loading of optimizer state from a checkpoint.
- """
- for _, param in self.named_parameters.items():
- if param.requires_grad:
- t = torch.zeros_like(param)
- param.grad = torch.autograd.Variable(t)
- # Calling ``step`` will load the initial state for optimizer states.
- self.step(closure=None)
- def _pre_load_state_dict(self, state_dict) -> Dict[str, Any]:
- if isinstance(self.module, FSDP):
- return FSDP.load_optim_state_dict_pre_hook(
- self.module, self._optimizer, state_dict
- )
- return state_dict
- def _post_state_dict(self, state_dict) -> Dict[str, Any]:
- if isinstance(self.module, FSDP):
- FSDP.optim_state_dict_post_hook(self.module, self._optimizer, state_dict)
- return state_dict
- def _gen_param_group_key(param_keys: List[str]) -> str:
- """
- Concatenate all param keys as a unique indentifier for one param group.
- """
- return "/".join(sorted(param_keys))
|