cpp.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. from typing import List, Optional, Sequence, Set, Union
  2. from torchgen import local
  3. from torchgen.api.types import (
  4. ArgName,
  5. ArrayCType,
  6. ArrayRefCType,
  7. BaseCType,
  8. BaseTypeToCppMapping,
  9. Binding,
  10. boolT,
  11. ConstRefCType,
  12. CType,
  13. dimnameListT,
  14. intArrayRefT,
  15. iTensorListRefT,
  16. ListCType,
  17. longT,
  18. MutRefCType,
  19. NamedCType,
  20. OptionalCType,
  21. optionalIntArrayRefT,
  22. optionalSymIntArrayRefT,
  23. scalarT,
  24. SpecialArgName,
  25. symIntArrayRefT,
  26. SymIntT,
  27. tensorListT,
  28. tensorOptionsT,
  29. tensorT,
  30. TupleCType,
  31. VectorCType,
  32. voidT,
  33. )
  34. from torchgen.model import (
  35. Argument,
  36. Arguments,
  37. BaseTy,
  38. BaseType,
  39. FunctionSchema,
  40. ListType,
  41. NativeFunction,
  42. OptionalType,
  43. Return,
  44. SelfArgument,
  45. TensorOptionsArguments,
  46. Type,
  47. )
  48. from torchgen.utils import assert_never
  49. # This file describes the translation of JIT schema to the public C++
  50. # API, which is what people use when they call functions like at::add.
  51. #
  52. # Prominent characteristics of the C++ API:
  53. #
  54. # - dtype, layout, device and pin_memory are collected into
  55. # a single C++ type TensorOptions (the native functions API
  56. # also has this, but tensor options is really most relevant
  57. # for the C++ API; it makes calling kwarg factory functions
  58. # pleasant)
  59. #
  60. # - defaulting lives here (in fact, the dispatcher is completely
  61. # oblivious of defaults!)
  62. #
  63. # BTW: policy on name collisions: we try not to have types with
  64. # collisions, but functions are fair game to collide
  65. def name(
  66. func: FunctionSchema,
  67. *,
  68. faithful_name_for_out_overloads: bool = False,
  69. symint_overload: bool = False,
  70. ) -> str:
  71. name = str(func.name.name)
  72. if symint_overload:
  73. name += "_symint"
  74. if func.is_out_fn():
  75. if faithful_name_for_out_overloads:
  76. name += "_outf"
  77. else:
  78. name += "_out"
  79. return name
  80. # Translation of "value types" in JIT schema to C++ API type. Value
  81. # types look the same no matter if they are argument types or return
  82. # types. Returns None if the type in question is not a value type.
  83. def valuetype_type(
  84. t: Type,
  85. *,
  86. binds: ArgName,
  87. remove_non_owning_ref_types: bool = False,
  88. symint: bool = False,
  89. ) -> Optional[NamedCType]:
  90. if isinstance(t, BaseType):
  91. if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
  92. return None
  93. elif str(t) == "SymInt":
  94. if symint:
  95. return NamedCType(binds, BaseCType(SymIntT))
  96. else:
  97. return NamedCType(binds, BaseCType(longT))
  98. if remove_non_owning_ref_types:
  99. if t.name == BaseTy.str:
  100. raise AssertionError(
  101. "string ref->value conversion: not implemented yet"
  102. )
  103. # All other BaseType currently map directly to BaseCppTypes.
  104. return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
  105. elif isinstance(t, OptionalType):
  106. elem = valuetype_type(t.elem, binds=binds, symint=symint)
  107. if elem is None:
  108. return None
  109. return NamedCType(binds, OptionalCType(elem.type))
  110. elif isinstance(t, ListType):
  111. if str(t.elem) == "bool":
  112. assert t.size is not None
  113. return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
  114. else:
  115. return None
  116. else:
  117. raise AssertionError(f"unrecognized type {repr(t)}")
  118. # Translation of types occuring in JIT arguments to a C++ argument type.
  119. # If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
  120. # For example, we'll return std::vector<int> instead of IntArrayRef.
  121. # See Note [translation from C++ reference to value types]
  122. def argumenttype_type(
  123. t: Type,
  124. *,
  125. mutable: bool,
  126. binds: ArgName,
  127. remove_non_owning_ref_types: bool = False,
  128. symint: bool = False,
  129. ) -> NamedCType:
  130. # If it's a value type, do the value type translation
  131. r = valuetype_type(
  132. t,
  133. binds=binds,
  134. symint=symint,
  135. remove_non_owning_ref_types=remove_non_owning_ref_types,
  136. )
  137. if r is not None:
  138. return r
  139. if isinstance(t, BaseType):
  140. if t.name == BaseTy.Tensor:
  141. if mutable and not local.use_const_ref_for_mutable_tensors():
  142. return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
  143. else:
  144. return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
  145. elif t.name == BaseTy.Scalar:
  146. return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
  147. else:
  148. raise AssertionError(f"base type should have been value type {t}")
  149. elif isinstance(t, OptionalType):
  150. if str(t.elem) == "Tensor":
  151. if mutable and not local.use_const_ref_for_mutable_tensors():
  152. return NamedCType(
  153. binds, MutRefCType(BaseCType(tensorT))
  154. ) # TODO: fix this discrepancy
  155. else:
  156. return NamedCType(
  157. binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
  158. )
  159. elif str(t.elem) == "Scalar":
  160. return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
  161. elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
  162. return NamedCType(binds, BaseCType(optionalIntArrayRefT))
  163. elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt":
  164. if symint:
  165. return NamedCType(binds, BaseCType(optionalSymIntArrayRefT))
  166. else:
  167. return NamedCType(binds, BaseCType(optionalIntArrayRefT))
  168. elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
  169. return NamedCType(binds, OptionalCType(elem.type))
  170. elif isinstance(t, ListType):
  171. # TODO: remove these special cases, ArrayRef fallthrough works fine
  172. if str(t.elem) == "int":
  173. if remove_non_owning_ref_types:
  174. return NamedCType(binds, VectorCType(BaseCType(longT)))
  175. else:
  176. return NamedCType(binds, BaseCType(intArrayRefT))
  177. if str(t.elem) == "SymInt":
  178. if remove_non_owning_ref_types:
  179. if symint:
  180. return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
  181. else:
  182. return NamedCType(binds, VectorCType(BaseCType(longT)))
  183. else:
  184. if symint:
  185. return NamedCType(binds, BaseCType(symIntArrayRefT))
  186. else:
  187. return NamedCType(binds, BaseCType(intArrayRefT))
  188. if str(t.elem) == "Tensor":
  189. if local.use_ilistref_for_tensor_lists():
  190. return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
  191. else:
  192. return NamedCType(binds, BaseCType(tensorListT))
  193. elif str(t.elem) == "Scalar":
  194. return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
  195. elif str(t.elem) == "Dimname":
  196. return NamedCType(binds, BaseCType(dimnameListT))
  197. elif str(t.elem) == "Tensor?":
  198. return NamedCType(
  199. binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
  200. )
  201. elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
  202. return NamedCType(binds, ArrayRefCType(elem.type))
  203. else:
  204. raise AssertionError(f"unrecognized type {repr(t)}")
  205. # Translate a JIT argument into its C++ type
  206. def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
  207. return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
  208. # Translation of a (non-multi) return type from JIT to C++
  209. # N.B: returntype_type returns a CType, not a NamedCType.
  210. # This is mostly because of the mismatch between return types and return names.
  211. # e.g. a function with a return type of 'void' has 0 return names,
  212. # and a function with a return type of 'std::tuple' has >1 return name.
  213. def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
  214. # placeholder is ignored
  215. r = valuetype_type(t, binds="__placeholder__", symint=symint)
  216. if r is not None:
  217. return r.type
  218. if isinstance(t, BaseType):
  219. if t.name == BaseTy.Tensor:
  220. if mutable:
  221. if local.use_const_ref_for_mutable_tensors():
  222. return ConstRefCType(BaseCType(tensorT))
  223. else:
  224. return MutRefCType(BaseCType(tensorT))
  225. else:
  226. # Note [Tensor Copy Returns]
  227. # Currently, we use "Argument.is_write" to determine
  228. # whether or not Tensor return types should be copies or references.
  229. # If that ever changes, take a look at other locations of this note!
  230. return BaseCType(tensorT)
  231. elif t.name == BaseTy.Scalar:
  232. return BaseCType(scalarT)
  233. elif isinstance(t, ListType):
  234. assert (
  235. not mutable
  236. ), "Native functions should never return a mutable tensor list. They should return void."
  237. elem = returntype_type(t.elem, mutable=False, symint=symint)
  238. assert t.size is None, f"fixed size list returns not supported: {t}"
  239. return VectorCType(elem)
  240. raise AssertionError(f"unrecognized return type {t}")
  241. # Translation of a single return to its C++ type
  242. def return_type(r: Return, *, symint: bool = False) -> CType:
  243. return returntype_type(r.type, mutable=r.is_write, symint=symint)
  244. # Translation of a full (possibly multi) return from JIT to its C++ type
  245. def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
  246. if len(rs) == 0:
  247. return BaseCType(voidT)
  248. elif len(rs) == 1:
  249. return return_type(rs[0], symint=symint)
  250. else:
  251. return TupleCType([return_type(r, symint=symint) for r in rs])
  252. def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
  253. returns: List[str] = []
  254. for i, r in enumerate(f.func.returns):
  255. # If we have an inplace function, the return argument is
  256. # implicitly named self.
  257. # TODO: Consider incorporating this into the data model
  258. if f.func.name.name.inplace:
  259. assert i == 0, "illegal inplace function with multiple returns"
  260. name = "self"
  261. # If we are out function, the name is the name of the
  262. # corresponding output function (r.name will get recorded
  263. # in field_name later.)
  264. elif f.func.is_out_fn():
  265. name = f.func.arguments.out[i].name
  266. # If the return argument is explicitly named...
  267. elif r.name:
  268. name_conflict = any(
  269. r.name == a.name for a in f.func.schema_order_arguments()
  270. )
  271. if name_conflict and not f.func.is_out_fn():
  272. name = f"{r.name}_return"
  273. else:
  274. name = r.name
  275. # If there is no explicit name and no fallback name was passed in, we just name the output result,
  276. # unless it's a multi-return, in which case it's result0,
  277. # result1, etc (zero-indexed)
  278. else:
  279. name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
  280. returns.append(name)
  281. return returns
  282. JIT_TO_CPP_DEFAULT = {
  283. "False": "false",
  284. "True": "true",
  285. "None": "c10::nullopt", # UGH this one is type directed
  286. "Mean": "at::Reduction::Mean",
  287. "[]": "{}",
  288. "contiguous_format": "MemoryFormat::Contiguous",
  289. "long": "at::kLong",
  290. }
  291. # Convert a JIT default into C++ expression representing the default
  292. def default_expr(d: str, t: Type, *, symint: bool) -> str:
  293. if d == "None" and str(t) == "Tensor?":
  294. return "{}"
  295. if isinstance(t, BaseType) and t.name is BaseTy.str:
  296. # Schema allows single quotes but C++ needs double
  297. if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
  298. s = ""
  299. i = 1
  300. while i + 1 < len(d):
  301. if d[i] != "\\":
  302. if d[i] == '"':
  303. s += '\\"'
  304. else:
  305. s += d[i]
  306. i += 1
  307. else:
  308. if d[i + 1] == "'":
  309. s += "'"
  310. else:
  311. s += d[i : i + 2]
  312. i += 2
  313. return f'"{s}"'
  314. if isinstance(t, OptionalType):
  315. if d == "None":
  316. return "c10::nullopt"
  317. return default_expr(d, t.elem, symint=symint)
  318. if isinstance(t, ListType):
  319. if d.startswith("[") and d.endswith("]"):
  320. return "{" + d[1:-1] + "}"
  321. elif symint and d.isdigit() and str(t.elem) == "SymInt":
  322. return f"c10::SymInt({d})"
  323. elif t.size is None:
  324. # NOTE: Sized lists can have scalar defaults
  325. raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
  326. return JIT_TO_CPP_DEFAULT.get(d, d)
  327. # Convert an argument into its C++ API form
  328. def argument(
  329. a: Union[Argument, TensorOptionsArguments, SelfArgument],
  330. *,
  331. cpp_no_default_args: Set[str],
  332. method: bool,
  333. faithful: bool,
  334. symint: bool = False,
  335. has_tensor_options: bool,
  336. ) -> List[Binding]:
  337. def sub_argument(
  338. a: Union[Argument, TensorOptionsArguments, SelfArgument]
  339. ) -> List[Binding]:
  340. return argument(
  341. a,
  342. cpp_no_default_args=cpp_no_default_args,
  343. method=method,
  344. faithful=faithful,
  345. symint=symint,
  346. has_tensor_options=has_tensor_options,
  347. )
  348. if isinstance(a, Argument):
  349. binds: ArgName
  350. if a.name == "memory_format" and has_tensor_options:
  351. binds = SpecialArgName.possibly_redundant_memory_format
  352. else:
  353. binds = a.name
  354. default: Optional[str] = None
  355. if a.name not in cpp_no_default_args and a.default is not None:
  356. default = default_expr(a.default, a.type, symint=symint)
  357. return [
  358. Binding(
  359. nctype=argument_type(a, binds=binds, symint=symint),
  360. name=a.name,
  361. default=default,
  362. argument=a,
  363. )
  364. ]
  365. elif isinstance(a, TensorOptionsArguments):
  366. if faithful:
  367. return (
  368. sub_argument(a.dtype)
  369. + sub_argument(a.layout)
  370. + sub_argument(a.device)
  371. + sub_argument(a.pin_memory)
  372. )
  373. else:
  374. default = None
  375. # Enforced by NativeFunction.__post_init__
  376. assert "options" not in cpp_no_default_args
  377. if all(x.default == "None" for x in a.all()):
  378. default = "{}"
  379. elif a.dtype.default == "long":
  380. default = "at::kLong" # TODO: this is wrong
  381. return [
  382. Binding(
  383. nctype=NamedCType("options", BaseCType(tensorOptionsT)),
  384. name="options",
  385. default=default,
  386. argument=a,
  387. )
  388. ]
  389. elif isinstance(a, SelfArgument):
  390. if method:
  391. # Caller is responsible for installing implicit this in context!
  392. return []
  393. else:
  394. return sub_argument(a.argument)
  395. else:
  396. assert_never(a)
  397. def arguments(
  398. arguments: Arguments,
  399. *,
  400. faithful: bool,
  401. symint: bool = False,
  402. method: bool,
  403. cpp_no_default_args: Set[str],
  404. ) -> List[Binding]:
  405. args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
  406. if faithful:
  407. args.extend(arguments.non_out)
  408. args.extend(arguments.out)
  409. else:
  410. args.extend(arguments.out)
  411. args.extend(arguments.non_out)
  412. return [
  413. r.no_default() if faithful else r
  414. for a in args
  415. for r in argument(
  416. a,
  417. faithful=faithful,
  418. symint=symint,
  419. method=method,
  420. has_tensor_options=arguments.tensor_options is not None,
  421. cpp_no_default_args=cpp_no_default_args,
  422. )
  423. ]