signal_handling.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. r""""Signal handling for multiprocessing data loading.
  2. NOTE [ Signal handling in multiprocessing data loading ]
  3. In cases like DataLoader, if a worker process dies due to bus error/segfault
  4. or just hang, the main process will hang waiting for data. This is difficult
  5. to avoid on PyTorch side as it can be caused by limited shm, or other
  6. libraries users call in the workers. In this file and `DataLoader.cpp`, we make
  7. our best effort to provide some error message to users when such unfortunate
  8. events happen.
  9. When a _BaseDataLoaderIter starts worker processes, their pids are registered in a
  10. defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]
  11. via `_set_worker_pids`.
  12. When an error happens in a worker process, the main process received a SIGCHLD,
  13. and Python will eventually call the handler registered below
  14. (in `_set_SIGCHLD_handler`). In the handler, the `_error_if_any_worker_fails`
  15. call checks all registered worker pids and raise proper error message to
  16. prevent main process from hanging waiting for data from worker.
  17. Additionally, at the beginning of each worker's `_utils.worker._worker_loop`,
  18. `_set_worker_signal_handlers` is called to register critical signal handlers
  19. (e.g., for SIGSEGV, SIGBUS, SIGFPE, SIGTERM) in C, which just prints an error
  20. message to stderr before triggering the default handler. So a message will also
  21. be printed from the worker process when it is killed by such signals.
  22. See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for the reasoning of
  23. this signal handling design and other mechanism we implement to make our
  24. multiprocessing data loading robust to errors.
  25. """
  26. import signal
  27. import threading
  28. from . import IS_WINDOWS
  29. # Some of the following imported functions are not used in this file, but are to
  30. # be used `_utils.signal_handling.XXXXX`.
  31. from torch._C import _set_worker_pids, _remove_worker_pids # noqa: F401
  32. from torch._C import _error_if_any_worker_fails, _set_worker_signal_handlers # noqa: F401
  33. _SIGCHLD_handler_set = False
  34. r"""Whether SIGCHLD handler is set for DataLoader worker failures. Only one
  35. handler needs to be set for all DataLoaders in a process."""
  36. def _set_SIGCHLD_handler():
  37. # Windows doesn't support SIGCHLD handler
  38. if IS_WINDOWS:
  39. return
  40. # can't set signal in child threads
  41. if not isinstance(threading.current_thread(), threading._MainThread): # type: ignore[attr-defined]
  42. return
  43. global _SIGCHLD_handler_set
  44. if _SIGCHLD_handler_set:
  45. return
  46. previous_handler = signal.getsignal(signal.SIGCHLD)
  47. if not callable(previous_handler):
  48. # This doesn't catch default handler, but SIGCHLD default handler is a
  49. # no-op.
  50. previous_handler = None
  51. def handler(signum, frame):
  52. # This following call uses `waitid` with WNOHANG from C side. Therefore,
  53. # Python can still get and update the process status successfully.
  54. _error_if_any_worker_fails()
  55. if previous_handler is not None:
  56. assert callable(previous_handler)
  57. previous_handler(signum, frame)
  58. signal.signal(signal.SIGCHLD, handler)
  59. _SIGCHLD_handler_set = True