__init__.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import os
  2. import sys
  3. from enum import Enum
  4. import torch
  5. def is_available() -> bool:
  6. """
  7. Returns ``True`` if the distributed package is available. Otherwise,
  8. ``torch.distributed`` does not expose any other APIs. Currently,
  9. ``torch.distributed`` is available on Linux, MacOS and Windows. Set
  10. ``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source.
  11. Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows,
  12. ``USE_DISTRIBUTED=0`` for MacOS.
  13. """
  14. return hasattr(torch._C, "_c10d_init")
  15. if is_available() and not torch._C._c10d_init():
  16. raise RuntimeError("Failed to initialize torch.distributed")
  17. # Custom Runtime Errors thrown from the distributed package
  18. DistBackendError = torch._C._DistBackendError
  19. if is_available():
  20. from torch._C._distributed_c10d import (
  21. Store,
  22. FileStore,
  23. TCPStore,
  24. ProcessGroup,
  25. Backend as _Backend,
  26. PrefixStore,
  27. Reducer,
  28. Logger,
  29. BuiltinCommHookType,
  30. GradBucket,
  31. Work as _Work,
  32. _DEFAULT_FIRST_BUCKET_BYTES,
  33. _register_comm_hook,
  34. _register_builtin_comm_hook,
  35. _broadcast_coalesced,
  36. _compute_bucket_assignment_by_size,
  37. _verify_params_across_processes,
  38. _test_python_store,
  39. DebugLevel,
  40. get_debug_level,
  41. set_debug_level,
  42. set_debug_level_from_env,
  43. _make_nccl_premul_sum,
  44. )
  45. if sys.platform != "win32":
  46. from torch._C._distributed_c10d import (
  47. HashStore,
  48. _round_robin_process_groups,
  49. )
  50. from .distributed_c10d import * # noqa: F403
  51. # Variables prefixed with underscore are not auto imported
  52. # See the comment in `distributed_c10d.py` above `_backend` on why we expose
  53. # this.
  54. from .distributed_c10d import (
  55. _backend,
  56. _all_gather_base,
  57. _reduce_scatter_base,
  58. _create_process_group_wrapper,
  59. _rank_not_in_group,
  60. _c10d_error_logger,
  61. )
  62. from .rendezvous import (
  63. rendezvous,
  64. _create_store_from_options,
  65. register_rendezvous_handler,
  66. )
  67. from .remote_device import _remote_device
  68. set_debug_level_from_env()
  69. else:
  70. # This stub is sufficient to get
  71. # python test/test_public_bindings.py -k test_correct_module_names
  72. # working even when USE_DISTRIBUTED=0. Feel free to add more
  73. # stubs as necessary.
  74. # We cannot define stubs directly because they confuse pyre
  75. class _ProcessGroupStub:
  76. pass
  77. sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]