dynamic_rendezvous.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import inspect
  7. import logging
  8. import os
  9. import pickle
  10. import socket
  11. import threading
  12. import time
  13. import weakref
  14. from abc import ABC, abstractmethod
  15. from dataclasses import dataclass
  16. from datetime import datetime, timedelta
  17. from enum import Enum
  18. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast
  19. from torch.distributed import PrefixStore, Store
  20. from torch.distributed.elastic.events import (
  21. NodeState,
  22. construct_and_record_rdzv_event,
  23. )
  24. from .api import (
  25. RendezvousClosedError,
  26. RendezvousError,
  27. RendezvousHandler,
  28. RendezvousParameters,
  29. RendezvousStateError,
  30. RendezvousTimeoutError,
  31. )
  32. from .utils import _delay, _PeriodicTimer
  33. __all__ = ['RendezvousBackend', 'RendezvousTimeout', 'RendezvousSettings', 'DynamicRendezvousHandler', 'create_handler']
  34. log = logging.getLogger(__name__)
  35. def get_method_name(depth=2):
  36. if len(inspect.stack()) > depth:
  37. return inspect.stack()[depth].function
  38. return "no_method_name"
  39. Token = Any
  40. """Represents an opaque fencing token used by the rendezvous backend."""
  41. class RendezvousBackend(ABC):
  42. """Represents a backend that holds the rendezvous state."""
  43. @property
  44. @abstractmethod
  45. def name(self) -> str:
  46. """Gets the name of the backend."""
  47. @abstractmethod
  48. def get_state(self) -> Optional[Tuple[bytes, Token]]:
  49. """Gets the rendezvous state.
  50. Returns:
  51. A tuple of the encoded rendezvous state and its fencing token or
  52. ``None`` if no state is found in the backend.
  53. Raises:
  54. RendezvousConnectionError:
  55. The connection to the backend has failed.
  56. RendezvousStateError:
  57. The rendezvous state is corrupt.
  58. """
  59. @abstractmethod
  60. def set_state(
  61. self, state: bytes, token: Optional[Token] = None
  62. ) -> Optional[Tuple[bytes, Token, bool]]:
  63. """Sets the rendezvous state.
  64. The new rendezvous state is set conditionally:
  65. - If the specified ``token`` matches the fencing token stored in the
  66. backend, the state will be updated. The new state will be returned
  67. to the caller along with its fencing token.
  68. - If the specified ``token`` does not match the fencing token stored
  69. in the backend, the state won't be updated; instead the existing
  70. state along with its fencing token will be returned to the caller.
  71. - If the specified ``token`` is ``None``, the new state will be set
  72. only if there is no existing state in the backend. Either the new
  73. state or the existing state along with its fencing token will be
  74. returned to the caller.
  75. Args:
  76. state:
  77. The encoded rendezvous state.
  78. token:
  79. An optional fencing token that was retrieved by a previous call
  80. to :py:meth:`get_state` or ``set_state()``.
  81. Returns:
  82. A tuple of the serialized rendezvous state, its fencing token, and
  83. a boolean value indicating whether our set attempt succeeded.
  84. Raises:
  85. RendezvousConnectionError:
  86. The connection to the backend has failed.
  87. RendezvousStateError:
  88. The rendezvous state is corrupt.
  89. """
  90. class RendezvousTimeout:
  91. """Holds the timeout configuration of a rendezvous.
  92. Args:
  93. join:
  94. The time within which the rendezvous is expected to complete.
  95. last_call:
  96. An additional wait amount before completing the rendezvous once the
  97. rendezvous has the minimum number of required participants.
  98. close:
  99. The time within which the rendezvous is expected to close after a
  100. call to :py:meth:`RendezvousHandler.set_closed` or
  101. :py:meth:`RendezvousHandler.shutdown`.
  102. keep_alive:
  103. The time within which a keep-alive heartbeat is expected to
  104. complete.
  105. """
  106. _ZERO = timedelta(0)
  107. _DEFAULT_TIMEOUTS = {
  108. "join": timedelta(seconds=600),
  109. "last_call": timedelta(seconds=30),
  110. "close": timedelta(seconds=30),
  111. "heartbeat": timedelta(seconds=5),
  112. }
  113. _join: timedelta
  114. _last_call: timedelta
  115. _close: timedelta
  116. _heartbeat: timedelta
  117. def __init__(
  118. self,
  119. join: Optional[timedelta] = None,
  120. last_call: Optional[timedelta] = None,
  121. close: Optional[timedelta] = None,
  122. heartbeat: Optional[timedelta] = None,
  123. ) -> None:
  124. self._set_timeouts(join=join, last_call=last_call, close=close, heartbeat=heartbeat)
  125. @property
  126. def join(self) -> timedelta:
  127. """Gets the join timeout."""
  128. return self._join
  129. @property
  130. def last_call(self) -> timedelta:
  131. """Gets the last call timeout."""
  132. return self._last_call
  133. @property
  134. def close(self) -> timedelta:
  135. """Gets the close timeout."""
  136. return self._close
  137. @property
  138. def heartbeat(self) -> timedelta:
  139. """Gets the keep-alive heartbeat timeout."""
  140. return self._heartbeat
  141. def _set_timeouts(self, **timeouts: Optional[timedelta]):
  142. for name, timeout in timeouts.items():
  143. if timeout is None:
  144. timeout = self._DEFAULT_TIMEOUTS[name]
  145. if timeout <= self._ZERO:
  146. raise ValueError(f"The {name} timeout ({timeout}) must be positive.")
  147. setattr(self, "_" + name, timeout)
  148. @dataclass(repr=False, eq=False, frozen=True)
  149. class RendezvousSettings:
  150. """Holds the settings of the rendezvous.
  151. Attributes:
  152. run_id:
  153. The run id of the rendezvous.
  154. min_nodes:
  155. The minimum number of nodes to admit to the rendezvous.
  156. max_nodes:
  157. The maximum number of nodes to admit to the rendezvous.
  158. timeout:
  159. The timeout configuration of the rendezvous.
  160. keep_alive_interval:
  161. The amount of time a node waits before sending a heartbeat to keep
  162. it alive in the rendezvous.
  163. keep_alive_max_attempt:
  164. The maximum number of failed heartbeat attempts after which a node
  165. is considered dead.
  166. """
  167. run_id: str
  168. min_nodes: int
  169. max_nodes: int
  170. timeout: RendezvousTimeout
  171. keep_alive_interval: timedelta
  172. keep_alive_max_attempt: int
  173. @dataclass(eq=True, order=True, frozen=True)
  174. class _NodeDesc:
  175. """Describes a node in the rendezvous.
  176. Attributes:
  177. addr:
  178. The FQDN of the node or user specified local node address.
  179. pid:
  180. The id of the process in which the rendezvous handler runs.
  181. local_id:
  182. A process-wide unique id.
  183. """
  184. addr: str
  185. pid: int
  186. local_id: int
  187. def __repr__(self) -> str:
  188. return f"{self.addr}_{self.pid}_{self.local_id}"
  189. class _NodeDescGenerator:
  190. """Generates node descriptors.
  191. A node descriptor is a combination of an FQDN, a process id, and an auto-
  192. incremented integer that uniquely identifies a node in the rendezvous.
  193. """
  194. _lock: threading.Lock
  195. _local_id: int
  196. def __init__(self) -> None:
  197. self._lock = threading.Lock()
  198. # An integer that is incremented with each call to generate().
  199. self._local_id = 0
  200. def generate(self, local_addr: Optional[str] = None) -> _NodeDesc:
  201. # This method can be called by multiple threads concurrently; therefore,
  202. # we must increment the integer atomically.
  203. with self._lock:
  204. local_id = self._local_id
  205. self._local_id += 1
  206. return _NodeDesc(local_addr or socket.getfqdn(), os.getpid(), local_id)
  207. class _RendezvousState:
  208. """Holds the state of a rendezvous.
  209. Attributes:
  210. round:
  211. The current round of the rendezvous.
  212. complete:
  213. A boolean value indicating whether the current round of the
  214. rendezvous is complete.
  215. deadline:
  216. The time at which the current round of the rendezvous will be
  217. considered complete if it is still waiting for nodes to join.
  218. closed:
  219. A boolean value indicating whether the rendezvous is closed.
  220. participants:
  221. A dictionary of the participants and their corresponding ranks.
  222. wait_list:
  223. A set of nodes that are waiting to participate in the next round of
  224. the rendezvous.
  225. last_heartbeats:
  226. A dictionary containing each node's last heartbeat time.
  227. """
  228. round: int
  229. complete: bool
  230. deadline: Optional[datetime]
  231. closed: bool
  232. participants: Dict[_NodeDesc, int]
  233. wait_list: Set[_NodeDesc]
  234. last_heartbeats: Dict[_NodeDesc, datetime]
  235. def __init__(self) -> None:
  236. self.round = 0
  237. self.complete = False
  238. self.deadline = None
  239. self.closed = False
  240. self.participants = {}
  241. self.wait_list = set()
  242. self.last_heartbeats = {}
  243. def _remove_participant_epilogue(state: _RendezvousState, settings: RendezvousSettings) -> None:
  244. if state.complete:
  245. # If we do not have any participants left, move to the next round.
  246. if not state.participants:
  247. state.complete = False
  248. state.round += 1
  249. else:
  250. if len(state.participants) < settings.min_nodes:
  251. state.deadline = None
  252. class _RendezvousStateHolder(ABC):
  253. """Holds the shared rendezvous state synced with other nodes."""
  254. @property
  255. @abstractmethod
  256. def state(self) -> _RendezvousState:
  257. """Gets the local state."""
  258. @abstractmethod
  259. def sync(self) -> Optional[bool]:
  260. """Reads or writes the latest state.
  261. Returns:
  262. A boolean value indicating whether the local state, in case marked
  263. as dirty, was successfully synced with other nodes.
  264. """
  265. @abstractmethod
  266. def mark_dirty(self) -> None:
  267. """Marks the local state as dirty."""
  268. class _BackendRendezvousStateHolder(_RendezvousStateHolder):
  269. """Holds the rendezvous state synced with other nodes via a backend.
  270. Args:
  271. backend:
  272. The rendezvous backend to use.
  273. settings:
  274. The rendezvous settings.
  275. cache_duration:
  276. The amount of time, in seconds, to cache the last rendezvous state
  277. before requesting it from the backend again.
  278. """
  279. _backend: RendezvousBackend
  280. _state: _RendezvousState
  281. _settings: RendezvousSettings
  282. _cache_duration: int
  283. _token: Token
  284. _dirty: bool
  285. _last_sync_time: float
  286. _dead_nodes: List[_NodeDesc]
  287. def __init__(
  288. self,
  289. backend: RendezvousBackend,
  290. settings: RendezvousSettings,
  291. cache_duration: int = 1,
  292. ) -> None:
  293. self._backend = backend
  294. self._state = _RendezvousState()
  295. self._settings = settings
  296. self._cache_duration = cache_duration
  297. self._token = None
  298. self._dirty = False
  299. self._last_sync_time = -1
  300. self._dead_nodes = []
  301. def _record(self, message: str, node_state: NodeState = NodeState.RUNNING):
  302. construct_and_record_rdzv_event(
  303. name=f"{self.__class__.__name__}.{get_method_name()}",
  304. run_id=self._settings.run_id,
  305. message=message,
  306. node_state=node_state,
  307. )
  308. @property
  309. def state(self) -> _RendezvousState:
  310. """See base class."""
  311. return self._state
  312. def sync(self) -> Optional[bool]:
  313. """See base class."""
  314. state_bits: Optional[bytes] = None
  315. token = None
  316. has_set: Optional[bool]
  317. if self._dirty:
  318. has_set = False
  319. state_bits = pickle.dumps(self._state)
  320. set_response = self._backend.set_state(state_bits, self._token)
  321. if set_response is not None:
  322. state_bits, token, has_set = set_response
  323. else:
  324. has_set = None
  325. if self._cache_duration > 0:
  326. # Avoid overloading the backend if we are asked to retrieve the
  327. # state repeatedly. Try to serve the cached state.
  328. if self._last_sync_time >= max(time.monotonic() - self._cache_duration, 0):
  329. return None
  330. get_response = self._backend.get_state()
  331. if get_response is not None:
  332. state_bits, token = get_response
  333. if state_bits is not None:
  334. try:
  335. self._state = pickle.loads(state_bits)
  336. except pickle.PickleError as exc:
  337. raise RendezvousStateError(
  338. "The rendezvous state is corrupt. See inner exception for details."
  339. ) from exc
  340. else:
  341. self._state = _RendezvousState()
  342. if has_set and self._dead_nodes and log.isEnabledFor(logging.DEBUG):
  343. node_list = ", ".join(f"'{dead_node}'" for dead_node in self._dead_nodes)
  344. msg = (
  345. f"As part of the sync operation the node(s) {node_list} have been removed from the "
  346. f"rendezvous '{self._settings.run_id}' since they had no heartbeat."
  347. )
  348. self._record(message=msg)
  349. log.debug(msg)
  350. self._token = token
  351. self._dirty = False
  352. self._last_sync_time = time.monotonic()
  353. self._sanitize()
  354. return has_set
  355. def _sanitize(self) -> None:
  356. state = self._state
  357. expire_time = datetime.utcnow() - (
  358. self._settings.keep_alive_interval * self._settings.keep_alive_max_attempt
  359. )
  360. # Filter out the dead nodes.
  361. self._dead_nodes = [
  362. node
  363. for node, last_heartbeat in state.last_heartbeats.items()
  364. if last_heartbeat < expire_time
  365. ]
  366. participant_removed = False
  367. for dead_node in self._dead_nodes:
  368. del state.last_heartbeats[dead_node]
  369. try:
  370. del state.participants[dead_node]
  371. participant_removed = True
  372. except KeyError:
  373. pass
  374. try:
  375. state.wait_list.remove(dead_node)
  376. except KeyError:
  377. pass
  378. if participant_removed:
  379. # Common epilogue shared with the _remove_from_participants()
  380. # function of _DistributedRendezvousOpExecutor.
  381. _remove_participant_epilogue(state, self._settings)
  382. def mark_dirty(self) -> None:
  383. """See base class.
  384. If the local rendezvous state is dirty, the next sync call will try to
  385. write the changes back to the backend. However this attempt might fail
  386. if another node, which had the same state, also made changes and wrote
  387. them before us.
  388. """
  389. self._dirty = True
  390. class _Action(Enum):
  391. """Specifies the possible actions based on the state of the rendezvous."""
  392. KEEP_ALIVE = 1
  393. ADD_TO_PARTICIPANTS = 2
  394. ADD_TO_WAIT_LIST = 3
  395. REMOVE_FROM_PARTICIPANTS = 4
  396. REMOVE_FROM_WAIT_LIST = 5
  397. MARK_RENDEZVOUS_COMPLETE = 6
  398. MARK_RENDEZVOUS_CLOSED = 7
  399. SYNC = 8
  400. ERROR_CLOSED = 9
  401. ERROR_TIMEOUT = 10
  402. FINISH = 11
  403. class _RendezvousContext:
  404. """Holds the context of the rendezvous.
  405. Attributes:
  406. node:
  407. The node descriptor associated with the current rendezvous handler
  408. instance.
  409. state:
  410. The current state of the rendezvous.
  411. settings:
  412. The rendezvous settings.
  413. """
  414. node: _NodeDesc
  415. state: _RendezvousState
  416. settings: RendezvousSettings
  417. def __init__(
  418. self, node: _NodeDesc, state: _RendezvousState, settings: RendezvousSettings
  419. ) -> None:
  420. self.node = node
  421. self.state = state
  422. self.settings = settings
  423. class _RendezvousOpExecutor(ABC):
  424. """Executes rendezvous operations."""
  425. @abstractmethod
  426. def run(
  427. self,
  428. state_handler: Callable[[_RendezvousContext, float], _Action],
  429. deadline: float,
  430. ) -> None:
  431. """Executes a rendezvous operation.
  432. An operation is run inside a state machine and is expected to transition
  433. the rendezvous from one state to another.
  434. Args:
  435. state_handler:
  436. A callable that is expected to return the next state transition
  437. action based on the current state of the rendezvous.
  438. deadline:
  439. The time, in seconds, at which the operation will be considered
  440. timed-out.
  441. """
  442. class _DistributedRendezvousOpExecutor(_RendezvousOpExecutor):
  443. """Executes rendezvous operations using a shared state.
  444. Args:
  445. node:
  446. The node descriptor associated with the current rendezvous handler
  447. instance.
  448. state_holder:
  449. The ``RendezvousStateHolder`` to use to sync the rendezvous state
  450. with other nodes.
  451. settings:
  452. The rendezvous settings.
  453. """
  454. _node: _NodeDesc
  455. _state: _RendezvousState
  456. _state_holder: _RendezvousStateHolder
  457. _settings: RendezvousSettings
  458. def __init__(
  459. self,
  460. node: _NodeDesc,
  461. state_holder: _RendezvousStateHolder,
  462. settings: RendezvousSettings,
  463. ) -> None:
  464. self._node = node
  465. self._state_holder = state_holder
  466. self._settings = settings
  467. def _record(self, message: str, node_state: NodeState = NodeState.RUNNING) -> None:
  468. construct_and_record_rdzv_event(
  469. name=f"{self.__class__.__name__}.{get_method_name()}",
  470. run_id=self._settings.run_id,
  471. message=message,
  472. node_state=node_state,
  473. hostname=self._node.addr,
  474. pid=self._node.pid,
  475. local_id=self._node.local_id,
  476. )
  477. def run(
  478. self,
  479. state_handler: Callable[[_RendezvousContext, float], _Action],
  480. deadline: float,
  481. ) -> None:
  482. """See base class."""
  483. action = None
  484. while action != _Action.FINISH:
  485. # Reads or writes the latest rendezvous state shared by all nodes in
  486. # the rendezvous. Note that our local changes might get overridden
  487. # by another node if that node synced its changes before us.
  488. has_set = self._state_holder.sync()
  489. if has_set is not None:
  490. if has_set:
  491. msg = (
  492. f"The node '{self._node}' has successfully synced its local changes with "
  493. f"other nodes in the rendezvous '{self._settings.run_id}'."
  494. )
  495. else:
  496. msg = (
  497. f"The node '{self._node}' has a stale state and failed to sync its local "
  498. f"changes with other nodes in the rendezvous '{self._settings.run_id}'."
  499. )
  500. self._record(message=msg)
  501. log.debug(msg)
  502. self._state = self._state_holder.state
  503. ctx = _RendezvousContext(self._node, self._state, self._settings)
  504. # Determine the next action to take based on the current state of
  505. # the rendezvous.
  506. action = state_handler(ctx, deadline)
  507. if action == _Action.FINISH:
  508. continue
  509. if action == _Action.ERROR_CLOSED:
  510. raise RendezvousClosedError()
  511. if action == _Action.ERROR_TIMEOUT:
  512. raise RendezvousTimeoutError()
  513. if action == _Action.SYNC:
  514. # Delay the execution by one second to avoid overloading the
  515. # backend if we are asked to poll for state changes.
  516. _delay(seconds=1)
  517. else:
  518. if action == _Action.KEEP_ALIVE:
  519. self._keep_alive()
  520. elif action == _Action.ADD_TO_PARTICIPANTS:
  521. self._add_to_participants()
  522. elif action == _Action.ADD_TO_WAIT_LIST:
  523. self._add_to_wait_list()
  524. elif action == _Action.REMOVE_FROM_PARTICIPANTS:
  525. self._remove_from_participants()
  526. elif action == _Action.REMOVE_FROM_WAIT_LIST:
  527. self._remove_from_wait_list()
  528. elif action == _Action.MARK_RENDEZVOUS_COMPLETE:
  529. self._mark_rendezvous_complete()
  530. elif action == _Action.MARK_RENDEZVOUS_CLOSED:
  531. self._mark_rendezvous_closed()
  532. # Attempt to sync our changes back to other nodes.
  533. self._state_holder.mark_dirty()
  534. def _keep_alive(self) -> None:
  535. msg = (
  536. f"The node '{self._node}' updated its keep-alive heartbeat time for the rendezvous "
  537. f"'{self._settings.run_id}'. Pending sync."
  538. )
  539. self._record(message=msg)
  540. log.debug(msg)
  541. self._state.last_heartbeats[self._node] = datetime.utcnow()
  542. def _add_to_participants(self) -> None:
  543. msg = (
  544. f"The node '{self._node}' added itself to the participants of round "
  545. f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
  546. )
  547. self._record(message=msg)
  548. log.debug(msg)
  549. state = self._state
  550. try:
  551. state.wait_list.remove(self._node)
  552. except KeyError:
  553. pass
  554. # The ranks of the participants will be set once the rendezvous is
  555. # complete.
  556. state.participants[self._node] = 0
  557. self._keep_alive()
  558. if len(state.participants) == self._settings.min_nodes:
  559. state.deadline = datetime.utcnow() + self._settings.timeout.last_call
  560. if len(state.participants) == self._settings.max_nodes:
  561. self._mark_rendezvous_complete()
  562. def _add_to_wait_list(self) -> None:
  563. msg = (
  564. f"The node '{self._node}' added itself to the wait list of round "
  565. f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
  566. )
  567. self._record(message=msg)
  568. log.debug(msg)
  569. self._state.wait_list.add(self._node)
  570. self._keep_alive()
  571. def _remove_from_participants(self) -> None:
  572. msg = (
  573. f"The node '{self._node}' removed itself from the participants of round "
  574. f"{self._state.round} of the rendezvous '{self._settings.run_id}'. Pending sync."
  575. )
  576. self._record(message=msg)
  577. log.debug(msg)
  578. state = self._state
  579. del state.participants[self._node]
  580. del state.last_heartbeats[self._node]
  581. # Common epilogue shared with the sanitizer() function of
  582. # _BackendRendezvousStateHolder.
  583. _remove_participant_epilogue(state, self._settings)
  584. def _remove_from_wait_list(self) -> None:
  585. msg = (
  586. f"The node '{self._node}' removed itself from the wait list of round "
  587. f"{self._state.round + 1} of the rendezvous '{self._settings.run_id}'. Pending sync."
  588. )
  589. self._record(message=msg)
  590. log.debug(msg)
  591. self._state.wait_list.remove(self._node)
  592. del self._state.last_heartbeats[self._node]
  593. def _mark_rendezvous_complete(self) -> None:
  594. msg = (
  595. f"The node '{self._node}' marked round {self._state.round} of the rendezvous "
  596. f"'{self._settings.run_id}' as complete. Pending sync."
  597. )
  598. self._record(message=msg, node_state=NodeState.SUCCEEDED)
  599. log.debug(msg)
  600. state = self._state
  601. state.complete = True
  602. state.deadline = None
  603. # Assign the ranks.
  604. for rank, node in enumerate(sorted(state.participants)):
  605. state.participants[node] = rank
  606. def _mark_rendezvous_closed(self) -> None:
  607. msg = (
  608. f"The node '{self._node}' marked the rendezvous '{self._settings.run_id}' as closed. "
  609. "Pending sync."
  610. )
  611. self._record(message=msg, node_state=NodeState.SUCCEEDED)
  612. log.debug(msg)
  613. self._state.closed = True
  614. def _should_keep_alive(ctx: _RendezvousContext) -> bool:
  615. """Determines whether a keep-alive heartbeat should be sent."""
  616. try:
  617. last_heartbeat = ctx.state.last_heartbeats[ctx.node]
  618. except KeyError:
  619. return False
  620. return last_heartbeat <= datetime.utcnow() - ctx.settings.keep_alive_interval
  621. class _RendezvousExitOp:
  622. """Represents a rendezvous exit operation."""
  623. def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
  624. if ctx.node in ctx.state.participants:
  625. if time.monotonic() > deadline:
  626. return _Action.ERROR_TIMEOUT
  627. return _Action.REMOVE_FROM_PARTICIPANTS
  628. return _Action.FINISH
  629. class _RendezvousJoinOp:
  630. """Represents a rendezvous join operation."""
  631. def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
  632. state = ctx.state
  633. # A closed rendezvous means that it no longer accepts new nodes.
  634. if state.closed:
  635. return _Action.ERROR_CLOSED
  636. is_participant = ctx.node in state.participants
  637. # If we are part of the rendezvous and it is already complete there is
  638. # no further action to take.
  639. if state.complete and is_participant:
  640. return _Action.FINISH
  641. now = time.monotonic()
  642. if now > deadline:
  643. rollback_period = 5 # 5 seconds
  644. # If we still have time to rollback (a short period on top of the
  645. # operation deadline), try to remove ourself from the rendezvous.
  646. # It is okay if we can't though as our keep-alive will eventually
  647. # expire.
  648. if now <= deadline + rollback_period:
  649. # If we are part of the rendezvous, it means we couldn't find
  650. # enough participants to complete it on time.
  651. if is_participant:
  652. return _Action.REMOVE_FROM_PARTICIPANTS
  653. # If we are in the wait list, it means we couldn't wait till the
  654. # next round of the rendezvous.
  655. if ctx.node in state.wait_list:
  656. return _Action.REMOVE_FROM_WAIT_LIST
  657. return _Action.ERROR_TIMEOUT
  658. if state.complete:
  659. # If we are here, it means we are not part of the rendezvous. In
  660. # case the rendezvous has capacity for additional participants add
  661. # ourself to the wait list for the next round.
  662. if len(state.participants) < ctx.settings.max_nodes:
  663. if ctx.node not in state.wait_list:
  664. return _Action.ADD_TO_WAIT_LIST
  665. elif is_participant:
  666. # If the rendezvous has enough number of participants including us,
  667. # check whether we have passed the rendezvous deadline. If yes,
  668. # complete it.
  669. if len(state.participants) >= ctx.settings.min_nodes:
  670. if cast(datetime, state.deadline) < datetime.utcnow():
  671. return _Action.MARK_RENDEZVOUS_COMPLETE
  672. else:
  673. # The rendezvous is not complete yet and we are not part of it. Try
  674. # to join.
  675. return _Action.ADD_TO_PARTICIPANTS
  676. if _should_keep_alive(ctx):
  677. return _Action.KEEP_ALIVE
  678. # At this point either the rendezvous is not complete, but we are part
  679. # of it, which means we have to wait for other participants to join; or
  680. # the rendezvous is complete, but we are not part of it, which means we
  681. # have to wait for the next round.
  682. return _Action.SYNC
  683. class _RendezvousCloseOp:
  684. """Represents a rendezvous close operation."""
  685. def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
  686. if ctx.state.closed:
  687. return _Action.FINISH
  688. if time.monotonic() > deadline:
  689. return _Action.ERROR_TIMEOUT
  690. return _Action.MARK_RENDEZVOUS_CLOSED
  691. class _RendezvousKeepAliveOp:
  692. """Represents a rendezvous keep-alive update operation."""
  693. def __call__(self, ctx: _RendezvousContext, deadline: float) -> _Action:
  694. if _should_keep_alive(ctx):
  695. if time.monotonic() > deadline:
  696. return _Action.ERROR_TIMEOUT
  697. return _Action.KEEP_ALIVE
  698. return _Action.FINISH
  699. class DynamicRendezvousHandler(RendezvousHandler):
  700. """Represents a handler that sets up a rendezvous among a set of nodes."""
  701. # Static
  702. _node_desc_generator = _NodeDescGenerator()
  703. _this_node: _NodeDesc
  704. _settings: RendezvousSettings
  705. _backend_name: str
  706. _store: Store
  707. _state_holder: _RendezvousStateHolder
  708. _op_executor: _RendezvousOpExecutor
  709. _heartbeat_lock: threading.Lock
  710. _keep_alive_timer: Optional[_PeriodicTimer]
  711. @classmethod
  712. def from_backend(
  713. cls,
  714. run_id: str,
  715. store: Store,
  716. backend: RendezvousBackend,
  717. min_nodes: int,
  718. max_nodes: int,
  719. local_addr: Optional[str] = None,
  720. timeout: Optional[RendezvousTimeout] = None,
  721. ):
  722. """Creates a new :py:class:`DynamicRendezvousHandler`.
  723. Args:
  724. run_id:
  725. The run id of the rendezvous.
  726. store:
  727. The C10d store to return as part of the rendezvous.
  728. backend:
  729. The backend to use to hold the rendezvous state.
  730. min_nodes:
  731. The minimum number of nodes to admit to the rendezvous.
  732. max_nodes:
  733. The maximum number of nodes to admit to the rendezvous.
  734. local_addr:
  735. The local node adress.
  736. timeout:
  737. The timeout configuration of the rendezvous.
  738. """
  739. # We associate each handler instance with a unique node descriptor.
  740. node = cls._node_desc_generator.generate(local_addr)
  741. settings = RendezvousSettings(
  742. run_id,
  743. min_nodes,
  744. max_nodes,
  745. timeout or RendezvousTimeout(),
  746. keep_alive_interval=timedelta(seconds=5),
  747. keep_alive_max_attempt=3,
  748. )
  749. state_holder = _BackendRendezvousStateHolder(backend, settings)
  750. return cls(node, settings, backend.name, store, state_holder)
  751. def __init__(
  752. self,
  753. node: _NodeDesc,
  754. settings: RendezvousSettings,
  755. backend_name: str,
  756. store: Store,
  757. state_holder: _RendezvousStateHolder,
  758. ) -> None:
  759. if not settings.run_id:
  760. raise ValueError("The run id must be a non-empty string.")
  761. if settings.min_nodes < 1:
  762. raise ValueError(
  763. f"The minimum number of nodes ({settings.min_nodes}) must be greater than zero."
  764. )
  765. if settings.max_nodes < settings.min_nodes:
  766. raise ValueError(
  767. f"The maximum number of nodes ({settings.max_nodes}) must be greater than or equal "
  768. f"to the minimum number of nodes ({settings.min_nodes})."
  769. )
  770. self._this_node = node
  771. self._settings = settings
  772. self._backend_name = backend_name
  773. self._store = store
  774. self._state_holder = state_holder
  775. self._op_executor = _DistributedRendezvousOpExecutor(
  776. self._this_node, self._state_holder, self._settings
  777. )
  778. self._heartbeat_lock = threading.Lock()
  779. self._keep_alive_timer = None
  780. def _record(
  781. self,
  782. message: str,
  783. node_state: NodeState = NodeState.RUNNING,
  784. rank: Optional[int] = None,
  785. ) -> None:
  786. construct_and_record_rdzv_event(
  787. name=f"{self.__class__.__name__}.{get_method_name()}",
  788. run_id=self._settings.run_id,
  789. message=message,
  790. node_state=node_state,
  791. hostname=self._this_node.addr,
  792. pid=self._this_node.pid,
  793. local_id=self._this_node.local_id,
  794. rank=rank,
  795. )
  796. @property
  797. def settings(self) -> RendezvousSettings:
  798. """Gets the settings of the rendezvous."""
  799. return self._settings
  800. def get_backend(self) -> str:
  801. """See base class."""
  802. return self._backend_name
  803. def next_rendezvous(self) -> Tuple[Store, int, int]:
  804. """See base class."""
  805. msg = (
  806. f"The node '{self._this_node}' attempts to join the next round of the rendezvous "
  807. f"'{self._settings.run_id}'."
  808. )
  809. self._record(message=msg)
  810. log.info(msg)
  811. try:
  812. self._stop_heartbeats()
  813. # Delay the execution for a small random amount of time if this is our
  814. # first run. This will slightly skew the rendezvous attempts across the
  815. # nodes and reduce the load on the backend.
  816. if self._state_holder.state.round == 0:
  817. _delay(seconds=(0, 0.3))
  818. exit_op = _RendezvousExitOp()
  819. join_op = _RendezvousJoinOp()
  820. deadline = self._get_deadline(self._settings.timeout.join)
  821. self._op_executor.run(exit_op, deadline)
  822. self._op_executor.run(join_op, deadline)
  823. self._start_heartbeats()
  824. rank, world_size = self._get_world()
  825. store = self._get_store()
  826. except Exception as e:
  827. self._record(
  828. message=f"{type(e).__name__}: {str(e)}",
  829. node_state=NodeState.FAILED,
  830. )
  831. raise
  832. msg = (
  833. f"The node '{self._this_node}' has joined round {self._state_holder.state.round} of "
  834. f"the rendezvous '{self._settings.run_id}' as rank {rank} in a world of size "
  835. f"{world_size}."
  836. )
  837. self._record(message=msg, rank=rank)
  838. log.info(msg)
  839. return store, rank, world_size
  840. def is_closed(self) -> bool:
  841. """See base class."""
  842. try:
  843. with self._heartbeat_lock:
  844. self._state_holder.sync()
  845. return self._state_holder.state.closed
  846. except Exception as e:
  847. self._record(
  848. message=f"{type(e).__name__}: {str(e)}",
  849. node_state=NodeState.FAILED,
  850. )
  851. raise
  852. def set_closed(self) -> None:
  853. """See base class."""
  854. try:
  855. with self._heartbeat_lock:
  856. self._close()
  857. except Exception as e:
  858. self._record(
  859. message=f"{type(e).__name__}: {str(e)}",
  860. node_state=NodeState.FAILED,
  861. )
  862. raise
  863. def num_nodes_waiting(self) -> int:
  864. """See base class."""
  865. try:
  866. with self._heartbeat_lock:
  867. self._state_holder.sync()
  868. return len(self._state_holder.state.wait_list)
  869. except Exception as e:
  870. self._record(
  871. message=f"{type(e).__name__}: {str(e)}",
  872. node_state=NodeState.FAILED,
  873. )
  874. raise
  875. def get_run_id(self) -> str:
  876. """See base class."""
  877. return self._settings.run_id
  878. def shutdown(self) -> bool:
  879. """See base class."""
  880. self._stop_heartbeats()
  881. try:
  882. self._close()
  883. return True
  884. except RendezvousError as ex:
  885. msg = (
  886. f"The node '{self._this_node}' has failed to shutdown the rendezvous "
  887. f"'{self._settings.run_id}' due to an error of type {type(ex).__name__}."
  888. )
  889. self._record(message=msg, node_state=NodeState.FAILED)
  890. log.warning(msg)
  891. return False
  892. except Exception as e:
  893. self._record(
  894. message=f"{type(e).__name__}: {str(e)}",
  895. node_state=NodeState.FAILED,
  896. )
  897. raise
  898. def _close(self) -> None:
  899. op = _RendezvousCloseOp()
  900. deadline = self._get_deadline(self._settings.timeout.close)
  901. self._op_executor.run(op, deadline)
  902. msg = f"The node '{self._this_node}' has closed the rendezvous '{self._settings.run_id}'."
  903. self._record(message=msg, node_state=NodeState.SUCCEEDED)
  904. log.info(msg)
  905. @staticmethod
  906. def _keep_alive_weak(weak_self) -> None:
  907. self = weak_self()
  908. if self is not None:
  909. self._keep_alive()
  910. def _keep_alive(self) -> None:
  911. self._heartbeat_lock.acquire()
  912. op = _RendezvousKeepAliveOp()
  913. deadline = self._get_deadline(self._settings.timeout.heartbeat)
  914. try:
  915. self._op_executor.run(op, deadline)
  916. msg = (
  917. f"The node '{self._this_node}' has sent a keep-alive heartbeat to the rendezvous "
  918. f"'{self._settings.run_id}'."
  919. )
  920. self._record(message=msg)
  921. log.debug(msg)
  922. except RendezvousError as ex:
  923. msg = (
  924. f"The node '{self._this_node}' has failed to send a keep-alive heartbeat to the "
  925. f"rendezvous '{self._settings.run_id}' due to an error of type {type(ex).__name__}."
  926. )
  927. self._record(message=msg, node_state=NodeState.FAILED)
  928. log.warning(msg)
  929. finally:
  930. self._heartbeat_lock.release()
  931. def _start_heartbeats(self) -> None:
  932. self._keep_alive_timer = _PeriodicTimer(
  933. self._settings.keep_alive_interval, self._keep_alive_weak, weakref.ref(self)
  934. )
  935. self._keep_alive_timer.set_name(f"RendezvousKeepAliveTimer_{self._this_node.local_id}")
  936. self._keep_alive_timer.start()
  937. def _stop_heartbeats(self) -> None:
  938. if self._keep_alive_timer is None:
  939. return
  940. self._keep_alive_timer.cancel()
  941. def _get_world(self) -> Tuple[int, int]:
  942. state = self._state_holder.state
  943. return state.participants[self._this_node], len(state.participants)
  944. def _get_store(self) -> Store:
  945. key_prefix = f"torch.rendezvous.{self._settings.run_id}.{self._state_holder.state.round}"
  946. return PrefixStore(key_prefix, self._store)
  947. def _get_deadline(self, timeout: timedelta) -> float:
  948. return time.monotonic() + timeout.total_seconds()
  949. def _get_timeout(params: RendezvousParameters, key: str) -> Optional[timedelta]:
  950. timeout = params.get_as_int(key + "_timeout")
  951. if timeout is None:
  952. return None
  953. return timedelta(seconds=timeout)
  954. def create_handler(
  955. store: Store, backend: RendezvousBackend, params: RendezvousParameters
  956. ) -> DynamicRendezvousHandler:
  957. """Creates a new :py:class:`DynamicRendezvousHandler` from the specified
  958. parameters.
  959. Args:
  960. store:
  961. The C10d store to return as part of the rendezvous.
  962. backend:
  963. The backend to use to hold the rendezvous state.
  964. +-------------------+------------------------------------------------------+
  965. | Parameter | Description |
  966. +===================+======================================================+
  967. | join_timeout | The total time, in seconds, within which the |
  968. | | rendezvous is expected to complete. Defaults to 600 |
  969. | | seconds. |
  970. +-------------------+------------------------------------------------------+
  971. | last_call_timeout | An additional wait amount, in seconds, before |
  972. | | completing the rendezvous once the minimum number of |
  973. | | nodes has been reached. Defaults to 30 seconds. |
  974. +-------------------+------------------------------------------------------+
  975. | close_timeout | The time, in seconds, within which the rendezvous is |
  976. | | expected to close after a call to |
  977. | | :py:meth:`RendezvousHandler.set_closed` or |
  978. | | :py:meth:`RendezvousHandler.shutdown`. Defaults to |
  979. | | 30 seconds. |
  980. +-------------------+------------------------------------------------------+
  981. """
  982. try:
  983. timeout = RendezvousTimeout(
  984. _get_timeout(params, "join"),
  985. _get_timeout(params, "last_call"),
  986. _get_timeout(params, "close"),
  987. )
  988. return DynamicRendezvousHandler.from_backend(
  989. params.run_id,
  990. store,
  991. backend,
  992. params.min_nodes,
  993. params.max_nodes,
  994. params.local_addr,
  995. timeout,
  996. )
  997. except Exception as e:
  998. construct_and_record_rdzv_event(
  999. message=f"{type(e).__name__}: {str(e)}",
  1000. run_id=params.run_id,
  1001. node_state=NodeState.FAILED,
  1002. )
  1003. raise