context.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import contextlib
  2. import functools
  3. from typing import Callable, Dict, Iterator, Optional, TypeVar, Union
  4. import torchgen.local as local
  5. from torchgen.model import (
  6. BackendIndex,
  7. DispatchKey,
  8. NativeFunction,
  9. NativeFunctionsGroup,
  10. NativeFunctionsViewGroup,
  11. )
  12. from torchgen.utils import context, S, T
  13. # Helper functions for defining generators on things in the model
  14. F = TypeVar(
  15. "F",
  16. NativeFunction,
  17. NativeFunctionsGroup,
  18. NativeFunctionsViewGroup,
  19. Union[NativeFunction, NativeFunctionsGroup],
  20. Union[NativeFunction, NativeFunctionsViewGroup],
  21. )
  22. F2 = TypeVar(
  23. "F2",
  24. NativeFunction,
  25. NativeFunctionsGroup,
  26. Optional[NativeFunction],
  27. bool,
  28. str,
  29. )
  30. @contextlib.contextmanager
  31. def native_function_manager(
  32. g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]
  33. ) -> Iterator[None]:
  34. if isinstance(g, NativeFunctionsGroup):
  35. # By default, we associate all errors with structured native functions
  36. # with the out variant. In some cases, it might be better to have
  37. # a more specific place to hang things; if so, use
  38. # native_function_manager again on the inside
  39. f = g.out
  40. elif isinstance(g, NativeFunctionsViewGroup):
  41. # We associate errors with the view operator
  42. f = g.view
  43. else:
  44. f = g
  45. with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"):
  46. with local.parametrize(
  47. use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
  48. use_ilistref_for_tensor_lists=f.part_of_structured_group,
  49. ):
  50. yield
  51. # Given a function that operates on NativeFunction, wrap it into a new function
  52. # that sets some appropriate context managers for that native function.
  53. # YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
  54. # (you will get an error if we try to access the local variables without having
  55. # set them).
  56. def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
  57. @functools.wraps(func)
  58. def wrapper(f: F) -> T:
  59. with native_function_manager(f):
  60. return func(f)
  61. return wrapper
  62. def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
  63. @functools.wraps(func)
  64. def wrapper(f: F, f2: F2) -> T:
  65. # The first native_function is assumed to be the one with the appropriate context.
  66. with native_function_manager(f):
  67. return func(f, f2)
  68. return wrapper
  69. def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
  70. @functools.wraps(func)
  71. def wrapper(slf: S, f: F) -> T:
  72. with native_function_manager(f):
  73. return func(slf, f)
  74. return wrapper
  75. # Convenience decorator for functions that explicitly take in a BackendIndex,
  76. # instead of indirectly taking one in as a closure
  77. def with_native_function_and_index(
  78. func: Callable[[F, BackendIndex], T]
  79. ) -> Callable[[F, BackendIndex], T]:
  80. @functools.wraps(func)
  81. def wrapper(f: F, backend_index: BackendIndex) -> T:
  82. with native_function_manager(f):
  83. return func(f, backend_index)
  84. return wrapper
  85. # Convenience decorator for functions that explicitly take in a Dict of BackendIndices
  86. def with_native_function_and_indices(
  87. func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
  88. ) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
  89. @functools.wraps(func)
  90. def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T:
  91. with native_function_manager(f):
  92. return func(f, backend_indices)
  93. return wrapper