_utils.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. from contextlib import contextmanager
  2. from typing import cast
  3. import logging
  4. from . import api
  5. from . import TensorPipeAgent
  6. logger = logging.getLogger(__name__)
  7. @contextmanager
  8. def _group_membership_management(store, name, is_join):
  9. token_key = "RpcGroupManagementToken"
  10. join_or_leave = "join" if is_join else "leave"
  11. my_token = f"Token_for_{name}_{join_or_leave}"
  12. while True:
  13. # Retrieve token from store to signal start of rank join/leave critical section
  14. returned = store.compare_set(token_key, "", my_token).decode()
  15. if returned == my_token:
  16. # Yield to the function this context manager wraps
  17. yield
  18. # Finished, now exit and release token
  19. # Update from store to signal end of rank join/leave critical section
  20. store.set(token_key, "")
  21. # Other will wait for this token to be set before they execute
  22. store.set(my_token, "Done")
  23. break
  24. else:
  25. # Store will wait for the token to be released
  26. try:
  27. store.wait([returned])
  28. except RuntimeError:
  29. logger.error(f"Group membership token {my_token} timed out waiting for {returned} to be released.")
  30. raise
  31. def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
  32. agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
  33. ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join)
  34. return ret