123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- from typing import Optional
- import multiprocessing
- import multiprocessing.connection
- import signal
- import sys
- import warnings
- from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined]
- class ProcessException(Exception):
- __slots__ = ["error_index", "error_pid"]
- def __init__(self, msg: str, error_index: int, pid: int):
- super().__init__(msg)
- self.msg = msg
- self.error_index = error_index
- self.pid = pid
- def __reduce__(self):
- return type(self), (self.msg, self.error_index, self.pid)
- class ProcessRaisedException(ProcessException):
- """
- Exception is thrown when the process failed due to exception
- raised by the code.
- """
- def __init__(
- self,
- msg: str,
- error_index: int,
- error_pid: int,
- ):
- super().__init__(msg, error_index, error_pid)
- class ProcessExitedException(ProcessException):
- """
- Exception is thrown when the process failed due to signal
- or exited with a specific code.
- """
- __slots__ = ["exit_code"]
- def __init__(
- self, msg: str, error_index: int, error_pid: int,
- exit_code: int, signal_name: Optional[str] = None
- ):
- super().__init__(msg, error_index, error_pid)
- self.exit_code = exit_code
- self.signal_name = signal_name
- def __reduce__(self):
- return (
- type(self),
- (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
- )
- def _wrap(fn, i, args, error_queue):
- # prctl(2) is a Linux specific system call.
- # On other systems the following function call has no effect.
- # This is set to ensure that non-daemonic child processes can
- # terminate if their parent terminates before they do.
- _prctl_pr_set_pdeathsig(signal.SIGINT)
- try:
- fn(i, *args)
- except KeyboardInterrupt:
- pass # SIGINT; Killed by parent, do nothing
- except Exception:
- # Propagate exception to parent process, keeping original traceback
- import traceback
- error_queue.put(traceback.format_exc())
- sys.exit(1)
- class ProcessContext:
- def __init__(self, processes, error_queues):
- self.error_queues = error_queues
- self.processes = processes
- self.sentinels = {
- process.sentinel: index
- for index, process in enumerate(processes)
- }
- def pids(self):
- return [int(process.pid) for process in self.processes]
- def join(self, timeout=None):
- r"""
- Tries to join one or more processes in this spawn context.
- If one of them exited with a non-zero exit status, this function
- kills the remaining processes and raises an exception with the cause
- of the first process exiting.
- Returns ``True`` if all processes have been joined successfully,
- ``False`` if there are more processes that need to be joined.
- Args:
- timeout (float): Wait this long before giving up on waiting.
- """
- # Ensure this function can be called even when we're done.
- if len(self.sentinels) == 0:
- return True
- # Wait for any process to fail or all of them to succeed.
- ready = multiprocessing.connection.wait(
- self.sentinels.keys(),
- timeout=timeout,
- )
- error_index = None
- for sentinel in ready:
- index = self.sentinels.pop(sentinel)
- process = self.processes[index]
- process.join()
- if process.exitcode != 0:
- error_index = index
- break
- # Return if there was no error.
- if error_index is None:
- # Return whether or not all processes have been joined.
- return len(self.sentinels) == 0
- # Assume failure. Terminate processes that are still alive.
- for process in self.processes:
- if process.is_alive():
- process.terminate()
- process.join()
- # There won't be an error on the queue if the process crashed.
- failed_process = self.processes[error_index]
- if self.error_queues[error_index].empty():
- exitcode = self.processes[error_index].exitcode
- if exitcode < 0:
- name = signal.Signals(-exitcode).name
- raise ProcessExitedException(
- "process %d terminated with signal %s" %
- (error_index, name),
- error_index=error_index,
- error_pid=failed_process.pid,
- exit_code=exitcode,
- signal_name=name
- )
- else:
- raise ProcessExitedException(
- "process %d terminated with exit code %d" %
- (error_index, exitcode),
- error_index=error_index,
- error_pid=failed_process.pid,
- exit_code=exitcode
- )
- original_trace = self.error_queues[error_index].get()
- msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
- msg += original_trace
- raise ProcessRaisedException(msg, error_index, failed_process.pid)
- class SpawnContext(ProcessContext):
- def __init__(self, processes, error_queues):
- warnings.warn('SpawnContext is renamed to ProcessContext since 1.4 release.')
- super().__init__(processes, error_queues)
- # Note: [start_processes]
- # mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
- # more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
- # CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
- # works better than 'spawn'. Every helper function we created for mp.spawn is indeed
- # general enough, and backends like XLA can reuse them in Colab notebooks as well.
- # Currently we only add this API first, we can consider adding it to documentation as
- # needed in the future.
- def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
- mp = multiprocessing.get_context(start_method)
- error_queues = []
- processes = []
- for i in range(nprocs):
- error_queue = mp.SimpleQueue()
- process = mp.Process(
- target=_wrap,
- args=(fn, i, args, error_queue),
- daemon=daemon,
- )
- process.start()
- error_queues.append(error_queue)
- processes.append(process)
- context = ProcessContext(processes, error_queues)
- if not join:
- return context
- # Loop on join until it returns True or raises an exception.
- while not context.join():
- pass
- def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
- r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
- If one of the processes exits with a non-zero exit status, the
- remaining processes are killed and an exception is raised with the
- cause of termination. In the case an exception was caught in the
- child process, it is forwarded and its traceback is included in
- the exception raised in the parent process.
- Args:
- fn (function): Function is called as the entrypoint of the
- spawned process. This function must be defined at the top
- level of a module so it can be pickled and spawned. This
- is a requirement imposed by multiprocessing.
- The function is called as ``fn(i, *args)``, where ``i`` is
- the process index and ``args`` is the passed through tuple
- of arguments.
- args (tuple): Arguments passed to ``fn``.
- nprocs (int): Number of processes to spawn.
- join (bool): Perform a blocking join on all processes.
- daemon (bool): The spawned processes' daemon flag. If set to True,
- daemonic processes will be created.
- start_method (str): (deprecated) this method will always use ``spawn``
- as the start method. To use a different start method
- use ``start_processes()``.
- Returns:
- None if ``join`` is ``True``,
- :class:`~ProcessContext` if ``join`` is ``False``
- """
- if start_method != 'spawn':
- msg = ('This method only supports start_method=spawn (got: %s).\n'
- 'To use a different start_method use:\n\t\t'
- ' torch.multiprocessing.start_processes(...)' % start_method)
- warnings.warn(msg)
- return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
|