optimizer.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import logging
  2. from collections import defaultdict
  3. from threading import Lock
  4. from typing import List, Optional
  5. import torch
  6. import torch.distributed.autograd as dist_autograd
  7. import torch.distributed.rpc as rpc
  8. import torch.jit as jit
  9. import torch.nn as nn
  10. from torch import Tensor
  11. from torch.distributed.rpc import RRef
  12. from .utils import functional_optim_map
  13. __all__ = ["DistributedOptimizer"]
  14. logger = logging.getLogger(__name__)
  15. # XXX: we define a _ScriptModuleOptimizer here to explicitly
  16. # compile the FunctionalOptimizer class into TorchScript
  17. # This is because ScriptClass instance still lives in
  18. # python unless you explicitly compile it as an attribute
  19. # in ScriptModule or pass it to a ScriptFunction
  20. # _ScriptLocalOptimizerInterface serves as a common
  21. # interface type for Optimizer ScriptModules.
  22. #
  23. # TODO (wanchaol): remove this once we added TorchScript
  24. # class reference semantics
  25. @jit.interface
  26. class _ScriptLocalOptimizerInterface:
  27. def step(self, autograd_ctx_id: int) -> None:
  28. pass
  29. class _ScriptLocalOptimizer(nn.Module):
  30. # TorchScript does not support multithread concurrent compiling.
  31. # request_callback might invoke concurrent compiling, so we
  32. # serialize the compiling with a lock
  33. compile_lock = Lock()
  34. def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
  35. super().__init__()
  36. self._local_params = [rref.local_value() for rref in local_params_rref]
  37. self.optim = optim_cls(self._local_params, *args, **kwargs)
  38. @jit.export
  39. def step(self, autograd_ctx_id: int):
  40. all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
  41. # apply functional optimizer step with a list of gradients
  42. grads: List[Optional[Tensor]] = [
  43. all_local_grads[p] if p in all_local_grads else None
  44. for p in self._local_params
  45. ]
  46. self.optim.step(grads)
  47. # TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once
  48. # we have converted all to functional optimizer in distributed.optim
  49. class _LocalOptimizer:
  50. # Ideally we would only need to share a lock for instances of
  51. # _LocalOptimizer that deal with the same parameters. We are
  52. # making a simplifying assumption here that if there is more
  53. # than one instance of _LocalOptimizer per worker, they will
  54. # be optimizing the same parameters (e.g. each data parallel
  55. # trainer will create its own instance of _LocalOptimizer but
  56. # they will all optimize the same parameters on each worker)
  57. global_lock = Lock()
  58. def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
  59. self._local_params = [rref.local_value() for rref in local_params_rref]
  60. self.optim = optim_cls(self._local_params, *args, **kwargs)
  61. def step(self, autograd_ctx_id):
  62. all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
  63. with _LocalOptimizer.global_lock:
  64. for param, grad in all_local_grads.items():
  65. param.grad = grad
  66. self.optim.step()
  67. def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
  68. return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs))
  69. def _local_optimizer_step(local_optim_rref, autograd_ctx_id):
  70. local_optim = local_optim_rref.local_value()
  71. local_optim.step(autograd_ctx_id)
  72. # new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer
  73. def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
  74. optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)
  75. with _ScriptLocalOptimizer.compile_lock:
  76. script_optim = jit.script(optim)
  77. return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface)
  78. @jit.script
  79. def _script_local_optimizer_step(
  80. local_optim_rref: RRef[_ScriptLocalOptimizerInterface], autograd_ctx_id: int
  81. ) -> None:
  82. local_optim = local_optim_rref.local_value()
  83. local_optim.step(autograd_ctx_id)
  84. def _wait_for_all(rpc_futs):
  85. # TODO: improve error propagation
  86. exception = None
  87. results = []
  88. for fut in rpc_futs:
  89. try:
  90. results.append(fut.wait())
  91. except Exception as e:
  92. results.append(e)
  93. exception = e
  94. if exception is not None:
  95. raise exception
  96. return results
  97. class DistributedOptimizer:
  98. """
  99. DistributedOptimizer takes remote references to parameters scattered
  100. across workers and applies the given optimizer locally for each parameter.
  101. This class uses :meth:`~torch.distributed.autograd.get_gradients` in order
  102. to retrieve the gradients for specific parameters.
  103. Concurrent calls to
  104. :meth:`~torch.distributed.optim.DistributedOptimizer.step`,
  105. either from the same or different clients, will
  106. be serialized on each worker -- as each worker's optimizer can only work
  107. on one set of gradients at a time. However, there is no guarantee that
  108. the full forward-backward-optimizer sequence will execute for one client
  109. at a time. This means that the gradients being applied may not correspond
  110. to the latest forward pass executed on a given worker. Also, there is no
  111. guaranteed ordering across workers.
  112. `DistributedOptimizer` creates the local optimizer with TorchScript enabled
  113. by default, so that optimizer updates are not blocked by the Python Global
  114. Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed
  115. Model Parallel). This feature is currently enabled for most optimizers. You
  116. can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support
  117. for your own custom optimizers.
  118. Args:
  119. optimizer_class (optim.Optimizer): the class of optimizer to
  120. instantiate on each worker.
  121. params_rref (list[RRef]): list of RRefs to local or remote parameters
  122. to optimize.
  123. args: arguments to pass to the optimizer constructor on each worker.
  124. kwargs: arguments to pass to the optimizer constructor on each worker.
  125. Example::
  126. >>> # xdoctest: +SKIP("distributed")
  127. >>> import torch.distributed.autograd as dist_autograd
  128. >>> import torch.distributed.rpc as rpc
  129. >>> from torch import optim
  130. >>> from torch.distributed.optim import DistributedOptimizer
  131. >>>
  132. >>> with dist_autograd.context() as context_id:
  133. >>> # Forward pass.
  134. >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
  135. >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
  136. >>> loss = rref1.to_here() + rref2.to_here()
  137. >>>
  138. >>> # Backward pass.
  139. >>> dist_autograd.backward(context_id, [loss.sum()])
  140. >>>
  141. >>> # Optimizer.
  142. >>> dist_optim = DistributedOptimizer(
  143. >>> optim.SGD,
  144. >>> [rref1, rref2],
  145. >>> lr=0.05,
  146. >>> )
  147. >>> dist_optim.step(context_id)
  148. __ https://github.com/pytorch/tutorials/pull/1465
  149. """
  150. def __init__(self, optimizer_class, params_rref, *args, **kwargs):
  151. torch._C._log_api_usage_once("torch.distributed.optim.DistributedOptimizer")
  152. per_worker_params_rref = defaultdict(list)
  153. for param in params_rref:
  154. per_worker_params_rref[param.owner()].append(param)
  155. if optimizer_class in functional_optim_map and jit._state._enabled:
  156. optim_ctor = functional_optim_map.get(optimizer_class)
  157. else:
  158. optim_ctor = optimizer_class
  159. self.is_functional_optim = optim_ctor != optimizer_class
  160. if self.is_functional_optim:
  161. optimizer_new_func = _new_script_local_optimizer
  162. else:
  163. logger.warning(
  164. f"Creating the optimizer {optimizer_class} without TorchScript support, "
  165. "this might result in slow computation time in multithreading environment"
  166. "(i.e. Distributed Model Parallel training on CPU) due to the Python's "
  167. "Global Interpreter Lock (GIL). Please file an issue if you need this "
  168. "optimizer in TorchScript. "
  169. )
  170. optimizer_new_func = _new_local_optimizer
  171. remote_optim_futs = []
  172. for worker, param_rrefs in per_worker_params_rref.items():
  173. remote_optim_rref_fut = rpc.rpc_async(
  174. worker,
  175. optimizer_new_func,
  176. args=(optim_ctor, param_rrefs) + args,
  177. kwargs=kwargs,
  178. )
  179. remote_optim_futs.append(remote_optim_rref_fut)
  180. self.remote_optimizers = _wait_for_all(remote_optim_futs)
  181. def step(self, context_id):
  182. """
  183. Performs a single optimization step.
  184. This will call :meth:`torch.optim.Optimizer.step` on each worker
  185. containing parameters to be optimized, and will block until all workers
  186. return. The provided ``context_id`` will be used to retrieve the
  187. corresponding :class:`~torch.distributed.autograd.context` that
  188. contains the gradients that should be applied to the parameters.
  189. Args:
  190. context_id: the autograd context id for which we should run the
  191. optimizer step.
  192. """
  193. dist_autograd._is_valid_context(context_id)
  194. if self.is_functional_optim:
  195. optimizer_step_func = _script_local_optimizer_step
  196. else:
  197. optimizer_step_func = _local_optimizer_step
  198. rpc_futs = []
  199. for optimizer in self.remote_optimizers:
  200. rpc_futs.append(
  201. rpc.rpc_async(
  202. optimizer.owner(),
  203. optimizer_step_func,
  204. args=(optimizer, context_id),
  205. )
  206. )
  207. _wait_for_all(rpc_futs)