import functools import logging import torch from torch._dynamo import eval_frame from torch._dynamo.utils import counters from torch._functorch.aot_autograd import aot_module_simplified from torch._subclasses import FakeTensor from torch.utils._python_dispatch import _disable_current_modes log = logging.getLogger(__name__) def aot_autograd(**kwargs): def compiler_fn(gm: torch.fx.GraphModule, example_inputs): import functorch.compile # Hack to get around circular import problems with aot_eager_decomp_partition if callable(kwargs.get("decompositions")): kwargs["decompositions"] = kwargs["decompositions"]() # TODO: stop monkeypatching here (without even cleaning up, UGH!) functorch.compile.config.use_functionalize = True functorch.compile.config.use_fake_tensor = True counters["aot_autograd"]["total"] += 1 use_fallback = False if use_fallback: log.debug("Unable to use AOT Autograd because graph has mutation") counters["aot_autograd"]["not_ok"] += 1 return gm # OK attempt to compile def _wrapped_bw_compiler(*args, **kwargs): # stop TorchDynamo from trying to compile our generated backwards pass return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs)) bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] kwargs["bw_compiler"] = _wrapped_bw_compiler from torch._inductor.debug import enable_aot_logging try: # NB: NOT cloned! with enable_aot_logging(): cg = aot_module_simplified(gm, example_inputs, **kwargs) counters["aot_autograd"]["ok"] += 1 return eval_frame.disable(cg) except Exception: counters["aot_autograd"]["not_ok"] += 1 raise return compiler_fn def mem_efficient_fusion_kwargs(use_decomps): from functorch.compile import ( default_decompositions, min_cut_rematerialization_partition, ts_compile, ) kwargs = { # these are taken from memory_efficient_fusion() "fw_compiler": ts_compile, "bw_compiler": ts_compile, "partition_fn": min_cut_rematerialization_partition, } if use_decomps: kwargs["decompositions"] = default_decompositions return kwargs def fake_tensor_unsupported(fn): """ Decorator for backends that need real inputs. We swap out fake tensors for zero tensors. """ def defake(x): if not isinstance(x, FakeTensor): return x if x._has_symbolic_sizes_strides: size = [s.node.shape_env.size_hint(s.node.expr) for s in x.size()] stride = [s.node.shape_env.size_hint(s.node.expr) for s in x.stride()] else: size = x.size() stride = x.stride() y = torch.empty_strided( size, stride, dtype=x.dtype, device=x.device, requires_grad=x.requires_grad, ) y.zero_() return y @functools.wraps(fn) def wrapper(model, inputs, **kwargs): with _disable_current_modes(): inputs = list(map(defake, inputs)) return fn(model, inputs, **kwargs) return wrapper def device_from_inputs(example_inputs) -> torch.device: for x in example_inputs: if hasattr(x, "device"): return x.device def dtype_from_inputs(example_inputs) -> torch.dtype: for x in example_inputs: if hasattr(x, "dtype"): return x.dtype