| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404 | 
							- import collections
 
- import contextlib
 
- import cProfile
 
- import functools
 
- import itertools
 
- import logging
 
- import os.path
 
- import pstats
 
- import shutil
 
- import subprocess
 
- import sys
 
- from typing import Any, List
 
- from unittest.mock import patch
 
- from functorch.compile import (
 
-     config as functorch_config,
 
-     draw_graph,
 
-     get_aot_graph_name,
 
-     get_graph_being_compiled,
 
- )
 
- import torch
 
- from torch import fx as fx
 
- from torch._dynamo import config as dynamo_config
 
- from torch._dynamo.debug_utils import save_graph_repro, wrap_compiler_debug
 
- from torch._dynamo.utils import get_debug_dir, init_logging
 
- from torch.fx.graph_module import GraphModule
 
- from torch.fx.passes.shape_prop import TensorMetadata
 
- from torch.fx.passes.tools_common import legalize_graph
 
- from . import config, ir  # noqa: F811, this is needed
 
- from .scheduler import (
 
-     BaseSchedulerNode,
 
-     FusedSchedulerNode,
 
-     NopKernelSchedulerNode,
 
-     OutputNode,
 
-     SchedulerNode,
 
- )
 
- from .virtualized import V
 
- log = logging.getLogger(__name__)
 
- @functools.lru_cache(None)
 
- def has_dot():
 
-     try:
 
-         subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
 
-         return True
 
-     except subprocess.SubprocessError:
 
-         return False
 
- def draw_buffers(nodes, print_graph=False, fname=None):
 
-     """
 
-     Draw a graph in fname.svg.
 
-     nodes is a list of SchedulerNode objects.
 
-     """
 
-     if not has_dot():
 
-         log.warning("draw_buffers() requires `graphviz` package")
 
-         return
 
-     if fname is None:
 
-         fname = get_graph_being_compiled()
 
-     graph = create_fx_from_snodes(nodes)
 
-     for node in graph.nodes:
 
-         if "fusion_meta" not in node.meta:
 
-             continue
 
-         group = node.meta["fusion_meta"].group
 
-         if isinstance(group, tuple):
 
-             group = group[1]
 
-         # gather meta data
 
-         dtype = None
 
-         if isinstance(node, ir.ComputedBuffer):
 
-             dtype = node.data.dtype
 
-         metadata = TensorMetadata(group, dtype, None, None, None, None, None)
 
-         node.meta["tensor_meta"] = metadata
 
-     if print_graph:
 
-         print(graph)
 
-     gm = GraphModule({}, graph)
 
-     legalize_graph(gm)
 
-     gm.graph.lint()
 
-     draw_graph(gm, fname, clear_meta=False)
 
- def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
 
-     """
 
-     Creates a FX Graph from a list of SchedulerNode objects.
 
-     """
 
-     def get_fake_func(name):
 
-         def func1(*args):
 
-             return 0
 
-         func1.__name__ = name
 
-         return func1
 
-     FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"])
 
-     func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]}
 
-     buf_to_fx_node = {}
 
-     graph = torch.fx.Graph()
 
-     first_node = None
 
-     outputs = []
 
-     group: Any = None
 
-     # create call_function node for each Buffer and Kernel
 
-     for snode in snodes:
 
-         if snode.is_extern():
 
-             node_type = "extern"
 
-             group = node_type
 
-         elif snode.is_template():
 
-             node_type = "template"
 
-             group = node_type
 
-         elif isinstance(snode, NopKernelSchedulerNode):
 
-             node_type = "nop"
 
-             group = node_type
 
-         elif isinstance(snode, SchedulerNode):
 
-             node_type = "compute"
 
-             group = snode.group
 
-         elif isinstance(snode, FusedSchedulerNode):
 
-             node_type = "fused"
 
-             group = snode.group
 
-         else:
 
-             raise RuntimeError("Unknown node type")
 
-         node_func = func_dict[node_type]
 
-         fx_node = graph.call_function(node_func, args=(), kwargs=None)
 
-         def in_output(snode):
 
-             if isinstance(snode, FusedSchedulerNode):
 
-                 return any([in_output(x) for x in snode.snodes])
 
-             return any([isinstance(user.node, OutputNode) for user in snode.users])
 
-         if in_output(snode):
 
-             outputs.append(fx_node)
 
-         name = snode.get_name()
 
-         fx_node.name = name
 
-         fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type)
 
-         if isinstance(snode, FusedSchedulerNode):
 
