worker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
  2. These **needs** to be in global scope since Py2 doesn't support serializing
  3. static methods.
  4. """
  5. import torch
  6. import random
  7. import os
  8. import queue
  9. from dataclasses import dataclass
  10. from torch._utils import ExceptionWrapper
  11. from typing import Optional, Union, TYPE_CHECKING
  12. from . import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS, HAS_NUMPY
  13. if TYPE_CHECKING:
  14. from torch.utils.data import Dataset
  15. if IS_WINDOWS:
  16. import ctypes
  17. from ctypes.wintypes import DWORD, BOOL, HANDLE
  18. # On Windows, the parent ID of the worker process remains unchanged when the manager process
  19. # is gone, and the only way to check it through OS is to let the worker have a process handle
  20. # of the manager and ask if the process status has changed.
  21. class ManagerWatchdog:
  22. def __init__(self):
  23. self.manager_pid = os.getppid()
  24. # mypy cannot detect this code is windows only
  25. self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True) # type: ignore[attr-defined]
  26. self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
  27. self.kernel32.OpenProcess.restype = HANDLE
  28. self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
  29. self.kernel32.WaitForSingleObject.restype = DWORD
  30. # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
  31. SYNCHRONIZE = 0x00100000
  32. self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
  33. if not self.manager_handle:
  34. raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined]
  35. self.manager_dead = False
  36. def is_alive(self):
  37. if not self.manager_dead:
  38. # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
  39. self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
  40. return not self.manager_dead
  41. else:
  42. class ManagerWatchdog: # type: ignore[no-redef]
  43. def __init__(self):
  44. self.manager_pid = os.getppid()
  45. self.manager_dead = False
  46. def is_alive(self):
  47. if not self.manager_dead:
  48. self.manager_dead = os.getppid() != self.manager_pid
  49. return not self.manager_dead
  50. _worker_info = None
  51. class WorkerInfo:
  52. id: int
  53. num_workers: int
  54. seed: int
  55. dataset: 'Dataset'
  56. __initialized = False
  57. def __init__(self, **kwargs):
  58. for k, v in kwargs.items():
  59. setattr(self, k, v)
  60. self.__keys = tuple(kwargs.keys())
  61. self.__initialized = True
  62. def __setattr__(self, key, val):
  63. if self.__initialized:
  64. raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__))
  65. return super().__setattr__(key, val)
  66. def __repr__(self):
  67. items = []
  68. for k in self.__keys:
  69. items.append('{}={}'.format(k, getattr(self, k)))
  70. return '{}({})'.format(self.__class__.__name__, ', '.join(items))
  71. def get_worker_info() -> Optional[WorkerInfo]:
  72. r"""Returns the information about the current
  73. :class:`~torch.utils.data.DataLoader` iterator worker process.
  74. When called in a worker, this returns an object guaranteed to have the
  75. following attributes:
  76. * :attr:`id`: the current worker id.
  77. * :attr:`num_workers`: the total number of workers.
  78. * :attr:`seed`: the random seed set for the current worker. This value is
  79. determined by main process RNG and the worker id. See
  80. :class:`~torch.utils.data.DataLoader`'s documentation for more details.
  81. * :attr:`dataset`: the copy of the dataset object in **this** process. Note
  82. that this will be a different object in a different process than the one
  83. in the main process.
  84. When called in the main process, this returns ``None``.
  85. .. note::
  86. When used in a :attr:`worker_init_fn` passed over to
  87. :class:`~torch.utils.data.DataLoader`, this method can be useful to
  88. set up each worker process differently, for instance, using ``worker_id``
  89. to configure the ``dataset`` object to only read a specific fraction of a
  90. sharded dataset, or use ``seed`` to seed other libraries used in dataset
  91. code.
  92. """
  93. return _worker_info
  94. r"""Dummy class used to signal the end of an IterableDataset"""
  95. @dataclass(frozen=True)
  96. class _IterableDatasetStopIteration:
  97. worker_id: int
  98. r"""Dummy class used to resume the fetching when worker reuse is enabled"""
  99. @dataclass(frozen=True)
  100. class _ResumeIteration:
  101. seed: Optional[int] = None
  102. # The function `_generate_state` is adapted from `numpy.random.SeedSequence`
  103. # from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
  104. # It's MIT licensed, here is the copyright:
  105. # Copyright (c) 2015 Melissa E. O'Neill
  106. # Copyright (c) 2019 NumPy Developers
  107. #
  108. # Permission is hereby granted, free of charge, to any person obtaining a copy
  109. # of this software and associated documentation files (the "Software"), to deal
  110. # in the Software without restriction, including without limitation the rights
  111. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  112. # copies of the Software, and to permit persons to whom the Software is
  113. # furnished to do so, subject to the following conditions:
  114. #
  115. # The above copyright notice and this permission notice shall be included in
  116. # all copies or substantial portions of the Software.
  117. #
  118. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  119. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  120. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  121. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  122. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  123. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  124. # SOFTWARE.
  125. # This function generates an array of int32 as the seed for
  126. # `numpy.random`, in order to prevent state collision due to same
  127. # seed and algorithm for `numpy.random` and `random` modules.
  128. # TODO: Implement `SeedSequence` like object for `torch.random`
  129. def _generate_state(base_seed, worker_id):
  130. INIT_A = 0x43b0d7e5
  131. MULT_A = 0x931e8875
  132. INIT_B = 0x8b51f9dd
  133. MULT_B = 0x58f38ded
  134. MIX_MULT_L = 0xca01f9dd
  135. MIX_MULT_R = 0x4973f715
  136. XSHIFT = 4 * 8 // 2
  137. MASK32 = 0xFFFFFFFF
  138. entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
  139. pool = [0] * 4
  140. hash_const_A = INIT_A
  141. def hash(value):
  142. nonlocal hash_const_A
  143. value = (value ^ hash_const_A) & MASK32
  144. hash_const_A = (hash_const_A * MULT_A) & MASK32
  145. value = (value * hash_const_A) & MASK32
  146. value = (value ^ (value >> XSHIFT)) & MASK32
  147. return value
  148. def mix(x, y):
  149. result_x = (MIX_MULT_L * x) & MASK32
  150. result_y = (MIX_MULT_R * y) & MASK32
  151. result = (result_x - result_y) & MASK32
  152. result = (result ^ (result >> XSHIFT)) & MASK32
  153. return result
  154. # Add in the entropy to the pool.
  155. for i in range(len(pool)):
  156. pool[i] = hash(entropy[i])
  157. # Mix all bits together so late bits can affect earlier bits.
  158. for i_src in range(len(pool)):
  159. for i_dst in range(len(pool)):
  160. if i_src != i_dst:
  161. pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
  162. hash_const_B = INIT_B
  163. state = []
  164. for i_dst in range(4):
  165. data_val = pool[i_dst]
  166. data_val = (data_val ^ hash_const_B) & MASK32
  167. hash_const_B = (hash_const_B * MULT_B) & MASK32
  168. data_val = (data_val * hash_const_B) & MASK32
  169. data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
  170. state.append(data_val)
  171. return state
  172. def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
  173. auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
  174. num_workers, persistent_workers, shared_seed):
  175. # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
  176. # logic of this function.
  177. try:
  178. # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
  179. # module's handlers are executed after Python returns from C low-level
  180. # handlers, likely when the same fatal signal had already happened
  181. # again.
  182. # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
  183. signal_handling._set_worker_signal_handlers()
  184. torch.set_num_threads(1)
  185. seed = base_seed + worker_id
  186. random.seed(seed)
  187. torch.manual_seed(seed)
  188. if HAS_NUMPY:
  189. np_seed = _generate_state(base_seed, worker_id)
  190. import numpy as np
  191. np.random.seed(np_seed)
  192. from torch.utils.data import IterDataPipe
  193. from torch.utils.data.graph_settings import apply_random_seed
  194. shared_rng = torch.Generator()
  195. if isinstance(dataset, IterDataPipe):
  196. assert shared_seed is not None
  197. shared_rng.manual_seed(shared_seed)
  198. dataset = apply_random_seed(dataset, shared_rng)
  199. global _worker_info
  200. _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
  201. seed=seed, dataset=dataset)
  202. from torch.utils.data import _DatasetKind
  203. init_exception = None
  204. try:
  205. if init_fn is not None:
  206. init_fn(worker_id)
  207. fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
  208. except Exception:
  209. init_exception = ExceptionWrapper(
  210. where="in DataLoader worker process {}".format(worker_id))
  211. # When using Iterable mode, some worker can exit earlier than others due
  212. # to the IterableDataset behaving differently for different workers.
  213. # When such things happen, an `_IterableDatasetStopIteration` object is
  214. # sent over to the main process with the ID of this worker, so that the
  215. # main process won't send more tasks to this worker, and will send
  216. # `None` to this worker to properly exit it.
  217. #
  218. # Note that we cannot set `done_event` from a worker as it is shared
  219. # among all processes. Instead, we set the `iteration_end` flag to
  220. # signify that the iterator is exhausted. When either `done_event` or
  221. # `iteration_end` is set, we skip all processing step and just wait for
  222. # `None`.
  223. iteration_end = False
  224. watchdog = ManagerWatchdog()
  225. while watchdog.is_alive():
  226. try:
  227. r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
  228. except queue.Empty:
  229. continue
  230. if isinstance(r, _ResumeIteration):
  231. # Acknowledge the main process
  232. data_queue.put((r, None))
  233. iteration_end = False
  234. if isinstance(dataset, IterDataPipe):
  235. assert r.seed is not None
  236. shared_rng.manual_seed(r.seed)
  237. dataset = apply_random_seed(dataset, shared_rng)
  238. # Recreate the fetcher for worker-reuse policy
  239. fetcher = _DatasetKind.create_fetcher(
  240. dataset_kind, dataset, auto_collation, collate_fn, drop_last)
  241. continue
  242. elif r is None:
  243. # Received the final signal
  244. assert done_event.is_set() or iteration_end
  245. break
  246. elif done_event.is_set() or iteration_end:
  247. # `done_event` is set. But I haven't received the final signal
  248. # (None) yet. I will keep continuing until get it, and skip the
  249. # processing steps.
  250. continue
  251. idx, index = r
  252. data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
  253. if init_exception is not None:
  254. data = init_exception
  255. init_exception = None
  256. else:
  257. try:
  258. data = fetcher.fetch(index)
  259. except Exception as e:
  260. if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
  261. data = _IterableDatasetStopIteration(worker_id)
  262. # Set `iteration_end`
  263. # (1) to save future `next(...)` calls, and
  264. # (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
  265. iteration_end = True
  266. else:
  267. # It is important that we don't store exc_info in a variable.
  268. # `ExceptionWrapper` does the correct thing.
  269. # See NOTE [ Python Traceback Reference Cycle Problem ]
  270. data = ExceptionWrapper(
  271. where="in DataLoader worker process {}".format(worker_id))
  272. data_queue.put((idx, data))
  273. del data, idx, index, r # save memory
  274. except KeyboardInterrupt:
  275. # Main process will raise KeyboardInterrupt anyways.
  276. pass
  277. if done_event.is_set():
  278. data_queue.cancel_join_thread()
  279. data_queue.close()