unboxing.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from dataclasses import dataclass
  2. from typing import Callable, List, Sequence, Tuple
  3. from torchgen.api.types import Binding, CType, NamedCType
  4. from torchgen.model import (
  5. Argument,
  6. BaseTy,
  7. BaseType,
  8. ListType,
  9. NativeFunction,
  10. OptionalType,
  11. Type,
  12. )
  13. connector = "\n\t"
  14. # Return unboxing function name for a NativeFunction
  15. def name(f: NativeFunction) -> str:
  16. return f.func.name.unambiguous_name()
  17. @dataclass(frozen=True)
  18. class Unboxing:
  19. """
  20. Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
  21. A sample generated code:
  22. // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
  23. void mul_out(EValue** stack) {
  24. EValue& self = *stack[0];
  25. EValue& other = *stack[1];
  26. EValue& out = *stack[2];
  27. const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
  28. const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
  29. torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
  30. EXECUTORCH_SCOPE_PROF("native_call_mul.out");
  31. torch::executor::mul_outf(self_base, other_base, out_base);
  32. }
  33. """
  34. # this is a callable that converts a JIT argument, into its C++ type.
  35. # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
  36. argument_type_gen: Callable[
  37. ...,
  38. NamedCType,
  39. ]
  40. # Convert all the arguments in a NativeFunction to C++ code
  41. def convert_arguments(
  42. self, args: Sequence[Binding]
  43. ) -> Tuple[List[Binding], List[str]]:
  44. code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
  45. binding_list = []
  46. for arg in args:
  47. # expecting only Argument
  48. if not isinstance(arg.argument, Argument):
  49. raise Exception(
  50. f"Unexpected argument type, expecting `Argument` but got {arg}"
  51. )
  52. argument: Argument = arg.argument
  53. unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
  54. argument.type, argument.name, mutable=argument.is_write
  55. )
  56. code_list.extend(decl)
  57. code_list.extend(code)
  58. binding_list.append(arg.with_name(unboxed_name))
  59. return binding_list, code_list
  60. def argumenttype_evalue_convert(
  61. self, t: Type, arg_name: str, *, mutable: bool = False
  62. ) -> Tuple[str, CType, List[str], List[str]]:
  63. """
  64. Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
  65. (1) the C++ code necessary to unbox the argument
  66. (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
  67. :param t: a `Type` of an argument
  68. :param arg_name: argument name
  69. :param mutable: boolean for whether this argument type is mutable
  70. :return: unboxed result
  71. """
  72. ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
  73. if isinstance(t, BaseType):
  74. out_name = f"{arg_name}_base"
  75. code, decl = self._gen_code_base_type(
  76. arg_name=arg_name, out_name=out_name, ctype=ctype
  77. )
  78. elif isinstance(t, OptionalType):
  79. out_name = f"{arg_name}_opt_out"
  80. code, decl = self._gen_code_optional_type(
  81. arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
  82. )
  83. elif isinstance(t, ListType):
  84. out_name = f"{arg_name}_list_out"
  85. code, decl = self._gen_code_list_type(
  86. arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
  87. )
  88. else:
  89. raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")
  90. return out_name, ctype, code, decl
  91. def _gen_code_base_type(
  92. self, arg_name: str, out_name: str, ctype: CType
  93. ) -> Tuple[List[str], List[str]]:
  94. return [
  95. f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
  96. ], []
  97. def _gen_code_optional_type(
  98. self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
  99. ) -> Tuple[List[str], List[str]]:
  100. in_name = f"{arg_name}_opt_in"
  101. res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
  102. t.elem, in_name
  103. )
  104. return (
  105. f"""
  106. {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
  107. """.split(
  108. "\n"
  109. ),
  110. decl,
  111. )
  112. def _gen_code_list_type(
  113. self, arg_name: str, out_name: str, t: ListType, ctype: CType
  114. ) -> Tuple[List[str], List[str]]:
  115. in_name = f"{arg_name}_list_in"
  116. elem_name = f"{arg_name}_elem"
  117. code = []
  118. res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
  119. t.elem, elem_name
  120. )
  121. if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
  122. code.extend(
  123. f"""
  124. {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList();
  125. """.split(
  126. "\n"
  127. )
  128. )
  129. elif isinstance(t.elem, BaseType) and (
  130. t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
  131. ):
  132. code.extend(
  133. f"""
  134. {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList();
  135. """.split(
  136. "\n"
  137. )
  138. )
  139. elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
  140. code.extend(
  141. f"""
  142. {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList();
  143. """.split(
  144. "\n"
  145. )
  146. )
  147. elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
  148. # handle list type with size, e.g., bool[4]
  149. code.extend(
  150. f"""
  151. {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList();
  152. """.split(
  153. "\n"
  154. )
  155. )
  156. # pytorch codegen:
  157. # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<c10::optional<at::Tensor>>
  158. elif (
  159. isinstance(t.elem, OptionalType)
  160. and isinstance(t.elem.elem, BaseType)
  161. and t.elem.elem.name == BaseTy.Tensor
  162. ):
  163. code.extend(
  164. f"""
  165. #ifdef USE_ATEN_LIB
  166. at::ArrayRef<c10::optional<at::Tensor>> {in_name} = {arg_name}.toListOptionalTensor();
  167. c10::List<c10::optional<at::Tensor>> {out_name};
  168. for (auto {elem_name}: {in_name}) {{
  169. {out_name}.push_back({elem_name});
  170. }}
  171. #else
  172. torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> {out_name} = {arg_name}.toListOptionalTensor();
  173. #endif
  174. """.split(
  175. "\n"
  176. )
  177. )
  178. else:
  179. # use ArrayRef as default.
  180. vec_name = arg_name + "_vec"
  181. # need to bring vector instantiation out of scope so that ArrayRef has valid data
  182. decl.append(
  183. f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
  184. )
  185. code.extend(
  186. f"""
  187. for (EValue {elem_name}: {in_name}) {{
  188. {connector.join(res_code)}
  189. {vec_name}.push_back({res_name});
  190. }}
  191. {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
  192. """.split(
  193. "\n"
  194. )
  195. )
  196. return code, decl