builtin.py 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156
  1. import functools
  2. import inspect
  3. import itertools
  4. import logging
  5. import math
  6. import operator
  7. import types
  8. from typing import Dict, List
  9. import torch
  10. from torch import sym_float, sym_int
  11. from .. import config, variables
  12. from ..allowed_functions import is_allowed
  13. from ..exc import unimplemented, Unsupported
  14. from ..guards import GuardBuilder
  15. from ..replay_record import DummyModule
  16. from ..source import AttrSource, is_constant_source, SuperSource, TypeSource
  17. from ..utils import (
  18. check_constant_args,
  19. check_unspec_python_args,
  20. istype,
  21. proxy_args_kwargs,
  22. specialize_args_kwargs,
  23. )
  24. from .base import MutableLocal, typestr, VariableTracker
  25. from .constant import ConstantVariable
  26. from .dicts import ConstDictVariable
  27. from .lists import BaseListVariable, ListVariable, TupleVariable
  28. from .tensor import FakeItemVariable, SymNodeVariable, UnspecializedPythonVariable
  29. from .user_defined import UserDefinedVariable
  30. log = logging.getLogger(__name__)
  31. class BuiltinVariable(VariableTracker):
  32. @staticmethod
  33. @functools.lru_cache(None)
  34. def _constant_fold_functions():
  35. fns = {
  36. abs,
  37. all,
  38. any,
  39. bool,
  40. callable,
  41. chr,
  42. dict,
  43. divmod,
  44. float,
  45. int,
  46. len,
  47. list,
  48. max,
  49. min,
  50. ord,
  51. pow,
  52. repr,
  53. round,
  54. set,
  55. str,
  56. str.format,
  57. sum,
  58. tuple,
  59. type,
  60. operator.pos,
  61. operator.neg,
  62. operator.not_,
  63. operator.invert,
  64. operator.pow,
  65. operator.mul,
  66. operator.matmul,
  67. operator.floordiv,
  68. operator.truediv,
  69. operator.mod,
  70. operator.add,
  71. operator.sub,
  72. operator.getitem,
  73. operator.lshift,
  74. operator.rshift,
  75. operator.and_,
  76. operator.or_,
  77. operator.xor,
  78. operator.ipow,
  79. operator.imul,
  80. operator.imatmul,
  81. operator.ifloordiv,
  82. operator.itruediv,
  83. operator.imod,
  84. operator.iadd,
  85. operator.isub,
  86. operator.ilshift,
  87. operator.irshift,
  88. operator.iand,
  89. operator.ixor,
  90. operator.ior,
  91. operator.index,
  92. }
  93. fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt)))
  94. return fns
  95. def can_constant_fold_through(self):
  96. return self.fn in self._constant_fold_functions()
  97. @staticmethod
  98. @functools.lru_cache(None)
  99. def _fx_graph_functions():
  100. fns = {
  101. operator.pos,
  102. operator.neg,
  103. operator.not_,
  104. operator.invert,
  105. operator.pow,
  106. operator.mul,
  107. operator.matmul,
  108. operator.floordiv,
  109. operator.truediv,
  110. operator.mod,
  111. operator.add,
  112. operator.sub,
  113. operator.getitem,
  114. operator.lshift,
  115. operator.rshift,
  116. operator.and_,
  117. operator.or_,
  118. operator.xor,
  119. operator.ipow,
  120. operator.imul,
  121. operator.imatmul,
  122. operator.ifloordiv,
  123. operator.itruediv,
  124. operator.imod,
  125. operator.iadd,
  126. operator.isub,
  127. operator.ilshift,
  128. operator.irshift,
  129. operator.iand,
  130. operator.ixor,
  131. operator.ior,
  132. }
  133. return fns
  134. @staticmethod
  135. @functools.lru_cache(None)
  136. def _reversible_binops():
  137. # function -> (forward magic method name, reverse magic method name)
  138. fns = {
  139. operator.add: ("__add__", "__radd__"),
  140. operator.sub: ("__sub__", "__rsub__"),
  141. operator.mul: ("__mul__", "__rmul__"),
  142. operator.truediv: ("__truediv__", "__rtruediv__"),
  143. operator.floordiv: ("__floordiv__", "__rfloordiv__"),
  144. operator.mod: ("__mod__", "__rmod__"),
  145. pow: ("__pow__", "__rpow__"),
  146. operator.pow: ("__pow__", "__rpow__"),
  147. # Don't support these for now, since the corresponding reverse magic methods
  148. # aren't defined on SymInt / SymFloat.
  149. # operator.matmul: ("__matmul__", "__rmatmul__"),
  150. # divmod: ("__divmod__", "__rdivmod__"),
  151. # operator.lshift: ("__lshift__", "__rlshift__"),
  152. # operator.rshift: ("__rshift__", "__rrshift__"),
  153. # operator.and_: ("__and__", "__rand__"),
  154. # operator.or_: ("__or__", "__ror__"),
  155. # operator.xor: ("__xor__", "__rxor__"),
  156. }
  157. return fns
  158. @staticmethod
  159. @functools.lru_cache(None)
  160. def _inplace_binops():
  161. fns = {
  162. operator.ipow: "__ipow__",
  163. operator.imul: "__imul__",
  164. operator.imatmul: "__imatmul__",
  165. operator.ifloordiv: "__ifloordiv__",
  166. operator.itruediv: "__itruediv__",
  167. operator.imod: "__imod__",
  168. operator.iadd: "__iadd__",
  169. operator.iconcat: "__iconcat__",
  170. operator.isub: "__isub__",
  171. operator.ilshift: "__ilshift__",
  172. operator.irshift: "__irshift__",
  173. operator.iand: "__iand__",
  174. operator.ixor: "__ixor__",
  175. operator.ior: "__ior__",
  176. }
  177. return fns
  178. @staticmethod
  179. @functools.lru_cache(None)
  180. def _binop_handlers():
  181. # Multiple dispatch mechanism defining custom binop behavior for certain type
  182. # combinations. Handlers are attempted in order, and will be used if the type checks
  183. # match. They are expected to have the signature:
  184. # fn(tx, arg0: VariableTracker, arg1: VariableTracker, options) -> VariableTracker
  185. # Override table contains: op_fn -> [list of handlers]
  186. op_handlers = {}
  187. for (op, magic_method_names) in itertools.chain(
  188. BuiltinVariable._inplace_binops().items(),
  189. BuiltinVariable._reversible_binops().items(),
  190. ):
  191. handlers = []
  192. # User-defined args (highest precedence)
  193. if isinstance(magic_method_names, tuple):
  194. # Reversible binary ops have forward / backward magic methods
  195. forward_name, reverse_name = magic_method_names
  196. def user_defined_handler(
  197. tx,
  198. a,
  199. b,
  200. options,
  201. forward_name=forward_name,
  202. reverse_name=reverse_name,
  203. ):
  204. # Manually handle reversing logic if needed (e.g. call __radd__)
  205. # TODO: If we expand this to handle tensor args, we need to manually
  206. # handle cases like this:
  207. #
  208. # class A(int):
  209. # def __radd__(self, other):
  210. # print("woof")
  211. # torch.randn(3) + A(3)
  212. #
  213. # In this example, A.__radd__() is not called -> nothing is printed, because
  214. # Tensor.__add__ only does a subtype test against int, ignoring the subclass.
  215. # To be fully correct, we should not call A.__radd__() here, and there may be
  216. # other cases to reason about and add exceptions for.
  217. if isinstance(a, UserDefinedVariable):
  218. return a.call_method(tx, forward_name, [b], {})
  219. else:
  220. return b.call_method(tx, reverse_name, [a], {})
  221. else:
  222. forward_name = magic_method_names
  223. def user_defined_handler(tx, a, b, options, forward_name=forward_name):
  224. return a.call_method(tx, forward_name, [b], {})
  225. handlers.append(
  226. ((UserDefinedVariable, VariableTracker), user_defined_handler)
  227. )
  228. handlers.append(
  229. ((VariableTracker, UserDefinedVariable), user_defined_handler)
  230. )
  231. # Dynamic shape args
  232. def dynamic_handler(tx, a, b, options, fn=op):
  233. from .builder import wrap_fx_proxy
  234. return wrap_fx_proxy(
  235. tx,
  236. tx.output.create_proxy(
  237. "call_function", fn, *proxy_args_kwargs([a, b], {})
  238. ),
  239. **options,
  240. )
  241. handlers.append(((SymNodeVariable, VariableTracker), dynamic_handler))
  242. handlers.append(((VariableTracker, SymNodeVariable), dynamic_handler))
  243. op_handlers[op] = handlers
  244. # Special cases - lower precedence but still prefer these over constant folding
  245. # List-like addition (e.g. [1, 2] + [3, 4])
  246. def tuple_add_handler(tx, a, b, options):
  247. return TupleVariable(a.items + list(b.unpack_var_sequence(tx)), **options)
  248. list_like_addition_handlers = [
  249. # NB: Prefer the tuple-specific logic over base logic because of
  250. # some SizeVariable weirdness. Specifically, the tuple-specific logic
  251. # drops the subclass type (e.g. SizeVariable) and returns TupleVariables.
  252. (
  253. (TupleVariable, TupleVariable),
  254. tuple_add_handler,
  255. ),
  256. (
  257. (TupleVariable, ConstantVariable),
  258. tuple_add_handler,
  259. ),
  260. (
  261. (ConstantVariable, TupleVariable),
  262. lambda tx, a, b, options: TupleVariable(
  263. list(a.unpack_var_sequence(tx)) + b.items, **options
  264. ),
  265. ),
  266. (
  267. (BaseListVariable, BaseListVariable),
  268. lambda tx, a, b, options: type(a)(a.items + b.items, **options),
  269. ),
  270. ]
  271. op_handlers[operator.add].extend(list_like_addition_handlers)
  272. def list_iadd_handler(tx, a, b, options):
  273. if not a.mutable_local or not b.has_unpack_var_sequence(tx):
  274. # Handler doesn't apply
  275. return None
  276. return tx.replace_all(
  277. a,
  278. ListVariable(
  279. list(a.items) + list(b.unpack_var_sequence(tx)),
  280. regen_guards=False,
  281. **options,
  282. ),
  283. )
  284. list_like_iadd_handlers = [
  285. (
  286. (ListVariable, VariableTracker),
  287. list_iadd_handler,
  288. ),
  289. (
  290. (TupleVariable, TupleVariable),
  291. tuple_add_handler,
  292. ),
  293. (
  294. (TupleVariable, ConstantVariable),
  295. tuple_add_handler,
  296. ),
  297. ]
  298. op_handlers[operator.iadd].extend(list_like_iadd_handlers)
  299. # List-like expansion (e.g. [1, 2, 3] * 3)
  300. def expand_list_like(tx, lst, const, options):
  301. return lst.__class__(
  302. items=lst.items * const.as_python_constant(),
  303. mutable_local=MutableLocal(),
  304. **options,
  305. )
  306. list_like_expansion_handlers = [
  307. ((ListVariable, ConstantVariable), expand_list_like),
  308. ((TupleVariable, ConstantVariable), expand_list_like),
  309. (
  310. (ConstantVariable, ListVariable),
  311. lambda tx, a, b, options: expand_list_like(tx, b, a, options),
  312. ),
  313. (
  314. (ConstantVariable, TupleVariable),
  315. lambda tx, a, b, options: expand_list_like(tx, b, a, options),
  316. ),
  317. ]
  318. op_handlers[operator.mul].extend(list_like_expansion_handlers)
  319. return op_handlers
  320. @staticmethod
  321. def _find_binop_handler(op, a, b):
  322. handlers = BuiltinVariable._binop_handlers()
  323. if op not in handlers:
  324. return None
  325. # Return first handler that matches the type checks
  326. for ((type1, type2), handler) in handlers[op]:
  327. if isinstance(a, type1) and isinstance(b, type2):
  328. return handler
  329. return None
  330. def can_insert_in_graph(self):
  331. return self.fn in self._fx_graph_functions()
  332. def __init__(self, fn, **kwargs):
  333. super().__init__(**kwargs)
  334. self.fn = fn
  335. def __str__(self):
  336. if self.fn is None:
  337. name = "None"
  338. else:
  339. name = self.fn.__name__
  340. return f"{self.__class__.__name__}({name})"
  341. def python_type(self):
  342. return type(self.fn)
  343. def as_python_constant(self):
  344. return self.fn
  345. def reconstruct(self, codegen):
  346. name = self.fn.__name__
  347. assert self.fn.__module__ == "builtins"
  348. assert name not in codegen.tx.f_globals, "shadowed global"
  349. return [codegen.create_load_global(name, add=True)]
  350. def constant_args(self, *args, **kwargs):
  351. return check_constant_args(args, kwargs)
  352. def tensor_args(self, *args, **kwargs):
  353. return any(
  354. isinstance(i, variables.TensorVariable)
  355. for i in itertools.chain(args, kwargs.values())
  356. ) and not any(
  357. isinstance(i, variables.GetAttrVariable)
  358. for i in itertools.chain(args, kwargs.values())
  359. )
  360. def unspec_python_args(self, *args, **kwargs):
  361. return check_unspec_python_args(args, kwargs)
  362. @staticmethod
  363. def unwrap_unspec_args_kwargs(args, kwargs):
  364. unwrapped_args = []
  365. unwrapped_kwargs = {}
  366. for x in args:
  367. if isinstance(
  368. x,
  369. (variables.UnspecializedPythonVariable,),
  370. ):
  371. unwrapped_args.append(x.raw_value)
  372. else:
  373. unwrapped_args.append(x.as_python_constant())
  374. for k, v in kwargs:
  375. if isinstance(
  376. x,
  377. (variables.UnspecializedPythonVariable,),
  378. ):
  379. unwrapped_kwargs.update({k: v.raw_value})
  380. else:
  381. unwrapped_kwargs.update({k: v.as_python_constant()})
  382. return unwrapped_args, unwrapped_kwargs
  383. def call_function(
  384. self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
  385. ) -> "VariableTracker":
  386. from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
  387. constant_args = check_constant_args(args, kwargs)
  388. tensor_args = self.tensor_args(*args, **kwargs)
  389. unspec_python_args = self.unspec_python_args(*args, **kwargs)
  390. options = VariableTracker.propagate(self, args, kwargs.values())
  391. has_constant_handler = self.can_constant_fold_through() and (
  392. constant_args or unspec_python_args
  393. )
  394. assert isinstance(args, (list, tuple))
  395. assert isinstance(kwargs, dict)
  396. if (
  397. self.fn is operator.getitem
  398. and len(args) == 2
  399. and isinstance(args[1], variables.TensorVariable)
  400. and args[1].dtype == torch.bool
  401. and not config.dynamic_shapes
  402. ):
  403. unimplemented("dynamic Tensor.__getitem__(bool[])")
  404. # args[0] is list and args[1] is unspec
  405. if self.fn is operator.getitem and not isinstance(
  406. args[0], variables.TensorVariable
  407. ):
  408. tensor_args = False
  409. args, kwargs = specialize_args_kwargs(tx, args, kwargs)
  410. if (
  411. self.can_insert_in_graph()
  412. and tensor_args
  413. and not (
  414. self.fn is operator.getitem
  415. and isinstance(args[0], ConstDictVariable)
  416. and isinstance(args[1], variables.TensorVariable)
  417. )
  418. ):
  419. try:
  420. fn = self.fn
  421. if self.fn is operator.iadd and isinstance(
  422. args[0], variables.ConstantVariable
  423. ):
  424. # Work around weird bug in hf_T5
  425. fn, args = operator.add, [args[1], args[0]]
  426. if self.fn is operator.not_:
  427. fn = torch.logical_not
  428. proxy = tx.output.create_proxy(
  429. "call_function",
  430. fn,
  431. *proxy_args_kwargs(args, kwargs),
  432. )
  433. if any([isinstance(arg, FakeItemVariable) for arg in args]):
  434. return wrap_fx_proxy_cls(
  435. FakeItemVariable,
  436. tx,
  437. proxy,
  438. **options,
  439. )
  440. elif self.unspec_python_args(*args, **kwargs):
  441. _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs)
  442. raw_value = self.fn(*_args, **_kwargs)
  443. need_unwrap = any(
  444. x.need_unwrap
  445. for x in itertools.chain(args, kwargs.values())
  446. if isinstance(x, variables.UnspecializedPythonVariable)
  447. )
  448. return wrap_fx_proxy_cls(
  449. UnspecializedPythonVariable,
  450. tx,
  451. proxy,
  452. raw_value=raw_value,
  453. need_unwrap=need_unwrap,
  454. **options,
  455. )
  456. elif all(isinstance(x, SymNodeVariable) for x in args):
  457. return SymNodeVariable.create(tx, proxy, None, **options)
  458. else:
  459. # Work around for vision_maskrcnn due to precision difference
  460. # specialize the dividend when float divide by tensor
  461. if self.fn is operator.truediv and isinstance(
  462. args[0], variables.UnspecializedPythonVariable
  463. ):
  464. args[0] = args[0].convert_to_constant(tx)
  465. return wrap_fx_proxy(tx, proxy, **options)
  466. except NotImplementedError:
  467. unimplemented(f"partial tensor op: {self} {args} {kwargs}")
  468. # Handle cases like int(torch.seed())
  469. # Also handle sym_float to sym_int cases
  470. if self.fn in (int, float) and isinstance(args[0], SymNodeVariable):
  471. fn_ = sym_int if self.fn is int else sym_float
  472. out = wrap_fx_proxy(
  473. tx=tx,
  474. proxy=tx.output.create_proxy(
  475. "call_function",
  476. fn_,
  477. (args[0].as_proxy(),),
  478. {},
  479. ),
  480. **options,
  481. )
  482. return out
  483. # Handle binary ops (e.g. __add__ / __radd__, __iadd__, etc.)
  484. # NB: Tensor args are handled above and not here
  485. if self.fn in self._reversible_binops() or self.fn in self._inplace_binops():
  486. assert len(kwargs) == 0 and len(args) == 2
  487. # Try to find a handler for the arg types; otherwise, fall through to constant handler
  488. binop_handler = BuiltinVariable._find_binop_handler(
  489. self.fn, args[0], args[1]
  490. )
  491. if binop_handler:
  492. res = binop_handler(tx, args[0], args[1], options)
  493. if res is not None:
  494. return res
  495. handler = getattr(self, f"call_{self.fn.__name__}", None)
  496. if handler:
  497. try:
  498. inspect.signature(handler).bind(tx, *args, **kwargs)
  499. except TypeError as exc:
  500. if not has_constant_handler:
  501. log.warning(
  502. f"incorrect arg count {handler} {exc} and no constant handler"
  503. )
  504. handler = None
  505. if handler:
  506. try:
  507. result = handler(tx, *args, **kwargs)
  508. if result is not None:
  509. return result.add_options(options)
  510. except Unsupported as exc:
  511. if not has_constant_handler:
  512. raise
  513. # Actually, we will handle this just fine
  514. exc.remove_from_stats()
  515. if has_constant_handler:
  516. args, kwargs = specialize_args_kwargs(tx, args, kwargs)
  517. # constant fold
  518. return variables.ConstantVariable(
  519. self.as_python_constant()(
  520. *[x.as_python_constant() for x in args],
  521. **{k: v.as_python_constant() for k, v in kwargs.items()},
  522. ),
  523. **options,
  524. )
  525. return super().call_function(tx, args, kwargs)
  526. def _call_min_max(self, tx, *args):
  527. if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
  528. # expand iterable
  529. items = args[0].unpack_var_sequence(tx)
  530. return self._call_min_max_seq(tx, items)
  531. elif len(args) == 2:
  532. return self._call_min_max_binary(tx, args[0], args[1])
  533. elif len(args) > 2:
  534. return self._call_min_max_seq(tx, args)
  535. def _call_min_max_seq(self, tx, items):
  536. assert len(items) > 0
  537. if len(items) == 1:
  538. return items[0]
  539. return functools.reduce(functools.partial(self._call_min_max_binary, tx), items)
  540. def _call_min_max_binary(self, tx, a, b):
  541. if self.tensor_args(a, b):
  542. if not isinstance(a, variables.TensorVariable):
  543. a, b = b, a
  544. assert isinstance(a, variables.TensorVariable)
  545. # result of an item call is a scalar convert to a tensor
  546. if isinstance(a, FakeItemVariable):
  547. a = variables.TorchVariable(torch.tensor).call_function(tx, [a], {})
  548. # Dynamic input does not get resolved, rather, gets stored as call_function
  549. if isinstance(a, SymNodeVariable):
  550. from .builder import wrap_fx_proxy
  551. return wrap_fx_proxy(
  552. tx=tx,
  553. proxy=tx.output.create_proxy(
  554. "call_function",
  555. self.fn,
  556. *proxy_args_kwargs([a, b], {}),
  557. ),
  558. **VariableTracker.propagate(self, [a, b]),
  559. )
  560. # convert min/max to torch ops
  561. if b.is_python_constant():
  562. kwargs = {"min": b} if (self.fn is max) else {"max": b}
  563. result = variables.TorchVariable(torch.clamp).call_function(
  564. tx, [a], kwargs
  565. )
  566. else:
  567. fn = {max: torch.maximum, min: torch.minimum}[self.fn]
  568. result = variables.TorchVariable(fn).call_function(tx, [a, b], {})
  569. # return unspec if both a, b are unspec or const
  570. if all(
  571. isinstance(
  572. i,
  573. (
  574. variables.UnspecializedPythonVariable,
  575. variables.ConstantVariable,
  576. ),
  577. )
  578. for i in [a, b]
  579. ):
  580. if any([isinstance(val, FakeItemVariable) for val in [a, b]]):
  581. return variables.FakeItemVariable.from_tensor_variable(result)
  582. if b.is_python_constant():
  583. raw_b = b.as_python_constant()
  584. else:
  585. raw_b = b.raw_value
  586. if self.fn is max:
  587. raw_res = max(a.raw_value, raw_b)
  588. else:
  589. raw_res = min(a.raw_value, raw_b)
  590. need_unwrap = any(
  591. x.need_unwrap
  592. for x in [a, b]
  593. if isinstance(x, variables.UnspecializedPythonVariable)
  594. )
  595. return variables.UnspecializedPythonVariable.from_tensor_variable(
  596. result, raw_res, need_unwrap
  597. )
  598. # otherwise return tensor
  599. else:
  600. return result
  601. elif isinstance(a, variables.ConstantVariable) and isinstance(
  602. b, variables.ConstantVariable
  603. ):
  604. if self.fn is max:
  605. return variables.ConstantVariable(max(a.value, b.value))
  606. else:
  607. return variables.ConstantVariable(min(a.value, b.value))
  608. elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable):
  609. proxy = tx.output.create_proxy(
  610. "call_function", self.fn, *proxy_args_kwargs([a, b], {})
  611. )
  612. return SymNodeVariable.create(tx, proxy, None)
  613. else:
  614. unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}")
  615. call_min = _call_min_max
  616. call_max = _call_min_max
  617. def call_range(self, tx, *args):
  618. if self.unspec_python_args(*args) or self.constant_args(*args):
  619. args, _ = specialize_args_kwargs(tx, args, {})
  620. return variables.RangeVariable(args)
  621. elif self._dynamic_args(*args):
  622. def guard_if_dyn(arg):
  623. if isinstance(arg, SymNodeVariable):
  624. return arg.evaluate_expr(tx.output)
  625. elif isinstance(arg, ConstantVariable):
  626. return arg.as_python_constant()
  627. return arg
  628. args = [variables.ConstantVariable(guard_if_dyn(arg)) for arg in args]
  629. return variables.RangeVariable(args)
  630. # None no-ops this handler and lets the driving function proceed
  631. return None
  632. def _dynamic_args(self, *args, **kwargs):
  633. return any([isinstance(x, SymNodeVariable) for x in args]) or any(
  634. [isinstance(x, SymNodeVariable) for x in kwargs.values()]
  635. )
  636. def call_slice(self, tx, *args):
  637. return variables.SliceVariable(args)
  638. def _dyn_proxy(self, tx, *args, **kwargs):
  639. from .builder import wrap_fx_proxy
  640. options = VariableTracker.propagate(self, args, kwargs.values())
  641. return wrap_fx_proxy(
  642. tx,
  643. tx.output.create_proxy(
  644. "call_function", self.fn, *proxy_args_kwargs(args, kwargs)
  645. ),
  646. **options,
  647. )
  648. def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs):
  649. if self._dynamic_args(*args, **kwargs):
  650. return self._dyn_proxy(tx, *args, **kwargs)
  651. cls = variables.BaseListVariable.cls_for(self.fn)
  652. if obj is None:
  653. return cls(
  654. [],
  655. mutable_local=MutableLocal(),
  656. )
  657. elif obj.has_unpack_var_sequence(tx):
  658. guards = set()
  659. if obj.source and not is_constant_source(obj.source):
  660. guards.add(obj.source.make_guard(GuardBuilder.LIST_LENGTH))
  661. return cls(
  662. list(obj.unpack_var_sequence(tx)),
  663. mutable_local=MutableLocal(),
  664. guards=guards,
  665. ).add_options(self, obj)
  666. call_iter = _call_iter_tuple_list
  667. call_tuple = _call_iter_tuple_list
  668. call_list = _call_iter_tuple_list
  669. def call_dict(self, tx, arg):
  670. if isinstance(arg, variables.ConstDictVariable):
  671. return arg.clone(mutable_local=MutableLocal())
  672. def call_zip(self, tx, *args):
  673. options = VariableTracker.propagate(self, args)
  674. if all(x.has_unpack_var_sequence(tx) for x in args):
  675. items = [
  676. variables.TupleVariable(list(item), **options)
  677. for item in zip(*[arg.unpack_var_sequence(tx) for arg in args])
  678. ]
  679. return variables.TupleVariable(items, **options)
  680. def call_enumerate(self, tx, *args):
  681. options = VariableTracker.propagate(self, args)
  682. if len(args) == 1:
  683. start = 0
  684. else:
  685. assert len(args) == 2
  686. assert isinstance(args[1], variables.ConstantVariable)
  687. start = args[1].as_python_constant()
  688. if args[0].has_unpack_var_sequence(tx):
  689. items = [
  690. variables.TupleVariable(
  691. [variables.ConstantVariable(idx, **options), var],
  692. **options,
  693. )
  694. for idx, var in enumerate(args[0].unpack_var_sequence(tx), start)
  695. ]
  696. return variables.TupleVariable(items, **options)
  697. def call_len(self, tx, *args, **kwargs):
  698. return args[0].call_method(tx, "__len__", args[1:], kwargs)
  699. def call_getitem(self, tx, *args, **kwargs):
  700. if self.unspec_python_args(*args, **kwargs):
  701. args, kwargs = specialize_args_kwargs(tx, args, kwargs)
  702. return args[0].call_method(tx, "__getitem__", args[1:], kwargs)
  703. def call_isinstance(self, tx, arg, isinstance_type):
  704. arg_type = arg.python_type()
  705. isinstance_type = isinstance_type.as_python_constant()
  706. if isinstance(arg, variables.TensorVariable) and arg.dtype is not None:
  707. return variables.ConstantVariable(arg.call_isinstance(isinstance_type))
  708. # UserDefinedObject with C extensions can have torch.Tensor attributes,
  709. # so break graph.
  710. if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance(
  711. arg.value, types.MemberDescriptorType
  712. ):
  713. unimplemented(
  714. f"isinstance called on UserDefinedClass {arg} {isinstance_type}"
  715. )
  716. # handle __instancecheck__ defined in user class
  717. if (
  718. isinstance(arg, variables.UserDefinedObjectVariable)
  719. and "__instancecheck__" in isinstance_type.__class__.__dict__
  720. ):
  721. return variables.ConstantVariable(
  722. isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value)
  723. )
  724. try:
  725. val = issubclass(arg_type, isinstance_type)
  726. except TypeError:
  727. val = arg_type is isinstance_type
  728. return variables.ConstantVariable(val)
  729. def call_super(self, tx, a, b):
  730. source = (
  731. None
  732. if a.source is None or b.source is None
  733. else SuperSource(a.source, b.source)
  734. )
  735. return variables.SuperVariable(a, b, source=source)
  736. def call_next(self, tx, arg):
  737. if isinstance(arg, variables.ListIteratorVariable):
  738. val, next_iter = arg.next_variables()
  739. tx.replace_all(arg, next_iter)
  740. return val
  741. elif isinstance(arg, variables.BaseListVariable):
  742. return arg.items[0].add_options(self, arg)
  743. def call_hasattr(self, tx, obj, attr):
  744. if attr.is_python_constant():
  745. name = attr.as_python_constant()
  746. return obj.call_hasattr(tx, name).add_options(self, obj, attr)
  747. def call_map(self, tx, fn, seq):
  748. if seq.has_unpack_var_sequence(tx):
  749. items = [fn.call_function(tx, [x], {}) for x in seq.unpack_var_sequence(tx)]
  750. return variables.TupleVariable(items).add_options(self, fn, seq)
  751. def call_sum(self, tx, seq, **kwargs):
  752. # Special case for sum on tuple of floats and ints
  753. if (
  754. isinstance(seq, (variables.ListVariable, variables.TupleVariable))
  755. and all(
  756. [
  757. isinstance(x, variables.ConstantVariable)
  758. and isinstance(x.value, (int, float))
  759. for x in seq.items
  760. ]
  761. )
  762. and not kwargs
  763. ):
  764. new_list = [x.value for x in seq.items]
  765. return variables.ConstantVariable(sum(new_list))
  766. if seq.has_unpack_var_sequence(tx):
  767. start = kwargs.pop(
  768. "start", variables.ConstantVariable(0)
  769. ).as_python_constant()
  770. assert not kwargs
  771. items = seq.unpack_var_sequence(tx)[start:]
  772. return BuiltinVariable(functools.reduce).call_function(
  773. tx,
  774. [
  775. BuiltinVariable(operator.add),
  776. variables.TupleVariable(items),
  777. variables.ConstantVariable(0).add_options(self, seq),
  778. ],
  779. {},
  780. )
  781. def call_reduce(self, tx, function, iterable, initializer=None):
  782. if iterable.has_unpack_var_sequence(tx):
  783. items = iterable.unpack_var_sequence(tx)
  784. if initializer is None:
  785. value, items = items[0], items[1:]
  786. else:
  787. value = initializer
  788. for element in items:
  789. value = function.call_function(tx, [value, element], {})
  790. return value
  791. def call_getattr(
  792. self, tx, obj: VariableTracker, name_var: VariableTracker, default=None
  793. ):
  794. from . import (
  795. ConstantVariable,
  796. GetAttrVariable,
  797. PythonModuleVariable,
  798. TorchVariable,
  799. UserFunctionVariable,
  800. )
  801. from .builder import VariableBuilder
  802. options = VariableTracker.propagate(self, obj, name_var)
  803. guards = options["guards"]
  804. name = name_var.as_python_constant()
  805. if not name_var.is_python_constant():
  806. unimplemented("non-const getattr() name")
  807. if tx.output.side_effects.is_attribute_mutation(obj):
  808. try:
  809. # re-read a pending side effect?
  810. return tx.output.side_effects.load_attr(obj, name).add_options(options)
  811. except KeyError:
  812. pass
  813. if default is not None:
  814. hasattr_var = self.call_hasattr(tx, obj, name_var)
  815. guards.update(hasattr_var.guards)
  816. assert hasattr_var.as_python_constant() in (True, False)
  817. if not hasattr_var.as_python_constant():
  818. return default.add_guards(guards)
  819. if obj.source:
  820. source = AttrSource(obj.source, name)
  821. options["source"] = source
  822. else:
  823. source = None
  824. if isinstance(obj, variables.NNModuleVariable):
  825. return obj.var_getattr(tx, name).add_options(options)
  826. elif isinstance(obj, variables.TensorVariable) and name == "grad":
  827. if source:
  828. # We are going to be raising this tensor as grapharg. So, ensure
  829. # that we have real grad value instead of fake tensor value.
  830. # Walk through the inputs of the subgraph and find if we already
  831. # have the original tensor stored in the graphargs.
  832. for grapharg in tx.output.graphargs:
  833. if grapharg.source == source.base:
  834. example_value = grapharg.example.grad
  835. return VariableBuilder(tx, source)(example_value).add_options(
  836. options
  837. )
  838. unimplemented("tensor grad")
  839. else:
  840. unimplemented("tensor grad")
  841. elif isinstance(
  842. obj,
  843. (
  844. variables.TensorVariable,
  845. variables.NamedTupleVariable,
  846. variables.ConstantVariable,
  847. variables.UserDefinedClassVariable,
  848. variables.UserDefinedObjectVariable,
  849. ),
  850. ):
  851. try:
  852. return (
  853. obj.var_getattr(tx, name).clone(source=source).add_options(options)
  854. )
  855. except NotImplementedError:
  856. return GetAttrVariable(obj, name, **options)
  857. elif isinstance(obj, TorchVariable):
  858. member = getattr(obj.value, name)
  859. if is_allowed(member):
  860. return TorchVariable(member, **options)
  861. elif ConstantVariable.is_literal(member):
  862. return ConstantVariable(member, **options)
  863. else:
  864. return VariableBuilder(tx, source)(member).add_guards(guards)
  865. elif isinstance(obj, (PythonModuleVariable, DummyModule)):
  866. member = obj.value.__dict__[name]
  867. if config.replay_record_enabled:
  868. tx.exec_recorder.record_module_access(obj.value, name, member)
  869. return VariableBuilder(tx, source)(member).add_guards(guards)
  870. elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"):
  871. return ConstantVariable(
  872. getattr(obj.fn, name), **VariableTracker.propagate(obj)
  873. )
  874. else:
  875. try:
  876. return (
  877. obj.var_getattr(tx, name).clone(source=source).add_options(options)
  878. )
  879. except NotImplementedError:
  880. return GetAttrVariable(obj, name, **options)
  881. def call_setattr(
  882. self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker
  883. ):
  884. if isinstance(obj, (variables.BlackHoleVariable, variables.DataClassVariable)):
  885. return obj.call_method(tx, "__setattr__", [name_var, val], {})
  886. elif (
  887. tx.output.side_effects.is_attribute_mutation(obj)
  888. and name_var.is_python_constant()
  889. ):
  890. tx.output.side_effects.store_attr(obj, name_var.as_python_constant(), val)
  891. return val.add_options(self, obj, name_var)
  892. elif isinstance(obj, variables.UserDefinedObjectVariable):
  893. unimplemented(
  894. f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}"
  895. )
  896. elif isinstance(obj, variables.NNModuleVariable):
  897. obj.convert_to_unspecialized(tx)
  898. def call_type(self, tx, obj: VariableTracker):
  899. from .builder import VariableBuilder
  900. try:
  901. py_type = obj.python_type()
  902. except NotImplementedError:
  903. py_type = None
  904. if istype(obj, variables.TupleVariable):
  905. return BuiltinVariable(py_type).add_options(self, obj)
  906. if py_type is not None and obj.source:
  907. return VariableBuilder(tx, TypeSource(obj.source))(py_type).add_options(
  908. self, obj
  909. )
  910. unimplemented(f"type({obj})")
  911. def call_reversed(self, tx, obj: VariableTracker):
  912. if obj.has_unpack_var_sequence(tx):
  913. items = list(reversed(obj.unpack_var_sequence(tx)))
  914. return variables.TupleVariable(
  915. items, **VariableTracker.propagate(self, obj)
  916. )
  917. def call_chain(self, tx, *args):
  918. if all(obj.has_unpack_var_sequence(tx) for obj in args):
  919. items = []
  920. for obj in args:
  921. items.extend(obj.unpack_var_sequence(tx))
  922. return variables.TupleVariable(
  923. items, **VariableTracker.propagate(self, *args)
  924. )
  925. def call_islice(self, tx, iterable, *args):
  926. if iterable.has_unpack_var_sequence(tx) and all(
  927. x.is_python_constant() for x in args
  928. ):
  929. const_args = [x.as_python_constant() for x in args]
  930. items = iterable.unpack_var_sequence(tx)
  931. items = list(itertools.islice(items, *const_args))
  932. return variables.TupleVariable(
  933. items, **VariableTracker.propagate(self, iterable, *args)
  934. )
  935. def call_id(self, tx, *args):
  936. if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable):
  937. nn_mod_variable = args[0]
  938. mod = tx.output.get_submodule(nn_mod_variable.module_key)
  939. return variables.ConstantVariable(id(mod))
  940. else:
  941. unimplemented(f"call_id with args {args}")
  942. def _comparison(self, tx, left, right):
  943. """
  944. Used to implement comparison operators for different types.
  945. For example, list1 < list2 is implemented differently from tensor1 < tensor2
  946. """
  947. from . import (
  948. BaseListVariable,
  949. ConstantVariable,
  950. TensorVariable,
  951. UserFunctionVariable,
  952. )
  953. from .lists import SizeVariable
  954. from .tensor import (
  955. supported_const_comparison_ops,
  956. supported_tensor_comparison_ops,
  957. )
  958. op = self.fn
  959. def _unimplemented():
  960. unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}")
  961. if isinstance(left, UserFunctionVariable):
  962. if op not in supported_const_comparison_ops.values():
  963. _unimplemented()
  964. if not isinstance(right, UserFunctionVariable):
  965. _unimplemented()
  966. return ConstantVariable(op(left.fn, right.fn))
  967. # Note, we have a rare BaseListVariable subtype mismatch with valid comparison
  968. # x = torch.randn([3, 3])
  969. # x.size() == (3, 3) # True
  970. # (3, 3) == x.size() # True
  971. if isinstance(left, (SizeVariable, TupleVariable)) and isinstance(
  972. right, (TupleVariable, SizeVariable)
  973. ):
  974. return BaseListVariable.list_compare(tx, op, left, right)
  975. if isinstance(left, BaseListVariable):
  976. if not type(left) == type(right): # Mismatch in BaseListVariable subclasses
  977. _unimplemented()
  978. return BaseListVariable.list_compare(tx, op, left, right)
  979. if isinstance(left, TensorVariable):
  980. from .builder import wrap_fx_proxy
  981. if op not in supported_tensor_comparison_ops.values():
  982. _unimplemented()
  983. return wrap_fx_proxy(
  984. tx,
  985. op(left.as_proxy(), right.as_proxy()),
  986. )
  987. if isinstance(left, SymNodeVariable) or isinstance(right, SymNodeVariable):
  988. if op not in supported_tensor_comparison_ops.values():
  989. _unimplemented()
  990. return SymNodeVariable.create(
  991. tx,
  992. op(left.as_proxy(), right.as_proxy()),
  993. sym_num=None,
  994. )
  995. _unimplemented()
  996. # and_ is a constant fold function, so we only get here if constant fold is not valid
  997. def call_and_(self, tx, a, b):
  998. if isinstance(a, SymNodeVariable) and isinstance(b, SymNodeVariable):
  999. return SymNodeVariable.create(
  1000. tx,
  1001. tx.output.create_proxy(
  1002. "call_function", operator.and_, *proxy_args_kwargs([a, b], {})
  1003. ),
  1004. sym_num=None,
  1005. )
  1006. # None no-ops this handler and lets the driving function proceed
  1007. return None
  1008. def call_not_(self, tx, a):
  1009. if isinstance(a, SymNodeVariable):
  1010. return SymNodeVariable.create(
  1011. tx,
  1012. tx.output.create_proxy(
  1013. "call_function", operator.not_, *proxy_args_kwargs([a], {})
  1014. ),
  1015. sym_num=None,
  1016. )
  1017. return None
  1018. call_eq = _comparison
  1019. call_gt = _comparison
  1020. call_lt = _comparison
  1021. call_ge = _comparison
  1022. call_le = _comparison
  1023. call_ne = _comparison
  1024. call_is_ = _comparison
  1025. call_is_not = _comparison