12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- from typing import List, Union, Mapping, Dict, Any
- import torch.optim as optim
- from torch import Tensor
- from torch.distributed._shard.sharded_tensor import ShardedTensor
- class ShardedOptimizer(optim.Optimizer):
- def __init__(
- self,
- named_params: Mapping[str, Union[Tensor, ShardedTensor]],
- optimizer_class,
- *optimizer_args,
- **optimizer_kwargs
- ):
- """
- ShardedOptimizer collects all tensors and local shard tensors of
- ShardedTensor, then use these tensors as ``params`` for optimizers
- Args:
- named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict
- of parameters, where key is the parameter key, value is either
- Tensor or ShardedTensor parameter.
- optimizer_class (torch.optim.Optimizer): the Optimizer to use
- locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc.
- *optimizer_args: the arguments to initialize the optimizer.
- **optimizer_kwargs: the key-word arguments to initialize the optimizer.
- """
- tensors: List[Tensor] = []
- for value in named_params.values():
- if isinstance(value, ShardedTensor):
- for local_shard in value.local_shards():
- tensors.append(local_shard.tensor)
- else:
- tensors.append(value)
- self.named_params = named_params
- self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs)
- self.param_groups = self._optim.param_groups
- self.state = self._optim.state
- def zero_grad(self, set_to_none: bool = True): # type: ignore[override]
- r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
- Args:
- set_to_none (bool): instead of setting to zero, set the grads to None.
- This will in general have lower memory footprint, and can modestly improve performance.
- However, it changes certain behaviors. For example:
- 1. When the user tries to access a gradient and perform manual ops on it,
- a None attribute or a Tensor full of 0s will behave differently.
- 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
- are guaranteed to be None for params that did not receive a gradient.
- 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
- (in one case it does the step with a gradient of 0 and in the other it skips
- the step altogether).
- """
- self._optim.zero_grad(set_to_none)
- def step(self, closure=None):
- r"""Performs a single optimization step (parameter update).
- Args:
- closure (Callable): A closure that reevaluates the model and
- returns the loss. Optional for most optimizers.
- .. note::
- Unless otherwise specified, this function should not modify the
- ``.grad`` field of the parameters.
- """
- self._optim.step(closure)
- def state_dict(self) -> Dict[str, Any]:
- """
- Returned state and param_groups will contain parameter keys
- instead of parameter indices like torch.optim.Optimizer.
- This allows for advanced functionality like optimizer re-sharding to be implemented.
- """
- # TODO: implement state_dict
- raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!")
- def load_state_dict(self, state_dict: Mapping[str, Any]):
- r"""Loads the ShardedOptimizer state.
- Args:
- state_dict (dict): ShardedOptimizer state. Should be an object returned
- from a call to :meth:`state_dict`.
- """
- # TODO: implement load_state_dict
- raise NotImplementedError("ShardedOptimizer load_state_dict not implemented yet!")
- def add_param_group(self, param_group: Any):
- r"""Add a new param group
- """
- # TODO: implement add_param_group
- raise NotImplementedError("ShardedOptimizer add_param_group not implemented yet!")
|