api.py 4.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from typing import List, Union, Mapping, Dict, Any
  2. import torch.optim as optim
  3. from torch import Tensor
  4. from torch.distributed._shard.sharded_tensor import ShardedTensor
  5. class ShardedOptimizer(optim.Optimizer):
  6. def __init__(
  7. self,
  8. named_params: Mapping[str, Union[Tensor, ShardedTensor]],
  9. optimizer_class,
  10. *optimizer_args,
  11. **optimizer_kwargs
  12. ):
  13. """
  14. ShardedOptimizer collects all tensors and local shard tensors of
  15. ShardedTensor, then use these tensors as ``params`` for optimizers
  16. Args:
  17. named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict
  18. of parameters, where key is the parameter key, value is either
  19. Tensor or ShardedTensor parameter.
  20. optimizer_class (torch.optim.Optimizer): the Optimizer to use
  21. locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc.
  22. *optimizer_args: the arguments to initialize the optimizer.
  23. **optimizer_kwargs: the key-word arguments to initialize the optimizer.
  24. """
  25. tensors: List[Tensor] = []
  26. for value in named_params.values():
  27. if isinstance(value, ShardedTensor):
  28. for local_shard in value.local_shards():
  29. tensors.append(local_shard.tensor)
  30. else:
  31. tensors.append(value)
  32. self.named_params = named_params
  33. self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs)
  34. self.param_groups = self._optim.param_groups
  35. self.state = self._optim.state
  36. def zero_grad(self, set_to_none: bool = True): # type: ignore[override]
  37. r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
  38. Args:
  39. set_to_none (bool): instead of setting to zero, set the grads to None.
  40. This will in general have lower memory footprint, and can modestly improve performance.
  41. However, it changes certain behaviors. For example:
  42. 1. When the user tries to access a gradient and perform manual ops on it,
  43. a None attribute or a Tensor full of 0s will behave differently.
  44. 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
  45. are guaranteed to be None for params that did not receive a gradient.
  46. 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
  47. (in one case it does the step with a gradient of 0 and in the other it skips
  48. the step altogether).
  49. """
  50. self._optim.zero_grad(set_to_none)
  51. def step(self, closure=None):
  52. r"""Performs a single optimization step (parameter update).
  53. Args:
  54. closure (Callable): A closure that reevaluates the model and
  55. returns the loss. Optional for most optimizers.
  56. .. note::
  57. Unless otherwise specified, this function should not modify the
  58. ``.grad`` field of the parameters.
  59. """
  60. self._optim.step(closure)
  61. def state_dict(self) -> Dict[str, Any]:
  62. """
  63. Returned state and param_groups will contain parameter keys
  64. instead of parameter indices like torch.optim.Optimizer.
  65. This allows for advanced functionality like optimizer re-sharding to be implemented.
  66. """
  67. # TODO: implement state_dict
  68. raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!")
  69. def load_state_dict(self, state_dict: Mapping[str, Any]):
  70. r"""Loads the ShardedOptimizer state.
  71. Args:
  72. state_dict (dict): ShardedOptimizer state. Should be an object returned
  73. from a call to :meth:`state_dict`.
  74. """
  75. # TODO: implement load_state_dict
  76. raise NotImplementedError("ShardedOptimizer load_state_dict not implemented yet!")
  77. def add_param_group(self, param_group: Any):
  78. r"""Add a new param group
  79. """
  80. # TODO: implement add_param_group
  81. raise NotImplementedError("ShardedOptimizer add_param_group not implemented yet!")