launch.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. r"""
  2. ``torch.distributed.launch`` is a module that spawns up multiple distributed
  3. training processes on each of the training nodes.
  4. .. warning::
  5. This module is going to be deprecated in favor of :ref:`torchrun <launcher-api>`.
  6. The utility can be used for single-node distributed training, in which one or
  7. more processes per node will be spawned. The utility can be used for either
  8. CPU training or GPU training. If the utility is used for GPU training,
  9. each distributed process will be operating on a single GPU. This can achieve
  10. well-improved single-node training performance. It can also be used in
  11. multi-node distributed training, by spawning up multiple processes on each node
  12. for well-improved multi-node distributed training performance as well.
  13. This will especially be benefitial for systems with multiple Infiniband
  14. interfaces that have direct-GPU support, since all of them can be utilized for
  15. aggregated communication bandwidth.
  16. In both cases of single-node distributed training or multi-node distributed
  17. training, this utility will launch the given number of processes per node
  18. (``--nproc-per-node``). If used for GPU training, this number needs to be less
  19. or equal to the number of GPUs on the current system (``nproc_per_node``),
  20. and each process will be operating on a single GPU from *GPU 0 to
  21. GPU (nproc_per_node - 1)*.
  22. **How to use this module:**
  23. 1. Single-Node multi-process distributed training
  24. ::
  25. python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
  26. YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
  27. arguments of your training script)
  28. 2. Multi-Node multi-process distributed training: (e.g. two nodes)
  29. Node 1: *(IP: 192.168.1.1, and has a free port: 1234)*
  30. ::
  31. python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
  32. --nnodes=2 --node-rank=0 --master-addr="192.168.1.1"
  33. --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
  34. and all other arguments of your training script)
  35. Node 2:
  36. ::
  37. python -m torch.distributed.launch --nproc-per-node=NUM_GPUS_YOU_HAVE
  38. --nnodes=2 --node-rank=1 --master-addr="192.168.1.1"
  39. --master-port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3
  40. and all other arguments of your training script)
  41. 3. To look up what optional arguments this module offers:
  42. ::
  43. python -m torch.distributed.launch --help
  44. **Important Notices:**
  45. 1. This utility and multi-process distributed (single-node or
  46. multi-node) GPU training currently only achieves the best performance using
  47. the NCCL distributed backend. Thus NCCL backend is the recommended backend to
  48. use for GPU training.
  49. 2. In your training program, you must parse the command-line argument:
  50. ``--local-rank=LOCAL_PROCESS_RANK``, which will be provided by this module.
  51. If your training program uses GPUs, you should ensure that your code only
  52. runs on the GPU device of LOCAL_PROCESS_RANK. This can be done by:
  53. Parsing the local_rank argument
  54. ::
  55. >>> # xdoctest: +SKIP
  56. >>> import argparse
  57. >>> parser = argparse.ArgumentParser()
  58. >>> parser.add_argument("--local-rank", type=int)
  59. >>> args = parser.parse_args()
  60. Set your device to local rank using either
  61. ::
  62. >>> torch.cuda.set_device(args.local_rank) # before your code runs
  63. or
  64. ::
  65. >>> with torch.cuda.device(args.local_rank):
  66. >>> # your code to run
  67. >>> ...
  68. 3. In your training program, you are supposed to call the following function
  69. at the beginning to start the distributed backend. It is strongly recommended
  70. that ``init_method=env://``. Other init methods (e.g. ``tcp://``) may work,
  71. but ``env://`` is the one that is officially supported by this module.
  72. ::
  73. >>> torch.distributed.init_process_group(backend='YOUR BACKEND',
  74. >>> init_method='env://')
  75. 4. In your training program, you can either use regular distributed functions
  76. or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
  77. training program uses GPUs for training and you would like to use
  78. :func:`torch.nn.parallel.DistributedDataParallel` module,
  79. here is how to configure it.
  80. ::
  81. >>> model = torch.nn.parallel.DistributedDataParallel(model,
  82. >>> device_ids=[args.local_rank],
  83. >>> output_device=args.local_rank)
  84. Please ensure that ``device_ids`` argument is set to be the only GPU device id
  85. that your code will be operating on. This is generally the local rank of the
  86. process. In other words, the ``device_ids`` needs to be ``[args.local_rank]``,
  87. and ``output_device`` needs to be ``args.local_rank`` in order to use this
  88. utility
  89. 5. Another way to pass ``local_rank`` to the subprocesses via environment variable
  90. ``LOCAL_RANK``. This behavior is enabled when you launch the script with
  91. ``--use-env=True``. You must adjust the subprocess example above to replace
  92. ``args.local_rank`` with ``os.environ['LOCAL_RANK']``; the launcher
  93. will not pass ``--local-rank`` when you specify this flag.
  94. .. warning::
  95. ``local_rank`` is NOT globally unique: it is only unique per process
  96. on a machine. Thus, don't use it to decide if you should, e.g.,
  97. write to a networked filesystem. See
  98. https://github.com/pytorch/pytorch/issues/12042 for an example of
  99. how things can go wrong if you don't do this correctly.
  100. """
  101. import logging
  102. import warnings
  103. from torch.distributed.run import get_args_parser, run
  104. logger = logging.getLogger(__name__)
  105. def parse_args(args):
  106. parser = get_args_parser()
  107. parser.add_argument(
  108. "--use-env",
  109. "--use_env",
  110. default=False,
  111. action="store_true",
  112. help="Use environment variable to pass "
  113. "'local rank'. For legacy reasons, the default value is False. "
  114. "If set to True, the script will not pass "
  115. "--local-rank as argument, and will instead set LOCAL_RANK.",
  116. )
  117. return parser.parse_args(args)
  118. def launch(args):
  119. if args.no_python and not args.use_env:
  120. raise ValueError(
  121. "When using the '--no-python' flag,"
  122. " you must also set the '--use-env' flag."
  123. )
  124. run(args)
  125. def main(args=None):
  126. warnings.warn(
  127. "The module torch.distributed.launch is deprecated\n"
  128. "and will be removed in future. Use torchrun.\n"
  129. "Note that --use-env is set by default in torchrun.\n"
  130. "If your script expects `--local-rank` argument to be set, please\n"
  131. "change it to read from `os.environ['LOCAL_RANK']` instead. See \n"
  132. "https://pytorch.org/docs/stable/distributed.html#launch-utility for \n"
  133. "further instructions\n",
  134. FutureWarning,
  135. )
  136. args = parse_args(args)
  137. launch(args)
  138. if __name__ == "__main__":
  139. main()