unboxing.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. from typing import List, Tuple
  2. from torchgen.api import cpp
  3. from torchgen.api.types import Binding, CppSignatureGroup, CType
  4. from torchgen.model import (
  5. Argument,
  6. BaseTy,
  7. BaseType,
  8. ListType,
  9. NativeFunction,
  10. OptionalType,
  11. Type,
  12. )
  13. # This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the
  14. # ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is
  15. # an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the
  16. # job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register
  17. # a fixed set of operators known at compile time and thus can save some time in runtime initialization phase.
  18. #
  19. # Here's an example on how the codegen works:
  20. #
  21. # - Function Schema (source of truth)
  22. #
  23. # aten::empty.names(int[] size, *, Dimname[]? names,
  24. # ScalarType? dtype=None, Layout? layout=None,
  25. # Device? device=None, bool? pin_memory=None,
  26. # MemoryFormat? memory_format=None) -> Tensor
  27. # - Argument Conversion
  28. # Generates C++ code to convert an ivalue (from stack) to its underlying C++ type.
  29. # - int[] size
  30. # ```cpp
  31. # const c10::List<c10::IValue> size_list_in = (std::move(peek(stack, 0, 7))).toList();
  32. #
  33. # std::vector<int64_t> size_vec;
  34. # for (c10::IValue size_elem: size_list_in) {
  35. # int64_t size_base = size_elem.to<int64_t>();
  36. # size_vec.push_back(size_base);
  37. # }
  38. # at::ArrayRef<int64_t> size_list_out(size_vec);
  39. # ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack.
  40. # Will be passed to unboxed kernel.
  41. # ```
  42. # - Dimname[]? names
  43. # ```cpp
  44. # c10::optional<c10::IValue> names_opt = (std::move(peek(stack, 1, 7))).toOptional<c10::IValue>();
  45. # c10::optional<at::ArrayRef<at::Dimname>> names_opt_out;
  46. # if (names_opt.has_value()) {
  47. # ~~~~~~~~~~~ <-- Unwrapping optional shell
  48. # const c10::IValue names_opt_in = names_opt.value();
  49. # const c10::List<c10::IValue> names_list_in = names_opt_in.toList();
  50. #
  51. # std::vector<at::Dimname> names_vec;
  52. # for (c10::IValue names_elem: names_list_in) {
  53. # ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one.
  54. # at::Dimname names_base = names_elem.to<at::Dimname>();
  55. # names_vec.push_back(names_base);
  56. # }
  57. # at::ArrayRef<at::Dimname> names_list_out(names_vec);
  58. #
  59. # names_opt_out = c10::optional<at::ArrayRef<at::Dimname>>(names_list_out);
  60. # } else {
  61. # names_opt_out = c10::optional<at::ArrayRef<at::Dimname>>();
  62. # }
  63. # ```
  64. # - ScalarType? dtype (similarly for the rest of the arguments)
  65. # ```cpp
  66. # c10::optional<c10::IValue> dtype_opt = (std::move(peek(stack, 2, 7))).toOptional<c10::IValue>();
  67. # c10::optional<at::ScalarType> dtype_opt_out;
  68. # if (dtype_opt.has_value()) {
  69. # const c10::IValue dtype_opt_in = dtype_opt.value();
  70. # at::ScalarType dtype_base = dtype_opt_in.to<at::ScalarType>();
  71. # ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it
  72. # directly using ".to<T>()" API.
  73. # dtype_opt_out = c10::optional<at::ScalarType>(dtype_base);
  74. # } else {
  75. # dtype_opt_out = c10::optional<at::ScalarType>();
  76. # }
  77. # ```
  78. #
  79. # - Unboxed Kernel Call
  80. # ```cpp
  81. # auto result_ = torch::empty(
  82. # size_list_out,
  83. # names_opt_out,
  84. # options,
  85. # memory_format_opt_out
  86. # );
  87. # ```
  88. #
  89. # - Push Result Back to Stack
  90. # ```cpp
  91. # drop(stack, 7);
  92. # pack(stack, std::move(result_));
  93. # ```
  94. connector = "\n\t"
  95. # Return unboxing function name for a NativeFunction
  96. def name(f: NativeFunction) -> str:
  97. return f.func.name.unambiguous_name()
  98. # Convert all the arguments in a NativeFunction to C++ code
  99. def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]:
  100. # we need the 'self' argument so method needs to be False
  101. args = (
  102. CppSignatureGroup.from_native_function(f, method=False)
  103. .most_faithful_signature()
  104. .arguments()
  105. )
  106. code_list = [
  107. f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));"
  108. for i in range(len(args))
  109. ] + [""]
  110. binding_list = []
  111. for i, arg in enumerate(args):
  112. # expecting only Argument
  113. if not isinstance(arg.argument, Argument):
  114. raise Exception(
  115. f"Unexpected argument type, expecting `Argument` but got {arg}"
  116. )
  117. argument: Argument = arg.argument
  118. unboxed_name, _, code, decl = argumenttype_ivalue_convert(
  119. argument.type,
  120. argument.name,
  121. mutable=argument.is_write,
  122. )
  123. code_list.extend(decl)
  124. code_list.extend(code)
  125. binding_list.append(arg.with_name(unboxed_name))
  126. return binding_list, code_list
  127. # Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
  128. # (1) the C++ code necessary to unbox the argument
  129. # (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
  130. def argumenttype_ivalue_convert(
  131. t: Type, arg_name: str, *, mutable: bool = False
  132. ) -> Tuple[str, CType, List[str], List[str]]:
  133. # Unboxing is for mobile, which doesn't care about SymInts
  134. ctype = cpp.argumenttype_type(
  135. t=t, mutable=mutable, binds=arg_name, symint=False
  136. ).type
  137. if isinstance(t, BaseType):
  138. out_name = f"{arg_name}_base"
  139. code, decl = _gen_code_base_type(
  140. arg_name=arg_name, out_name=out_name, ctype=ctype
  141. )
  142. elif isinstance(t, OptionalType):
  143. out_name = f"{arg_name}_opt_out"
  144. code, decl = _gen_code_optional_type(
  145. arg_name=arg_name,
  146. out_name=out_name,
  147. t=t,
  148. ctype=ctype,
  149. )
  150. elif isinstance(t, ListType):
  151. out_name = f"{arg_name}_list_out"
  152. code, decl = _gen_code_list_type(
  153. arg_name=arg_name,
  154. out_name=out_name,
  155. t=t,
  156. ctype=ctype,
  157. )
  158. else:
  159. raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")
  160. return out_name, ctype, code, decl
  161. def _gen_code_base_type(
  162. arg_name: str, out_name: str, ctype: CType
  163. ) -> Tuple[List[str], List[str]]:
  164. return [
  165. f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
  166. ], []
  167. def _gen_code_optional_type(
  168. arg_name: str, out_name: str, t: OptionalType, ctype: CType
  169. ) -> Tuple[List[str], List[str]]:
  170. in_name = f"{arg_name}_opt_in"
  171. res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
  172. return (
  173. f"""
  174. c10::optional<c10::IValue> {arg_name}_opt = {arg_name}.toOptional<c10::IValue>();
  175. {ctype.cpp_type(strip_ref=True)} {out_name};
  176. if ({arg_name}_opt.has_value()) {{
  177. const c10::IValue {in_name} = {arg_name}_opt.value();
  178. {connector.join(res_code)}
  179. {out_name} = {ctype.cpp_type(strip_ref=True)}({res_name});
  180. }} else {{
  181. {out_name} = {ctype.cpp_type(strip_ref=True)}();
  182. }}
  183. """.split(
  184. "\n"
  185. ),
  186. decl,
  187. )
  188. def _gen_code_list_type(
  189. arg_name: str, out_name: str, t: ListType, ctype: CType
  190. ) -> Tuple[List[str], List[str]]:
  191. in_name = f"{arg_name}_list_in"
  192. elem_name = f"{arg_name}_elem"
  193. code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"]
  194. res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name)
  195. # handle list type with size, e.g., bool[4]
  196. if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size:
  197. code.extend(
  198. f"""
  199. {ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name});
  200. """.split(
  201. "\n"
  202. )
  203. )
  204. # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<c10::optional<at::Tensor>>
  205. elif isinstance(t.elem, OptionalType):
  206. code.extend(
  207. f"""
  208. {ctype.cpp_type(strip_ref=True)} {out_name};
  209. for (c10::IValue {elem_name}: {in_name}) {{
  210. {connector.join(res_code)}
  211. {out_name}.push_back({res_name});
  212. }}
  213. """.split(
  214. "\n"
  215. )
  216. )
  217. else:
  218. # use ArrayRef as default.
  219. vec_name = arg_name + "_vec"
  220. # need to bring vector instantiation out of scope so that ArrayRef has valid data
  221. decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};")
  222. code.extend(
  223. f"""
  224. for (c10::IValue {elem_name}: {in_name}) {{
  225. {connector.join(res_code)}
  226. {vec_name}.push_back({res_name});
  227. }}
  228. {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
  229. """.split(
  230. "\n"
  231. )
  232. )
  233. return code, decl