run.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """
  8. ``torchrun`` provides a superset of the functionality as ``torch.distributed.launch``
  9. with the following additional functionalities:
  10. 1. Worker failures are handled gracefully by restarting all workers.
  11. 2. Worker ``RANK`` and ``WORLD_SIZE`` are assigned automatically.
  12. 3. Number of nodes is allowed to change between minimum and maximum sizes (elasticity).
  13. .. note:: ``torchrun`` is a python
  14. `console script <https://packaging.python.org/en/latest/specifications/entry-points/#use-for-scripts>`_
  15. to the main module
  16. `torch.distributed.run <https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py>`_
  17. declared in the ``entry_points`` configuration in
  18. `setup.py <https://github.com/pytorch/pytorch/blob/master/setup.py>`_.
  19. It is equivalent to invoking ``python -m torch.distributed.run``.
  20. Transitioning from torch.distributed.launch to torchrun
  21. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  22. ``torchrun`` supports the same arguments as ``torch.distributed.launch`` **except**
  23. for ``--use-env`` which is now deprecated. To migrate from ``torch.distributed.launch``
  24. to ``torchrun`` follow these steps:
  25. 1. If your training script is already reading ``local_rank`` from the ``LOCAL_RANK`` environment variable.
  26. Then you need simply omit the ``--use-env`` flag, e.g.:
  27. +--------------------------------------------------------------------+--------------------------------------------+
  28. | ``torch.distributed.launch`` | ``torchrun`` |
  29. +====================================================================+============================================+
  30. | | |
  31. | .. code-block:: shell-session | .. code-block:: shell-session |
  32. | | |
  33. | $ python -m torch.distributed.launch --use-env train_script.py | $ torchrun train_script.py |
  34. | | |
  35. +--------------------------------------------------------------------+--------------------------------------------+
  36. 2. If your training script reads local rank from a ``--local-rank`` cmd argument.
  37. Change your training script to read from the ``LOCAL_RANK`` environment variable as
  38. demonstrated by the following code snippet:
  39. +-------------------------------------------------------+----------------------------------------------------+
  40. | ``torch.distributed.launch`` | ``torchrun`` |
  41. +=======================================================+====================================================+
  42. | | |
  43. | .. code-block:: python | .. code-block:: python |
  44. | | |
  45. | | |
  46. | import argparse | import os |
  47. | parser = argparse.ArgumentParser() | local_rank = int(os.environ["LOCAL_RANK"]) |
  48. | parser.add_argument("--local-rank", type=int) | |
  49. | args = parser.parse_args() | |
  50. | | |
  51. | local_rank = args.local_rank | |
  52. | | |
  53. +-------------------------------------------------------+----------------------------------------------------+
  54. The aformentioned changes suffice to migrate from ``torch.distributed.launch`` to ``torchrun``.
  55. To take advantage of new features such as elasticity, fault-tolerance, and error reporting of ``torchrun``
  56. please refer to:
  57. * :ref:`elastic_train_script` for more information on authoring training scripts that are ``torchrun`` compliant.
  58. * the rest of this page for more information on the features of ``torchrun``.
  59. Usage
  60. --------
  61. Single-node multi-worker
  62. ++++++++++++++++++++++++++++++
  63. ::
  64. torchrun
  65. --standalone
  66. --nnodes=1
  67. --nproc-per-node=$NUM_TRAINERS
  68. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  69. Stacked single-node multi-worker
  70. +++++++++++++++++++++++++++++++++++
  71. To run multiple instances (separate jobs) of single-node, multi-worker on the
  72. same host, we need to make sure that each instance (job) is
  73. setup on different ports to avoid port conflicts (or worse, two jobs being merged
  74. as a single job). To do this you have to run with ``--rdzv-backend=c10d``
  75. and specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``.
  76. For ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random
  77. port automatically instead of manually assgining different ports for each run.
  78. ::
  79. torchrun
  80. --rdzv-backend=c10d
  81. --rdzv-endpoint=localhost:0
  82. --nnodes=1
  83. --nproc-per-node=$NUM_TRAINERS
  84. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  85. Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures)
  86. ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  87. ::
  88. torchrun
  89. --nnodes=$NUM_NODES
  90. --nproc-per-node=$NUM_TRAINERS
  91. --max-restarts=3
  92. --rdzv-id=$JOB_ID
  93. --rdzv-backend=c10d
  94. --rdzv-endpoint=$HOST_NODE_ADDR
  95. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  96. ``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
  97. the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
  98. node in your training cluster, but ideally you should pick a node that has a high bandwidth.
  99. .. note::
  100. If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
  101. Elastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures)
  102. +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  103. ::
  104. torchrun
  105. --nnodes=1:4
  106. --nproc-per-node=$NUM_TRAINERS
  107. --max-restarts=3
  108. --rdzv-id=$JOB_ID
  109. --rdzv-backend=c10d
  110. --rdzv-endpoint=$HOST_NODE_ADDR
  111. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  112. ``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
  113. the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
  114. node in your training cluster, but ideally you should pick a node that has a high bandwidth.
  115. .. note::
  116. If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
  117. Note on rendezvous backend
  118. ------------------------------
  119. For multi-node training you need to specify:
  120. 1. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job)
  121. 2. ``--rdzv-backend``: An implementation of
  122. :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler`
  123. 3. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form
  124. ``host:port``.
  125. Currently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy) rendezvous backends are
  126. supported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api
  127. enabled (e.g. ``--enable-v2``).
  128. .. warning::
  129. ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd
  130. server. Our tests use etcd v3.4.3.
  131. .. warning::
  132. For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally
  133. equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be
  134. removed in a future version.
  135. Definitions
  136. --------------
  137. 1. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with.
  138. 2. ``Worker`` - A worker in the context of distributed training.
  139. 3. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers).
  140. 4. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node.
  141. 5. ``RANK`` - The rank of the worker within a worker group.
  142. 6. ``WORLD_SIZE`` - The total number of workers in a worker group.
  143. 7. ``LOCAL_RANK`` - The rank of the worker within a local worker group.
  144. 8. ``LOCAL_WORLD_SIZE`` - The size of the local worker group.
  145. 9. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is
  146. used by each node to join as a member of a particular worker group.
  147. 9. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly
  148. consistent key-value store.
  149. 10. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``<host>:<port>``.
  150. A ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of
  151. all ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``.
  152. Environment Variables
  153. ----------------------
  154. The following environment variables are made available to you in your script:
  155. 1. ``LOCAL_RANK`` - The local rank.
  156. 2. ``RANK`` - The global rank.
  157. 3. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When
  158. running a single worker group per node, this is the rank of the node.
  159. 4. ``ROLE_RANK`` - The rank of the worker across all the workers that have the same role. The role
  160. of the worker is specified in the ``WorkerSpec``.
  161. 5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to
  162. ``--nproc-per-node`` specified on ``torchrun``.
  163. 6. ``WORLD_SIZE`` - The world size (total number of workers in the job).
  164. 7. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified
  165. in ``WorkerSpec``.
  166. 8. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize
  167. the Torch Distributed backend.
  168. 9. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store.
  169. 10. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far.
  170. 11. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts.
  171. 12. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id).
  172. 13. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will
  173. use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default.
  174. Deployment
  175. ------------
  176. 1. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be
  177. passed as ``--rdzv-endpoint`` to the launcher script)
  178. 2. Single-node multi-worker: Start the launcher on the host to start the agent process which
  179. creates and monitors a local worker group.
  180. 3. Multi-node multi-worker: Start the launcher with the same arguments on all the nodes
  181. participating in training.
  182. When using a job/cluster manager the entry point command to the multi-node job should be this
  183. launcher.
  184. Failure Modes
  185. ---------------
  186. 1. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers
  187. are stopped and restarted up to ``max_restarts``.
  188. 2. Agent failure: An agent failure results in a local worker group failure. It is up to the job
  189. manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors
  190. are supported by the agent.
  191. 3. Node failure: Same as agent failure.
  192. Membership Changes
  193. --------------------
  194. 1. Node departure (scale-down): The agent is notified of the departure, all existing workers are
  195. stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
  196. ``WORLD_SIZE``.
  197. 2. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped,
  198. a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
  199. ``WORLD_SIZE``.
  200. Important Notices
  201. --------------------
  202. 1. This utility and multi-process distributed (single-node or
  203. multi-node) GPU training currently only achieves the best performance using
  204. the NCCL distributed backend. Thus NCCL backend is the recommended backend to
  205. use for GPU training.
  206. 2. The environment variables necessary to initialize a Torch process group are provided to you by
  207. this module, no need for you to pass ``RANK`` manually. To initialize a process group in your
  208. training script, simply run:
  209. ::
  210. >>> # xdoctest: +SKIP("stub")
  211. >>> import torch.distributed as dist
  212. >>> dist.init_process_group(backend="gloo|nccl")
  213. 3. In your training program, you can either use regular distributed functions
  214. or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
  215. training program uses GPUs for training and you would like to use
  216. :func:`torch.nn.parallel.DistributedDataParallel` module,
  217. here is how to configure it.
  218. ::
  219. local_rank = int(os.environ["LOCAL_RANK"])
  220. model = torch.nn.parallel.DistributedDataParallel(model,
  221. device_ids=[local_rank],
  222. output_device=local_rank)
  223. Please ensure that ``device_ids`` argument is set to be the only GPU device id
  224. that your code will be operating on. This is generally the local rank of the
  225. process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``,
  226. and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this
  227. utility
  228. 4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to
  229. checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance
  230. for lost work.
  231. 5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all
  232. nodes run the same number of local workers (per role).
  233. 6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assgined a
  234. different range of ranks than before. NEVER hard code any assumptions about the stable-ness of
  235. ranks or some correlation between ``RANK`` and ``LOCAL_RANK``.
  236. 7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about
  237. ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join.
  238. 8. It is recommended for your script to have the following structure:
  239. ::
  240. def main():
  241. load_checkpoint(checkpoint_path)
  242. initialize()
  243. train()
  244. def train():
  245. for batch in iter(dataset):
  246. train_step(batch)
  247. if should_checkpoint:
  248. save_checkpoint(checkpoint_path)
  249. 9. (Recommended) On worker errors, this tool will summarize the details of the error
  250. (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
  251. is heuristically reported as the "Root Cause" error. To get tracebacks as part of this
  252. error summary print out, you must decorate your main entrypoint function in your
  253. training script as shown in the example below. If not decorated, then the summary
  254. will not include the traceback of the exception and will only contain the exitcode.
  255. For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html
  256. ::
  257. from torch.distributed.elastic.multiprocessing.errors import record
  258. @record
  259. def main():
  260. # do train
  261. pass
  262. if __name__ == "__main__":
  263. main()
  264. """
  265. import logging
  266. import os
  267. import sys
  268. import uuid
  269. from argparse import REMAINDER, ArgumentParser
  270. from typing import Callable, List, Tuple, Union
  271. import torch
  272. from torch.distributed.argparse_util import check_env, env
  273. from torch.distributed.elastic.multiprocessing import Std
  274. from torch.distributed.elastic.multiprocessing.errors import record
  275. from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
  276. from torch.distributed.elastic.utils import macros
  277. from torch.distributed.elastic.utils.logging import get_logger
  278. from torch.distributed.launcher.api import LaunchConfig, elastic_launch
  279. log = get_logger()
  280. def get_args_parser() -> ArgumentParser:
  281. """Helper function parsing the command line options."""
  282. parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher")
  283. #
  284. # Worker/node size related arguments.
  285. #
  286. parser.add_argument(
  287. "--nnodes",
  288. action=env,
  289. type=str,
  290. default="1:1",
  291. help="Number of nodes, or the range of nodes in form <minimum_nodes>:<maximum_nodes>.",
  292. )
  293. parser.add_argument(
  294. "--nproc-per-node",
  295. "--nproc_per_node",
  296. action=env,
  297. type=str,
  298. default="1",
  299. help="Number of workers per node; supported values: [auto, cpu, gpu, int].",
  300. )
  301. #
  302. # Rendezvous related arguments
  303. #
  304. parser.add_argument(
  305. "--rdzv-backend",
  306. "--rdzv_backend",
  307. action=env,
  308. type=str,
  309. default="static",
  310. help="Rendezvous backend.",
  311. )
  312. parser.add_argument(
  313. "--rdzv-endpoint",
  314. "--rdzv_endpoint",
  315. action=env,
  316. type=str,
  317. default="",
  318. help="Rendezvous backend endpoint; usually in form <host>:<port>.",
  319. )
  320. parser.add_argument(
  321. "--rdzv-id",
  322. "--rdzv_id",
  323. action=env,
  324. type=str,
  325. default="none",
  326. help="User-defined group id.",
  327. )
  328. parser.add_argument(
  329. "--rdzv-conf",
  330. "--rdzv_conf",
  331. action=env,
  332. type=str,
  333. default="",
  334. help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
  335. )
  336. parser.add_argument(
  337. "--standalone",
  338. action=check_env,
  339. help="Start a local standalone rendezvous backend that is represented by a C10d TCP store "
  340. "on port 29400. Useful when launching single-node, multi-worker job. If specified "
  341. "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned; any explicitly set values "
  342. "are ignored.",
  343. )
  344. #
  345. # User-code launch related arguments.
  346. #
  347. parser.add_argument(
  348. "--max-restarts",
  349. "--max_restarts",
  350. action=env,
  351. type=int,
  352. default=0,
  353. help="Maximum number of worker group restarts before failing.",
  354. )
  355. parser.add_argument(
  356. "--monitor-interval",
  357. "--monitor_interval",
  358. action=env,
  359. type=float,
  360. default=5,
  361. help="Interval, in seconds, to monitor the state of workers.",
  362. )
  363. parser.add_argument(
  364. "--start-method",
  365. "--start_method",
  366. action=env,
  367. type=str,
  368. default="spawn",
  369. choices=["spawn", "fork", "forkserver"],
  370. help="Multiprocessing start method to use when creating workers.",
  371. )
  372. parser.add_argument(
  373. "--role",
  374. action=env,
  375. type=str,
  376. default="default",
  377. help="User-defined role for the workers.",
  378. )
  379. parser.add_argument(
  380. "-m",
  381. "--module",
  382. action=check_env,
  383. help="Change each process to interpret the launch script as a Python module, executing "
  384. "with the same behavior as 'python -m'.",
  385. )
  386. parser.add_argument(
  387. "--no-python",
  388. "--no_python",
  389. action=check_env,
  390. help="Skip prepending the training script with 'python' - just execute it directly. Useful "
  391. "when the script is not a Python script.",
  392. )
  393. parser.add_argument(
  394. "--run-path",
  395. "--run_path",
  396. action=check_env,
  397. help="Run the training script with runpy.run_path in the same interpreter."
  398. " Script must be provided as an abs path (e.g. /abs/path/script.py)."
  399. " Takes precedence over --no-python.",
  400. )
  401. parser.add_argument(
  402. "--log-dir",
  403. "--log_dir",
  404. action=env,
  405. type=str,
  406. default=None,
  407. help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same "
  408. "directory is re-used for multiple runs (a unique job-level sub-directory is created with "
  409. "rdzv_id as the prefix).",
  410. )
  411. parser.add_argument(
  412. "-r",
  413. "--redirects",
  414. action=env,
  415. type=str,
  416. default="0",
  417. help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects "
  418. "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and "
  419. "stderr for local rank 1).",
  420. )
  421. parser.add_argument(
  422. "-t",
  423. "--tee",
  424. action=env,
  425. type=str,
  426. default="0",
  427. help="Tee std streams into a log file and also to console (see --redirects for format).",
  428. )
  429. #
  430. # Backwards compatible parameters with caffe2.distributed.launch.
  431. #
  432. parser.add_argument(
  433. "--node-rank",
  434. "--node_rank",
  435. type=int,
  436. action=env,
  437. default=0,
  438. help="Rank of the node for multi-node distributed training.",
  439. )
  440. parser.add_argument(
  441. "--master-addr",
  442. "--master_addr",
  443. default="127.0.0.1",
  444. type=str,
  445. action=env,
  446. help="Address of the master node (rank 0) that only used for static rendezvous. It should "
  447. "be either the IP address or the hostname of rank 0. For single node multi-proc training "
  448. "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern "
  449. "`[0:0:0:0:0:0:0:1]`.",
  450. )
  451. parser.add_argument(
  452. "--master-port",
  453. "--master_port",
  454. default=29500,
  455. type=int,
  456. action=env,
  457. help="Port on the master node (rank 0) to be used for communication during distributed "
  458. "training. It is only used for static rendezvous.",
  459. )
  460. parser.add_argument(
  461. "--local-addr",
  462. "--local_addr",
  463. default=None,
  464. type=str,
  465. action=env,
  466. help="Address of the local node. If specified, will use the given address for connection. "
  467. "Else, will look up the local node address instead. Else, it will be default to local "
  468. "machine's FQDN.",
  469. )
  470. #
  471. # Positional arguments.
  472. #
  473. parser.add_argument(
  474. "training_script",
  475. type=str,
  476. help="Full path to the (single GPU) training program/script to be launched in parallel, "
  477. "followed by all the arguments for the training script.",
  478. )
  479. # Rest from the training program.
  480. parser.add_argument("training_script_args", nargs=REMAINDER)
  481. return parser
  482. def parse_args(args):
  483. parser = get_args_parser()
  484. return parser.parse_args(args)
  485. def parse_min_max_nnodes(nnodes: str):
  486. arr = nnodes.split(":")
  487. if len(arr) == 1:
  488. min_nodes = max_nodes = int(arr[0])
  489. elif len(arr) == 2:
  490. min_nodes = int(arr[0])
  491. max_nodes = int(arr[1])
  492. else:
  493. raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format')
  494. return min_nodes, max_nodes
  495. def determine_local_world_size(nproc_per_node: str):
  496. try:
  497. logging.info(f"Using nproc_per_node={nproc_per_node}.")
  498. return int(nproc_per_node)
  499. except ValueError as e:
  500. if nproc_per_node == "cpu":
  501. num_proc = os.cpu_count()
  502. device_type = "cpu"
  503. elif nproc_per_node == "gpu":
  504. if not torch.cuda.is_available():
  505. raise ValueError("Cuda is not available.") from e
  506. device_type = "gpu"
  507. num_proc = torch.cuda.device_count()
  508. elif nproc_per_node == "auto":
  509. if torch.cuda.is_available():
  510. num_proc = torch.cuda.device_count()
  511. device_type = "gpu"
  512. else:
  513. num_proc = os.cpu_count()
  514. device_type = "cpu"
  515. else:
  516. raise ValueError(f"Unsupported nproc_per_node value: {nproc_per_node}") from e
  517. log.info(
  518. f"Using nproc_per_node={nproc_per_node},"
  519. f" seting to {num_proc} since the instance "
  520. f"has {os.cpu_count()} {device_type}"
  521. )
  522. return num_proc
  523. def get_rdzv_endpoint(args):
  524. if args.rdzv_backend == "static" and not args.rdzv_endpoint:
  525. return f"{args.master_addr}:{args.master_port}"
  526. return args.rdzv_endpoint
  527. def get_use_env(args) -> bool:
  528. """
  529. Retrieves ``use_env`` from the args.
  530. ``use_env`` is a legacy argument, if ``use_env`` is False, the
  531. ``--node-rank`` argument will be transferred to all worker processes.
  532. ``use_env`` is only used by the ``torch.distributed.launch`` and will
  533. be deprecated in future releases.
  534. """
  535. if not hasattr(args, "use_env"):
  536. return True
  537. return args.use_env
  538. def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str]]:
  539. # If ``args`` not passed, defaults to ``sys.argv[:1]``
  540. min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
  541. assert 0 < min_nodes <= max_nodes
  542. assert args.max_restarts >= 0
  543. if hasattr(args, "master_addr") and args.rdzv_backend != "static":
  544. log.warning(
  545. "master_addr is only used for static rdzv_backend and when rdzv_endpoint "
  546. "is not specified."
  547. )
  548. nproc_per_node = determine_local_world_size(args.nproc_per_node)
  549. if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
  550. omp_num_threads = 1
  551. log.warning(
  552. f"\n*****************************************\n"
  553. f"Setting OMP_NUM_THREADS environment variable for each process to be "
  554. f"{omp_num_threads} in default, to avoid your system being overloaded, "
  555. f"please further tune the variable for optimal performance in "
  556. f"your application as needed. \n"
  557. f"*****************************************"
  558. )
  559. # This env variable will be passed down to the subprocesses
  560. os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
  561. rdzv_configs = _parse_rendezvous_config(args.rdzv_conf)
  562. if args.rdzv_backend == "static":
  563. rdzv_configs["rank"] = args.node_rank
  564. rdzv_endpoint = get_rdzv_endpoint(args)
  565. config = LaunchConfig(
  566. min_nodes=min_nodes,
  567. max_nodes=max_nodes,
  568. nproc_per_node=nproc_per_node,
  569. run_id=args.rdzv_id,
  570. role=args.role,
  571. rdzv_endpoint=rdzv_endpoint,
  572. rdzv_backend=args.rdzv_backend,
  573. rdzv_configs=rdzv_configs,
  574. max_restarts=args.max_restarts,
  575. monitor_interval=args.monitor_interval,
  576. start_method=args.start_method,
  577. redirects=Std.from_str(args.redirects),
  578. tee=Std.from_str(args.tee),
  579. log_dir=args.log_dir,
  580. local_addr=args.local_addr,
  581. )
  582. with_python = not args.no_python
  583. cmd: Union[Callable, str]
  584. cmd_args = []
  585. use_env = get_use_env(args)
  586. if args.run_path:
  587. cmd = run_script_path
  588. cmd_args.append(args.training_script)
  589. else:
  590. if with_python:
  591. cmd = os.getenv("PYTHON_EXEC", sys.executable)
  592. cmd_args.append("-u")
  593. if args.module:
  594. cmd_args.append("-m")
  595. cmd_args.append(args.training_script)
  596. else:
  597. if args.module:
  598. raise ValueError(
  599. "Don't use both the '--no-python' flag"
  600. " and the '--module' flag at the same time."
  601. )
  602. cmd = args.training_script
  603. if not use_env:
  604. cmd_args.append(f"--local-rank={macros.local_rank}")
  605. cmd_args.extend(args.training_script_args)
  606. return config, cmd, cmd_args
  607. def run_script_path(training_script: str, *training_script_args: str):
  608. """
  609. Runs the provided `training_script` from within this interpreter.
  610. Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")`
  611. """
  612. import runpy
  613. import sys
  614. sys.argv = [training_script] + [*training_script_args]
  615. runpy.run_path(sys.argv[0], run_name="__main__")
  616. def run(args):
  617. if args.standalone:
  618. args.rdzv_backend = "c10d"
  619. args.rdzv_endpoint = "localhost:29400"
  620. args.rdzv_id = str(uuid.uuid4())
  621. log.info(
  622. f"\n**************************************\n"
  623. f"Rendezvous info:\n"
  624. f"--rdzv-backend={args.rdzv_backend} "
  625. f"--rdzv-endpoint={args.rdzv_endpoint} "
  626. f"--rdzv-id={args.rdzv_id}\n"
  627. f"**************************************\n"
  628. )
  629. config, cmd, cmd_args = config_from_args(args)
  630. elastic_launch(
  631. config=config,
  632. entrypoint=cmd,
  633. )(*cmd_args)
  634. @record
  635. def main(args=None):
  636. args = parse_args(args)
  637. run(args)
  638. if __name__ == "__main__":
  639. main()