eval_frame.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809
  1. import contextlib
  2. import functools
  3. import inspect
  4. import logging
  5. import os
  6. import sys
  7. import textwrap
  8. import threading
  9. import traceback
  10. import types
  11. import warnings
  12. from enum import Enum
  13. from typing import Optional, Tuple, TYPE_CHECKING, Union
  14. from unittest.mock import patch
  15. import torch
  16. import torch.utils._pytree as pytree
  17. from torch.fx.experimental.proxy_tensor import make_fx
  18. from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
  19. from torch.nn.parallel.distributed import DistributedDataParallel
  20. from .backends.registry import CompilerFn, lookup_backend
  21. from .hooks import Hooks
  22. if TYPE_CHECKING:
  23. from torch._C._dynamo.eval_frame import ( # noqa: F401
  24. reset_code,
  25. set_eval_frame,
  26. set_guard_error_hook,
  27. set_guard_fail_hook,
  28. skip_code,
  29. unsupported,
  30. )
  31. else:
  32. for name in dir(torch._C._dynamo.eval_frame):
  33. if name.startswith("__"):
  34. continue
  35. globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
  36. from . import config, convert_frame, skipfiles, utils
  37. from .exc import ResetRequired
  38. from .mutation_guard import install_generation_tagging_init
  39. from .types import DynamoCallback
  40. from .utils import compile_times
  41. log = logging.getLogger(__name__)
  42. from torch.fx.experimental import proxy_tensor
  43. always_optimize_code_objects = utils.ExactWeakKeyDictionary()
  44. null_context = contextlib.nullcontext
  45. # See https://github.com/python/typing/pull/240
  46. class Unset(Enum):
  47. token = 0
  48. unset = Unset.token
  49. compile_lock = threading.RLock()
  50. most_recent_backend: Optional[CompilerFn] = None
  51. class OptimizedModule(torch.nn.Module):
  52. """
  53. Wraps the original nn.Module object and later patches its
  54. forward method to optimized self.forward method.
  55. """
  56. def __init__(self, mod, dynamo_ctx):
  57. super().__init__()
  58. # Installs the params/buffer
  59. self._orig_mod = mod
  60. self.dynamo_ctx = dynamo_ctx
  61. def __getattr__(self, name):
  62. if name == "_orig_mod":
  63. return self._modules["_orig_mod"]
  64. return getattr(self._orig_mod, name)
  65. def forward(self, *args, **kwargs):
  66. return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  67. def remove_from_cache(f):
  68. """
  69. Make sure f.__code__ is not cached to force a recompile
  70. """
  71. if isinstance(f, types.CodeType):
  72. reset_code(f)
  73. elif hasattr(f, "__code__"):
  74. reset_code(f.__code__)
  75. elif hasattr(getattr(f, "forward", None), "__code__"):
  76. reset_code(f.forward.__code__)
  77. else:
  78. from . import reset
  79. reset()
  80. log.warning("could not determine __code__ for %s", f)
  81. def nothing():
  82. pass
  83. def innermost_fn(fn):
  84. """
  85. In case of nesting of _TorchDynamoContext calls, find the innermost
  86. function. TorchDynamo caches on fn.__code__ object, so its necessary to find
  87. the innermost function to pass on the optimize, run, disable etc.
  88. """
  89. unaltered_fn = fn
  90. while hasattr(unaltered_fn, "_torchdynamo_orig_callable"):
  91. unaltered_fn = unaltered_fn._torchdynamo_orig_callable
  92. assert callable(unaltered_fn)
  93. return unaltered_fn
  94. @contextlib.contextmanager
  95. def enable_dynamic(enable: bool = True):
  96. if not enable:
  97. yield
  98. return
  99. with config.patch(dynamic_shapes=True, specialize_int_float=False):
  100. yield
  101. class _TorchDynamoContext:
  102. def __init__(
  103. self,
  104. callback: DynamoCallback,
  105. on_enter=nothing,
  106. backend_ctx_ctor=null_context,
  107. patch_fn=nothing,
  108. first_ctx=False,
  109. *,
  110. dynamic=False,
  111. ):
  112. super().__init__()
  113. assert callable(callback) or callback is False or callback is None
  114. self.callback: DynamoCallback = callback
  115. self.prior: Union[Unset, DynamoCallback] = unset
  116. self.on_enter = on_enter
  117. self.extra_ctx_ctor = backend_ctx_ctor
  118. self.first_ctx = first_ctx
  119. self.dynamic = dynamic
  120. patch_fn()
  121. def __enter__(self):
  122. if config.raise_on_ctx_manager_usage:
  123. raise RuntimeError(
  124. "torch._dynamo.optimize(...) is used with a context manager. "
  125. "Please refer to https://github.com/pytorch/torchdynamo#usage-example "
  126. "to use torch._dynamo.optimize(...) as an annotation/decorator. "
  127. )
  128. self.on_enter()
  129. self.prior = set_eval_frame(self.callback)
  130. self.backend_ctx = self.extra_ctx_ctor()
  131. self.backend_ctx.__enter__()
  132. self.dynamic_ctx = enable_dynamic(self.dynamic)
  133. self.dynamic_ctx.__enter__()
  134. def __exit__(self, exc_type, exc_val, exc_tb):
  135. assert self.prior is not unset
  136. set_eval_frame(self.prior)
  137. self.prior = unset
  138. # TODO: This is totally not the right way to chain contexts manually
  139. self.dynamic_ctx.__exit__(exc_type, exc_val, exc_tb)
  140. self.backend_ctx.__exit__(exc_type, exc_val, exc_tb)
  141. def __call__(self, fn):
  142. fn = innermost_fn(fn)
  143. # Optimize the forward method of torch.nn.Module object
  144. if isinstance(fn, torch.nn.Module):
  145. mod = fn
  146. new_mod = OptimizedModule(mod, self)
  147. # Save the function pointer to find the original callable while nesting
  148. # of decorators.
  149. new_mod._torchdynamo_orig_callable = mod.forward
  150. return new_mod
  151. assert callable(fn)
  152. callback = self.callback
  153. on_enter = self.on_enter
  154. backend_ctx_ctor = self.extra_ctx_ctor
  155. @functools.wraps(fn)
  156. def _fn(*args, **kwargs):
  157. if (
  158. not isinstance(self, DisableContext)
  159. and torch.fx._symbolic_trace.is_fx_tracing()
  160. ):
  161. if config.error_on_nested_fx_trace:
  162. raise RuntimeError(
  163. "Detected that you are using FX to symbolically trace "
  164. "a dynamo-optimized function. This is not supported at the moment."
  165. )
  166. else:
  167. return fn(*args, **kwargs)
  168. on_enter()
  169. prior = set_eval_frame(callback)
  170. backend_ctx = backend_ctx_ctor()
  171. backend_ctx.__enter__()
  172. dynamic_ctx = enable_dynamic(self.dynamic)
  173. dynamic_ctx.__enter__()
  174. try:
  175. return fn(*args, **kwargs)
  176. finally:
  177. set_eval_frame(prior)
  178. dynamic_ctx.__exit__(None, None, None)
  179. backend_ctx.__exit__(None, None, None)
  180. # hooks to properly handle inlining
  181. if isinstance(self, DisableContext):
  182. _fn._torchdynamo_disable = True # type: ignore[attr-defined]
  183. else:
  184. _fn._torchdynamo_inline = fn # type: ignore[attr-defined]
  185. # Save the function pointer to find the original callable while nesting
  186. # of decorators.
  187. _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
  188. # If the function is called using torch._dynamo.optimize decorator, we
  189. # should prevent any type of skipping.
  190. if callback not in (None, False):
  191. if not hasattr(fn, "__code__"):
  192. raise RuntimeError(
  193. textwrap.dedent(
  194. """
  195. torch._dynamo.optimize is called on a non function object.
  196. If this is a callable class, please wrap the relevant code into a function and optimize the
  197. wrapper function.
  198. >> class CallableClass:
  199. >> def __init__(self):
  200. >> super().__init__()
  201. >> self.relu = torch.nn.ReLU()
  202. >>
  203. >> def __call__(self, x):
  204. >> return self.relu(torch.sin(x))
  205. >>
  206. >> def print_hello(self):
  207. >> print("Hello world")
  208. >>
  209. >> mod = CallableClass()
  210. If you want to optimize the __call__ function and other code, wrap that up in a function
  211. >> def wrapper_fn(x):
  212. >> y = mod(x)
  213. >> return y.sum()
  214. and then optimize the wrapper_fn
  215. >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn)
  216. """
  217. )
  218. )
  219. always_optimize_code_objects[fn.__code__] = True
  220. return _fn
  221. class OptimizeContext(_TorchDynamoContext):
  222. @staticmethod
  223. def _different_backend(old, new):
  224. return not (old == new or old is None)
  225. def __init__(self, callback, backend_ctx_ctor, first_ctx=False, *, dynamic=False):
  226. def on_enter():
  227. global most_recent_backend
  228. if OptimizeContext._different_backend(most_recent_backend, compiler_fn):
  229. if config.raise_on_backend_change:
  230. raise ResetRequired()
  231. else:
  232. warnings.warn(
  233. "changing options to `torch.compile()` may require "
  234. "calling `torch._dynamo.reset()` to take effect"
  235. )
  236. most_recent_backend = compiler_fn
  237. install_generation_tagging_init()
  238. compiler_fn = innermost_fn(callback)
  239. super().__init__(
  240. callback=callback,
  241. on_enter=on_enter,
  242. backend_ctx_ctor=backend_ctx_ctor,
  243. patch_fn=TorchPatcher.patch,
  244. first_ctx=first_ctx,
  245. dynamic=dynamic,
  246. )
  247. class RunOnlyContext(_TorchDynamoContext):
  248. def __init__(self):
  249. super().__init__(callback=False)
  250. class DisableContext(_TorchDynamoContext):
  251. def __init__(self):
  252. super().__init__(callback=None)
  253. def catch_errors_wrapper(callback, hooks: Hooks):
  254. @functools.wraps(callback)
  255. def catch_errors(frame, cache_size):
  256. if (
  257. frame.f_lasti >= 0
  258. or skipfiles.check(frame.f_code.co_filename)
  259. or config.disable
  260. ):
  261. log.debug(f"skipping {frame.f_code.co_name} {frame.f_code.co_filename}")
  262. return None
  263. if frame.f_code.co_filename == "<string>" and frame.f_code.co_name == "__new__":
  264. # nametuple constructor
  265. return None
  266. if config.optimize_ddp:
  267. ddp_module = DistributedDataParallel._get_active_ddp_module()
  268. if ddp_module:
  269. with compile_lock:
  270. from torch._dynamo.backends.distributed import DDPOptimizer
  271. ddp_optimizer = DDPOptimizer(
  272. bucket_bytes_cap=ddp_module.bucket_bytes_cap,
  273. backend_compile_fn=callback._torchdynamo_orig_callable,
  274. )
  275. hijacked_callback = convert_frame.convert_frame(
  276. ddp_optimizer.compile_fn,
  277. hooks=hooks,
  278. )
  279. return hijacked_callback(frame, cache_size, hooks)
  280. with compile_lock:
  281. return callback(frame, cache_size, hooks)
  282. catch_errors._torchdynamo_orig_callable = callback # type: ignore[attr-defined]
  283. return catch_errors
  284. def _optimize_catch_errors(
  285. compile_fn, hooks: Hooks, backend_ctx_ctor=null_context, dynamic=False
  286. ):
  287. return OptimizeContext(
  288. catch_errors_wrapper(compile_fn, hooks),
  289. backend_ctx_ctor=backend_ctx_ctor,
  290. first_ctx=True,
  291. dynamic=dynamic,
  292. )
  293. def get_compiler_fn(compiler_fn):
  294. from .debug_utils import wrap_backend_debug
  295. if hasattr(compiler_fn, "compiler_name"):
  296. compiler_str = compiler_fn.compiler_name
  297. elif isinstance(compiler_fn, str):
  298. compiler_str = compiler_fn
  299. else:
  300. compiler_str = None
  301. compiler_fn = lookup_backend(compiler_fn)
  302. return wrap_backend_debug(compiler_fn, compiler_str)
  303. class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
  304. def __call__(self, fn):
  305. assert callable(fn)
  306. return fn
  307. def check_if_dynamo_supported():
  308. if sys.platform == "win32":
  309. raise RuntimeError("Windows not yet supported for torch.compile")
  310. if sys.version_info >= (3, 11):
  311. raise RuntimeError("Python 3.11+ not yet supported for torch.compile")
  312. def optimize(
  313. backend="inductor",
  314. *,
  315. nopython=False,
  316. guard_export_fn=None,
  317. guard_fail_fn=None,
  318. disable=False,
  319. dynamic=False,
  320. ):
  321. """
  322. The main entrypoint of TorchDynamo. Do graph capture and call
  323. backend() to optimize extracted graphs.
  324. Args:
  325. backend: One of the two things:
  326. - Either, a function/callable taking a torch.fx.GraphModule and
  327. example_inputs and returning a python callable that runs the
  328. graph faster.
  329. One can also provide additional context for the backend, like
  330. torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute.
  331. See AOTAutogradMemoryEfficientFusionWithContext for the usage.
  332. - Or, a string backend name in `torch._dynamo.list_backends()`
  333. nopython: If True, graph breaks will be errors and there will
  334. be a single whole-program graph.
  335. disable: If True, turn this decorator into a no-op
  336. dynamic: If True, turn on dynamic shapes support
  337. Example Usage::
  338. @torch._dynamo.optimize()
  339. def toy_example(a, b):
  340. ...
  341. """
  342. check_if_dynamo_supported()
  343. # Note: The hooks object could be global instead of passed around, *however* that would make
  344. # for a confusing API usage and plumbing story wherein we nest multiple .optimize calls.
  345. # There is some prior art around this, w/r/t nesting backend calls are enforced to be the same
  346. # compiler, however, this feels onerous for callback and hooks, and it feels better to give our users an
  347. # easier to understand UX at the cost of a little more plumbing on our end.
  348. hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn)
  349. torch._C._log_api_usage_once("torch._dynamo.optimize")
  350. if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1":
  351. return _NullDecorator()
  352. backend = get_compiler_fn(backend)
  353. # Find if backend has any extra context manager
  354. backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
  355. if nopython:
  356. return optimize_assert(
  357. backend,
  358. dynamic=dynamic,
  359. hooks=hooks,
  360. )
  361. return _optimize_catch_errors(
  362. convert_frame.convert_frame(backend, hooks=hooks),
  363. hooks,
  364. backend_ctx_ctor,
  365. dynamic=dynamic,
  366. )
  367. # TODO(voz): Consider making "explain" output alongside a run / part of a run
  368. @patch("torch._dynamo.symbolic_convert.explain", True)
  369. def explain(f, *args, **kwargs):
  370. # TODO(voz): Do we want a decorator for this?
  371. from . import reset
  372. reset()
  373. out_guards = []
  374. graphs = []
  375. ops_per_graph = []
  376. op_count = 0
  377. break_reasons = []
  378. def dynamo_graph_accumulating_compiler(gm: torch.fx.GraphModule, example_inputs):
  379. nonlocal graphs
  380. nonlocal op_count
  381. nonlocal ops_per_graph
  382. graphs.append(gm)
  383. ops = []
  384. for node in gm.graph.nodes:
  385. if node.op == "call_function":
  386. ops.append(node.target)
  387. op_count += len(ops)
  388. ops_per_graph.append(ops)
  389. if gm.compile_subgraph_reason is not None:
  390. break_reasons.append(gm.compile_subgraph_reason)
  391. return gm.forward
  392. def guard_export_print(guards):
  393. nonlocal out_guards
  394. out_guards.append(guards)
  395. with patch(f"{__name__}.most_recent_backend", None):
  396. opt_f = optimize(
  397. dynamo_graph_accumulating_compiler,
  398. nopython=False,
  399. guard_export_fn=guard_export_print,
  400. )(f)
  401. # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.
  402. opt_f(*args, **kwargs)
  403. graph_count = len(graphs)
  404. # For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it.
  405. deduped_reasons = {}
  406. for reason in break_reasons:
  407. innermost_frame = reason.user_stack[-1]
  408. # __repr__ uniquely identifies a FrameSummary so we can use it for deduping
  409. deduped_reasons[repr(innermost_frame)] = reason
  410. formatted_list = ""
  411. for idx, break_reason in enumerate(deduped_reasons.values()):
  412. formatted_stack = "".join(traceback.format_list(break_reason.user_stack))
  413. msg = f"{break_reason.reason}\n{formatted_stack}"
  414. formatted_list += f"{idx + 1}. {msg} \n"
  415. explanation = f"Dynamo produced {graph_count} graphs "
  416. explanation += f"with {graph_count - 1} graph break and {op_count} ops"
  417. explanation_verbose = explanation
  418. explanation_verbose += f"\n Break reasons: \n\n{formatted_list}"
  419. explanation_verbose += compile_times()
  420. # TODO(voz): Do we want a decorator for this?
  421. reset()
  422. return (
  423. explanation,
  424. out_guards,
  425. graphs,
  426. ops_per_graph,
  427. break_reasons,
  428. explanation_verbose,
  429. )
  430. def export(
  431. f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs
  432. ):
  433. check_if_dynamo_supported()
  434. torch._C._log_api_usage_once("torch._dynamo.export")
  435. if decomposition_table is not None or tracing_mode != "real":
  436. assert (
  437. aten_graph
  438. ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True"
  439. f = innermost_fn(f)
  440. graph = None
  441. out_guards = None
  442. graph_captured_input = None
  443. graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None
  444. def produce_matching(source_args, candidate_args):
  445. matched_elements_positions = []
  446. dict_of_source_args = dict()
  447. for i in range(0, len(source_args)):
  448. element_id = id(source_args[i])
  449. dict_of_source_args[element_id] = i
  450. for i in range(0, len(candidate_args)):
  451. arg = candidate_args[i]
  452. # 1-element tensor arg can be unspec int/float
  453. if isinstance(arg, torch.Tensor) and torch.numel(arg) == 1:
  454. if id(arg) in dict_of_source_args:
  455. matched_elements_positions.append(dict_of_source_args[id(arg)])
  456. elif id(arg.item()) in dict_of_source_args:
  457. matched_elements_positions.append(
  458. dict_of_source_args[id(arg.item())]
  459. )
  460. else:
  461. raise AssertionError(
  462. "Dynamo input/output is not consistent with traced input/output"
  463. )
  464. else:
  465. assert (
  466. id(arg) in dict_of_source_args
  467. ), "Dynamo input and output is a strict subset of traced input/output"
  468. matched_elements_positions.append(dict_of_source_args[id(arg)])
  469. return matched_elements_positions
  470. def guard_export_print(guards):
  471. nonlocal out_guards
  472. assert out_guards is None, "whole graph export entails exactly one guard export"
  473. out_guards = guards
  474. def dynamo_normalization_capturing_compiler(
  475. gm: torch.fx.GraphModule, example_inputs
  476. ):
  477. nonlocal graph
  478. assert graph is None, "whole graph export entails exactly one graph"
  479. graph = gm
  480. def result_capturing_wrapper(*graph_inputs):
  481. nonlocal graph_captured_result
  482. nonlocal graph_captured_input
  483. graph_captured_input = graph_inputs
  484. assert graph is not None
  485. graph_captured_result = graph(*graph_inputs)
  486. return graph_captured_result
  487. return result_capturing_wrapper
  488. flat_args, in_spec = pytree.tree_flatten((args, kwargs))
  489. remove_from_cache(f)
  490. with patch(f"{__name__}.most_recent_backend", None):
  491. opt_f = optimize_assert(
  492. dynamo_normalization_capturing_compiler,
  493. hooks=Hooks(guard_export_fn=guard_export_print, guard_fail_fn=None),
  494. export=True,
  495. dynamic=(tracing_mode == "symbolic"),
  496. )(f)
  497. # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject.
  498. result_traced = opt_f(*args, **kwargs)
  499. remove_from_cache(f)
  500. assert graph is not None, "whole graph export entails exactly one call"
  501. assert out_guards is not None, "whole graph export entails exactly one guard export"
  502. matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)
  503. flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced)
  504. assert graph_captured_result is not None
  505. flat_both = list(graph_captured_result) + flat_args
  506. matched_output_elements_positions = produce_matching(flat_both, flat_results_traced)
  507. class ChangeInputOutputSignature(torch.fx.interpreter.Transformer):
  508. def __init__(
  509. self,
  510. m,
  511. ):
  512. super().__init__(m)
  513. arg_len = len(flat_args)
  514. self.new_args = [
  515. super(ChangeInputOutputSignature, self).placeholder(f"arg{i}", (), {})
  516. for i in range(0, arg_len)
  517. ]
  518. self.old_args_gen = (
  519. self.new_args[i] for i in matched_input_elements_positions
  520. )
  521. def placeholder(self, target, args, kwargs):
  522. arg = next(self.old_args_gen)
  523. if "val" in self.current_node.meta:
  524. arg.node.meta["val"] = self.current_node.meta["val"]
  525. if "tensor_dict" in self.current_node.meta:
  526. arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"]
  527. return arg
  528. def output(self, target, args, kwargs):
  529. dynamo_result_flat = args[0]
  530. lookup = [*dynamo_result_flat, *self.new_args]
  531. new_result_flat = [lookup[i] for i in matched_output_elements_positions]
  532. return super().output(target, (new_result_flat,), {})
  533. def run_node(self, n):
  534. self.current_node = n
  535. return super().run_node(n)
  536. if aten_graph:
  537. # Running graph with interpreter is needed for propagating the stack_trace
  538. def graph_with_interpreter(*args):
  539. with torch.fx.traceback.preserve_node_meta():
  540. return torch.fx.Interpreter(graph).run(*args)
  541. graph = make_fx(
  542. graph_with_interpreter,
  543. decomposition_table=decomposition_table,
  544. tracing_mode=tracing_mode,
  545. _allow_non_fake_inputs=True,
  546. )(*graph_captured_input)
  547. new_graph = ChangeInputOutputSignature(
  548. graph,
  549. ).transform()
  550. # Make dynamo graph to have same input/output spec as user code
  551. input_strs = [f"orig_arg_{i}" for i in range(len(args))] + list(kwargs.keys())
  552. new_graph.graph._codegen = _PyTreeCodeGen(
  553. _PyTreeInfo(
  554. input_strs,
  555. in_spec,
  556. out_spec_traced,
  557. )
  558. )
  559. new_graph.recompile()
  560. return (new_graph, out_guards)
  561. def assume_constant_result(fn):
  562. fn._dynamo_marked_constant = True
  563. return fn
  564. def optimize_assert(backend, *, hooks=Hooks(None, None), export=False, dynamic=False):
  565. """
  566. The same as `torch._dynamo.optimize(backend, nopython=True)`
  567. """
  568. backend = get_compiler_fn(backend)
  569. # Find if backend has any extra context manager
  570. backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context)
  571. return _optimize_catch_errors(
  572. convert_frame.convert_frame_assert(backend, export=export),
  573. hooks,
  574. backend_ctx_ctor,
  575. dynamic=dynamic,
  576. )
  577. def run(fn=None):
  578. """Don't do any dynamic compiles, just use prior optimizations"""
  579. if fn is not None:
  580. fn = innermost_fn(fn)
  581. assert callable(fn)
  582. return RunOnlyContext()(fn)
  583. return RunOnlyContext()
  584. def disable(fn=None):
  585. """Decorator and context manager to disable TorchDynamo"""
  586. if fn is not None:
  587. fn = innermost_fn(fn)
  588. assert callable(fn)
  589. return DisableContext()(fn)
  590. return DisableContext()
  591. def skip(fn=None):
  592. """
  593. Skip frames associated with the function code, but still process recursively
  594. invoked frames
  595. """
  596. if fn is None:
  597. return skip
  598. fn = innermost_fn(fn)
  599. assert callable(fn)
  600. skip_code(fn.__code__)
  601. fn._torchdynamo_disable = True
  602. return fn
  603. class TorchPatcher:
  604. @staticmethod
  605. @functools.lru_cache(None)
  606. def patch():
  607. # Disable TorchDynamo on some torch.* compilers generated frames
  608. torch.jit.trace = disable(torch.jit.trace)
  609. torch.jit.trace_module = disable(torch.jit.trace_module)
  610. torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph)
  611. # symbolic_trace creates new frames. We disable Dynamo on such frames
  612. torch.fx._symbolic_trace.Tracer.trace = disable(
  613. torch.fx._symbolic_trace.Tracer.trace
  614. )
  615. torch.onnx.export_to_pretty_string = disable(torch.onnx.export_to_pretty_string)
  616. torch.distributions.Distribution.set_default_validate_args(False)
  617. proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace)
  618. optimizers = [
  619. opt
  620. for opt in torch.optim.__dict__.values()
  621. if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
  622. ]
  623. # disable dynamo for the wrapper that helps give dynamo hints about entering DDP
  624. if hasattr(DistributedDataParallel, "_inside_ddp_forward"):
  625. DistributedDataParallel._inside_ddp_forward = skip(
  626. DistributedDataParallel._inside_ddp_forward
  627. )
  628. from ..optim import adagrad, adam, adamax, adamw, asgd, nadam, sgd
  629. for opt_mod in adagrad, adam, adamax, adamw, asgd, nadam, sgd:
  630. multi_tensor_fn_name = f"_multi_tensor_{opt_mod.__name__.split('.')[-1]}"
  631. if hasattr(opt_mod, multi_tensor_fn_name):
  632. setattr(
  633. opt_mod,
  634. multi_tensor_fn_name,
  635. disable(getattr(opt_mod, multi_tensor_fn_name)),
  636. )
  637. excluded_opts = {torch.optim.SparseAdam, torch.optim.RAdam, torch.optim.LBFGS}
  638. for opt in optimizers:
  639. if opt in excluded_opts:
  640. opt.step = disable(opt.step)
  641. opt._cuda_graph_capture_health_check = disable(
  642. opt._cuda_graph_capture_health_check
  643. )
  644. opt.zero_grad = disable(opt.zero_grad)
  645. if hasattr(opt, "_init_group"):
  646. opt._init_group = disable(opt._init_group)
  647. # disable any currently set hooks
  648. # Note: we only want to disable the profiling hook
  649. # which is the *last* hook applied, we want to keep the no_grad hook
  650. hooked = getattr(opt.step, "hooked", False)
  651. if hooked:
  652. unwrapped_step = getattr(opt.step, "__wrapped__", None)
  653. if unwrapped_step:
  654. opt.step = unwrapped_step
  655. # disable future hooking
  656. opt.step.hooked = True
  657. @staticmethod
  658. def suppress_torch_distributed_warnings(fn):
  659. def inner_fn(*args, **kwargs):
  660. warnings.filterwarnings(
  661. "ignore", category=UserWarning, module="torch.distributed"
  662. )
  663. return fn(*args, **kwargs)
  664. return inner_fn