import collections import dataclasses import dis import functools import importlib import inspect import itertools import logging import operator import sys import traceback import types import typing import weakref from collections.abc import Sized from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple from unittest.mock import patch import torch from torch._guards import Checkpointable from . import ( allowed_functions, config, exc, logging as torchdynamo_logging, side_effects, skipfiles, variables, ) from .allowed_functions import is_allowed, is_builtin_callable, is_builtin_constant from .bytecode_analysis import JUMP_OPNAMES, livevars_analysis from .bytecode_transformation import ( cleaned_instructions, create_instruction, create_jump_absolute, Instruction, is_generator, unique_id, ) from .codegen import PyCodegen from .exc import BackendCompilerFailed, unimplemented, Unsupported from .guards import GuardBuilder from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState from .replay_record import DummyModule, ExecutionRecorder from .resume_execution import ContinueExecutionCache, ReenterWith from .source import ( AttrSource, GetItemSource, GlobalSource, GlobalWeakRefSource, LocalInputSource, LocalSource, ) from .utils import counters, graph_break_dup_warning_checker, istype, proxy_args_kwargs from .variables.base import MutableLocal, typestr, VariableTracker from .variables.builder import VariableBuilder, wrap_fx_proxy from .variables.builtin import BuiltinVariable from .variables.constant import ConstantVariable, EnumVariable from .variables.dicts import ConstDictVariable from .variables.functions import ( BaseUserFunctionVariable, NestedUserFunctionVariable, UserFunctionVariable, UserMethodVariable, ) from .variables.lists import ( BaseListVariable, ListIteratorVariable, ListVariable, SliceVariable, TupleVariable, ) from .variables.misc import ( ClosureVariable, ContextWrappingVariable, GetAttrVariable, GradModeVariable, PythonModuleVariable, UnknownVariable, WithExitFunctionVariable, ) from .variables.nn_module import NNModuleVariable from .variables.tensor import ( supported_const_comparison_ops, supported_tensor_comparison_ops, SymNodeVariable, TensorVariable, ) from .variables.torch import TorchVariable from .variables.user_defined import UserDefinedObjectVariable, UserDefinedVariable log = logging.getLogger(__name__) @functools.lru_cache(None) def _step_logger(): return torchdynamo_logging.get_step_logger(log) @dataclasses.dataclass class BlockStackEntry: target: Instruction stack_index: Optional[int] = None with_context: ContextWrappingVariable = None def can_restore(self): return self.with_context is not None def resume_fn(self): assert self.stack_index is not None if self.with_context and self.with_context.target_values: return ReenterWith(self.stack_index, tuple(self.with_context.target_values)) else: return ReenterWith(self.stack_index) def exit(self, tx): return self.with_context.exit(tx) class InstructionTranslatorGraphState(NamedTuple): output: OutputGraphState symbolic_locals: Dict[str, VariableTracker] stack: List[VariableTracker] block_stack: List[BlockStackEntry] instruction_pointer: Optional[int] current_instruction: Instruction next_instruction: Optional[Instruction] lineno: int def diff(self, other: "InstructionTranslatorGraphState") -> Optional[str]: for k in self._fields: if k == "output": return self.output.diff(other.output, prefix=f"{k}.") sv = getattr(self, k) ov = getattr(other, k) if sv != ov: return f"{k} mismatch: {sv} != {ov}" return None def stack_op(fn: typing.Callable[..., object]): nargs = len(inspect.signature(fn).parameters) fn_var = BuiltinVariable(fn) @functools.wraps(fn) def impl(self: "InstructionTranslatorBase", inst: Instruction): self.push(fn_var.call_function(self, self.popn(nargs), {})) return impl def _detect_and_normalize_assert_statement( self: "InstructionTranslatorBase", truth_fn: typing.Callable[[object], bool], push: bool, ): # Detect if this jump instruction is assert and normalize the assert # by pushing dummy error message when nothing is given. # # Python 3.9 assertion is in following format: # 18 POP_JUMP_IF_TRUE 28 # 20 LOAD_ASSERTION_ERROR # 22 LOAD_CONST 3 ('Assert message') -> optional instruction # 24 CALL_FUNCTION 1 -> optional instruction # 26 RAISE_VARARGS # # Python 3.8 assertion is in following format: # 18 POP_JUMP_IF_TRUE 28 # 20 LOAD_GLOBAL 0 (Assertion type) # 22 LOAD_CONST 3 ('Assert message') -> optional instruction # 24 CALL_FUNCTION 1 -> optional instruction # 26 RAISE_VARARGS 1 if (truth_fn is not operator.truth) or push: return False assert isinstance(self.instruction_pointer, int) current_instruction_pointer = self.instruction_pointer inst = self.instructions[current_instruction_pointer] # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0 if sys.version_info < (3, 9): if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError": return False else: if inst.opname != "LOAD_ASSERTION_ERROR": return False current_instruction_pointer += 1 if current_instruction_pointer >= len(self.instructions): return False inst = self.instructions[current_instruction_pointer] has_error_msg = False # DETECT RAISE_VARARGS or LOAD CONST if inst.opname == "LOAD_CONST": if not isinstance(inst.argval, str): return False self.LOAD_CONST(inst) has_error_msg = True # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION current_instruction_pointer += 1 if current_instruction_pointer >= len(self.instructions): return False inst = self.instructions[current_instruction_pointer] if inst.opname != "CALL_FUNCTION": return False # CALL_FUNCTION should be followed by RAISE_VARARGS current_instruction_pointer += 1 if current_instruction_pointer >= len(self.instructions): return False inst = self.instructions[current_instruction_pointer] if inst.opname != "RAISE_VARARGS": return False if not has_error_msg: # Push dummy value instead of error message self.push(ConstantVariable("assertion error")) return True def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): def inner(self: "InstructionTranslatorBase", inst: Instruction): value: VariableTracker = self.pop() self.output.guards.update(value.guards) if ( config.rewrite_assert_with_torch_assert and _detect_and_normalize_assert_statement(self, truth_fn, push) ): error_msg: VariableTracker = self.pop() self.output.guards.update(error_msg.guards) # Skip over things like `assert True` if value.is_python_constant() and bool(value.as_python_constant()): self.jump(inst) return # Manually insert torch._assert instead of python assert and jump over # assert related instructions as we don't need them anymore. self.output.create_proxy( "call_function", torch._assert, *proxy_args_kwargs((value, error_msg), {}), ) self.jump(inst) return if value.is_python_constant(): if truth_fn(value.as_python_constant()): push and self.push(value) self.jump(inst) elif ( isinstance(value, (TensorVariable)) and self.should_compile_partial_graph() ): # compile a partial subgraph prefix then jump into user code if self.has_backedge(): msg = ( "Skipping frame because there is a graph break in a for/while loop" ) log.debug(msg) raise exc.SkipFrame(msg) self.push(value) log.debug("generic_jump triggered compile") self.output.compile_subgraph( self, reason=GraphCompileReason( f"generic_jump {typestr(value)}", [self.frame_summary()] ), ) self.pop() if_next = self.create_call_resume_at(self.next_instruction) push and self.push(value) if_jump = self.create_call_resume_at(inst.target) self.output.add_output_instructions( [(create_instruction(inst.opname, target=if_jump[0]))] + if_next + if_jump ) elif isinstance(value, NNModuleVariable): # Equivant of "self.nn_module is not None" if truth_fn(value): push and self.push(value) self.jump(inst) elif isinstance(value, UserDefinedObjectVariable): x = value.var_getattr(self, "__bool__") # __bool__ is function if isinstance(x, UserMethodVariable): state = self.copy_graphstate() result = x.call_function(self, [], {}) if isinstance(result, ConstantVariable) and isinstance( result.value, bool ): self.output.guards.update(result.guards) if truth_fn(result.value): push and self.push(value) self.jump(inst) else: # rollback to the state before the __bool__ inline self.restore_graphstate(state) unimplemented( "generic_jump on UserDefined with __bool__ returning non-constant" ) # __bool__ is non-function or not existed in the user defined object else: if truth_fn(True): push and self.push(value) self.jump(inst) elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence( self ): if truth_fn(len(value.unpack_var_sequence(self))): push and self.push(value) self.jump(inst) elif isinstance(value, SymNodeVariable): eval_result = value.evaluate_expr(self.output) if truth_fn(eval_result): push and self.push(value) self.jump(inst) else: unimplemented(f"generic_jump {typestr(value)}") return inner explain = False def break_graph_if_unsupported(*, push): def decorator(inner_fn): @functools.wraps(inner_fn) def wrapper(self: "InstructionTranslatorBase", inst: Instruction): state = self.copy_graphstate() reason = None try: return inner_fn(self, inst) except Unsupported as excp: if self.has_backedge() and self.should_compile_partial_graph(): msg = "Skipping frame because there is a graph break in a for/while loop" log.debug(msg) raise exc.SkipFrame(msg) from excp if not self.should_compile_partial_graph(): raise log.debug("break_graph_if_unsupported triggered compile", exc_info=True) user_stack = [self.frame_summary()] + list(reversed(excp.real_stack)) user_stack_formatted = "".join(traceback.format_list(user_stack)) frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) # torch._dynamo.explain() formats this a little nicer, and presents a slightly # more actionable user code pointer if ( config.print_graph_breaks and not explain and graph_break_dup_warning_checker.add(frame_loc) ): log.warning( f"Graph break: {excp} from user code at {user_stack_formatted}" ) excp.remove_from_stats() excp.add_to_stats("graph_break") reason = GraphCompileReason(excp.msg, user_stack) self.restore_graphstate(state) self.output.compile_subgraph(self, reason=reason) self.popn(push - dis.stack_effect(inst.opcode, inst.arg)) for _ in range(push): self.push(UnknownVariable()) resume_call_insts = self.create_call_resume_at(self.next_instruction) # Check if there is a block stack entry with GradModeVariable. And # wrap the instruction causing the graph break inside a try..finally # block. See more details at # https://github.com/pytorch/torchdynamo/issues/207 cleanup = [] if len(self.block_stack) == 1 and isinstance( self.block_stack[0].with_context, GradModeVariable ): ctx_variable = self.block_stack[0].with_context cg = PyCodegen(self) setup_finally, cleanup = ctx_variable.reconstruct( cg, resume_call_insts[0] ) self.output.add_output_instructions(setup_finally) self.output.add_output_instructions([inst]) # Add the cleanup instructions from try..finally block self.output.add_output_instructions(cleanup) self.output.add_output_instructions( resume_call_insts, ) return wrapper return decorator def is_none(x): return x is None def is_not_none(x): return x is not None class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState]): output: OutputGraph symbolic_locals: Dict[str, VariableTracker] symbolic_globals: Dict[str, VariableTracker] stack: List[VariableTracker] instruction_pointer: Optional[int] current_instruction: Instruction next_instruction: Optional[Instruction] block_stack: List[BlockStackEntry] lineno: int mutated_closure_cell_contents: Set[str] checkpoint: Optional[Tuple[Instruction, InstructionTranslatorGraphState]] random_calls: List[ Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]] ] def has_backedge(self): cur_offset = self.current_instruction.offset assert self.instruction_pointer is not None for inst in self.instructions[self.instruction_pointer :]: if inst.opname in JUMP_OPNAMES: jump_offset = inst.argval if jump_offset < cur_offset: return True return False def cell_and_freevars(self): if not hasattr(self, "_cell_and_freevars"): self._cell_and_freevars = tuple( self.code_options["co_cellvars"] or [] ) + tuple(self.code_options["co_freevars"] or []) return self._cell_and_freevars def prune_dead_locals(self): reads = livevars_analysis(self.instructions, self.current_instruction) # implicit use by super() # reads = reads | {"__class__"} # output variables? reads = reads | set(self.cell_and_freevars()) self.symbolic_locals = collections.OrderedDict( [(k, v) for k, v in self.symbolic_locals.items() if k in reads] ) self.output.side_effects.prune_dead_object_new(self) def call_function( self, fn: VariableTracker, args: List[VariableTracker], kwargs: Dict[str, VariableTracker], ): assert isinstance(fn, VariableTracker) assert isinstance(args, list) assert isinstance(kwargs, dict) assert all( isinstance(x, VariableTracker) for x in itertools.chain(args, kwargs.values()) ) self.push(fn.call_function(self, args, kwargs)) def update_locals_and_stack(self, oldvar: VariableTracker, newvar: VariableTracker): def repl(v: VariableTracker): if v.mutable_local is oldvar.mutable_local: return newvar return v def skip(v: VariableTracker): return oldvar.mutable_local not in v.recursively_contains cache: Dict[int, Tuple[object, object]] = dict() self.output.side_effects.apply(repl, cache, skip_fn=skip) self.stack = [ VariableTracker.apply(repl, x, cache, skip_fn=skip) for x in self.stack ] for k, x in self.symbolic_locals.items(): self.symbolic_locals[k] = VariableTracker.apply( repl, x, cache, skip_fn=skip ) def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker): if isinstance(oldvar.mutable_local, side_effects.MutableSideEffects): newvar = self.output.side_effects.mutation(oldvar, newvar) else: assert isinstance(oldvar.mutable_local, variables.base.MutableLocal) newvar = newvar.clone(mutable_local=variables.base.MutableLocal()) self.update_locals_and_stack(oldvar, newvar) return newvar def inline_user_function_return(self, fn, args, kwargs): """ A call to some user defined function by inlining it. """ state = self.copy_graphstate() try: result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs) self.output.guards.update(fn.guards) return result except Exception: self.restore_graphstate(state) raise def step(self): """Process exactly one instruction, return False we should exit""" assert isinstance(self.instruction_pointer, int) inst = self.instructions[self.instruction_pointer] self.current_instruction = inst self.instruction_pointer += 1 if self.instruction_pointer < len(self.instructions): self.next_instruction = self.instructions[self.instruction_pointer] else: self.instruction_pointer = None self.next_instruction = None if inst.starts_line and self.lineno != inst.starts_line: self.lineno = inst.starts_line log.debug(f"TRACE starts_line {self.f_code.co_filename}:{self.lineno}") if len(self.stack) == 0 and self.should_compile_partial_graph(): self.checkpoint = inst, self.copy_graphstate() log.debug(f"TRACE {inst.opname} {inst.argval} {self.stack}") try: if not hasattr(self, inst.opname): unimplemented(f"missing: {inst.opname}") getattr(self, inst.opname)(inst) return inst.opname != "RETURN_VALUE" except BackendCompilerFailed: raise except Unsupported as exc: exc.real_stack.append(self.frame_summary()) if self.empty_checkpoint(): raise log.debug("step triggered compile", exc_info=True) except Exception as exc: real_stack = getattr(exc, "real_stack", []) real_stack.append(self.frame_summary()) exc.real_stack = real_stack # type: ignore[attr-defined] raise # generate code from checkpoint assert not self.output.output_instructions assert self.checkpoint is not None continue_inst, state = self.checkpoint self.restore_graphstate(state) self.output.compile_subgraph( self, partial_convert=True, reason=GraphCompileReason("step_unsupported", [self.frame_summary()]), ) self.output.add_output_instructions( [create_jump_absolute(continue_inst)] + self.instructions ) def run(self): try: self.output.push_tx(self) while ( self.instruction_pointer is not None and not self.output.should_exit and self.step() ): pass except BackendCompilerFailed: raise except Exception as e: if config.replay_record_enabled: e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined] raise finally: self.output.pop_tx() # Cleanup the outputGraph to delete the held tensors. We perform the # cleanup only for InstructionTranslator and not # InliningInstructionTranslator. The InliningInstructionTranslator # mutates the output object and is restored to original state if # there was an exception. if isinstance(self, InstructionTranslator): self.output.cleanup() def push(self, val: Optional[VariableTracker]): assert val is None or isinstance( val, VariableTracker ), f"push expects VariableTracker, got {typestr(val)}" self.stack.append(val) def push_many(self, vals: List[VariableTracker]): for val in vals: self.push(val) def pop(self) -> VariableTracker: return self.stack.pop() def popn(self, n: int) -> List[VariableTracker]: assert n >= 0 return list(reversed([self.pop() for _ in range(n)])) def LOAD_FAST(self, inst): name = inst.argval if name in self.f_locals and config.replay_record_enabled: self.exec_recorder.add_local_var(name, self.f_locals[name]) if name.startswith(".") and name not in self.symbolic_locals: # This happens in dict/list comprehensions name = name.replace(".", "implicit") assert name not in self.cell_and_freevars() if name not in self.symbolic_locals: unimplemented("undefined LOAD_FAST") self.push(self.symbolic_locals[name]) if name.startswith("___stack"): self.symbolic_locals.pop(name) def LOAD_DEREF(self, inst): assert inst.argval in self.cell_and_freevars() if inst.argval in self.f_locals and config.replay_record_enabled: self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval]) if inst.argval not in self.symbolic_locals: unimplemented(f"undefined LOAD_DEREF {inst.argval}") self.push(self.symbolic_locals[inst.argval]) def STORE_FAST(self, inst): self.symbolic_locals[inst.argval] = self.pop() def DELETE_FAST(self, inst): del self.symbolic_locals[inst.argval] STORE_DEREF = STORE_FAST def LOAD_CLOSURE(self, inst): self.push(ClosureVariable(name=inst.argval)) def LOAD_CONST(self, inst): self.push(ConstantVariable(value=inst.argval)) def get_global_source(self, name): if self.output.root_globals is self.f_globals: source = GlobalSource(name) else: if "__name__" in self.f_globals: source = AttrSource( self.import_source(self.f_globals["__name__"]), name ) else: mangled_name = f"___unnamed_scope_{id(self.f_globals)}" if mangled_name not in self.output.root_globals: self.output.install_global(mangled_name, self.f_globals) source = GetItemSource(GlobalSource(mangled_name), name) return source def LOAD_GLOBAL(self, inst): name = inst.argval if config.replay_record_enabled: if name in self.f_globals: self.exec_recorder.add_global_var(name, self.f_globals[name]) else: assert name in self.f_builtins self.exec_recorder.builtins[name] = self.f_builtins[name] if name in self.symbolic_globals: variable = self.output.side_effects[self.symbolic_globals[name]] self.push(self.output.side_effects.load_global(variable, name)) return try: value = self.f_globals[name] except KeyError: return self.load_builtin(inst) source = self.get_global_source(name) self.push(VariableBuilder(self, source)(value)) def STORE_GLOBAL(self, inst): value = self.pop() name = inst.argval source = self.get_global_source(name) if name not in self.symbolic_globals: self.symbolic_globals[name] = object() # sentinel object variable = self.output.side_effects.track_global_existing( source, self.symbolic_globals[name] ) self.output.side_effects.store_global(variable, name, value) def import_source(self, module_name): """Create an alias to a module for use in guards""" if "torch_package" in module_name: value = torch.package.package_importer._package_imported_modules[ module_name ] alias = ( module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_") ) else: value = importlib.import_module(module_name) alias = f"__import_{module_name.replace('.', '_dot_')}" f_globals = self.output.root_globals assert alias not in f_globals or f_globals[alias] is value f_globals[alias] = value self.output.update_co_names(alias) return GlobalSource(alias) def resolve_name(self, name, package, level): """ Copied from the Cpython implementation of __import__ Resolve a relative module name to an absolute one. https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902 """ bits = package.rsplit(".", level - 1) if len(bits) < level: raise ImportError("attempted relative import beyond top-level package") base = bits[0] return "{}.{}".format(base, name) if name else base def calc_package(self): """ Copied from the Cpython implementation of __import__ https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090 """ package = self.f_globals.get("__package__") spec = self.f_globals.get("__spec__") if package is not None: if spec is not None and package != spec.parent: log.warning( "__package__ != __spec__.parent " f"({package!r} != {spec.parent!r})", ImportWarning, stacklevel=3, ) # type: ignore[call-arg] return package elif spec is not None: return spec.parent else: log.warning( "can't resolve package from __spec__ or __package__, " "falling back on __name__ and __path__", ImportWarning, stacklevel=3, ) # type: ignore[call-arg] package = self.f_globals["__name__"] if "__path__" not in self.f_globals: package = package.rpartition(".")[0] return package def IMPORT_NAME(self, inst): level, fromlist = self.popn(2) level = level.as_python_constant() fromlist = fromlist.as_python_constant() module_name = inst.argval # Are we replaying? if so, load recorded module recorded_name = ( f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}" ) if recorded_name in self.f_globals: value = self.f_globals[recorded_name] source = GlobalSource(recorded_name) else: value = __import__( module_name, fromlist=fromlist, level=level, globals=self.f_globals, ) if level != 0: pkg = self.calc_package() module_name = self.resolve_name(module_name, pkg, level) # For __import__, when the name variable is of the form package.module, # normally, the top-level package (the name up till the first dot) is # returned, not the module named by module_name. However, when a # non-empty fromlist argument is given, the module named by name is # returned. Therefore, we set the source correctly here. if not fromlist: top_level_module_name = module_name.partition(".")[0] source = self.import_source(top_level_module_name) else: source = self.import_source(module_name) if config.replay_record_enabled: self.exec_recorder.add_local_mod(recorded_name, value) if is_allowed(value): self.push(TorchVariable(value, source=source)) elif istype(value, (types.ModuleType, DummyModule)): self.push(PythonModuleVariable(value, source=source)) else: unimplemented(f"IMPORT_NAME {typestr(value)}") def IMPORT_FROM(self, inst): self.DUP_TOP(inst) self.LOAD_ATTR(inst) def load_builtin(self, inst): assert inst.argval in self.f_builtins val = self.f_builtins[inst.argval] if callable(val): assert is_builtin_callable(val) self.push(VariableBuilder(self, GlobalSource(inst.argval))(val)) else: assert is_builtin_constant(val) self.push(ConstantVariable(value=val)) def jump(self, inst): self.instruction_pointer = self.indexof[id(inst.target)] JUMP_FORWARD = jump JUMP_ABSOLUTE = jump POP_JUMP_IF_FALSE = generic_jump(operator.not_, False) POP_JUMP_IF_TRUE = generic_jump(operator.truth, False) JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True) JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True) def SETUP_LOOP(self, inst): # only exists in python<=3.7 self.block_stack.append(BlockStackEntry(inst.target)) def SETUP_EXCEPT(self, inst): # only exists in python<=3.7 self.block_stack.append(BlockStackEntry(inst.target)) def POP_BLOCK(self, inst): self.block_stack.pop() def SETUP_WITH(self, inst): ctx = self.pop() if not isinstance(ctx, ContextWrappingVariable): unimplemented(f"SETUP_WITH {ctx}") self.output.guards.update(ctx.guards) if isinstance(self, InstructionTranslator): self.block_stack.append(BlockStackEntry(inst.target, len(self.stack), ctx)) else: # can't restore this while inlining self.block_stack.append(BlockStackEntry(inst.target)) self.push( WithExitFunctionVariable( ctx, inst.target, **VariableTracker.propagate(ctx), ) ) self.push(ctx.enter(self)) def SETUP_FINALLY(self, inst): self.block_stack.append(BlockStackEntry(inst.target)) def BEGIN_FINALLY(self, inst): self.push(None) def WITH_CLEANUP_START(self, inst): exit, exc = self.popn(2) assert exc is None self.push(exc) self.push(exit.call_function(self, [ConstantVariable(None)] * 3, {})) def WITH_CLEANUP_FINISH(self, inst): self.popn(2) self.push(None) def END_FINALLY(self, inst): tos = self.pop() assert tos is None def FOR_ITER(self, inst): it = self.pop() if isinstance(it, ListIteratorVariable): self.output.guards.update(it.guards) try: val, next_iter = it.next_variables() self.replace_all(it, next_iter) self.push(next_iter) self.push(val) except StopIteration: self.jump(inst) else: unimplemented(f"FOR_ITER {typestr(it)}") def COMPARE_OP(self, inst): left, right = self.popn(2) left = left.as_specialized(self) right = right.as_specialized(self) options = VariableTracker.propagate([left, right]) op = inst.argval supported_any = dict( itertools.chain( supported_tensor_comparison_ops.items(), supported_const_comparison_ops.items(), ) ) if ( isinstance( left, ( TensorVariable, SymNodeVariable, NNModuleVariable, BaseListVariable, UserDefinedVariable, BaseUserFunctionVariable, ConstDictVariable, ), ) and isinstance(right, ConstantVariable) and right.value is None and op in supported_const_comparison_ops ): # is None self.push( ConstantVariable( supported_const_comparison_ops[op](object(), right.value), **options ) ) elif ( left.is_python_constant() and right.is_python_constant() and op in supported_any ): # constant fold self.push( ConstantVariable( supported_any[op]( left.as_python_constant(), right.as_python_constant() ), **options, ) ) elif op in ("in", "not in"): self.push(right.call_method(self, "__contains__", [left], {})) if op == "not in": self.UNARY_NOT(inst) else: self.push( BuiltinVariable(supported_any[op], **options).call_function( self, [left, right], {} ) ) def GET_ITER(self, inst): self.call_function(BuiltinVariable(iter), [self.pop()], {}) @break_graph_if_unsupported(push=1) def CALL_FUNCTION(self, inst): args = self.popn(inst.argval) fn = self.pop() self.call_function(fn, args, {}) @break_graph_if_unsupported(push=1) def CALL_FUNCTION_EX(self, inst): if inst.argval == 0: kwargsvars = ConstDictVariable({}, dict) argsvars = self.pop() elif inst.argval == 1: kwargsvars = self.pop() argsvars = self.pop() else: unimplemented("CALL_FUNCTION_EX") fn = self.pop() self.output.guards.update(argsvars.guards) self.output.guards.update(kwargsvars.guards) if ( isinstance(fn, GetAttrVariable) and isinstance(fn.obj, TensorVariable) and fn.name == "view" and isinstance(argsvars, (ConstantVariable, TensorVariable)) ): # Hack to handle special case in some bert models. Converts # x.view(*shape) into x.view(shape), which is correct for view() # but not generally. See test_transpose_for_scores(). argsvars = TupleVariable([argsvars]) if not isinstance( argsvars, BaseListVariable ) and argsvars.has_unpack_var_sequence(self): argsvars = TupleVariable(argsvars.unpack_var_sequence(self)) if not isinstance(argsvars, BaseListVariable) or not isinstance( kwargsvars, ConstDictVariable ): unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}") self.call_function(fn, argsvars.items, kwargsvars.items) @break_graph_if_unsupported(push=1) def CALL_FUNCTION_KW(self, inst): argnames = self.pop() args = self.popn(inst.argval) fn = self.pop() assert isinstance(argnames, ConstantVariable) argnames = argnames.value args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :] kwargs = dict(zip(argnames, kwargs_list)) assert len(kwargs) == len(argnames) self.call_function(fn, args, kwargs) def LOAD_METHOD(self, inst): self.LOAD_ATTR(inst) self.push(self.pop()) self.push(None) def CALL_METHOD(self, inst): args = self.popn(inst.argval) dummy = self.pop() assert dummy is None fn = self.pop() self.call_function(fn, args, {}) def LOAD_ATTR(self, inst): obj = self.pop() result = BuiltinVariable(getattr).call_function( self, [obj, ConstantVariable(inst.argval)], {} ) self.push(result) def STORE_ATTR(self, inst): prior = self.copy_graphstate() val, obj = self.popn(2) if isinstance(obj, NNModuleVariable): # We don't allow side effects during export # https://github.com/pytorch/torchdynamo/issues/1475 assert ( not self.export ), f"Mutating module attribute {inst.argval} during export." try: self.output.guards.update( BuiltinVariable(setattr) .call_function(self, [obj, ConstantVariable(inst.argval), val], {}) .guards ) return except Unsupported as e: if not self.should_compile_partial_graph(): raise log.debug("STORE_ATTR triggered compile", exc_info=True) e.remove_from_stats() e.add_to_stats("graph_break") self.restore_graphstate(prior) # break the graph self.output.compile_subgraph( self, reason=GraphCompileReason("store_attr", [self.frame_summary()]) ) self.output.add_output_instructions([inst]) self.popn(2) self.output.add_output_instructions( self.create_call_resume_at(self.next_instruction) ) def create_call_resume_at(self, offset): raise AssertionError( f"create_call_resume_at not overridden by subclass {type(self)}" ) def should_compile_partial_graph(self) -> bool: raise AssertionError( f"should_compile_partial_graph not overridden by subclass {type(self)}" ) @break_graph_if_unsupported(push=0) def STORE_SUBSCR(self, inst): val, obj, key = self.popn(3) result = obj.call_method(self, "__setitem__", [key, val], {}) # no result is pushed, so need to lift the guards to global self.output.guards.update(result.guards) def BUILD_TUPLE(self, inst): items = self.popn(inst.argval) options = VariableTracker.propagate(items) self.push(TupleVariable(items, **options)) def BUILD_SLICE(self, inst): items = self.popn(inst.argval) options = VariableTracker.propagate(items) self.push( SliceVariable( [x.as_specialized(self) for x in items], **options, ) ) def BUILD_LIST(self, inst): items = self.popn(inst.argval) options = VariableTracker.propagate(items) self.push(ListVariable(items, mutable_local=MutableLocal(), **options)) def BUILD_LIST_UNPACK(self, inst, cls=ListVariable): seqs = self.popn(inst.argval) options = VariableTracker.propagate(seqs) items = list() for seq in seqs: try: items.extend(seq.unpack_var_sequence(self)) except NotImplementedError: unimplemented(f"BUILD_LIST_UNPACK {seq}") self.push(cls(items, mutable_local=MutableLocal(), **options)) def BUILD_TUPLE_UNPACK(self, inst): self.BUILD_LIST_UNPACK(inst, cls=TupleVariable) BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK def BUILD_MAP(self, inst): items = self.popn(inst.argval * 2) options = VariableTracker.propagate(items) result = dict() for k, v in zip(items[::2], items[1::2]): assert isinstance(k, (ConstantVariable, EnumVariable)) or ( isinstance(k, TensorVariable) and k.specialized_value is not None ) result[ConstDictVariable.get_key(k)] = v assert len(result) == len(items) / 2 self.push( ConstDictVariable(result, dict, mutable_local=MutableLocal(), **options) ) def BUILD_CONST_KEY_MAP(self, inst): keys = self.pop() values = self.popn(inst.argval) options = VariableTracker.propagate([keys] + values) assert isinstance(keys, ConstantVariable) keys = keys.value assert istype(keys, tuple) assert len(keys) == len(values) self.push( ConstDictVariable( dict(zip(keys, values)), dict, mutable_local=MutableLocal(), **options, ) ) def MAP_ADD(self, inst): k, v = self.popn(2) assert inst.argval > 0 obj = self.stack[-inst.arg] assert isinstance(obj, ConstDictVariable) assert obj.mutable_local items = dict(obj.items) items[k.as_python_constant()] = v self.replace_all( obj, ConstDictVariable( items, obj.user_cls, **VariableTracker.propagate([obj, k, v]), ), ) def LIST_APPEND(self, inst): v = self.pop() assert inst.argval > 0 obj = self.stack[-inst.arg] assert isinstance(obj, ListVariable) assert obj.mutable_local # only copy if the new obj contains other mutables new_rec_contains = obj.recursively_contains if v.recursively_contains or v.mutable_local: new_rec_contains = obj.recursively_contains.union(v.recursively_contains) if v.mutable_local: new_rec_contains.add(v.mutable_local) self.replace_all( obj, ListVariable( obj.items + [v], recursively_contains=new_rec_contains, regen_guards=False, **VariableTracker.propagate([obj, v]), ), ) def MAKE_FUNCTION(self, inst): flags = inst.arg old_stack = list(self.stack) fn_name = self.pop() code = self.pop() defaults = None closure = None annotations = None kwdefaults = None if flags & 0x08: closure = self.pop() if flags & 0x04: annotations = self.pop() if flags & 0x02: kwdefaults = self.pop() if flags & 0x01: defaults = self.pop() options = VariableTracker.propagate(old_stack[len(self.stack) :]) self.push( NestedUserFunctionVariable( fn_name, code, self.f_globals, defaults, kwdefaults, annotations, closure, closure_scope=self, **options, ) ) def UNPACK_SEQUENCE(self, inst): seq = self.pop() if isinstance(seq, BaseListVariable): self.output.guards.update(seq.guards) val = seq.unpack_var_sequence(self) elif seq.is_python_constant() and isinstance(seq, ConstantVariable): val = seq.unpack_var_sequence(self) elif isinstance(seq, TensorVariable): val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable): # x, y = a.shape proxy = getattr(seq.obj.as_proxy(), seq.name) options = VariableTracker.propagate(self) val = [wrap_fx_proxy(self, proxy[i], **options) for i in range(inst.argval)] else: unimplemented(f"UNPACK_SEQUENCE {seq}") assert len(val) == inst.argval for i in reversed(val): self.push(i) def UNPACK_EX(self, inst): assert 0 <= inst.argval <= 0xFFFF prefix = inst.argval & 0xFF # low byte suffix = inst.argval >> 8 # high byte seq = self.pop() options = VariableTracker.propagate(seq) if seq.has_unpack_var_sequence(self): vals = list(seq.unpack_var_sequence(self)) assert len(vals) >= prefix + suffix vals_prefix = vals[:prefix] vals_list = vals[prefix : len(vals) - suffix] vals_suffix = vals[len(vals) - suffix :] for item in reversed(vals_suffix): self.push(item.add_options(options)) self.push(TupleVariable(vals_list, **options)) for item in reversed(vals_prefix): self.push(item.add_options(options)) else: unimplemented(f"UNPACK_EX {seq}") def NOP(self, inst): pass def POP_TOP(self, inst): self.pop() def ROT_TWO(self, inst): a = self.pop() b = self.pop() self.push(a) self.push(b) def ROT_THREE(self, inst): a = self.pop() b = self.pop() c = self.pop() self.push(a) self.push(c) self.push(b) def ROT_FOUR(self, inst): a = self.pop() b = self.pop() c = self.pop() d = self.pop() self.push(a) self.push(d) self.push(c) self.push(b) def DUP_TOP(self, inst): a = self.pop() self.push(a) self.push(a) def DUP_TOP_TWO(self, inst): a = self.pop() b = self.pop() self.push(b) self.push(a) self.push(b) self.push(a) def FORMAT_VALUE(self, inst): flags = inst.arg if (flags & 0x04) == 0x04: fmt_spec = self.pop() else: fmt_spec = ConstantVariable("") value = self.pop() if isinstance(value, SymNodeVariable): value = ConstantVariable(str(value.sym_num)) if (flags & 0x03) == 0x01: value = BuiltinVariable(str).call_function(self, [value], {}) elif (flags & 0x03) == 0x02: value = BuiltinVariable(repr).call_function(self, [value], {}) elif (flags & 0x03) == 0x03: value = BuiltinVariable(ascii).call_function(self, [value], {}) fmt_var = ConstantVariable( "{:" + fmt_spec.as_python_constant() + "}" ).add_options(fmt_spec) self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) def BUILD_STRING(self, inst): result = "" for _ in range(inst.arg): str_var = self.pop() assert isinstance(str_var, ConstantVariable) result = str_var.value + result self.push(ConstantVariable(value=result)) def IS_OP(self, inst): assert inst.argval == 0 or inst.argval == 1 if inst.argval == 0: new_argval = "is" else: new_argval = "is not" new_inst = create_instruction("COMPARE_OP", argval=new_argval) self.COMPARE_OP(new_inst) def CONTAINS_OP(self, inst): assert inst.argval == 0 or inst.argval == 1 left, right = self.popn(2) op = inst.argval self.push(right.call_method(self, "__contains__", [left], {})) if op == 1: self.UNARY_NOT(inst) def LIST_EXTEND(self, inst): v = self.pop() assert inst.argval > 0 obj = self.stack[-inst.arg] assert isinstance(obj, ListVariable) assert obj.mutable_local obj.call_method(self, "extend", [v], {}) def LIST_TO_TUPLE(self, inst): self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) def DICT_MERGE(self, inst): v = self.pop() assert inst.argval > 0 obj = self.stack[-inst.arg] assert isinstance(obj, ConstDictVariable) assert obj.mutable_local obj.call_method(self, "update", [v], {}) def GEN_START(self, inst): self.pop() def GET_LEN(self, inst): tos = self.stack[-1] if tos.is_python_constant(): self.push(ConstantVariable(len(tos.as_python_constant()))) else: self.push(tos.call_method(self, "__len__", [], {})) def MATCH_MAPPING(self, inst): tos = self.stack[-1] assert isinstance(tos, ConstDictVariable) if isinstance(tos.items, collections.abc.Mapping): self.push(ConstantVariable(True)) else: self.push(ConstantVariable(False)) def MATCH_SEQUENCE(self, inst): tos = self.stack[-1] assert tos.is_python_constant() tos_value = tos.as_python_constant() if isinstance(tos_value, collections.abc.Sequence) and not isinstance( tos_value, (str, bytes, bytearray) ): self.push(ConstantVariable(True)) else: self.push(ConstantVariable(False)) def MATCH_KEYS(self, inst): tos = self.stack[-1] assert tos.is_python_constant() keys = tos.as_python_constant() tos1 = self.stack[-2] assert isinstance(tos1, ConstDictVariable) match_obj = tos1.items if all(key in match_obj for key in keys): self.push(TupleVariable([match_obj[key] for key in keys])) self.push(ConstantVariable(True)) else: self.push(ConstantVariable(None)) self.push(ConstantVariable(False)) UNARY_POSITIVE = stack_op(operator.pos) UNARY_NEGATIVE = stack_op(operator.neg) UNARY_NOT = stack_op(operator.not_) UNARY_INVERT = stack_op(operator.invert) BINARY_POWER = stack_op(operator.pow) BINARY_MULTIPLY = stack_op(operator.mul) BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul) BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv) BINARY_TRUE_DIVIDE = stack_op(operator.truediv) BINARY_MODULO = stack_op(operator.mod) BINARY_REMAINDER = stack_op(operator.mod) BINARY_ADD = stack_op(operator.add) BINARY_SUBTRACT = stack_op(operator.sub) BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem)) BINARY_LSHIFT = stack_op(operator.lshift) BINARY_RSHIFT = stack_op(operator.rshift) BINARY_AND = stack_op(operator.and_) BINARY_OR = stack_op(operator.or_) BINARY_XOR = stack_op(operator.xor) INPLACE_POWER = stack_op(operator.ipow) INPLACE_MULTIPLY = stack_op(operator.imul) INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul) INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv) INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv) INPLACE_MODULO = stack_op(operator.imod) INPLACE_REMAINDER = stack_op(operator.imod) INPLACE_ADD = stack_op(operator.iadd) INPLACE_SUBTRACT = stack_op(operator.isub) INPLACE_LSHIFT = stack_op(operator.ilshift) INPLACE_RSHIFT = stack_op(operator.irshift) INPLACE_AND = stack_op(operator.iand) INPLACE_XOR = stack_op(operator.ixor) INPLACE_OR = stack_op(operator.ior) # 3.11 opcodes # note: passed opcodes are intentional def RESUME(self, inst): pass def BINARY_OP(self, inst): if sys.version_info >= (3, 11): opname = dis._nb_ops[inst.arg][0][3:] if opname.startswith("INPLACE"): return getattr(self, "INPLACE_" + opname[8:])(inst) return getattr(self, "BINARY_" + opname)(inst) else: unimplemented("BINARY_OP requires Python 3.11+") def COPY(self, inst): self.push(self.stack[-inst.arg]) def SWAP(self, inst): self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1] JUMP_BACKWARD = jump JUMP_BACKWARD_NO_INTERRUPT = jump POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False) POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False) POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False) POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False) POP_JUMP_FORWARD_IF_NOT_NONE = generic_jump(is_not_none, False) POP_JUMP_BACKWARD_IF_NOT_NONE = generic_jump(is_not_none, False) POP_JUMP_FORWARD_IF_NONE = generic_jump(is_none, False) POP_JUMP_BACKWARD_IF_NONE = generic_jump(is_none, False) def CACHE(self, inst): pass def copy_graphstate(self) -> InstructionTranslatorGraphState: """Create a checkpoint of the current state by copying everything""" return InstructionTranslatorGraphState( self.output.copy_graphstate(), collections.OrderedDict(self.symbolic_locals), list(self.stack), list(self.block_stack), self.instruction_pointer, self.current_instruction, self.next_instruction, self.lineno, ) def restore_graphstate(self, state: InstructionTranslatorGraphState): """Restore a checkpoint created by self.copy_graphstate()""" ( output_state, self.symbolic_locals, self.stack, self.block_stack, self.instruction_pointer, self.current_instruction, self.next_instruction, self.lineno, ) = state self.output.restore_graphstate(output_state) def empty_checkpoint(self): if self.checkpoint is None: return True output_graphstate = self.checkpoint[1][0] graphstate = self.checkpoint[1][1:] state = (*output_graphstate, *graphstate) for obj in state: if isinstance(obj, Sized): if len(obj) != 0: return False return True def format_frame_summary(self, additional_stack_frames=None): if additional_stack_frames is None: additional_stack_frames = [] return "".join( traceback.format_list( ([self.frame_summary()] + list(reversed(additional_stack_frames))) ) ) def frame_summary(self): return traceback.FrameSummary( getattr(self.f_code, "co_filename", ""), self.lineno, getattr(self.f_code, "co_name", ""), lookup_line=False, ) def store_dict_key(self, name, value): self.output.guards.add( GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE) ) if name not in self.output.root_globals: self.output.install_global(name, weakref.ref(value)) @property def fake_mode(self): return self._fake_mode def find_symbolic_locals_name(self, tensor_variable): for key, value in self.symbolic_locals.items(): if value is tensor_variable: return key return None def __init__( self, output: OutputGraph, instructions: List[Instruction], f_locals: Dict[str, Any], f_globals: Dict[str, Any], f_builtins: Dict[str, Any], code_options: Dict[str, Any], symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], f_code: types.CodeType, export: bool, ): super().__init__() # Mutable state checkpointed by copy_graphstate() self.output = output self.symbolic_locals = symbolic_locals self.symbolic_globals = symbolic_globals self.stack = [] self.instruction_pointer = 0 self.current_instruction = create_instruction("NOP") self.next_instruction = None self.block_stack = [] self.lineno = code_options["co_firstlineno"] # Properties of the input/output code self.instructions: List[Instruction] = instructions self.indexof: Dict[int, int] = {id(i): n for n, i in enumerate(instructions)} self.f_locals: Dict[ str, Any ] = f_locals # needed for recording accessed locals for replay self.f_globals: Dict[str, Any] = f_globals self.f_builtins: Dict[str, Any] = f_builtins self.code_options: Dict[str, Any] = code_options self.f_code: types.CodeType = f_code # Execution record for replaying errors self.exec_recorder = ExecutionRecorder(code=f_code, code_options=code_options) # Stack of module being parsed, current nn.module is at the end of ordered dict self.nn_module_stack: Dict[str, str] = {} # Flag to indicate whether tracing is used for export. self.export = export self._fake_mode = output.tracing_context.fake_mode self.checkpoint = None self.random_calls = [] if sys.version_info >= (3, 10): from .resume_execution import ( CO_ASYNC_GENERATOR, CO_COROUTINE, CO_GENERATOR, CO_ITERABLE_COROUTINE, ) if f_code.co_flags & ( CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR ): self.push(BuiltinVariable(None)) class InstructionTranslator(InstructionTranslatorBase): def __init__( self, instructions: List[Instruction], f_code, f_locals, f_globals, f_builtins, code_options, compiler_fn, one_graph, export, mutated_closure_cell_contents: Set[str], ): super().__init__( output=OutputGraph(f_globals, code_options, compiler_fn, self), instructions=instructions, f_locals=f_locals, f_globals=f_globals, f_builtins=f_builtins, code_options=code_options, symbolic_locals=collections.OrderedDict(), # set below # A global var is inserted only after a STORE_GLOBAL happens to it symbolic_globals=collections.OrderedDict(), f_code=f_code, export=export, ) self.one_graph: bool = one_graph self.export = export self.mutated_closure_cell_contents = mutated_closure_cell_contents if self.export: assert ( self.one_graph ), "Export without one graph - something has gone wrong." vars = list(code_options["co_varnames"]) vars.extend(x for x in self.cell_and_freevars() if x not in vars) self.symbolic_locals = collections.OrderedDict( ( k, VariableBuilder( self, LocalInputSource(k, code_options["co_varnames"].index(k)) if k in code_options["co_varnames"] else LocalSource((k)), )(f_locals[k]), ) for k in vars if k in f_locals ) # symbolic_locals contains the mapping from original f_locals to the # Variable objects. During the Variable building phase, each object also # has its associated guards. At the end, we will accumulate these # guards. # # One way of handling these guards is to just accumulate all of them # right now. However, many f_locals might not be used in the frame and # thus can unnecessarily increase guard execution overhead. Therefore, # we selectively update output.guards as we run the Python Bytecode # instruction by instruction. # # An exception here is list/dict variables. Guards related to these # variables have indexed access, like Tensor_match on args[0], and if # args is not used in this frame, we will miss a LIST_LENGTH check like # len(args) == 2. Missing the LIST_LENGTH check causes problem for the # next invocation when args is not a list, and args[0] is a runtime # error. Therefore, we recursively add guards for list/dict variable here. for val in self.symbolic_locals.values(): if isinstance( val, (ListIteratorVariable, BaseListVariable, ConstDictVariable) ): local_guards = VariableTracker.propagate(val)["guards"] index_guards = [ guard for guard in local_guards if guard.create_fn in ( GuardBuilder.LIST_LENGTH, GuardBuilder.DICT_KEYS, GuardBuilder.ODICT_KEYS, GuardBuilder.TUPLE_ITERATOR_LEN, ) ] self.output.guards.update(index_guards) self._freevars_ids = dict() for name in self.code_options["co_freevars"]: if name in f_locals: self._freevars_ids[name] = id(f_locals[name]) def run(self): _step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}") super().run() def match_nested_cell(self, name, cell): """Match a cell in this method to one in a function we are inlining""" value = cell.cell_contents # TODO(jansel): check the id of the cell rather than the contents if id(value) != self._freevars_ids.get(name): return None return self.symbolic_locals[name] def should_compile_partial_graph(self): return all(b.can_restore() for b in self.block_stack) and not self.one_graph def create_call_resume_at(self, inst): self.instruction_pointer = None if inst.opname == "RETURN_VALUE": return [create_instruction("RETURN_VALUE")] reads = livevars_analysis(self.instructions, inst) argnames = tuple( k for k in self.symbolic_locals.keys() if k in reads and k not in self.cell_and_freevars() ) nargs = len(self.stack) + len(argnames) name = unique_id(f"__resume_at_{inst.offset}") new_code: types.CodeType = ContinueExecutionCache.lookup( self.f_code, self.lineno, inst.offset, len(self.stack), argnames, tuple(b.resume_fn() for b in self.block_stack), ) cg = PyCodegen(self) if new_code.co_freevars: cg.make_function_with_closure(name, new_code, len(self.stack)) else: self.output.install_global( name, types.FunctionType(new_code, self.f_globals, name) ) cg.extend_output(cg.load_function_name(name, len(self.stack))) cg.extend_output([cg.create_load(k) for k in argnames]) cg.extend_output( [ create_instruction("CALL_FUNCTION", nargs), create_instruction("RETURN_VALUE"), ] ) return cg.get_instructions() def RETURN_VALUE(self, inst): if self.output.count_calls() == 0: raise exc.SkipFrame("because no content in function call") self.instruction_pointer = None _step_logger()( logging.INFO, f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)", ) log.debug("RETURN_VALUE triggered compile") self.output.compile_subgraph( self, reason=GraphCompileReason("return_value", [self.frame_summary()]) ) self.output.add_output_instructions([create_instruction("RETURN_VALUE")]) class InliningInstructionTranslator(InstructionTranslatorBase): """Trace and inline a called method""" symbolic_result: Optional[TensorVariable] @classmethod def inline_call(cls, parent, func, args, kwargs): with patch.dict(counters, {"unimplemented": counters["inline_call"]}): return cls.inline_call_(parent, func, args, kwargs) @staticmethod def inline_call_(parent, func, args, kwargs): assert isinstance( func, (UserFunctionVariable, NestedUserFunctionVariable), ) if func.has_self(): unimplemented("inline with __self__") if func.get_name() == "patched_init": unimplemented("Patched init cannot be inlined.") try: if id(func.get_function()) in allowed_functions._disallowed_function_ids: unimplemented(f"inlining disallowed: {func.get_function()}") except NotImplementedError: pass # closures if skipfiles.check( func.get_filename() ) and not skipfiles.is_torch_inline_allowed(func.get_filename()): unimplemented( f"inline in skipfiles: {func.fn.__qualname__} | {func.get_name()} {func.get_filename()}" ) try: sub_locals, closure_cells = func.bind_args(parent, args, kwargs) except TypeError as e: log.warning( f"{func.get_filename()} {func.get_function()} {args} {kwargs} {e}" ) unimplemented("arg mismatch inlining") for v in itertools.chain(sub_locals.values(), closure_cells.values()): if not isinstance(v, VariableTracker): unimplemented(f"unconverted arg {v}") code: types.CodeType = func.get_code() if code.co_name in ("__setitem__", "__setattr__"): unimplemented(f"inline {code.co_name}") log.debug(f"INLINING {code} \n {dis.Bytecode(code).dis()} \n") tracer: InliningInstructionTranslator if is_generator(code): tracer = InliningGeneratorInstructionTranslator( parent, code, sub_locals, parent.symbolic_globals, closure_cells, func ) else: tracer = InliningInstructionTranslator( parent, code, sub_locals, parent.symbolic_globals, closure_cells, func ) try: tracer.run() except exc.SkipFrame as e: msg = f"SKIPPED INLINING {code}: {e}" log.debug(msg) raise Unsupported(msg) from e except Exception as e: log.debug(f"FAILED INLINING {code}") raise assert tracer.symbolic_result is not None func.export_freevars(parent, tracer) if tracer.f_globals is parent.f_globals: # Merge symbolic_globals back if parent and child are in the same namespace parent.symbolic_globals.update(tracer.symbolic_globals) log.debug(f"DONE INLINING {code}") if is_generator(code): assert isinstance(tracer, InliningGeneratorInstructionTranslator) assert tracer.symbolic_result.as_python_constant() is None return ListIteratorVariable( tracer.generated_items, mutable_local=MutableLocal(), **VariableTracker.propagate(tracer.symbolic_result), ) else: return tracer.symbolic_result def __init__( self, parent: InstructionTranslatorBase, code: types.CodeType, symbolic_locals: Dict[str, VariableTracker], symbolic_globals: Dict[str, VariableTracker], closure_cells: Dict[str, VariableTracker], funcvar: BaseUserFunctionVariable, ): f_globals = funcvar.get_globals() f_builtins = f_globals["__builtins__"] if not isinstance(f_builtins, dict): f_builtins = f_builtins.__dict__ super().__init__( output=parent.output, f_locals={}, f_globals=f_globals, f_builtins=f_builtins, symbolic_locals=symbolic_locals, symbolic_globals=symbolic_globals, instructions=cleaned_instructions(code), code_options={k: getattr(code, k) for k in dir(code)}, f_code=code, export=parent.export, ) self.parent = parent self.symbolic_result = None self.closure_cells = closure_cells self.nn_module_stack = parent.nn_module_stack.copy() @property def fake_mode(self): return self.parent.fake_mode def STORE_DEREF(self, inst): if inst.argval in self.closure_cells: cell = self.closure_cells[inst.argval] val = self.pop() if isinstance(cell, ClosureVariable): self.output.root_tx.symbolic_locals[cell.name] = val else: self.output.side_effects.store_cell(cell, val) else: maybe_cell = self.symbolic_locals.get(inst.argval) if isinstance( maybe_cell, variables.NewCellVariable, ): self.output.side_effects.store_cell( self.symbolic_locals[inst.argval], self.pop() ) else: if ( maybe_cell is not None and maybe_cell.source.name() not in self.parent.mutated_closure_cell_contents ): # Why is the source name here unique? # mutated_closure_cell_contents is a per-frame # concept, and sources identify, e.g., particular # locals from the frame. If you had two locals, # they'll get different source names, and therefore # differ here. self.parent.mutated_closure_cell_contents.add( maybe_cell.source.name() ) raise exc.RestartAnalysis() unimplemented("write to __closure__ while inlining") def LOAD_DEREF(self, inst): if inst.argval in self.closure_cells: cell = self.closure_cells[inst.argval] if isinstance(cell, ClosureVariable): self.push(self.output.root_tx.symbolic_locals[cell.name]) else: self.push(self.output.side_effects.load_cell(cell)) else: maybe_sym_local = self.symbolic_locals.get(inst.argval, None) if isinstance(maybe_sym_local, variables.NewCellVariable): self.push(self.output.side_effects.load_cell(maybe_sym_local)) else: super().LOAD_DEREF(inst) def LOAD_CLOSURE(self, inst): assert inst.argval in self.cell_and_freevars() self.push(self.closure_cells[inst.argval]) def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker): newvar = super().replace_all(oldvar, newvar) # recursively check and update parent's locals and stack in case oldvar is from parent translator: InstructionTranslatorBase = self while hasattr(translator, "parent"): translator = translator.parent # type: ignore[attr-defined] translator.update_locals_and_stack(oldvar, newvar) return newvar def should_compile_partial_graph(self): return False # inlining functions is all-or-nothing def create_call_resume_at(self, offset): unimplemented("cant resume while inlining") def RETURN_VALUE(self, inst): self.symbolic_result = self.pop() self.instruction_pointer = None class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): generated_items: List[VariableTracker] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.generated_items = [] def YIELD_VALUE(self, inst: Instruction): self.generated_items.append(self.pop()) # TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE self.push(ConstantVariable(None))