lists.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. import functools
  2. import operator
  3. from typing import Dict, List, Optional
  4. import torch
  5. import torch.fx
  6. from .. import config, variables
  7. from ..bytecode_transformation import create_instruction
  8. from ..exc import unimplemented
  9. from ..source import GetItemSource
  10. from ..utils import namedtuple_fields, proxy_args_kwargs
  11. from .base import MutableLocal, VariableTracker
  12. from .constant import ConstantVariable
  13. class BaseListVariable(VariableTracker):
  14. @staticmethod
  15. def cls_for(obj):
  16. return {
  17. iter: ListIteratorVariable,
  18. list: ListVariable,
  19. slice: SliceVariable,
  20. torch.Size: SizeVariable,
  21. tuple: TupleVariable,
  22. }[obj]
  23. def __init__(
  24. self,
  25. items: List[VariableTracker],
  26. recursively_contains=None,
  27. regen_guards=True,
  28. **kwargs,
  29. ):
  30. super().__init__(recursively_contains=recursively_contains, **kwargs)
  31. assert isinstance(items, list)
  32. assert all(isinstance(x, VariableTracker) for x in items)
  33. # Sometimes, we know that we have passed in the guards from the items in the list
  34. if regen_guards:
  35. self.guards.update(VariableTracker.propagate(items)["guards"])
  36. self.items: List[VariableTracker] = items
  37. def _as_proxy(self):
  38. return [x.as_proxy() for x in self.items]
  39. def as_python_constant(self):
  40. return self.python_type()([x.as_python_constant() for x in self.items])
  41. def as_proxy(self):
  42. assert self.python_type() is not SizeVariable
  43. return self.python_type()(self._as_proxy())
  44. def getitem_const(self, arg: VariableTracker):
  45. index = arg.as_python_constant()
  46. if isinstance(index, slice):
  47. if self.source is not None:
  48. return self.clone(
  49. items=self.items[index],
  50. source=GetItemSource(self.source, index),
  51. mutable_local=MutableLocal() if self.mutable_local else None,
  52. ).add_options(arg, self)
  53. else:
  54. return self.clone(
  55. items=self.items[index],
  56. mutable_local=MutableLocal() if self.mutable_local else None,
  57. ).add_options(arg, self)
  58. else:
  59. assert isinstance(index, int)
  60. return self.items[index].add_options(arg, self)
  61. def unpack_var_sequence(self, tx):
  62. return [x.add_options(self) for x in self.items]
  63. def call_method(
  64. self,
  65. tx,
  66. name,
  67. args: "List[VariableTracker]",
  68. kwargs: "Dict[str, VariableTracker]",
  69. ) -> "VariableTracker":
  70. options = VariableTracker.propagate(self, args, kwargs.values())
  71. if name == "__getitem__":
  72. assert not kwargs and len(args) == 1
  73. return self.getitem_const(args[0])
  74. elif (
  75. name == "__contains__"
  76. and len(args) == 1
  77. and args[0].is_python_constant()
  78. and all(x.is_python_constant() for x in self.items)
  79. ):
  80. assert not kwargs
  81. search = args[0].as_python_constant()
  82. result = any(x.as_python_constant() == search for x in self.items)
  83. return variables.ConstantVariable(result, **options)
  84. return super().call_method(tx, name, args, kwargs)
  85. @staticmethod
  86. def list_compare(tx, op, left, right):
  87. from .builtin import BuiltinVariable
  88. eq_result = BaseListVariable.list_eq(tx, left, right)
  89. if op is operator.eq:
  90. return eq_result
  91. elif op is operator.ne:
  92. return BuiltinVariable(operator.not_).call_function(tx, [eq_result], {})
  93. else:
  94. unimplemented(f"list_compare {left} {op} {right}")
  95. @staticmethod
  96. def list_eq(tx, left, right):
  97. from .builtin import BuiltinVariable
  98. options = VariableTracker.propagate(left, right)
  99. # Most list-like variables implement comparison ops the same way,
  100. # so they can re-use this helper.
  101. # There are quirks though, like how `tuple([2]) == torch.Size([2])`,
  102. # but `tuple([2]) != list([2])`
  103. if len(left.items) != len(right.items):
  104. return ConstantVariable(False, **options)
  105. if len(left.items) == 0:
  106. return ConstantVariable(True, **options)
  107. # Generic list comparison works by iterating over left aka self and right the compared-to list.
  108. # If we hit here, their lengths are the same and they cannot be expressed as python constants.
  109. # So, we iterate over the zipped list items.
  110. comps = []
  111. for l, r in zip(left.items, right.items):
  112. comp = BuiltinVariable(operator.eq).call_function(tx, [l, r], {})
  113. if comp.is_python_constant() and not comp.as_python_constant():
  114. # early exit in false case
  115. return comp.add_options(options)
  116. comps.append(comp)
  117. return functools.reduce(
  118. lambda a, b: BuiltinVariable(operator.and_).call_function(tx, [a, b], {}),
  119. comps,
  120. ).add_options(options)
  121. class RangeVariable(BaseListVariable):
  122. def __init__(self, items, **kwargs):
  123. items_to_map = items
  124. start = variables.ConstantVariable(0)
  125. stop = None
  126. step = variables.ConstantVariable(1)
  127. if len(items_to_map) == 1:
  128. (stop,) = items_to_map
  129. elif len(items_to_map) == 2:
  130. start, stop = items_to_map
  131. elif len(items_to_map) == 3:
  132. start, stop, step = items_to_map
  133. else:
  134. raise AssertionError()
  135. assert stop is not None
  136. super().__init__([start, stop, step], **kwargs)
  137. def python_type(self):
  138. return range
  139. def as_python_constant(self):
  140. return range(*[x.as_python_constant() for x in self.items])
  141. def as_proxy(self):
  142. return self.python_type()(*self._as_proxy())
  143. def unpack_var_sequence(self, tx):
  144. return [
  145. variables.ConstantVariable(x).add_options(self)
  146. for x in self.as_python_constant()
  147. ]
  148. def reconstruct(self, codegen):
  149. assert "range" not in codegen.tx.f_globals
  150. codegen.append_output(codegen.create_load_python_module(range))
  151. codegen.foreach(self.items)
  152. return [create_instruction("CALL_FUNCTION", 3)]
  153. def var_getattr(self, tx, name):
  154. fields = ["start", "stop", "step"]
  155. if name not in fields:
  156. unimplemented(f"range.{name}")
  157. return self.items[fields.index(name)].add_options(self)
  158. class ListVariable(BaseListVariable):
  159. def python_type(self):
  160. return list
  161. def reconstruct(self, codegen):
  162. codegen.foreach(self.items)
  163. return [create_instruction("BUILD_LIST", len(self.items))]
  164. def call_method(
  165. self,
  166. tx,
  167. name,
  168. args: "List[VariableTracker]",
  169. kwargs: "Dict[str, VariableTracker]",
  170. ) -> "VariableTracker":
  171. options = VariableTracker.propagate(self, args, kwargs.values())
  172. if name == "append" and self.mutable_local:
  173. assert not kwargs
  174. (arg,) = args
  175. new_rec_contains = self.recursively_contains.union(arg.recursively_contains)
  176. if arg.mutable_local is not None:
  177. new_rec_contains.add(arg.mutable_local)
  178. tx.replace_all(
  179. self,
  180. ListVariable(
  181. self.items + [arg],
  182. recursively_contains=new_rec_contains,
  183. regen_guards=False,
  184. **options,
  185. ),
  186. )
  187. return ConstantVariable(None)
  188. elif (
  189. name == "extend"
  190. and self.mutable_local
  191. and args
  192. and args[0].has_unpack_var_sequence(tx)
  193. ):
  194. assert not kwargs
  195. (arg,) = args
  196. return tx.replace_all(
  197. self,
  198. ListVariable(
  199. list(self.items) + list(arg.unpack_var_sequence(tx)),
  200. regen_guards=False,
  201. **options,
  202. ),
  203. )
  204. elif name == "insert" and self.mutable_local:
  205. assert not kwargs
  206. idx, value = args
  207. items = list(self.items)
  208. items.insert(idx.as_python_constant(), value)
  209. return tx.replace_all(
  210. self,
  211. ListVariable(items, regen_guards=False, **options),
  212. )
  213. elif name == "pop" and self.mutable_local:
  214. assert not kwargs
  215. items = list(self.items)
  216. result = items.pop(*[a.as_python_constant() for a in args])
  217. tx.replace_all(
  218. self,
  219. ListVariable(items, regen_guards=False, **options),
  220. )
  221. return result
  222. elif name == "clear" and self.mutable_local:
  223. assert not kwargs and not args
  224. return tx.replace_all(
  225. self,
  226. ListVariable([], regen_guards=False, **options),
  227. )
  228. elif (
  229. name == "__setitem__"
  230. and self.mutable_local
  231. and args
  232. and args[0].is_python_constant()
  233. ):
  234. assert not kwargs
  235. key, value = args
  236. items = list(self.items)
  237. if isinstance(key, SliceVariable):
  238. items[key.as_python_constant()] = list(value.items)
  239. else:
  240. items[key.as_python_constant()] = value
  241. result = ListVariable(items, regen_guards=False, **options)
  242. return tx.replace_all(self, result)
  243. else:
  244. return super().call_method(tx, name, args, kwargs)
  245. class TupleVariable(BaseListVariable):
  246. def python_type(self):
  247. return tuple
  248. def reconstruct(self, codegen):
  249. codegen.foreach(self.items)
  250. return [create_instruction("BUILD_TUPLE", len(self.items))]
  251. def call_method(
  252. self,
  253. tx,
  254. name,
  255. args: "List[VariableTracker]",
  256. kwargs: "Dict[str, VariableTracker]",
  257. ) -> "VariableTracker":
  258. return super().call_method(tx, name, args, kwargs)
  259. class SizeVariable(TupleVariable):
  260. """torch.Size(...)"""
  261. def __init__(
  262. self,
  263. items: List[VariableTracker],
  264. proxy: Optional[torch.fx.Proxy] = None,
  265. **kwargs,
  266. ):
  267. self.proxy = proxy
  268. super().__init__(items, **kwargs)
  269. def python_type(self):
  270. return torch.Size
  271. def as_proxy(self):
  272. if self.proxy is not None:
  273. return self.proxy
  274. # torch.Size needs special handling. Normally, we pun a list-like
  275. # container to directly contain Proxy/Node objects from FX, and FX
  276. # knows to look inside containers (via map_aggregate). But torch.Size
  277. # is weird; although it subclasses from tuple, it doesn't allow
  278. # members which aren't int-like (rejecting Proxy and Node). This
  279. # means we can't use the normal representation trick
  280. # torch.Size([proxy0, proxy1]). I looked into seeing if I could
  281. # relax torch.Size in PyTorch proper, but if torch.Size constructor
  282. # sees a type that it doesn't recognize, it will try to call
  283. # __index__() on it, so there is no BC way to actually change this
  284. # behavior (though it occurs to me that I could have just added a
  285. # YOLO no checking alternate constructor.)
  286. #
  287. # To work around this problem, I represent a torch.Size proxy as
  288. # a straight up proxy, that would have been constructed by taking
  289. # the constituent proxies as arguments. This trick can be generally
  290. # used for any construct that we need a proxy for but we can't
  291. # directly represent as an aggregate; I don't see very many examples
  292. # of this in torchdynamo though!
  293. # Look for a proxy. If there are none, do the legacy behavior
  294. tracer = None
  295. proxies = self._as_proxy()
  296. for proxy in proxies:
  297. if isinstance(proxy, torch.fx.Proxy):
  298. tracer = proxy.tracer
  299. break
  300. if tracer is None:
  301. return torch.Size(proxies)
  302. proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {})
  303. proxy.node.meta["example_value"] = torch.Size(
  304. [p.node.meta["example_value"] for p in proxies]
  305. )
  306. return proxy
  307. def reconstruct(self, codegen):
  308. codegen.load_import_from("torch", "Size")
  309. codegen.foreach(self.items)
  310. build_torch_size = [
  311. create_instruction("BUILD_TUPLE", len(self.items)),
  312. create_instruction("CALL_FUNCTION", 1),
  313. ]
  314. return build_torch_size
  315. def unpack_var_sequence(self, tx):
  316. return [x.add_options(self) for x in self.items]
  317. def call_method(
  318. self,
  319. tx,
  320. name,
  321. args: "List[VariableTracker]",
  322. kwargs: "Dict[str, VariableTracker]",
  323. ) -> "VariableTracker":
  324. options = VariableTracker.propagate(self, args, kwargs.values())
  325. if name == "__getitem__":
  326. assert not kwargs and len(args) == 1
  327. if config.dynamic_shapes:
  328. out = self.get_item_dyn(tx, args[0])
  329. else:
  330. out = self.getitem_const(args[0])
  331. return out
  332. return super().call_method(tx, name, args, kwargs)
  333. def get_item_dyn(self, tx, arg: VariableTracker):
  334. from .tensor import SymNodeVariable
  335. index = arg.as_python_constant()
  336. if isinstance(index, slice):
  337. def _dynamo_get_item_lambda(target, index):
  338. return torch.Size.__getitem__(target, index)
  339. parent_proxy = self.as_proxy()
  340. proxy = tx.output.create_proxy(
  341. "call_function",
  342. _dynamo_get_item_lambda,
  343. *proxy_args_kwargs([self, arg], {}),
  344. )
  345. items = self.items[index]
  346. def _unpack_into_example(item):
  347. if isinstance(item, SymNodeVariable):
  348. return item.sym_num
  349. return item.as_python_constant()
  350. # Mirror the indexing into example_value for downstream correctness
  351. proxy.node.meta["example_value"] = parent_proxy.node.meta["example_value"][
  352. index
  353. ]
  354. return SizeVariable(items, proxy=proxy).add_options(arg, self)
  355. else:
  356. assert isinstance(index, int)
  357. return self.items[index].add_options(arg, self)
  358. class ShapeVariable(TupleVariable):
  359. """
  360. Represents tensor.shape(...) and helps differentiate between a constant
  361. TupleVariable and ShapeVariable.
  362. """
  363. pass
  364. class NamedTupleVariable(TupleVariable):
  365. def __init__(self, items, tuple_cls, **kwargs):
  366. super().__init__(items, **kwargs)
  367. self.tuple_cls = tuple_cls
  368. def python_type(self):
  369. return self.tuple_cls
  370. def as_python_constant(self):
  371. return self.python_type()(*[x.as_python_constant() for x in self.items])
  372. def reconstruct(self, codegen):
  373. create_fn = getattr(self.tuple_cls, "_make", self.tuple_cls)
  374. codegen.append_output(codegen._create_load_const(create_fn))
  375. codegen.foreach(self.items)
  376. return [
  377. create_instruction("BUILD_TUPLE", len(self.items)),
  378. create_instruction("CALL_FUNCTION", 1),
  379. ]
  380. def var_getattr(self, tx, name):
  381. fields = namedtuple_fields(self.tuple_cls)
  382. if name not in fields:
  383. unimplemented(f"NamedTupleVariable.{name}")
  384. return self.items[fields.index(name)].add_options(self)
  385. def call_hasattr(self, tx, name: str) -> "VariableTracker":
  386. options = VariableTracker.propagate(self)
  387. fields = namedtuple_fields(self.tuple_cls)
  388. return variables.ConstantVariable(name in fields, **options)
  389. class SliceVariable(BaseListVariable):
  390. def __init__(self, items, **kwargs):
  391. items_to_map = items
  392. start, stop, step = [variables.ConstantVariable(None)] * 3
  393. if len(items_to_map) == 1:
  394. (stop,) = items_to_map
  395. elif len(items_to_map) == 2:
  396. start, stop = items_to_map
  397. elif len(items_to_map) == 3:
  398. start, stop, step = items_to_map
  399. else:
  400. raise AssertionError()
  401. if isinstance(start, variables.TensorVariable) or isinstance(
  402. stop, variables.TensorVariable
  403. ):
  404. unimplemented("Dynamic slicing on data-dependent value is not supported")
  405. super().__init__([start, stop, step], **kwargs)
  406. def as_proxy(self):
  407. return slice(*self._as_proxy())
  408. def python_type(self):
  409. return slice
  410. def as_python_constant(self):
  411. return slice(*[x.as_python_constant() for x in self.items])
  412. def reconstruct(self, codegen):
  413. codegen.foreach(self.items)
  414. return [create_instruction("BUILD_SLICE", len(self.items))]
  415. def var_getattr(self, tx, name):
  416. fields = ["start", "stop", "step"]
  417. if name not in fields:
  418. unimplemented(f"slice.{name}")
  419. return self.items[fields.index(name)].add_options(self)
  420. class ListIteratorVariable(VariableTracker):
  421. def __init__(self, items, index: int = 0, recursively_contains=None, **kwargs):
  422. super().__init__(recursively_contains=recursively_contains, **kwargs)
  423. assert isinstance(items, list)
  424. # Removing this check as it slows things down too much
  425. # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492
  426. # assert all(isinstance(x, VariableTracker) for x in items)
  427. self.items = items
  428. self.index = index
  429. def next_variables(self):
  430. assert self.mutable_local
  431. if self.index >= len(self.items):
  432. raise StopIteration()
  433. return self.items[self.index].add_options(self), ListIteratorVariable(
  434. self.items,
  435. self.index + 1,
  436. mutable_local=MutableLocal(),
  437. recursively_contains=self.recursively_contains,
  438. **VariableTracker.propagate([self]),
  439. )
  440. def as_python_constant(self):
  441. if self.index > 0:
  442. raise NotImplementedError()
  443. return iter([x.as_python_constant() for x in self.items])
  444. def unpack_var_sequence(self, tx):
  445. return [x.add_options(self) for x in self.items[self.index :]]
  446. def reconstruct(self, codegen):
  447. remaining_items = self.items[self.index :]
  448. codegen.foreach(remaining_items)
  449. return [
  450. create_instruction("BUILD_TUPLE", len(remaining_items)),
  451. create_instruction("GET_ITER"),
  452. ]