output_graph.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799
  1. import collections
  2. import copy
  3. import functools
  4. import itertools
  5. import logging
  6. import operator
  7. import re
  8. import traceback
  9. from dataclasses import dataclass
  10. from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union
  11. import torch.nn
  12. from torch import fx
  13. from torch._guards import (
  14. Checkpointable,
  15. Guard,
  16. GuardsCheckpointState,
  17. tracing,
  18. TracingContext,
  19. )
  20. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  21. from . import config, logging as torchdynamo_logging, variables
  22. from .backends.registry import CompiledFn, CompilerFn
  23. from .bytecode_transformation import create_instruction, Instruction, unique_id
  24. from .codegen import PyCodegen
  25. from .exc import BackendCompilerFailed, unimplemented
  26. from .guards import GuardBuilder
  27. from .mutation_guard import is_dynamic_nn_module
  28. from .side_effects import SideEffects
  29. from .source import (
  30. ConstantSource,
  31. is_constant_source,
  32. LocalInputSource,
  33. LocalSource,
  34. ShapeEnvSource,
  35. )
  36. from .utils import (
  37. assert_no_fake_params_or_buffers,
  38. checkpoint_params,
  39. CleanupHook,
  40. clone_inputs,
  41. count_calls,
  42. counters,
  43. dynamo_timed,
  44. format_graph_tabular,
  45. same,
  46. )
  47. from .variables.base import VariableTracker
  48. from .variables.builder import GraphArg, TrackedFake, VariableBuilder, wrap_fx_proxy
  49. from .variables.nn_module import NNModuleVariable
  50. from .variables.tensor import (
  51. SymNodeVariable,
  52. TensorVariable,
  53. UnspecializedPythonVariable,
  54. )
  55. log = logging.getLogger(__name__)
  56. class OutputGraphState(NamedTuple):
  57. graphargs: List[GraphArg]
  58. tracked_fakes: List[TrackedFake]
  59. guard_state: GuardsCheckpointState
  60. nn_modules: Optional[Dict[str, torch.nn.Module]]
  61. side_effects: SideEffects
  62. timestamp: int
  63. def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]:
  64. for k in self._fields:
  65. if k == "guard_state":
  66. r = self.guard_state.diff(other.guard_state)
  67. if r is not None:
  68. return r
  69. continue
  70. elif k == "side_effects":
  71. r = self.side_effects.diff(other.side_effects)
  72. if r is not None:
  73. return r
  74. continue
  75. sv = getattr(self, k)
  76. ov = getattr(other, k)
  77. if sv != ov:
  78. return f"{prefix}{k} mismatch: {sv} != {ov}"
  79. return None
  80. # Back compat .guards api
  81. @property
  82. def guards(self):
  83. return self.guard_state.dynamo_guards
  84. @functools.lru_cache(None)
  85. def _step_logger():
  86. return torchdynamo_logging.get_step_logger(log)
  87. @dataclass
  88. class GraphCompileReason:
  89. """Stores why a given output graph was compiled; i.e. what caused the graph break."""
  90. reason: str
  91. user_stack: List[traceback.FrameSummary]
  92. def _get_gen_rand_values_fn(random_calls):
  93. def _gen_rand_values():
  94. return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
  95. return _gen_rand_values
  96. class FakeRootModule(torch.nn.Module):
  97. """Trick the constructor of fx.GraphModule"""
  98. def __init__(self, nn_modules: Dict[str, torch.nn.Module]):
  99. super().__init__()
  100. for k, v in nn_modules.items():
  101. setattr(self, k, v)
  102. def __repr__(self):
  103. return "FakeRootModule(...)"
  104. class WrapperBackend:
  105. def __init__(self, backend: CompilerFn, original_example_inputs):
  106. self.backend: CompilerFn = backend
  107. self.original_example_inputs = original_example_inputs
  108. @property
  109. def example_inputs(self):
  110. return clone_inputs(self.original_example_inputs)
  111. def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
  112. self.restore = checkpoint_params(gm)
  113. self.gm = gm
  114. copy_gm = copy.deepcopy(self.gm)
  115. self.candidate = self.backend(copy_gm, self.original_example_inputs)
  116. if self.candidate is None or self.candidate is self.gm.forward:
  117. return self.gm.forward
  118. if not config.verify_correctness:
  119. return self.candidate
  120. # if verify_correctness=True
  121. try:
  122. correct = self.gm.forward(*self.example_inputs)
  123. result = self.candidate(*self.example_inputs)
  124. # TODO: replace `same` function with the one in testing
  125. if same(correct, result):
  126. return self.candidate
  127. raise RuntimeError(f"incorrect results of backend {self}")
  128. return self.gm.forward
  129. except Exception:
  130. log.exception("error in verify_correctness")
  131. raise
  132. finally:
  133. self.restore()
  134. class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
  135. """
  136. Wrapper class to hold outputs of InstructionTranslator. Mainly the
  137. generated fx.Graph.
  138. """
  139. def __init__(
  140. self,
  141. f_globals: Dict[str, Any],
  142. code_options: Dict[str, Any],
  143. compiler_fn: CompilerFn,
  144. root_tx,
  145. ):
  146. super().__init__()
  147. self.graph = torch.fx.Graph()
  148. self.graphargs: List[GraphArg] = []
  149. fake_mode = torch._subclasses.FakeTensorMode(
  150. shape_env=ShapeEnv() if config.dynamic_shapes else None,
  151. )
  152. self.tracing_context: TracingContext = TracingContext(fake_mode)
  153. if config.dynamic_shapes:
  154. # Register a SHAPE_ENV guard to make sure we setup shape guards
  155. # that show up in ShapeEnv
  156. self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
  157. # tracked_fakes says where any tensor that was wrapped to fake came
  158. # from. It is similar to GraphArg, in that all GraphArgs will get
  159. # will get added to TrackedFakes, but TrackedFakes also contains
  160. # GraphArgs that got pruned, and things like Tensor attributes which
  161. # aren't explicit graph inputs. Used by shape guard
  162. self.tracked_fakes: List[TrackedFake] = []
  163. # Although we prune unused graphargs before sending graphs to
  164. # compilers, we may have legitimately triggered shape guards
  165. # on "unused" inputs that we must keep track of. So after
  166. # remove_unused_graphargs is called, orig_graphargs and
  167. # graphargs no longer alias; orig_graphargs is the original
  168. # graphargs, and graphargs is the pruned list. Guard creation
  169. # should use original graphargs.
  170. self.orig_graphargs: List[GraphArg] = self.graphargs
  171. self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
  172. self.side_effects = SideEffects()
  173. self.code_options = dict(code_options)
  174. self.output_instructions: List[Instruction] = []
  175. # used to track nodes that are added between calls of copy_graphstate
  176. # and restore_graphstate
  177. self.timestamp = 0
  178. # Node => computed real value (see utils.get_real_value)
  179. self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}
  180. # Not checkpointed
  181. self.compiler_fn: CompilerFn = compiler_fn
  182. self.root_globals = f_globals
  183. self.root_tx = root_tx
  184. from torch._dynamo.symbolic_convert import InstructionTranslatorBase
  185. self._current_tx: List[InstructionTranslatorBase] = []
  186. self.cleanups: List[CleanupHook] = []
  187. self.should_exit = False
  188. self.random_values_var = None
  189. self.initial_random_state = ()
  190. self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {}
  191. # Maps the source arg position to the grapharg position
  192. self.pos_to_arg: Dict[int, int] = {}
  193. # Enables creating unique node names by tracking
  194. # all current placeholder node names
  195. self.name_to_input: OrderedDict[
  196. str, Optional[fx.Proxy]
  197. ] = collections.OrderedDict()
  198. @property
  199. def output(self):
  200. return self
  201. @property
  202. def fake_mode(self):
  203. return self.root_tx.fake_mode
  204. @property
  205. def shape_env(self):
  206. return self.tracing_context.fake_mode.shape_env
  207. @property
  208. def guards(self) -> Set[Guard]:
  209. return self.tracing_context.guards_context.dynamo_guards
  210. def push_tx(self, tx):
  211. self._current_tx.append(tx)
  212. def pop_tx(self):
  213. return self._current_tx.pop()
  214. @property
  215. def current_tx(self):
  216. return self.root_tx if not self._current_tx else self._current_tx[-1]
  217. def copy_graphstate(self) -> OutputGraphState:
  218. """Create a checkpoint of the current state by copying everything"""
  219. assert self.nn_modules is not None
  220. guards_graph_state = self.tracing_context.guards_context.copy_graphstate()
  221. state = OutputGraphState(
  222. list(self.graphargs),
  223. list(self.tracked_fakes),
  224. guards_graph_state,
  225. dict(self.nn_modules),
  226. self.side_effects.clone(),
  227. self.timestamp,
  228. )
  229. self.timestamp += 1
  230. return state
  231. def restore_graphstate(self, state: OutputGraphState):
  232. """Restore a checkpoint created by self.copy_graphstate()"""
  233. (
  234. self.graphargs,
  235. self.tracked_fakes,
  236. guards_state,
  237. self.nn_modules,
  238. self.side_effects,
  239. self.timestamp,
  240. ) = state
  241. self.tracing_context.guards_context.restore_graphstate(guards_state)
  242. # FX deepcopy doesn't work for a partially created graph, so just remove new nodes
  243. removed_nodes = 0
  244. for node in reversed(list(self.graph.nodes)):
  245. if node.meta["creation_timestamp"] > self.timestamp:
  246. # Erasing node alone does not remove the meta information
  247. # So, remove the help tensor explicitly
  248. if "example_value" in node.meta:
  249. del node.meta["example_value"]
  250. self.remove_node(node)
  251. self.real_value_cache.pop(node, None)
  252. removed_nodes += 1
  253. log.debug(f"restore_graphstate: removed {removed_nodes} nodes")
  254. def add_grapharg(self, arg: GraphArg):
  255. curr_pos = len(self.graphargs)
  256. self.graphargs.append(arg)
  257. if isinstance(arg.source, LocalInputSource):
  258. self.pos_to_arg[arg.source.pos] = curr_pos
  259. def count_calls(self):
  260. return count_calls(self.graph)
  261. def get_submodule(self, keys):
  262. assert keys
  263. obj = self.nn_modules
  264. for k in keys.split("."):
  265. if isinstance(obj, dict):
  266. obj = obj[k]
  267. else:
  268. obj = getattr(obj, k)
  269. return obj
  270. def create_graph_input(self, name, type_expr=None):
  271. # unique
  272. if name in self.name_to_input:
  273. for i in itertools.count():
  274. if f"{name}_{i}" not in self.name_to_input:
  275. name = f"{name}_{i}"
  276. break
  277. if self.name_to_input:
  278. prev_name = next(reversed(self.name_to_input))
  279. ctx = self.graph.inserting_after(self.name_to_input[prev_name])
  280. else:
  281. ctx = self.graph.inserting_before(None)
  282. with ctx:
  283. proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
  284. self.name_to_input[name] = proxy.node
  285. return proxy
  286. def new_var(self, name="tmp"):
  287. existing = set(self.code_options["co_varnames"])
  288. for i in itertools.count():
  289. var = f"___{name}_{i}"
  290. if var not in existing:
  291. self.code_options["co_varnames"] = self.code_options["co_varnames"] + (
  292. var,
  293. )
  294. return var
  295. def update_co_names(self, name):
  296. """Ensure self.code_options.co_names contains name"""
  297. if name not in self.code_options["co_names"]:
  298. self.code_options["co_names"] = tuple(self.code_options["co_names"]) + (
  299. name,
  300. )
  301. def register_attr_or_module(
  302. self,
  303. target: Union[torch.nn.Module, torch.Tensor, Any],
  304. *names,
  305. **options,
  306. ):
  307. if is_dynamic_nn_module(target):
  308. return variables.UnspecializedNNModuleVariable(target, **options)
  309. options = dict(options)
  310. options["guards"] = set(options.get("guards", []))
  311. assert "source" in options
  312. source = options["source"]
  313. if isinstance(target, torch.Tensor):
  314. if not is_constant_source(source):
  315. options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH))
  316. def wrap_name(module_key):
  317. return wrap_fx_proxy(
  318. self.root_tx,
  319. self.create_proxy("get_attr", module_key, tuple(), {}),
  320. example_value=target,
  321. **options,
  322. )
  323. elif isinstance(target, torch.nn.Module):
  324. assert isinstance(target, torch.nn.Module)
  325. options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
  326. def wrap_name(module_key):
  327. return NNModuleVariable(type(target), module_key, **options)
  328. elif isinstance(target, (torch.SymInt, torch.SymFloat)):
  329. # HACKY CODE REGION BEGIN
  330. # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
  331. # This ultimately gets written to self.nn_modules, which is unfortunate
  332. # Attrs that are tenors and symints and such need to be migrated to have their
  333. # own storage
  334. # alas, this is like this for now
  335. def wrap_name(module_key):
  336. return SymNodeVariable.create(
  337. self,
  338. self.create_proxy("get_attr", module_key, tuple(), {}),
  339. sym_num=target,
  340. **options,
  341. )
  342. # HACKY CODE REGION END
  343. else:
  344. def wrap_name(module_key):
  345. self.output.update_co_names(module_key)
  346. self.root_globals[module_key] = target
  347. return VariableBuilder(self, ConstantSource(source_name=module_key))(
  348. target
  349. )
  350. assert self.nn_modules is not None
  351. for k, v in self.nn_modules.items():
  352. if v is target:
  353. # it already exists
  354. return wrap_name(k)
  355. # create a new unique name
  356. name = "_".join(map(str, names))
  357. # e.g. repalce abc.xyz[123].qkv with abc.xyz_123.qkv
  358. name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
  359. # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
  360. name = re.sub(r"[^a-zA-Z0-9]", "_", name)
  361. if not name or not name[0].isalpha():
  362. name = "sub" + name
  363. base = name
  364. for i in itertools.count():
  365. if name not in self.nn_modules:
  366. self.nn_modules[name] = target
  367. return wrap_name(name)
  368. name = f"{base}_{i}"
  369. raise AssertionError("unreachable")
  370. def compile_subgraph(
  371. self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
  372. ):
  373. """
  374. Generate a subgraph to continue execution on user code.
  375. Automatically restore live variables.
  376. """
  377. from .eval_frame import disable
  378. self.partial_convert = partial_convert
  379. self.compile_subgraph_reason = reason
  380. log.debug(f"COMPILING GRAPH due to {reason}")
  381. if not all(block.can_restore() for block in tx.block_stack):
  382. unimplemented("compile_subgraph with block_depth != 0")
  383. for block in reversed(tx.block_stack):
  384. block.exit(tx)
  385. tx.prune_dead_locals()
  386. stack_values = list(tx.stack)
  387. assert self.nn_modules is not None
  388. root = FakeRootModule(self.nn_modules)
  389. # Add all the local vars to the "stack" so restore at the end
  390. restore_vars = []
  391. val_to_names: OrderedDict[
  392. VariableTracker, List[str]
  393. ] = collections.OrderedDict()
  394. if stack_values:
  395. val_to_names[stack_values[-1]] = list()
  396. for k, v in tx.symbolic_locals.items():
  397. if isinstance(v.source, LocalSource) and v.source.name() == k:
  398. continue # no need to restore initial state
  399. if v not in val_to_names:
  400. val_to_names[v] = list()
  401. val_to_names[v].append(k)
  402. for v in val_to_names.keys():
  403. restore_vars.extend(val_to_names[v])
  404. stack_values.extend([v] * len(val_to_names[v]))
  405. # to handle random calls
  406. if len(tx.random_calls) > 0:
  407. random_calls_instructions = []
  408. self.random_values_var = self.new_var("random_values")
  409. rand_fn_name = unique_id("__gen_rand_values")
  410. rand_fn = disable(_get_gen_rand_values_fn(tx.random_calls))
  411. self.install_global(rand_fn_name, rand_fn)
  412. codegen = PyCodegen(tx, root)
  413. random_calls_instructions.extend(
  414. [
  415. codegen.create_load_global("random", add=True),
  416. codegen.create_load_attr("setstate"),
  417. codegen.create_load_const(tx.output.initial_random_state),
  418. create_instruction("CALL_FUNCTION", 1),
  419. ]
  420. )
  421. random_calls_instructions.extend(codegen.load_function_name(rand_fn_name))
  422. random_calls_instructions.extend(
  423. [
  424. create_instruction("CALL_FUNCTION", 0),
  425. codegen.create_store(tx.output.random_values_var),
  426. ]
  427. )
  428. self.add_output_instructions(random_calls_instructions)
  429. if (
  430. stack_values
  431. and all(
  432. not isinstance(v, UnspecializedPythonVariable) for v in stack_values
  433. )
  434. and all(isinstance(x, TensorVariable) for x in stack_values)
  435. and len(set(stack_values)) == len(stack_values)
  436. and self.side_effects.is_empty()
  437. ):
  438. # optimization to generate better code in a common case
  439. self.add_output_instructions(
  440. self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  441. + [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
  442. )
  443. else:
  444. graph_output_var = self.new_var("graph_out")
  445. pass1 = PyCodegen(tx, root, graph_output_var)
  446. self.side_effects.codegen_save_tempvars(pass1)
  447. pass1.foreach(stack_values)
  448. self.side_effects.codegen_update_mutated(pass1)
  449. # one more time now that we have established tempvars
  450. pass2 = PyCodegen(
  451. tx,
  452. root,
  453. graph_output_var,
  454. tempvars={val: None for val, count in pass1.uses.items() if count > 1},
  455. )
  456. self.side_effects.codegen_save_tempvars(pass2)
  457. pass2.foreach(stack_values)
  458. self.side_effects.codegen_update_mutated(pass2)
  459. output = []
  460. if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
  461. output.extend(
  462. self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  463. )
  464. if len(pass2.graph_outputs) != 0:
  465. output.append(pass2.create_store(graph_output_var))
  466. else:
  467. output.append(create_instruction("POP_TOP"))
  468. self.add_output_instructions(output + pass2.get_instructions())
  469. # restore all the live local vars
  470. self.add_output_instructions(
  471. [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
  472. )
  473. def compile_and_call_fx_graph(self, tx, rv, root):
  474. """
  475. Generate code from self.graph and return the Instruction()s to
  476. call that generated code.
  477. """
  478. from .eval_frame import disable
  479. assert isinstance(rv, list)
  480. assert isinstance(root, FakeRootModule)
  481. for output in rv:
  482. self.guards.update(output.guards)
  483. self.create_node(
  484. "output", "output", (self.create_arg(tuple(x.as_proxy() for x in rv)),), {}
  485. )
  486. self.remove_unused_graphargs()
  487. ncalls = count_calls(self.graph)
  488. counters["stats"]["calls_captured"] += ncalls
  489. counters["stats"]["fusions_possible"] += ncalls - 1
  490. # free a bit of memory
  491. for node in self.graph.nodes:
  492. if "example_value" in node.meta:
  493. del node.meta["example_value"]
  494. self.real_value_cache.clear()
  495. gm = fx.GraphModule(root, self.graph)
  496. gm.recompile()
  497. gm.compile_subgraph_reason = self.compile_subgraph_reason
  498. name = unique_id("__compiled_fn")
  499. assert_no_fake_params_or_buffers(gm)
  500. with tracing(self.tracing_context):
  501. compiled_fn = self.call_user_compiler(gm)
  502. compiled_fn = disable(compiled_fn)
  503. counters["stats"]["unique_graphs"] += 1
  504. self.install_global(name, compiled_fn)
  505. try:
  506. # the call to tabulate can cause a lot of memory to be allocated
  507. if config.log_level <= logging.INFO and config.output_code:
  508. graph_str = (
  509. gm.print_readable()
  510. if config.output_graph_code
  511. else format_graph_tabular(gm.graph)
  512. )
  513. log.log(
  514. logging.INFO,
  515. f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {graph_str}\n",
  516. )
  517. except ImportError:
  518. log.warning(
  519. "Unable to print graph: `format_graph_tabular` relies on the library `tabulate`, "
  520. "which could not be found on this machine. Run `pip "
  521. "install tabulate` to install the library."
  522. )
  523. cg = PyCodegen(tx)
  524. cg.make_call_generated_code(name)
  525. return cg.get_instructions()
  526. @dynamo_timed(phase_name="backend_compile")
  527. def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
  528. tot = 0
  529. for node in gm.graph.nodes:
  530. if node.op in ("call_function", "call_method", "call_module"):
  531. tot += 1
  532. torch._dynamo.utils.increment_op_count(tot)
  533. try:
  534. name = (
  535. self.compiler_fn.__name__
  536. if hasattr(self.compiler_fn, "__name__")
  537. else ""
  538. )
  539. _step_logger()(logging.INFO, f"calling compiler function {name}")
  540. compiler_fn = self.compiler_fn
  541. # WrapperBackend needs real inputs, for now, to verify correctness
  542. if config.verify_correctness:
  543. compiler_fn = WrapperBackend(compiler_fn, self.example_inputs())
  544. # NOTE: [Real Tensors in Accuracy Evaluation]
  545. #
  546. # Today, tensors are passed to backends as fake at compile time. See the .fake_example_inputs()
  547. # call to compiler_fn below. At runtime, backends use real tensors.
  548. #
  549. # This should be a strong invariant we hold across all backends,
  550. # and generally, it is. However, for accuracy evaluation, we need real tensors at compile time,
  551. # for now, due to the unfortunate setup described below.
  552. #
  553. # Due to the nature of how we invoke comparison as a backend in two different ways:
  554. #
  555. # (1) Less bad, but still worth rewriting, WrapperBackend above, which takes
  556. # real inputs for its ctor. see the config.verify_correctnes above.
  557. #
  558. # (2) More bad, and very worth rewriting, the minifier installs accuracy comparison as
  559. # a true backend, and therefore needs to be compiled with real inputs. This is made trickier
  560. # by the fact that the minifier will spawn new processes during minification. As such, we have
  561. # created a global flag, MINIFIER_SPAWNED, that should be set IF AND ONLY IF this run was spawned
  562. # as part of accuracy minification. This flag is not a contract, and ideally will not be here long.
  563. #
  564. # The longer term PoR is to:
  565. # (A) Rewrite the minifier accuracy evaluation and verify_correctness code to share the same
  566. # correctness and accuracy logic, so as not to have two different ways of doing the same thing.
  567. #
  568. # (B) Refactor minifier accuracy backend to do its comparison fully at runtime, so as not to need to
  569. # pass real tensors to it at compile time.
  570. is_top_level_minifying = (
  571. config.repro_after is not None and config.repro_level == 4
  572. )
  573. if torch._dynamo.debug_utils.MINIFIER_SPAWNED or is_top_level_minifying:
  574. compiled_fn = compiler_fn(gm, self.example_inputs())
  575. elif config.DO_NOT_USE_legacy_non_fake_example_inputs:
  576. compiled_fn = compiler_fn(gm, self.example_inputs())
  577. else:
  578. compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  579. _step_logger()(logging.INFO, f"done compiler function {name}")
  580. assert callable(compiled_fn), "compiler_fn did not return callable"
  581. except Exception as e:
  582. compiled_fn = gm.forward
  583. raise BackendCompilerFailed(self.compiler_fn, e) from e
  584. return compiled_fn
  585. def fake_example_inputs(self) -> List[torch.Tensor]:
  586. result = []
  587. for arg in self.graphargs:
  588. example = arg.get_fake_examples()
  589. if example is not None:
  590. result.extend(example)
  591. else:
  592. # Fallback, in case fake_tensor was not set
  593. # Particularly for graph args that are not tensors
  594. result.extend(arg.get_examples())
  595. return result
  596. def example_inputs(self) -> List[torch.Tensor]:
  597. result = []
  598. for arg in self.graphargs:
  599. result.extend(arg.get_examples())
  600. return result
  601. def remove_unused_graphargs(self) -> None:
  602. for node in reversed(list(self.graph.nodes)):
  603. if len(list(node.users)) == 0:
  604. if node.op == "get_attr":
  605. self.remove_node(node)
  606. elif node.op == "call_function" and node.target is operator.getitem:
  607. self.remove_node(node)
  608. expanded_graphargs = []
  609. for arg in self.graphargs:
  610. expanded_graphargs.extend([arg] * len(arg))
  611. arg.uses = 0
  612. for node, arg in zip(self.graph.nodes, expanded_graphargs):
  613. assert node.op == "placeholder"
  614. arg.uses += len(node.users)
  615. for node, arg in list(zip(self.graph.nodes, expanded_graphargs)):
  616. if arg.uses == 0:
  617. log.debug(f"REMOVE UNUSED GRAPHARG {arg.source.name()}")
  618. if "example_value" in node.meta:
  619. del node.meta["example_value"]
  620. self.remove_node(node)
  621. self.real_value_cache.pop(node, None)
  622. self.graphargs = [arg for arg in self.graphargs if arg.uses > 0]
  623. def add_output_instructions(self, prefix: List[Instruction]) -> None:
  624. """
  625. We call this on the creation of a new compiled subgraph that is inserted
  626. before user code.
  627. """
  628. self.output_instructions.extend(prefix)
  629. self.should_exit = True
  630. def install_global(self, name, value) -> None:
  631. self.cleanups.append(CleanupHook.create(self.root_globals, name, value))
  632. def cleanup(self) -> None:
  633. # There is a reference cycle between tracer and OutputGraph, causing
  634. # some of the tensor objects to be held alive for longer than necessary.
  635. self.root_tx = None
  636. # Note: generated fx graph will hold a reference to the nn_module,
  637. # So depending on the backend they may not be released
  638. self.nn_modules = None
  639. # Cleanup graphargs
  640. for graph_arg in self.graphargs:
  641. graph_arg.erase()
  642. for node in self.graph.nodes:
  643. if "example_value" in node.meta:
  644. del node.meta["example_value"]
  645. self.real_value_cache.clear()
  646. self.name_to_input.clear()
  647. self.side_effects.keepalive = []
  648. def create_proxy(
  649. self,
  650. kind,
  651. target,
  652. args,
  653. kwargs,
  654. name=None,
  655. type_expr=None,
  656. proxy_factory_fn=None,
  657. ):
  658. rv = super().create_proxy(
  659. kind, target, args, kwargs, name, type_expr, proxy_factory_fn
  660. )
  661. # append stack trace to fx node
  662. tx = self.current_tx
  663. nn_module_stack = tx.nn_module_stack
  664. if nn_module_stack:
  665. rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
  666. if kind in {"call_function", "call_method"}:
  667. rv.node.meta["source_fn"] = target
  668. frame_summaries: List[traceback.FrameSummary] = []
  669. while tx:
  670. frame_summaries.append(tx.frame_summary())
  671. tx = getattr(tx, "parent", None)
  672. # official from_list stub doesn't have new-style type
  673. msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type]
  674. rv.node.stack_trace = " | ".join(msgs)
  675. return rv
  676. def create_node(self, *args, **kwargs):
  677. node = super().create_node(*args, **kwargs)
  678. node.meta["creation_timestamp"] = self.timestamp
  679. return node
  680. # Note: we did not override erase_node since
  681. # we call self.graph.erase_node elsewhere
  682. def remove_node(self, node):
  683. self.graph.erase_node(node)
  684. self.name_to_input.pop(node.name, None)