api.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import abc
  8. import logging
  9. import os
  10. import re
  11. import signal
  12. import subprocess
  13. import sys
  14. import time
  15. from contextlib import nullcontext
  16. from dataclasses import dataclass, field
  17. from enum import IntFlag
  18. from multiprocessing import synchronize
  19. from types import FrameType
  20. from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
  21. import torch.multiprocessing as mp
  22. from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
  23. from torch.distributed.elastic.multiprocessing.redirects import (
  24. redirect_stderr,
  25. redirect_stdout,
  26. )
  27. from torch.distributed.elastic.multiprocessing.tail_log import TailLog
  28. IS_WINDOWS = sys.platform == "win32"
  29. IS_MACOS = sys.platform == "darwin"
  30. log = logging.getLogger(__name__)
  31. __all__ = ["SignalException", "Std", "to_map", "RunProcsResult", "PContext", "get_std_cm", "MultiprocessContext",
  32. "SubprocessHandler", "SubprocessContext"]
  33. class SignalException(Exception):
  34. """
  35. Exception is raised inside the torchelastic agent process by the termination handler
  36. if the death signal got received by the process.
  37. """
  38. def __init__(self, msg: str, sigval: signal.Signals) -> None:
  39. super().__init__(msg)
  40. self.sigval = sigval
  41. def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None:
  42. """Termination handler that raises exceptions on the main process.
  43. When the process receives death signal(SIGTERM, SIGINT), this termination handler will
  44. be invoked. It raises the ``SignalException`` exception that should be processed by the
  45. user code. Python does not terminate process after the termination handler is finished,
  46. so the exception should not be silently ignored, otherwise the process will never
  47. be terminated.
  48. """
  49. sigval = signal.Signals(signum)
  50. raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
  51. def _get_kill_signal() -> signal.Signals:
  52. """
  53. Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows.
  54. """
  55. if IS_WINDOWS:
  56. return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
  57. else:
  58. return signal.SIGKILL
  59. def _get_default_signal() -> signal.Signals:
  60. """
  61. Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.
  62. """
  63. if IS_WINDOWS:
  64. return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
  65. else:
  66. return signal.SIGTERM
  67. def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
  68. actual_keys = set(d.keys())
  69. expected_keys = set(range(nprocs))
  70. if actual_keys != expected_keys:
  71. raise RuntimeError(
  72. f"{what}, local rank mapping mismatch,"
  73. f" expected: {expected_keys}, actual: {actual_keys}"
  74. )
  75. _MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$"
  76. _VALUE_REGEX = r"^[0123]$"
  77. class Std(IntFlag):
  78. NONE = 0
  79. OUT = 1
  80. ERR = 2
  81. ALL = OUT | ERR
  82. @classmethod
  83. def from_str(cls, vm: str) -> Union["Std", Dict[int, "Std"]]:
  84. """
  85. Example:
  86. ::
  87. from_str("0") -> Std.NONE
  88. from_str("1") -> Std.OUT
  89. from_str("0:3,1:0,2:1,3:2") -> {0: Std.ALL, 1: Std.NONE, 2: Std.OUT, 3: Std.ERR}
  90. Any other input raises an exception
  91. """
  92. def to_std(v: str) -> Std: # type: ignore[return]
  93. s = Std(int(v))
  94. if s in Std:
  95. return s
  96. # return None -> should NEVER reach here since we regex check input
  97. if re.match(_VALUE_REGEX, vm): # vm is a number (e.g. 0)
  98. return to_std(vm)
  99. elif re.match(_MAPPING_REGEX, vm): # vm is a mapping (e.g. 0:1,1:2)
  100. d: Dict[int, Std] = {}
  101. for m in vm.split(","):
  102. i, v = m.split(":")
  103. d[int(i)] = to_std(v)
  104. return d
  105. else:
  106. raise ValueError(
  107. f"{vm} does not match: <{_VALUE_REGEX}> or <{_MAPPING_REGEX}>"
  108. )
  109. def to_map(
  110. val_or_map: Union[Std, Dict[int, Std]], local_world_size: int
  111. ) -> Dict[int, Std]:
  112. """
  113. Certain APIs take redirect settings either as a single value (e.g. apply to all
  114. local ranks) or as an explicit user-provided mapping. This method is a convenience
  115. method that converts a value or mapping into a mapping.
  116. Example:
  117. ::
  118. to_map(Std.OUT, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
  119. to_map({1: Std.OUT}, local_world_size=2) # returns: {0: Std.NONE, 1: Std.OUT}
  120. to_map({0: Std.OUT, 1: Std.OUT}, local_world_size=2) # returns: {0: Std.OUT, 1: Std.OUT}
  121. """
  122. if isinstance(val_or_map, Std):
  123. return {i: val_or_map for i in range(local_world_size)}
  124. else:
  125. map = {}
  126. for i in range(local_world_size):
  127. map[i] = val_or_map.get(i, Std.NONE)
  128. return map
  129. @dataclass
  130. class RunProcsResult:
  131. """
  132. Results of a completed run of processes started with ``start_processes()``.
  133. Returned by ``PContext``.
  134. Note the following:
  135. 1. All fields are mapped by local rank
  136. 2. ``return_values`` - only populated for functions (not the binaries).
  137. 3. ``stdouts`` - path to stdout.log (empty string if no redirect)
  138. 4. ``stderrs`` - path to stderr.log (empty string if no redirect)
  139. """
  140. return_values: Dict[int, Any] = field(default_factory=dict)
  141. failures: Dict[int, ProcessFailure] = field(default_factory=dict)
  142. stdouts: Dict[int, str] = field(default_factory=dict)
  143. stderrs: Dict[int, str] = field(default_factory=dict)
  144. def is_failed(self) -> bool:
  145. return len(self.failures) > 0
  146. class PContext(abc.ABC):
  147. """
  148. The base class that standardizes operations over a set of processes
  149. that are launched via different mechanisms. The name ``PContext``
  150. is intentional to disambiguate with ``torch.multiprocessing.ProcessContext``.
  151. .. warning:: stdouts and stderrs should ALWAYS be a superset of
  152. tee_stdouts and tee_stderrs (respectively) this is b/c
  153. tee is implemented as a redirect + tail -f <stdout/stderr.log>
  154. """
  155. def __init__(
  156. self,
  157. name: str,
  158. entrypoint: Union[Callable, str],
  159. args: Dict[int, Tuple],
  160. envs: Dict[int, Dict[str, str]],
  161. stdouts: Dict[int, str],
  162. stderrs: Dict[int, str],
  163. tee_stdouts: Dict[int, str],
  164. tee_stderrs: Dict[int, str],
  165. error_files: Dict[int, str],
  166. ):
  167. self.name = name
  168. # validate that all mappings have the same number of keys and
  169. # all local ranks are accounted for
  170. nprocs = len(args)
  171. _validate_full_rank(stdouts, nprocs, "stdouts")
  172. _validate_full_rank(stderrs, nprocs, "stderrs")
  173. self.entrypoint = entrypoint
  174. self.args = args
  175. self.envs = envs
  176. self.stdouts = stdouts
  177. self.stderrs = stderrs
  178. self.error_files = error_files
  179. self.nprocs = nprocs
  180. self._stdout_tail = TailLog(name, tee_stdouts, sys.stdout)
  181. self._stderr_tail = TailLog(name, tee_stderrs, sys.stderr)
  182. def start(self) -> None:
  183. """
  184. Start processes using parameters defined in the constructor.
  185. """
  186. signal.signal(signal.SIGTERM, _terminate_process_handler)
  187. signal.signal(signal.SIGINT, _terminate_process_handler)
  188. if not IS_WINDOWS:
  189. signal.signal(signal.SIGHUP, _terminate_process_handler)
  190. signal.signal(signal.SIGQUIT, _terminate_process_handler)
  191. self._start()
  192. self._stdout_tail.start()
  193. self._stderr_tail.start()
  194. @abc.abstractmethod
  195. def _start(self) -> None:
  196. """
  197. Start processes using strategy defined in a particular context.
  198. """
  199. raise NotImplementedError()
  200. @abc.abstractmethod
  201. def _poll(self) -> Optional[RunProcsResult]:
  202. """
  203. Polls the run status of the processes running under this context.
  204. This method follows an "all-or-nothing" policy and returns
  205. a ``RunProcessResults`` object if either all processes complete
  206. successfully or any process fails. Returns ``None`` if
  207. all processes are still running.
  208. """
  209. raise NotImplementedError()
  210. def wait(self, timeout: float = -1, period: float = 1) -> Optional[RunProcsResult]:
  211. """
  212. Waits for the specified ``timeout`` seconds, polling every ``period`` seconds
  213. for the processes to be done. Returns ``None`` if the processes are still running
  214. on timeout expiry. Negative timeout values are interpreted as "wait-forever".
  215. A timeout value of zero simply queries the status of the processes (e.g. equivalent
  216. to a poll).
  217. ..note: Multiprocesing library registers SIGTERM and SIGINT signal handlers that raise
  218. ``SignalException`` when the signals received. It is up to the consumer of the code
  219. to properly handle the exception. It is important not to swallow the exception otherwise
  220. the process would not terminate. Example of the typical workflow can be:
  221. .. code-block:: python
  222. pc = start_processes(...)
  223. try:
  224. pc.wait(1)
  225. .. do some other work
  226. except SignalException as e:
  227. pc.shutdown(e.sigval, timeout=30)
  228. If SIGTERM or SIGINT occurs, the code above will try to shutdown child processes by propagating
  229. received signal. If child processes will not terminate in the timeout time, the process will send
  230. the SIGKILL.
  231. """
  232. if timeout == 0:
  233. return self._poll()
  234. if timeout < 0:
  235. timeout = sys.maxsize
  236. expiry = time.time() + timeout
  237. while time.time() < expiry:
  238. pr = self._poll()
  239. if pr:
  240. return pr
  241. time.sleep(period)
  242. return None
  243. @abc.abstractmethod
  244. def pids(self) -> Dict[int, int]:
  245. """
  246. Returns pids of processes mapped by their respective local_ranks
  247. """
  248. raise NotImplementedError()
  249. @abc.abstractmethod
  250. def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
  251. r"""
  252. Terminates all processes managed by this context and cleans up any
  253. meta resources (e.g. redirect, error_file files).
  254. """
  255. raise NotImplementedError()
  256. def close(
  257. self, death_sig: Optional[signal.Signals] = None, timeout: int = 30
  258. ) -> None:
  259. r"""
  260. Terminates all processes managed by this context and cleans up any
  261. meta resources (e.g. redirect, error_file files).
  262. Args:
  263. death_sig: Death signal to terminate porcesses.
  264. timeout: Time to wait for processes to finish, if process is
  265. still alive after this time, it will be terminated via SIGKILL.
  266. """
  267. if not death_sig:
  268. death_sig = _get_default_signal()
  269. self._close(death_sig=death_sig, timeout=timeout)
  270. if self._stdout_tail:
  271. self._stdout_tail.stop()
  272. if self._stderr_tail:
  273. self._stderr_tail.stop()
  274. def get_std_cm(std_rd: str, redirect_fn):
  275. if IS_WINDOWS or IS_MACOS or not std_rd:
  276. return nullcontext()
  277. else:
  278. return redirect_fn(std_rd)
  279. def _wrap(
  280. local_rank: int,
  281. fn: Callable,
  282. args: Dict[int, Tuple],
  283. envs: Dict[int, Dict[str, str]],
  284. stdout_redirects: Dict[int, str], # redirect file for stdout (to console if None)
  285. stderr_redirects: Dict[int, str], # redirect file for stderr (to console if None)
  286. ret_vals: Dict[int, mp.SimpleQueue],
  287. queue_finished_reading_event: synchronize.Event,
  288. ) -> None:
  289. # get the per-rank params up front so we fail fast if no mapping is found
  290. args_ = args[local_rank]
  291. env_ = envs[local_rank]
  292. ret_val_ = ret_vals[local_rank]
  293. stdout_rd = stdout_redirects[local_rank]
  294. stderr_rd = stderr_redirects[local_rank]
  295. stdout_cm = get_std_cm(stdout_rd, redirect_stdout)
  296. stderr_cm = get_std_cm(stderr_rd, redirect_stderr)
  297. for k, v in env_.items():
  298. os.environ[k] = v
  299. with stdout_cm, stderr_cm:
  300. ret = record(fn)(*args_)
  301. ret_val_.put(ret)
  302. queue_finished_reading_event.wait()
  303. class MultiprocessContext(PContext):
  304. """
  305. ``PContext`` holding worker processes invoked as a function.
  306. """
  307. def __init__(
  308. self,
  309. name: str,
  310. entrypoint: Callable,
  311. args: Dict[int, Tuple],
  312. envs: Dict[int, Dict[str, str]],
  313. stdouts: Dict[int, str],
  314. stderrs: Dict[int, str],
  315. tee_stdouts: Dict[int, str],
  316. tee_stderrs: Dict[int, str],
  317. error_files: Dict[int, str],
  318. start_method: str,
  319. ):
  320. super().__init__(
  321. name,
  322. entrypoint,
  323. args,
  324. envs,
  325. stdouts,
  326. stderrs,
  327. tee_stdouts,
  328. tee_stderrs,
  329. error_files,
  330. )
  331. self.start_method = start_method
  332. # each ret_val queue will always contain a single element.
  333. self._ret_vals = {
  334. local_rank: mp.get_context(self.start_method).SimpleQueue()
  335. for local_rank in range(self.nprocs)
  336. }
  337. # see comments in ``join()`` for what this is
  338. self._return_values: Dict[int, Any] = {}
  339. self._pc: Optional[mp.ProcessContext] = None
  340. # Note: set method should ONLY be invoked for the use case when all processes finished
  341. # successfully. If any process died on event.wait() calling set() method will deadlock.
  342. self._worker_finished_event = mp.get_context(self.start_method).Event()
  343. def _start(self):
  344. if self._pc:
  345. raise ValueError(
  346. "The process context already initialized."
  347. " Most likely the start method got called twice."
  348. )
  349. self._pc = mp.start_processes(
  350. fn=_wrap,
  351. args=(
  352. self.entrypoint,
  353. self.args,
  354. self.envs,
  355. self.stdouts,
  356. self.stderrs,
  357. self._ret_vals,
  358. self._worker_finished_event,
  359. ),
  360. nprocs=self.nprocs,
  361. join=False,
  362. daemon=False,
  363. start_method=self.start_method,
  364. )
  365. def _is_done(self) -> bool:
  366. return len(self._return_values) == self.nprocs
  367. def _poll(self) -> Optional[RunProcsResult]:
  368. assert self._pc is not None # assertion for mypy type checker
  369. try:
  370. # torch.mp.ProcessContext Throws an Exception if some/all of
  371. # worker processes failed
  372. # timeout < 0 checks worker status and return immediately
  373. # Join will never return success since we use synchronize.Event to wait
  374. # for all processes to finish.
  375. self._pc.join(-1)
  376. # IMPORTANT: we use multiprocessing.Queue to carry worker return values
  377. # back to the parent, the worker process will wait before terminating
  378. # until all the buffered items are fed by the feeder thread to the underlying
  379. # pipe. Hence to prevent deadlocks on large return values,
  380. # we opportunistically try queue.get on each join call
  381. # See: https://docs.python.org/2/library/multiprocessing.html#all-platforms
  382. for local_rank in range(0, self.nprocs):
  383. return_queue = self._ret_vals[local_rank]
  384. if not return_queue.empty():
  385. # save the return values temporarily into a member var
  386. self._return_values[local_rank] = return_queue.get()
  387. if self._is_done():
  388. # we should ALWAYS have ALL the return values when all the processes are done
  389. self._worker_finished_event.set()
  390. # Wait untill all processes are finished. At this point workers finished executing
  391. # user function
  392. self._pc.join()
  393. _validate_full_rank(
  394. self._return_values, self.nprocs, "return_value queue"
  395. )
  396. self.close()
  397. return RunProcsResult(
  398. return_values=self._return_values,
  399. stdouts=self.stdouts,
  400. stderrs=self.stderrs,
  401. )
  402. else:
  403. return None
  404. except (mp.ProcessRaisedException, mp.ProcessExitedException) as e:
  405. failed_local_rank = e.error_index
  406. # entrypoint for MultiprocessContext will always be a Callable
  407. fn_name = self.entrypoint.__qualname__ # type: ignore[union-attr]
  408. failed_proc = self._pc.processes[failed_local_rank]
  409. error_filepath = self.error_files[failed_local_rank]
  410. log.error(
  411. f"failed (exitcode: {failed_proc.exitcode})"
  412. f" local_rank: {failed_local_rank} (pid: {e.pid})"
  413. f" of fn: {fn_name} (start_method: {self.start_method})",
  414. exc_info=True,
  415. )
  416. self.close()
  417. return RunProcsResult(
  418. failures={
  419. failed_local_rank: ProcessFailure(
  420. local_rank=failed_local_rank,
  421. pid=e.pid,
  422. exitcode=failed_proc.exitcode,
  423. error_file=error_filepath,
  424. )
  425. },
  426. stdouts=self.stdouts,
  427. stderrs=self.stderrs,
  428. )
  429. def pids(self) -> Dict[int, int]:
  430. assert self._pc is not None # assertion for mypy type checking
  431. return {local_rank: pid for local_rank, pid in enumerate(self._pc.pids())}
  432. def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
  433. if not self._pc:
  434. return
  435. for proc in self._pc.processes:
  436. if proc.is_alive():
  437. log.warning(f"Closing process {proc.pid} via signal {death_sig.name}")
  438. try:
  439. os.kill(proc.pid, death_sig)
  440. except ProcessLookupError:
  441. # If the process exited because of some reason,
  442. # `ProcessLookupError` will be rasied, it is safe to ignore it.
  443. pass
  444. end = time.monotonic() + timeout
  445. for proc in self._pc.processes:
  446. time_to_wait = end - time.monotonic()
  447. if time_to_wait <= 0:
  448. break
  449. proc.join(time_to_wait)
  450. for proc in self._pc.processes:
  451. if proc.is_alive():
  452. log.warning(
  453. f"Unable to shutdown process {proc.pid} via {death_sig}, forcefully exiting via {_get_kill_signal()}"
  454. )
  455. try:
  456. os.kill(proc.pid, _get_kill_signal())
  457. except ProcessLookupError:
  458. # If the process exited because of some reason,
  459. # `ProcessLookupError` will be rasied, it is safe to ignore it.
  460. pass
  461. proc.join()
  462. class SubprocessHandler:
  463. """
  464. Convenience wrapper around python's ``subprocess.Popen``. Keeps track of
  465. meta-objects associated to the process (e.g. stdout and stderr redirect fds).
  466. """
  467. def __init__(
  468. self,
  469. entrypoint: str,
  470. args: Tuple,
  471. env: Dict[str, str],
  472. stdout: str,
  473. stderr: str,
  474. ):
  475. self._stdout = open(stdout, "w") if stdout else None
  476. self._stderr = open(stderr, "w") if stderr else None
  477. # inherit parent environment vars
  478. env_vars = os.environ.copy()
  479. env_vars.update(env)
  480. args_str = (entrypoint, *[str(e) for e in args])
  481. self.proc: subprocess.Popen = self._popen(args_str, env_vars)
  482. def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen:
  483. return subprocess.Popen(
  484. # pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
  485. # _PathLike[str], bytes, str]], bytes, str]` for 1st param but got
  486. # `Tuple[str, *Tuple[Any, ...]]`.
  487. args=args,
  488. env=env,
  489. stdout=self._stdout,
  490. stderr=self._stderr,
  491. )
  492. def close(self, death_sig: Optional[signal.Signals] = None) -> None:
  493. if not death_sig:
  494. death_sig = _get_default_signal()
  495. self.proc.send_signal(death_sig)
  496. if self._stdout:
  497. self._stdout.close()
  498. if self._stderr:
  499. self._stderr.close()
  500. class SubprocessContext(PContext):
  501. """
  502. ``PContext`` holding worker processes invoked as a binary.
  503. """
  504. def __init__(
  505. self,
  506. name: str,
  507. entrypoint: str,
  508. args: Dict[int, Tuple],
  509. envs: Dict[int, Dict[str, str]],
  510. stdouts: Dict[int, str],
  511. stderrs: Dict[int, str],
  512. tee_stdouts: Dict[int, str],
  513. tee_stderrs: Dict[int, str],
  514. error_files: Dict[int, str],
  515. ):
  516. super().__init__(
  517. name,
  518. entrypoint,
  519. args,
  520. envs,
  521. stdouts,
  522. stderrs,
  523. tee_stdouts,
  524. tee_stderrs,
  525. error_files,
  526. )
  527. # state vector; _vdone[local_rank] -> is local_rank finished or not
  528. self._running_local_ranks: Set[int] = set(range(self.nprocs))
  529. self._failures: Dict[int, ProcessFailure] = {}
  530. self.subprocess_handlers: Dict[int, SubprocessHandler] = {}
  531. def _start(self):
  532. if self.subprocess_handlers:
  533. raise ValueError(
  534. "The subprocess handlers already initialized. Most likely the start method got called twice."
  535. )
  536. self.subprocess_handlers = {
  537. local_rank: SubprocessHandler(
  538. entrypoint=self.entrypoint, # type: ignore[arg-type] # entrypoint is always a str
  539. args=self.args[local_rank],
  540. env=self.envs[local_rank],
  541. stdout=self.stdouts[local_rank],
  542. stderr=self.stderrs[local_rank],
  543. )
  544. for local_rank in range(self.nprocs)
  545. }
  546. def _poll(self) -> Optional[RunProcsResult]:
  547. done_local_ranks = set()
  548. for local_rank in self._running_local_ranks:
  549. handler = self.subprocess_handlers[local_rank]
  550. exitcode = handler.proc.poll()
  551. if exitcode is not None:
  552. done_local_ranks.add(local_rank)
  553. if exitcode != 0: # failed or signaled
  554. self._failures[local_rank] = ProcessFailure(
  555. local_rank=local_rank,
  556. pid=handler.proc.pid,
  557. exitcode=exitcode,
  558. error_file=self.error_files[local_rank],
  559. )
  560. # else: --> succeeded; nothing to do
  561. self._running_local_ranks.difference_update(done_local_ranks)
  562. # if ALL procs are finished or ANY have failed
  563. if not self._running_local_ranks or self._failures:
  564. self.close() # terminate all running procs
  565. result = RunProcsResult(
  566. failures=self._failures,
  567. stdouts=self.stdouts,
  568. stderrs=self.stderrs,
  569. )
  570. if result.is_failed():
  571. first_failure = min(result.failures.values(), key=lambda f: f.timestamp)
  572. log.error(
  573. f"failed (exitcode: {first_failure.exitcode})"
  574. f" local_rank: {first_failure.local_rank} (pid: {first_failure.pid})"
  575. f" of binary: {self.entrypoint}"
  576. )
  577. else:
  578. # Populate return with dummy values. This provides consistency with MultiprocessingHandler
  579. result.return_values = {
  580. local_rank: None for local_rank in range(self.nprocs)
  581. }
  582. return result
  583. else: # there are no failures and procs still running
  584. return None
  585. def pids(self) -> Dict[int, int]:
  586. return {
  587. local_rank: sh.proc.pid
  588. for local_rank, sh in self.subprocess_handlers.items()
  589. }
  590. def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None:
  591. if not self.subprocess_handlers:
  592. return
  593. for handler in self.subprocess_handlers.values():
  594. if handler.proc.poll() is None:
  595. log.warning(
  596. f"Sending process {handler.proc.pid} closing signal {death_sig.name}"
  597. )
  598. handler.close(death_sig=death_sig)
  599. end = time.monotonic() + timeout
  600. for handler in self.subprocess_handlers.values():
  601. time_to_wait = end - time.monotonic()
  602. if time_to_wait <= 0:
  603. break
  604. try:
  605. handler.proc.wait(time_to_wait)
  606. except subprocess.TimeoutExpired:
  607. # Ignore the timeout expired exception, since
  608. # the child process will be forcefully terminated via SIGKILL
  609. pass
  610. for handler in self.subprocess_handlers.values():
  611. if handler.proc.poll() is None:
  612. log.warning(
  613. f"Unable to shutdown process {handler.proc.pid} via {death_sig}, forcefully exiting via {_get_kill_signal()}"
  614. )
  615. handler.close(death_sig=_get_kill_signal())
  616. handler.proc.wait()