api.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import abc
  8. import functools
  9. import json
  10. import os
  11. import signal
  12. import socket
  13. import time
  14. import traceback
  15. import warnings
  16. from contextlib import closing
  17. from dataclasses import dataclass, field
  18. from enum import Enum
  19. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  20. import torch.distributed.elastic.rendezvous as rdzv
  21. import torch.distributed.elastic.utils.store as store_util
  22. from torch.distributed import Store
  23. from torch.distributed.elastic.events import Event, EventSource, record
  24. from torch.distributed.elastic.metrics import prof, put_metric
  25. from torch.distributed.elastic.multiprocessing import (
  26. ProcessFailure,
  27. SignalException,
  28. Std,
  29. )
  30. from torch.distributed.elastic.utils.logging import get_logger
  31. __all__ = ['WorkerSpec', 'Worker', 'WorkerState', 'WorkerGroup', 'RunResult', 'ElasticAgent', 'SimpleElasticAgent']
  32. _TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state"
  33. DEFAULT_ROLE = "default"
  34. log = get_logger()
  35. @dataclass
  36. class WorkerSpec:
  37. """
  38. Contains blueprint information about a particular type of worker.
  39. For a given role, there must only exist a single worker spec.
  40. Worker spec is expected to be homogenous across all nodes (machine),
  41. that is each node runs the same number of workers for a particular spec.
  42. Args:
  43. role: user-defined role for the workers with this spec
  44. local_world_size: number local workers to run
  45. fn: (deprecated use entrypoint instead)
  46. entrypoint: worker function or command
  47. args: arguments to pass to ``entrypoint``
  48. rdzv_handler: handles rdzv for this set of workers
  49. max_restarts: number of max retries for the workers
  50. monitor_interval: monitor status of workers every ``n`` seconds
  51. master_port: fixed port to run the c10d store on rank 0
  52. if not specified then will chose a random free port
  53. master_addr: fixed master_addr to run the c10d store on rank 0
  54. if not specified then will chose hostname on agent rank 0
  55. redirects: redirect std streams to a file,
  56. selectively redirect for a particular
  57. local rank by passing a map
  58. tee: tees the specified std stream(s) to console + file,
  59. selectively tee for a particular local rank by passing a map,
  60. takes precedence over ``redirects`` settings.
  61. """
  62. role: str
  63. local_world_size: int
  64. rdzv_handler: rdzv.RendezvousHandler
  65. fn: Optional[Callable] = None
  66. # TODO @kiuk - make entrypoint a required field
  67. entrypoint: Union[Callable, str, None] = None
  68. args: Tuple = ()
  69. max_restarts: int = 3
  70. monitor_interval: float = 30.0
  71. master_port: Optional[int] = None
  72. master_addr: Optional[str] = None
  73. local_addr: Optional[str] = None
  74. redirects: Union[Std, Dict[int, Std]] = Std.NONE
  75. tee: Union[Std, Dict[int, Std]] = Std.NONE
  76. def __post_init__(self):
  77. assert self.local_world_size > 0
  78. assert self.monitor_interval > 0
  79. if self.fn:
  80. warnings.warn(
  81. "WorkerSpec.fn will be deprecated,"
  82. " please use WorkerSpec.entrypoint instead",
  83. category=DeprecationWarning,
  84. )
  85. self.entrypoint = self.fn
  86. assert self.entrypoint
  87. def get_entrypoint_name(self):
  88. """
  89. If the entrypoint is a function (e.g. ``Callable``) returns its ``__qualname__``,
  90. else if the entrypoint is a binary (e.g. ``str``), returns the binary name.
  91. """
  92. if isinstance(self.entrypoint, str):
  93. return os.path.basename(self.entrypoint)
  94. else:
  95. assert self.entrypoint is not None
  96. return self.entrypoint.__qualname__
  97. class Worker:
  98. """
  99. Represents a worker instance. Contrast this with ``WorkerSpec`` that
  100. represents the specifications of a worker. A ``Worker`` is created from
  101. a ``WorkerSpec``. A ``Worker`` is to a ``WorkerSpec`` as an object is to
  102. a class.
  103. The ``id`` of the worker is interpreted
  104. by the specific implementation of ``ElasticAgent``. For a local
  105. agent, it could be the ``pid (int)`` of the worker, for a remote
  106. agent it could be encoded as ``host:port (string)``.
  107. Args:
  108. id (Any): uniquely identifies a worker (interpreted by the agent)
  109. local_rank (int): local rank of the worker
  110. global_rank (int): global rank of the worker
  111. role_rank (int): rank of the worker across all workers that have the same role
  112. world_size (int): number of workers (globally)
  113. role_world_size (int): number of workers that have the same role
  114. """
  115. __slots__ = [
  116. "id",
  117. "local_rank",
  118. "global_rank",
  119. "role_rank",
  120. "world_size",
  121. "role_world_size",
  122. ]
  123. def __init__(
  124. self,
  125. local_rank: int,
  126. global_rank: int = -1,
  127. role_rank: int = -1,
  128. world_size: int = -1,
  129. role_world_size: int = -1,
  130. ):
  131. # unique identifier for this worker
  132. self.id: Any = None
  133. # rank of the worker among workers with the same role being monitored
  134. # by the same ``agent`` instance.
  135. self.local_rank: int = local_rank
  136. # rank of the worker among all the workers across all roles
  137. # across all ``agent`` instances.
  138. # Global rank is not stable between re-rendezvous.
  139. self.global_rank: int = global_rank
  140. # rank of the worker among all the workers with the same role
  141. # across all ``agent`` instances.
  142. # Role rank is not stable between re-rendezvous.
  143. self.role_rank: int = role_rank
  144. # total number of workers (globally). Due to elasticity
  145. # the world size may change between re-rendezvous.
  146. self.world_size: int = world_size
  147. # total number of workers that share the same role. Due to elasticity
  148. # the role world size may change between re-rendezvous.
  149. self.role_world_size: int = role_world_size
  150. def __str__(self):
  151. return (
  152. f"local_rank={self.local_rank},global_rank={self.global_rank}"
  153. f",role_rank={self.role_rank},world_size={self.world_size}"
  154. f",role_world_size={self.role_world_size}"
  155. )
  156. def __repr__(self):
  157. return str(self)
  158. class WorkerState(str, Enum):
  159. """
  160. State of the ``WorkerGroup``. Workers in a worker group change state as a unit.
  161. If a single worker in a worker group fails the entire set is considered
  162. failed::
  163. UNKNOWN - agent lost track of worker group state, unrecoverable
  164. INIT - worker group object created not yet started
  165. HEALTHY - workers running and healthy
  166. UNHEALTHY - workers running and unhealthy
  167. STOPPED - workers stopped (interrupted) by the agent
  168. SUCCEEDED - workers finished running (exit 0)
  169. FAILED - workers failed to successfully finish (exit !0)
  170. A worker group starts from an initial ``INIT`` state,
  171. then progresses to ``HEALTHY`` or ``UNHEALTHY`` states,
  172. and finally reaches a terminal ``SUCCEEDED`` or ``FAILED`` state.
  173. Worker groups can be interrupted and temporarily put into ``STOPPED`` state
  174. by the agent. Workers in ``STOPPED`` state are scheduled to be restarted
  175. in the near future by the agent. Some examples of workers being put into
  176. ``STOPPED`` state are:
  177. 1. Worker group failure|unhealthy observed
  178. 2. Membership change detected
  179. When actions (start, stop, rdzv, retry, etc) on worker group fails
  180. and results in the action being partially applied to the worker group
  181. the state will be ``UNKNOWN``. Typically this happens on uncaught/unhandled
  182. exceptions during state change events on the agent. The agent is not
  183. expected to recover worker groups in ``UNKNOWN`` state and is better off
  184. self terminating and allowing the job manager to retry the node.
  185. """
  186. UNKNOWN = "UNKNOWN"
  187. INIT = "INIT"
  188. HEALTHY = "HEALTHY"
  189. UNHEALTHY = "UNHEALTHY"
  190. STOPPED = "STOPPED"
  191. SUCCEEDED = "SUCCEEDED"
  192. FAILED = "FAILED"
  193. @staticmethod
  194. def is_running(state: "WorkerState") -> bool:
  195. """
  196. Returns:
  197. True if the worker state represents workers still running
  198. (e.g. that the process exists but not necessarily healthy).
  199. """
  200. return state in {WorkerState.HEALTHY, WorkerState.UNHEALTHY}
  201. class WorkerGroup:
  202. """
  203. Represents the set of ``Worker`` instances for the given ``WorkerSpec``
  204. managed by ``ElasticAgent``. Whether the worker group contains cross
  205. instance workers or not depends on the implementation of the agent.
  206. """
  207. __slots__ = ["spec", "workers", "store", "group_rank", "group_world_size", "state"]
  208. def __init__(self, spec: WorkerSpec):
  209. self.spec = spec
  210. self.workers = [Worker(local_rank=i) for i in range(self.spec.local_world_size)]
  211. # assigned after rdzv
  212. self.store = None
  213. self.group_rank = None
  214. self.group_world_size = None
  215. self.state = WorkerState.INIT
  216. class _RoleInstanceInfo:
  217. """
  218. The class is used by the agent to exchange the information with other agents.
  219. The information is used to determine the rank of the workers that agent
  220. manages in heterogeneous environments, where different agents can have
  221. different number of workers.
  222. """
  223. __slots__ = ["role", "rank", "local_world_size"]
  224. def __init__(self, role: str, rank: int, local_world_size: int):
  225. r"""
  226. Args:
  227. role (str): user-defined role for the workers with this spec
  228. rank (int): the rank of the agent
  229. local_world_size (int): number of local workers to run
  230. """
  231. self.role = role
  232. self.rank = rank
  233. self.local_world_size = local_world_size
  234. def serialize(self) -> bytes:
  235. dict_data = {
  236. "role": self.role,
  237. "rank": self.rank,
  238. "local_world_size": self.local_world_size,
  239. }
  240. return json.dumps(dict_data).encode(encoding="UTF-8")
  241. @staticmethod
  242. def deserialize(data: bytes):
  243. dict_data = json.loads(data.decode(encoding="UTF-8"))
  244. return _RoleInstanceInfo(
  245. dict_data["role"], dict_data["rank"], dict_data["local_world_size"]
  246. )
  247. @staticmethod
  248. def compare(obj1, obj2) -> int:
  249. if obj1.role == obj2.role:
  250. return obj1.rank - obj2.rank
  251. elif obj1.role > obj2.role:
  252. return 1
  253. else:
  254. return -1
  255. @staticmethod
  256. def find_role_boundaries(roles_infos: List, role: str) -> Tuple[int, int]:
  257. start_idx, end_idx = -1, -1
  258. for idx, role_info in enumerate(roles_infos):
  259. if role_info.role == role:
  260. if start_idx == -1:
  261. start_idx = idx
  262. end_idx = idx
  263. return (start_idx, end_idx)
  264. @dataclass
  265. class RunResult:
  266. """
  267. Results returned by the worker executions. Run results follow an "all-or-nothing" policy
  268. where the run is successful if and only if ALL local workers managed by this agent
  269. complete successfully.
  270. If the result is successful (e.g. ``is_failed() = False``) then the ``return_values``
  271. field contains the outputs (return values) of the workers managed by THIS agent mapped
  272. by their GLOBAL ranks. That is ``result.return_values[0]`` is the return value of
  273. global rank 0.
  274. .. note:: ``return_values`` are only meaningful for when the worker entrypoint
  275. is a function. Workers specified as a binary entrypoint do not canonically
  276. have a return value and the ``return_values`` field is meaningless and
  277. may be empty.
  278. If ``is_failed()`` returns ``True`` then the ``failures`` field contains the
  279. failure information, again, mapped by the GLOBAL rank of the worker that failed.
  280. The keys in ``return_values`` and ``failures`` are mutually exclusive, that is,
  281. a worker's final state can only be one of: succeeded, failed. Workers intentionally
  282. terminated by the agent according to the agent's restart policy, are not represented
  283. in either ``return_values`` nor ``failures``.
  284. """
  285. state: WorkerState
  286. return_values: Dict[int, Any] = field(default_factory=dict)
  287. failures: Dict[int, ProcessFailure] = field(default_factory=dict)
  288. def is_failed(self) -> bool:
  289. return self.state == WorkerState.FAILED
  290. def _get_socket_with_port() -> socket.socket:
  291. """
  292. Returns a free port on localhost that is "reserved" by binding a temporary
  293. socket on it. Close the socket before passing the port to the entity
  294. that requires it. Usage example
  295. ::
  296. sock = _get_socket_with_port()
  297. with closing(sock):
  298. port = sock.getsockname()[1]
  299. sock.close()
  300. # there is still a race-condition that some other process
  301. # may grab this port before func() runs
  302. func(port)
  303. """
  304. addrs = socket.getaddrinfo(
  305. host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
  306. )
  307. for addr in addrs:
  308. family, type, proto, _, _ = addr
  309. s = socket.socket(family, type, proto)
  310. try:
  311. s.bind(("localhost", 0))
  312. s.listen(0)
  313. return s
  314. except OSError as e:
  315. s.close()
  316. log.info("Socket creation attempt failed.", exc_info=e)
  317. raise RuntimeError("Failed to create a socket")
  318. def _get_fq_hostname() -> str:
  319. return socket.getfqdn(socket.gethostname())
  320. class ElasticAgent(abc.ABC):
  321. """
  322. Agent process responsible for managing one or more worker processes.
  323. The worker processes are assumed to be regular distributed PyTorch scripts.
  324. When the worker process is created by the agent, the agent provides the
  325. necessary information for the worker processes to properly initialize
  326. a torch process group.
  327. The exact deployment topology and ratio of agent-to-worker is dependent
  328. on the specific implementation of the agent and the user's job placement
  329. preferences. For instance, to run a distributed training job on GPU with
  330. 8 trainers (one per GPU) one can:
  331. 1. Use 8 x single GPU instances, place an agent per instance, managing
  332. 1 worker per agent.
  333. 2. Use 4 x double GPU instances, place an agent per instance, managing
  334. 2 workers per agent.
  335. 3. Use 2 x quad GPU instances, place an agent per instance, managing
  336. 4 workers per agent.
  337. 4. Use 1 x 8 GPU instance, place an agent per instance, managing
  338. 8 workers per agent.
  339. Usage
  340. ::
  341. group_result = agent.run()
  342. if group_result.is_failed():
  343. # workers failed
  344. failure = group_result.failures[0]
  345. log.exception(f"worker 0 failed with exit code : {failure.exit_code}")
  346. else:
  347. return group_result.return_values[0] # return rank 0's results
  348. """
  349. @abc.abstractmethod
  350. def run(self, role: str = DEFAULT_ROLE) -> RunResult:
  351. """
  352. Runs the agent, retrying the worker group on failures up to
  353. ``max_restarts``.
  354. Returns:
  355. The result of the execution, containing the return values or
  356. failure details for each worker mapped by the worker's global rank.
  357. Raises:
  358. Exception - any other failures NOT related to worker process
  359. """
  360. raise NotImplementedError()
  361. @abc.abstractmethod
  362. def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
  363. """
  364. Returns:
  365. The ``WorkerGroup`` for the given ``role``.
  366. Note that the worker group is a mutable object and hence in a
  367. multi-threaded/process environment it may change state.
  368. Implementors are encouraged (but not required) to return
  369. a defensive read-only copy.
  370. """
  371. raise NotImplementedError()
  372. class SimpleElasticAgent(ElasticAgent):
  373. """
  374. An ``ElasticAgent`` that manages workers (``WorkerGroup``)
  375. for a single ``WorkerSpec`` (e.g. one particular type of worker role).
  376. """
  377. def __init__(self, spec: WorkerSpec, exit_barrier_timeout: float = 300):
  378. self._worker_group = WorkerGroup(spec)
  379. self._remaining_restarts = self._worker_group.spec.max_restarts
  380. self._store = None
  381. self._exit_barrier_timeout = exit_barrier_timeout
  382. self._total_execution_time = 0
  383. def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup:
  384. return self._worker_group
  385. @abc.abstractmethod
  386. def _start_workers(self, worker_group: WorkerGroup) -> Dict[int, Any]:
  387. r"""
  388. Starts ``worker_group.spec.local_world_size`` number of workers
  389. according to worker spec for the worker group .
  390. Returns a map of ``local_rank`` to worker ``id``.
  391. """
  392. raise NotImplementedError()
  393. @abc.abstractmethod
  394. def _stop_workers(self, worker_group: WorkerGroup) -> None:
  395. r"""
  396. Stops all workers in the given worker group. Implementors
  397. must deal with workers in all states defined by ``WorkerState``.
  398. That is, it must gracefully handle stopping non-existent workers,
  399. unhealthy (stuck) workers, etc.
  400. """
  401. raise NotImplementedError()
  402. @abc.abstractmethod
  403. def _monitor_workers(self, worker_group: WorkerGroup) -> RunResult:
  404. r"""
  405. Checks on the workers for the ``worker_group`` and returns
  406. the new state of the worker group.
  407. """
  408. raise NotImplementedError()
  409. @abc.abstractmethod
  410. def _shutdown(self, death_sig: signal.Signals = signal.SIGTERM) -> None:
  411. """
  412. Cleans up any resources that were allocated during the agent's work.
  413. Args:
  414. death_sig: Signal to send to the child process, SIGTERM is default
  415. """
  416. raise NotImplementedError()
  417. @staticmethod
  418. def _set_master_addr_port(
  419. store: Store,
  420. master_addr: Optional[str],
  421. master_port: Optional[int],
  422. local_addr: Optional[str],
  423. ):
  424. if master_port is None:
  425. sock = _get_socket_with_port()
  426. with closing(sock):
  427. master_port = sock.getsockname()[1]
  428. if master_addr is None:
  429. # If user specified the address for the local node, use it as the master addr if not exist
  430. if local_addr:
  431. master_addr = local_addr
  432. else:
  433. master_addr = _get_fq_hostname()
  434. store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
  435. store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8"))
  436. @staticmethod
  437. def _get_master_addr_port(store: Store) -> Tuple[str, int]:
  438. master_addr = store.get("MASTER_ADDR").decode(encoding="UTF-8")
  439. master_port = int(store.get("MASTER_PORT").decode(encoding="UTF-8"))
  440. return (master_addr, master_port)
  441. # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
  442. # `torch.distributed.elastic.metrics.prof`.
  443. @prof
  444. def _rendezvous(self, worker_group: WorkerGroup) -> None:
  445. r"""
  446. Runs rendezvous for the workers specified by worker spec.
  447. Assigns workers a new global rank and world size.
  448. Updates the rendezvous store for the worker group.
  449. """
  450. spec = worker_group.spec
  451. store, group_rank, group_world_size = spec.rdzv_handler.next_rendezvous()
  452. self._store = store
  453. workers = self._assign_worker_ranks(store, group_rank, group_world_size, spec)
  454. worker_group.workers = workers
  455. worker_group.store = store
  456. worker_group.group_rank = group_rank
  457. worker_group.group_world_size = group_world_size
  458. if group_rank == 0:
  459. self._set_master_addr_port(
  460. store,
  461. spec.master_addr,
  462. spec.master_port,
  463. spec.local_addr,
  464. )
  465. master_addr, master_port = self._get_master_addr_port(store)
  466. restart_count = spec.max_restarts - self._remaining_restarts
  467. log.info(
  468. f"[{spec.role}] Rendezvous complete for workers. Result:\n"
  469. f" restart_count={restart_count}\n"
  470. f" master_addr={master_addr}\n"
  471. f" master_port={master_port}\n"
  472. f" group_rank={group_rank}\n"
  473. f" group_world_size={group_world_size}\n"
  474. f" local_ranks={[worker.local_rank for worker in workers]}\n"
  475. f" role_ranks={[worker.role_rank for worker in workers]}\n"
  476. f" global_ranks={[worker.global_rank for worker in workers]}\n"
  477. f" role_world_sizes={[worker.role_world_size for worker in workers]}\n"
  478. f" global_world_sizes={[worker.world_size for worker in workers]}\n"
  479. )
  480. def _get_ranks(
  481. self,
  482. role_infos: List[_RoleInstanceInfo],
  483. role_idx: int,
  484. start_idx: int = 0,
  485. end_idx: int = -1,
  486. ) -> Tuple[int, List[int]]:
  487. if end_idx == -1:
  488. end_idx = len(role_infos)
  489. prefix_sum = 0
  490. total_sum = 0
  491. for idx in range(start_idx, end_idx):
  492. if role_idx > idx:
  493. prefix_sum += role_infos[idx].local_world_size
  494. total_sum += role_infos[idx].local_world_size
  495. return (
  496. total_sum,
  497. list(range(prefix_sum, prefix_sum + role_infos[role_idx].local_world_size)),
  498. )
  499. # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
  500. # `torch.distributed.elastic.metrics.prof`.
  501. @prof
  502. def _assign_worker_ranks(
  503. self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
  504. ) -> List[Worker]:
  505. """
  506. Determines proper ranks for worker processes. The rank assignment
  507. is done according to the following algorithm:
  508. 1. Each agent writes its configuration(group_rank, group_world_size
  509. , num_workers) to the common store.
  510. 2. Each agent retrieves configuration for all agents
  511. and performs two level sort using role and rank.
  512. 3. Determine the global rank: the global rank of the workers for the current
  513. agent is the offset of the infos array up to group_rank of the agent.
  514. The offset is computed as a sum of local_world_size of all agents that
  515. have rank less than the group_rank. The workers would have the ranks:
  516. [offset, offset+local_world_size)
  517. 4. Determine the role rank: The role rank is determined using the algorithms
  518. in the point 3 with the exception that the offset is done from the first
  519. agent that has the same role as current one and has the minimum group rank.
  520. """
  521. role_infos = self._share_and_gather(store, group_rank, group_world_size, spec)
  522. my_role_info = role_infos[group_rank]
  523. worker_world_size, worker_global_ranks = self._get_ranks(role_infos, group_rank)
  524. role_infos = sorted(
  525. role_infos, key=functools.cmp_to_key(_RoleInstanceInfo.compare)
  526. )
  527. role_start_idx, role_end_idx = _RoleInstanceInfo.find_role_boundaries(
  528. role_infos, my_role_info.role
  529. )
  530. role_pos = next(
  531. idx
  532. for idx, role_info in enumerate(role_infos)
  533. if _RoleInstanceInfo.compare(role_info, my_role_info) == 0
  534. )
  535. role_world_size, role_ranks = self._get_ranks(
  536. role_infos, role_pos, role_start_idx, role_end_idx + 1
  537. )
  538. workers = []
  539. for ind in range(spec.local_world_size):
  540. worker = Worker(
  541. local_rank=ind,
  542. global_rank=worker_global_ranks[ind],
  543. role_rank=role_ranks[ind],
  544. world_size=worker_world_size,
  545. role_world_size=role_world_size,
  546. )
  547. workers.append(worker)
  548. return workers
  549. def _share_and_gather(
  550. self, store, group_rank: int, group_world_size: int, spec: WorkerSpec
  551. ) -> List:
  552. agent_role_info = _RoleInstanceInfo(
  553. spec.role, group_rank, spec.local_world_size
  554. )
  555. key_prefix = "torchelastic/role_info"
  556. agent_config_enc = agent_role_info.serialize()
  557. role_infos_bytes = store_util.synchronize(
  558. store, agent_config_enc, group_rank, group_world_size, key_prefix
  559. )
  560. role_infos = [
  561. _RoleInstanceInfo.deserialize(role_info_bytes)
  562. for role_info_bytes in role_infos_bytes
  563. ]
  564. return role_infos
  565. # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
  566. # `torch.distributed.elastic.metrics.prof`.
  567. @prof
  568. def _initialize_workers(self, worker_group: WorkerGroup) -> None:
  569. r"""
  570. Starts a fresh set of workers for the worker_group.
  571. Essentially a rendezvous followed by a start_workers.
  572. The caller should first call ``_stop_workers()`` to stop running workers
  573. prior to calling this method.
  574. Optimistically sets the state of the worker group that
  575. just started as ``HEALTHY`` and delegates the actual monitoring
  576. of state to ``_monitor_workers()`` method
  577. """
  578. role = worker_group.spec.role
  579. log.info(f"[{role}] Rendezvous'ing worker group")
  580. # TODO after stopping workers, wait at least monitor_interval*2 for
  581. # workers on different nodes to fail on a collective op before waiting
  582. # on the rdzv barrier, this way we ensure that nodes enter rdzv
  583. # at around the same time and reduce false positive rdzv timeout errors
  584. self._rendezvous(worker_group)
  585. log.info(f"[{role}] Starting worker group")
  586. worker_ids = self._start_workers(worker_group)
  587. for local_rank, w_id in worker_ids.items():
  588. worker = worker_group.workers[local_rank]
  589. worker.id = w_id
  590. worker_group.state = WorkerState.HEALTHY
  591. # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
  592. # `torch.distributed.elastic.metrics.prof`.
  593. @prof
  594. def _restart_workers(self, worker_group: WorkerGroup) -> None:
  595. """
  596. Restarts (stops, rendezvous, starts) all local workers in the group.
  597. """
  598. role = worker_group.spec.role
  599. log.info(f"[{role}] Stopping worker group")
  600. self._stop_workers(worker_group)
  601. worker_group.state = WorkerState.STOPPED
  602. self._initialize_workers(worker_group)
  603. # pyre-fixme[56]: Pyre was not able to infer the type of the decorator
  604. # `torch.distributed.elastic.metrics.prof`.
  605. @prof
  606. def run(self, role: str = DEFAULT_ROLE) -> RunResult:
  607. start_time = time.monotonic()
  608. shutdown_called: bool = False
  609. try:
  610. result = self._invoke_run(role)
  611. self._total_execution_time = int(time.monotonic() - start_time)
  612. self._record_metrics(result)
  613. self._record_worker_events(result)
  614. return result
  615. except SignalException as e:
  616. log.warning(f"Received {e.sigval} death signal, shutting down workers")
  617. self._shutdown(e.sigval)
  618. shutdown_called = True
  619. raise
  620. finally:
  621. if not shutdown_called:
  622. self._shutdown()
  623. # record the execution time in case there were any exceptions during run.
  624. self._total_execution_time = int(time.monotonic() - start_time)
  625. def get_event_failed(self) -> Event:
  626. return self._construct_event(
  627. state="FAILED",
  628. source=EventSource.AGENT,
  629. raw_error=traceback.format_exc(),
  630. )
  631. def get_event_succeeded(self) -> Event:
  632. return self._construct_event(
  633. state="SUCCEEDED",
  634. source=EventSource.AGENT,
  635. )
  636. def _record_worker_events(self, result: RunResult) -> None:
  637. for worker in self._worker_group.workers:
  638. failure = result.failures.get(worker.global_rank)
  639. state: str = self._get_worker_state(worker, result)
  640. raw_error = json.dumps(failure.error_file_data) if failure else None
  641. record(self._construct_event(state, EventSource.WORKER, worker, raw_error))
  642. def _get_worker_state(self, worker: Worker, result: RunResult) -> str:
  643. failure = result.failures.get(worker.global_rank)
  644. if result.state in {WorkerState.UNHEALTHY, WorkerState.FAILED} and not failure:
  645. # The worker got terminated by the torchelastic agent via SIGTERM signal
  646. return "TERMINATED"
  647. elif failure or worker.global_rank in result.return_values:
  648. return result.state.value
  649. else:
  650. raise ValueError(f"Unknow worker: {worker.global_rank}")
  651. def _construct_event(
  652. self,
  653. state: str,
  654. source: EventSource,
  655. worker: Optional[Worker] = None,
  656. raw_error: Optional[str] = None,
  657. ) -> Event:
  658. wg = self._worker_group
  659. spec = wg.spec
  660. md = {
  661. "group_world_size": wg.group_world_size,
  662. "entry_point": spec.get_entrypoint_name(),
  663. }
  664. if worker:
  665. md["local_rank"] = (worker.local_rank,)
  666. md["role_rank"] = (worker.role_rank,)
  667. md["role_world_size"] = (worker.role_world_size,)
  668. global_rank = worker.global_rank
  669. worker_id = str(worker.id)
  670. else:
  671. global_rank = None
  672. worker_id = None
  673. md_str = json.dumps(md)
  674. metadata = {
  675. "run_id": spec.rdzv_handler.get_run_id(),
  676. "global_rank": global_rank,
  677. "group_rank": wg.group_rank,
  678. "worker_id": worker_id,
  679. "role": spec.role,
  680. "hostname": _get_fq_hostname(),
  681. "state": state,
  682. "total_run_time": self._total_execution_time,
  683. "rdzv_backend": spec.rdzv_handler.get_backend(),
  684. "raw_error": raw_error,
  685. "metadata": md_str,
  686. "agent_restarts": spec.max_restarts - self._remaining_restarts,
  687. }
  688. return Event(
  689. f"torchelastic.worker.status.{state}", source=source, metadata=metadata
  690. )
  691. def _record_metrics(self, group_results: RunResult):
  692. is_failed = group_results.is_failed()
  693. self._record_flakiness_metric(is_failed)
  694. spec = self._worker_group.spec
  695. restarts_happened = self._remaining_restarts != spec.max_restarts
  696. put_metric(f"workers.{spec.role}.run_total", 1)
  697. self._record_metric_with_condition(
  698. "run_success_with_retries", not is_failed and restarts_happened
  699. )
  700. self._record_metric_with_condition(
  701. "run_success_no_retries", not is_failed and not restarts_happened
  702. )
  703. self._record_metric_with_condition(
  704. "run_failed_with_retries", is_failed and restarts_happened
  705. )
  706. self._record_metric_with_condition(
  707. "run_failed_no_retries", is_failed and not restarts_happened
  708. )
  709. def _record_metric_with_condition(self, metric_name, condition):
  710. spec = self._worker_group.spec
  711. if condition:
  712. put_metric(f"workers.{spec.role}.{metric_name}", 1)
  713. else:
  714. put_metric(f"workers.{spec.role}.{metric_name}", 0)
  715. def _record_flakiness_metric(self, is_failed: bool = False):
  716. if is_failed:
  717. flakiness = 100.0
  718. else:
  719. spec = self._worker_group.spec
  720. flakiness = 100.0 - 100.0 * (self._remaining_restarts + 1) / (
  721. spec.max_restarts + 1
  722. )
  723. spec = self._worker_group.spec
  724. put_metric(f"workers.{spec.role}.flakiness", int(flakiness))
  725. def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult:
  726. # NOTE: currently only works for a single role
  727. spec = self._worker_group.spec
  728. role = spec.role
  729. log.info(
  730. f"[{role}] starting workers for entrypoint: {spec.get_entrypoint_name()}"
  731. )
  732. self._initialize_workers(self._worker_group)
  733. monitor_interval = spec.monitor_interval
  734. rdzv_handler = spec.rdzv_handler
  735. while True:
  736. assert self._worker_group.state != WorkerState.INIT
  737. time.sleep(monitor_interval)
  738. run_result = self._monitor_workers(self._worker_group)
  739. state = run_result.state
  740. self._worker_group.state = state
  741. put_metric(f"workers.{role}.remaining_restarts", self._remaining_restarts)
  742. put_metric(f"workers.{role}.{state.name.lower()}", 1)
  743. if state == WorkerState.SUCCEEDED:
  744. log.info(
  745. f"[{role}] worker group successfully finished."
  746. f" Waiting {self._exit_barrier_timeout} seconds for other agents to finish."
  747. )
  748. self._exit_barrier()
  749. return run_result
  750. elif state in {WorkerState.UNHEALTHY, WorkerState.FAILED}:
  751. if self._remaining_restarts > 0:
  752. log.info(
  753. f"[{role}] Worker group {state.name}. "
  754. f"{self._remaining_restarts}/{spec.max_restarts} attempts left;"
  755. f" will restart worker group"
  756. )
  757. self._remaining_restarts -= 1
  758. self._restart_workers(self._worker_group)
  759. else:
  760. self._stop_workers(self._worker_group)
  761. self._worker_group.state = WorkerState.FAILED
  762. self._exit_barrier()
  763. return run_result
  764. elif state == WorkerState.HEALTHY:
  765. # membership changes do not count as retries
  766. num_nodes_waiting = rdzv_handler.num_nodes_waiting()
  767. group_rank = self._worker_group.group_rank
  768. if num_nodes_waiting > 0:
  769. log.info(
  770. f"[{role}] Detected {num_nodes_waiting} "
  771. f"new nodes from group_rank={group_rank}; "
  772. f"will restart worker group"
  773. )
  774. self._restart_workers(self._worker_group)
  775. else:
  776. raise Exception(f"[{role}] Worker group in {state.name} state")
  777. def _exit_barrier(self):
  778. """
  779. Wait for ``exit_barrier_timeout`` seconds for all agents to finish
  780. executing their local workers (either successfully or not). This
  781. acts as a safety guard against user scripts that terminate at different
  782. times. This barrier keeps the agent process alive until all workers finish.
  783. """
  784. log.info(
  785. f"Local worker group finished ({self._worker_group.state}). "
  786. f"Waiting {self._exit_barrier_timeout} seconds for other agents to finish"
  787. )
  788. start = time.time()
  789. try:
  790. store_util.barrier(
  791. self._store,
  792. self._worker_group.group_rank,
  793. self._worker_group.group_world_size,
  794. key_prefix=_TERMINAL_STATE_SYNC_ID,
  795. barrier_timeout=self._exit_barrier_timeout,
  796. )
  797. log.info(
  798. f"Done waiting for other agents. Elapsed: {time.time() - start} seconds"
  799. )
  800. except SignalException as e:
  801. log.warning(f"Got termination signal: {e.sigval}")
  802. raise
  803. except Exception:
  804. log.exception(
  805. f"Error waiting on exit barrier. Elapsed: {time.time() - start} seconds"
  806. )