symbolic_convert.py 71 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007
  1. import collections
  2. import dataclasses
  3. import dis
  4. import functools
  5. import importlib
  6. import inspect
  7. import itertools
  8. import logging
  9. import operator
  10. import sys
  11. import traceback
  12. import types
  13. import typing
  14. import weakref
  15. from collections.abc import Sized
  16. from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
  17. from unittest.mock import patch
  18. import torch
  19. from torch._guards import Checkpointable
  20. from . import (
  21. allowed_functions,
  22. config,
  23. exc,
  24. logging as torchdynamo_logging,
  25. side_effects,
  26. skipfiles,
  27. variables,
  28. )
  29. from .allowed_functions import is_allowed, is_builtin_callable, is_builtin_constant
  30. from .bytecode_analysis import JUMP_OPNAMES, livevars_analysis
  31. from .bytecode_transformation import (
  32. cleaned_instructions,
  33. create_instruction,
  34. create_jump_absolute,
  35. Instruction,
  36. is_generator,
  37. unique_id,
  38. )
  39. from .codegen import PyCodegen
  40. from .exc import BackendCompilerFailed, unimplemented, Unsupported
  41. from .guards import GuardBuilder
  42. from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
  43. from .replay_record import DummyModule, ExecutionRecorder
  44. from .resume_execution import ContinueExecutionCache, ReenterWith
  45. from .source import (
  46. AttrSource,
  47. GetItemSource,
  48. GlobalSource,
  49. GlobalWeakRefSource,
  50. LocalInputSource,
  51. LocalSource,
  52. )
  53. from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs
  54. from .variables.base import MutableLocal, typestr, VariableTracker
  55. from .variables.builder import VariableBuilder, wrap_fx_proxy
  56. from .variables.builtin import BuiltinVariable
  57. from .variables.constant import ConstantVariable, EnumVariable
  58. from .variables.dicts import ConstDictVariable
  59. from .variables.functions import (
  60. BaseUserFunctionVariable,
  61. NestedUserFunctionVariable,
  62. UserFunctionVariable,
  63. UserMethodVariable,
  64. )
  65. from .variables.lists import (
  66. BaseListVariable,
  67. ListIteratorVariable,
  68. ListVariable,
  69. SliceVariable,
  70. TupleVariable,
  71. )
  72. from .variables.misc import (
  73. ClosureVariable,
  74. ContextWrappingVariable,
  75. GetAttrVariable,
  76. GradModeVariable,
  77. PythonModuleVariable,
  78. UnknownVariable,
  79. WithExitFunctionVariable,
  80. )
  81. from .variables.nn_module import NNModuleVariable
  82. from .variables.tensor import (
  83. supported_const_comparison_ops,
  84. supported_tensor_comparison_ops,
  85. SymNodeVariable,
  86. TensorVariable,
  87. )
  88. from .variables.torch import TorchVariable
  89. from .variables.user_defined import UserDefinedObjectVariable, UserDefinedVariable
  90. log = logging.getLogger(__name__)
  91. @functools.lru_cache(None)
  92. def _step_logger():
  93. return torchdynamo_logging.get_step_logger(log)
  94. @dataclasses.dataclass
  95. class BlockStackEntry:
  96. target: Instruction
  97. stack_index: Optional[int] = None
  98. with_context: ContextWrappingVariable = None
  99. def can_restore(self):
  100. return self.with_context is not None
  101. def resume_fn(self):
  102. assert self.stack_index is not None
  103. if self.with_context and self.with_context.target_values:
  104. return ReenterWith(self.stack_index, tuple(self.with_context.target_values))
  105. else:
  106. return ReenterWith(self.stack_index)
  107. def exit(self, tx):
  108. return self.with_context.exit(tx)
  109. class InstructionTranslatorGraphState(NamedTuple):
  110. output: OutputGraphState
  111. symbolic_locals: Dict[str, VariableTracker]
  112. stack: List[VariableTracker]
  113. block_stack: List[BlockStackEntry]
  114. instruction_pointer: Optional[int]
  115. current_instruction: Instruction
  116. next_instruction: Optional[Instruction]
  117. lineno: int
  118. def diff(self, other: "InstructionTranslatorGraphState") -> Optional[str]:
  119. for k in self._fields:
  120. if k == "output":
  121. return self.output.diff(other.output, prefix=f"{k}.")
  122. sv = getattr(self, k)
  123. ov = getattr(other, k)
  124. if sv != ov:
  125. return f"{k} mismatch: {sv} != {ov}"
  126. return None
  127. def stack_op(fn: typing.Callable[..., object]):
  128. nargs = len(inspect.signature(fn).parameters)
  129. fn_var = BuiltinVariable(fn)
  130. @functools.wraps(fn)
  131. def impl(self: "InstructionTranslatorBase", inst: Instruction):
  132. self.push(fn_var.call_function(self, self.popn(nargs), {}))
  133. return impl
  134. def _detect_and_normalize_assert_statement(
  135. self: "InstructionTranslatorBase",
  136. truth_fn: typing.Callable[[object], bool],
  137. push: bool,
  138. ):
  139. # Detect if this jump instruction is assert and normalize the assert
  140. # by pushing dummy error message when nothing is given.
  141. #
  142. # Python 3.9 assertion is in following format:
  143. # 18 POP_JUMP_IF_TRUE 28
  144. # 20 LOAD_ASSERTION_ERROR
  145. # 22 LOAD_CONST 3 ('Assert message') -> optional instruction
  146. # 24 CALL_FUNCTION 1 -> optional instruction
  147. # 26 RAISE_VARARGS
  148. #
  149. # Python 3.8 assertion is in following format:
  150. # 18 POP_JUMP_IF_TRUE 28
  151. # 20 LOAD_GLOBAL 0 (Assertion type)
  152. # 22 LOAD_CONST 3 ('Assert message') -> optional instruction
  153. # 24 CALL_FUNCTION 1 -> optional instruction
  154. # 26 RAISE_VARARGS 1
  155. if (truth_fn is not operator.truth) or push:
  156. return False
  157. assert isinstance(self.instruction_pointer, int)
  158. current_instruction_pointer = self.instruction_pointer
  159. inst = self.instructions[current_instruction_pointer]
  160. # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0
  161. if sys.version_info < (3, 9):
  162. if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError":
  163. return False
  164. else:
  165. if inst.opname != "LOAD_ASSERTION_ERROR":
  166. return False
  167. current_instruction_pointer += 1
  168. if current_instruction_pointer >= len(self.instructions):
  169. return False
  170. inst = self.instructions[current_instruction_pointer]
  171. has_error_msg = False
  172. # DETECT RAISE_VARARGS or LOAD CONST
  173. if inst.opname == "LOAD_CONST":
  174. if not isinstance(inst.argval, str):
  175. return False
  176. self.LOAD_CONST(inst)
  177. has_error_msg = True
  178. # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION
  179. current_instruction_pointer += 1
  180. if current_instruction_pointer >= len(self.instructions):
  181. return False
  182. inst = self.instructions[current_instruction_pointer]
  183. if inst.opname != "CALL_FUNCTION":
  184. return False
  185. # CALL_FUNCTION should be followed by RAISE_VARARGS
  186. current_instruction_pointer += 1
  187. if current_instruction_pointer >= len(self.instructions):
  188. return False
  189. inst = self.instructions[current_instruction_pointer]
  190. if inst.opname != "RAISE_VARARGS":
  191. return False
  192. if not has_error_msg:
  193. # Push dummy value instead of error message
  194. self.push(ConstantVariable("assertion error"))
  195. return True
  196. def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
  197. def inner(self: "InstructionTranslatorBase", inst: Instruction):
  198. value: VariableTracker = self.pop()
  199. self.output.guards.update(value.guards)
  200. if (
  201. config.rewrite_assert_with_torch_assert
  202. and _detect_and_normalize_assert_statement(self, truth_fn, push)
  203. ):
  204. error_msg: VariableTracker = self.pop()
  205. self.output.guards.update(error_msg.guards)
  206. # Skip over things like `assert True`
  207. if value.is_python_constant() and bool(value.as_python_constant()):
  208. self.jump(inst)
  209. return
  210. # Manually insert torch._assert instead of python assert and jump over
  211. # assert related instructions as we don't need them anymore.
  212. self.output.create_proxy(
  213. "call_function",
  214. torch._assert,
  215. *proxy_args_kwargs((value, error_msg), {}),
  216. )
  217. self.jump(inst)
  218. return
  219. if value.is_python_constant():
  220. if truth_fn(value.as_python_constant()):
  221. push and self.push(value)
  222. self.jump(inst)
  223. elif (
  224. isinstance(value, (TensorVariable)) and self.should_compile_partial_graph()
  225. ):
  226. # compile a partial subgraph prefix then jump into user code
  227. if self.has_backedge():
  228. msg = (
  229. "Skipping frame because there is a graph break in a for/while loop"
  230. )
  231. log.debug(msg)
  232. raise exc.SkipFrame(msg)
  233. self.push(value)
  234. log.debug("generic_jump triggered compile")
  235. self.output.compile_subgraph(
  236. self,
  237. reason=GraphCompileReason(
  238. f"generic_jump {typestr(value)}", [self.frame_summary()]
  239. ),
  240. )
  241. self.pop()
  242. if_next = self.create_call_resume_at(self.next_instruction)
  243. push and self.push(value)
  244. if_jump = self.create_call_resume_at(inst.target)
  245. self.output.add_output_instructions(
  246. [(create_instruction(inst.opname, target=if_jump[0]))]
  247. + if_next
  248. + if_jump
  249. )
  250. elif isinstance(value, NNModuleVariable):
  251. # Equivant of "self.nn_module is not None"
  252. if truth_fn(value):
  253. push and self.push(value)
  254. self.jump(inst)
  255. elif isinstance(value, UserDefinedObjectVariable):
  256. x = value.var_getattr(self, "__bool__")
  257. # __bool__ is function
  258. if isinstance(x, UserMethodVariable):
  259. state = self.copy_graphstate()
  260. result = x.call_function(self, [], {})
  261. if isinstance(result, ConstantVariable) and isinstance(
  262. result.value, bool
  263. ):
  264. self.output.guards.update(result.guards)
  265. if truth_fn(result.value):
  266. push and self.push(value)
  267. self.jump(inst)
  268. else:
  269. # rollback to the state before the __bool__ inline
  270. self.restore_graphstate(state)
  271. unimplemented(
  272. "generic_jump on UserDefined with __bool__ returning non-constant"
  273. )
  274. # __bool__ is non-function or not existed in the user defined object
  275. else:
  276. if truth_fn(True):
  277. push and self.push(value)
  278. self.jump(inst)
  279. elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
  280. self
  281. ):
  282. if truth_fn(len(value.unpack_var_sequence(self))):
  283. push and self.push(value)
  284. self.jump(inst)
  285. elif isinstance(value, SymNodeVariable):
  286. eval_result = value.evaluate_expr(self.output)
  287. if truth_fn(eval_result):
  288. push and self.push(value)
  289. self.jump(inst)
  290. else:
  291. unimplemented(f"generic_jump {typestr(value)}")
  292. return inner
  293. explain = False
  294. def break_graph_if_unsupported(*, push):
  295. def decorator(inner_fn):
  296. @functools.wraps(inner_fn)
  297. def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
  298. state = self.copy_graphstate()
  299. reason = None
  300. try:
  301. return inner_fn(self, inst)
  302. except Unsupported as excp:
  303. if self.has_backedge() and self.should_compile_partial_graph():
  304. msg = "Skipping frame because there is a graph break in a for/while loop"
  305. log.debug(msg)
  306. raise exc.SkipFrame(msg) from excp
  307. if not self.should_compile_partial_graph():
  308. raise
  309. log.debug("break_graph_if_unsupported triggered compile", exc_info=True)
  310. user_stack = [self.frame_summary()] + list(reversed(excp.real_stack))
  311. user_stack_formatted = "".join(traceback.format_list(user_stack))
  312. frame_loc = (user_stack[-1].filename, user_stack[-1].lineno)
  313. # torch._dynamo.explain() formats this a little nicer, and presents a slightly
  314. # more actionable user code pointer
  315. if (
  316. config.print_graph_breaks
  317. and not explain
  318. and graph_break_dup_warning_checker.add(frame_loc)
  319. ):
  320. log.warning(
  321. f"Graph break: {excp} from user code at {user_stack_formatted}"
  322. )
  323. excp.remove_from_stats()
  324. excp.add_to_stats("graph_break")
  325. reason = GraphCompileReason(excp.msg, user_stack)
  326. self.restore_graphstate(state)
  327. self.output.compile_subgraph(self, reason=reason)
  328. self.popn(push - dis.stack_effect(inst.opcode, inst.arg))
  329. for _ in range(push):
  330. self.push(UnknownVariable())
  331. resume_call_insts = self.create_call_resume_at(self.next_instruction)
  332. # Check if there is a block stack entry with GradModeVariable. And
  333. # wrap the instruction causing the graph break inside a try..finally
  334. # block. See more details at
  335. # https://github.com/pytorch/torchdynamo/issues/207
  336. cleanup = []
  337. if len(self.block_stack) == 1 and isinstance(
  338. self.block_stack[0].with_context, GradModeVariable
  339. ):
  340. ctx_variable = self.block_stack[0].with_context
  341. cg = PyCodegen(self)
  342. setup_finally, cleanup = ctx_variable.reconstruct(
  343. cg, resume_call_insts[0]
  344. )
  345. self.output.add_output_instructions(setup_finally)
  346. self.output.add_output_instructions([inst])
  347. # Add the cleanup instructions from try..finally block
  348. self.output.add_output_instructions(cleanup)
  349. self.output.add_output_instructions(
  350. resume_call_insts,
  351. )
  352. return wrapper
  353. return decorator
  354. def is_none(x):
  355. return x is None
  356. def is_not_none(x):
  357. return x is not None
  358. class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState]):
  359. output: OutputGraph
  360. symbolic_locals: Dict[str, VariableTracker]
  361. symbolic_globals: Dict[str, VariableTracker]
  362. stack: List[VariableTracker]
  363. instruction_pointer: Optional[int]
  364. current_instruction: Instruction
  365. next_instruction: Optional[Instruction]
  366. block_stack: List[BlockStackEntry]
  367. lineno: int
  368. mutated_closure_cell_contents: Set[str]
  369. checkpoint: Optional[Tuple[Instruction, InstructionTranslatorGraphState]]
  370. random_calls: List[
  371. Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
  372. ]
  373. def has_backedge(self):
  374. cur_offset = self.current_instruction.offset
  375. assert self.instruction_pointer is not None
  376. for inst in self.instructions[self.instruction_pointer :]:
  377. if inst.opname in JUMP_OPNAMES:
  378. jump_offset = inst.argval
  379. if jump_offset < cur_offset:
  380. return True
  381. return False
  382. def cell_and_freevars(self):
  383. if not hasattr(self, "_cell_and_freevars"):
  384. self._cell_and_freevars = tuple(
  385. self.code_options["co_cellvars"] or []
  386. ) + tuple(self.code_options["co_freevars"] or [])
  387. return self._cell_and_freevars
  388. def prune_dead_locals(self):
  389. reads = livevars_analysis(self.instructions, self.current_instruction)
  390. # implicit use by super()
  391. # reads = reads | {"__class__"}
  392. # output variables?
  393. reads = reads | set(self.cell_and_freevars())
  394. self.symbolic_locals = collections.OrderedDict(
  395. [(k, v) for k, v in self.symbolic_locals.items() if k in reads]
  396. )
  397. self.output.side_effects.prune_dead_object_new(self)
  398. def call_function(
  399. self,
  400. fn: VariableTracker,
  401. args: List[VariableTracker],
  402. kwargs: Dict[str, VariableTracker],
  403. ):
  404. assert isinstance(fn, VariableTracker)
  405. assert isinstance(args, list)
  406. assert isinstance(kwargs, dict)
  407. assert all(
  408. isinstance(x, VariableTracker)
  409. for x in itertools.chain(args, kwargs.values())
  410. )
  411. self.push(fn.call_function(self, args, kwargs))
  412. def update_locals_and_stack(self, oldvar: VariableTracker, newvar: VariableTracker):
  413. def repl(v: VariableTracker):
  414. if v.mutable_local is oldvar.mutable_local:
  415. return newvar
  416. return v
  417. def skip(v: VariableTracker):
  418. return oldvar.mutable_local not in v.recursively_contains
  419. cache: Dict[int, Tuple[object, object]] = dict()
  420. self.output.side_effects.apply(repl, cache, skip_fn=skip)
  421. self.stack = [
  422. VariableTracker.apply(repl, x, cache, skip_fn=skip) for x in self.stack
  423. ]
  424. for k, x in self.symbolic_locals.items():
  425. self.symbolic_locals[k] = VariableTracker.apply(
  426. repl, x, cache, skip_fn=skip
  427. )
  428. def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker):
  429. if isinstance(oldvar.mutable_local, side_effects.MutableSideEffects):
  430. newvar = self.output.side_effects.mutation(oldvar, newvar)
  431. else:
  432. assert isinstance(oldvar.mutable_local, variables.base.MutableLocal)
  433. newvar = newvar.clone(mutable_local=variables.base.MutableLocal())
  434. self.update_locals_and_stack(oldvar, newvar)
  435. return newvar
  436. def inline_user_function_return(self, fn, args, kwargs):
  437. """
  438. A call to some user defined function by inlining it.
  439. """
  440. state = self.copy_graphstate()
  441. try:
  442. result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  443. self.output.guards.update(fn.guards)
  444. return result
  445. except Exception:
  446. self.restore_graphstate(state)
  447. raise
  448. def step(self):
  449. """Process exactly one instruction, return False we should exit"""
  450. assert isinstance(self.instruction_pointer, int)
  451. inst = self.instructions[self.instruction_pointer]
  452. self.current_instruction = inst
  453. self.instruction_pointer += 1
  454. if self.instruction_pointer < len(self.instructions):
  455. self.next_instruction = self.instructions[self.instruction_pointer]
  456. else:
  457. self.instruction_pointer = None
  458. self.next_instruction = None
  459. if inst.starts_line and self.lineno != inst.starts_line:
  460. self.lineno = inst.starts_line
  461. log.debug(f"TRACE starts_line {self.f_code.co_filename}:{self.lineno}")
  462. if len(self.stack) == 0 and self.should_compile_partial_graph():
  463. self.checkpoint = inst, self.copy_graphstate()
  464. log.debug(f"TRACE {inst.opname} {inst.argval} {self.stack}")
  465. try:
  466. if not hasattr(self, inst.opname):
  467. unimplemented(f"missing: {inst.opname}")
  468. getattr(self, inst.opname)(inst)
  469. return inst.opname != "RETURN_VALUE"
  470. except BackendCompilerFailed:
  471. raise
  472. except Unsupported as exc:
  473. exc.real_stack.append(self.frame_summary())
  474. if self.empty_checkpoint():
  475. raise
  476. log.debug("step triggered compile", exc_info=True)
  477. except Exception as exc:
  478. real_stack = getattr(exc, "real_stack", [])
  479. real_stack.append(self.frame_summary())
  480. exc.real_stack = real_stack # type: ignore[attr-defined]
  481. raise
  482. # generate code from checkpoint
  483. assert not self.output.output_instructions
  484. assert self.checkpoint is not None
  485. continue_inst, state = self.checkpoint
  486. self.restore_graphstate(state)
  487. self.output.compile_subgraph(
  488. self,
  489. partial_convert=True,
  490. reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
  491. )
  492. self.output.add_output_instructions(
  493. [create_jump_absolute(continue_inst)] + self.instructions
  494. )
  495. def run(self):
  496. try:
  497. self.output.push_tx(self)
  498. while (
  499. self.instruction_pointer is not None
  500. and not self.output.should_exit
  501. and self.step()
  502. ):
  503. pass
  504. except BackendCompilerFailed:
  505. raise
  506. except Exception as e:
  507. if config.replay_record_enabled:
  508. e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined]
  509. raise
  510. finally:
  511. self.output.pop_tx()
  512. # Cleanup the outputGraph to delete the held tensors. We perform the
  513. # cleanup only for InstructionTranslator and not
  514. # InliningInstructionTranslator. The InliningInstructionTranslator
  515. # mutates the output object and is restored to original state if
  516. # there was an exception.
  517. if isinstance(self, InstructionTranslator):
  518. self.output.cleanup()
  519. def push(self, val: Optional[VariableTracker]):
  520. assert val is None or isinstance(
  521. val, VariableTracker
  522. ), f"push expects VariableTracker, got {typestr(val)}"
  523. self.stack.append(val)
  524. def push_many(self, vals: List[VariableTracker]):
  525. for val in vals:
  526. self.push(val)
  527. def pop(self) -> VariableTracker:
  528. return self.stack.pop()
  529. def popn(self, n: int) -> List[VariableTracker]:
  530. assert n >= 0
  531. return list(reversed([self.pop() for _ in range(n)]))
  532. def LOAD_FAST(self, inst):
  533. name = inst.argval
  534. if name in self.f_locals and config.replay_record_enabled:
  535. self.exec_recorder.add_local_var(name, self.f_locals[name])
  536. if name.startswith(".") and name not in self.symbolic_locals:
  537. # This happens in dict/list comprehensions
  538. name = name.replace(".", "implicit")
  539. assert name not in self.cell_and_freevars()
  540. if name not in self.symbolic_locals:
  541. unimplemented("undefined LOAD_FAST")
  542. self.push(self.symbolic_locals[name])
  543. if name.startswith("___stack"):
  544. self.symbolic_locals.pop(name)
  545. def LOAD_DEREF(self, inst):
  546. assert inst.argval in self.cell_and_freevars()
  547. if inst.argval in self.f_locals and config.replay_record_enabled:
  548. self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval])
  549. if inst.argval not in self.symbolic_locals:
  550. unimplemented(f"undefined LOAD_DEREF {inst.argval}")
  551. self.push(self.symbolic_locals[inst.argval])
  552. def STORE_FAST(self, inst):
  553. self.symbolic_locals[inst.argval] = self.pop()
  554. def DELETE_FAST(self, inst):
  555. del self.symbolic_locals[inst.argval]
  556. STORE_DEREF = STORE_FAST
  557. def LOAD_CLOSURE(self, inst):
  558. self.push(ClosureVariable(name=inst.argval))
  559. def LOAD_CONST(self, inst):
  560. self.push(ConstantVariable(value=inst.argval))
  561. def get_global_source(self, name):
  562. if self.output.root_globals is self.f_globals:
  563. source = GlobalSource(name)
  564. else:
  565. if "__name__" in self.f_globals:
  566. source = AttrSource(
  567. self.import_source(self.f_globals["__name__"]), name
  568. )
  569. else:
  570. mangled_name = f"___unnamed_scope_{id(self.f_globals)}"
  571. if mangled_name not in self.output.root_globals:
  572. self.output.install_global(mangled_name, self.f_globals)
  573. source = GetItemSource(GlobalSource(mangled_name), name)
  574. return source
  575. def LOAD_GLOBAL(self, inst):
  576. name = inst.argval
  577. if config.replay_record_enabled:
  578. if name in self.f_globals:
  579. self.exec_recorder.add_global_var(name, self.f_globals[name])
  580. else:
  581. assert name in self.f_builtins
  582. self.exec_recorder.builtins[name] = self.f_builtins[name]
  583. if name in self.symbolic_globals:
  584. variable = self.output.side_effects[self.symbolic_globals[name]]
  585. self.push(self.output.side_effects.load_global(variable, name))
  586. return
  587. try:
  588. value = self.f_globals[name]
  589. except KeyError:
  590. return self.load_builtin(inst)
  591. source = self.get_global_source(name)
  592. self.push(VariableBuilder(self, source)(value))
  593. def STORE_GLOBAL(self, inst):
  594. value = self.pop()
  595. name = inst.argval
  596. source = self.get_global_source(name)
  597. if name not in self.symbolic_globals:
  598. self.symbolic_globals[name] = object() # sentinel object
  599. variable = self.output.side_effects.track_global_existing(
  600. source, self.symbolic_globals[name]
  601. )
  602. self.output.side_effects.store_global(variable, name, value)
  603. def import_source(self, module_name):
  604. """Create an alias to a module for use in guards"""
  605. if "torch_package" in module_name:
  606. value = torch.package.package_importer._package_imported_modules[
  607. module_name
  608. ]
  609. alias = (
  610. module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
  611. )
  612. else:
  613. value = importlib.import_module(module_name)
  614. alias = f"__import_{module_name.replace('.', '_dot_')}"
  615. f_globals = self.output.root_globals
  616. assert alias not in f_globals or f_globals[alias] is value
  617. f_globals[alias] = value
  618. self.output.update_co_names(alias)
  619. return GlobalSource(alias)
  620. def resolve_name(self, name, package, level):
  621. """
  622. Copied from the Cpython implementation of __import__
  623. Resolve a relative module name to an absolute one.
  624. https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902
  625. """
  626. bits = package.rsplit(".", level - 1)
  627. if len(bits) < level:
  628. raise ImportError("attempted relative import beyond top-level package")
  629. base = bits[0]
  630. return "{}.{}".format(base, name) if name else base
  631. def calc_package(self):
  632. """
  633. Copied from the Cpython implementation of __import__
  634. https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090
  635. """
  636. package = self.f_globals.get("__package__")
  637. spec = self.f_globals.get("__spec__")
  638. if package is not None:
  639. if spec is not None and package != spec.parent:
  640. log.warning(
  641. "__package__ != __spec__.parent "
  642. f"({package!r} != {spec.parent!r})",
  643. ImportWarning,
  644. stacklevel=3,
  645. ) # type: ignore[call-arg]
  646. return package
  647. elif spec is not None:
  648. return spec.parent
  649. else:
  650. log.warning(
  651. "can't resolve package from __spec__ or __package__, "
  652. "falling back on __name__ and __path__",
  653. ImportWarning,
  654. stacklevel=3,
  655. ) # type: ignore[call-arg]
  656. package = self.f_globals["__name__"]
  657. if "__path__" not in self.f_globals:
  658. package = package.rpartition(".")[0]
  659. return package
  660. def IMPORT_NAME(self, inst):
  661. level, fromlist = self.popn(2)
  662. level = level.as_python_constant()
  663. fromlist = fromlist.as_python_constant()
  664. module_name = inst.argval
  665. # Are we replaying? if so, load recorded module
  666. recorded_name = (
  667. f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}"
  668. )
  669. if recorded_name in self.f_globals:
  670. value = self.f_globals[recorded_name]
  671. source = GlobalSource(recorded_name)
  672. else:
  673. value = __import__(
  674. module_name,
  675. fromlist=fromlist,
  676. level=level,
  677. globals=self.f_globals,
  678. )
  679. if level != 0:
  680. pkg = self.calc_package()
  681. module_name = self.resolve_name(module_name, pkg, level)
  682. # For __import__, when the name variable is of the form package.module,
  683. # normally, the top-level package (the name up till the first dot) is
  684. # returned, not the module named by module_name. However, when a
  685. # non-empty fromlist argument is given, the module named by name is
  686. # returned. Therefore, we set the source correctly here.
  687. if not fromlist:
  688. top_level_module_name = module_name.partition(".")[0]
  689. source = self.import_source(top_level_module_name)
  690. else:
  691. source = self.import_source(module_name)
  692. if config.replay_record_enabled:
  693. self.exec_recorder.add_local_mod(recorded_name, value)
  694. if is_allowed(value):
  695. self.push(TorchVariable(value, source=source))
  696. elif istype(value, (types.ModuleType, DummyModule)):
  697. self.push(PythonModuleVariable(value, source=source))
  698. else:
  699. unimplemented(f"IMPORT_NAME {typestr(value)}")
  700. def IMPORT_FROM(self, inst):
  701. self.DUP_TOP(inst)
  702. self.LOAD_ATTR(inst)
  703. def load_builtin(self, inst):
  704. assert inst.argval in self.f_builtins
  705. val = self.f_builtins[inst.argval]
  706. if callable(val):
  707. assert is_builtin_callable(val)
  708. self.push(VariableBuilder(self, GlobalSource(inst.argval))(val))
  709. else:
  710. assert is_builtin_constant(val)
  711. self.push(ConstantVariable(value=val))
  712. def jump(self, inst):
  713. self.instruction_pointer = self.indexof[id(inst.target)]
  714. JUMP_FORWARD = jump
  715. JUMP_ABSOLUTE = jump
  716. POP_JUMP_IF_FALSE = generic_jump(operator.not_, False)
  717. POP_JUMP_IF_TRUE = generic_jump(operator.truth, False)
  718. JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True)
  719. JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True)
  720. def SETUP_LOOP(self, inst):
  721. # only exists in python<=3.7
  722. self.block_stack.append(BlockStackEntry(inst.target))
  723. def SETUP_EXCEPT(self, inst):
  724. # only exists in python<=3.7
  725. self.block_stack.append(BlockStackEntry(inst.target))
  726. def POP_BLOCK(self, inst):
  727. self.block_stack.pop()
  728. def SETUP_WITH(self, inst):
  729. ctx = self.pop()
  730. if not isinstance(ctx, ContextWrappingVariable):
  731. unimplemented(f"SETUP_WITH {ctx}")
  732. self.output.guards.update(ctx.guards)
  733. if isinstance(self, InstructionTranslator):
  734. self.block_stack.append(BlockStackEntry(inst.target, len(self.stack), ctx))
  735. else:
  736. # can't restore this while inlining
  737. self.block_stack.append(BlockStackEntry(inst.target))
  738. self.push(
  739. WithExitFunctionVariable(
  740. ctx,
  741. inst.target,
  742. **VariableTracker.propagate(ctx),
  743. )
  744. )
  745. self.push(ctx.enter(self))
  746. def SETUP_FINALLY(self, inst):
  747. self.block_stack.append(BlockStackEntry(inst.target))
  748. def BEGIN_FINALLY(self, inst):
  749. self.push(None)
  750. def WITH_CLEANUP_START(self, inst):
  751. exit, exc = self.popn(2)
  752. assert exc is None
  753. self.push(exc)
  754. self.push(exit.call_function(self, [ConstantVariable(None)] * 3, {}))
  755. def WITH_CLEANUP_FINISH(self, inst):
  756. self.popn(2)
  757. self.push(None)
  758. def END_FINALLY(self, inst):
  759. tos = self.pop()
  760. assert tos is None
  761. def FOR_ITER(self, inst):
  762. it = self.pop()
  763. if isinstance(it, ListIteratorVariable):
  764. self.output.guards.update(it.guards)
  765. try:
  766. val, next_iter = it.next_variables()
  767. self.replace_all(it, next_iter)
  768. self.push(next_iter)
  769. self.push(val)
  770. except StopIteration:
  771. self.jump(inst)
  772. else:
  773. unimplemented(f"FOR_ITER {typestr(it)}")
  774. def COMPARE_OP(self, inst):
  775. left, right = self.popn(2)
  776. left = left.as_specialized(self)
  777. right = right.as_specialized(self)
  778. options = VariableTracker.propagate([left, right])
  779. op = inst.argval
  780. supported_any = dict(
  781. itertools.chain(
  782. supported_tensor_comparison_ops.items(),
  783. supported_const_comparison_ops.items(),
  784. )
  785. )
  786. if (
  787. isinstance(
  788. left,
  789. (
  790. TensorVariable,
  791. SymNodeVariable,
  792. NNModuleVariable,
  793. BaseListVariable,
  794. UserDefinedVariable,
  795. BaseUserFunctionVariable,
  796. ConstDictVariable,
  797. ),
  798. )
  799. and isinstance(right, ConstantVariable)
  800. and right.value is None
  801. and op in supported_const_comparison_ops
  802. ):
  803. # <non-None> is None
  804. self.push(
  805. ConstantVariable(
  806. supported_const_comparison_ops[op](object(), right.value), **options
  807. )
  808. )
  809. elif (
  810. left.is_python_constant()
  811. and right.is_python_constant()
  812. and op in supported_any
  813. ):
  814. # constant fold
  815. self.push(
  816. ConstantVariable(
  817. supported_any[op](
  818. left.as_python_constant(), right.as_python_constant()
  819. ),
  820. **options,
  821. )
  822. )
  823. elif op in ("in", "not in"):
  824. self.push(right.call_method(self, "__contains__", [left], {}))
  825. if op == "not in":
  826. self.UNARY_NOT(inst)
  827. else:
  828. self.push(
  829. BuiltinVariable(supported_any[op], **options).call_function(
  830. self, [left, right], {}
  831. )
  832. )
  833. def GET_ITER(self, inst):
  834. self.call_function(BuiltinVariable(iter), [self.pop()], {})
  835. @break_graph_if_unsupported(push=1)
  836. def CALL_FUNCTION(self, inst):
  837. args = self.popn(inst.argval)
  838. fn = self.pop()
  839. self.call_function(fn, args, {})
  840. @break_graph_if_unsupported(push=1)
  841. def CALL_FUNCTION_EX(self, inst):
  842. if inst.argval == 0:
  843. kwargsvars = ConstDictVariable({}, dict)
  844. argsvars = self.pop()
  845. elif inst.argval == 1:
  846. kwargsvars = self.pop()
  847. argsvars = self.pop()
  848. else:
  849. unimplemented("CALL_FUNCTION_EX")
  850. fn = self.pop()
  851. self.output.guards.update(argsvars.guards)
  852. self.output.guards.update(kwargsvars.guards)
  853. if (
  854. isinstance(fn, GetAttrVariable)
  855. and isinstance(fn.obj, TensorVariable)
  856. and fn.name == "view"
  857. and isinstance(argsvars, (ConstantVariable, TensorVariable))
  858. ):
  859. # Hack to handle special case in some bert models. Converts
  860. # x.view(*shape) into x.view(shape), which is correct for view()
  861. # but not generally. See test_transpose_for_scores().
  862. argsvars = TupleVariable([argsvars])
  863. if not isinstance(
  864. argsvars, BaseListVariable
  865. ) and argsvars.has_unpack_var_sequence(self):
  866. argsvars = TupleVariable(argsvars.unpack_var_sequence(self))
  867. if not isinstance(argsvars, BaseListVariable) or not isinstance(
  868. kwargsvars, ConstDictVariable
  869. ):
  870. unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}")
  871. self.call_function(fn, argsvars.items, kwargsvars.items)
  872. @break_graph_if_unsupported(push=1)
  873. def CALL_FUNCTION_KW(self, inst):
  874. argnames = self.pop()
  875. args = self.popn(inst.argval)
  876. fn = self.pop()
  877. assert isinstance(argnames, ConstantVariable)
  878. argnames = argnames.value
  879. args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :]
  880. kwargs = dict(zip(argnames, kwargs_list))
  881. assert len(kwargs) == len(argnames)
  882. self.call_function(fn, args, kwargs)
  883. def LOAD_METHOD(self, inst):
  884. self.LOAD_ATTR(inst)
  885. self.push(self.pop())
  886. self.push(None)
  887. def CALL_METHOD(self, inst):
  888. args = self.popn(inst.argval)
  889. dummy = self.pop()
  890. assert dummy is None
  891. fn = self.pop()
  892. self.call_function(fn, args, {})
  893. def LOAD_ATTR(self, inst):
  894. obj = self.pop()
  895. result = BuiltinVariable(getattr).call_function(
  896. self, [obj, ConstantVariable(inst.argval)], {}
  897. )
  898. self.push(result)
  899. def STORE_ATTR(self, inst):
  900. prior = self.copy_graphstate()
  901. val, obj = self.popn(2)
  902. if isinstance(obj, NNModuleVariable):
  903. # We don't allow side effects during export
  904. # https://github.com/pytorch/torchdynamo/issues/1475
  905. assert (
  906. not self.export
  907. ), f"Mutating module attribute {inst.argval} during export."
  908. try:
  909. self.output.guards.update(
  910. BuiltinVariable(setattr)
  911. .call_function(self, [obj, ConstantVariable(inst.argval), val], {})
  912. .guards
  913. )
  914. return
  915. except Unsupported as e:
  916. if not self.should_compile_partial_graph():
  917. raise
  918. log.debug("STORE_ATTR triggered compile", exc_info=True)
  919. e.remove_from_stats()
  920. e.add_to_stats("graph_break")
  921. self.restore_graphstate(prior)
  922. # break the graph
  923. self.output.compile_subgraph(
  924. self, reason=GraphCompileReason("store_attr", [self.frame_summary()])
  925. )
  926. self.output.add_output_instructions([inst])
  927. self.popn(2)
  928. self.output.add_output_instructions(
  929. self.create_call_resume_at(self.next_instruction)
  930. )
  931. def create_call_resume_at(self, offset):
  932. raise AssertionError(
  933. f"create_call_resume_at not overridden by subclass {type(self)}"
  934. )
  935. def should_compile_partial_graph(self) -> bool:
  936. raise AssertionError(
  937. f"should_compile_partial_graph not overridden by subclass {type(self)}"
  938. )
  939. @break_graph_if_unsupported(push=0)
  940. def STORE_SUBSCR(self, inst):
  941. val, obj, key = self.popn(3)
  942. result = obj.call_method(self, "__setitem__", [key, val], {})
  943. # no result is pushed, so need to lift the guards to global
  944. self.output.guards.update(result.guards)
  945. def BUILD_TUPLE(self, inst):
  946. items = self.popn(inst.argval)
  947. options = VariableTracker.propagate(items)
  948. self.push(TupleVariable(items, **options))
  949. def BUILD_SLICE(self, inst):
  950. items = self.popn(inst.argval)
  951. options = VariableTracker.propagate(items)
  952. self.push(
  953. SliceVariable(
  954. [x.as_specialized(self) for x in items],
  955. **options,
  956. )
  957. )
  958. def BUILD_LIST(self, inst):
  959. items = self.popn(inst.argval)
  960. options = VariableTracker.propagate(items)
  961. self.push(ListVariable(items, mutable_local=MutableLocal(), **options))
  962. def BUILD_LIST_UNPACK(self, inst, cls=ListVariable):
  963. seqs = self.popn(inst.argval)
  964. options = VariableTracker.propagate(seqs)
  965. items = list()
  966. for seq in seqs:
  967. try:
  968. items.extend(seq.unpack_var_sequence(self))
  969. except NotImplementedError:
  970. unimplemented(f"BUILD_LIST_UNPACK {seq}")
  971. self.push(cls(items, mutable_local=MutableLocal(), **options))
  972. def BUILD_TUPLE_UNPACK(self, inst):
  973. self.BUILD_LIST_UNPACK(inst, cls=TupleVariable)
  974. BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK
  975. def BUILD_MAP(self, inst):
  976. items = self.popn(inst.argval * 2)
  977. options = VariableTracker.propagate(items)
  978. result = dict()
  979. for k, v in zip(items[::2], items[1::2]):
  980. assert isinstance(k, (ConstantVariable, EnumVariable)) or (
  981. isinstance(k, TensorVariable) and k.specialized_value is not None
  982. )
  983. result[ConstDictVariable.get_key(k)] = v
  984. assert len(result) == len(items) / 2
  985. self.push(
  986. ConstDictVariable(result, dict, mutable_local=MutableLocal(), **options)
  987. )
  988. def BUILD_CONST_KEY_MAP(self, inst):
  989. keys = self.pop()
  990. values = self.popn(inst.argval)
  991. options = VariableTracker.propagate([keys] + values)
  992. assert isinstance(keys, ConstantVariable)
  993. keys = keys.value
  994. assert istype(keys, tuple)
  995. assert len(keys) == len(values)
  996. self.push(
  997. ConstDictVariable(
  998. dict(zip(keys, values)),
  999. dict,
  1000. mutable_local=MutableLocal(),
  1001. **options,
  1002. )
  1003. )
  1004. def MAP_ADD(self, inst):
  1005. k, v = self.popn(2)
  1006. assert inst.argval > 0
  1007. obj = self.stack[-inst.arg]
  1008. assert isinstance(obj, ConstDictVariable)
  1009. assert obj.mutable_local
  1010. items = dict(obj.items)
  1011. items[k.as_python_constant()] = v
  1012. self.replace_all(
  1013. obj,
  1014. ConstDictVariable(
  1015. items,
  1016. obj.user_cls,
  1017. **VariableTracker.propagate([obj, k, v]),
  1018. ),
  1019. )
  1020. def LIST_APPEND(self, inst):
  1021. v = self.pop()
  1022. assert inst.argval > 0
  1023. obj = self.stack[-inst.arg]
  1024. assert isinstance(obj, ListVariable)
  1025. assert obj.mutable_local
  1026. # only copy if the new obj contains other mutables
  1027. new_rec_contains = obj.recursively_contains
  1028. if v.recursively_contains or v.mutable_local:
  1029. new_rec_contains = obj.recursively_contains.union(v.recursively_contains)
  1030. if v.mutable_local:
  1031. new_rec_contains.add(v.mutable_local)
  1032. self.replace_all(
  1033. obj,
  1034. ListVariable(
  1035. obj.items + [v],
  1036. recursively_contains=new_rec_contains,
  1037. regen_guards=False,
  1038. **VariableTracker.propagate([obj, v]),
  1039. ),
  1040. )
  1041. def MAKE_FUNCTION(self, inst):
  1042. flags = inst.arg
  1043. old_stack = list(self.stack)
  1044. fn_name = self.pop()
  1045. code = self.pop()
  1046. defaults = None
  1047. closure = None
  1048. annotations = None
  1049. kwdefaults = None
  1050. if flags & 0x08:
  1051. closure = self.pop()
  1052. if flags & 0x04:
  1053. annotations = self.pop()
  1054. if flags & 0x02:
  1055. kwdefaults = self.pop()
  1056. if flags & 0x01:
  1057. defaults = self.pop()
  1058. options = VariableTracker.propagate(old_stack[len(self.stack) :])
  1059. self.push(
  1060. NestedUserFunctionVariable(
  1061. fn_name,
  1062. code,
  1063. self.f_globals,
  1064. defaults,
  1065. kwdefaults,
  1066. annotations,
  1067. closure,
  1068. closure_scope=self,
  1069. **options,
  1070. )
  1071. )
  1072. def UNPACK_SEQUENCE(self, inst):
  1073. seq = self.pop()
  1074. if isinstance(seq, BaseListVariable):
  1075. self.output.guards.update(seq.guards)
  1076. val = seq.unpack_var_sequence(self)
  1077. elif seq.is_python_constant() and isinstance(seq, ConstantVariable):
  1078. val = seq.unpack_var_sequence(self)
  1079. elif isinstance(seq, TensorVariable):
  1080. val = seq.unpack_var_sequence(self, idxes=range(inst.argval))
  1081. elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable):
  1082. # x, y = a.shape
  1083. proxy = getattr(seq.obj.as_proxy(), seq.name)
  1084. options = VariableTracker.propagate(self)
  1085. val = [wrap_fx_proxy(self, proxy[i], **options) for i in range(inst.argval)]
  1086. else:
  1087. unimplemented(f"UNPACK_SEQUENCE {seq}")
  1088. assert len(val) == inst.argval
  1089. for i in reversed(val):
  1090. self.push(i)
  1091. def UNPACK_EX(self, inst):
  1092. assert 0 <= inst.argval <= 0xFFFF
  1093. prefix = inst.argval & 0xFF # low byte
  1094. suffix = inst.argval >> 8 # high byte
  1095. seq = self.pop()
  1096. options = VariableTracker.propagate(seq)
  1097. if seq.has_unpack_var_sequence(self):
  1098. vals = list(seq.unpack_var_sequence(self))
  1099. assert len(vals) >= prefix + suffix
  1100. vals_prefix = vals[:prefix]
  1101. vals_list = vals[prefix : len(vals) - suffix]
  1102. vals_suffix = vals[len(vals) - suffix :]
  1103. for item in reversed(vals_suffix):
  1104. self.push(item.add_options(options))
  1105. self.push(TupleVariable(vals_list, **options))
  1106. for item in reversed(vals_prefix):
  1107. self.push(item.add_options(options))
  1108. else:
  1109. unimplemented(f"UNPACK_EX {seq}")
  1110. def NOP(self, inst):
  1111. pass
  1112. def POP_TOP(self, inst):
  1113. self.pop()
  1114. def ROT_TWO(self, inst):
  1115. a = self.pop()
  1116. b = self.pop()
  1117. self.push(a)
  1118. self.push(b)
  1119. def ROT_THREE(self, inst):
  1120. a = self.pop()
  1121. b = self.pop()
  1122. c = self.pop()
  1123. self.push(a)
  1124. self.push(c)
  1125. self.push(b)
  1126. def ROT_FOUR(self, inst):
  1127. a = self.pop()
  1128. b = self.pop()
  1129. c = self.pop()
  1130. d = self.pop()
  1131. self.push(a)
  1132. self.push(d)
  1133. self.push(c)
  1134. self.push(b)
  1135. def DUP_TOP(self, inst):
  1136. a = self.pop()
  1137. self.push(a)
  1138. self.push(a)
  1139. def DUP_TOP_TWO(self, inst):
  1140. a = self.pop()
  1141. b = self.pop()
  1142. self.push(b)
  1143. self.push(a)
  1144. self.push(b)
  1145. self.push(a)
  1146. def FORMAT_VALUE(self, inst):
  1147. flags = inst.arg
  1148. if (flags & 0x04) == 0x04:
  1149. fmt_spec = self.pop()
  1150. else:
  1151. fmt_spec = ConstantVariable("")
  1152. value = self.pop()
  1153. if isinstance(value, SymNodeVariable):
  1154. value = ConstantVariable(str(value.sym_num))
  1155. if (flags & 0x03) == 0x01:
  1156. value = BuiltinVariable(str).call_function(self, [value], {})
  1157. elif (flags & 0x03) == 0x02:
  1158. value = BuiltinVariable(repr).call_function(self, [value], {})
  1159. elif (flags & 0x03) == 0x03:
  1160. value = BuiltinVariable(ascii).call_function(self, [value], {})
  1161. fmt_var = ConstantVariable(
  1162. "{:" + fmt_spec.as_python_constant() + "}"
  1163. ).add_options(fmt_spec)
  1164. self.call_function(BuiltinVariable(str.format), [fmt_var, value], {})
  1165. def BUILD_STRING(self, inst):
  1166. result = ""
  1167. for _ in range(inst.arg):
  1168. str_var = self.pop()
  1169. assert isinstance(str_var, ConstantVariable)
  1170. result = str_var.value + result
  1171. self.push(ConstantVariable(value=result))
  1172. def IS_OP(self, inst):
  1173. assert inst.argval == 0 or inst.argval == 1
  1174. if inst.argval == 0:
  1175. new_argval = "is"
  1176. else:
  1177. new_argval = "is not"
  1178. new_inst = create_instruction("COMPARE_OP", argval=new_argval)
  1179. self.COMPARE_OP(new_inst)
  1180. def CONTAINS_OP(self, inst):
  1181. assert inst.argval == 0 or inst.argval == 1
  1182. left, right = self.popn(2)
  1183. op = inst.argval
  1184. self.push(right.call_method(self, "__contains__", [left], {}))
  1185. if op == 1:
  1186. self.UNARY_NOT(inst)
  1187. def LIST_EXTEND(self, inst):
  1188. v = self.pop()
  1189. assert inst.argval > 0
  1190. obj = self.stack[-inst.arg]
  1191. assert isinstance(obj, ListVariable)
  1192. assert obj.mutable_local
  1193. obj.call_method(self, "extend", [v], {})
  1194. def LIST_TO_TUPLE(self, inst):
  1195. self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {}))
  1196. def DICT_MERGE(self, inst):
  1197. v = self.pop()
  1198. assert inst.argval > 0
  1199. obj = self.stack[-inst.arg]
  1200. assert isinstance(obj, ConstDictVariable)
  1201. assert obj.mutable_local
  1202. obj.call_method(self, "update", [v], {})
  1203. def GEN_START(self, inst):
  1204. self.pop()
  1205. def GET_LEN(self, inst):
  1206. tos = self.stack[-1]
  1207. if tos.is_python_constant():
  1208. self.push(ConstantVariable(len(tos.as_python_constant())))
  1209. else:
  1210. self.push(tos.call_method(self, "__len__", [], {}))
  1211. def MATCH_MAPPING(self, inst):
  1212. tos = self.stack[-1]
  1213. assert isinstance(tos, ConstDictVariable)
  1214. if isinstance(tos.items, collections.abc.Mapping):
  1215. self.push(ConstantVariable(True))
  1216. else:
  1217. self.push(ConstantVariable(False))
  1218. def MATCH_SEQUENCE(self, inst):
  1219. tos = self.stack[-1]
  1220. assert tos.is_python_constant()
  1221. tos_value = tos.as_python_constant()
  1222. if isinstance(tos_value, collections.abc.Sequence) and not isinstance(
  1223. tos_value, (str, bytes, bytearray)
  1224. ):
  1225. self.push(ConstantVariable(True))
  1226. else:
  1227. self.push(ConstantVariable(False))
  1228. def MATCH_KEYS(self, inst):
  1229. tos = self.stack[-1]
  1230. assert tos.is_python_constant()
  1231. keys = tos.as_python_constant()
  1232. tos1 = self.stack[-2]
  1233. assert isinstance(tos1, ConstDictVariable)
  1234. match_obj = tos1.items
  1235. if all(key in match_obj for key in keys):
  1236. self.push(TupleVariable([match_obj[key] for key in keys]))
  1237. self.push(ConstantVariable(True))
  1238. else:
  1239. self.push(ConstantVariable(None))
  1240. self.push(ConstantVariable(False))
  1241. UNARY_POSITIVE = stack_op(operator.pos)
  1242. UNARY_NEGATIVE = stack_op(operator.neg)
  1243. UNARY_NOT = stack_op(operator.not_)
  1244. UNARY_INVERT = stack_op(operator.invert)
  1245. BINARY_POWER = stack_op(operator.pow)
  1246. BINARY_MULTIPLY = stack_op(operator.mul)
  1247. BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul)
  1248. BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv)
  1249. BINARY_TRUE_DIVIDE = stack_op(operator.truediv)
  1250. BINARY_MODULO = stack_op(operator.mod)
  1251. BINARY_REMAINDER = stack_op(operator.mod)
  1252. BINARY_ADD = stack_op(operator.add)
  1253. BINARY_SUBTRACT = stack_op(operator.sub)
  1254. BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem))
  1255. BINARY_LSHIFT = stack_op(operator.lshift)
  1256. BINARY_RSHIFT = stack_op(operator.rshift)
  1257. BINARY_AND = stack_op(operator.and_)
  1258. BINARY_OR = stack_op(operator.or_)
  1259. BINARY_XOR = stack_op(operator.xor)
  1260. INPLACE_POWER = stack_op(operator.ipow)
  1261. INPLACE_MULTIPLY = stack_op(operator.imul)
  1262. INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul)
  1263. INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv)
  1264. INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv)
  1265. INPLACE_MODULO = stack_op(operator.imod)
  1266. INPLACE_REMAINDER = stack_op(operator.imod)
  1267. INPLACE_ADD = stack_op(operator.iadd)
  1268. INPLACE_SUBTRACT = stack_op(operator.isub)
  1269. INPLACE_LSHIFT = stack_op(operator.ilshift)
  1270. INPLACE_RSHIFT = stack_op(operator.irshift)
  1271. INPLACE_AND = stack_op(operator.iand)
  1272. INPLACE_XOR = stack_op(operator.ixor)
  1273. INPLACE_OR = stack_op(operator.ior)
  1274. # 3.11 opcodes
  1275. # note: passed opcodes are intentional
  1276. def RESUME(self, inst):
  1277. pass
  1278. def BINARY_OP(self, inst):
  1279. if sys.version_info >= (3, 11):
  1280. opname = dis._nb_ops[inst.arg][0][3:]
  1281. if opname.startswith("INPLACE"):
  1282. return getattr(self, "INPLACE_" + opname[8:])(inst)
  1283. return getattr(self, "BINARY_" + opname)(inst)
  1284. else:
  1285. unimplemented("BINARY_OP requires Python 3.11+")
  1286. def COPY(self, inst):
  1287. self.push(self.stack[-inst.arg])
  1288. def SWAP(self, inst):
  1289. self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1]
  1290. JUMP_BACKWARD = jump
  1291. JUMP_BACKWARD_NO_INTERRUPT = jump
  1292. POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False)
  1293. POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False)
  1294. POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False)
  1295. POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False)
  1296. POP_JUMP_FORWARD_IF_NOT_NONE = generic_jump(is_not_none, False)
  1297. POP_JUMP_BACKWARD_IF_NOT_NONE = generic_jump(is_not_none, False)
  1298. POP_JUMP_FORWARD_IF_NONE = generic_jump(is_none, False)
  1299. POP_JUMP_BACKWARD_IF_NONE = generic_jump(is_none, False)
  1300. def CACHE(self, inst):
  1301. pass
  1302. def copy_graphstate(self) -> InstructionTranslatorGraphState:
  1303. """Create a checkpoint of the current state by copying everything"""
  1304. return InstructionTranslatorGraphState(
  1305. self.output.copy_graphstate(),
  1306. collections.OrderedDict(self.symbolic_locals),
  1307. list(self.stack),
  1308. list(self.block_stack),
  1309. self.instruction_pointer,
  1310. self.current_instruction,
  1311. self.next_instruction,
  1312. self.lineno,
  1313. )
  1314. def restore_graphstate(self, state: InstructionTranslatorGraphState):
  1315. """Restore a checkpoint created by self.copy_graphstate()"""
  1316. (
  1317. output_state,
  1318. self.symbolic_locals,
  1319. self.stack,
  1320. self.block_stack,
  1321. self.instruction_pointer,
  1322. self.current_instruction,
  1323. self.next_instruction,
  1324. self.lineno,
  1325. ) = state
  1326. self.output.restore_graphstate(output_state)
  1327. def empty_checkpoint(self):
  1328. if self.checkpoint is None:
  1329. return True
  1330. output_graphstate = self.checkpoint[1][0]
  1331. graphstate = self.checkpoint[1][1:]
  1332. state = (*output_graphstate, *graphstate)
  1333. for obj in state:
  1334. if isinstance(obj, Sized):
  1335. if len(obj) != 0:
  1336. return False
  1337. return True
  1338. def format_frame_summary(self, additional_stack_frames=None):
  1339. if additional_stack_frames is None:
  1340. additional_stack_frames = []
  1341. return "".join(
  1342. traceback.format_list(
  1343. ([self.frame_summary()] + list(reversed(additional_stack_frames)))
  1344. )
  1345. )
  1346. def frame_summary(self):
  1347. return traceback.FrameSummary(
  1348. getattr(self.f_code, "co_filename", "<unknown>"),
  1349. self.lineno,
  1350. getattr(self.f_code, "co_name", "<unknown>"),
  1351. lookup_line=False,
  1352. )
  1353. def store_dict_key(self, name, value):
  1354. self.output.guards.add(
  1355. GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE)
  1356. )
  1357. if name not in self.output.root_globals:
  1358. self.output.install_global(name, weakref.ref(value))
  1359. @property
  1360. def fake_mode(self):
  1361. return self._fake_mode
  1362. def find_symbolic_locals_name(self, tensor_variable):
  1363. for key, value in self.symbolic_locals.items():
  1364. if value is tensor_variable:
  1365. return key
  1366. return None
  1367. def __init__(
  1368. self,
  1369. output: OutputGraph,
  1370. instructions: List[Instruction],
  1371. f_locals: Dict[str, Any],
  1372. f_globals: Dict[str, Any],
  1373. f_builtins: Dict[str, Any],
  1374. code_options: Dict[str, Any],
  1375. symbolic_locals: Dict[str, VariableTracker],
  1376. symbolic_globals: Dict[str, VariableTracker],
  1377. f_code: types.CodeType,
  1378. export: bool,
  1379. ):
  1380. super().__init__()
  1381. # Mutable state checkpointed by copy_graphstate()
  1382. self.output = output
  1383. self.symbolic_locals = symbolic_locals
  1384. self.symbolic_globals = symbolic_globals
  1385. self.stack = []
  1386. self.instruction_pointer = 0
  1387. self.current_instruction = create_instruction("NOP")
  1388. self.next_instruction = None
  1389. self.block_stack = []
  1390. self.lineno = code_options["co_firstlineno"]
  1391. # Properties of the input/output code
  1392. self.instructions: List[Instruction] = instructions
  1393. self.indexof: Dict[int, int] = {id(i): n for n, i in enumerate(instructions)}
  1394. self.f_locals: Dict[
  1395. str, Any
  1396. ] = f_locals # needed for recording accessed locals for replay
  1397. self.f_globals: Dict[str, Any] = f_globals
  1398. self.f_builtins: Dict[str, Any] = f_builtins
  1399. self.code_options: Dict[str, Any] = code_options
  1400. self.f_code: types.CodeType = f_code
  1401. # Execution record for replaying errors
  1402. self.exec_recorder = ExecutionRecorder(code=f_code, code_options=code_options)
  1403. # Stack of module being parsed, current nn.module is at the end of ordered dict
  1404. self.nn_module_stack: Dict[str, str] = {}
  1405. # Flag to indicate whether tracing is used for export.
  1406. self.export = export
  1407. self._fake_mode = output.tracing_context.fake_mode
  1408. self.checkpoint = None
  1409. self.random_calls = []
  1410. if sys.version_info >= (3, 10):
  1411. from .resume_execution import (
  1412. CO_ASYNC_GENERATOR,
  1413. CO_COROUTINE,
  1414. CO_GENERATOR,
  1415. CO_ITERABLE_COROUTINE,
  1416. )
  1417. if f_code.co_flags & (
  1418. CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR
  1419. ):
  1420. self.push(BuiltinVariable(None))
  1421. class InstructionTranslator(InstructionTranslatorBase):
  1422. def __init__(
  1423. self,
  1424. instructions: List[Instruction],
  1425. f_code,
  1426. f_locals,
  1427. f_globals,
  1428. f_builtins,
  1429. code_options,
  1430. compiler_fn,
  1431. one_graph,
  1432. export,
  1433. mutated_closure_cell_contents: Set[str],
  1434. ):
  1435. super().__init__(
  1436. output=OutputGraph(f_globals, code_options, compiler_fn, self),
  1437. instructions=instructions,
  1438. f_locals=f_locals,
  1439. f_globals=f_globals,
  1440. f_builtins=f_builtins,
  1441. code_options=code_options,
  1442. symbolic_locals=collections.OrderedDict(), # set below
  1443. # A global var is inserted only after a STORE_GLOBAL happens to it
  1444. symbolic_globals=collections.OrderedDict(),
  1445. f_code=f_code,
  1446. export=export,
  1447. )
  1448. self.one_graph: bool = one_graph
  1449. self.export = export
  1450. self.mutated_closure_cell_contents = mutated_closure_cell_contents
  1451. if self.export:
  1452. assert (
  1453. self.one_graph
  1454. ), "Export without one graph - something has gone wrong."
  1455. vars = list(code_options["co_varnames"])
  1456. vars.extend(x for x in self.cell_and_freevars() if x not in vars)
  1457. self.symbolic_locals = collections.OrderedDict(
  1458. (
  1459. k,
  1460. VariableBuilder(
  1461. self,
  1462. LocalInputSource(k, code_options["co_varnames"].index(k))
  1463. if k in code_options["co_varnames"]
  1464. else LocalSource((k)),
  1465. )(f_locals[k]),
  1466. )
  1467. for k in vars
  1468. if k in f_locals
  1469. )
  1470. # symbolic_locals contains the mapping from original f_locals to the
  1471. # Variable objects. During the Variable building phase, each object also
  1472. # has its associated guards. At the end, we will accumulate these
  1473. # guards.
  1474. #
  1475. # One way of handling these guards is to just accumulate all of them
  1476. # right now. However, many f_locals might not be used in the frame and
  1477. # thus can unnecessarily increase guard execution overhead. Therefore,
  1478. # we selectively update output.guards as we run the Python Bytecode
  1479. # instruction by instruction.
  1480. #
  1481. # An exception here is list/dict variables. Guards related to these
  1482. # variables have indexed access, like Tensor_match on args[0], and if
  1483. # args is not used in this frame, we will miss a LIST_LENGTH check like
  1484. # len(args) == 2. Missing the LIST_LENGTH check causes problem for the
  1485. # next invocation when args is not a list, and args[0] is a runtime
  1486. # error. Therefore, we recursively add guards for list/dict variable here.
  1487. for val in self.symbolic_locals.values():
  1488. if isinstance(
  1489. val, (ListIteratorVariable, BaseListVariable, ConstDictVariable)
  1490. ):
  1491. local_guards = VariableTracker.propagate(val)["guards"]
  1492. index_guards = [
  1493. guard
  1494. for guard in local_guards
  1495. if guard.create_fn
  1496. in (
  1497. GuardBuilder.LIST_LENGTH,
  1498. GuardBuilder.DICT_KEYS,
  1499. GuardBuilder.ODICT_KEYS,
  1500. GuardBuilder.TUPLE_ITERATOR_LEN,
  1501. )
  1502. ]
  1503. self.output.guards.update(index_guards)
  1504. self._freevars_ids = dict()
  1505. for name in self.code_options["co_freevars"]:
  1506. if name in f_locals:
  1507. self._freevars_ids[name] = id(f_locals[name])
  1508. def run(self):
  1509. _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
  1510. super().run()
  1511. def match_nested_cell(self, name, cell):
  1512. """Match a cell in this method to one in a function we are inlining"""
  1513. value = cell.cell_contents
  1514. # TODO(jansel): check the id of the cell rather than the contents
  1515. if id(value) != self._freevars_ids.get(name):
  1516. return None
  1517. return self.symbolic_locals[name]
  1518. def should_compile_partial_graph(self):
  1519. return all(b.can_restore() for b in self.block_stack) and not self.one_graph
  1520. def create_call_resume_at(self, inst):
  1521. self.instruction_pointer = None
  1522. if inst.opname == "RETURN_VALUE":
  1523. return [create_instruction("RETURN_VALUE")]
  1524. reads = livevars_analysis(self.instructions, inst)
  1525. argnames = tuple(
  1526. k
  1527. for k in self.symbolic_locals.keys()
  1528. if k in reads and k not in self.cell_and_freevars()
  1529. )
  1530. nargs = len(self.stack) + len(argnames)
  1531. name = unique_id(f"__resume_at_{inst.offset}")
  1532. new_code: types.CodeType = ContinueExecutionCache.lookup(
  1533. self.f_code,
  1534. self.lineno,
  1535. inst.offset,
  1536. len(self.stack),
  1537. argnames,
  1538. tuple(b.resume_fn() for b in self.block_stack),
  1539. )
  1540. cg = PyCodegen(self)
  1541. if new_code.co_freevars:
  1542. cg.make_function_with_closure(name, new_code, len(self.stack))
  1543. else:
  1544. self.output.install_global(
  1545. name, types.FunctionType(new_code, self.f_globals, name)
  1546. )
  1547. cg.extend_output(cg.load_function_name(name, len(self.stack)))
  1548. cg.extend_output([cg.create_load(k) for k in argnames])
  1549. cg.extend_output(
  1550. [
  1551. create_instruction("CALL_FUNCTION", nargs),
  1552. create_instruction("RETURN_VALUE"),
  1553. ]
  1554. )
  1555. return cg.get_instructions()
  1556. def RETURN_VALUE(self, inst):
  1557. if self.output.count_calls() == 0:
  1558. raise exc.SkipFrame("because no content in function call")
  1559. self.instruction_pointer = None
  1560. _step_logger()(
  1561. logging.INFO,
  1562. f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)",
  1563. )
  1564. log.debug("RETURN_VALUE triggered compile")
  1565. self.output.compile_subgraph(
  1566. self, reason=GraphCompileReason("return_value", [self.frame_summary()])
  1567. )
  1568. self.output.add_output_instructions([create_instruction("RETURN_VALUE")])
  1569. class InliningInstructionTranslator(InstructionTranslatorBase):
  1570. """Trace and inline a called method"""
  1571. symbolic_result: Optional[TensorVariable]
  1572. @classmethod
  1573. def inline_call(cls, parent, func, args, kwargs):
  1574. with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
  1575. return cls.inline_call_(parent, func, args, kwargs)
  1576. @staticmethod
  1577. def inline_call_(parent, func, args, kwargs):
  1578. assert isinstance(
  1579. func,
  1580. (UserFunctionVariable, NestedUserFunctionVariable),
  1581. )
  1582. if func.has_self():
  1583. unimplemented("inline with __self__")
  1584. if func.get_name() == "patched_init":
  1585. unimplemented("Patched init cannot be inlined.")
  1586. try:
  1587. if id(func.get_function()) in allowed_functions._disallowed_function_ids:
  1588. unimplemented(f"inlining disallowed: {func.get_function()}")
  1589. except NotImplementedError:
  1590. pass # closures
  1591. if skipfiles.check(
  1592. func.get_filename()
  1593. ) and not skipfiles.is_torch_inline_allowed(func.get_filename()):
  1594. unimplemented(
  1595. f"inline in skipfiles: {func.fn.__qualname__} | {func.get_name()} {func.get_filename()}"
  1596. )
  1597. try:
  1598. sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
  1599. except TypeError as e:
  1600. log.warning(
  1601. f"{func.get_filename()} {func.get_function()} {args} {kwargs} {e}"
  1602. )
  1603. unimplemented("arg mismatch inlining")
  1604. for v in itertools.chain(sub_locals.values(), closure_cells.values()):
  1605. if not isinstance(v, VariableTracker):
  1606. unimplemented(f"unconverted arg {v}")
  1607. code: types.CodeType = func.get_code()
  1608. if code.co_name in ("__setitem__", "__setattr__"):
  1609. unimplemented(f"inline {code.co_name}")
  1610. log.debug(f"INLINING {code} \n {dis.Bytecode(code).dis()} \n")
  1611. tracer: InliningInstructionTranslator
  1612. if is_generator(code):
  1613. tracer = InliningGeneratorInstructionTranslator(
  1614. parent, code, sub_locals, parent.symbolic_globals, closure_cells, func
  1615. )
  1616. else:
  1617. tracer = InliningInstructionTranslator(
  1618. parent, code, sub_locals, parent.symbolic_globals, closure_cells, func
  1619. )
  1620. try:
  1621. tracer.run()
  1622. except exc.SkipFrame as e:
  1623. msg = f"SKIPPED INLINING {code}: {e}"
  1624. log.debug(msg)
  1625. raise Unsupported(msg) from e
  1626. except Exception as e:
  1627. log.debug(f"FAILED INLINING {code}")
  1628. raise
  1629. assert tracer.symbolic_result is not None
  1630. func.export_freevars(parent, tracer)
  1631. if tracer.f_globals is parent.f_globals:
  1632. # Merge symbolic_globals back if parent and child are in the same namespace
  1633. parent.symbolic_globals.update(tracer.symbolic_globals)
  1634. log.debug(f"DONE INLINING {code}")
  1635. if is_generator(code):
  1636. assert isinstance(tracer, InliningGeneratorInstructionTranslator)
  1637. assert tracer.symbolic_result.as_python_constant() is None
  1638. return ListIteratorVariable(
  1639. tracer.generated_items,
  1640. mutable_local=MutableLocal(),
  1641. **VariableTracker.propagate(tracer.symbolic_result),
  1642. )
  1643. else:
  1644. return tracer.symbolic_result
  1645. def __init__(
  1646. self,
  1647. parent: InstructionTranslatorBase,
  1648. code: types.CodeType,
  1649. symbolic_locals: Dict[str, VariableTracker],
  1650. symbolic_globals: Dict[str, VariableTracker],
  1651. closure_cells: Dict[str, VariableTracker],
  1652. funcvar: BaseUserFunctionVariable,
  1653. ):
  1654. f_globals = funcvar.get_globals()
  1655. f_builtins = f_globals["__builtins__"]
  1656. if not isinstance(f_builtins, dict):
  1657. f_builtins = f_builtins.__dict__
  1658. super().__init__(
  1659. output=parent.output,
  1660. f_locals={},
  1661. f_globals=f_globals,
  1662. f_builtins=f_builtins,
  1663. symbolic_locals=symbolic_locals,
  1664. symbolic_globals=symbolic_globals,
  1665. instructions=cleaned_instructions(code),
  1666. code_options={k: getattr(code, k) for k in dir(code)},
  1667. f_code=code,
  1668. export=parent.export,
  1669. )
  1670. self.parent = parent
  1671. self.symbolic_result = None
  1672. self.closure_cells = closure_cells
  1673. self.nn_module_stack = parent.nn_module_stack.copy()
  1674. @property
  1675. def fake_mode(self):
  1676. return self.parent.fake_mode
  1677. def STORE_DEREF(self, inst):
  1678. if inst.argval in self.closure_cells:
  1679. cell = self.closure_cells[inst.argval]
  1680. val = self.pop()
  1681. if isinstance(cell, ClosureVariable):
  1682. self.output.root_tx.symbolic_locals[cell.name] = val
  1683. else:
  1684. self.output.side_effects.store_cell(cell, val)
  1685. else:
  1686. maybe_cell = self.symbolic_locals.get(inst.argval)
  1687. if isinstance(
  1688. maybe_cell,
  1689. variables.NewCellVariable,
  1690. ):
  1691. self.output.side_effects.store_cell(
  1692. self.symbolic_locals[inst.argval], self.pop()
  1693. )
  1694. else:
  1695. if (
  1696. maybe_cell is not None
  1697. and maybe_cell.source.name()
  1698. not in self.parent.mutated_closure_cell_contents
  1699. ):
  1700. # Why is the source name here unique?
  1701. # mutated_closure_cell_contents is a per-frame
  1702. # concept, and sources identify, e.g., particular
  1703. # locals from the frame. If you had two locals,
  1704. # they'll get different source names, and therefore
  1705. # differ here.
  1706. self.parent.mutated_closure_cell_contents.add(
  1707. maybe_cell.source.name()
  1708. )
  1709. raise exc.RestartAnalysis()
  1710. unimplemented("write to __closure__ while inlining")
  1711. def LOAD_DEREF(self, inst):
  1712. if inst.argval in self.closure_cells:
  1713. cell = self.closure_cells[inst.argval]
  1714. if isinstance(cell, ClosureVariable):
  1715. self.push(self.output.root_tx.symbolic_locals[cell.name])
  1716. else:
  1717. self.push(self.output.side_effects.load_cell(cell))
  1718. else:
  1719. maybe_sym_local = self.symbolic_locals.get(inst.argval, None)
  1720. if isinstance(maybe_sym_local, variables.NewCellVariable):
  1721. self.push(self.output.side_effects.load_cell(maybe_sym_local))
  1722. else:
  1723. super().LOAD_DEREF(inst)
  1724. def LOAD_CLOSURE(self, inst):
  1725. assert inst.argval in self.cell_and_freevars()
  1726. self.push(self.closure_cells[inst.argval])
  1727. def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker):
  1728. newvar = super().replace_all(oldvar, newvar)
  1729. # recursively check and update parent's locals and stack in case oldvar is from parent
  1730. translator: InstructionTranslatorBase = self
  1731. while hasattr(translator, "parent"):
  1732. translator = translator.parent # type: ignore[attr-defined]
  1733. translator.update_locals_and_stack(oldvar, newvar)
  1734. return newvar
  1735. def should_compile_partial_graph(self):
  1736. return False # inlining functions is all-or-nothing
  1737. def create_call_resume_at(self, offset):
  1738. unimplemented("cant resume while inlining")
  1739. def RETURN_VALUE(self, inst):
  1740. self.symbolic_result = self.pop()
  1741. self.instruction_pointer = None
  1742. class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
  1743. generated_items: List[VariableTracker]
  1744. def __init__(self, *args, **kwargs):
  1745. super().__init__(*args, **kwargs)
  1746. self.generated_items = []
  1747. def YIELD_VALUE(self, inst: Instruction):
  1748. self.generated_items.append(self.pop())
  1749. # TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE
  1750. self.push(ConstantVariable(None))