__init__.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import sys
  2. import torch
  3. def is_available():
  4. return hasattr(torch._C, "_dist_autograd_init")
  5. if is_available() and not torch._C._dist_autograd_init():
  6. raise RuntimeError("Failed to initialize torch.distributed.autograd")
  7. if is_available():
  8. from torch._C._distributed_autograd import (
  9. get_gradients,
  10. backward,
  11. _init,
  12. _new_context,
  13. _release_context,
  14. _get_max_id,
  15. _is_valid_context,
  16. _retrieve_context,
  17. _current_context,
  18. _get_debug_info,
  19. DistAutogradContext,
  20. )
  21. class context:
  22. '''
  23. Context object to wrap forward and backward passes when using
  24. distributed autograd. The ``context_id`` generated in the ``with``
  25. statement is required to uniquely identify a distributed backward pass
  26. on all workers. Each worker stores metadata associated with this
  27. ``context_id``, which is required to correctly execute a distributed
  28. autograd pass.
  29. Example::
  30. >>> # xdoctest: +SKIP
  31. >>> import torch.distributed.autograd as dist_autograd
  32. >>> with dist_autograd.context() as context_id:
  33. >>> t1 = torch.rand((3, 3), requires_grad=True)
  34. >>> t2 = torch.rand((3, 3), requires_grad=True)
  35. >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
  36. >>> dist_autograd.backward(context_id, [loss])
  37. '''
  38. def __enter__(self):
  39. self.autograd_context = _new_context()
  40. return self.autograd_context._context_id()
  41. def __exit__(self, type, value, traceback):
  42. _release_context(self.autograd_context._context_id())