123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799 |
- import collections
- import copy
- import functools
- import itertools
- import logging
- import operator
- import re
- import traceback
- from dataclasses import dataclass
- from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union
- import torch.nn
- from torch import fx
- from torch._guards import (
- Checkpointable,
- Guard,
- GuardsCheckpointState,
- tracing,
- TracingContext,
- )
- from torch.fx.experimental.symbolic_shapes import ShapeEnv
- from . import config, logging as torchdynamo_logging, variables
- from .backends.registry import CompiledFn, CompilerFn
- from .bytecode_transformation import create_instruction, Instruction, unique_id
- from .codegen import PyCodegen
- from .exc import BackendCompilerFailed, unimplemented
- from .guards import GuardBuilder
- from .mutation_guard import is_dynamic_nn_module
- from .side_effects import SideEffects
- from .source import (
- ConstantSource,
- is_constant_source,
- LocalInputSource,
- LocalSource,
- ShapeEnvSource,
- )
- from .utils import (
- assert_no_fake_params_or_buffers,
- checkpoint_params,
- CleanupHook,
- clone_inputs,
- count_calls,
- counters,
- dynamo_timed,
- format_graph_tabular,
- same,
- )
- from .variables.base import VariableTracker
- from .variables.builder import GraphArg, TrackedFake, VariableBuilder, wrap_fx_proxy
- from .variables.nn_module import NNModuleVariable
- from .variables.tensor import (
- SymNodeVariable,
- TensorVariable,
- UnspecializedPythonVariable,
- )
- log = logging.getLogger(__name__)
- class OutputGraphState(NamedTuple):
- graphargs: List[GraphArg]
- tracked_fakes: List[TrackedFake]
- guard_state: GuardsCheckpointState
- nn_modules: Optional[Dict[str, torch.nn.Module]]
- side_effects: SideEffects
- timestamp: int
- def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]:
- for k in self._fields:
- if k == "guard_state":
- r = self.guard_state.diff(other.guard_state)
- if r is not None:
- return r
- continue
- elif k == "side_effects":
- r = self.side_effects.diff(other.side_effects)
- if r is not None:
- return r
- continue
- sv = getattr(self, k)
- ov = getattr(other, k)
- if sv != ov:
- return f"{prefix}{k} mismatch: {sv} != {ov}"
- return None
- # Back compat .guards api
- @property
- def guards(self):
- return self.guard_state.dynamo_guards
- @functools.lru_cache(None)
- def _step_logger():
- return torchdynamo_logging.get_step_logger(log)
- @dataclass
- class GraphCompileReason:
- """Stores why a given output graph was compiled; i.e. what caused the graph break."""
- reason: str
- user_stack: List[traceback.FrameSummary]
- def _get_gen_rand_values_fn(random_calls):
- def _gen_rand_values():
- return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
- return _gen_rand_values
- class FakeRootModule(torch.nn.Module):
- """Trick the constructor of fx.GraphModule"""
- def __init__(self, nn_modules: Dict[str, torch.nn.Module]):
- super().__init__()
- for k, v in nn_modules.items():
- setattr(self, k, v)
- def __repr__(self):
- return "FakeRootModule(...)"
- class WrapperBackend:
- def __init__(self, backend: CompilerFn, original_example_inputs):
- self.backend: CompilerFn = backend
- self.original_example_inputs = original_example_inputs
- @property
- def example_inputs(self):
- return clone_inputs(self.original_example_inputs)
- def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
- self.restore = checkpoint_params(gm)
- self.gm = gm
- copy_gm = copy.deepcopy(self.gm)
- self.candidate = self.backend(copy_gm, self.original_example_inputs)
- if self.candidate is None or self.candidate is self.gm.forward:
- return self.gm.forward
- if not config.verify_correctness:
- return self.candidate
- # if verify_correctness=True
- try:
- correct = self.gm.forward(*self.example_inputs)
- result = self.candidate(*self.example_inputs)
- # TODO: replace `same` function with the one in testing
- if same(correct, result):
- return self.candidate
- raise RuntimeError(f"incorrect results of backend {self}")
- return self.gm.forward
- except Exception:
- log.exception("error in verify_correctness")
- raise
- finally:
- self.restore()
- class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
- """
- Wrapper class to hold outputs of InstructionTranslator. Mainly the
- generated fx.Graph.
- """
- def __init__(
- self,
- f_globals: Dict[str, Any],
- code_options: Dict[str, Any],
- compiler_fn: CompilerFn,
- root_tx,
- ):
- super().__init__()
- self.graph = torch.fx.Graph()
- self.graphargs: List[GraphArg] = []
- fake_mode = torch._subclasses.FakeTensorMode(
- shape_env=ShapeEnv() if config.dynamic_shapes else None,
- )
- self.tracing_context: TracingContext = TracingContext(fake_mode)
- if config.dynamic_shapes:
- # Register a SHAPE_ENV guard to make sure we setup shape guards
- # that show up in ShapeEnv
- self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
- # tracked_fakes says where any tensor that was wrapped to fake came
- # from. It is similar to GraphArg, in that all GraphArgs will get
- # will get added to TrackedFakes, but TrackedFakes also contains
- # GraphArgs that got pruned, and things like Tensor attributes which
- # aren't explicit graph inputs. Used by shape guard
- self.tracked_fakes: List[TrackedFake] = []
- # Although we prune unused graphargs before sending graphs to
- # compilers, we may have legitimately triggered shape guards
- # on "unused" inputs that we must keep track of. So after
- # remove_unused_graphargs is called, orig_graphargs and
- # graphargs no longer alias; orig_graphargs is the original
- # graphargs, and graphargs is the pruned list. Guard creation
- # should use original graphargs.
- self.orig_graphargs: List[GraphArg] = self.graphargs
- self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
- self.side_effects = SideEffects()
- self.code_options = dict(code_options)
- self.output_instructions: List[Instruction] = []
- # used to track nodes that are added between calls of copy_graphstate
- # and restore_graphstate
- self.timestamp = 0
- # Node => computed real value (see utils.get_real_value)
- self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}
- # Not checkpointed
- self.compiler_fn: CompilerFn = compiler_fn
- self.root_globals = f_globals
- self.root_tx = root_tx
- from torch._dynamo.symbolic_convert import InstructionTranslatorBase
- self._current_tx: List[InstructionTranslatorBase] = []
- self.cleanups: List[CleanupHook] = []
- self.should_exit = False
- self.random_values_var = None
- self.initial_random_state = ()
- self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {}
- # Maps the source arg position to the grapharg position
- self.pos_to_arg: Dict[int, int] = {}
- # Enables creating unique node names by tracking
- # all current placeholder node names
- self.name_to_input: OrderedDict[
- str, Optional[fx.Proxy]
- ] = collections.OrderedDict()
- @property
- def output(self):
- return self
- @property
- def fake_mode(self):
- return self.root_tx.fake_mode
- @property
- def shape_env(self):
- return self.tracing_context.fake_mode.shape_env
- @property
- def guards(self) -> Set[Guard]:
- return self.tracing_context.guards_context.dynamo_guards
- def push_tx(self, tx):
- self._current_tx.append(tx)
- def pop_tx(self):
- return self._current_tx.pop()
- @property
- def current_tx(self):
- return self.root_tx if not self._current_tx else self._current_tx[-1]
- def copy_graphstate(self) -> OutputGraphState:
- """Create a checkpoint of the current state by copying everything"""
- assert self.nn_modules is not None
- guards_graph_state = self.tracing_context.guards_context.copy_graphstate()
- state = OutputGraphState(
- list(self.graphargs),
- list(self.tracked_fakes),
- guards_graph_state,
- dict(self.nn_modules),
- self.side_effects.clone(),
- self.timestamp,
- )
- self.timestamp += 1
- return state
- def restore_graphstate(self, state: OutputGraphState):
- """Restore a checkpoint created by self.copy_graphstate()"""
- (
- self.graphargs,
- self.tracked_fakes,
- guards_state,
- self.nn_modules,
- self.side_effects,
- self.timestamp,
- ) = state
- self.tracing_context.guards_context.restore_graphstate(guards_state)
- # FX deepcopy doesn't work for a partially created graph, so just remove new nodes
- removed_nodes = 0
- for node in reversed(list(self.graph.nodes)):
- if node.meta["creation_timestamp"] > self.timestamp:
- # Erasing node alone does not remove the meta information
- # So, remove the help tensor explicitly
- if "example_value" in node.meta:
- del node.meta["example_value"]
- self.remove_node(node)
- self.real_value_cache.pop(node, None)
- removed_nodes += 1
- log.debug(f"restore_graphstate: removed {removed_nodes} nodes")
- def add_grapharg(self, arg: GraphArg):
- curr_pos = len(self.graphargs)
- self.graphargs.append(arg)
- if isinstance(arg.source, LocalInputSource):
- self.pos_to_arg[arg.source.pos] = curr_pos
- def count_calls(self):
- return count_calls(self.graph)
- def get_submodule(self, keys):
- assert keys
- obj = self.nn_modules
- for k in keys.split("."):
- if isinstance(obj, dict):
- obj = obj[k]
- else:
- obj = getattr(obj, k)
- return obj
- def create_graph_input(self, name, type_expr=None):
- # unique
- if name in self.name_to_input:
- for i in itertools.count():
- if f"{name}_{i}" not in self.name_to_input:
- name = f"{name}_{i}"
- break
- if self.name_to_input:
- prev_name = next(reversed(self.name_to_input))
- ctx = self.graph.inserting_after(self.name_to_input[prev_name])
- else:
- ctx = self.graph.inserting_before(None)
- with ctx:
- proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
- self.name_to_input[name] = proxy.node
- return proxy
- def new_var(self, name="tmp"):
- existing = set(self.code_options["co_varnames"])
- for i in itertools.count():
- var = f"___{name}_{i}"
- if var not in existing:
- self.code_options["co_varnames"] = self.code_options["co_varnames"] + (
- var,
- )
- return var
- def update_co_names(self, name):
- """Ensure self.code_options.co_names contains name"""
- if name not in self.code_options["co_names"]:
- self.code_options["co_names"] = tuple(self.code_options["co_names"]) + (
- name,
- )
- def register_attr_or_module(
- self,
- target: Union[torch.nn.Module, torch.Tensor, Any],
- *names,
- **options,
- ):
- if is_dynamic_nn_module(target):
- return variables.UnspecializedNNModuleVariable(target, **options)
- options = dict(options)
- options["guards"] = set(options.get("guards", []))
- assert "source" in options
- source = options["source"]
- if isinstance(target, torch.Tensor):
- if not is_constant_source(source):
- options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH))
- def wrap_name(module_key):
- return wrap_fx_proxy(
- self.root_tx,
- self.create_proxy("get_attr", module_key, tuple(), {}),
- example_value=target,
- **options,
- )
- elif isinstance(target, torch.nn.Module):
- assert isinstance(target, torch.nn.Module)
- options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
- def wrap_name(module_key):
- return NNModuleVariable(type(target), module_key, **options)
- elif isinstance(target, (torch.SymInt, torch.SymFloat)):
- # HACKY CODE REGION BEGIN
- # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
- # This ultimately gets written to self.nn_modules, which is unfortunate
- # Attrs that are tenors and symints and such need to be migrated to have their
- # own storage
- # alas, this is like this for now
- def wrap_name(module_key):
- return SymNodeVariable.create(
- self,
- self.create_proxy("get_attr", module_key, tuple(), {}),
- sym_num=target,
- **options,
- )
- # HACKY CODE REGION END
- else:
- def wrap_name(module_key):
- self.output.update_co_names(module_key)
- self.root_globals[module_key] = target
- return VariableBuilder(self, ConstantSource(source_name=module_key))(
- target
- )
- assert self.nn_modules is not None
- for k, v in self.nn_modules.items():
- if v is target:
- # it already exists
- return wrap_name(k)
- # create a new unique name
- name = "_".join(map(str, names))
- # e.g. repalce abc.xyz[123].qkv with abc.xyz_123.qkv
- name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
- # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
- name = re.sub(r"[^a-zA-Z0-9]", "_", name)
- if not name or not name[0].isalpha():
- name = "sub" + name
- base = name
- for i in itertools.count():
- if name not in self.nn_modules:
- self.nn_modules[name] = target
- return wrap_name(name)
- name = f"{base}_{i}"
- raise AssertionError("unreachable")
- def compile_subgraph(
- self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
- ):
- """
- Generate a subgraph to continue execution on user code.
- Automatically restore live variables.
- """
- from .eval_frame import disable
- self.partial_convert = partial_convert
- self.compile_subgraph_reason = reason
- log.debug(f"COMPILING GRAPH due to {reason}")
- if not all(block.can_restore() for block in tx.block_stack):
- unimplemented("compile_subgraph with block_depth != 0")
- for block in reversed(tx.block_stack):
- block.exit(tx)
- tx.prune_dead_locals()
- stack_values = list(tx.stack)
- assert self.nn_modules is not None
- root = FakeRootModule(self.nn_modules)
- # Add all the local vars to the "stack" so restore at the end
- restore_vars = []
- val_to_names: OrderedDict[
- VariableTracker, List[str]
- ] = collections.OrderedDict()
- if stack_values:
- val_to_names[stack_values[-1]] = list()
- for k, v in tx.symbolic_locals.items():
- if isinstance(v.source, LocalSource) and v.source.name() == k:
- continue # no need to restore initial state
- if v not in val_to_names:
- val_to_names[v] = list()
- val_to_names[v].append(k)
- for v in val_to_names.keys():
- restore_vars.extend(val_to_names[v])
- stack_values.extend([v] * len(val_to_names[v]))
- # to handle random calls
- if len(tx.random_calls) > 0:
- random_calls_instructions = []
- self.random_values_var = self.new_var("random_values")
- rand_fn_name = unique_id("__gen_rand_values")
- rand_fn = disable(_get_gen_rand_values_fn(tx.random_calls))
- self.install_global(rand_fn_name, rand_fn)
- codegen = PyCodegen(tx, root)
- random_calls_instructions.extend(
- [
- codegen.create_load_global("random", add=True),
- codegen.create_load_attr("setstate"),
- codegen.create_load_const(tx.output.initial_random_state),
- create_instruction("CALL_FUNCTION", 1),
- ]
- )
- random_calls_instructions.extend(codegen.load_function_name(rand_fn_name))
- random_calls_instructions.extend(
- [
- create_instruction("CALL_FUNCTION", 0),
- codegen.create_store(tx.output.random_values_var),
- ]
- )
- self.add_output_instructions(random_calls_instructions)
- if (
- stack_values
- and all(
- not isinstance(v, UnspecializedPythonVariable) for v in stack_values
- )
- and all(isinstance(x, TensorVariable) for x in stack_values)
- and len(set(stack_values)) == len(stack_values)
- and self.side_effects.is_empty()
- ):
- # optimization to generate better code in a common case
- self.add_output_instructions(
- self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
- + [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
- )
- else:
- graph_output_var = self.new_var("graph_out")
- pass1 = PyCodegen(tx, root, graph_output_var)
- self.side_effects.codegen_save_tempvars(pass1)
- pass1.foreach(stack_values)
- self.side_effects.codegen_update_mutated(pass1)
- # one more time now that we have established tempvars
- pass2 = PyCodegen(
- tx,
- root,
- graph_output_var,
- tempvars={val: None for val, count in pass1.uses.items() if count > 1},
- )
- self.side_effects.codegen_save_tempvars(pass2)
- pass2.foreach(stack_values)
- self.side_effects.codegen_update_mutated(pass2)
- output = []
- if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
- output.extend(
- self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
- )
- if len(pass2.graph_outputs) != 0:
- output.append(pass2.create_store(graph_output_var))
- else:
- output.append(create_instruction("POP_TOP"))
- self.add_output_instructions(output + pass2.get_instructions())
- # restore all the live local vars
- self.add_output_instructions(
- [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
- )
- def compile_and_call_fx_graph(self, tx, rv, root):
- """
- Generate code from self.graph and return the Instruction()s to
- call that generated code.
- """
- from .eval_frame import disable
- assert isinstance(rv, list)
- assert isinstance(root, FakeRootModule)
- for output in rv:
- self.guards.update(output.guards)
- self.create_node(
- "output", "output", (self.create_arg(tuple(x.as_proxy() for x in rv)),), {}
- )
- self.remove_unused_graphargs()
- ncalls = count_calls(self.graph)
- counters["stats"]["calls_captured"] += ncalls
- counters["stats"]["fusions_possible"] += ncalls - 1
- # free a bit of memory
- for node in self.graph.nodes:
- if "example_value" in node.meta:
- del node.meta["example_value"]
- self.real_value_cache.clear()
- gm = fx.GraphModule(root, self.graph)
- gm.recompile()
- gm.compile_subgraph_reason = self.compile_subgraph_reason
- name = unique_id("__compiled_fn")
- assert_no_fake_params_or_buffers(gm)
- with tracing(self.tracing_context):
- compiled_fn = self.call_user_compiler(gm)
- compiled_fn = disable(compiled_fn)
- counters["stats"]["unique_graphs"] += 1
- self.install_global(name, compiled_fn)
- try:
- # the call to tabulate can cause a lot of memory to be allocated
- if config.log_level <= logging.INFO and config.output_code:
- graph_str = (
- gm.print_readable()
- if config.output_graph_code
- else format_graph_tabular(gm.graph)
- )
- log.log(
- logging.INFO,
- f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {graph_str}\n",
- )
- except ImportError:
- log.warning(
- "Unable to print graph: `format_graph_tabular` relies on the library `tabulate`, "
- "which could not be found on this machine. Run `pip "
- "install tabulate` to install the library."
- )
- cg = PyCodegen(tx)
- cg.make_call_generated_code(name)
- return cg.get_instructions()
- @dynamo_timed(phase_name="backend_compile")
- def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
- tot = 0
- for node in gm.graph.nodes:
- if node.op in ("call_function", "call_method", "call_module"):
- tot += 1
- torch._dynamo.utils.increment_op_count(tot)
- try:
- name = (
- self.compiler_fn.__name__
- if hasattr(self.compiler_fn, "__name__")
- else ""
- )
- _step_logger()(logging.INFO, f"calling compiler function {name}")
- compiler_fn = self.compiler_fn
- # WrapperBackend needs real inputs, for now, to verify correctness
- if config.verify_correctness:
- compiler_fn = WrapperBackend(compiler_fn, self.example_inputs())
- # NOTE: [Real Tensors in Accuracy Evaluation]
- #
- # Today, tensors are passed to backends as fake at compile time. See the .fake_example_inputs()
- # call to compiler_fn below. At runtime, backends use real tensors.
- #
- # This should be a strong invariant we hold across all backends,
- # and generally, it is. However, for accuracy evaluation, we need real tensors at compile time,
- # for now, due to the unfortunate setup described below.
- #
- # Due to the nature of how we invoke comparison as a backend in two different ways:
- #
- # (1) Less bad, but still worth rewriting, WrapperBackend above, which takes
- # real inputs for its ctor. see the config.verify_correctnes above.
- #
- # (2) More bad, and very worth rewriting, the minifier installs accuracy comparison as
- # a true backend, and therefore needs to be compiled with real inputs. This is made trickier
- # by the fact that the minifier will spawn new processes during minification. As such, we have
- # created a global flag, MINIFIER_SPAWNED, that should be set IF AND ONLY IF this run was spawned
- # as part of accuracy minification. This flag is not a contract, and ideally will not be here long.
- #
- # The longer term PoR is to:
- # (A) Rewrite the minifier accuracy evaluation and verify_correctness code to share the same
- # correctness and accuracy logic, so as not to have two different ways of doing the same thing.
- #
- # (B) Refactor minifier accuracy backend to do its comparison fully at runtime, so as not to need to
- # pass real tensors to it at compile time.
- is_top_level_minifying = (
- config.repro_after is not None and config.repro_level == 4
- )
- if torch._dynamo.debug_utils.MINIFIER_SPAWNED or is_top_level_minifying:
- compiled_fn = compiler_fn(gm, self.example_inputs())
- elif config.DO_NOT_USE_legacy_non_fake_example_inputs:
- compiled_fn = compiler_fn(gm, self.example_inputs())
- else:
- compiled_fn = compiler_fn(gm, self.fake_example_inputs())
- _step_logger()(logging.INFO, f"done compiler function {name}")
- assert callable(compiled_fn), "compiler_fn did not return callable"
- except Exception as e:
- compiled_fn = gm.forward
- raise BackendCompilerFailed(self.compiler_fn, e) from e
- return compiled_fn
- def fake_example_inputs(self) -> List[torch.Tensor]:
- result = []
- for arg in self.graphargs:
- example = arg.get_fake_examples()
- if example is not None:
- result.extend(example)
- else:
- # Fallback, in case fake_tensor was not set
- # Particularly for graph args that are not tensors
- result.extend(arg.get_examples())
- return result
- def example_inputs(self) -> List[torch.Tensor]:
- result = []
- for arg in self.graphargs:
- result.extend(arg.get_examples())
- return result
- def remove_unused_graphargs(self) -> None:
- for node in reversed(list(self.graph.nodes)):
- if len(list(node.users)) == 0:
- if node.op == "get_attr":
- self.remove_node(node)
- elif node.op == "call_function" and node.target is operator.getitem:
- self.remove_node(node)
- expanded_graphargs = []
- for arg in self.graphargs:
- expanded_graphargs.extend([arg] * len(arg))
- arg.uses = 0
- for node, arg in zip(self.graph.nodes, expanded_graphargs):
- assert node.op == "placeholder"
- arg.uses += len(node.users)
- for node, arg in list(zip(self.graph.nodes, expanded_graphargs)):
- if arg.uses == 0:
- log.debug(f"REMOVE UNUSED GRAPHARG {arg.source.name()}")
- if "example_value" in node.meta:
- del node.meta["example_value"]
- self.remove_node(node)
- self.real_value_cache.pop(node, None)
- self.graphargs = [arg for arg in self.graphargs if arg.uses > 0]
- def add_output_instructions(self, prefix: List[Instruction]) -> None:
- """
- We call this on the creation of a new compiled subgraph that is inserted
- before user code.
- """
- self.output_instructions.extend(prefix)
- self.should_exit = True
- def install_global(self, name, value) -> None:
- self.cleanups.append(CleanupHook.create(self.root_globals, name, value))
- def cleanup(self) -> None:
- # There is a reference cycle between tracer and OutputGraph, causing
- # some of the tensor objects to be held alive for longer than necessary.
- self.root_tx = None
- # Note: generated fx graph will hold a reference to the nn_module,
- # So depending on the backend they may not be released
- self.nn_modules = None
- # Cleanup graphargs
- for graph_arg in self.graphargs:
- graph_arg.erase()
- for node in self.graph.nodes:
- if "example_value" in node.meta:
- del node.meta["example_value"]
- self.real_value_cache.clear()
- self.name_to_input.clear()
- self.side_effects.keepalive = []
- def create_proxy(
- self,
- kind,
- target,
- args,
- kwargs,
- name=None,
- type_expr=None,
- proxy_factory_fn=None,
- ):
- rv = super().create_proxy(
- kind, target, args, kwargs, name, type_expr, proxy_factory_fn
- )
- # append stack trace to fx node
- tx = self.current_tx
- nn_module_stack = tx.nn_module_stack
- if nn_module_stack:
- rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
- if kind in {"call_function", "call_method"}:
- rv.node.meta["source_fn"] = target
- frame_summaries: List[traceback.FrameSummary] = []
- while tx:
- frame_summaries.append(tx.frame_summary())
- tx = getattr(tx, "parent", None)
- # official from_list stub doesn't have new-style type
- msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type]
- rv.node.stack_trace = " | ".join(msgs)
- return rv
- def create_node(self, *args, **kwargs):
- node = super().create_node(*args, **kwargs)
- node.meta["creation_timestamp"] = self.timestamp
- return node
- # Note: we did not override erase_node since
- # we call self.graph.erase_node elsewhere
- def remove_node(self, node):
- self.graph.erase_node(node)
- self.name_to_input.pop(node.name, None)
|