_distributed_c10d.pyi 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. from datetime import timedelta
  2. from enum import Enum
  3. from typing import Any, Dict, List, Optional, overload, Tuple, Union
  4. from torch import Tensor
  5. from torch.futures import Future
  6. # This module is defined in torch/csrc/distributed/c10d/init.cpp
  7. _DEFAULT_FIRST_BUCKET_BYTES: int
  8. _DEFAULT_NO_TIMEOUT: timedelta
  9. _DEFAULT_PG_TIMEOUT: timedelta
  10. class BuiltinCommHookType(Enum):
  11. ALLREDUCE = ...
  12. FP16_COMPRESS = ...
  13. def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ...
  14. def _register_builtin_comm_hook(
  15. reducer: Reducer, comm_hook_type: BuiltinCommHookType
  16. ): ...
  17. class GradBucket:
  18. def index(self) -> int: ...
  19. def buffer(self) -> Tensor: ...
  20. def gradients(self) -> List[Tensor]: ...
  21. def is_last(self) -> bool: ...
  22. def set_buffer(self, tensor: Tensor) -> None: ...
  23. def parameters(self) -> List[Tensor]: ...
  24. class Reducer:
  25. def __init__(
  26. self,
  27. params: List[Tensor],
  28. bucket_indices: List[List[int]],
  29. per_bucket_size_limits: List[int],
  30. process_group: ProcessGroup,
  31. expect_sparse_gradients: List[bool] = ...,
  32. bucket_bytes_cap: int = ..., # kDefaultBucketBytesCap in reducer.hpp
  33. find_unused_parameters: bool = ...,
  34. gradient_as_bucket_view: bool = ...,
  35. param_to_name_mapping: Dict[int, str] = ...,
  36. first_bucket_types_cap: int = ..., # kDefaultFirstBucketBytes in reducer.hpp
  37. ): ...
  38. def prepare_for_forward(self) -> None: ...
  39. def prepare_for_backward(self, output: List[Tensor]) -> None: ...
  40. def get_backward_stats(self) -> List[int]: ...
  41. def _install_post_backward_futures(self, futures: List[Future]) -> None: ...
  42. def _rebuild_buckets(self) -> bool: ...
  43. def _get_zeros_like_grad_buckets(self) -> List[GradBucket]: ...
  44. def _push_all_rebuilt_params(self) -> None: ...
  45. def _set_forward_pass_work_handle(
  46. self, work: Work, use_static_world_size: bool
  47. ): ...
  48. def _get_local_used_map(self) -> Tensor: ...
  49. def _set_ddp_runtime_logging_sample_rate(self, sample_rate: int) -> None: ...
  50. def _set_static_graph(self) -> None: ...
  51. def _run_comm_hook(self, bucket: GradBucket) -> Future: ...
  52. def set_logger(self, logger: Logger) -> None: ...
  53. class DDPLoggingData:
  54. strs_map: Dict[str, str]
  55. ints_map: Dict[str, int]
  56. class Logger:
  57. def __init__(self, reducer: Reducer): ...
  58. def set_construction_data_and_log(
  59. self,
  60. module_name: str,
  61. device_ids: List[int],
  62. output_device: int,
  63. broadcast_buffers: bool,
  64. has_sync_bn: bool,
  65. static_graph: bool,
  66. ): ...
  67. def set_runtime_stats_and_log(self) -> None: ...
  68. def set_error_and_log(self, error: str) -> None: ...
  69. def _get_ddp_logging_data(self) -> DDPLoggingData: ...
  70. def _set_comm_hook_name(self, comm_hook: str) -> None: ...
  71. def _set_uneven_input_join(self) -> None: ...
  72. def _set_static_graph(self) -> None: ...
  73. def get_debug_level(): ...
  74. def set_debug_level(): ...
  75. def set_debug_level_from_env(): ...
  76. class DebugLevel(Enum):
  77. OFF = ...
  78. INFO = ...
  79. DETAIL = ...
  80. class ReduceOp:
  81. def __init__(self, op: "RedOpType"): ...
  82. SUM = ...
  83. PRODUCT = ...
  84. MIN = ...
  85. MAX = ...
  86. BAND = ...
  87. BOR = ...
  88. BXOR = ...
  89. PREMUL_SUM = ...
  90. UNUSED = ...
  91. class RedOpType(Enum): ...
  92. class BroadcastOptions:
  93. rootRank: int
  94. rootTensor: int
  95. timeout: timedelta
  96. class AllreduceOptions:
  97. reduceOp: ReduceOp
  98. timeout: timedelta
  99. class AllreduceCoalescedOptions(AllreduceOptions): ...
  100. class ReduceOptions:
  101. reduceOp: ReduceOp
  102. rootRank: int
  103. rootTensor: int
  104. timeout: timedelta
  105. class AllGatherOptions:
  106. timeout: timedelta
  107. class GatherOptions:
  108. rootRank: int
  109. timeout: timedelta
  110. class ScatterOptions:
  111. rootRank: int
  112. timeout: timedelta
  113. class ReduceScatterOptions:
  114. reduceOp: ReduceOp
  115. timeout: timedelta
  116. class BarrierOptions:
  117. device_ids: List[int]
  118. timeout: timedelta
  119. class AllToAllOptions:
  120. timeout: timedelta
  121. class Store:
  122. def set(self, key: str, value: str): ...
  123. def get(self, key: str) -> bytes: ...
  124. def add(self, key: str, value: int) -> int: ...
  125. def compare_set(
  126. self, key: str, expected_value: str, desired_value: str
  127. ) -> bytes: ...
  128. def delete_key(self, key: str) -> bool: ...
  129. def num_keys(self) -> int: ...
  130. def set_timeout(self, timeout: timedelta): ...
  131. @overload
  132. def wait(self, keys: List[str]): ...
  133. @overload
  134. def wait(self, keys: List[str], timeout: timedelta): ...
  135. class FileStore(Store):
  136. def __init__(self, path: str, numWorkers: int = ...): ...
  137. class HashStore(Store):
  138. def __init__(self): ...
  139. class TCPStore(Store):
  140. def __init__(
  141. self,
  142. host_name: str,
  143. port: int,
  144. world_size: Optional[int] = ...,
  145. is_master: bool = ...,
  146. timeout: timedelta = ...,
  147. wait_for_workers: bool = ...,
  148. multi_tenant: bool = ...,
  149. ): ...
  150. @property
  151. def host(self) -> str: ...
  152. @property
  153. def port(self) -> int: ...
  154. class PrefixStore(Store):
  155. def __init__(self, prefix: str, store: Store): ...
  156. @property
  157. def underlying_store(self) -> Store: ...
  158. class Work:
  159. def is_completed(self) -> bool: ...
  160. def is_success(self) -> bool: ...
  161. def exception(self) -> Any: ...
  162. def wait(self, timeout: timedelta = _DEFAULT_NO_TIMEOUT) -> bool: ...
  163. def source_rank(self) -> int: ...
  164. def _source_rank(self) -> int: ...
  165. def result(self) -> List[Tensor]: ...
  166. def synchronize(self): ...
  167. ...
  168. class ProcessGroup:
  169. class Options: ...
  170. def __init__(self): ...
  171. def rank(self) -> int: ...
  172. def size(self) -> int: ...
  173. @overload
  174. def broadcast(
  175. self,
  176. tensors: List[Tensor],
  177. opts=BroadcastOptions(),
  178. ) -> Work: ...
  179. @overload
  180. def broadcast(
  181. self,
  182. tensor: Tensor,
  183. root: int,
  184. ) -> Work: ...
  185. @overload
  186. def allreduce(
  187. self,
  188. tensors: List[Tensor],
  189. opts: AllreduceOptions = AllreduceOptions(),
  190. ) -> Work: ...
  191. @overload
  192. def allreduce(
  193. self,
  194. tensors: List[Tensor],
  195. op=ReduceOp.SUM,
  196. ) -> Work: ...
  197. @overload
  198. def allreduce(
  199. self,
  200. tensor: Tensor,
  201. op=ReduceOp.SUM,
  202. ) -> Work: ...
  203. def allreduce_coalesced(
  204. self,
  205. tensors: List[Tensor],
  206. opts=AllreduceCoalescedOptions(),
  207. ) -> Work: ...
  208. @overload
  209. def reduce(
  210. self,
  211. tensors: List[Tensor],
  212. opts=ReduceOptions(),
  213. ) -> Work: ...
  214. @overload
  215. def reduce(
  216. self,
  217. tensor: Tensor,
  218. root: int,
  219. op=ReduceOp.SUM,
  220. ) -> Work: ...
  221. @overload
  222. def allgather(
  223. self,
  224. output_tensors: List[List[Tensor]],
  225. input_tensors: List[Tensor],
  226. opts=AllGatherOptions(),
  227. ) -> Work: ...
  228. @overload
  229. def allgather(
  230. self,
  231. output_tensors: List[Tensor],
  232. input_tensor: Tensor,
  233. ) -> Work: ...
  234. def _allgather_base(
  235. self,
  236. output: Tensor,
  237. input: Tensor,
  238. opts=AllGatherOptions(),
  239. ) -> Work: ...
  240. def allgather_coalesced(
  241. self,
  242. output_lists: List[List[Tensor]],
  243. input_list: List[Tensor],
  244. opts=AllGatherOptions(),
  245. ) -> Work: ...
  246. @overload
  247. def gather(
  248. self,
  249. output_tensors: List[List[Tensor]],
  250. input_tensors: List[Tensor],
  251. opts=GatherOptions(),
  252. ) -> Work: ...
  253. @overload
  254. def gather(
  255. self,
  256. output_tensors: List[Tensor],
  257. input_tensor: Tensor,
  258. root: int,
  259. ) -> Work: ...
  260. @overload
  261. def scatter(
  262. self,
  263. output_tensors: List[Tensor],
  264. input_tensors: List[List[Tensor]],
  265. opts=ScatterOptions(),
  266. ) -> Work: ...
  267. @overload
  268. def scatter(
  269. self,
  270. output_tensor: Tensor,
  271. input_tensors: List[Tensor],
  272. root: int,
  273. ) -> Work: ...
  274. @overload
  275. def reduce_scatter(
  276. self,
  277. output_tensors: List[Tensor],
  278. input_tensors: List[List[Tensor]],
  279. opts=ReduceScatterOptions(),
  280. ) -> Work: ...
  281. @overload
  282. def reduce_scatter(
  283. self,
  284. output_tensors: Tensor,
  285. input_tensor: List[Tensor],
  286. ) -> Work: ...
  287. def _reduce_scatter_base(
  288. self,
  289. outputTensor: Tensor,
  290. inputTensor: Tensor,
  291. ) -> Work: ...
  292. @overload
  293. def alltoall_base(
  294. self,
  295. output_tensor: Tensor,
  296. input_tensor: Tensor,
  297. output_split_sizes: List[int],
  298. input_split_sizes: List[int],
  299. opts=AllToAllOptions(),
  300. ) -> Work: ...
  301. @overload
  302. def alltoall_base(
  303. self,
  304. output: Tensor,
  305. input: Tensor,
  306. output_split_sizes: List[int],
  307. input_split_sizes: List[int],
  308. ) -> Work: ...
  309. @overload
  310. def alltoall(
  311. self,
  312. output_tensor: List[Tensor],
  313. input_tensor: List[Tensor],
  314. opts=AllToAllOptions(),
  315. ) -> Work: ...
  316. @overload
  317. def alltoall(
  318. self,
  319. output: List[Tensor],
  320. input: List[Tensor],
  321. ) -> Work: ...
  322. def send(
  323. self,
  324. tensors: List[Tensor],
  325. dstRank: int,
  326. tag: int,
  327. ) -> Work: ...
  328. def recv(
  329. self,
  330. tensors: List[Tensor],
  331. srcRank: int,
  332. tag: int,
  333. ) -> Work: ...
  334. def recv_anysource(self, tensors: List[Tensor], tag: int) -> Work: ...
  335. def barrier(self, opts=BarrierOptions()) -> Work: ...
  336. class ProcessGroupRoundRobin(ProcessGroup): ...
  337. def _round_robin_process_groups(
  338. process_groups: List[ProcessGroup],
  339. ) -> ProcessGroupRoundRobin: ...
  340. class ProcessGroupGloo(ProcessGroup):
  341. class Device: ...
  342. class Options: ...
  343. def __init__(
  344. self,
  345. store: Store,
  346. rank: int,
  347. size: int,
  348. timeout: timedelta,
  349. ): ...
  350. @staticmethod
  351. def create_device(hostname=str(), interface=str()) -> Device: ...
  352. ...
  353. @staticmethod
  354. def create_default_device() -> Device: ...
  355. ...
  356. class _ProcessGroupWrapper(ProcessGroup):
  357. def __init__(self, pg: ProcessGroup, gloo_pg: ProcessGroupGloo): ...
  358. wrapped_pg: ProcessGroup
  359. class ProcessGroupNCCL(ProcessGroup):
  360. class Options: ...
  361. def __init__(
  362. self,
  363. store: Store,
  364. rank: int,
  365. size: int,
  366. timeout: timedelta,
  367. ): ...
  368. @staticmethod
  369. def _group_start() -> None: ...
  370. @staticmethod
  371. def _group_end() -> None: ...
  372. ...
  373. class ProcessGroupUCC(ProcessGroup):
  374. def __init__(
  375. self,
  376. store: Store,
  377. rank: int,
  378. size: int,
  379. timeout: timedelta,
  380. ): ...
  381. class ProcessGroupMPI(ProcessGroup):
  382. def __init__(
  383. self,
  384. rank: int,
  385. size: int,
  386. pgComm: int,
  387. ): ...
  388. @staticmethod
  389. def create(ranks: List[int]) -> ProcessGroupMPI: ...
  390. def _compute_bucket_assignment_by_size(
  391. tensors: List[Tensor],
  392. bucket_size_limits: List[int],
  393. expect_sparse_gradient: List[bool] = ...,
  394. tensor_indices: List[int] = ...,
  395. ) -> Tuple[List[List[int]], List[int]]: ...
  396. def _broadcast_coalesced(
  397. process_group: ProcessGroup,
  398. tensors: List[Tensor],
  399. buffer_size: int,
  400. src: int,
  401. ): ...
  402. def _test_python_store(store: Store): ...
  403. def _verify_params_across_processes(
  404. process_group: ProcessGroup,
  405. params: List[Tensor],
  406. logger: Optional[Logger],
  407. ): ...
  408. def _make_nccl_premul_sum(factor: Union[float, List[Tensor]]) -> ReduceOp: ...
  409. class Backend:
  410. def __init__(
  411. self,
  412. rank: int,
  413. size: int,
  414. ): ...