123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- from typing import List, Union
- from torchgen.api import cpp
- from torchgen.api.types import (
- ArgName,
- ArrayRefCType,
- BaseCType,
- Binding,
- ConstRefCType,
- dimnameListT,
- intArrayRefT,
- iOptTensorListRefT,
- iTensorListRefT,
- NamedCType,
- OptionalCType,
- optionalIntArrayRefT,
- optionalScalarRefT,
- optionalTensorRefT,
- scalarT,
- tensorT,
- )
- from torchgen.model import (
- Argument,
- BaseTy,
- BaseType,
- ListType,
- NativeFunctionsGroup,
- OptionalType,
- SelfArgument,
- TensorOptionsArguments,
- Type,
- )
- from torchgen.utils import assert_never
- # This file describes the translation of JIT schema to the structured functions API.
- # This is similar to native API, but a number of historical problems with native
- # API have been fixed.
- # Translation of types occuring in JIT arguments to a C++ argument type.
- # NB: For now, mutable doesn't do anything; but it could if we make
- # some more nominal types
- def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
- # If it's a value type, do the value type translation
- # NB: structured kernels ALWAYS have symint off, since they involve actual
- # kernels that require real ints. The one exception is the
- # CompositeExplicitAutograd and the meta function (which could
- # hypothetically be SymInt), but for simplicity we plan for these to just
- # be handled in Python
- r = cpp.valuetype_type(t, symint=False, binds=binds)
- if r is not None:
- return r
- if isinstance(t, BaseType):
- if t.name == BaseTy.Tensor:
- return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
- elif t.name == BaseTy.Scalar:
- return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
- else:
- raise AssertionError(f"base type should have been value type {t}")
- elif isinstance(t, OptionalType):
- if t.elem == BaseType(BaseTy.Tensor):
- return NamedCType(binds, BaseCType(optionalTensorRefT))
- elif t.elem == BaseType(BaseTy.Scalar):
- return NamedCType(binds, BaseCType(optionalScalarRefT))
- elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
- return NamedCType(binds, BaseCType(optionalIntArrayRefT))
- elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
- return NamedCType(binds, OptionalCType(elem.type))
- elif isinstance(t, ListType):
- if t.elem == BaseType(BaseTy.Tensor):
- return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
- elif t.elem == OptionalType(BaseType(BaseTy.Tensor)):
- return NamedCType(binds, BaseCType(iOptTensorListRefT))
- # TODO: delete these special cases; see torchgen.api.cpp--these
- # must be changed in tandem, but there are problems; see
- # https://github.com/pytorch/pytorch/pull/51485
- elif str(t.elem) == "int":
- return NamedCType(binds, BaseCType(intArrayRefT))
- elif str(t.elem) == "Dimname":
- return NamedCType(binds, BaseCType(dimnameListT))
- elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
- return NamedCType(binds, ArrayRefCType(elem.type))
- else:
- raise AssertionError(f"unrecognized type {repr(t)}")
- def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
- return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
- # returns_type intentionally omitted, because structured kernels never "return";
- # instead, they always indirectly report their outputs (in the case of a meta
- # function, by calling set_output; in the case of an impl function, by writing
- # directly into the provided out argument).
- # Structured kernels are never defaulted
- def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
- if isinstance(a, Argument):
- return [
- Binding(
- nctype=argument_type(a, binds=a.name),
- name=a.name,
- default=None,
- argument=a,
- )
- ]
- elif isinstance(a, SelfArgument):
- return argument(a.argument)
- elif isinstance(a, TensorOptionsArguments):
- raise AssertionError("structured kernels don't support TensorOptions yet")
- else:
- assert_never(a)
- def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
- args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
- if g.out.precomputed:
- # A list of parameters for the impl function with
- # certain parameters replaced with precomputed counterparts
- # as specified in native_functions.yaml.
- non_out_args_replaced: List[
- Union[Argument, TensorOptionsArguments, SelfArgument]
- ] = []
- for a in g.out.func.arguments.non_out:
- if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
- # If a is in precompute.replace, append the parameters
- # that should replace it onto non_out_args_replaced.
- for replacement in g.out.precomputed.replace[a.name]:
- non_out_args_replaced.append(replacement)
- else:
- # If not, push a as it is.
- non_out_args_replaced.append(a)
- args.extend(non_out_args_replaced)
- # g.out.precomputed.add is the list of parameters that are added
- # without replacement after the non out args and just before the out args
- args.extend(g.out.precomputed.add)
- else:
- args.extend(g.out.func.arguments.non_out)
- args.extend(g.out.func.arguments.out)
- return [r for arg in args for r in argument(arg)]
- def meta_arguments(g: NativeFunctionsGroup) -> List[Binding]:
- args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
- args.extend(g.functional.func.arguments.non_out)
- return [r for arg in args for r in argument(arg)]
- def out_arguments(g: NativeFunctionsGroup) -> List[Binding]:
- args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
- args.extend(g.out.func.arguments.out)
- return [r for arg in args for r in argument(arg)]
|