distributed_utils.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from contextlib import contextmanager
  2. from datetime import timedelta
  3. from functools import (
  4. partial,
  5. wraps,
  6. )
  7. import torch.distributed as dist
  8. import torch.distributed.distributed_c10d as c10d
  9. class MockProcessGroup(dist.ProcessGroup):
  10. def __init__(self, rank, world):
  11. super().__init__(rank, world)
  12. def getBackendName(self):
  13. return "mock_process_group"
  14. def create_mock_pg(prefix_store, rank, world_size, timeout):
  15. return MockProcessGroup(rank, world_size)
  16. dist.Backend.register_backend('mock_process_group', create_mock_pg)
  17. def mock_init_dist(rank, world_size):
  18. # !!! WARNING !!!
  19. # Kids don't try this at home, this is a cute pile of hacks that
  20. # depends on a small mountain of c10d internals
  21. assert not dist.is_initialized()
  22. store = dist.HashStore()
  23. # Trick _store_based_barrier into believing everyone else already checked-in
  24. # Zero is the group index
  25. store.add(f"{c10d.STORE_BASED_BARRIER_PREFIX}:0", world_size - 1)
  26. dist.init_process_group(
  27. backend="mock_process_group",
  28. rank=rank,
  29. world_size=world_size,
  30. store=store,
  31. group_name="fake",
  32. timeout=timedelta(seconds=1))
  33. @contextmanager
  34. def with_dist(rank=0, world_size=2):
  35. """
  36. Context manager that initializer c10d with a fake process group.
  37. """
  38. mock_init_dist(rank=rank, world_size=world_size)
  39. try:
  40. yield
  41. finally:
  42. dist.destroy_process_group()
  43. def with_fake_comms(func=None, rank=0, world_size=2):
  44. """
  45. Function wrapper that inits a fake process group designed for testing.
  46. Right now only querying for world size is available
  47. """
  48. if func is None:
  49. return partial(with_fake_comms, rank=rank, world_size=world_size)
  50. @wraps(func)
  51. def wrapper(self, *args, **kwargs):
  52. with with_dist(rank, world_size):
  53. func(self, *args, **kwargs)
  54. return wrapper