debug.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. import collections
  2. import contextlib
  3. import cProfile
  4. import functools
  5. import itertools
  6. import logging
  7. import os.path
  8. import pstats
  9. import shutil
  10. import subprocess
  11. import sys
  12. from typing import Any, List
  13. from unittest.mock import patch
  14. from functorch.compile import (
  15. config as functorch_config,
  16. draw_graph,
  17. get_aot_graph_name,
  18. get_graph_being_compiled,
  19. )
  20. import torch
  21. from torch import fx as fx
  22. from torch._dynamo import config as dynamo_config
  23. from torch._dynamo.debug_utils import save_graph_repro, wrap_compiler_debug
  24. from torch._dynamo.utils import get_debug_dir, init_logging
  25. from torch.fx.graph_module import GraphModule
  26. from torch.fx.passes.shape_prop import TensorMetadata
  27. from torch.fx.passes.tools_common import legalize_graph
  28. from . import config, ir # noqa: F811, this is needed
  29. from .scheduler import (
  30. BaseSchedulerNode,
  31. FusedSchedulerNode,
  32. NopKernelSchedulerNode,
  33. OutputNode,
  34. SchedulerNode,
  35. )
  36. from .virtualized import V
  37. log = logging.getLogger(__name__)
  38. @functools.lru_cache(None)
  39. def has_dot():
  40. try:
  41. subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
  42. return True
  43. except subprocess.SubprocessError:
  44. return False
  45. def draw_buffers(nodes, print_graph=False, fname=None):
  46. """
  47. Draw a graph in fname.svg.
  48. nodes is a list of SchedulerNode objects.
  49. """
  50. if not has_dot():
  51. log.warning("draw_buffers() requires `graphviz` package")
  52. return
  53. if fname is None:
  54. fname = get_graph_being_compiled()
  55. graph = create_fx_from_snodes(nodes)
  56. for node in graph.nodes:
  57. if "fusion_meta" not in node.meta:
  58. continue
  59. group = node.meta["fusion_meta"].group
  60. if isinstance(group, tuple):
  61. group = group[1]
  62. # gather meta data
  63. dtype = None
  64. if isinstance(node, ir.ComputedBuffer):
  65. dtype = node.data.dtype
  66. metadata = TensorMetadata(group, dtype, None, None, None, None, None)
  67. node.meta["tensor_meta"] = metadata
  68. if print_graph:
  69. print(graph)
  70. gm = GraphModule({}, graph)
  71. legalize_graph(gm)
  72. gm.graph.lint()
  73. draw_graph(gm, fname, clear_meta=False)
  74. def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
  75. """
  76. Creates a FX Graph from a list of SchedulerNode objects.
  77. """
  78. def get_fake_func(name):
  79. def func1(*args):
  80. return 0
  81. func1.__name__ = name
  82. return func1
  83. FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"])
  84. func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]}
  85. buf_to_fx_node = {}
  86. graph = torch.fx.Graph()
  87. first_node = None
  88. outputs = []
  89. group: Any = None
  90. # create call_function node for each Buffer and Kernel
  91. for snode in snodes:
  92. if snode.is_extern():
  93. node_type = "extern"
  94. group = node_type
  95. elif snode.is_template():
  96. node_type = "template"
  97. group = node_type
  98. elif isinstance(snode, NopKernelSchedulerNode):
  99. node_type = "nop"
  100. group = node_type
  101. elif isinstance(snode, SchedulerNode):
  102. node_type = "compute"
  103. group = snode.group
  104. elif isinstance(snode, FusedSchedulerNode):
  105. node_type = "fused"
  106. group = snode.group
  107. else:
  108. raise RuntimeError("Unknown node type")
  109. node_func = func_dict[node_type]
  110. fx_node = graph.call_function(node_func, args=(), kwargs=None)
  111. def in_output(snode):
  112. if isinstance(snode, FusedSchedulerNode):
  113. return any([in_output(x) for x in snode.snodes])
  114. return any([isinstance(user.node, OutputNode) for user in snode.users])
  115. if in_output(snode):
  116. outputs.append(fx_node)
  117. name = snode.get_name()
  118. fx_node.name = name
  119. fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type)
  120. if isinstance(snode, FusedSchedulerNode):
  121. for x in snode.snodes:
  122. buf_to_fx_node[x.get_name()] = fx_node
  123. buf_to_fx_node[name] = fx_node
  124. if first_node is None:
  125. first_node = fx_node
  126. # create edges between nodes
  127. for snode in snodes:
  128. name = snode.get_name()
  129. deps = snode.read_writes.reads
  130. fx_node = buf_to_fx_node[name]
  131. new_args = []
  132. for dep in deps:
  133. if dep.name in buf_to_fx_node:
  134. dep_node = buf_to_fx_node[dep.name]
  135. else:
  136. with graph.inserting_before(first_node):
  137. dep_node = graph.placeholder(dep.name)
  138. buf_to_fx_node[dep.name] = dep_node
  139. new_args.append(dep_node)
  140. fx_node.args = tuple(new_args)
  141. graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
  142. return graph
  143. @contextlib.contextmanager
  144. def enable_aot_logging():
  145. compile_debug = bool(os.environ.get("TORCH_COMPILE_DEBUG", False))
  146. debug_graphs = functorch_config.debug_graphs
  147. debug_joint_graphs = functorch_config.debug_joint
  148. import torch._functorch.aot_autograd
  149. log = logging.getLogger(torch._functorch.aot_autograd.__name__)
  150. stack = contextlib.ExitStack()
  151. stack.enter_context(patch("functorch.compile.config.log_level", logging.DEBUG))
  152. # if user has specified they want to see graphs via either env var
  153. # add stream to std out
  154. if debug_graphs or debug_joint_graphs:
  155. stdout_handler = logging.StreamHandler(sys.stdout)
  156. log.addHandler(stdout_handler)
  157. stack.callback(lambda: log.removeHandler(stdout_handler))
  158. if not compile_debug:
  159. try:
  160. yield
  161. finally:
  162. stack.close()
  163. return
  164. # Enable all graphs to be logged to a file by setting the flags to True
  165. # and the log level of the file logger to DEBUG
  166. stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
  167. stack.enter_context(patch("functorch.compile.config.debug_graphs", True))
  168. stack.enter_context(patch("functorch.compile.config.debug_joint", True))
  169. path = os.path.join(get_debug_dir(), "aot_torchinductor")
  170. if not os.path.exists(path):
  171. os.makedirs(path)
  172. fh = logging.FileHandler(
  173. os.path.join(
  174. path,
  175. f"aot_{get_aot_graph_name()}_debug.log",
  176. )
  177. )
  178. fh.setLevel(logging.DEBUG)
  179. fh.setFormatter(
  180. logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
  181. )
  182. log.addHandler(fh)
  183. try:
  184. yield
  185. finally:
  186. log.removeHandler(fh)
  187. stack.close()
  188. class DebugContext:
  189. _counter = itertools.count()
  190. @staticmethod
  191. def wrap(fn):
  192. @functools.wraps(fn)
  193. def inner(*args, **kwargs):
  194. with DebugContext():
  195. return fn(*args, **kwargs)
  196. return wrap_compiler_debug(inner, compiler_name="inductor")
  197. @staticmethod
  198. def create_debug_dir(folder_name):
  199. for n in DebugContext._counter:
  200. dirname = os.path.join(
  201. get_debug_dir(),
  202. "aot_torchinductor",
  203. f"{folder_name}.{n}",
  204. )
  205. if not os.path.exists(dirname):
  206. os.makedirs(dirname)
  207. return dirname
  208. def __init__(self):
  209. self._prof = None
  210. self._path = None
  211. self._stack = contextlib.ExitStack()
  212. def rename(self, new_path: str):
  213. if not self._path:
  214. return
  215. assert new_path.endswith(".debug"), new_path
  216. if os.path.exists(new_path):
  217. shutil.rmtree(new_path)
  218. try:
  219. os.rename(self._path, new_path)
  220. self._path = new_path
  221. except OSError:
  222. # other OS might have troubling renaming dir with open files
  223. pass
  224. def fopen(self, filename):
  225. assert self._path
  226. return open(os.path.join(self._path, filename), "w")
  227. def filename(self, suffix):
  228. return os.path.join(self._path, suffix)
  229. def upload_tar(self):
  230. if config.trace.upload_tar is not None:
  231. import tarfile
  232. assert self._path
  233. tar_file = os.path.join(
  234. self._path, f"{os.path.basename(self._path)}.tar.gz"
  235. )
  236. with tarfile.open(tar_file, "w:gz") as tar:
  237. tar.add(self._path, arcname=os.path.basename(self._path))
  238. config.trace.upload_tar(tar_file)
  239. def __enter__(self):
  240. log = logging.getLogger("torch._inductor")
  241. if not log.handlers:
  242. init_logging()
  243. if config.debug:
  244. def reset_log_level(level):
  245. dynamo_config.log_level = level
  246. self._stack.callback(reset_log_level, dynamo_config.log_level)
  247. dynamo_config.log_level = logging.DEBUG
  248. self._stack.enter_context(V.set_debug_handler(self))
  249. if not config.trace.enabled:
  250. return
  251. self._path = self.create_debug_dir(get_aot_graph_name())
  252. if config.trace.debug_log:
  253. self._setup_log_capture("debug.log", logging.DEBUG)
  254. if config.trace.info_log:
  255. self._setup_log_capture("info.log", logging.INFO)
  256. if config.trace.compile_profile:
  257. self._prof = cProfile.Profile()
  258. self._prof.enable()
  259. def _setup_log_capture(self, filename, level):
  260. log = logging.getLogger("torch._inductor")
  261. fd = self._stack.enter_context(self.fopen(filename))
  262. ch = logging.StreamHandler(fd)
  263. ch.setLevel(level)
  264. ch.setFormatter(
  265. logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
  266. )
  267. log.addHandler(ch)
  268. log.setLevel(min(log.level, level))
  269. self._stack.callback(log.removeHandler, ch)
  270. def __exit__(self, exc_type, exc_val, exc_tb):
  271. if self._prof:
  272. self._prof.disable()
  273. self._save_profile_data()
  274. if self._path:
  275. self.upload_tar()
  276. log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
  277. self._stack.close()
  278. def _save_profile_data(self):
  279. self._prof.dump_stats(self.filename("compile.prof"))
  280. with self.fopen("compile.stats") as fd:
  281. stats = pstats.Stats(self._prof, stream=fd)
  282. stats.strip_dirs()
  283. stats.sort_stats("cumtime")
  284. stats.print_stats(100)
  285. stats.sort_stats("tottime")
  286. stats.print_stats(100)
  287. def __getattr__(self, name):
  288. if config.trace.enabled and getattr(config.trace, name):
  289. try:
  290. return getattr(DebugFormatter(self), name)
  291. except Exception:
  292. log.warning("Ignoring exception in debug code", exc_info=True)
  293. else:
  294. def ignored(*args, **kwargs):
  295. pass
  296. return ignored
  297. SchedulerNodeList = List[Any]
  298. class DebugFormatter:
  299. def __init__(self, handler):
  300. self.fopen = handler.fopen
  301. self.filename = handler.filename
  302. self.handler = handler
  303. def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
  304. with self.fopen("fx_graph_runnable.py") as fd:
  305. save_graph_repro(fd, gm, inputs, "inductor")
  306. with self.fopen("fx_graph_readable.py") as fd:
  307. fd.write(gm.print_readable(print_output=False))
  308. def fx_graph_transformed(
  309. self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
  310. ):
  311. with self.fopen("fx_graph_transformed.py") as fd:
  312. fd.write(gm.print_readable(print_output=False))
  313. def ir_pre_fusion(self, nodes: SchedulerNodeList):
  314. self._write_ir("ir_pre_fusion.txt", nodes)
  315. def ir_post_fusion(self, nodes: SchedulerNodeList):
  316. self._write_ir("ir_post_fusion.txt", nodes)
  317. def _write_ir(self, filename: str, nodes: SchedulerNodeList):
  318. with self.fopen(filename) as fd:
  319. for node in nodes:
  320. fd.write(node.debug_str())
  321. fd.write("\n\n\n")
  322. def graph_diagram(self, nodes: SchedulerNodeList):
  323. draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
  324. def output_code(self, filename):
  325. shutil.copy(filename, self.filename("output_code.py"))