123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404 |
- import collections
- import contextlib
- import cProfile
- import functools
- import itertools
- import logging
- import os.path
- import pstats
- import shutil
- import subprocess
- import sys
- from typing import Any, List
- from unittest.mock import patch
- from functorch.compile import (
- config as functorch_config,
- draw_graph,
- get_aot_graph_name,
- get_graph_being_compiled,
- )
- import torch
- from torch import fx as fx
- from torch._dynamo import config as dynamo_config
- from torch._dynamo.debug_utils import save_graph_repro, wrap_compiler_debug
- from torch._dynamo.utils import get_debug_dir, init_logging
- from torch.fx.graph_module import GraphModule
- from torch.fx.passes.shape_prop import TensorMetadata
- from torch.fx.passes.tools_common import legalize_graph
- from . import config, ir # noqa: F811, this is needed
- from .scheduler import (
- BaseSchedulerNode,
- FusedSchedulerNode,
- NopKernelSchedulerNode,
- OutputNode,
- SchedulerNode,
- )
- from .virtualized import V
- log = logging.getLogger(__name__)
- @functools.lru_cache(None)
- def has_dot():
- try:
- subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE)
- return True
- except subprocess.SubprocessError:
- return False
- def draw_buffers(nodes, print_graph=False, fname=None):
- """
- Draw a graph in fname.svg.
- nodes is a list of SchedulerNode objects.
- """
- if not has_dot():
- log.warning("draw_buffers() requires `graphviz` package")
- return
- if fname is None:
- fname = get_graph_being_compiled()
- graph = create_fx_from_snodes(nodes)
- for node in graph.nodes:
- if "fusion_meta" not in node.meta:
- continue
- group = node.meta["fusion_meta"].group
- if isinstance(group, tuple):
- group = group[1]
- # gather meta data
- dtype = None
- if isinstance(node, ir.ComputedBuffer):
- dtype = node.data.dtype
- metadata = TensorMetadata(group, dtype, None, None, None, None, None)
- node.meta["tensor_meta"] = metadata
- if print_graph:
- print(graph)
- gm = GraphModule({}, graph)
- legalize_graph(gm)
- gm.graph.lint()
- draw_graph(gm, fname, clear_meta=False)
- def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
- """
- Creates a FX Graph from a list of SchedulerNode objects.
- """
- def get_fake_func(name):
- def func1(*args):
- return 0
- func1.__name__ = name
- return func1
- FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"])
- func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]}
- buf_to_fx_node = {}
- graph = torch.fx.Graph()
- first_node = None
- outputs = []
- group: Any = None
- # create call_function node for each Buffer and Kernel
- for snode in snodes:
- if snode.is_extern():
- node_type = "extern"
- group = node_type
- elif snode.is_template():
- node_type = "template"
- group = node_type
- elif isinstance(snode, NopKernelSchedulerNode):
- node_type = "nop"
- group = node_type
- elif isinstance(snode, SchedulerNode):
- node_type = "compute"
- group = snode.group
- elif isinstance(snode, FusedSchedulerNode):
- node_type = "fused"
- group = snode.group
- else:
- raise RuntimeError("Unknown node type")
- node_func = func_dict[node_type]
- fx_node = graph.call_function(node_func, args=(), kwargs=None)
- def in_output(snode):
- if isinstance(snode, FusedSchedulerNode):
- return any([in_output(x) for x in snode.snodes])
- return any([isinstance(user.node, OutputNode) for user in snode.users])
- if in_output(snode):
- outputs.append(fx_node)
- name = snode.get_name()
- fx_node.name = name
- fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type)
- if isinstance(snode, FusedSchedulerNode):
- for x in snode.snodes:
- buf_to_fx_node[x.get_name()] = fx_node
- buf_to_fx_node[name] = fx_node
- if first_node is None:
- first_node = fx_node
- # create edges between nodes
- for snode in snodes:
- name = snode.get_name()
- deps = snode.read_writes.reads
- fx_node = buf_to_fx_node[name]
- new_args = []
- for dep in deps:
- if dep.name in buf_to_fx_node:
- dep_node = buf_to_fx_node[dep.name]
- else:
- with graph.inserting_before(first_node):
- dep_node = graph.placeholder(dep.name)
- buf_to_fx_node[dep.name] = dep_node
- new_args.append(dep_node)
- fx_node.args = tuple(new_args)
- graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs))
- return graph
- @contextlib.contextmanager
- def enable_aot_logging():
- compile_debug = bool(os.environ.get("TORCH_COMPILE_DEBUG", False))
- debug_graphs = functorch_config.debug_graphs
- debug_joint_graphs = functorch_config.debug_joint
- import torch._functorch.aot_autograd
- log = logging.getLogger(torch._functorch.aot_autograd.__name__)
- stack = contextlib.ExitStack()
- stack.enter_context(patch("functorch.compile.config.log_level", logging.DEBUG))
- # if user has specified they want to see graphs via either env var
- # add stream to std out
- if debug_graphs or debug_joint_graphs:
- stdout_handler = logging.StreamHandler(sys.stdout)
- log.addHandler(stdout_handler)
- stack.callback(lambda: log.removeHandler(stdout_handler))
- if not compile_debug:
- try:
- yield
- finally:
- stack.close()
- return
- # Enable all graphs to be logged to a file by setting the flags to True
- # and the log level of the file logger to DEBUG
- stack.enter_context(patch("functorch.compile.config.debug_partitioner", True))
- stack.enter_context(patch("functorch.compile.config.debug_graphs", True))
- stack.enter_context(patch("functorch.compile.config.debug_joint", True))
- path = os.path.join(get_debug_dir(), "aot_torchinductor")
- if not os.path.exists(path):
- os.makedirs(path)
- fh = logging.FileHandler(
- os.path.join(
- path,
- f"aot_{get_aot_graph_name()}_debug.log",
- )
- )
- fh.setLevel(logging.DEBUG)
- fh.setFormatter(
- logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
- )
- log.addHandler(fh)
- try:
- yield
- finally:
- log.removeHandler(fh)
- stack.close()
- class DebugContext:
- _counter = itertools.count()
- @staticmethod
- def wrap(fn):
- @functools.wraps(fn)
- def inner(*args, **kwargs):
- with DebugContext():
- return fn(*args, **kwargs)
- return wrap_compiler_debug(inner, compiler_name="inductor")
- @staticmethod
- def create_debug_dir(folder_name):
- for n in DebugContext._counter:
- dirname = os.path.join(
- get_debug_dir(),
- "aot_torchinductor",
- f"{folder_name}.{n}",
- )
- if not os.path.exists(dirname):
- os.makedirs(dirname)
- return dirname
- def __init__(self):
- self._prof = None
- self._path = None
- self._stack = contextlib.ExitStack()
- def rename(self, new_path: str):
- if not self._path:
- return
- assert new_path.endswith(".debug"), new_path
- if os.path.exists(new_path):
- shutil.rmtree(new_path)
- try:
- os.rename(self._path, new_path)
- self._path = new_path
- except OSError:
- # other OS might have troubling renaming dir with open files
- pass
- def fopen(self, filename):
- assert self._path
- return open(os.path.join(self._path, filename), "w")
- def filename(self, suffix):
- return os.path.join(self._path, suffix)
- def upload_tar(self):
- if config.trace.upload_tar is not None:
- import tarfile
- assert self._path
- tar_file = os.path.join(
- self._path, f"{os.path.basename(self._path)}.tar.gz"
- )
- with tarfile.open(tar_file, "w:gz") as tar:
- tar.add(self._path, arcname=os.path.basename(self._path))
- config.trace.upload_tar(tar_file)
- def __enter__(self):
- log = logging.getLogger("torch._inductor")
- if not log.handlers:
- init_logging()
- if config.debug:
- def reset_log_level(level):
- dynamo_config.log_level = level
- self._stack.callback(reset_log_level, dynamo_config.log_level)
- dynamo_config.log_level = logging.DEBUG
- self._stack.enter_context(V.set_debug_handler(self))
- if not config.trace.enabled:
- return
- self._path = self.create_debug_dir(get_aot_graph_name())
- if config.trace.debug_log:
- self._setup_log_capture("debug.log", logging.DEBUG)
- if config.trace.info_log:
- self._setup_log_capture("info.log", logging.INFO)
- if config.trace.compile_profile:
- self._prof = cProfile.Profile()
- self._prof.enable()
- def _setup_log_capture(self, filename, level):
- log = logging.getLogger("torch._inductor")
- fd = self._stack.enter_context(self.fopen(filename))
- ch = logging.StreamHandler(fd)
- ch.setLevel(level)
- ch.setFormatter(
- logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s")
- )
- log.addHandler(ch)
- log.setLevel(min(log.level, level))
- self._stack.callback(log.removeHandler, ch)
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self._prof:
- self._prof.disable()
- self._save_profile_data()
- if self._path:
- self.upload_tar()
- log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path)
- self._stack.close()
- def _save_profile_data(self):
- self._prof.dump_stats(self.filename("compile.prof"))
- with self.fopen("compile.stats") as fd:
- stats = pstats.Stats(self._prof, stream=fd)
- stats.strip_dirs()
- stats.sort_stats("cumtime")
- stats.print_stats(100)
- stats.sort_stats("tottime")
- stats.print_stats(100)
- def __getattr__(self, name):
- if config.trace.enabled and getattr(config.trace, name):
- try:
- return getattr(DebugFormatter(self), name)
- except Exception:
- log.warning("Ignoring exception in debug code", exc_info=True)
- else:
- def ignored(*args, **kwargs):
- pass
- return ignored
- SchedulerNodeList = List[Any]
- class DebugFormatter:
- def __init__(self, handler):
- self.fopen = handler.fopen
- self.filename = handler.filename
- self.handler = handler
- def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
- with self.fopen("fx_graph_runnable.py") as fd:
- save_graph_repro(fd, gm, inputs, "inductor")
- with self.fopen("fx_graph_readable.py") as fd:
- fd.write(gm.print_readable(print_output=False))
- def fx_graph_transformed(
- self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]
- ):
- with self.fopen("fx_graph_transformed.py") as fd:
- fd.write(gm.print_readable(print_output=False))
- def ir_pre_fusion(self, nodes: SchedulerNodeList):
- self._write_ir("ir_pre_fusion.txt", nodes)
- def ir_post_fusion(self, nodes: SchedulerNodeList):
- self._write_ir("ir_post_fusion.txt", nodes)
- def _write_ir(self, filename: str, nodes: SchedulerNodeList):
- with self.fopen(filename) as fd:
- for node in nodes:
- fd.write(node.debug_str())
- fd.write("\n\n\n")
- def graph_diagram(self, nodes: SchedulerNodeList):
- draw_buffers(nodes, fname=self.filename("graph_diagram.svg"))
- def output_code(self, filename):
- shutil.copy(filename, self.filename("output_code.py"))