123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- from typing import Any, Dict, List, Optional, Tuple, Union, overload
- from datetime import timedelta
- import enum
- import torch
- from torch.types import Device
- from . import Future
- from ._autograd import ProfilerEvent
- from ._distributed_c10d import ProcessGroup, Store
- from ._profiler import ActiveProfilerType, ProfilerConfig, ProfilerState
- # This module is defined in torch/csrc/distributed/rpc/init.cpp
- _DEFAULT_INIT_METHOD: str
- _DEFAULT_NUM_WORKER_THREADS: int
- _UNSET_RPC_TIMEOUT: float
- _DEFAULT_RPC_TIMEOUT_SEC: float
- class RpcBackendOptions:
- rpc_timeout: float
- init_method: str
- def __init__(
- self,
- rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
- init_method: str = _DEFAULT_INIT_METHOD,
- ): ...
- class WorkerInfo:
- def __init__(self, name: str, worker_id: int): ...
- @property
- def name(self) -> str: ...
- @property
- def id(self) -> int: ...
- def __eq__(self, other: object) -> bool: ...
- def __repr__(self) -> str: ...
- class RpcAgent:
- def join(self, shutdown: bool = False, timeout: float = 0): ...
- def sync(self): ...
- def shutdown(self): ...
- @overload
- def get_worker_info(self) -> WorkerInfo: ...
- @overload
- def get_worker_info(self, workerName: str) -> WorkerInfo: ...
- def get_worker_infos(self) -> List[WorkerInfo]: ...
- def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
- def get_debug_info(self) -> Dict[str, str]: ...
- def get_metrics(self) -> Dict[str, str]: ...
- class PyRRef:
- def __init__(self, value: Any, type_hint: Any = None): ...
- def is_owner(self) -> bool: ...
- def confirmed_by_owner(self) -> bool: ...
- def owner(self) -> WorkerInfo: ...
- def owner_name(self) -> str: ...
- def to_here(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
- def local_value(self) -> Any: ...
- def rpc_sync(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
- def rpc_async(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
- def remote(self, timeout: float = _UNSET_RPC_TIMEOUT) -> Any: ...
- def _serialize(self) -> Tuple: ...
- @staticmethod
- def _deserialize(tp: Tuple) -> 'PyRRef': ...
- def _get_type(self) -> Any: ...
- def _get_future(self) -> Future: ...
- def _get_profiling_future(self) -> Future: ...
- def _set_profiling_future(self, profilingFuture: Future): ...
- def __repr__(self) -> str: ...
- ...
- class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions):
- num_worker_threads: int
- device_maps: Dict[str, Dict[torch.device, torch.device]]
- devices: List[torch.device]
- def __init__(
- self,
- num_worker_threads: int,
- _transports: Optional[List],
- _channels: Optional[List],
- rpc_timeout: float = _DEFAULT_RPC_TIMEOUT_SEC,
- init_method: str = _DEFAULT_INIT_METHOD,
- device_maps: Dict[str, Dict[torch.device, torch.device]] = {},
- devices: List[torch.device] = list()): ...
- def _set_device_map(self, to: str, device_map: Dict[torch.device, torch.device]): ...
- class TensorPipeAgent(RpcAgent):
- def __init__(
- self,
- store: Store,
- name: str,
- worker_id: int,
- world_size: Optional[int],
- opts: _TensorPipeRpcBackendOptionsBase,
- reverse_device_maps: Dict[str, Dict[torch.device, torch.device]],
- devices: List[torch.device],
- ): ...
- def join(self, shutdown: bool = False, timeout: float = 0): ...
- def shutdown(self): ...
- @overload
- def get_worker_info(self) -> WorkerInfo: ...
- @overload
- def get_worker_info(self, workerName: str) -> WorkerInfo: ...
- @overload
- def get_worker_info(self, id: int) -> WorkerInfo: ...
- def get_worker_infos(self) -> List[WorkerInfo]: ...
- def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
- def _update_group_membership(
- self,
- worker_info: WorkerInfo,
- my_devices: List[torch.device],
- reverse_device_map: Dict[str, Dict[torch.device, torch.device]],
- is_join: bool): ...
- def _get_backend_options(self) -> _TensorPipeRpcBackendOptionsBase: ...
- @property
- def is_static_group(self) -> bool: ...
- @property
- def store(self) -> Store: ...
- def _is_current_rpc_agent_set() -> bool: ...
- def _get_current_rpc_agent()-> RpcAgent: ...
- def _set_and_start_rpc_agent(agent: RpcAgent): ...
- def _reset_current_rpc_agent(): ...
- def _delete_all_user_and_unforked_owner_rrefs(timeout: timedelta = ...): ...
- def _destroy_rref_context(ignoreRRefLeak: bool): ...
- def _rref_context_get_debug_info() -> Dict[str, str]: ...
- def _cleanup_python_rpc_handler(): ...
- def _invoke_rpc_builtin(
- dst: WorkerInfo,
- opName: str,
- rpcTimeoutSeconds: float,
- *args: Any,
- **kwargs: Any
- ): ...
- def _invoke_rpc_python_udf(
- dst: WorkerInfo,
- pickledPythonUDF: str,
- tensors: List[torch.Tensor],
- rpcTimeoutSeconds: float,
- isAsyncExecution: bool
- ): ...
- def _invoke_rpc_torchscript(
- dstWorkerName: str,
- qualifiedNameStr: str,
- argsTuple: Tuple,
- kwargsDict: Dict,
- rpcTimeoutSeconds: float,
- isAsyncExecution: bool,
- ): ...
- def _invoke_remote_builtin(
- dst: WorkerInfo,
- opName: str,
- rpcTimeoutSeconds: float,
- *args: Any,
- **kwargs: Any
- ): ...
- def _invoke_remote_python_udf(
- dst: WorkerInfo,
- pickledPythonUDF: str,
- tensors: List[torch.Tensor],
- rpcTimeoutSeconds: float,
- isAsyncExecution: bool,
- ): ...
- def _invoke_remote_torchscript(
- dstWorkerName: WorkerInfo,
- qualifiedNameStr: str,
- rpcTimeoutSeconds: float,
- isAsyncExecution: bool,
- *args: Any,
- **kwargs: Any
- ): ...
- def get_rpc_timeout() -> float: ...
- def enable_gil_profiling(flag: bool): ...
- def _set_rpc_timeout(rpcTimeoutSeconds: float): ...
- class RemoteProfilerManager:
- @staticmethod
- def set_current_profiling_key(key: str): ...
- def _enable_server_process_global_profiler(new_config: ProfilerConfig): ...
- def _disable_server_process_global_profiler() -> List[List[List[ProfilerEvent]]]: ...
- def _set_profiler_node_id(default_node_id: int): ...
- def _enable_jit_rref_pickle(): ...
- def _disable_jit_rref_pickle(): ...
|