123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- from dataclasses import dataclass
- from typing import Callable, List, Sequence, Tuple
- from torchgen.api.types import Binding, CType, NamedCType
- from torchgen.model import (
- Argument,
- BaseTy,
- BaseType,
- ListType,
- NativeFunction,
- OptionalType,
- Type,
- )
- connector = "\n\t"
- # Return unboxing function name for a NativeFunction
- def name(f: NativeFunction) -> str:
- return f.func.name.unambiguous_name()
- @dataclass(frozen=True)
- class Unboxing:
- """
- Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
- A sample generated code:
- // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
- void mul_out(EValue** stack) {
- EValue& self = *stack[0];
- EValue& other = *stack[1];
- EValue& out = *stack[2];
- const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
- const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
- torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
- EXECUTORCH_SCOPE_PROF("native_call_mul.out");
- torch::executor::mul_outf(self_base, other_base, out_base);
- }
- """
- # this is a callable that converts a JIT argument, into its C++ type.
- # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
- argument_type_gen: Callable[
- ...,
- NamedCType,
- ]
- # Convert all the arguments in a NativeFunction to C++ code
- def convert_arguments(
- self, args: Sequence[Binding]
- ) -> Tuple[List[Binding], List[str]]:
- code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
- binding_list = []
- for arg in args:
- # expecting only Argument
- if not isinstance(arg.argument, Argument):
- raise Exception(
- f"Unexpected argument type, expecting `Argument` but got {arg}"
- )
- argument: Argument = arg.argument
- unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
- argument.type, argument.name, mutable=argument.is_write
- )
- code_list.extend(decl)
- code_list.extend(code)
- binding_list.append(arg.with_name(unboxed_name))
- return binding_list, code_list
- def argumenttype_evalue_convert(
- self, t: Type, arg_name: str, *, mutable: bool = False
- ) -> Tuple[str, CType, List[str], List[str]]:
- """
- Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
- (1) the C++ code necessary to unbox the argument
- (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
- :param t: a `Type` of an argument
- :param arg_name: argument name
- :param mutable: boolean for whether this argument type is mutable
- :return: unboxed result
- """
- ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
- if isinstance(t, BaseType):
- out_name = f"{arg_name}_base"
- code, decl = self._gen_code_base_type(
- arg_name=arg_name, out_name=out_name, ctype=ctype
- )
- elif isinstance(t, OptionalType):
- out_name = f"{arg_name}_opt_out"
- code, decl = self._gen_code_optional_type(
- arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
- )
- elif isinstance(t, ListType):
- out_name = f"{arg_name}_list_out"
- code, decl = self._gen_code_list_type(
- arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
- )
- else:
- raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")
- return out_name, ctype, code, decl
- def _gen_code_base_type(
- self, arg_name: str, out_name: str, ctype: CType
- ) -> Tuple[List[str], List[str]]:
- return [
- f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
- ], []
- def _gen_code_optional_type(
- self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
- ) -> Tuple[List[str], List[str]]:
- in_name = f"{arg_name}_opt_in"
- res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
- t.elem, in_name
- )
- return (
- f"""
- {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
- """.split(
- "\n"
- ),
- decl,
- )
- def _gen_code_list_type(
- self, arg_name: str, out_name: str, t: ListType, ctype: CType
- ) -> Tuple[List[str], List[str]]:
- in_name = f"{arg_name}_list_in"
- elem_name = f"{arg_name}_elem"
- code = []
- res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
- t.elem, elem_name
- )
- if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
- code.extend(
- f"""
- {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList();
- """.split(
- "\n"
- )
- )
- elif isinstance(t.elem, BaseType) and (
- t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
- ):
- code.extend(
- f"""
- {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList();
- """.split(
- "\n"
- )
- )
- elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
- code.extend(
- f"""
- {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList();
- """.split(
- "\n"
- )
- )
- elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
- # handle list type with size, e.g., bool[4]
- code.extend(
- f"""
- {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList();
- """.split(
- "\n"
- )
- )
- # pytorch codegen:
- # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<c10::optional<at::Tensor>>
- elif (
- isinstance(t.elem, OptionalType)
- and isinstance(t.elem.elem, BaseType)
- and t.elem.elem.name == BaseTy.Tensor
- ):
- code.extend(
- f"""
- #ifdef USE_ATEN_LIB
- at::ArrayRef<c10::optional<at::Tensor>> {in_name} = {arg_name}.toListOptionalTensor();
- c10::List<c10::optional<at::Tensor>> {out_name};
- for (auto {elem_name}: {in_name}) {{
- {out_name}.push_back({elem_name});
- }}
- #else
- torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> {out_name} = {arg_name}.toListOptionalTensor();
- #endif
- """.split(
- "\n"
- )
- )
- else:
- # use ArrayRef as default.
- vec_name = arg_name + "_vec"
- # need to bring vector instantiation out of scope so that ArrayRef has valid data
- decl.append(
- f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
- )
- code.extend(
- f"""
- for (EValue {elem_name}: {in_name}) {{
- {connector.join(res_code)}
- {vec_name}.push_back({res_name});
- }}
- {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
- """.split(
- "\n"
- )
- )
- return code, decl
|