123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- from typing import List, Optional, Sequence, Union
- from torchgen import local
- from torchgen.api import cpp
- from torchgen.api.types import (
- ArgName,
- BaseCType,
- Binding,
- boolT,
- ConstRefCType,
- CType,
- deviceT,
- layoutT,
- ListCType,
- MutRefCType,
- NamedCType,
- OptionalCType,
- scalarT,
- scalarTypeT,
- tensorT,
- )
- from torchgen.model import (
- Argument,
- FunctionSchema,
- Return,
- SelfArgument,
- TensorOptionsArguments,
- Type,
- )
- from torchgen.utils import assert_never
- # This file describes the translation of JIT schema to the native functions API.
- # This looks a lot like the C++ API (which makes historical sense, because the
- # idea was you wrote native functions to implement functions in the C++ API),
- # but over time we have evolved the C++ API without actually changing our
- # native:: kernels. The intention is to make native API and dispatcher API
- # line up as closely as possible, since this results in the least overhead
- # (no translation is needed from dispatcher API to native API).
- #
- # NB: this is symint aware, you will get the non-SymInt variant for some
- # dispatch entries and SymInt for others.
- def name(func: FunctionSchema) -> str:
- name = str(func.name.name)
- # TODO: delete this!
- if func.is_out_fn():
- name += "_out"
- if func.name.overload_name:
- name += f"_{func.name.overload_name}"
- return name
- def argumenttype_type(
- t: Type, *, mutable: bool, binds: ArgName, symint: bool
- ) -> NamedCType:
- if str(t) == "Tensor?":
- tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
- if mutable and not local.use_const_ref_for_mutable_tensors():
- return NamedCType(binds, MutRefCType(tensor_type))
- else:
- return NamedCType(binds, ConstRefCType(tensor_type))
- elif str(t) == "Tensor?[]":
- return NamedCType(
- binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
- )
- elif str(t) == "Scalar":
- return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
- elif str(t) == "Scalar?":
- return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
- return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
- def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
- return cpp.returns_type(rs, symint=symint)
- def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
- return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
- def argument(
- a: Union[Argument, SelfArgument, TensorOptionsArguments],
- *,
- is_out: bool,
- symint: bool,
- ) -> List[Binding]:
- # Ideally, we NEVER default native functions. However, there are a number
- # of functions that call native:: directly and rely on the defaulting
- # existing. So for BC, we generate defaults for non-out variants (but not
- # for out variants, where it is impossible to generate an appropriate
- # default)
- should_default = not is_out
- if isinstance(a, Argument):
- default: Optional[str] = None
- if should_default and a.default is not None:
- default = cpp.default_expr(a.default, a.type, symint=symint)
- return [
- Binding(
- nctype=argument_type(a, binds=a.name, symint=symint),
- name=a.name,
- default=default,
- argument=a,
- )
- ]
- elif isinstance(a, SelfArgument):
- # Erase SelfArgument from the distinction
- return argument(a.argument, is_out=is_out, symint=symint)
- elif isinstance(a, TensorOptionsArguments):
- default = None
- if should_default:
- default = "{}"
- # TODO: Not sure why the arguments assigned here are for
- # TensorOptionsArguments and not the constituent pieces. It seems
- # to matter
- return [
- Binding(
- nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
- name="dtype",
- default=default,
- argument=a,
- ),
- Binding(
- nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
- name="layout",
- default=default,
- argument=a,
- ),
- Binding(
- nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
- name="device",
- default=default,
- argument=a,
- ),
- Binding(
- nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
- name="pin_memory",
- default=default,
- argument=a,
- ),
- ]
- else:
- assert_never(a)
- def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]:
- args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
- args.extend(func.arguments.non_out)
- args.extend(func.arguments.out)
- return [
- r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
- ]
|