-             for x in snode.snodes:
 
-                 buf_to_fx_node[x.get_name()] = fx_node
 
-         buf_to_fx_node[name] = fx_node
 
-         if first_node is None:
 
-             first_node = fx_node
 
-     # create edges between nodes
 
-     for snode in snodes:
 
-         name = snode.get_name()
 
-         deps = snode.read_writes.reads
 
-         fx_node = buf_to_fx_node[name]
 
-         new_args = []
 
-         for dep in deps:
 
-             if dep.name in buf_to_fx_node:
 
-                 dep_node = buf_to_fx_node[dep.name]
 
-             else:
 
-                 with graph.inserting_before(first_node):
 
-                     dep_node = graph.placeholder(dep.name)
 
-                     buf_to_fx_node[dep.name] = dep_node
 
-             new_args.append(dep_node)
 
-         fx_node.args = tuple(new_args)
 
-     graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
 
-     return graph
 
- @contextlib.contextmanager
 
- def enable_aot_logging():
 
-     compile_debug = bool(os.environ.get("TORCH_COMPILE_DEBUG", False))
 
-     debug_graphs = functorch_config.debug_graphs
 
-     debug_joint_graphs = functorch_config.debug_joint
 
-     import torch._functorch.aot_autograd
 
-     log = logging.getLogger(torch._functorch.aot_autograd.__name__)
 
-     stack = contextlib.ExitStack()
 
-     stack.enter_context(patch("functorch.compile.config.log_level", logging.DEBUG))
 
-     # if user has specified they want to see graphs via either env var
 
-     # add stream to std out
 
-     if debug_graphs or debug_joint_graphs:
 
-         stdout_handler = logging.StreamHandler(sys.stdout)
 
-         log.addHandler(stdout_handler)
 
-         stack.callback(lambda: log.removeHandler(stdout_handler))
 
-     if not compile_debug:
 
-         try:
 
-             yield
 
-         finally:
 
-             stack.close()
 
-         return
 
-     # Enable all graphs to be logged to a file by setting the flags to True
 
-     # and the log level of the file logger to DEBUG
 
-     stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
 
-     stack.enter_context(patch("functorch.compile.config.debug_graphs", True))
 
-     stack.enter_context(patch("functorch.compile.config.debug_joint", True))
 
-     path = os.path.join(get_debug_dir(), "aot_torchinductor")
 
-     if not os.path.exists(path):
 
-         os.makedirs(path)
 
-     fh = logging.FileHandler(
 
-         os.path.join(
 
-             path,
 
-             f"aot_{get_aot_graph_name()}_debug.log",
 
-         )
 
-     )
 
-     fh.setLevel(logging.DEBUG)
 
-     fh.setFormatter(
 
-         logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
 
-     )
 
-     log.addHandler(fh)
 
-     try:
 
-         yield
 
-     finally:
 
-         log.removeHandler(fh)
 
-         stack.close()
 
- class DebugContext:
 
-     _counter = itertools.count()
 
-     @staticmethod
 
-     def wrap(fn):
 
-         @functools.wraps(fn)
 
-         def inner(*args, **kwargs):
 
-             with DebugContext():
 
-                 return fn(*args, **kwargs)
 
-         return wrap_compiler_debug(inner, compiler_name="inductor")
 
-     @staticmethod
 
-     def create_debug_dir(folder_name):
 
-         for n in DebugContext._counter:
 
-             dirname = os.path.join(
 
-                 get_debug_dir(),
 
-                 "aot_torchinductor",
 
-                 f"{folder_name}.{n}",
 
-             )
 
-             if not os.path.exists(dirname):
 
-                 os.makedirs(dirname)
 
-                 return dirname
 
-     def __init__(self):
 
-         self._prof = None
 
-         self._path = None
 
-         self._stack = contextlib.ExitStack()
 
-     def rename(self, new_path: str):
 
-         if not self._path:
 
-             return
 
-         assert new_path.endswith(".debug"), new_path
 
-         if os.path.exists(new_path):
 
-             shutil.rmtree(new_path)
 
-         try:
 
-             os.rename(self._path, new_path)
 
-             self._path = new_path
 
-         except OSError:
 
-             # other OS might have troubling renaming dir with open files
 
-             pass
 
-     def fopen(self, filename):
 
-         assert self._path
 
-         return open(os.path.join(self._path, filename), "w")
 
-     def filename(self, suffix):
 
-         return os.path.join(self._path, suffix)
 
