_functorch.pyi 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from torch import Tensor
  2. from enum import Enum
  3. from typing import Optional, Tuple
  4. # Defined in torch/csrc/functorch/init.cpp
  5. def _set_dynamic_layer_keys_included(included: bool) -> None: ...
  6. def get_unwrapped(tensor: Tensor) -> Tensor: ...
  7. def is_batchedtensor(tensor: Tensor) -> bool: ...
  8. def is_functionaltensor(tensor: Tensor) -> bool: ...
  9. def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ...
  10. def is_gradtrackingtensor(tensor: Tensor) -> bool: ...
  11. def maybe_get_bdim(tensor: Tensor) -> int: ...
  12. def maybe_get_level(tensor: Tensor) -> int: ...
  13. def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
  14. def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
  15. def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
  16. def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ...
  17. def current_level() -> int: ...
  18. def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
  19. def set_single_level_autograd_function_allowed(allowed: bool) -> None: ...
  20. def get_single_level_autograd_function_allowed() -> bool: ...
  21. # Defined in aten/src/ATen/functorch/Interpreter.h
  22. class TransformType(Enum):
  23. Torch: TransformType = ...
  24. Vmap: TransformType = ...
  25. Grad: TransformType = ...
  26. Jvp: TransformType = ...
  27. Functionalize: TransformType = ...
  28. class RandomnessType(Enum):
  29. Error: TransformType = ...
  30. Same: TransformType = ...
  31. Different: TransformType = ...
  32. class CInterpreter:
  33. def key(self) -> TransformType: ...
  34. def level(self) -> int: ...
  35. class CGradInterpreterPtr:
  36. def __init__(self, interpreter: CInterpreter): ...
  37. def lift(self, Tensor) -> Tensor: ...
  38. def prevGradMode(self) -> bool: ...
  39. class CJvpInterpreterPtr:
  40. def __init__(self, interpreter: CInterpreter): ...
  41. def lift(self, Tensor) -> Tensor: ...
  42. def prevFwdGradMode(self) -> bool: ...
  43. class CFunctionalizeInterpreterPtr:
  44. def __init__(self, interpreter: CInterpreter): ...
  45. def key(self) -> TransformType: ...
  46. def level(self) -> int: ...
  47. def functionalizeAddBackViews(self) -> bool: ...
  48. class CVmapInterpreterPtr:
  49. def __init__(self, interpreter: CInterpreter): ...
  50. def key(self) -> TransformType: ...
  51. def level(self) -> int: ...
  52. def batchSize(self) -> int: ...
  53. def randomness(self) -> RandomnessType: ...
  54. class DynamicLayer:
  55. pass
  56. def peek_interpreter_stack() -> CInterpreter: ...
  57. def pop_dynamic_layer_stack() -> DynamicLayer: ...
  58. def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ...