_autograd.pyi 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from typing import List, Set, Callable, Any, Union, Optional
  2. from enum import Enum
  3. import torch
  4. from ._profiler import _ProfilerEvent, ActiveProfilerType, ProfilerActivity, ProfilerConfig
  5. # Defined in tools/autograd/init.cpp
  6. class DeviceType(Enum):
  7. CPU = ...
  8. CUDA = ...
  9. MKLDNN = ...
  10. OPENGL = ...
  11. OPENCL = ...
  12. IDEEP = ...
  13. HIP = ...
  14. FPGA = ...
  15. ORT = ...
  16. XLA = ...
  17. MPS = ...
  18. HPU = ...
  19. Meta = ...
  20. Vulkan = ...
  21. Metal = ...
  22. ...
  23. class ProfilerEvent:
  24. def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
  25. def cpu_memory_usage(self) -> int: ...
  26. def cuda_elapsed_us(self, other: ProfilerEvent) -> float: ...
  27. def cuda_memory_usage(self) -> int: ...
  28. def device(self) -> int: ...
  29. def handle(self) -> int: ...
  30. def has_cuda(self) -> bool: ...
  31. def is_remote(self) -> bool: ...
  32. def kind(self) -> int: ...
  33. def name(self) -> str: ...
  34. def node_id(self) -> int: ...
  35. def sequence_nr(self) -> int: ...
  36. def shapes(self) -> List[List[int]]: ...
  37. def thread_id(self) -> int: ...
  38. def flops(self) -> float: ...
  39. def is_async(self) -> bool: ...
  40. ...
  41. class _KinetoEvent:
  42. def name(self) -> str: ...
  43. def device_index(self) -> int: ...
  44. def start_us(self) -> int: ...
  45. def duration_us(self) -> int: ...
  46. def is_async(self) -> bool: ...
  47. def linked_correlation_id(self) -> int: ...
  48. ...
  49. class _ProfilerResult:
  50. def events(self) -> List[_KinetoEvent]: ...
  51. def legacy_events(self) -> List[List[ProfilerEvent]]: ...
  52. def save(self, path: str) -> None: ...
  53. def experimental_event_tree(self) -> List[_ProfilerEvent]: ...
  54. class SavedTensor:
  55. ...
  56. def _enable_profiler(config: ProfilerConfig, activities: Set[ProfilerActivity]) -> None: ...
  57. def _prepare_profiler(config: ProfilerConfig, activities: Set[ProfilerActivity]) -> None: ...
  58. def _disable_profiler() -> _ProfilerResult: ...
  59. def _profiler_enabled() -> bool: ...
  60. def _add_metadata_json(key: str, value: str) -> None: ...
  61. def _kineto_step() -> None: ...
  62. def kineto_available() -> bool: ...
  63. def _record_function_with_args_enter(name: str, args: List[Any]) -> torch.Tensor: ...
  64. def _record_function_with_args_exit(handle: torch.Tensor) -> None: ...
  65. def _supported_activities() -> Set[ProfilerActivity]: ...
  66. def _enable_record_function(enable: bool) -> None: ...
  67. def _set_empty_test_observer(is_global: bool, sampling_prob: float) -> None: ...
  68. def _push_saved_tensors_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
  69. def _pop_saved_tensors_default_hooks() -> None: ...
  70. def _unsafe_set_version_counter(t: torch.Tensor, prev_version: int) -> None: ...
  71. def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
  72. def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
  73. def _profiler_type() -> ActiveProfilerType: ...
  74. def _saved_tensors_hooks_enable() -> None: ...
  75. def _saved_tensors_hooks_disable(message: str) -> None: ...
  76. def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ...