12345678910111213141516171819202122232425262728293031323334353637 |
- from contextlib import contextmanager
- from typing import cast
- import logging
- from . import api
- from . import TensorPipeAgent
- logger = logging.getLogger(__name__)
- @contextmanager
- def _group_membership_management(store, name, is_join):
- token_key = "RpcGroupManagementToken"
- join_or_leave = "join" if is_join else "leave"
- my_token = f"Token_for_{name}_{join_or_leave}"
- while True:
- # Retrieve token from store to signal start of rank join/leave critical section
- returned = store.compare_set(token_key, "", my_token).decode()
- if returned == my_token:
- # Yield to the function this context manager wraps
- yield
- # Finished, now exit and release token
- # Update from store to signal end of rank join/leave critical section
- store.set(token_key, "")
- # Other will wait for this token to be set before they execute
- store.set(my_token, "Done")
- break
- else:
- # Store will wait for the token to be released
- try:
- store.wait([returned])
- except RuntimeError:
- logger.error(f"Group membership token {my_token} timed out waiting for {returned} to be released.")
- raise
- def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
- agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
- ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join)
- return ret
|