123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- import contextlib
- import functools
- from typing import Callable, Dict, Iterator, Optional, TypeVar, Union
- import torchgen.local as local
- from torchgen.model import (
- BackendIndex,
- DispatchKey,
- NativeFunction,
- NativeFunctionsGroup,
- NativeFunctionsViewGroup,
- )
- from torchgen.utils import context, S, T
- # Helper functions for defining generators on things in the model
- F = TypeVar(
- "F",
- NativeFunction,
- NativeFunctionsGroup,
- NativeFunctionsViewGroup,
- Union[NativeFunction, NativeFunctionsGroup],
- Union[NativeFunction, NativeFunctionsViewGroup],
- )
- F2 = TypeVar(
- "F2",
- NativeFunction,
- NativeFunctionsGroup,
- Optional[NativeFunction],
- bool,
- str,
- )
- @contextlib.contextmanager
- def native_function_manager(
- g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]
- ) -> Iterator[None]:
- if isinstance(g, NativeFunctionsGroup):
- # By default, we associate all errors with structured native functions
- # with the out variant. In some cases, it might be better to have
- # a more specific place to hang things; if so, use
- # native_function_manager again on the inside
- f = g.out
- elif isinstance(g, NativeFunctionsViewGroup):
- # We associate errors with the view operator
- f = g.view
- else:
- f = g
- with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"):
- with local.parametrize(
- use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
- use_ilistref_for_tensor_lists=f.part_of_structured_group,
- ):
- yield
- # Given a function that operates on NativeFunction, wrap it into a new function
- # that sets some appropriate context managers for that native function.
- # YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
- # (you will get an error if we try to access the local variables without having
- # set them).
- def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
- @functools.wraps(func)
- def wrapper(f: F) -> T:
- with native_function_manager(f):
- return func(f)
- return wrapper
- def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
- @functools.wraps(func)
- def wrapper(f: F, f2: F2) -> T:
- # The first native_function is assumed to be the one with the appropriate context.
- with native_function_manager(f):
- return func(f, f2)
- return wrapper
- def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
- @functools.wraps(func)
- def wrapper(slf: S, f: F) -> T:
- with native_function_manager(f):
- return func(slf, f)
- return wrapper
- # Convenience decorator for functions that explicitly take in a BackendIndex,
- # instead of indirectly taking one in as a closure
- def with_native_function_and_index(
- func: Callable[[F, BackendIndex], T]
- ) -> Callable[[F, BackendIndex], T]:
- @functools.wraps(func)
- def wrapper(f: F, backend_index: BackendIndex) -> T:
- with native_function_manager(f):
- return func(f, backend_index)
- return wrapper
- # Convenience decorator for functions that explicitly take in a Dict of BackendIndices
- def with_native_function_and_indices(
- func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
- ) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
- @functools.wraps(func)
- def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T:
- with native_function_manager(f):
- return func(f, backend_indices)
- return wrapper
|