dist_utils.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import re
  2. import sys
  3. import time
  4. from functools import partial, wraps
  5. from typing import Tuple
  6. import torch.distributed as dist
  7. import torch.distributed.rpc as rpc
  8. from torch.distributed.rpc import _rref_context_get_debug_info
  9. from torch.testing._internal.common_utils import FILE_SCHEMA, TEST_WITH_TSAN
  10. if not dist.is_available():
  11. print("c10d not available, skipping tests", file=sys.stderr)
  12. sys.exit(0)
  13. INIT_METHOD_TEMPLATE = FILE_SCHEMA + "{file_name}"
  14. def dist_init(
  15. old_test_method=None,
  16. setup_rpc: bool = True,
  17. clean_shutdown: bool = True,
  18. faulty_messages=None,
  19. messages_to_delay=None,
  20. ):
  21. """
  22. We use this decorator for setting up and tearing down state since
  23. MultiProcessTestCase runs each `test*` method in a separate process and
  24. each process just runs the `test*` method without actually calling
  25. 'setUp' and 'tearDown' methods of unittest.
  26. Note: pass the string representation of MessageTypes that should be used
  27. with the faulty agent's send function. By default, all retriable messages
  28. ("RREF_FORK_REQUEST", "RREF_CHILD_ACCEPT", "RREF_USER_DELETE",
  29. "CLEANUP_AUTOGRAD_CONTEXT_REQ") will use the faulty send (this default is
  30. set from faulty_rpc_agent_test_fixture.py).
  31. """
  32. # If we use dist_init without arguments (ex: @dist_init), old_test_method is
  33. # appropriately set and we return the wrapper appropriately. On the other
  34. # hand if dist_init has arguments (ex: @dist_init(clean_shutdown=False)),
  35. # old_test_method is None and we return a functools.partial which is the real
  36. # decorator that is used and as a result we recursively call dist_init with
  37. # old_test_method and the rest of the arguments appropriately set.
  38. if old_test_method is None:
  39. return partial(
  40. dist_init,
  41. setup_rpc=setup_rpc,
  42. clean_shutdown=clean_shutdown,
  43. faulty_messages=faulty_messages,
  44. messages_to_delay=messages_to_delay,
  45. )
  46. @wraps(old_test_method)
  47. def new_test_method(self, *arg, **kwargs):
  48. # Setting _ignore_rref_leak to make sure OwnerRRefs are properly deleted
  49. # in tests.
  50. import torch.distributed.rpc.api as api
  51. api._ignore_rref_leak = False
  52. self.worker_id = self.rank
  53. self.setup_fault_injection(faulty_messages, messages_to_delay)
  54. rpc_backend_options = self.rpc_backend_options
  55. if setup_rpc:
  56. if TEST_WITH_TSAN:
  57. # TSAN runs much slower.
  58. rpc_backend_options.rpc_timeout = rpc.constants.DEFAULT_RPC_TIMEOUT_SEC * 5
  59. rpc.constants.DEFAULT_SHUTDOWN_TIMEOUT = 60
  60. rpc.init_rpc(
  61. name="worker%d" % self.rank,
  62. backend=self.rpc_backend,
  63. rank=self.rank,
  64. world_size=self.world_size,
  65. rpc_backend_options=rpc_backend_options,
  66. )
  67. return_value = old_test_method(self, *arg, **kwargs)
  68. if setup_rpc:
  69. rpc.shutdown(graceful=clean_shutdown)
  70. return return_value
  71. return new_test_method
  72. def noop() -> None:
  73. pass
  74. def wait_until_node_failure(rank: int, expected_error_regex: str = ".*") -> str:
  75. """
  76. Loops until an RPC to the given rank fails. This is used to
  77. indicate that the node has failed in unit tests.
  78. Args:
  79. rank (int): Rank of the node expected to fail
  80. expected_error_regex (optional, str): Regex of exception message expected. Useful to ensure a specific failure
  81. occurs, not just any.
  82. """
  83. while True:
  84. try:
  85. rpc.rpc_sync("worker{}".format(rank), noop, args=())
  86. time.sleep(0.1)
  87. except Exception as e:
  88. if re.search(pattern=expected_error_regex, string=str(e)):
  89. return str(e)
  90. def wait_until_pending_futures_and_users_flushed(timeout: int = 20) -> None:
  91. """
  92. The RRef protocol holds forkIds of rrefs in a map until those forks are
  93. confirmed by the owner. The message confirming the fork may arrive after
  94. our tests check whether this map is empty, which leads to failures and
  95. flaky tests. to_here also does not guarantee that we have finished
  96. processind the owner's confirmation message for the RRef. This function
  97. loops until the map is empty, which means the messages have been received
  98. as processed. Call this function before asserting the map returned by
  99. _get_debug_info is empty.
  100. """
  101. start = time.time()
  102. while True:
  103. debug_info = _rref_context_get_debug_info()
  104. num_pending_futures = int(debug_info["num_pending_futures"])
  105. num_pending_users = int(debug_info["num_pending_users"])
  106. if num_pending_futures == 0 and num_pending_users == 0:
  107. break
  108. time.sleep(0.1)
  109. if time.time() - start > timeout:
  110. raise ValueError(
  111. "Timed out waiting to flush pending futures and users, had {} pending futures and {} pending users".format(
  112. num_pending_futures, num_pending_users
  113. )
  114. )
  115. def get_num_owners_and_forks() -> Tuple[str, str]:
  116. """
  117. Retrieves number of OwnerRRefs and forks on this node from
  118. _rref_context_get_debug_info.
  119. """
  120. rref_dbg_info = _rref_context_get_debug_info()
  121. num_owners = rref_dbg_info["num_owner_rrefs"]
  122. num_forks = rref_dbg_info["num_forks"]
  123. return num_owners, num_forks
  124. def wait_until_owners_and_forks_on_rank(
  125. num_owners: int, num_forks: int, rank: int, timeout: int = 20
  126. ) -> None:
  127. """
  128. Waits until timeout for num_forks and num_owners to exist on the rank. Used
  129. to ensure proper deletion of RRefs in tests.
  130. """
  131. start = time.time()
  132. while True:
  133. num_owners_on_rank, num_forks_on_rank = rpc.rpc_sync(
  134. worker_name(rank), get_num_owners_and_forks, args=(), timeout=5
  135. )
  136. num_owners_on_rank = int(num_owners_on_rank)
  137. num_forks_on_rank = int(num_forks_on_rank)
  138. if num_owners_on_rank == num_owners and num_forks_on_rank == num_forks:
  139. return
  140. time.sleep(1)
  141. if time.time() - start > timeout:
  142. raise ValueError(
  143. "Timed out waiting {} sec for {} owners and {} forks on rank, had {} owners and {} forks".format(
  144. timeout,
  145. num_owners,
  146. num_forks,
  147. num_owners_on_rank,
  148. num_forks_on_rank,
  149. )
  150. )
  151. def initialize_pg(init_method, rank: int, world_size: int) -> None:
  152. # This is for tests using `dist.barrier`.
  153. if not dist.is_initialized():
  154. dist.init_process_group(
  155. backend="gloo",
  156. init_method=init_method,
  157. rank=rank,
  158. world_size=world_size,
  159. )
  160. def worker_name(rank: int) -> str:
  161. return "worker{}".format(rank)
  162. def get_function_event(function_events, partial_event_name):
  163. """
  164. Returns the first event that matches partial_event_name in the provided
  165. function_events. These function_events should be the output of
  166. torch.autograd.profiler.function_events().
  167. Args:
  168. function_events: function_events returned by the profiler.
  169. event_name (str): partial key that the event was profiled with.
  170. """
  171. event = [event for event in function_events if partial_event_name in event.name][0]
  172. return event