_distributed_rpc.pyi 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from typing import Any, Dict, List, Optional, Tuple, Union, overload
  2. from datetime import timedelta
  3. import enum
  4. import torch
  5. from torch.types import Device
  6. from . import Future
  7. from ._autograd import ProfilerEvent
  8. from ._distributed_c10d import ProcessGroup, Store
  9. from ._profiler import ActiveProfilerType, ProfilerConfig, ProfilerState
  10. # This module is defined in torch/csrc/distributed/rpc/init.cpp
  11. _DEFAULT_INIT_METHOD: str
  12. _DEFAULT_NUM_WORKER_THREADS: int
  13. _UNSET_RPC_TIMEOUT: float
  14. _DEFAULT_RPC_TIMEOUT_SEC: float
  15. class RpcBackendOptions:
  16. rpc_timeout: float
  17. init_method: str
  18. def __init__(
  19. self,
  20. rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
  21. init_method: str = _DEFAULT_INIT_METHOD,
  22. ): ...
  23. class WorkerInfo:
  24. def __init__(self, name: str, worker_id: int): ...
  25. @property
  26. def name(self) -> str: ...
  27. @property
  28. def id(self) -> int: ...
  29. def __eq__(self, other: object) -> bool: ...
  30. def __repr__(self) -> str: ...
  31. class RpcAgent:
  32. def join(self, shutdown: bool = False, timeout: float = 0): ...
  33. def sync(self): ...
  34. def shutdown(self): ...
  35. @overload
  36. def get_worker_info(self) -> WorkerInfo: ...
  37. @overload
  38. def get_worker_info(self, workerName: str) -> WorkerInfo: ...
  39. def get_worker_infos(self) -> List[WorkerInfo]: ...
  40. def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
  41. def get_debug_info(self) -> Dict[str, str]: ...
  42. def get_metrics(self) -> Dict[str, str]: ...
  43. class PyRRef:
  44. def __init__(self, value: Any, type_hint: Any = None): ...
  45. def is_owner(self) -> bool: ...
  46. def confirmed_by_owner(self) -> bool: ...
  47. def owner(self) -> WorkerInfo: ...
  48. def owner_name(self) -> str: ...
  49. def to_here(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
  50. def local_value(self) -> Any: ...
  51. def rpc_sync(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
  52. def rpc_async(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
  53. def remote(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
  54. def _serialize(self) -> Tuple: ...
  55. @staticmethod
  56. def _deserialize(tp: Tuple) -> 'PyRRef': ...
  57. def _get_type(self) -> Any: ...
  58. def _get_future(self) -> Future: ...
  59. def _get_profiling_future(self) -> Future: ...
  60. def _set_profiling_future(self, profilingFuture: Future): ...
  61. def __repr__(self) -> str: ...
  62. ...
  63. class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
  64. num_worker_threads: int
  65. device_maps: Dict[str, Dict[torch.device, torch.device]]
  66. devices: List[torch.device]
  67. def __init__(
  68. self,
  69. num_worker_threads: int,
  70. _transports: Optional[List],
  71. _channels: Optional[List],
  72. rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
  73. init_method: str = _DEFAULT_INIT_METHOD,
  74. device_maps: Dict[str, Dict[torch.device, torch.device]] = {},
  75. devices: List[torch.device] = list()): ...
  76. def _set_device_map(self, to: str, device_map: Dict[torch.device, torch.device]): ...
  77. class TensorPipeAgent(RpcAgent):
  78. def __init__(
  79. self,
  80. store: Store,
  81. name: str,
  82. worker_id: int,
  83. world_size: Optional[int],
  84. opts: _TensorPipeRpcBackendOptionsBase,
  85. reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
  86. devices: List[torch.device],
  87. ): ...
  88. def join(self, shutdown: bool = False, timeout: float = 0): ...
  89. def shutdown(self): ...
  90. @overload
  91. def get_worker_info(self) -> WorkerInfo: ...
  92. @overload
  93. def get_worker_info(self, workerName: str) -> WorkerInfo: ...
  94. @overload
  95. def get_worker_info(self, id: int) -> WorkerInfo: ...
  96. def get_worker_infos(self) -> List[WorkerInfo]: ...
  97. def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
  98. def _update_group_membership(
  99. self,
  100. worker_info: WorkerInfo,
  101. my_devices: List[torch.device],
  102. reverse_device_map: Dict[str, Dict[torch.device, torch.device]],
  103. is_join: bool): ...
  104. def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
  105. @property
  106. def is_static_group(self) -> bool: ...
  107. @property
  108. def store(self) -> Store: ...
  109. def _is_current_rpc_agent_set() -> bool: ...
  110. def _get_current_rpc_agent()-> RpcAgent: ...
  111. def _set_and_start_rpc_agent(agent: RpcAgent): ...
  112. def _reset_current_rpc_agent(): ...
  113. def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
  114. def _destroy_rref_context(ignoreRRefLeak: bool): ...
  115. def _rref_context_get_debug_info() -> Dict[str, str]: ...
  116. def _cleanup_python_rpc_handler(): ...
  117. def _invoke_rpc_builtin(
  118. dst: WorkerInfo,
  119. opName: str,
  120. rpcTimeoutSeconds: float,
  121. *args: Any,
  122. **kwargs: Any
  123. ): ...
  124. def _invoke_rpc_python_udf(
  125. dst: WorkerInfo,
  126. pickledPythonUDF: str,
  127. tensors: List[torch.Tensor],
  128. rpcTimeoutSeconds: float,
  129. isAsyncExecution: bool
  130. ): ...
  131. def _invoke_rpc_torchscript(
  132. dstWorkerName: str,
  133. qualifiedNameStr: str,
  134. argsTuple: Tuple,
  135. kwargsDict: Dict,
  136. rpcTimeoutSeconds: float,
  137. isAsyncExecution: bool,
  138. ): ...
  139. def _invoke_remote_builtin(
  140. dst: WorkerInfo,
  141. opName: str,
  142. rpcTimeoutSeconds: float,
  143. *args: Any,
  144. **kwargs: Any
  145. ): ...
  146. def _invoke_remote_python_udf(
  147. dst: WorkerInfo,
  148. pickledPythonUDF: str,
  149. tensors: List[torch.Tensor],
  150. rpcTimeoutSeconds: float,
  151. isAsyncExecution: bool,
  152. ): ...
  153. def _invoke_remote_torchscript(
  154. dstWorkerName: WorkerInfo,
  155. qualifiedNameStr: str,
  156. rpcTimeoutSeconds: float,
  157. isAsyncExecution: bool,
  158. *args: Any,
  159. **kwargs: Any
  160. ): ...
  161. def get_rpc_timeout() -> float: ...
  162. def enable_gil_profiling(flag: bool): ...
  163. def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
  164. class RemoteProfilerManager:
  165. @staticmethod
  166. def set_current_profiling_key(key: str): ...
  167. def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
  168. def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
  169. def _set_profiler_node_id(default_node_id: int): ...
  170. def _enable_jit_rref_pickle(): ...
  171. def _disable_jit_rref_pickle(): ...