__init__.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """
  2. torch.multiprocessing is a wrapper around the native :mod:`multiprocessing`
  3. module. It registers custom reducers, that use shared memory to provide shared
  4. views on the same data in different processes. Once the tensor/storage is moved
  5. to shared_memory (see :func:`~torch.Tensor.share_memory_`), it will be possible
  6. to send it to other processes without making any copies.
  7. The API is 100% compatible with the original module - it's enough to change
  8. ``import multiprocessing`` to ``import torch.multiprocessing`` to have all the
  9. tensors sent through the queues or shared via other mechanisms, moved to shared
  10. memory.
  11. Because of the similarity of APIs we do not document most of this package
  12. contents, and we recommend referring to very good docs of the original module.
  13. """
  14. import torch
  15. import sys
  16. from .reductions import init_reductions
  17. import multiprocessing
  18. __all__ = ['set_sharing_strategy', 'get_sharing_strategy',
  19. 'get_all_sharing_strategies']
  20. from multiprocessing import * # noqa: F403
  21. __all__ += multiprocessing.__all__ # type: ignore[attr-defined]
  22. # This call adds a Linux specific prctl(2) wrapper function to this module.
  23. # See https://github.com/pytorch/pytorch/pull/14391 for more information.
  24. torch._C._multiprocessing_init()
  25. """Add helper function to spawn N processes and wait for completion of any of
  26. them. This depends `mp.get_context` which was added in Python 3.4."""
  27. from .spawn import spawn, SpawnContext, start_processes, ProcessContext, \
  28. ProcessRaisedException, ProcessExitedException
  29. if sys.platform == 'darwin' or sys.platform == 'win32':
  30. _sharing_strategy = 'file_system'
  31. _all_sharing_strategies = {'file_system'}
  32. else:
  33. _sharing_strategy = 'file_descriptor'
  34. _all_sharing_strategies = {'file_descriptor', 'file_system'}
  35. def set_sharing_strategy(new_strategy):
  36. """Sets the strategy for sharing CPU tensors.
  37. Args:
  38. new_strategy (str): Name of the selected strategy. Should be one of
  39. the values returned by :func:`get_all_sharing_strategies()`.
  40. """
  41. global _sharing_strategy
  42. assert new_strategy in _all_sharing_strategies
  43. _sharing_strategy = new_strategy
  44. def get_sharing_strategy():
  45. """Returns the current strategy for sharing CPU tensors."""
  46. return _sharing_strategy
  47. def get_all_sharing_strategies():
  48. """Returns a set of sharing strategies supported on a current system."""
  49. return _all_sharing_strategies
  50. init_reductions()