-     def upload_tar(self):
 
-         if config.trace.upload_tar is not None:
 
-             import tarfile
 
-             assert self._path
 
-             tar_file = os.path.join(
 
-                 self._path, f"{os.path.basename(self._path)}.tar.gz"
 
-             )
 
-             with tarfile.open(tar_file, "w:gz") as tar:
 
-                 tar.add(self._path, arcname=os.path.basename(self._path))
 
-             config.trace.upload_tar(tar_file)
 
-     def __enter__(self):
 
-         log = logging.getLogger("torch._inductor")
 
-         if not log.handlers:
 
-             init_logging()
 
-         if config.debug:
 
-             def reset_log_level(level):
 
-                 dynamo_config.log_level = level
 
-             self._stack.callback(reset_log_level, dynamo_config.log_level)
 
-             dynamo_config.log_level = logging.DEBUG
 
-         self._stack.enter_context(V.set_debug_handler(self))
 
-         if not config.trace.enabled:
 
-             return
 
-         self._path = self.create_debug_dir(get_aot_graph_name())
 
-         if config.trace.debug_log:
 
-             self._setup_log_capture("debug.log", logging.DEBUG)
 
-         if config.trace.info_log:
 
-             self._setup_log_capture("info.log", logging.INFO)
 
-         if config.trace.compile_profile:
 
-             self._prof = cProfile.Profile()
 
-             self._prof.enable()
 
-     def _setup_log_capture(self, filename, level):
 
-         log = logging.getLogger("torch._inductor")
 
-         fd = self._stack.enter_context(self.fopen(filename))
 
-         ch = logging.StreamHandler(fd)
 
-         ch.setLevel(level)
 
-         ch.setFormatter(
 
-             logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
 
-         )
 
-         log.addHandler(ch)
 
-         log.setLevel(min(log.level, level))
 
-         self._stack.callback(log.removeHandler, ch)
 
-     def __exit__(self, exc_type, exc_val, exc_tb):
 
-         if self._prof:
 
-             self._prof.disable()
 
-             self._save_profile_data()
 
-         if self._path:
 
-             self.upload_tar()
 
-             log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
 
-         self._stack.close()
 
-     def _save_profile_data(self):
 
-         self._prof.dump_stats(self.filename("compile.prof"))
 
-         with self.fopen("compile.stats") as fd:
 
-             stats = pstats.Stats(self._prof, stream=fd)
 
-             stats.strip_dirs()
 
-             stats.sort_stats("cumtime")
 
-             stats.print_stats(100)
 
-             stats.sort_stats("tottime")
 
-             stats.print_stats(100)
 
-     def __getattr__(self, name):
 
-         if config.trace.enabled and getattr(config.trace, name):
 
-             try:
 
-                 return getattr(DebugFormatter(self), name)
 
-             except Exception:
 
-                 log.warning("Ignoring exception in debug code", exc_info=True)
 
-         else:
 
-             def ignored(*args, **kwargs):
 
-                 pass
 
-             return ignored
 
- SchedulerNodeList = List[Any]
 
- class DebugFormatter:
 
-     def __init__(self, handler):
 
-         self.fopen = handler.fopen
 
-         self.filename = handler.filename
 
-         self.handler = handler
 
-     def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
 
-         with self.fopen("fx_graph_runnable.py") as fd:
 
-             save_graph_repro(fd, gm, inputs, "inductor")
 
-         with self.fopen("fx_graph_readable.py") as fd:
 
-             fd.write(gm.print_readable(print_output=False))
 
-     def fx_graph_transformed(
 
-         self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
 
-     ):
 
-         with self.fopen("fx_graph_transformed.py") as fd:
 
-             fd.write(gm.print_readable(print_output=False))
 
-     def ir_pre_fusion(self, nodes: SchedulerNodeList):
 
-         self._write_ir("ir_pre_fusion.txt", nodes)
 
-     def ir_post_fusion(self, nodes: SchedulerNodeList):
 
-         self._write_ir("ir_post_fusion.txt", nodes)
 
-     def _write_ir(self, filename: str, nodes: SchedulerNodeList):
 
-         with self.fopen(filename) as fd:
 
-             for node in nodes:
 
-                 fd.write(node.debug_str())
 
-                 fd.write("\n\n\n")
 
-     def graph_diagram(self, nodes: SchedulerNodeList):
 
-         draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
 
-     def output_code(self, filename):
 
-         shutil.copy(filename, self.filename("output_code.py"))
 
 
  |