spawn.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. from typing import Optional
  2. import multiprocessing
  3. import multiprocessing.connection
  4. import signal
  5. import sys
  6. import warnings
  7. from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
  8. class ProcessException(Exception):
  9. __slots__ = ["error_index", "error_pid"]
  10. def __init__(self, msg: str, error_index: int, pid: int):
  11. super().__init__(msg)
  12. self.msg = msg
  13. self.error_index = error_index
  14. self.pid = pid
  15. def __reduce__(self):
  16. return type(self), (self.msg, self.error_index, self.pid)
  17. class ProcessRaisedException(ProcessException):
  18. """
  19. Exception is thrown when the process failed due to exception
  20. raised by the code.
  21. """
  22. def __init__(
  23. self,
  24. msg: str,
  25. error_index: int,
  26. error_pid: int,
  27. ):
  28. super().__init__(msg, error_index, error_pid)
  29. class ProcessExitedException(ProcessException):
  30. """
  31. Exception is thrown when the process failed due to signal
  32. or exited with a specific code.
  33. """
  34. __slots__ = ["exit_code"]
  35. def __init__(
  36. self, msg: str, error_index: int, error_pid: int,
  37. exit_code: int, signal_name: Optional[str] = None
  38. ):
  39. super().__init__(msg, error_index, error_pid)
  40. self.exit_code = exit_code
  41. self.signal_name = signal_name
  42. def __reduce__(self):
  43. return (
  44. type(self),
  45. (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
  46. )
  47. def _wrap(fn, i, args, error_queue):
  48. # prctl(2) is a Linux specific system call.
  49. # On other systems the following function call has no effect.
  50. # This is set to ensure that non-daemonic child processes can
  51. # terminate if their parent terminates before they do.
  52. _prctl_pr_set_pdeathsig(signal.SIGINT)
  53. try:
  54. fn(i, *args)
  55. except KeyboardInterrupt:
  56. pass # SIGINT; Killed by parent, do nothing
  57. except Exception:
  58. # Propagate exception to parent process, keeping original traceback
  59. import traceback
  60. error_queue.put(traceback.format_exc())
  61. sys.exit(1)
  62. class ProcessContext:
  63. def __init__(self, processes, error_queues):
  64. self.error_queues = error_queues
  65. self.processes = processes
  66. self.sentinels = {
  67. process.sentinel: index
  68. for index, process in enumerate(processes)
  69. }
  70. def pids(self):
  71. return [int(process.pid) for process in self.processes]
  72. def join(self, timeout=None):
  73. r"""
  74. Tries to join one or more processes in this spawn context.
  75. If one of them exited with a non-zero exit status, this function
  76. kills the remaining processes and raises an exception with the cause
  77. of the first process exiting.
  78. Returns ``True`` if all processes have been joined successfully,
  79. ``False`` if there are more processes that need to be joined.
  80. Args:
  81. timeout (float): Wait this long before giving up on waiting.
  82. """
  83. # Ensure this function can be called even when we're done.
  84. if len(self.sentinels) == 0:
  85. return True
  86. # Wait for any process to fail or all of them to succeed.
  87. ready = multiprocessing.connection.wait(
  88. self.sentinels.keys(),
  89. timeout=timeout,
  90. )
  91. error_index = None
  92. for sentinel in ready:
  93. index = self.sentinels.pop(sentinel)
  94. process = self.processes[index]
  95. process.join()
  96. if process.exitcode != 0:
  97. error_index = index
  98. break
  99. # Return if there was no error.
  100. if error_index is None:
  101. # Return whether or not all processes have been joined.
  102. return len(self.sentinels) == 0
  103. # Assume failure. Terminate processes that are still alive.
  104. for process in self.processes:
  105. if process.is_alive():
  106. process.terminate()
  107. process.join()
  108. # There won't be an error on the queue if the process crashed.
  109. failed_process = self.processes[error_index]
  110. if self.error_queues[error_index].empty():
  111. exitcode = self.processes[error_index].exitcode
  112. if exitcode < 0:
  113. name = signal.Signals(-exitcode).name
  114. raise ProcessExitedException(
  115. "process %d terminated with signal %s" %
  116. (error_index, name),
  117. error_index=error_index,
  118. error_pid=failed_process.pid,
  119. exit_code=exitcode,
  120. signal_name=name
  121. )
  122. else:
  123. raise ProcessExitedException(
  124. "process %d terminated with exit code %d" %
  125. (error_index, exitcode),
  126. error_index=error_index,
  127. error_pid=failed_process.pid,
  128. exit_code=exitcode
  129. )
  130. original_trace = self.error_queues[error_index].get()
  131. msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
  132. msg += original_trace
  133. raise ProcessRaisedException(msg, error_index, failed_process.pid)
  134. class SpawnContext(ProcessContext):
  135. def __init__(self, processes, error_queues):
  136. warnings.warn('SpawnContext is renamed to ProcessContext since 1.4 release.')
  137. super().__init__(processes, error_queues)
  138. # Note: [start_processes]
  139. # mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
  140. # more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
  141. # CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
  142. # works better than 'spawn'. Every helper function we created for mp.spawn is indeed
  143. # general enough, and backends like XLA can reuse them in Colab notebooks as well.
  144. # Currently we only add this API first, we can consider adding it to documentation as
  145. # needed in the future.
  146. def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
  147. mp = multiprocessing.get_context(start_method)
  148. error_queues = []
  149. processes = []
  150. for i in range(nprocs):
  151. error_queue = mp.SimpleQueue()
  152. process = mp.Process(
  153. target=_wrap,
  154. args=(fn, i, args, error_queue),
  155. daemon=daemon,
  156. )
  157. process.start()
  158. error_queues.append(error_queue)
  159. processes.append(process)
  160. context = ProcessContext(processes, error_queues)
  161. if not join:
  162. return context
  163. # Loop on join until it returns True or raises an exception.
  164. while not context.join():
  165. pass
  166. def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
  167. r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
  168. If one of the processes exits with a non-zero exit status, the
  169. remaining processes are killed and an exception is raised with the
  170. cause of termination. In the case an exception was caught in the
  171. child process, it is forwarded and its traceback is included in
  172. the exception raised in the parent process.
  173. Args:
  174. fn (function): Function is called as the entrypoint of the
  175. spawned process. This function must be defined at the top
  176. level of a module so it can be pickled and spawned. This
  177. is a requirement imposed by multiprocessing.
  178. The function is called as ``fn(i, *args)``, where ``i`` is
  179. the process index and ``args`` is the passed through tuple
  180. of arguments.
  181. args (tuple): Arguments passed to ``fn``.
  182. nprocs (int): Number of processes to spawn.
  183. join (bool): Perform a blocking join on all processes.
  184. daemon (bool): The spawned processes' daemon flag. If set to True,
  185. daemonic processes will be created.
  186. start_method (str): (deprecated) this method will always use ``spawn``
  187. as the start method. To use a different start method
  188. use ``start_processes()``.
  189. Returns:
  190. None if ``join`` is ``True``,
  191. :class:`~ProcessContext` if ``join`` is ``False``
  192. """
  193. if start_method != 'spawn':
  194. msg = ('This method only supports start_method=spawn (got: %s).\n'
  195. 'To use a different start_method use:\n\t\t'
  196. ' torch.multiprocessing.start_processes(...)' % start_method)
  197. warnings.warn(msg)
  198. return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')