14 KB

  1. import functools
  2. import itertools
  3. import logging
  4. import os
  5. import types
  6. import weakref
  7. from typing import Dict, Optional, Set
  8. import torch
  9. from torch.fx.graph_module import _forward_from_src as original_forward_from_src
  10. from . import config, exc
  11. from .allowed_functions import is_allowed
  12. from .backends.registry import CompilerFn
  13. from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
  14. from .bytecode_transformation import is_generator, transform_code_object
  15. from .eval_frame import always_optimize_code_objects, skip_code, TorchPatcher
  16. from .exc import (
  17. augment_exc_message,
  18. BackendCompilerFailed,
  19. format_error_msg,
  20. InternalTorchDynamoError,
  21. TorchRuntimeError,
  22. unimplemented,
  23. Unsupported,
  24. )
  25. from .guards import CheckFunctionManager, GuardedCode
  26. from .hooks import Hooks
  27. from .output_graph import OutputGraph
  28. from .replay_record import ExecutionRecord
  29. from .symbolic_convert import InstructionTranslator
  30. from .utils import (
  31. CleanupManager,
  32. counters,
  33. dynamo_timed,
  34. format_bytecode,
  35. gen_record_file_name,
  36. guard_failures,
  37. increment_frame,
  38. init_logging,
  39. is_namedtuple,
  40. istype,
  41. orig_code_map,
  42. troubleshooting_url,
  43. write_record_to_file,
  44. )
  45. log = logging.getLogger(__name__)
  46. class Tracker:
  47. def __init__(self):
  48. self.seen = []
  49. self.seen_ids = set()
  50. def add(self, strong_obj):
  51. idx = id(strong_obj)
  52. if idx not in self.seen_ids:
  53. obj = weakref.ref(strong_obj, lambda _: self.seen_ids.remove(idx))
  54. self.seen.append(obj)
  55. self.seen_ids.add(idx)
  56. def __contains__(self, item):
  57. return id(item) in self.seen_ids
  58. def clear(self):
  59. self.seen.clear()
  60. self.seen_ids.clear()
  61. input_codes = Tracker()
  62. output_codes = Tracker()
  63. initial_grad_state = None
  64. @functools.wraps(original_forward_from_src)
  65. def fx_forward_from_src_skip_result(*args, **kwargs):
  66. # we monkey patch FX to prevent infinite loop of trying to convert
  67. # our generated code
  68. result: types.FunctionType = original_forward_from_src(*args, **kwargs)
  69. skip_code(result.__code__)
  70. return result
  71. def wrap_convert_context(fn):
  72. """
  73. Context manager to:
  74. 1) Save/restore torch random state
  75. 2) Save/restore torch.is_grad_enabled() state
  76. 3) Monkey patch torch.fx.graph_module._forward_from_src
  77. """
  78. @functools.wraps(fn)
  79. def _fn(*args, **kwargs):
  80. prior_grad_mode = torch.is_grad_enabled()
  81. rng_state = torch.random.get_rng_state()
  82. if torch.cuda.is_available():
  83. cuda_rng_state = torch.cuda.get_rng_state()
  84. prior_fwd_from_src = torch.fx.graph_module._forward_from_src
  85. torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
  86. try:
  87. return fn(*args, **kwargs)
  88. finally:
  89. torch._C._set_grad_enabled(prior_grad_mode)
  90. torch.random.set_rng_state(rng_state)
  91. if torch.cuda.is_available():
  92. torch.cuda.set_rng_state(cuda_rng_state)
  93. torch.fx.graph_module._forward_from_src = prior_fwd_from_src
  94. _fn._torchdynamo_orig_callable = fn # type: ignore[attr-defined]
  95. return _fn
  96. @TorchPatcher.suppress_torch_distributed_warnings
  97. def has_tensor_in_frame(frame):
  98. """Check if the frame has torch.* related bits"""
  99. # Check if the function was decorated using torch._dynamo.optimize
  100. if frame.f_code in always_optimize_code_objects:
  101. return True
  102. # Check if there is global import of torch.*
  103. for co_name in frame.f_code.co_names:
  104. if co_name in frame.f_globals:
  105. if is_allowed(frame.f_globals[co_name]):
  106. return True
  107. seen_ids: Dict[int, bool] = dict()
  108. def has_tensor(obj):
  109. """Recursively check if the obj has a tensor"""
  110. obj_id = id(obj)
  111. if obj_id in seen_ids:
  112. return seen_ids[obj_id]
  113. seen_ids[obj_id] = False
  114. if isinstance(obj, (torch.Tensor, torch.nn.Module)):
  115. seen_ids[obj_id] = True
  116. return seen_ids[obj_id]
  117. elif istype(obj, (list, tuple)):
  118. seen_ids[obj_id] = any([has_tensor(v) for v in obj])
  119. return seen_ids[obj_id]
  120. elif istype(obj, dict):
  121. # Some packages like pytest can be updated during runtime. So, make a
  122. # copy of values to avoid issues like "RuntimeError: dictionary
  123. # changed size during iteration"
  124. values = list(obj.values())
  125. seen_ids[obj_id] = any([has_tensor(v) for v in values])
  126. return seen_ids[obj_id]
  127. elif istype(obj, (str, int, float, type(None), bool)):
  128. seen_ids[obj_id] = False
  129. return seen_ids[obj_id]
  130. elif is_namedtuple(obj):
  131. seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields])
  132. return seen_ids[obj_id]
  133. else:
  134. # if config.debug:
  135. # print(
  136. # f"Assuming that object of type {type(obj)} does not have a tensor"
  137. # )
  138. return False
  139. # Check if the passed arguments are of type Tensor
  140. for value in frame.f_locals.values():
  141. if has_tensor(value):
  142. return True
  143. log.debug(
  144. f"skipping because no torch.* {frame.f_code.co_name} \
  145. {frame.f_code.co_filename} {frame.f_code.co_firstlineno}"
  146. )
  147. return False
  148. def exception_handler(e, code, frame=None):
  149. record_filename = None
  150. if hasattr(e, "exec_record"):
  151. record_filename = gen_record_file_name(e, code)
  152. write_record_to_file(record_filename, e.exec_record)
  153. e.record_filename = record_filename
  154. augment_exc_message(e)
  155. # Only log the exception if we are going to suppress it
  156. # if aren't suppressing it, a higher level except block will handle it
  157. if config.suppress_errors:
  158. log.error(format_error_msg(e, code, record_filename, frame))
  159. def convert_frame_assert(
  160. compiler_fn: CompilerFn,
  161. one_graph: bool = True,
  162. export: bool = False,
  163. ):
  164. """Fully convert a frame into an FX graph"""
  165. init_logging()
  166. def _convert_frame_assert(frame: types.FrameType, cache_size: int, hooks: Hooks):
  167. increment_frame()
  168. code = frame.f_code
  169. input_codes.add(code)
  170. if code in output_codes:
  171. return None
  172. if (
  173. os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION")
  174. and os.environ.get("TORCHDYNAMO_DEBUG_FUNCTION") != code.co_name
  175. ):
  176. return None
  177. if code.co_name == "<genexpr>" and code.co_filename.endswith(
  178. ("transformers/", "transformers/utils/")
  179. ):
  180. # not needed, but cleans up torchbench error stats
  181. return None
  182. if code.co_name == "__setattr__":
  183. # setattr could be tricky to handle generally,
  184. # but also not likely useful to compile- skip the whole frame
  185. return None
  186. # Check if the frame is generated by an exec builtin call
  187. # TODO - Running exec generated frame seems propagates f_globals to the
  188. # next frames.
  189. if code.co_name == "<module>" and code.co_filename == "<string>":
  190. return None
  191. if (
  192. code.co_name == "<lambda>"
  193. and code.co_filename == "<string>"
  194. and not bool(frame.f_builtins)
  195. ):
  196. # namedtuple subclass constructor. Empty builtins cause issue with
  197. # len keyword in LIST_LEN guard.
  198. return None
  199. if is_generator(code):
  200. unimplemented("generator")
  201. if cache_size >= config.cache_size_limit:
  202. def format_func_info(code):
  203. return f"'{code.co_name}' ({code.co_filename}:{code.co_firstlineno})"
  204. def format_guard_failures(code):
  205. # For the common case, it's sufficient to see just the most recent failure.
  206. # We could add a verbose mode if needed
  207. return f"{str(guard_failures[code][-1])}"
  208. assert code in guard_failures, "TODO(whc) any other recompile reasons?"
  209. log.warning(
  210. f"torch._dynamo hit config.cache_size_limit ({config.cache_size_limit})\n"
  211. + f" function: {format_func_info(code)}\n"
  212. + f" reasons: {format_guard_failures(code)}\n"
  213. + f"to diagnose recompilation issues, see {troubleshooting_url}."
  214. )
  215. unimplemented("cache_size_limit reached")
  216. if not has_tensor_in_frame(frame):
  217. return None
  218. global initial_grad_state
  219. initial_grad_state = torch.is_grad_enabled()
  220. return _compile(
  221. frame.f_code,
  222. frame.f_globals,
  223. frame.f_locals,
  224. frame.f_builtins,
  225. compiler_fn,
  226. one_graph,
  227. export,
  228. hooks,
  229. frame,
  230. )
  231. _convert_frame_assert._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
  232. return wrap_convert_context(_convert_frame_assert)
  233. @dynamo_timed(phase_name="entire_frame_compile")
  234. def _compile(
  235. code: types.CodeType,
  236. globals: Dict[str, object],
  237. locals: Dict[str, object],
  238. builtins: Dict[str, object],
  239. compiler_fn: CompilerFn,
  240. one_graph: bool,
  241. export: bool,
  242. hooks: Hooks,
  243. frame: Optional[types.FrameType] = None,
  244. ) -> Optional[GuardedCode]:
  245. output: Optional[OutputGraph] = None
  246. # This is shared across restarts
  247. mutated_closure_cell_contents: Set[str] = set()
  248. # from .utils import print_once; print_once(code.co_filename)
  249. def transform(instructions, code_options):
  250. nonlocal output
  251. tracer = InstructionTranslator(
  252. instructions,
  253. code,
  254. locals,
  255. globals,
  256. builtins,
  257. code_options,
  258. compiler_fn,
  259. one_graph,
  260. export,
  261. mutated_closure_cell_contents,
  262. )
  264. output = tracer.output
  265. assert output is not None
  266. assert output.output_instructions
  267. instructions[:] = output.output_instructions
  268. code_options.update(output.code_options)
  269. if config.dead_code_elimination:
  270. instructions[:] = remove_pointless_jumps(remove_dead_code(instructions))
  271. try:
  272. for attempt in itertools.count():
  273. try:
  274. out_code = transform_code_object(code, transform)
  275. orig_code_map[out_code] = code
  276. break
  277. except exc.RestartAnalysis:
  278. log.debug("Restarting analysis ...")
  279. if attempt > 100:
  280. unimplemented("100+ RestartAnalysis() calls")
  281. except exc.SkipFrame as e:
  282. log.debug(
  283. f"Skipping frame {e} {code.co_name} \
  284. {code.co_filename} {code.co_firstlineno}"
  285. )
  286. if one_graph:
  287. log.debug("No graph captured with one_graph=True")
  288. return None
  289. output_codes.add(out_code)
  290. if config.output_code:
  292. format_bytecode(
  294. code.co_name,
  295. code.co_filename,
  296. code.co_firstlineno,
  297. code,
  298. ),
  299. )
  301. format_bytecode(
  303. code.co_name,
  304. code.co_filename,
  305. code.co_firstlineno,
  306. out_code,
  307. ),
  308. )
  309. assert output is not None
  310. assert output.guards is not None
  311. CleanupManager.instance[out_code] = output.cleanups
  312. check_fn = CheckFunctionManager(
  313. output,
  314. locals,
  315. globals,
  316. hooks.guard_fail_fn if hooks else None,
  317. )
  318. guarded_code = GuardedCode(out_code, check_fn.check_fn)
  319. if config.output_code:
  320. guard_str = "GUARDS:\n"
  321. guard_str += "\n".join(
  322. [f" - {str(guard)}" for guard in sorted(output.guards)]
  323. )
  325. if hooks.guard_export_fn is not None:
  326. hooks.guard_export_fn(output.guards)
  327. return guarded_code
  328. except (
  329. Unsupported,
  330. TorchRuntimeError,
  331. BackendCompilerFailed,
  332. AssertionError,
  333. ) as e:
  334. exception_handler(e, code, frame)
  335. raise
  336. except Exception as e:
  337. exception_handler(e, code, frame)
  338. raise InternalTorchDynamoError() from e
  339. def convert_frame(compiler_fn: CompilerFn, hooks: Hooks):
  340. """Try to convert a frame into an FX graph, if error leave frame unmodified"""
  341. inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
  342. def _convert_frame(frame: types.FrameType, cache_size: int, hooks: Hooks):
  343. counters["frames"]["total"] += 1
  344. try:
  345. result = inner_convert(frame, cache_size, hooks)
  346. counters["frames"]["ok"] += 1
  347. return result
  348. except (NotImplementedError, Unsupported):
  349."converting frame raised unsupported, leaving it unconverted")
  350. except Exception:
  351. if not config.suppress_errors:
  352. raise
  353."converting frame raised error, suppressing error")
  354. return None
  355. _convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
  356. return _convert_frame
  357. # TODO mlazos: add support for same args, or record them
  358. def replay(filename):
  359. from .backends.debugging import eager
  360. original_replay_val = config.replay_record_enabled
  361. config.replay_record_enabled = False
  362. init_logging()
  363. with open(filename, "rb") as in_file:
  364. record = ExecutionRecord.load(in_file)
  365. record.globals = {
  366. k: v for k, v in itertools.chain(record.globals.items(), globals().items())
  367. }
  368. try:
  369. _compile(
  370. record.code,
  371. record.globals,
  372. record.locals,
  373. record.builtins,
  374. compiler_fn=eager,
  375. one_graph=False,
  376. export=False,
  377. hooks=Hooks(),
  378. frame=None,
  379. )
  380. except Exception:
  381. pass
  382. finally:
  383. config.replay_record_enabled = original_replay_val