lazy.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. from typing import Any, Dict, List, Optional, Tuple, Union
  2. from torchgen.api.types import (
  3. BaseCppType,
  4. BaseCType,
  5. boolT,
  6. CType,
  7. deviceT,
  8. doubleT,
  9. layoutT,
  10. ListCType,
  11. longT,
  12. memoryFormatT,
  13. NamedCType,
  14. OptionalCType,
  15. scalarT,
  16. scalarTypeT,
  17. stringT,
  18. SymIntT,
  19. VectorCType,
  20. )
  21. from torchgen.model import (
  22. Argument,
  23. BaseTy,
  24. BaseType,
  25. FunctionSchema,
  26. ListType,
  27. OperatorName,
  28. OptionalType,
  29. Return,
  30. TensorOptionsArguments,
  31. Type,
  32. )
  33. _valueT = None
  34. # A ValueT is an IR type which represents the computation of a Tensor. In other
  35. # words, a PyTorch user will do operations on lazy tensors, and each output lazy
  36. # tensor internally tracks a ValueT representing the IR node that would have
  37. # actually produced the value of this tensor for real.
  38. #
  39. # This is configurable because different lazy tensor backends (LTC vs XLA) will
  40. # have different IR representations. (Though, arguably, after unification they
  41. # shouldn't!)
  42. def getValueT() -> BaseCppType:
  43. global _valueT
  44. if not _valueT:
  45. raise NotImplementedError(
  46. "The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
  47. )
  48. return _valueT
  49. def setValueT(val: BaseCppType) -> None:
  50. global _valueT
  51. _valueT = val
  52. # this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
  53. # making it easier to represent special properties of an arg.
  54. tensorListValueT = BaseCppType("torch::lazy", "Value")
  55. def process_ir_type(
  56. typ: Type, properties: "LazyIrProperties", *, symint: bool
  57. ) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
  58. """
  59. This function takes a type from NativeFunctions and converts it for use with
  60. lazy tensor codegen.
  61. Type conversion for lazy currently consists of
  62. (1) changing at::Tensors into lazy::Values
  63. (2) wrapping everything in a BaseCType
  64. (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
  65. (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
  66. There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
  67. This is incomplete- there are assertions in places that it's expected to need to add
  68. more types as the codegen is used with more operators.
  69. """
  70. if isinstance(typ, BaseType):
  71. if typ.name == BaseTy.Tensor:
  72. return BaseCType(getValueT())
  73. elif typ.name == BaseTy.Scalar:
  74. if properties.TreatScalarsAsConstants:
  75. return BaseCType(scalarT)
  76. # at::scalar has special handling,
  77. # and is wrapped in an lazy::Value just like at::tensor
  78. return BaseCType(getValueT())
  79. elif typ.name == BaseTy.ScalarType:
  80. return BaseCType(scalarTypeT)
  81. elif typ.name == BaseTy.int:
  82. return BaseCType(longT)
  83. elif typ.name == BaseTy.SymInt:
  84. if symint:
  85. return BaseCType(getValueT())
  86. else:
  87. return BaseCType(longT)
  88. elif typ.name == BaseTy.bool:
  89. return BaseCType(boolT)
  90. elif typ.name == BaseTy.float:
  91. return BaseCType(doubleT)
  92. elif typ.name == BaseTy.str:
  93. return BaseCType(stringT)
  94. elif typ.name == BaseTy.Device:
  95. return BaseCType(deviceT)
  96. elif typ.name == BaseTy.Layout:
  97. return BaseCType(layoutT)
  98. elif typ.name == BaseTy.MemoryFormat:
  99. return BaseCType(memoryFormatT)
  100. else:
  101. raise AssertionError(f"TODO add support for type {repr(typ)}")
  102. elif isinstance(typ, OptionalType):
  103. return OptionalCType(process_ir_type(typ.elem, properties, symint=symint))
  104. elif isinstance(typ, ListType):
  105. if str(typ.elem) == "Tensor?":
  106. # TODO(whc) is this actually correct? or should it use a Vector like above
  107. return ListCType(OptionalCType(BaseCType(getValueT())))
  108. elif str(typ.elem) == "Tensor":
  109. # this is a TensorList which comes in from GetTensorList as a Value
  110. return BaseCType(tensorListValueT)
  111. elif typ.elem == BaseType(BaseTy.SymInt):
  112. # TODO: return a value type. The problem here is analogous to
  113. # the problem with tensorListValueT: if you have SymInt[] you
  114. # cannot conveniently save the list of Value directly, as nodes
  115. # expect to save values as a vector for ALL arguments. So you
  116. # need a separate IR node that represents all of the size nodes
  117. # assembled into a list. I'm not an LTC dev so I don't want to
  118. # figure it out right now. Y'all figure it out...
  119. return VectorCType(BaseCType(longT))
  120. else:
  121. return VectorCType(process_ir_type(typ.elem, properties, symint=symint))
  122. else:
  123. raise AssertionError(f"unrecognized type {repr(typ)}")
  124. # TODO: Determining this based off of CType is bad; this should be computed
  125. # from Type directly; then the same logic as process_ir_type can be used
  126. #
  127. # Invariant: passed typ should be an *owning* CType (e.g., we will report
  128. # that ArrayRef<Value> is NOT a value type)
  129. def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
  130. """
  131. Given a type, determine if it is a Value-like type. This is equivalent to
  132. being Tensor-like, but assumes the type has already been transformed.
  133. """
  134. if isinstance(typ, BaseCType):
  135. # I am regretting my naming conventions, but now we are wrapping at::scalar in
  136. # lazy value, while preserving other 'scalar' types as scalars in the IR
  137. treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
  138. return (
  139. typ.type == getValueT()
  140. or (typ.type == scalarT and not treat_scalars_as_constants)
  141. or typ.type == SymIntT
  142. )
  143. elif typ == VectorCType(BaseCType(SymIntT)):
  144. # TODO: report True for this
  145. return False
  146. elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
  147. return isValueType(typ.elem, properties)
  148. return False
  149. def isSymIntType(typ: Type) -> bool:
  150. return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
  151. def isWrappedScalarType(typ: Type) -> bool:
  152. """
  153. Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
  154. Since we literally change the type from scalarT to valueT, information is lost.
  155. This function helps build a list of wrapped scalars to save that information
  156. """
  157. if isinstance(typ, BaseType):
  158. # I am regretting my naming conventions, but now we are wrapping at::scalar in
  159. # lazy value, while preserving other 'scalar' types as scalars in the IR
  160. return typ.name == BaseTy.Scalar
  161. elif isinstance(typ, (OptionalType, ListType)):
  162. return isWrappedScalarType(typ.elem)
  163. return False
  164. # TODO: dedupe with Type.is_generator_like
  165. def isGeneratorType(typ: Type) -> bool:
  166. if isinstance(typ, BaseType):
  167. return typ.name == BaseTy.Generator
  168. elif isinstance(typ, (OptionalType)):
  169. return isGeneratorType(typ.elem)
  170. return False
  171. # This class caches a few derived properties computed from an Argument
  172. # and LazyIrProperties
  173. class LazyArgument:
  174. name: str
  175. orig_type: Type
  176. lazy_type_: Optional[CType]
  177. is_wrapped_scalar: bool
  178. is_generator: bool
  179. # TODO: this is lies, it is false for symint list
  180. is_symint_or_list: bool
  181. # Whether or not we are treating this as symint or not
  182. symint: bool
  183. # true if this argument is or contains a lazy IR value
  184. is_lazy_value: bool
  185. def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: bool):
  186. self.name = arg.name
  187. self.orig_type = arg.type
  188. self.symint = symint
  189. self.is_optional = isinstance(arg.type, OptionalType)
  190. self.is_generator = isGeneratorType(arg.type)
  191. if self.is_generator:
  192. assert (
  193. self.is_optional
  194. ), "We expect all generators are optional since currently they are"
  195. # there is no handling for generators in TorchScript IR (or XLA)
  196. # so we fall back to eager if the (optional)generator has value, and otherwise
  197. # its null and safe to exclude from lazy IR
  198. self.lazy_type_ = None
  199. else:
  200. self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
  201. self.is_wrapped_scalar = isWrappedScalarType(arg.type)
  202. self.is_symint_or_list = symint and (
  203. isSymIntType(arg.type)
  204. or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
  205. # TODO: lists of symints are not currently treated as value types
  206. # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
  207. )
  208. self.is_lazy_value = not self.is_generator and isValueType(
  209. self.lazy_type, properties
  210. )
  211. @property
  212. def lazy_type(self) -> CType:
  213. assert (
  214. self.lazy_type_ is not None
  215. ), f"Attempted to access lazy_type for invalid argument {self.name}"
  216. return self.lazy_type_
  217. class LazyIrProperties:
  218. """Collection of properties for an IR node
  219. The property groups are listed below. Each group is mutually
  220. exclusive, meaning that only one property from each group can be True
  221. at any one time. The properties can be accessed as if they were normal
  222. attributes. The mutual exclusivity is automatically handled.
  223. """
  224. Properties: Tuple[Tuple[str, ...], ...] = (
  225. (
  226. "ShapePrecompute", # Assume shape has been precomputed
  227. "ShapeCompute", # Need to compute the shape on construction
  228. "ShapeCache", # Utilize the shape cache to defer computation
  229. ),
  230. (
  231. "Lower", # Codegen full lower function
  232. "LowerDeclOnly", # Codegen only lower function declaration
  233. ),
  234. (
  235. "CanBeReused", # Codegen full reuse function
  236. "CanBeReusedDeclOnly", # Codegen only reuse function declaration
  237. ),
  238. (
  239. "CreateFn", # Codegen full create function
  240. "CreateFnDeclOnly", # Codegen only create function declaration
  241. ),
  242. (
  243. "TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
  244. ),
  245. )
  246. def __init__(self, *default_properties: str):
  247. properties: Dict[Tuple[str, ...], Optional[str]] = {
  248. p: None for p in LazyIrProperties.Properties
  249. }
  250. self.__dict__["properties"] = properties
  251. for p in default_properties:
  252. setattr(self, p, True)
  253. def __getattr__(self, key: str) -> Any:
  254. properties = self.__dict__["properties"]
  255. for values in LazyIrProperties.Properties:
  256. if key in values:
  257. return properties[values] == key
  258. return self.__getattribute__(key)
  259. def __setattr__(self, key: str, value: Any) -> Any:
  260. properties = self.__dict__["properties"]
  261. for values in LazyIrProperties.Properties:
  262. if key in values:
  263. properties[values] = key if value else None
  264. return value
  265. raise KeyError(f"Invalid property: {key}")
  266. # Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
  267. # Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
  268. # but carries type information from a native FunctionSchema modified for use with IR nodes,
  269. # and preserving original argument names.
  270. #
  271. # TODO: This is not idiomatic with how other torchgen APIs transform on schema.
  272. class LazyIrSchema:
  273. # The name of the operator this function schema describes.
  274. name: "OperatorName"
  275. positional_args: Tuple[LazyArgument, ...]
  276. keyword_args: Tuple[LazyArgument, ...]
  277. # TODO: Need to handle collisions with argument names at some point
  278. returns: Tuple["Return", ...]
  279. # if this schema has a Generator arg, list its orig ctype/name but don't
  280. # build a LazyArgument since lazy IR doesn't support it
  281. generator_arg: Optional[NamedCType] = None
  282. # original function schema
  283. func: FunctionSchema
  284. # Whether or not we are code-genning for SymInt or not
  285. symint: bool
  286. properties: LazyIrProperties = LazyIrProperties(
  287. # default properties
  288. "ShapePrecompute",
  289. "Lower",
  290. "CanBeReused",
  291. )
  292. opkind: Optional[str] = None
  293. def __init__(
  294. self,
  295. func: FunctionSchema,
  296. properties: Optional[LazyIrProperties] = None,
  297. *,
  298. symint: bool,
  299. ):
  300. if properties:
  301. self.properties = properties
  302. self.func = func
  303. self.symint = symint
  304. positional_args: List[LazyArgument] = []
  305. for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
  306. if arg_field == "self_arg" and func.arguments.self_arg is not None:
  307. arg = getattr(func.arguments, "self_arg").argument
  308. positional_args.append(
  309. LazyArgument(arg, self.properties, symint=symint)
  310. )
  311. elif getattr(func.arguments, arg_field) is not None:
  312. positional_args.extend(
  313. LazyArgument(arg, self.properties, symint=symint)
  314. for arg in getattr(func.arguments, arg_field)
  315. )
  316. self.positional_args = tuple(positional_args)
  317. keyword_args: List[LazyArgument] = []
  318. for arg_field in [
  319. "pre_tensor_options_kwarg_only",
  320. "tensor_options",
  321. "post_tensor_options_kwarg_only",
  322. "out",
  323. ]:
  324. curr_args = getattr(func.arguments, arg_field)
  325. if curr_args is not None:
  326. if isinstance(curr_args, TensorOptionsArguments):
  327. curr_args = curr_args.all()
  328. for arg in curr_args:
  329. if isGeneratorType(arg.type):
  330. assert (
  331. self.generator_arg is None
  332. ), "We expect there is only one generator arg"
  333. self.generator_arg = NamedCType(arg.name, arg.type)
  334. keyword_args.extend(
  335. LazyArgument(arg, self.properties, symint=symint)
  336. for arg in curr_args
  337. )
  338. self.keyword_args = tuple(keyword_args)
  339. self.name = func.name
  340. self.returns = func.returns
  341. @property
  342. def node_name(self) -> str:
  343. """
  344. Return camel-case version of op in node.
  345. Note: This function also appends any `overload_name` in the operation.
  346. For example, if the op is `bitwise_and.Tensor`, the returned name
  347. will be `BitwiseAndTensor`.
  348. """
  349. op_name = f"{self.name.name}_{self.name.overload_name}".lower()
  350. return "".join(word.capitalize() or "" for word in op_name.split("_"))
  351. @property
  352. def aten_name(self) -> str:
  353. return str(self.name.name)
  354. @property
  355. def base_name(self) -> str:
  356. return f"{self.name.name.base}"
  357. def filtered_args(
  358. self,
  359. positional: bool = True,
  360. keyword: bool = True,
  361. values: bool = True,
  362. scalars: bool = True,
  363. generator: bool = False,
  364. ) -> List[LazyArgument]:
  365. # This function maintains the sorted order of arguments but provides different filtered views.
  366. # Some parts of the code care about kwargs vs args (TS lowerings),
  367. # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
  368. # Generators are special cased, as they are needed for fallback/shape-inference but not supported
  369. # in TS lowerings and therefore also omitted from lazy IR.
  370. args: List[LazyArgument] = []
  371. if positional:
  372. args.extend(self.positional_args)
  373. if keyword:
  374. args.extend(self.keyword_args)
  375. if values and scalars and generator:
  376. return args
  377. elif values and scalars:
  378. return [a for a in args if not a.is_generator]
  379. elif values:
  380. return [a for a in args if a.is_lazy_value]
  381. elif scalars:
  382. return [
  383. a
  384. for a in args
  385. if not a.is_lazy_value and (generator or not a.is_generator)
  386. ]
  387. return []
  388. @property
  389. def positional_values(self) -> List[LazyArgument]:
  390. return self.filtered_args(
  391. positional=True, keyword=False, values=True, scalars=False
  392. )
  393. @property
  394. def positional_scalars(self) -> List[LazyArgument]:
  395. return self.filtered_args(
  396. positional=True, keyword=False, values=False, scalars=True
  397. )
  398. @property
  399. def keyword_values(self) -> List[LazyArgument]:
  400. return self.filtered_args(
  401. positional=False, keyword=True, values=True, scalars=False
  402. )
  403. @property
  404. def keyword_scalars(self) -> List[LazyArgument]:
  405. return self.filtered_args(
  406. positional=False, keyword=True, values=False, scalars=True
  407. )