__init__.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """
  8. Library that launches and manages ``n`` copies of worker subprocesses
  9. either specified by a function or a binary.
  10. For functions, it uses ``torch.multiprocessing`` (and therefore python
  11. ``multiprocessing``) to spawn/fork worker processes. For binaries it uses python
  12. ``subprocessing.Popen`` to create worker processes.
  13. Usage 1: Launching two trainers as a function
  14. ::
  15. from torch.distributed.elastic.multiprocessing import Std, start_processes
  16. def trainer(a, b, c):
  17. pass # train
  18. # runs two trainers
  19. # LOCAL_RANK=0 trainer(1,2,3)
  20. # LOCAL_RANK=1 trainer(4,5,6)
  21. ctx = start_processes(
  22. name="trainer",
  23. entrypoint=trainer,
  24. args={0: (1,2,3), 1: (4,5,6)},
  25. envs={0: {"LOCAL_RANK": 0}, 1: {"LOCAL_RANK": 1}},
  26. log_dir="/tmp/foobar",
  27. redirects=Std.ALL, # write all worker stdout/stderr to a log file
  28. tee={0: Std.ERR}, # tee only local rank 0's stderr to console
  29. )
  30. # waits for all copies of trainer to finish
  31. ctx.wait()
  32. Usage 2: Launching 2 echo workers as a binary
  33. ::
  34. # same as invoking
  35. # echo hello
  36. # echo world > stdout.log
  37. ctx = start_processes(
  38. name="echo"
  39. entrypoint="echo",
  40. log_dir="/tmp/foobar",
  41. args={0: "hello", 1: "world"},
  42. redirects={1: Std.OUT},
  43. )
  44. Just like ``torch.multiprocessing``, the return value of the function
  45. :func:`start_processes` is a process context (:class:`api.PContext`). If a function
  46. was launched, a :class:`api.MultiprocessContext` is returned and if a binary
  47. was launched a :class:`api.SubprocessContext` is returned. Both are specific
  48. implementations of the parent :class:`api.PContext` class.
  49. """
  50. import os
  51. from typing import Callable, Dict, Tuple, Union
  52. from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401
  53. MultiprocessContext,
  54. PContext,
  55. ProcessFailure,
  56. RunProcsResult,
  57. Std,
  58. SignalException,
  59. SubprocessContext,
  60. _validate_full_rank,
  61. to_map,
  62. )
  63. from torch.distributed.elastic.utils.logging import get_logger
  64. log = get_logger()
  65. def start_processes(
  66. name: str,
  67. entrypoint: Union[Callable, str],
  68. args: Dict[int, Tuple],
  69. envs: Dict[int, Dict[str, str]],
  70. log_dir: str,
  71. start_method: str = "spawn",
  72. redirects: Union[Std, Dict[int, Std]] = Std.NONE,
  73. tee: Union[Std, Dict[int, Std]] = Std.NONE,
  74. ) -> PContext:
  75. """
  76. Starts ``n`` copies of ``entrypoint`` processes with the provided options.
  77. ``entrypoint`` is either a ``Callable`` (function) or a ``str`` (binary).
  78. The number of copies is determined by the number of entries for ``args`` and
  79. ``envs`` arguments, which need to have the same key set.
  80. ``args`` and ``env`` parameters are the arguments and environment variables
  81. to pass down to the entrypoint mapped by the replica index (local rank).
  82. All local ranks must be accounted for.
  83. That is, the keyset should be ``{0,1,...,(nprocs-1)}``.
  84. .. note:: When the ``entrypoint`` is a binary (``str``), ``args`` can only be strings.
  85. If any other type is given, then it is casted to a string representation
  86. (e.g. ``str(arg1)``). Furthermore, a binary failure will only write
  87. an ``error.json`` error file if the main function is annotated with
  88. ``torch.distributed.elastic.multiprocessing.errors.record``. For function launches,
  89. this is done by default and there is no need to manually annotate
  90. with the ``@record`` annotation.
  91. ``redirects`` and ``tee`` are bitmasks specifying which std stream(s) to redirect
  92. to a log file in the ``log_dir``. Valid mask values are defined in ``Std``.
  93. To redirect/tee only certain local ranks, pass ``redirects`` as a map with the key as
  94. the local rank to specify the redirect behavior for.
  95. Any missing local ranks will default to ``Std.NONE``.
  96. ``tee`` acts like the unix "tee" command in that it redirects + prints to console.
  97. To avoid worker stdout/stderr from printing to console, use the ``redirects`` parameter.
  98. For each process, the ``log_dir`` will contain:
  99. #. ``{local_rank}/error.json``: if the process failed, a file with the error info
  100. #. ``{local_rank}/stdout.json``: if ``redirect & STDOUT == STDOUT``
  101. #. ``{local_rank}/stderr.json``: if ``redirect & STDERR == STDERR``
  102. .. note:: It is expected that the ``log_dir`` exists, is empty, and is a directory.
  103. Example:
  104. ::
  105. log_dir = "/tmp/test"
  106. # ok; two copies of foo: foo("bar0"), foo("bar1")
  107. start_processes(
  108. name="trainer",
  109. entrypoint=foo,
  110. args:{0:("bar0",), 1:("bar1",),
  111. envs:{0:{}, 1:{}},
  112. log_dir=log_dir
  113. )
  114. # invalid; envs missing for local rank 1
  115. start_processes(
  116. name="trainer",
  117. entrypoint=foo,
  118. args:{0:("bar0",), 1:("bar1",),
  119. envs:{0:{}},
  120. log_dir=log_dir
  121. )
  122. # ok; two copies of /usr/bin/touch: touch file1, touch file2
  123. start_processes(
  124. name="trainer",
  125. entrypoint="/usr/bin/touch",
  126. args:{0:("file1",), 1:("file2",),
  127. envs:{0:{}, 1:{}},
  128. log_dir=log_dir
  129. )
  130. # caution; arguments casted to string, runs:
  131. # echo "1" "2" "3" and echo "[1, 2, 3]"
  132. start_processes(
  133. name="trainer",
  134. entrypoint="/usr/bin/echo",
  135. args:{0:(1,2,3), 1:([1,2,3],),
  136. envs:{0:{}, 1:{}},
  137. log_dir=log_dir
  138. )
  139. Args:
  140. name: a human readable short name that describes what the processes are
  141. (used as header when tee'ing stdout/stderr outputs)
  142. entrypoint: either a ``Callable`` (function) or ``cmd`` (binary)
  143. args: arguments to each replica
  144. envs: env vars to each replica
  145. log_dir: directory used to write log files
  146. start_method: multiprocessing start method (spawn, fork, forkserver)
  147. ignored for binaries
  148. redirects: which std streams to redirect to a log file
  149. tee: which std streams to redirect + print to console
  150. """
  151. # listdir raises FileNotFound or NotADirectoryError so no need to check manually
  152. if log_dir != os.devnull and os.listdir(log_dir):
  153. raise RuntimeError(
  154. f"log_dir: {log_dir} is not empty, please provide an empty log_dir"
  155. )
  156. nprocs = len(args)
  157. _validate_full_rank(args, nprocs, "args")
  158. _validate_full_rank(envs, nprocs, "envs")
  159. # create subdirs for each local rank in the logs_dir
  160. # logs_dir
  161. # |- 0
  162. # |- error.json
  163. # |- stdout.log
  164. # |- stderr.log
  165. # |- ...
  166. # |- (nprocs-1)
  167. redirs = to_map(redirects, nprocs)
  168. ts = to_map(tee, nprocs)
  169. # to tee stdout/stderr we first redirect into a file
  170. # then tail -f stdout.log/stderr.log so add tee settings to redirects
  171. for local_rank, tee_std in ts.items():
  172. redirect_std = redirs[local_rank]
  173. redirs[local_rank] = redirect_std | tee_std
  174. stdouts = {local_rank: "" for local_rank in range(nprocs)}
  175. stderrs = {local_rank: "" for local_rank in range(nprocs)}
  176. tee_stdouts: Dict[int, str] = {}
  177. tee_stderrs: Dict[int, str] = {}
  178. error_files = {}
  179. for local_rank in range(nprocs):
  180. if log_dir == os.devnull:
  181. tee_stdouts[local_rank] = os.devnull
  182. tee_stderrs[local_rank] = os.devnull
  183. error_files[local_rank] = os.devnull
  184. envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = ""
  185. else:
  186. clogdir = os.path.join(log_dir, str(local_rank))
  187. os.mkdir(clogdir)
  188. rd = redirs[local_rank]
  189. if (rd & Std.OUT) == Std.OUT:
  190. stdouts[local_rank] = os.path.join(clogdir, "stdout.log")
  191. if (rd & Std.ERR) == Std.ERR:
  192. stderrs[local_rank] = os.path.join(clogdir, "stderr.log")
  193. t = ts[local_rank]
  194. if t & Std.OUT == Std.OUT:
  195. tee_stdouts[local_rank] = stdouts[local_rank]
  196. if t & Std.ERR == Std.ERR:
  197. tee_stderrs[local_rank] = stderrs[local_rank]
  198. error_file = os.path.join(clogdir, "error.json")
  199. error_files[local_rank] = error_file
  200. log.info(f"Setting worker{local_rank} reply file to: {error_file}")
  201. envs[local_rank]["TORCHELASTIC_ERROR_FILE"] = error_file
  202. context: PContext
  203. if isinstance(entrypoint, str):
  204. context = SubprocessContext(
  205. name=name,
  206. entrypoint=entrypoint,
  207. args=args,
  208. envs=envs,
  209. stdouts=stdouts,
  210. stderrs=stderrs,
  211. tee_stdouts=tee_stdouts,
  212. tee_stderrs=tee_stderrs,
  213. error_files=error_files,
  214. )
  215. else:
  216. context = MultiprocessContext(
  217. name=name,
  218. entrypoint=entrypoint,
  219. args=args,
  220. envs=envs,
  221. stdouts=stdouts,
  222. stderrs=stderrs,
  223. tee_stdouts=tee_stdouts,
  224. tee_stderrs=tee_stderrs,
  225. error_files=error_files,
  226. start_method=start_method,
  227. )
  228. try:
  229. context.start()
  230. return context
  231. except Exception:
  232. context.close()
  233. raise