et_cpp.py 13 KB


  1. from typing import List, Optional, Sequence, Set, Union
  2. from torchgen import local
  3. from torchgen.api.types import (
  4. ArgName,
  5. ArrayCType,
  6. BaseCType,
  7. Binding,
  8. ConstRefCType,
  9. CType,
  10. MutRefCType,
  11. NamedCType,
  12. SpecialArgName,
  13. TupleCType,
  14. VectorCType,
  15. voidT,
  16. )
  17. from torchgen.model import (
  18. Argument,
  19. Arguments,
  20. BaseTy,
  21. BaseType,
  22. ListType,
  23. NativeFunction,
  24. OptionalType,
  25. Return,
  26. SelfArgument,
  27. TensorOptionsArguments,
  28. Type,
  29. )
  30. from torchgen.utils import assert_never
  31. from .types import (
  32. ArrayRefCType,
  33. BaseTypeToCppMapping,
  34. OptionalCType,
  35. scalarT,
  36. tensorListT,
  37. tensorT,
  38. )
  39. """
  40. This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
  41. functions like at::add. It also serves as a native function API, which is the signature of kernels,
  42. since in Executorch CppSignature is the same as NativeSignature.
  43. Difference between this file and torchgen.api.cpp.py:
  44. - Executorch doesn't support TensorOptions, however in this file we still keep the logic here to be compatible with
  45. torchgen.api.cpp, so that we can do stuff like ATen mode (running ATen kernels in Executorch).
  46. - Executorch doesn't support Dimname.
  47. - Executorch runtime doesn't support SymInt, will treat it as int.
  48. """
  49. # Translation of "value types" in JIT schema to C++ API type. Value
  50. # types look the same no matter if they are argument types or return
  51. # types. Returns None if the type in question is not a value type.
  52. def valuetype_type(
  53. t: Type,
  54. *,
  55. binds: ArgName,
  56. remove_non_owning_ref_types: bool = False,
  57. ) -> Optional[NamedCType]:
  58. if isinstance(t, BaseType):
  59. if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
  60. return None
  61. # For SymInt we simply treat it as int.
  62. elif str(t) == "SymInt":
  63. return NamedCType(binds, BaseCType(BaseTypeToCppMapping[BaseTy.int]))
  64. if remove_non_owning_ref_types:
  65. if t.name == BaseTy.str:
  66. raise AssertionError(
  67. "string ref->value conversion: not implemented yet"
  68. )
  69. # All other BaseType currently map directly to BaseCppTypes.
  70. return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
  71. elif isinstance(t, OptionalType):
  72. elem = valuetype_type(t.elem, binds=binds)
  73. if elem is None:
  74. return None
  75. return NamedCType(binds, OptionalCType(elem.type))
  76. elif isinstance(t, ListType):
  77. if str(t.elem) == "bool":
  78. assert t.size is not None
  79. return NamedCType(
  80. binds, ArrayCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool]), t.size)
  81. )
  82. else:
  83. return None
  84. else:
  85. raise AssertionError(f"unrecognized type {repr(t)}")
  86. # Translation of types occuring in JIT arguments to a C++ argument type.
  87. # If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
  88. # For example, we'll return std::vector<int> instead of IntArrayRef.
  89. # See Note [translation from C++ reference to value types]
  90. def argumenttype_type(
  91. t: Type,
  92. *,
  93. mutable: bool,
  94. binds: ArgName,
  95. remove_non_owning_ref_types: bool = False,
  96. ) -> NamedCType:
  97. # If it's a value type, do the value type translation
  98. r = valuetype_type(
  99. t,
  100. binds=binds,
  101. remove_non_owning_ref_types=remove_non_owning_ref_types,
  102. )
  103. if r is not None:
  104. return r
  105. if isinstance(t, BaseType):
  106. if t.name == BaseTy.Tensor:
  107. if mutable and not local.use_const_ref_for_mutable_tensors():
  108. return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
  109. else:
  110. return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
  111. elif t.name == BaseTy.Scalar:
  112. return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
  113. else:
  114. raise AssertionError(f"base type should have been value type {t}")
  115. elif isinstance(t, OptionalType):
  116. if str(t.elem) == "Tensor":
  117. if mutable and not local.use_const_ref_for_mutable_tensors():
  118. return NamedCType(
  119. binds, MutRefCType(BaseCType(tensorT))
  120. ) # TODO: fix this discrepancy
  121. else:
  122. return NamedCType(
  123. binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
  124. )
  125. elif str(t.elem) == "Scalar":
  126. return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
  127. elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
  128. return NamedCType(binds, OptionalCType(elem.type))
  129. elif isinstance(t, ListType):
  130. # TODO: keeping these special cases for Tensor[] and Tensor?[] so that we can hookup with ATen kernels.
  131. if str(t.elem) == "Tensor":
  132. return NamedCType(binds, BaseCType(tensorListT))
  133. elif str(t.elem) == "Dimname":
  134. raise NotImplementedError("Executorch doesn't support Dimname")
  135. elif str(t.elem) == "Tensor?":
  136. return NamedCType(binds, ArrayRefCType(OptionalCType(BaseCType(tensorT))))
  137. elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
  138. return NamedCType(binds, ArrayRefCType(elem.type))
  139. else:
  140. raise AssertionError(f"unrecognized type {repr(t)}")
  141. # Translate a JIT argument into its C++ type
  142. def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
  143. return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
  144. # Translation of a (non-multi) return type from JIT to C++
  145. # N.B: returntype_type returns a CType, not a NamedCType.
  146. # This is mostly because of the mismatch between return types and return names.
  147. # e.g. a function with a return type of 'void' has 0 return names,
  148. # and a function with a return type of 'std::tuple' has >1 return name.
  149. def returntype_type(t: Type, *, mutable: bool) -> CType:
  150. # placeholder is ignored
  151. r = valuetype_type(t, binds="__placeholder__")
  152. if r is not None:
  153. return r.type
  154. if isinstance(t, BaseType):
  155. if t.name == BaseTy.Tensor:
  156. if mutable:
  157. if local.use_const_ref_for_mutable_tensors():
  158. return ConstRefCType(BaseCType(tensorT))
  159. else:
  160. return MutRefCType(BaseCType(tensorT))
  161. else:
  162. # Note [Tensor Copy Returns]
  163. # Currently, we use "Argument.is_write" to determine
  164. # whether or not Tensor return types should be copies or references.
  165. # If that ever changes, take a look at other locations of this note!
  166. return BaseCType(tensorT)
  167. elif t.name == BaseTy.Scalar:
  168. return BaseCType(scalarT)
  169. elif isinstance(t, ListType):
  170. assert (
  171. not mutable
  172. ), "Native functions should never return a mutable tensor list. They should return void."
  173. elem = returntype_type(t.elem, mutable=False)
  174. assert t.size is None, f"fixed size list returns not supported: {t}"
  175. return VectorCType(elem)
  176. raise AssertionError(f"unrecognized return type {t}")
  177. # Translation of a single return to its C++ type
  178. def return_type(r: Return) -> CType:
  179. return returntype_type(r.type, mutable=r.is_write)
  180. # Translation of a full (possibly multi) return from JIT to its C++ type
  181. def returns_type(rs: Sequence[Return]) -> CType:
  182. if len(rs) == 0:
  183. return BaseCType(voidT)
  184. elif len(rs) == 1:
  185. return return_type(rs[0])
  186. else:
  187. return TupleCType([return_type(r) for r in rs])
  188. def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
  189. returns: List[str] = []
  190. for i, r in enumerate(f.func.returns):
  191. # If we have an inplace function, the return argument is
  192. # implicitly named self.
  193. # TODO: Consider incorporating this into the data model
  194. if f.func.name.name.inplace:
  195. assert i == 0, "illegal inplace function with multiple returns"
  196. name = "self"
  197. # If we are out function, the name is the name of the
  198. # corresponding output function (r.name will get recorded
  199. # in field_name later.)
  200. elif f.func.is_out_fn():
  201. name = f.func.arguments.out[i].name
  202. # If the return argument is explicitly named...
  203. elif r.name:
  204. name_conflict = any(
  205. r.name == a.name for a in f.func.schema_order_arguments()
  206. )
  207. if name_conflict and not f.func.is_out_fn():
  208. name = f"{r.name}_return"
  209. else:
  210. name = r.name
  211. # If there is no explicit name and no fallback name was passed in, we just name the output result,
  212. # unless it's a multi-return, in which case it's result0,
  213. # result1, etc (zero-indexed)
  214. else:
  215. name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
  216. returns.append(name)
  217. return returns
  218. JIT_TO_CPP_DEFAULT = {
  219. "False": "false",
  220. "True": "true",
  221. "None": "torch::executorch::nullopt", # UGH this one is type directed
  222. "[]": "{}",
  223. "contiguous_format": "torch::executorch::MemoryFormat::Contiguous",
  224. "long": "torch::executorch::kLong",
  225. }
  226. # Convert a JIT default into C++ expression representing the default
  227. def default_expr(d: str, t: Type) -> str:
  228. if d == "None" and str(t) == "Tensor?":
  229. return "{}"
  230. if isinstance(t, BaseType) and t.name is BaseTy.str:
  231. # Schema allows single quotes but C++ needs double
  232. if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
  233. s = ""
  234. i = 1
  235. while i + 1 < len(d):
  236. if d[i] != "\\":
  237. if d[i] == '"':
  238. s += '\\"'
  239. else:
  240. s += d[i]
  241. i += 1
  242. else:
  243. if d[i + 1] == "'":
  244. s += "'"
  245. else:
  246. s += d[i : i + 2]
  247. i += 2
  248. return f'"{s}"'
  249. if isinstance(t, OptionalType):
  250. if d == "None":
  251. return "torch::executor::nullopt"
  252. return default_expr(d, t.elem)
  253. if isinstance(t, ListType):
  254. if d.startswith("[") and d.endswith("]"):
  255. return "{" + d[1:-1] + "}"
  256. elif t.size is None:
  257. # NOTE: Sized lists can have scalar defaults
  258. raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
  259. return JIT_TO_CPP_DEFAULT.get(d, d)
  260. # Convert an argument into its C++ API form
  261. def argument(
  262. a: Union[Argument, TensorOptionsArguments, SelfArgument],
  263. *,
  264. cpp_no_default_args: Set[str],
  265. method: bool,
  266. faithful: bool,
  267. has_tensor_options: bool,
  268. ) -> List[Binding]:
  269. def sub_argument(
  270. a: Union[Argument, TensorOptionsArguments, SelfArgument]
  271. ) -> List[Binding]:
  272. return argument(
  273. a,
  274. cpp_no_default_args=cpp_no_default_args,
  275. method=method,
  276. faithful=faithful,
  277. has_tensor_options=has_tensor_options,
  278. )
  279. if isinstance(a, Argument):
  280. binds: ArgName
  281. if a.name == "memory_format" and has_tensor_options:
  282. binds = SpecialArgName.possibly_redundant_memory_format
  283. else:
  284. binds = a.name
  285. default: Optional[str] = None
  286. if a.name not in cpp_no_default_args and a.default is not None:
  287. default = default_expr(a.default, a.type)
  288. return [
  289. Binding(
  290. nctype=argument_type(a, binds=binds),
  291. name=a.name,
  292. default=default,
  293. argument=a,
  294. )
  295. ]
  296. elif isinstance(a, TensorOptionsArguments):
  297. raise NotImplementedError("Need to implement type resolution for TensorOptions")
  298. elif isinstance(a, SelfArgument):
  299. if method:
  300. # Caller is responsible for installing implicit this in context!
  301. return []
  302. else:
  303. return sub_argument(a.argument)
  304. else:
  305. assert_never(a)
  306. def arguments(
  307. arguments: Arguments,
  308. *,
  309. faithful: bool,
  310. method: bool,
  311. cpp_no_default_args: Set[str],
  312. ) -> List[Binding]:
  313. args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
  314. if faithful:
  315. args.extend(arguments.non_out)
  316. args.extend(arguments.out)
  317. else:
  318. args.extend(arguments.out)
  319. args.extend(arguments.non_out)
  320. return [
  321. r.no_default() if faithful else r
  322. for a in args
  323. for r in argument(
  324. a,
  325. faithful=faithful,
  326. method=method,
  327. has_tensor_options=arguments.tensor_options is not None,
  328. cpp_no_default_args=cpp_no_default_args,
  329. )
  330. ]