_distributed_autograd.pyi 908 B

12345678910111213141516171819202122232425
  1. import torch
  2. from typing import Dict, List, Set, Any
  3. # This module is defined in torch/csrc/distributed/autograd/init.cpp
  4. class DistAutogradContext:
  5. def _context_id(self) -> int: ...
  6. def _recv_functions(self) -> Dict[int, Any]: ...
  7. def _send_functions(self) -> Dict[int, Any]: ...
  8. def _known_worker_ids(self) -> Set[int]: ...
  9. def _new_context() -> DistAutogradContext: ...
  10. def _release_context(context_id: int) -> None: ...
  11. def _get_max_id() -> int: ...
  12. def _is_valid_context(worker_id: int) -> bool: ...
  13. def _retrieve_context(context_id: int) -> DistAutogradContext: ...
  14. def _current_context() -> DistAutogradContext: ...
  15. def _init(worker_id: int) -> None: ...
  16. def _get_debug_info() -> Dict[str, str]: ...
  17. def backward(
  18. context_id: int,
  19. roots: List[torch.Tensor],
  20. retain_graph = False
  21. ) -> None: ...
  22. def get_gradients(context_id: int) -> Dict[torch.Tensor, torch.Tensor]: ...