123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007 |
- import collections
- import dataclasses
- import dis
- import functools
- import importlib
- import inspect
- import itertools
- import logging
- import operator
- import sys
- import traceback
- import types
- import typing
- import weakref
- from collections.abc import Sized
- from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
- from unittest.mock import patch
- import torch
- from torch._guards import Checkpointable
- from . import (
- allowed_functions,
- config,
- exc,
- logging as torchdynamo_logging,
- side_effects,
- skipfiles,
- variables,
- )
- from .allowed_functions import is_allowed, is_builtin_callable, is_builtin_constant
- from .bytecode_analysis import JUMP_OPNAMES, livevars_analysis
- from .bytecode_transformation import (
- cleaned_instructions,
- create_instruction,
- create_jump_absolute,
- Instruction,
- is_generator,
- unique_id,
- )
- from .codegen import PyCodegen
- from .exc import BackendCompilerFailed, unimplemented, Unsupported
- from .guards import GuardBuilder
- from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState
- from .replay_record import DummyModule, ExecutionRecorder
- from .resume_execution import ContinueExecutionCache, ReenterWith
- from .source import (
- AttrSource,
- GetItemSource,
- GlobalSource,
- GlobalWeakRefSource,
- LocalInputSource,
- LocalSource,
- )
- from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs
- from .variables.base import MutableLocal, typestr, VariableTracker
- from .variables.builder import VariableBuilder, wrap_fx_proxy
- from .variables.builtin import BuiltinVariable
- from .variables.constant import ConstantVariable, EnumVariable
- from .variables.dicts import ConstDictVariable
- from .variables.functions import (
- BaseUserFunctionVariable,
- NestedUserFunctionVariable,
- UserFunctionVariable,
- UserMethodVariable,
- )
- from .variables.lists import (
- BaseListVariable,
- ListIteratorVariable,
- ListVariable,
- SliceVariable,
- TupleVariable,
- )
- from .variables.misc import (
- ClosureVariable,
- ContextWrappingVariable,
- GetAttrVariable,
- GradModeVariable,
- PythonModuleVariable,
- UnknownVariable,
- WithExitFunctionVariable,
- )
- from .variables.nn_module import NNModuleVariable
- from .variables.tensor import (
- supported_const_comparison_ops,
- supported_tensor_comparison_ops,
- SymNodeVariable,
- TensorVariable,
- )
- from .variables.torch import TorchVariable
- from .variables.user_defined import UserDefinedObjectVariable, UserDefinedVariable
- log = logging.getLogger(__name__)
- @functools.lru_cache(None)
- def _step_logger():
- return torchdynamo_logging.get_step_logger(log)
- @dataclasses.dataclass
- class BlockStackEntry:
- target: Instruction
- stack_index: Optional[int] = None
- with_context: ContextWrappingVariable = None
- def can_restore(self):
- return self.with_context is not None
- def resume_fn(self):
- assert self.stack_index is not None
- if self.with_context and self.with_context.target_values:
- return ReenterWith(self.stack_index, tuple(self.with_context.target_values))
- else:
- return ReenterWith(self.stack_index)
- def exit(self, tx):
- return self.with_context.exit(tx)
- class InstructionTranslatorGraphState(NamedTuple):
- output: OutputGraphState
- symbolic_locals: Dict[str, VariableTracker]
- stack: List[VariableTracker]
- block_stack: List[BlockStackEntry]
- instruction_pointer: Optional[int]
- current_instruction: Instruction
- next_instruction: Optional[Instruction]
- lineno: int
- def diff(self, other: "InstructionTranslatorGraphState") -> Optional[str]:
- for k in self._fields:
- if k == "output":
- return self.output.diff(other.output, prefix=f"{k}.")
- sv = getattr(self, k)
- ov = getattr(other, k)
- if sv != ov:
- return f"{k} mismatch: {sv} != {ov}"
- return None
- def stack_op(fn: typing.Callable[..., object]):
- nargs = len(inspect.signature(fn).parameters)
- fn_var = BuiltinVariable(fn)
- @functools.wraps(fn)
- def impl(self: "InstructionTranslatorBase", inst: Instruction):
- self.push(fn_var.call_function(self, self.popn(nargs), {}))
- return impl
- def _detect_and_normalize_assert_statement(
- self: "InstructionTranslatorBase",
- truth_fn: typing.Callable[[object], bool],
- push: bool,
- ):
- # Detect if this jump instruction is assert and normalize the assert
- # by pushing dummy error message when nothing is given.
- #
- # Python 3.9 assertion is in following format:
- # 18 POP_JUMP_IF_TRUE 28
- # 20 LOAD_ASSERTION_ERROR
- # 22 LOAD_CONST 3 ('Assert message') -> optional instruction
- # 24 CALL_FUNCTION 1 -> optional instruction
- # 26 RAISE_VARARGS
- #
- # Python 3.8 assertion is in following format:
- # 18 POP_JUMP_IF_TRUE 28
- # 20 LOAD_GLOBAL 0 (Assertion type)
- # 22 LOAD_CONST 3 ('Assert message') -> optional instruction
- # 24 CALL_FUNCTION 1 -> optional instruction
- # 26 RAISE_VARARGS 1
- if (truth_fn is not operator.truth) or push:
- return False
- assert isinstance(self.instruction_pointer, int)
- current_instruction_pointer = self.instruction_pointer
- inst = self.instructions[current_instruction_pointer]
- # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0
- if sys.version_info < (3, 9):
- if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError":
- return False
- else:
- if inst.opname != "LOAD_ASSERTION_ERROR":
- return False
- current_instruction_pointer += 1
- if current_instruction_pointer >= len(self.instructions):
- return False
- inst = self.instructions[current_instruction_pointer]
- has_error_msg = False
- # DETECT RAISE_VARARGS or LOAD CONST
- if inst.opname == "LOAD_CONST":
- if not isinstance(inst.argval, str):
- return False
- self.LOAD_CONST(inst)
- has_error_msg = True
- # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION
- current_instruction_pointer += 1
- if current_instruction_pointer >= len(self.instructions):
- return False
- inst = self.instructions[current_instruction_pointer]
- if inst.opname != "CALL_FUNCTION":
- return False
- # CALL_FUNCTION should be followed by RAISE_VARARGS
- current_instruction_pointer += 1
- if current_instruction_pointer >= len(self.instructions):
- return False
- inst = self.instructions[current_instruction_pointer]
- if inst.opname != "RAISE_VARARGS":
- return False
- if not has_error_msg:
- # Push dummy value instead of error message
- self.push(ConstantVariable("assertion error"))
- return True
- def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
- def inner(self: "InstructionTranslatorBase", inst: Instruction):
- value: VariableTracker = self.pop()
- self.output.guards.update(value.guards)
- if (
- config.rewrite_assert_with_torch_assert
- and _detect_and_normalize_assert_statement(self, truth_fn, push)
- ):
- error_msg: VariableTracker = self.pop()
- self.output.guards.update(error_msg.guards)
- # Skip over things like `assert True`
- if value.is_python_constant() and bool(value.as_python_constant()):
- self.jump(inst)
- return
- # Manually insert torch._assert instead of python assert and jump over
- # assert related instructions as we don't need them anymore.
- self.output.create_proxy(
- "call_function",
- torch._assert,
- *proxy_args_kwargs((value, error_msg), {}),
- )
- self.jump(inst)
- return
- if value.is_python_constant():
- if truth_fn(value.as_python_constant()):
- push and self.push(value)
- self.jump(inst)
- elif (
- isinstance(value, (TensorVariable)) and self.should_compile_partial_graph()
- ):
- # compile a partial subgraph prefix then jump into user code
- if self.has_backedge():
- msg = (
- "Skipping frame because there is a graph break in a for/while loop"
- )
- log.debug(msg)
- raise exc.SkipFrame(msg)
- self.push(value)
- log.debug("generic_jump triggered compile")
- self.output.compile_subgraph(
- self,
- reason=GraphCompileReason(
- f"generic_jump {typestr(value)}", [self.frame_summary()]
- ),
- )
- self.pop()
- if_next = self.create_call_resume_at(self.next_instruction)
- push and self.push(value)
- if_jump = self.create_call_resume_at(inst.target)
- self.output.add_output_instructions(
- [(create_instruction(inst.opname, target=if_jump[0]))]
- + if_next
- + if_jump
- )
- elif isinstance(value, NNModuleVariable):
- # Equivant of "self.nn_module is not None"
- if truth_fn(value):
- push and self.push(value)
- self.jump(inst)
- elif isinstance(value, UserDefinedObjectVariable):
- x = value.var_getattr(self, "__bool__")
- # __bool__ is function
- if isinstance(x, UserMethodVariable):
- state = self.copy_graphstate()
- result = x.call_function(self, [], {})
- if isinstance(result, ConstantVariable) and isinstance(
- result.value, bool
- ):
- self.output.guards.update(result.guards)
- if truth_fn(result.value):
- push and self.push(value)
- self.jump(inst)
- else:
- # rollback to the state before the __bool__ inline
- self.restore_graphstate(state)
- unimplemented(
- "generic_jump on UserDefined with __bool__ returning non-constant"
- )
- # __bool__ is non-function or not existed in the user defined object
- else:
- if truth_fn(True):
- push and self.push(value)
- self.jump(inst)
- elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
- self
- ):
- if truth_fn(len(value.unpack_var_sequence(self))):
- push and self.push(value)
- self.jump(inst)
- elif isinstance(value, SymNodeVariable):
- eval_result = value.evaluate_expr(self.output)
- if truth_fn(eval_result):
- push and self.push(value)
- self.jump(inst)
- else:
- unimplemented(f"generic_jump {typestr(value)}")
- return inner
- explain = False
- def break_graph_if_unsupported(*, push):
- def decorator(inner_fn):
- @functools.wraps(inner_fn)
- def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
- state = self.copy_graphstate()
- reason = None
- try:
- return inner_fn(self, inst)
- except Unsupported as excp:
- if self.has_backedge() and self.should_compile_partial_graph():
- msg = "Skipping frame because there is a graph break in a for/while loop"
- log.debug(msg)
- raise exc.SkipFrame(msg) from excp
- if not self.should_compile_partial_graph():
- raise
- log.debug("break_graph_if_unsupported triggered compile", exc_info=True)
- user_stack = [self.frame_summary()] + list(reversed(excp.real_stack))
- user_stack_formatted = "".join(traceback.format_list(user_stack))
- frame_loc = (user_stack[-1].filename, user_stack[-1].lineno)
- # torch._dynamo.explain() formats this a little nicer, and presents a slightly
- # more actionable user code pointer
- if (
- config.print_graph_breaks
- and not explain
- and graph_break_dup_warning_checker.add(frame_loc)
- ):
- log.warning(
- f"Graph break: {excp} from user code at {user_stack_formatted}"
- )
- excp.remove_from_stats()
- excp.add_to_stats("graph_break")
- reason = GraphCompileReason(excp.msg, user_stack)
- self.restore_graphstate(state)
- self.output.compile_subgraph(self, reason=reason)
- self.popn(push - dis.stack_effect(inst.opcode, inst.arg))
- for _ in range(push):
- self.push(UnknownVariable())
- resume_call_insts = self.create_call_resume_at(self.next_instruction)
- # Check if there is a block stack entry with GradModeVariable. And
- # wrap the instruction causing the graph break inside a try..finally
- # block. See more details at
- # https://github.com/pytorch/torchdynamo/issues/207
- cleanup = []
- if len(self.block_stack) == 1 and isinstance(
- self.block_stack[0].with_context, GradModeVariable
- ):
- ctx_variable = self.block_stack[0].with_context
- cg = PyCodegen(self)
- setup_finally, cleanup = ctx_variable.reconstruct(
- cg, resume_call_insts[0]
- )
- self.output.add_output_instructions(setup_finally)
- self.output.add_output_instructions([inst])
- # Add the cleanup instructions from try..finally block
- self.output.add_output_instructions(cleanup)
- self.output.add_output_instructions(
- resume_call_insts,
- )
- return wrapper
- return decorator
- def is_none(x):
- return x is None
- def is_not_none(x):
- return x is not None
- class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState]):
- output: OutputGraph
- symbolic_locals: Dict[str, VariableTracker]
- symbolic_globals: Dict[str, VariableTracker]
- stack: List[VariableTracker]
- instruction_pointer: Optional[int]
- current_instruction: Instruction
- next_instruction: Optional[Instruction]
- block_stack: List[BlockStackEntry]
- lineno: int
- mutated_closure_cell_contents: Set[str]
- checkpoint: Optional[Tuple[Instruction, InstructionTranslatorGraphState]]
- random_calls: List[
- Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
- ]
- def has_backedge(self):
- cur_offset = self.current_instruction.offset
- assert self.instruction_pointer is not None
- for inst in self.instructions[self.instruction_pointer :]:
- if inst.opname in JUMP_OPNAMES:
- jump_offset = inst.argval
- if jump_offset < cur_offset:
- return True
- return False
- def cell_and_freevars(self):
- if not hasattr(self, "_cell_and_freevars"):
- self._cell_and_freevars = tuple(
- self.code_options["co_cellvars"] or []
- ) + tuple(self.code_options["co_freevars"] or [])
- return self._cell_and_freevars
- def prune_dead_locals(self):
- reads = livevars_analysis(self.instructions, self.current_instruction)
- # implicit use by super()
- # reads = reads | {"__class__"}
- # output variables?
- reads = reads | set(self.cell_and_freevars())
- self.symbolic_locals = collections.OrderedDict(
- [(k, v) for k, v in self.symbolic_locals.items() if k in reads]
- )
- self.output.side_effects.prune_dead_object_new(self)
- def call_function(
- self,
- fn: VariableTracker,
- args: List[VariableTracker],
- kwargs: Dict[str, VariableTracker],
- ):
- assert isinstance(fn, VariableTracker)
- assert isinstance(args, list)
- assert isinstance(kwargs, dict)
- assert all(
- isinstance(x, VariableTracker)
- for x in itertools.chain(args, kwargs.values())
- )
- self.push(fn.call_function(self, args, kwargs))
- def update_locals_and_stack(self, oldvar: VariableTracker, newvar: VariableTracker):
- def repl(v: VariableTracker):
- if v.mutable_local is oldvar.mutable_local:
- return newvar
- return v
- def skip(v: VariableTracker):
- return oldvar.mutable_local not in v.recursively_contains
- cache: Dict[int, Tuple[object, object]] = dict()
- self.output.side_effects.apply(repl, cache, skip_fn=skip)
- self.stack = [
- VariableTracker.apply(repl, x, cache, skip_fn=skip) for x in self.stack
- ]
- for k, x in self.symbolic_locals.items():
- self.symbolic_locals[k] = VariableTracker.apply(
- repl, x, cache, skip_fn=skip
- )
- def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker):
- if isinstance(oldvar.mutable_local, side_effects.MutableSideEffects):
- newvar = self.output.side_effects.mutation(oldvar, newvar)
- else:
- assert isinstance(oldvar.mutable_local, variables.base.MutableLocal)
- newvar = newvar.clone(mutable_local=variables.base.MutableLocal())
- self.update_locals_and_stack(oldvar, newvar)
- return newvar
- def inline_user_function_return(self, fn, args, kwargs):
- """
- A call to some user defined function by inlining it.
- """
- state = self.copy_graphstate()
- try:
- result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
- self.output.guards.update(fn.guards)
- return result
- except Exception:
- self.restore_graphstate(state)
- raise
- def step(self):
- """Process exactly one instruction, return False we should exit"""
- assert isinstance(self.instruction_pointer, int)
- inst = self.instructions[self.instruction_pointer]
- self.current_instruction = inst
- self.instruction_pointer += 1
- if self.instruction_pointer < len(self.instructions):
- self.next_instruction = self.instructions[self.instruction_pointer]
- else:
- self.instruction_pointer = None
- self.next_instruction = None
- if inst.starts_line and self.lineno != inst.starts_line:
- self.lineno = inst.starts_line
- log.debug(f"TRACE starts_line {self.f_code.co_filename}:{self.lineno}")
- if len(self.stack) == 0 and self.should_compile_partial_graph():
- self.checkpoint = inst, self.copy_graphstate()
- log.debug(f"TRACE {inst.opname} {inst.argval} {self.stack}")
- try:
- if not hasattr(self, inst.opname):
- unimplemented(f"missing: {inst.opname}")
- getattr(self, inst.opname)(inst)
- return inst.opname != "RETURN_VALUE"
- except BackendCompilerFailed:
- raise
- except Unsupported as exc:
- exc.real_stack.append(self.frame_summary())
- if self.empty_checkpoint():
- raise
- log.debug("step triggered compile", exc_info=True)
- except Exception as exc:
- real_stack = getattr(exc, "real_stack", [])
- real_stack.append(self.frame_summary())
- exc.real_stack = real_stack # type: ignore[attr-defined]
- raise
- # generate code from checkpoint
- assert not self.output.output_instructions
- assert self.checkpoint is not None
- continue_inst, state = self.checkpoint
- self.restore_graphstate(state)
- self.output.compile_subgraph(
- self,
- partial_convert=True,
- reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
- )
- self.output.add_output_instructions(
- [create_jump_absolute(continue_inst)] + self.instructions
- )
- def run(self):
- try:
- self.output.push_tx(self)
- while (
- self.instruction_pointer is not None
- and not self.output.should_exit
- and self.step()
- ):
- pass
- except BackendCompilerFailed:
- raise
- except Exception as e:
- if config.replay_record_enabled:
- e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined]
- raise
- finally:
- self.output.pop_tx()
- # Cleanup the outputGraph to delete the held tensors. We perform the
- # cleanup only for InstructionTranslator and not
- # InliningInstructionTranslator. The InliningInstructionTranslator
- # mutates the output object and is restored to original state if
- # there was an exception.
- if isinstance(self, InstructionTranslator):
- self.output.cleanup()
- def push(self, val: Optional[VariableTracker]):
- assert val is None or isinstance(
- val, VariableTracker
- ), f"push expects VariableTracker, got {typestr(val)}"
- self.stack.append(val)
- def push_many(self, vals: List[VariableTracker]):
- for val in vals:
- self.push(val)
- def pop(self) -> VariableTracker:
- return self.stack.pop()
- def popn(self, n: int) -> List[VariableTracker]:
- assert n >= 0
- return list(reversed([self.pop() for _ in range(n)]))
- def LOAD_FAST(self, inst):
- name = inst.argval
- if name in self.f_locals and config.replay_record_enabled:
- self.exec_recorder.add_local_var(name, self.f_locals[name])
- if name.startswith(".") and name not in self.symbolic_locals:
- # This happens in dict/list comprehensions
- name = name.replace(".", "implicit")
- assert name not in self.cell_and_freevars()
- if name not in self.symbolic_locals:
- unimplemented("undefined LOAD_FAST")
- self.push(self.symbolic_locals[name])
- if name.startswith("___stack"):
- self.symbolic_locals.pop(name)
- def LOAD_DEREF(self, inst):
- assert inst.argval in self.cell_and_freevars()
- if inst.argval in self.f_locals and config.replay_record_enabled:
- self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval])
- if inst.argval not in self.symbolic_locals:
- unimplemented(f"undefined LOAD_DEREF {inst.argval}")
- self.push(self.symbolic_locals[inst.argval])
- def STORE_FAST(self, inst):
- self.symbolic_locals[inst.argval] = self.pop()
- def DELETE_FAST(self, inst):
- del self.symbolic_locals[inst.argval]
- STORE_DEREF = STORE_FAST
- def LOAD_CLOSURE(self, inst):
- self.push(ClosureVariable(name=inst.argval))
- def LOAD_CONST(self, inst):
- self.push(ConstantVariable(value=inst.argval))
- def get_global_source(self, name):
- if self.output.root_globals is self.f_globals:
- source = GlobalSource(name)
- else:
- if "__name__" in self.f_globals:
- source = AttrSource(
- self.import_source(self.f_globals["__name__"]), name
- )
- else:
- mangled_name = f"___unnamed_scope_{id(self.f_globals)}"
- if mangled_name not in self.output.root_globals:
- self.output.install_global(mangled_name, self.f_globals)
- source = GetItemSource(GlobalSource(mangled_name), name)
- return source
- def LOAD_GLOBAL(self, inst):
- name = inst.argval
- if config.replay_record_enabled:
- if name in self.f_globals:
- self.exec_recorder.add_global_var(name, self.f_globals[name])
- else:
- assert name in self.f_builtins
- self.exec_recorder.builtins[name] = self.f_builtins[name]
- if name in self.symbolic_globals:
- variable = self.output.side_effects[self.symbolic_globals[name]]
- self.push(self.output.side_effects.load_global(variable, name))
- return
- try:
- value = self.f_globals[name]
- except KeyError:
- return self.load_builtin(inst)
- source = self.get_global_source(name)
- self.push(VariableBuilder(self, source)(value))
- def STORE_GLOBAL(self, inst):
- value = self.pop()
- name = inst.argval
- source = self.get_global_source(name)
- if name not in self.symbolic_globals:
- self.symbolic_globals[name] = object() # sentinel object
- variable = self.output.side_effects.track_global_existing(
- source, self.symbolic_globals[name]
- )
- self.output.side_effects.store_global(variable, name, value)
- def import_source(self, module_name):
- """Create an alias to a module for use in guards"""
- if "torch_package" in module_name:
- value = torch.package.package_importer._package_imported_modules[
- module_name
- ]
- alias = (
- module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
- )
- else:
- value = importlib.import_module(module_name)
- alias = f"__import_{module_name.replace('.', '_dot_')}"
- f_globals = self.output.root_globals
- assert alias not in f_globals or f_globals[alias] is value
- f_globals[alias] = value
- self.output.update_co_names(alias)
- return GlobalSource(alias)
- def resolve_name(self, name, package, level):
- """
- Copied from the Cpython implementation of __import__
- Resolve a relative module name to an absolute one.
- https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902
- """
- bits = package.rsplit(".", level - 1)
- if len(bits) < level:
- raise ImportError("attempted relative import beyond top-level package")
- base = bits[0]
- return "{}.{}".format(base, name) if name else base
- def calc_package(self):
- """
- Copied from the Cpython implementation of __import__
- https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090
- """
- package = self.f_globals.get("__package__")
- spec = self.f_globals.get("__spec__")
- if package is not None:
- if spec is not None and package != spec.parent:
- log.warning(
- "__package__ != __spec__.parent "
- f"({package!r} != {spec.parent!r})",
- ImportWarning,
- stacklevel=3,
- ) # type: ignore[call-arg]
- return package
- elif spec is not None:
- return spec.parent
- else:
- log.warning(
- "can't resolve package from __spec__ or __package__, "
- "falling back on __name__ and __path__",
- ImportWarning,
- stacklevel=3,
- ) # type: ignore[call-arg]
- package = self.f_globals["__name__"]
- if "__path__" not in self.f_globals:
- package = package.rpartition(".")[0]
- return package
- def IMPORT_NAME(self, inst):
- level, fromlist = self.popn(2)
- level = level.as_python_constant()
- fromlist = fromlist.as_python_constant()
- module_name = inst.argval
- # Are we replaying? if so, load recorded module
- recorded_name = (
- f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}"
- )
- if recorded_name in self.f_globals:
- value = self.f_globals[recorded_name]
- source = GlobalSource(recorded_name)
- else:
- value = __import__(
- module_name,
- fromlist=fromlist,
- level=level,
- globals=self.f_globals,
- )
- if level != 0:
- pkg = self.calc_package()
- module_name = self.resolve_name(module_name, pkg, level)
- # For __import__, when the name variable is of the form package.module,
- # normally, the top-level package (the name up till the first dot) is
- # returned, not the module named by module_name. However, when a
- # non-empty fromlist argument is given, the module named by name is
- # returned. Therefore, we set the source correctly here.
- if not fromlist:
- top_level_module_name = module_name.partition(".")[0]
- source = self.import_source(top_level_module_name)
- else:
- source = self.import_source(module_name)
- if config.replay_record_enabled:
- self.exec_recorder.add_local_mod(recorded_name, value)
- if is_allowed(value):
- self.push(TorchVariable(value, source=source))
- elif istype(value, (types.ModuleType, DummyModule)):
- self.push(PythonModuleVariable(value, source=source))
- else:
- unimplemented(f"IMPORT_NAME {typestr(value)}")
- def IMPORT_FROM(self, inst):
- self.DUP_TOP(inst)
- self.LOAD_ATTR(inst)
- def load_builtin(self, inst):
- assert inst.argval in self.f_builtins
- val = self.f_builtins[inst.argval]
- if callable(val):
- assert is_builtin_callable(val)
- self.push(VariableBuilder(self, GlobalSource(inst.argval))(val))
- else:
- assert is_builtin_constant(val)
- self.push(ConstantVariable(value=val))
- def jump(self, inst):
- self.instruction_pointer = self.indexof[id(inst.target)]
- JUMP_FORWARD = jump
- JUMP_ABSOLUTE = jump
- POP_JUMP_IF_FALSE = generic_jump(operator.not_, False)
- POP_JUMP_IF_TRUE = generic_jump(operator.truth, False)
- JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True)
- JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True)
- def SETUP_LOOP(self, inst):
- # only exists in python<=3.7
- self.block_stack.append(BlockStackEntry(inst.target))
- def SETUP_EXCEPT(self, inst):
- # only exists in python<=3.7
- self.block_stack.append(BlockStackEntry(inst.target))
- def POP_BLOCK(self, inst):
- self.block_stack.pop()
- def SETUP_WITH(self, inst):
- ctx = self.pop()
- if not isinstance(ctx, ContextWrappingVariable):
- unimplemented(f"SETUP_WITH {ctx}")
- self.output.guards.update(ctx.guards)
- if isinstance(self, InstructionTranslator):
- self.block_stack.append(BlockStackEntry(inst.target, len(self.stack), ctx))
- else:
- # can't restore this while inlining
- self.block_stack.append(BlockStackEntry(inst.target))
- self.push(
- WithExitFunctionVariable(
- ctx,
- inst.target,
- **VariableTracker.propagate(ctx),
- )
- )
- self.push(ctx.enter(self))
- def SETUP_FINALLY(self, inst):
- self.block_stack.append(BlockStackEntry(inst.target))
- def BEGIN_FINALLY(self, inst):
- self.push(None)
- def WITH_CLEANUP_START(self, inst):
- exit, exc = self.popn(2)
- assert exc is None
- self.push(exc)
- self.push(exit.call_function(self, [ConstantVariable(None)] * 3, {}))
- def WITH_CLEANUP_FINISH(self, inst):
- self.popn(2)
- self.push(None)
- def END_FINALLY(self, inst):
- tos = self.pop()
- assert tos is None
- def FOR_ITER(self, inst):
- it = self.pop()
- if isinstance(it, ListIteratorVariable):
- self.output.guards.update(it.guards)
- try:
- val, next_iter = it.next_variables()
- self.replace_all(it, next_iter)
- self.push(next_iter)
- self.push(val)
- except StopIteration:
- self.jump(inst)
- else:
- unimplemented(f"FOR_ITER {typestr(it)}")
- def COMPARE_OP(self, inst):
- left, right = self.popn(2)
- left = left.as_specialized(self)
- right = right.as_specialized(self)
- options = VariableTracker.propagate([left, right])
- op = inst.argval
- supported_any = dict(
- itertools.chain(
- supported_tensor_comparison_ops.items(),
- supported_const_comparison_ops.items(),
- )
- )
- if (
- isinstance(
- left,
- (
- TensorVariable,
- SymNodeVariable,
- NNModuleVariable,
- BaseListVariable,
- UserDefinedVariable,
- BaseUserFunctionVariable,
- ConstDictVariable,
- ),
- )
- and isinstance(right, ConstantVariable)
- and right.value is None
- and op in supported_const_comparison_ops
- ):
- # <non-None> is None
- self.push(
- ConstantVariable(
- supported_const_comparison_ops[op](object(), right.value), **options
- )
- )
- elif (
- left.is_python_constant()
- and right.is_python_constant()
- and op in supported_any
- ):
- # constant fold
- self.push(
- ConstantVariable(
- supported_any[op](
- left.as_python_constant(), right.as_python_constant()
- ),
- **options,
- )
- )
- elif op in ("in", "not in"):
- self.push(right.call_method(self, "__contains__", [left], {}))
- if op == "not in":
- self.UNARY_NOT(inst)
- else:
- self.push(
- BuiltinVariable(supported_any[op], **options).call_function(
- self, [left, right], {}
- )
- )
- def GET_ITER(self, inst):
- self.call_function(BuiltinVariable(iter), [self.pop()], {})
- @break_graph_if_unsupported(push=1)
- def CALL_FUNCTION(self, inst):
- args = self.popn(inst.argval)
- fn = self.pop()
- self.call_function(fn, args, {})
- @break_graph_if_unsupported(push=1)
- def CALL_FUNCTION_EX(self, inst):
- if inst.argval == 0:
- kwargsvars = ConstDictVariable({}, dict)
- argsvars = self.pop()
- elif inst.argval == 1:
- kwargsvars = self.pop()
- argsvars = self.pop()
- else:
- unimplemented("CALL_FUNCTION_EX")
- fn = self.pop()
- self.output.guards.update(argsvars.guards)
- self.output.guards.update(kwargsvars.guards)
- if (
- isinstance(fn, GetAttrVariable)
- and isinstance(fn.obj, TensorVariable)
- and fn.name == "view"
- and isinstance(argsvars, (ConstantVariable, TensorVariable))
- ):
- # Hack to handle special case in some bert models. Converts
- # x.view(*shape) into x.view(shape), which is correct for view()
- # but not generally. See test_transpose_for_scores().
- argsvars = TupleVariable([argsvars])
- if not isinstance(
- argsvars, BaseListVariable
- ) and argsvars.has_unpack_var_sequence(self):
- argsvars = TupleVariable(argsvars.unpack_var_sequence(self))
- if not isinstance(argsvars, BaseListVariable) or not isinstance(
- kwargsvars, ConstDictVariable
- ):
- unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}")
- self.call_function(fn, argsvars.items, kwargsvars.items)
- @break_graph_if_unsupported(push=1)
- def CALL_FUNCTION_KW(self, inst):
- argnames = self.pop()
- args = self.popn(inst.argval)
- fn = self.pop()
- assert isinstance(argnames, ConstantVariable)
- argnames = argnames.value
- args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :]
- kwargs = dict(zip(argnames, kwargs_list))
- assert len(kwargs) == len(argnames)
- self.call_function(fn, args, kwargs)
- def LOAD_METHOD(self, inst):
- self.LOAD_ATTR(inst)
- self.push(self.pop())
- self.push(None)
- def CALL_METHOD(self, inst):
- args = self.popn(inst.argval)
- dummy = self.pop()
- assert dummy is None
- fn = self.pop()
- self.call_function(fn, args, {})
- def LOAD_ATTR(self, inst):
- obj = self.pop()
- result = BuiltinVariable(getattr).call_function(
- self, [obj, ConstantVariable(inst.argval)], {}
- )
- self.push(result)
- def STORE_ATTR(self, inst):
- prior = self.copy_graphstate()
- val, obj = self.popn(2)
- if isinstance(obj, NNModuleVariable):
- # We don't allow side effects during export
- # https://github.com/pytorch/torchdynamo/issues/1475
- assert (
- not self.export
- ), f"Mutating module attribute {inst.argval} during export."
- try:
- self.output.guards.update(
- BuiltinVariable(setattr)
- .call_function(self, [obj, ConstantVariable(inst.argval), val], {})
- .guards
- )
- return
- except Unsupported as e:
- if not self.should_compile_partial_graph():
- raise
- log.debug("STORE_ATTR triggered compile", exc_info=True)
- e.remove_from_stats()
- e.add_to_stats("graph_break")
- self.restore_graphstate(prior)
- # break the graph
- self.output.compile_subgraph(
- self, reason=GraphCompileReason("store_attr", [self.frame_summary()])
- )
- self.output.add_output_instructions([inst])
- self.popn(2)
- self.output.add_output_instructions(
- self.create_call_resume_at(self.next_instruction)
- )
- def create_call_resume_at(self, offset):
- raise AssertionError(
- f"create_call_resume_at not overridden by subclass {type(self)}"
- )
- def should_compile_partial_graph(self) -> bool:
- raise AssertionError(
- f"should_compile_partial_graph not overridden by subclass {type(self)}"
- )
- @break_graph_if_unsupported(push=0)
- def STORE_SUBSCR(self, inst):
- val, obj, key = self.popn(3)
- result = obj.call_method(self, "__setitem__", [key, val], {})
- # no result is pushed, so need to lift the guards to global
- self.output.guards.update(result.guards)
- def BUILD_TUPLE(self, inst):
- items = self.popn(inst.argval)
- options = VariableTracker.propagate(items)
- self.push(TupleVariable(items, **options))
- def BUILD_SLICE(self, inst):
- items = self.popn(inst.argval)
- options = VariableTracker.propagate(items)
- self.push(
- SliceVariable(
- [x.as_specialized(self) for x in items],
- **options,
- )
- )
- def BUILD_LIST(self, inst):
- items = self.popn(inst.argval)
- options = VariableTracker.propagate(items)
- self.push(ListVariable(items, mutable_local=MutableLocal(), **options))
- def BUILD_LIST_UNPACK(self, inst, cls=ListVariable):
- seqs = self.popn(inst.argval)
- options = VariableTracker.propagate(seqs)
- items = list()
- for seq in seqs:
- try:
- items.extend(seq.unpack_var_sequence(self))
- except NotImplementedError:
- unimplemented(f"BUILD_LIST_UNPACK {seq}")
- self.push(cls(items, mutable_local=MutableLocal(), **options))
- def BUILD_TUPLE_UNPACK(self, inst):
- self.BUILD_LIST_UNPACK(inst, cls=TupleVariable)
- BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK
- def BUILD_MAP(self, inst):
- items = self.popn(inst.argval * 2)
- options = VariableTracker.propagate(items)
- result = dict()
- for k, v in zip(items[::2], items[1::2]):
- assert isinstance(k, (ConstantVariable, EnumVariable)) or (
- isinstance(k, TensorVariable) and k.specialized_value is not None
- )
- result[ConstDictVariable.get_key(k)] = v
- assert len(result) == len(items) / 2
- self.push(
- ConstDictVariable(result, dict, mutable_local=MutableLocal(), **options)
- )
- def BUILD_CONST_KEY_MAP(self, inst):
- keys = self.pop()
- values = self.popn(inst.argval)
- options = VariableTracker.propagate([keys] + values)
- assert isinstance(keys, ConstantVariable)
- keys = keys.value
- assert istype(keys, tuple)
- assert len(keys) == len(values)
- self.push(
- ConstDictVariable(
- dict(zip(keys, values)),
- dict,
- mutable_local=MutableLocal(),
- **options,
- )
- )
- def MAP_ADD(self, inst):
- k, v = self.popn(2)
- assert inst.argval > 0
- obj = self.stack[-inst.arg]
- assert isinstance(obj, ConstDictVariable)
- assert obj.mutable_local
- items = dict(obj.items)
- items[k.as_python_constant()] = v
- self.replace_all(
- obj,
- ConstDictVariable(
- items,
- obj.user_cls,
- **VariableTracker.propagate([obj, k, v]),
- ),
- )
- def LIST_APPEND(self, inst):
- v = self.pop()
- assert inst.argval > 0
- obj = self.stack[-inst.arg]
- assert isinstance(obj, ListVariable)
- assert obj.mutable_local
- # only copy if the new obj contains other mutables
- new_rec_contains = obj.recursively_contains
- if v.recursively_contains or v.mutable_local:
- new_rec_contains = obj.recursively_contains.union(v.recursively_contains)
- if v.mutable_local:
- new_rec_contains.add(v.mutable_local)
- self.replace_all(
- obj,
- ListVariable(
- obj.items + [v],
- recursively_contains=new_rec_contains,
- regen_guards=False,
- **VariableTracker.propagate([obj, v]),
- ),
- )
- def MAKE_FUNCTION(self, inst):
- flags = inst.arg
- old_stack = list(self.stack)
- fn_name = self.pop()
- code = self.pop()
- defaults = None
- closure = None
- annotations = None
- kwdefaults = None
- if flags & 0x08:
- closure = self.pop()
- if flags & 0x04:
- annotations = self.pop()
- if flags & 0x02:
- kwdefaults = self.pop()
- if flags & 0x01:
- defaults = self.pop()
- options = VariableTracker.propagate(old_stack[len(self.stack) :])
- self.push(
- NestedUserFunctionVariable(
- fn_name,
- code,
- self.f_globals,
- defaults,
- kwdefaults,
- annotations,
- closure,
- closure_scope=self,
- **options,
- )
- )
- def UNPACK_SEQUENCE(self, inst):
- seq = self.pop()
- if isinstance(seq, BaseListVariable):
- self.output.guards.update(seq.guards)
- val = seq.unpack_var_sequence(self)
- elif seq.is_python_constant() and isinstance(seq, ConstantVariable):
- val = seq.unpack_var_sequence(self)
- elif isinstance(seq, TensorVariable):
- val = seq.unpack_var_sequence(self, idxes=range(inst.argval))
- elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable):
- # x, y = a.shape
- proxy = getattr(seq.obj.as_proxy(), seq.name)
- options = VariableTracker.propagate(self)
- val = [wrap_fx_proxy(self, proxy[i], **options) for i in range(inst.argval)]
- else:
- unimplemented(f"UNPACK_SEQUENCE {seq}")
- assert len(val) == inst.argval
- for i in reversed(val):
- self.push(i)
- def UNPACK_EX(self, inst):
- assert 0 <= inst.argval <= 0xFFFF
- prefix = inst.argval & 0xFF # low byte
- suffix = inst.argval >> 8 # high byte
- seq = self.pop()
- options = VariableTracker.propagate(seq)
- if seq.has_unpack_var_sequence(self):
- vals = list(seq.unpack_var_sequence(self))
- assert len(vals) >= prefix + suffix
- vals_prefix = vals[:prefix]
- vals_list = vals[prefix : len(vals) - suffix]
- vals_suffix = vals[len(vals) - suffix :]
- for item in reversed(vals_suffix):
- self.push(item.add_options(options))
- self.push(TupleVariable(vals_list, **options))
- for item in reversed(vals_prefix):
- self.push(item.add_options(options))
- else:
- unimplemented(f"UNPACK_EX {seq}")
- def NOP(self, inst):
- pass
- def POP_TOP(self, inst):
- self.pop()
- def ROT_TWO(self, inst):
- a = self.pop()
- b = self.pop()
- self.push(a)
- self.push(b)
- def ROT_THREE(self, inst):
- a = self.pop()
- b = self.pop()
- c = self.pop()
- self.push(a)
- self.push(c)
- self.push(b)
- def ROT_FOUR(self, inst):
- a = self.pop()
- b = self.pop()
- c = self.pop()
- d = self.pop()
- self.push(a)
- self.push(d)
- self.push(c)
- self.push(b)
- def DUP_TOP(self, inst):
- a = self.pop()
- self.push(a)
- self.push(a)
- def DUP_TOP_TWO(self, inst):
- a = self.pop()
- b = self.pop()
- self.push(b)
- self.push(a)
- self.push(b)
- self.push(a)
- def FORMAT_VALUE(self, inst):
- flags = inst.arg
- if (flags & 0x04) == 0x04:
- fmt_spec = self.pop()
- else:
- fmt_spec = ConstantVariable("")
- value = self.pop()
- if isinstance(value, SymNodeVariable):
- value = ConstantVariable(str(value.sym_num))
- if (flags & 0x03) == 0x01:
- value = BuiltinVariable(str).call_function(self, [value], {})
- elif (flags & 0x03) == 0x02:
- value = BuiltinVariable(repr).call_function(self, [value], {})
- elif (flags & 0x03) == 0x03:
- value = BuiltinVariable(ascii).call_function(self, [value], {})
- fmt_var = ConstantVariable(
- "{:" + fmt_spec.as_python_constant() + "}"
- ).add_options(fmt_spec)
- self.call_function(BuiltinVariable(str.format), [fmt_var, value], {})
- def BUILD_STRING(self, inst):
- result = ""
- for _ in range(inst.arg):
- str_var = self.pop()
- assert isinstance(str_var, ConstantVariable)
- result = str_var.value + result
- self.push(ConstantVariable(value=result))
- def IS_OP(self, inst):
- assert inst.argval == 0 or inst.argval == 1
- if inst.argval == 0:
- new_argval = "is"
- else:
- new_argval = "is not"
- new_inst = create_instruction("COMPARE_OP", argval=new_argval)
- self.COMPARE_OP(new_inst)
- def CONTAINS_OP(self, inst):
- assert inst.argval == 0 or inst.argval == 1
- left, right = self.popn(2)
- op = inst.argval
- self.push(right.call_method(self, "__contains__", [left], {}))
- if op == 1:
- self.UNARY_NOT(inst)
- def LIST_EXTEND(self, inst):
- v = self.pop()
- assert inst.argval > 0
- obj = self.stack[-inst.arg]
- assert isinstance(obj, ListVariable)
- assert obj.mutable_local
- obj.call_method(self, "extend", [v], {})
- def LIST_TO_TUPLE(self, inst):
- self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {}))
- def DICT_MERGE(self, inst):
- v = self.pop()
- assert inst.argval > 0
- obj = self.stack[-inst.arg]
- assert isinstance(obj, ConstDictVariable)
- assert obj.mutable_local
- obj.call_method(self, "update", [v], {})
- def GEN_START(self, inst):
- self.pop()
- def GET_LEN(self, inst):
- tos = self.stack[-1]
- if tos.is_python_constant():
- self.push(ConstantVariable(len(tos.as_python_constant())))
- else:
- self.push(tos.call_method(self, "__len__", [], {}))
- def MATCH_MAPPING(self, inst):
- tos = self.stack[-1]
- assert isinstance(tos, ConstDictVariable)
- if isinstance(tos.items, collections.abc.Mapping):
- self.push(ConstantVariable(True))
- else:
- self.push(ConstantVariable(False))
- def MATCH_SEQUENCE(self, inst):
- tos = self.stack[-1]
- assert tos.is_python_constant()
- tos_value = tos.as_python_constant()
- if isinstance(tos_value, collections.abc.Sequence) and not isinstance(
- tos_value, (str, bytes, bytearray)
- ):
- self.push(ConstantVariable(True))
- else:
- self.push(ConstantVariable(False))
- def MATCH_KEYS(self, inst):
- tos = self.stack[-1]
- assert tos.is_python_constant()
- keys = tos.as_python_constant()
- tos1 = self.stack[-2]
- assert isinstance(tos1, ConstDictVariable)
- match_obj = tos1.items
- if all(key in match_obj for key in keys):
- self.push(TupleVariable([match_obj[key] for key in keys]))
- self.push(ConstantVariable(True))
- else:
- self.push(ConstantVariable(None))
- self.push(ConstantVariable(False))
- UNARY_POSITIVE = stack_op(operator.pos)
- UNARY_NEGATIVE = stack_op(operator.neg)
- UNARY_NOT = stack_op(operator.not_)
- UNARY_INVERT = stack_op(operator.invert)
- BINARY_POWER = stack_op(operator.pow)
- BINARY_MULTIPLY = stack_op(operator.mul)
- BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul)
- BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv)
- BINARY_TRUE_DIVIDE = stack_op(operator.truediv)
- BINARY_MODULO = stack_op(operator.mod)
- BINARY_REMAINDER = stack_op(operator.mod)
- BINARY_ADD = stack_op(operator.add)
- BINARY_SUBTRACT = stack_op(operator.sub)
- BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem))
- BINARY_LSHIFT = stack_op(operator.lshift)
- BINARY_RSHIFT = stack_op(operator.rshift)
- BINARY_AND = stack_op(operator.and_)
- BINARY_OR = stack_op(operator.or_)
- BINARY_XOR = stack_op(operator.xor)
- INPLACE_POWER = stack_op(operator.ipow)
- INPLACE_MULTIPLY = stack_op(operator.imul)
- INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul)
- INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv)
- INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv)
- INPLACE_MODULO = stack_op(operator.imod)
- INPLACE_REMAINDER = stack_op(operator.imod)
- INPLACE_ADD = stack_op(operator.iadd)
- INPLACE_SUBTRACT = stack_op(operator.isub)
- INPLACE_LSHIFT = stack_op(operator.ilshift)
- INPLACE_RSHIFT = stack_op(operator.irshift)
- INPLACE_AND = stack_op(operator.iand)
- INPLACE_XOR = stack_op(operator.ixor)
- INPLACE_OR = stack_op(operator.ior)
- # 3.11 opcodes
- # note: passed opcodes are intentional
- def RESUME(self, inst):
- pass
- def BINARY_OP(self, inst):
- if sys.version_info >= (3, 11):
- opname = dis._nb_ops[inst.arg][0][3:]
- if opname.startswith("INPLACE"):
- return getattr(self, "INPLACE_" + opname[8:])(inst)
- return getattr(self, "BINARY_" + opname)(inst)
- else:
- unimplemented("BINARY_OP requires Python 3.11+")
- def COPY(self, inst):
- self.push(self.stack[-inst.arg])
- def SWAP(self, inst):
- self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1]
- JUMP_BACKWARD = jump
- JUMP_BACKWARD_NO_INTERRUPT = jump
- POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False)
- POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False)
- POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False)
- POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False)
- POP_JUMP_FORWARD_IF_NOT_NONE = generic_jump(is_not_none, False)
- POP_JUMP_BACKWARD_IF_NOT_NONE = generic_jump(is_not_none, False)
- POP_JUMP_FORWARD_IF_NONE = generic_jump(is_none, False)
- POP_JUMP_BACKWARD_IF_NONE = generic_jump(is_none, False)
- def CACHE(self, inst):
- pass
- def copy_graphstate(self) -> InstructionTranslatorGraphState:
- """Create a checkpoint of the current state by copying everything"""
- return InstructionTranslatorGraphState(
- self.output.copy_graphstate(),
- collections.OrderedDict(self.symbolic_locals),
- list(self.stack),
- list(self.block_stack),
- self.instruction_pointer,
- self.current_instruction,
- self.next_instruction,
- self.lineno,
- )
- def restore_graphstate(self, state: InstructionTranslatorGraphState):
- """Restore a checkpoint created by self.copy_graphstate()"""
- (
- output_state,
- self.symbolic_locals,
- self.stack,
- self.block_stack,
- self.instruction_pointer,
- self.current_instruction,
- self.next_instruction,
- self.lineno,
- ) = state
- self.output.restore_graphstate(output_state)
- def empty_checkpoint(self):
- if self.checkpoint is None:
- return True
- output_graphstate = self.checkpoint[1][0]
- graphstate = self.checkpoint[1][1:]
- state = (*output_graphstate, *graphstate)
- for obj in state:
- if isinstance(obj, Sized):
- if len(obj) != 0:
- return False
- return True
- def format_frame_summary(self, additional_stack_frames=None):
- if additional_stack_frames is None:
- additional_stack_frames = []
- return "".join(
- traceback.format_list(
- ([self.frame_summary()] + list(reversed(additional_stack_frames)))
- )
- )
- def frame_summary(self):
- return traceback.FrameSummary(
- getattr(self.f_code, "co_filename", "<unknown>"),
- self.lineno,
- getattr(self.f_code, "co_name", "<unknown>"),
- lookup_line=False,
- )
- def store_dict_key(self, name, value):
- self.output.guards.add(
- GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE)
- )
- if name not in self.output.root_globals:
- self.output.install_global(name, weakref.ref(value))
- @property
- def fake_mode(self):
- return self._fake_mode
- def find_symbolic_locals_name(self, tensor_variable):
- for key, value in self.symbolic_locals.items():
- if value is tensor_variable:
- return key
- return None
- def __init__(
- self,
- output: OutputGraph,
- instructions: List[Instruction],
- f_locals: Dict[str, Any],
- f_globals: Dict[str, Any],
- f_builtins: Dict[str, Any],
- code_options: Dict[str, Any],
- symbolic_locals: Dict[str, VariableTracker],
- symbolic_globals: Dict[str, VariableTracker],
- f_code: types.CodeType,
- export: bool,
- ):
- super().__init__()
- # Mutable state checkpointed by copy_graphstate()
- self.output = output
- self.symbolic_locals = symbolic_locals
- self.symbolic_globals = symbolic_globals
- self.stack = []
- self.instruction_pointer = 0
- self.current_instruction = create_instruction("NOP")
- self.next_instruction = None
- self.block_stack = []
- self.lineno = code_options["co_firstlineno"]
- # Properties of the input/output code
- self.instructions: List[Instruction] = instructions
- self.indexof: Dict[int, int] = {id(i): n for n, i in enumerate(instructions)}
- self.f_locals: Dict[
- str, Any
- ] = f_locals # needed for recording accessed locals for replay
- self.f_globals: Dict[str, Any] = f_globals
- self.f_builtins: Dict[str, Any] = f_builtins
- self.code_options: Dict[str, Any] = code_options
- self.f_code: types.CodeType = f_code
- # Execution record for replaying errors
- self.exec_recorder = ExecutionRecorder(code=f_code, code_options=code_options)
- # Stack of module being parsed, current nn.module is at the end of ordered dict
- self.nn_module_stack: Dict[str, str] = {}
- # Flag to indicate whether tracing is used for export.
- self.export = export
- self._fake_mode = output.tracing_context.fake_mode
- self.checkpoint = None
- self.random_calls = []
- if sys.version_info >= (3, 10):
- from .resume_execution import (
- CO_ASYNC_GENERATOR,
- CO_COROUTINE,
- CO_GENERATOR,
- CO_ITERABLE_COROUTINE,
- )
- if f_code.co_flags & (
- CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR
- ):
- self.push(BuiltinVariable(None))
- class InstructionTranslator(InstructionTranslatorBase):
- def __init__(
- self,
- instructions: List[Instruction],
- f_code,
- f_locals,
- f_globals,
- f_builtins,
- code_options,
- compiler_fn,
- one_graph,
- export,
- mutated_closure_cell_contents: Set[str],
- ):
- super().__init__(
- output=OutputGraph(f_globals, code_options, compiler_fn, self),
- instructions=instructions,
- f_locals=f_locals,
- f_globals=f_globals,
- f_builtins=f_builtins,
- code_options=code_options,
- symbolic_locals=collections.OrderedDict(), # set below
- # A global var is inserted only after a STORE_GLOBAL happens to it
- symbolic_globals=collections.OrderedDict(),
- f_code=f_code,
- export=export,
- )
- self.one_graph: bool = one_graph
- self.export = export
- self.mutated_closure_cell_contents = mutated_closure_cell_contents
- if self.export:
- assert (
- self.one_graph
- ), "Export without one graph - something has gone wrong."
- vars = list(code_options["co_varnames"])
- vars.extend(x for x in self.cell_and_freevars() if x not in vars)
- self.symbolic_locals = collections.OrderedDict(
- (
- k,
- VariableBuilder(
- self,
- LocalInputSource(k, code_options["co_varnames"].index(k))
- if k in code_options["co_varnames"]
- else LocalSource((k)),
- )(f_locals[k]),
- )
- for k in vars
- if k in f_locals
- )
- # symbolic_locals contains the mapping from original f_locals to the
- # Variable objects. During the Variable building phase, each object also
- # has its associated guards. At the end, we will accumulate these
- # guards.
- #
- # One way of handling these guards is to just accumulate all of them
- # right now. However, many f_locals might not be used in the frame and
- # thus can unnecessarily increase guard execution overhead. Therefore,
- # we selectively update output.guards as we run the Python Bytecode
- # instruction by instruction.
- #
- # An exception here is list/dict variables. Guards related to these
- # variables have indexed access, like Tensor_match on args[0], and if
- # args is not used in this frame, we will miss a LIST_LENGTH check like
- # len(args) == 2. Missing the LIST_LENGTH check causes problem for the
- # next invocation when args is not a list, and args[0] is a runtime
- # error. Therefore, we recursively add guards for list/dict variable here.
- for val in self.symbolic_locals.values():
- if isinstance(
- val, (ListIteratorVariable, BaseListVariable, ConstDictVariable)
- ):
- local_guards = VariableTracker.propagate(val)["guards"]
- index_guards = [
- guard
- for guard in local_guards
- if guard.create_fn
- in (
- GuardBuilder.LIST_LENGTH,
- GuardBuilder.DICT_KEYS,
- GuardBuilder.ODICT_KEYS,
- GuardBuilder.TUPLE_ITERATOR_LEN,
- )
- ]
- self.output.guards.update(index_guards)
- self._freevars_ids = dict()
- for name in self.code_options["co_freevars"]:
- if name in f_locals:
- self._freevars_ids[name] = id(f_locals[name])
- def run(self):
- _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
- super().run()
- def match_nested_cell(self, name, cell):
- """Match a cell in this method to one in a function we are inlining"""
- value = cell.cell_contents
- # TODO(jansel): check the id of the cell rather than the contents
- if id(value) != self._freevars_ids.get(name):
- return None
- return self.symbolic_locals[name]
- def should_compile_partial_graph(self):
- return all(b.can_restore() for b in self.block_stack) and not self.one_graph
- def create_call_resume_at(self, inst):
- self.instruction_pointer = None
- if inst.opname == "RETURN_VALUE":
- return [create_instruction("RETURN_VALUE")]
- reads = livevars_analysis(self.instructions, inst)
- argnames = tuple(
- k
- for k in self.symbolic_locals.keys()
- if k in reads and k not in self.cell_and_freevars()
- )
- nargs = len(self.stack) + len(argnames)
- name = unique_id(f"__resume_at_{inst.offset}")
- new_code: types.CodeType = ContinueExecutionCache.lookup(
- self.f_code,
- self.lineno,
- inst.offset,
- len(self.stack),
- argnames,
- tuple(b.resume_fn() for b in self.block_stack),
- )
- cg = PyCodegen(self)
- if new_code.co_freevars:
- cg.make_function_with_closure(name, new_code, len(self.stack))
- else:
- self.output.install_global(
- name, types.FunctionType(new_code, self.f_globals, name)
- )
- cg.extend_output(cg.load_function_name(name, len(self.stack)))
- cg.extend_output([cg.create_load(k) for k in argnames])
- cg.extend_output(
- [
- create_instruction("CALL_FUNCTION", nargs),
- create_instruction("RETURN_VALUE"),
- ]
- )
- return cg.get_instructions()
- def RETURN_VALUE(self, inst):
- if self.output.count_calls() == 0:
- raise exc.SkipFrame("because no content in function call")
- self.instruction_pointer = None
- _step_logger()(
- logging.INFO,
- f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)",
- )
- log.debug("RETURN_VALUE triggered compile")
- self.output.compile_subgraph(
- self, reason=GraphCompileReason("return_value", [self.frame_summary()])
- )
- self.output.add_output_instructions([create_instruction("RETURN_VALUE")])
- class InliningInstructionTranslator(InstructionTranslatorBase):
- """Trace and inline a called method"""
- symbolic_result: Optional[TensorVariable]
- @classmethod
- def inline_call(cls, parent, func, args, kwargs):
- with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
- return cls.inline_call_(parent, func, args, kwargs)
- @staticmethod
- def inline_call_(parent, func, args, kwargs):
- assert isinstance(
- func,
- (UserFunctionVariable, NestedUserFunctionVariable),
- )
- if func.has_self():
- unimplemented("inline with __self__")
- if func.get_name() == "patched_init":
- unimplemented("Patched init cannot be inlined.")
- try:
- if id(func.get_function()) in allowed_functions._disallowed_function_ids:
- unimplemented(f"inlining disallowed: {func.get_function()}")
- except NotImplementedError:
- pass # closures
- if skipfiles.check(
- func.get_filename()
- ) and not skipfiles.is_torch_inline_allowed(func.get_filename()):
- unimplemented(
- f"inline in skipfiles: {func.fn.__qualname__} | {func.get_name()} {func.get_filename()}"
- )
- try:
- sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
- except TypeError as e:
- log.warning(
- f"{func.get_filename()} {func.get_function()} {args} {kwargs} {e}"
- )
- unimplemented("arg mismatch inlining")
- for v in itertools.chain(sub_locals.values(), closure_cells.values()):
- if not isinstance(v, VariableTracker):
- unimplemented(f"unconverted arg {v}")
- code: types.CodeType = func.get_code()
- if code.co_name in ("__setitem__", "__setattr__"):
- unimplemented(f"inline {code.co_name}")
- log.debug(f"INLINING {code} \n {dis.Bytecode(code).dis()} \n")
- tracer: InliningInstructionTranslator
- if is_generator(code):
- tracer = InliningGeneratorInstructionTranslator(
- parent, code, sub_locals, parent.symbolic_globals, closure_cells, func
- )
- else:
- tracer = InliningInstructionTranslator(
- parent, code, sub_locals, parent.symbolic_globals, closure_cells, func
- )
- try:
- tracer.run()
- except exc.SkipFrame as e:
- msg = f"SKIPPED INLINING {code}: {e}"
- log.debug(msg)
- raise Unsupported(msg) from e
- except Exception as e:
- log.debug(f"FAILED INLINING {code}")
- raise
- assert tracer.symbolic_result is not None
- func.export_freevars(parent, tracer)
- if tracer.f_globals is parent.f_globals:
- # Merge symbolic_globals back if parent and child are in the same namespace
- parent.symbolic_globals.update(tracer.symbolic_globals)
- log.debug(f"DONE INLINING {code}")
- if is_generator(code):
- assert isinstance(tracer, InliningGeneratorInstructionTranslator)
- assert tracer.symbolic_result.as_python_constant() is None
- return ListIteratorVariable(
- tracer.generated_items,
- mutable_local=MutableLocal(),
- **VariableTracker.propagate(tracer.symbolic_result),
- )
- else:
- return tracer.symbolic_result
- def __init__(
- self,
- parent: InstructionTranslatorBase,
- code: types.CodeType,
- symbolic_locals: Dict[str, VariableTracker],
- symbolic_globals: Dict[str, VariableTracker],
- closure_cells: Dict[str, VariableTracker],
- funcvar: BaseUserFunctionVariable,
- ):
- f_globals = funcvar.get_globals()
- f_builtins = f_globals["__builtins__"]
- if not isinstance(f_builtins, dict):
- f_builtins = f_builtins.__dict__
- super().__init__(
- output=parent.output,
- f_locals={},
- f_globals=f_globals,
- f_builtins=f_builtins,
- symbolic_locals=symbolic_locals,
- symbolic_globals=symbolic_globals,
- instructions=cleaned_instructions(code),
- code_options={k: getattr(code, k) for k in dir(code)},
- f_code=code,
- export=parent.export,
- )
- self.parent = parent
- self.symbolic_result = None
- self.closure_cells = closure_cells
- self.nn_module_stack = parent.nn_module_stack.copy()
- @property
- def fake_mode(self):
- return self.parent.fake_mode
- def STORE_DEREF(self, inst):
- if inst.argval in self.closure_cells:
- cell = self.closure_cells[inst.argval]
- val = self.pop()
- if isinstance(cell, ClosureVariable):
- self.output.root_tx.symbolic_locals[cell.name] = val
- else:
- self.output.side_effects.store_cell(cell, val)
- else:
- maybe_cell = self.symbolic_locals.get(inst.argval)
- if isinstance(
- maybe_cell,
- variables.NewCellVariable,
- ):
- self.output.side_effects.store_cell(
- self.symbolic_locals[inst.argval], self.pop()
- )
- else:
- if (
- maybe_cell is not None
- and maybe_cell.source.name()
- not in self.parent.mutated_closure_cell_contents
- ):
- # Why is the source name here unique?
- # mutated_closure_cell_contents is a per-frame
- # concept, and sources identify, e.g., particular
- # locals from the frame. If you had two locals,
- # they'll get different source names, and therefore
- # differ here.
- self.parent.mutated_closure_cell_contents.add(
- maybe_cell.source.name()
- )
- raise exc.RestartAnalysis()
- unimplemented("write to __closure__ while inlining")
- def LOAD_DEREF(self, inst):
- if inst.argval in self.closure_cells:
- cell = self.closure_cells[inst.argval]
- if isinstance(cell, ClosureVariable):
- self.push(self.output.root_tx.symbolic_locals[cell.name])
- else:
- self.push(self.output.side_effects.load_cell(cell))
- else:
- maybe_sym_local = self.symbolic_locals.get(inst.argval, None)
- if isinstance(maybe_sym_local, variables.NewCellVariable):
- self.push(self.output.side_effects.load_cell(maybe_sym_local))
- else:
- super().LOAD_DEREF(inst)
- def LOAD_CLOSURE(self, inst):
- assert inst.argval in self.cell_and_freevars()
- self.push(self.closure_cells[inst.argval])
- def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker):
- newvar = super().replace_all(oldvar, newvar)
- # recursively check and update parent's locals and stack in case oldvar is from parent
- translator: InstructionTranslatorBase = self
- while hasattr(translator, "parent"):
- translator = translator.parent # type: ignore[attr-defined]
- translator.update_locals_and_stack(oldvar, newvar)
- return newvar
- def should_compile_partial_graph(self):
- return False # inlining functions is all-or-nothing
- def create_call_resume_at(self, offset):
- unimplemented("cant resume while inlining")
- def RETURN_VALUE(self, inst):
- self.symbolic_result = self.pop()
- self.instruction_pointer = None
- class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
- generated_items: List[VariableTracker]
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.generated_items = []
- def YIELD_VALUE(self, inst: Instruction):
- self.generated_items.append(self.pop())
- # TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE
- self.push(ConstantVariable(None))
|