_distributed_rpc_testing.pyi 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import torch
  2. from ._distributed_c10d import ProcessGroup, Store
  3. from ._distributed_rpc import (
  4. _TensorPipeRpcBackendOptionsBase,
  5. TensorPipeAgent,
  6. WorkerInfo,
  7. )
  8. from typing import List, Dict, overload
  9. from datetime import timedelta
  10. # This module is defined in torch/csrc/distributed/rpc/testing/init.cpp
  11. class FaultyTensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
  12. def __init__(
  13. self,
  14. num_worker_threads: int,
  15. rpc_timeout: float,
  16. init_method: str,
  17. messages_to_fail: List[str],
  18. messages_to_delay: Dict[str, float],
  19. num_fail_sends: int,
  20. ): ...
  21. num_send_recv_threads: int
  22. messages_to_fail: List[str]
  23. messages_to_delay: Dict[str, float]
  24. num_fail_sends: int
  25. class FaultyTensorPipeAgent(TensorPipeAgent):
  26. def __init__(
  27. self,
  28. store: Store,
  29. name: str,
  30. rank: int,
  31. world_size: int,
  32. options: FaultyTensorPipeRpcBackendOptions,
  33. reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
  34. devices: List[torch.device],
  35. ): ...