123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442 |
- import copy
- import logging
- import os
- import pickle
- import random
- from contextlib import contextmanager
- from functools import partial
- from typing import Callable, Optional, Tuple, Union
- import torch
- from torch import SymInt
- import torch.fx as fx
- import torch.nn as nn
- from torch._decomp import get_decompositions
- from torch.fx.experimental.symbolic_shapes import bind_symbols
- from .aot_autograd import aot_function, aot_module, make_boxed_compiler
- from .compile_utils import strip_overloads
- from .partitioners import (
- default_partition,
- draw_graph,
- min_cut_rematerialization_partition,
- )
- import torch.utils._pytree as pytree
- log = logging.getLogger(__name__)
- # These canonicalizations are needed here (and not decompositions), as the ops
- # we're trying to canonicalize to CompositeImplicitAutograd.
- def _canonicalize(fx_g):
- for node in fx_g.graph.nodes:
- if node.target == torch.ops.aten._to_copy:
- node.target = torch.ops.aten.to
- fx_g.recompile()
- return fx_g
- @contextmanager
- def _disable_jit_autocast():
- old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
- try:
- yield
- finally:
- torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
- @make_boxed_compiler
- def ts_compile(fx_g: fx.GraphModule, inps) -> Callable:
- """
- Compiles the :attr:`fx_g` with Torchscript compiler.
- .. warning::
- This API is experimental and likely to change.
- Args:
- fx_g(fx.GraphModule): The input Fx graph module to be compiled.
- Returns:
- Torch scripted model.
- """
- with _disable_jit_autocast():
- strip_overloads(fx_g)
- for node in fx_g.graph.nodes:
- if (
- node.target == torch.ops.aten._to_copy
- and len(node.args) == 1
- and len(node.kwargs) == 1
- and "dtype" in node.kwargs
- ):
- node.target = torch.ops.aten.to
- for node in fx_g.graph.nodes:
- new_kwargs = {}
- for k, v in node.kwargs.items():
- if isinstance(v, torch.device):
- v = v.type
- new_kwargs[k] = v
- node.kwargs = new_kwargs
- fx_g.graph.lint()
- fx_g.recompile()
- f = torch.jit.script(fx_g)
- torch._C._jit_pass_remove_mutation(f.graph)
- f = torch.jit.freeze(f.eval())
- f = torch.jit.optimize_for_inference(f)
- if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps):
- f(*inps)
- return f
- def _draw_graph_compile(fx_g, _, name, clear_meta=True):
- print(fx_g.code)
- draw_graph(fx_g, name, clear_meta=clear_meta)
- return fx_g
- def draw_graph_compile(name):
- return make_boxed_compiler(
- partial(_draw_graph_compile, name=name)
- )
- @make_boxed_compiler
- def nop(fx_g: fx.GraphModule, _) -> Callable:
- """
- Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler
- and can be used to check accuracy.
- .. warning::
- This API is experimental and likely to change.
- """
- return fx_g
- class DebugInterpreter(fx.Interpreter):
- def run(self, *args):
- self.symbol_mapping = bind_symbols(self.module, *args)
- super().run(*args)
- def run_node(self, n):
- import sympy
- def subst_symint(ni):
- if not isinstance(ni, SymInt):
- return ni
- r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping))
- assert len(r.free_symbols) == 0, r
- return int(r)
- def subst_symint_tuple(nis):
- return tuple(subst_symint(ni) for ni in nis)
- def check_significant_strides(a, b):
- if subst_symint(a.numel()) > 0:
- for idx in range(a.ndim):
- if subst_symint(a.stride(idx)) != b.stride(idx) and subst_symint(a.size(idx)) > 1:
- return False
- return True
- def check(nv, rv, desc):
- assert callable(desc)
- assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}"
- assert subst_symint_tuple(nv.size()) == rv.size(), \
- f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}"
- same_strides = check_significant_strides(nv, rv)
- assert same_strides, f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}"
- r = super().run_node(n)
- if 'val' in n.meta:
- n_vals, n_spec = pytree.tree_flatten(n.meta['val'])
- r_vals, r_spec = pytree.tree_flatten(r)
- # TODO: There is some sort of problem where we record that an
- # operator returned a tuple/list, and then later it turns out the
- # real version of the operator returned a list/tuple. Need to
- # figure out what's actually going on here, the error itself is
- # harmless enough as we only getitem out the outputs.
- # assert n_spec == r_spec, f"{n_spec} != {r_spec}"
- assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
- for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
- if not isinstance(rv, torch.Tensor):
- continue
- check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}")
- return r
- @make_boxed_compiler
- def debug_nop(fx_g: fx.GraphModule, _) -> Callable:
- """
- Returns a (slow) interpreter over the FX graph module that also checks
- various debugging properties (e.g., that tracing strides matched real
- strides.)
- """
- return DebugInterpreter(fx_g).run
- @make_boxed_compiler
- def simple_ts_compile(fx_g, _):
- strip_overloads(fx_g)
- f = torch.jit.script(fx_g)
- f = torch.jit.freeze(f.eval())
- return f
- def nnc_jit(f, static_argnums=None):
- return aot_function(f, simple_ts_compile, static_argnums=static_argnums)
- aten = torch.ops.aten
- default_decompositions = {
- aten.detach,
- aten.gelu_backward,
- aten.leaky_relu_backward,
- aten.sigmoid_backward,
- aten.threshold_backward,
- aten.hardtanh_backward,
- aten.hardsigmoid_backward,
- aten.hardswish_backward,
- aten.tanh_backward,
- aten.silu_backward,
- aten.elu_backward,
- aten.cudnn_batch_norm,
- aten.cudnn_batch_norm_backward,
- aten.masked_fill.Scalar,
- aten.masked_fill.Tensor,
- aten.elu,
- aten.leaky_relu,
- aten.hardtanh,
- aten.hardswish,
- aten.hardsigmoid,
- aten.conj_physical,
- aten.is_same_size,
- }
- default_decompositions = get_decompositions(default_decompositions)
- @make_boxed_compiler
- def print_compile(fx_g, _):
- print(fx_g.code)
- return fx_g
- def memory_efficient_fusion(
- fn: Union[Callable, nn.Module],
- static_argnums: Optional[Tuple[int]] = None,
- **kwargs,
- ):
- """
- Wrapper function over :func:`aot_function` and :func:`aot_module` to perform
- memory efficient fusion. It uses the
- :func:`min_cut_rematerialization_partition` partitioner to perform efficient
- recomputation. It uses NVFuser to compile the generated forward and backward
- graphs.
- .. warning::
- This API is experimental and likely to change.
- Args:
- fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module``
- that takes one ore more arguments. Must return one or more Tensors.
- static_argnums (Optional[Tuple[Int]]): An option tuple of ints to mark
- the arguments of the function as static.
- **kwargs: Any other overrides you want to make to the settings
- Returns:
- Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior
- of the original :attr:`fn`, but whose forward and backward graphs have
- gone through recomputation optimizations, and the graphs have been
- compiled with nvfuser.
- """
- config = {
- "fw_compiler": ts_compile,
- "bw_compiler": ts_compile,
- "partition_fn": min_cut_rematerialization_partition,
- "decompositions": default_decompositions,
- "static_argnums": static_argnums,
- }
- config.update(kwargs)
- if isinstance(fn, torch.nn.Module):
- return aot_module(fn, **config)
- else:
- return aot_function(fn, **config)
- def debug_compile(fx_g, inps):
- fx_g.to_folder("foo")
- print(
- f"""
- ##############################################################
- # To minimize FX graph, copy and paste the below and run it #
- ##############################################################
- import torch
- import torch.fx as fx
- from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
- inps = {[(i.shape, i.dtype) for i in inps]}
- inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
- from foo import FxModule
- mod = FxModule().cuda()
- with torch.jit.fuser("fuser2"):
- # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess
- minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)
- """
- )
- from foo import FxModule
- FxModule().cuda()(*inps)
- return ts_compile(fx_g, inps)
- graph_index = 0
- def get_inputs(input_data_path):
- """
- Return a random input for the given inputs meta generated from _save_fx_default.
- """
- inputs = []
- with (open(input_data_path, "rb")) as f:
- inputs_meta = pickle.load(f)
- inputs = []
- for meta in inputs_meta:
- if len(meta) == 1:
- type = meta
- input = type(random.rand())
- else:
- type, shape, stride, dtype, device = meta
- if dtype in {
- torch.int,
- torch.int32,
- torch.int64,
- torch.bool,
- torch.int,
- torch.uint8,
- int,
- float,
- }:
- input = torch.randint(0, 1, shape, dtype=dtype, device=device)
- else:
- input = torch.rand(shape, dtype=dtype, device=device)
- inputs.append(input)
- return inputs
- def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs):
- """
- The forward, backward, and joint computation graph will be stored in
- {folder_name}/{current_name}/{current_name}_forward_{graph_index},
- {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and
- {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively.
- The input shape of the graphs will be stored in the .input files.
- These files can be loaded with pickle,
- and is a list of format (type, shape, stride, dtype, device).
- In the case of type = int or float, it is just (type,).
- For joint graph input, it is a nested list [[],[]]
- where the two inner lists have the same format.
- If dump_example_input is True, example_inputs will be stored in .pt file.
- Since each function might produce multiple graphs,
- the graph_index is used to distinguish difference graphs
- """
- from functorch.compile import aot_module_simplified
- def get_input_meta(args):
- input_meta = []
- if len(args) > 0 and isinstance(args[0], tuple): # joint input
- input_meta += get_input_meta(args[0])
- input_meta += get_input_meta(args[1])
- return input_meta
- for arg in args:
- if type(arg) == int or type(arg) == float:
- input_meta.append((type(arg),))
- else:
- input_meta.append(
- (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)
- )
- return input_meta
- def graph_saver_helper(gm_to_save, args, type_name):
- global graph_index
- if len(gm_to_save.graph.nodes) == 0:
- log.log(
- logging.WARNING,
- f"No nodes in graph {current_name}_{type_name}_{graph_index}.",
- )
- return
- gm = copy.deepcopy(gm_to_save)
- gm.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen
- gm.recompile()
- input_meta = get_input_meta(args)
- isExist = os.path.exists(f"{folder_name}/{current_name}")
- if not isExist:
- os.makedirs(f"{folder_name}/{current_name}")
- gm.to_folder(
- f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}"
- )
- pickle.dump(
- input_meta,
- open(
- f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950
- "wb",
- ),
- ) # noqa: E501
- if dump_example_input:
- torch.save(
- args,
- f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950
- ) # noqa: E501
- def graph_saver_forward(gm, fw_args):
- graph_saver_helper(gm, fw_args, "forward")
- return gm
- def graph_saver_backward(gm, bw_args):
- graph_saver_helper(gm, bw_args, "backward")
- global graph_index
- graph_index += 1
- return gm
- def graph_saver_joint(gm, joint_args):
- graph_saver_helper(gm, joint_args, "joint")
- return default_partition(gm, joint_args)
- return aot_module_simplified(
- gm,
- example_inputs,
- fw_compiler=graph_saver_forward,
- bw_compiler=graph_saver_backward,
- partition_fn=graph_saver_joint,
- decompositions=default_decompositions,
- )
- # WARNING: This isn't tested anywhere!!
- def graph_dumper_aot(current_name, folder_name, dump_example_input=False):
- """
- Dump the forward, backward, and joint computation graph.
- Example Usage:
- save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False)
- optimize_ctx = torchdynamo.optimize(
- save_fx_func
- )
- with torch.enable_grad():
- with optimize_ctx:
- result = forward_and_backward_pass(model, example_inputs)
- """
- global graph_index
- graph_index = 0
- return partial(_save_fx_default, current_name, folder_name, dump_example_input)
|