12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- from contextlib import contextmanager
- from datetime import timedelta
- from functools import (
- partial,
- wraps,
- )
- import torch.distributed as dist
- import torch.distributed.distributed_c10d as c10d
- class MockProcessGroup(dist.ProcessGroup):
- def __init__(self, rank, world):
- super().__init__(rank, world)
- def getBackendName(self):
- return "mock_process_group"
- def create_mock_pg(prefix_store, rank, world_size, timeout):
- return MockProcessGroup(rank, world_size)
- dist.Backend.register_backend('mock_process_group', create_mock_pg)
- def mock_init_dist(rank, world_size):
- # !!! WARNING !!!
- # Kids don't try this at home, this is a cute pile of hacks that
- # depends on a small mountain of c10d internals
- assert not dist.is_initialized()
- store = dist.HashStore()
- # Trick _store_based_barrier into believing everyone else already checked-in
- # Zero is the group index
- store.add(f"{c10d.STORE_BASED_BARRIER_PREFIX}:0", world_size - 1)
- dist.init_process_group(
- backend="mock_process_group",
- rank=rank,
- world_size=world_size,
- store=store,
- group_name="fake",
- timeout=timedelta(seconds=1))
- @contextmanager
- def with_dist(rank=0, world_size=2):
- """
- Context manager that initializer c10d with a fake process group.
- """
- mock_init_dist(rank=rank, world_size=world_size)
- try:
- yield
- finally:
- dist.destroy_process_group()
- def with_fake_comms(func=None, rank=0, world_size=2):
- """
- Function wrapper that inits a fake process group designed for testing.
- Right now only querying for world size is available
- """
- if func is None:
- return partial(with_fake_comms, rank=rank, world_size=world_size)
- @wraps(func)
- def wrapper(self, *args, **kwargs):
- with with_dist(rank, world_size):
- func(self, *args, **kwargs)
- return wrapper
|