convert_frame.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. import functools
  2. import itertools
  3. import logging
  4. import os
  5. import types
  6. import weakref
  7. from typing import Dict, Optional, Set
  8. import torch
  9. from torch.fx.graph_module import _forward_from_src as original_forward_from_src
  10. from . import config, exc
  11. from .allowed_functions import is_allowed
  12. from .backends.registry import CompilerFn
  13. from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
  14. from .bytecode_transformation import is_generator, transform_code_object
  15. from .eval_frame import always_optimize_code_objects, skip_code, TorchPatcher
  16. from .exc import (
  17. augment_exc_message,
  18. BackendCompilerFailed,
  19. format_error_msg,
  20. InternalTorchDynamoError,
  21. TorchRuntimeError,
  22. unimplemented,
  23. Unsupported,
  24. )
  25. from .guards import CheckFunctionManager, GuardedCode
  26. from .hooks import Hooks
  27. from .output_graph import OutputGraph
  28. from .replay_record import ExecutionRecord
  29. from .symbolic_convert import InstructionTranslator
  30. from .utils import (
  31. CleanupManager,
  32. counters,
  33. dynamo_timed,
  34. format_bytecode,
  35. gen_record_file_name,
  36. guard_failures,
  37. increment_frame,
  38. init_logging,
  39. is_namedtuple,
  40. istype,
  41. orig_code_map,
  42. troubleshooting_url,
  43. write_record_to_file,
  44. )
  45. log = logging.getLogger(__name__)
  46. class Tracker:
  47. def __init__(self):
  48. self.seen = []
  49. self.seen_ids = set()
  50. def add(self, strong_obj):
  51. idx = id(strong_obj)
  52. if idx not in self.seen_ids:
  53. obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx))
  54. self.seen.append(obj)
  55. self.seen_ids.add(idx)
  56. def __contains__(self, item):
  57. return id(item) in self.seen_ids
  58. def clear(self):
  59. self.seen.clear()
  60. self.seen_ids.clear()
  61. input_codes = Tracker()
  62. output_codes = Tracker()
  63. initial_grad_state = None
  64. @functools.wraps(original_forward_from_src)
  65. def fx_forward_from_src_skip_result(*args, **kwargs):
  66. # we monkey patch FX to prevent infinite loop of trying to convert
  67. # our generated code
  68. result: types.FunctionType = original_forward_from_src(*args, **kwargs)
  69. skip_code(result.__code__)
  70. return result
  71. def wrap_convert_context(fn):
  72. """
  73. Context manager to:
  74. 1) Save/restore torch random state
  75. 2) Save/restore torch.is_grad_enabled() state
  76. 3) Monkey patch torch.fx.graph_module._forward_from_src
  77. """
  78. @functools.wraps(fn)
  79. def _fn(*args, **kwargs):
  80. prior_grad_mode = torch.is_grad_enabled()
  81. rng_state = torch.random.get_rng_state()
  82. if torch.cuda.is_available():
  83. cuda_rng_state = torch.cuda.get_rng_state()
  84. prior_fwd_from_src = torch.fx.graph_module._forward_from_src
  85. torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
  86. try:
  87. return fn(*args, **kwargs)
  88. finally:
  89. torch._C._set_grad_enabled(prior_grad_mode)
  90. torch.random.set_rng_state(rng_state)
  91. if torch.cuda.is_available():
  92. torch.cuda.set_rng_state(cuda_rng_state)
  93. torch.fx.graph_module._forward_from_src = prior_fwd_from_src
  94. _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
  95. return _fn
  96. @TorchPatcher.suppress_torch_distributed_warnings
  97. def has_tensor_in_frame(frame):
  98. """Check if the frame has torch.* related bits"""
  99. # Check if the function was decorated using torch._dynamo.optimize
  100. if frame.f_code in always_optimize_code_objects:
  101. return True
  102. # Check if there is global import of torch.*
  103. for co_name in frame.f_code.co_names:
  104. if co_name in frame.f_globals:
  105. if is_allowed(frame.f_globals[co_name]):
  106. return True
  107. seen_ids: Dict[int, bool] = dict()
  108. def has_tensor(obj):
  109. """Recursively check if the obj has a tensor"""
  110. obj_id = id(obj)
  111. if obj_id in seen_ids:
  112. return seen_ids[obj_id]
  113. seen_ids[obj_id] = False
  114. if isinstance(obj, (torch.Tensor, torch.nn.Module)):
  115. seen_ids[obj_id] = True
  116. return seen_ids[obj_id]
  117. elif istype(obj, (list, tuple)):
  118. seen_ids[obj_id] = any([has_tensor(v) for v in obj])
  119. return seen_ids[obj_id]
  120. elif istype(obj, dict):
  121. # Some packages like pytest can be updated during runtime. So, make a
  122. # copy of values to avoid issues like "RuntimeError: dictionary
  123. # changed size during iteration"
  124. values = list(obj.values())
  125. seen_ids[obj_id] = any([has_tensor(v) for v in values])
  126. return seen_ids[obj_id]
  127. elif istype(obj, (str, int, float, type(None), bool)):
  128. seen_ids[obj_id] = False
  129. return seen_ids[obj_id]
  130. elif is_namedtuple(obj):
  131. seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields])
  132. return seen_ids[obj_id]
  133. else:
  134. # if config.debug:
  135. # print(
  136. # f"Assuming that object of type {type(obj)} does not have a tensor"
  137. # )
  138. return False
  139. # Check if the passed arguments are of type Tensor
  140. for value in frame.f_locals.values():
  141. if has_tensor(value):
  142. return True
  143. log.debug(
  144. f"skipping because no torch.* {frame.f_code.co_name} \
  145. {frame.f_code.co_filename} {frame.f_code.co_firstlineno}"
  146. )
  147. return False
  148. def exception_handler(e, code, frame=None):
  149. record_filename = None
  150. if hasattr(e, "exec_record"):
  151. record_filename = gen_record_file_name(e, code)
  152. write_record_to_file(record_filename, e.exec_record)
  153. e.record_filename = record_filename
  154. augment_exc_message(e)
  155. # Only log the exception if we are going to suppress it
  156. # if aren't suppressing it, a higher level except block will handle it
  157. if config.suppress_errors:
  158. log.error(format_error_msg(e, code, record_filename, frame))
  159. def convert_frame_assert(
  160. compiler_fn: CompilerFn,
  161. one_graph: bool = True,
  162. export: bool = False,
  163. ):
  164. """Fully convert a frame into an FX graph"""
  165. init_logging()
  166. def _convert_frame_assert(frame: types.FrameType, cache_size: int, hooks: Hooks):
  167. increment_frame()
  168. code = frame.f_code
  169. input_codes.add(code)
  170. if code in output_codes:
  171. return None
  172. if (
  173. os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION")
  174. and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name
  175. ):
  176. return None
  177. if code.co_name == "<genexpr>" and code.co_filename.endswith(
  178. ("transformers/file_utils.py", "transformers/utils/generic.py")
  179. ):
  180. # not needed, but cleans up torchbench error stats
  181. return None
  182. if code.co_name == "__setattr__":
  183. # setattr could be tricky to handle generally,
  184. # but also not likely useful to compile- skip the whole frame
  185. return None
  186. # Check if the frame is generated by an exec builtin call
  187. # TODO - Running exec generated frame seems propagates f_globals to the
  188. # next frames.
  189. if code.co_name == "<module>" and code.co_filename == "<string>":
  190. return None
  191. if (
  192. code.co_name == "<lambda>"
  193. and code.co_filename == "<string>"
  194. and not bool(frame.f_builtins)
  195. ):
  196. # namedtuple subclass constructor. Empty builtins cause issue with
  197. # len keyword in LIST_LEN guard.
  198. return None
  199. if is_generator(code):
  200. unimplemented("generator")
  201. if cache_size >= config.cache_size_limit:
  202. def format_func_info(code):
  203. return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})"
  204. def format_guard_failures(code):
  205. # For the common case, it's sufficient to see just the most recent failure.
  206. # We could add a verbose mode if needed
  207. return f"{str(guard_failures[code][-1])}"
  208. assert code in guard_failures, "TODO(whc) any other recompile reasons?"
  209. log.warning(
  210. f"torch._dynamo hit config.cache_size_limit ({config.cache_size_limit})\n"
  211. + f" function: {format_func_info(code)}\n"
  212. + f" reasons: {format_guard_failures(code)}\n"
  213. + f"to diagnose recompilation issues, see {troubleshooting_url}."
  214. )
  215. unimplemented("cache_size_limit reached")
  216. if not has_tensor_in_frame(frame):
  217. return None
  218. global initial_grad_state
  219. initial_grad_state = torch.is_grad_enabled()
  220. return _compile(
  221. frame.f_code,
  222. frame.f_globals,
  223. frame.f_locals,
  224. frame.f_builtins,
  225. compiler_fn,
  226. one_graph,
  227. export,
  228. hooks,
  229. frame,
  230. )
  231. _convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
  232. return wrap_convert_context(_convert_frame_assert)
  233. @dynamo_timed(phase_name="entire_frame_compile")
  234. def _compile(
  235. code: types.CodeType,
  236. globals: Dict[str, object],
  237. locals: Dict[str, object],
  238. builtins: Dict[str, object],
  239. compiler_fn: CompilerFn,
  240. one_graph: bool,
  241. export: bool,
  242. hooks: Hooks,
  243. frame: Optional[types.FrameType] = None,
  244. ) -> Optional[GuardedCode]:
  245. output: Optional[OutputGraph] = None
  246. # This is shared across restarts
  247. mutated_closure_cell_contents: Set[str] = set()
  248. # from .utils import print_once; print_once(code.co_filename)
  249. def transform(instructions, code_options):
  250. nonlocal output
  251. tracer = InstructionTranslator(
  252. instructions,
  253. code,
  254. locals,
  255. globals,
  256. builtins,
  257. code_options,
  258. compiler_fn,
  259. one_graph,
  260. export,
  261. mutated_closure_cell_contents,
  262. )
  263. tracer.run()
  264. output = tracer.output
  265. assert output is not None
  266. assert output.output_instructions
  267. instructions[:] = output.output_instructions
  268. code_options.update(output.code_options)
  269. if config.dead_code_elimination:
  270. instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
  271. try:
  272. for attempt in itertools.count():
  273. try:
  274. out_code = transform_code_object(code, transform)
  275. orig_code_map[out_code] = code
  276. break
  277. except exc.RestartAnalysis:
  278. log.debug("Restarting analysis ...")
  279. if attempt > 100:
  280. unimplemented("100+ RestartAnalysis() calls")
  281. except exc.SkipFrame as e:
  282. log.debug(
  283. f"Skipping frame {e} {code.co_name} \
  284. {code.co_filename} {code.co_firstlineno}"
  285. )
  286. if one_graph:
  287. log.debug("No graph captured with one_graph=True")
  288. return None
  289. output_codes.add(out_code)
  290. if config.output_code:
  291. log.info(
  292. format_bytecode(
  293. "ORIGINAL BYTECODE",
  294. code.co_name,
  295. code.co_filename,
  296. code.co_firstlineno,
  297. code,
  298. ),
  299. )
  300. log.info(
  301. format_bytecode(
  302. "MODIFIED BYTECODE",
  303. code.co_name,
  304. code.co_filename,
  305. code.co_firstlineno,
  306. out_code,
  307. ),
  308. )
  309. assert output is not None
  310. assert output.guards is not None
  311. CleanupManager.instance[out_code] = output.cleanups
  312. check_fn = CheckFunctionManager(
  313. output,
  314. locals,
  315. globals,
  316. hooks.guard_fail_fn if hooks else None,
  317. )
  318. guarded_code = GuardedCode(out_code, check_fn.check_fn)
  319. if config.output_code:
  320. guard_str = "GUARDS:\n"
  321. guard_str += "\n".join(
  322. [f" - {str(guard)}" for guard in sorted(output.guards)]
  323. )
  324. log.info(guard_str)
  325. if hooks.guard_export_fn is not None:
  326. hooks.guard_export_fn(output.guards)
  327. return guarded_code
  328. except (
  329. Unsupported,
  330. TorchRuntimeError,
  331. BackendCompilerFailed,
  332. AssertionError,
  333. ) as e:
  334. exception_handler(e, code, frame)
  335. raise
  336. except Exception as e:
  337. exception_handler(e, code, frame)
  338. raise InternalTorchDynamoError() from e
  339. def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
  340. """Try to convert a frame into an FX graph, if error leave frame unmodified"""
  341. inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
  342. def _convert_frame(frame: types.FrameType, cache_size: int, hooks: Hooks):
  343. counters["frames"]["total"] += 1
  344. try:
  345. result = inner_convert(frame, cache_size, hooks)
  346. counters["frames"]["ok"] += 1
  347. return result
  348. except (NotImplementedError, Unsupported):
  349. log.info("converting frame raised unsupported, leaving it unconverted")
  350. except Exception:
  351. if not config.suppress_errors:
  352. raise
  353. log.info("converting frame raised error, suppressing error")
  354. return None
  355. _convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
  356. return _convert_frame
  357. # TODO mlazos: add support for same args, or record them
  358. def replay(filename):
  359. from .backends.debugging import eager
  360. original_replay_val = config.replay_record_enabled
  361. config.replay_record_enabled = False
  362. init_logging()
  363. with open(filename, "rb") as in_file:
  364. record = ExecutionRecord.load(in_file)
  365. record.globals = {
  366. k: v for k, v in itertools.chain(record.globals.items(), globals().items())
  367. }
  368. try:
  369. _compile(
  370. record.code,
  371. record.globals,
  372. record.locals,
  373. record.builtins,
  374. compiler_fn=eager,
  375. one_graph=False,
  376. export=False,
  377. hooks=Hooks(),
  378. frame=None,
  379. )
  380. except Exception:
  381. pass
  382. finally:
  383. config.replay_record_enabled = original_replay_val