123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- import functools
- import logging
- import torch
- from torch._dynamo import eval_frame
- from torch._dynamo.utils import counters
- from torch._functorch.aot_autograd import aot_module_simplified
- from torch._subclasses import FakeTensor
- from torch.utils._python_dispatch import _disable_current_modes
- log = logging.getLogger(__name__)
- def aot_autograd(**kwargs):
- def compiler_fn(gm: torch.fx.GraphModule, example_inputs):
- import functorch.compile
- # Hack to get around circular import problems with aot_eager_decomp_partition
- if callable(kwargs.get("decompositions")):
- kwargs["decompositions"] = kwargs["decompositions"]()
- # TODO: stop monkeypatching here (without even cleaning up, UGH!)
- functorch.compile.config.use_functionalize = True
- functorch.compile.config.use_fake_tensor = True
- counters["aot_autograd"]["total"] += 1
- use_fallback = False
- if use_fallback:
- log.debug("Unable to use AOT Autograd because graph has mutation")
- counters["aot_autograd"]["not_ok"] += 1
- return gm
- # OK attempt to compile
- def _wrapped_bw_compiler(*args, **kwargs):
- # stop TorchDynamo from trying to compile our generated backwards pass
- return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs))
- bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
- kwargs["bw_compiler"] = _wrapped_bw_compiler
- from torch._inductor.debug import enable_aot_logging
- try:
- # NB: NOT cloned!
- with enable_aot_logging():
- cg = aot_module_simplified(gm, example_inputs, **kwargs)
- counters["aot_autograd"]["ok"] += 1
- return eval_frame.disable(cg)
- except Exception:
- counters["aot_autograd"]["not_ok"] += 1
- raise
- return compiler_fn
- def mem_efficient_fusion_kwargs(use_decomps):
- from functorch.compile import (
- default_decompositions,
- min_cut_rematerialization_partition,
- ts_compile,
- )
- kwargs = {
- # these are taken from memory_efficient_fusion()
- "fw_compiler": ts_compile,
- "bw_compiler": ts_compile,
- "partition_fn": min_cut_rematerialization_partition,
- }
- if use_decomps:
- kwargs["decompositions"] = default_decompositions
- return kwargs
- def fake_tensor_unsupported(fn):
- """
- Decorator for backends that need real inputs. We swap out fake
- tensors for zero tensors.
- """
- def defake(x):
- if not isinstance(x, FakeTensor):
- return x
- if x._has_symbolic_sizes_strides:
- size = [s.node.shape_env.size_hint(s.node.expr) for s in x.size()]
- stride = [s.node.shape_env.size_hint(s.node.expr) for s in x.stride()]
- else:
- size = x.size()
- stride = x.stride()
- y = torch.empty_strided(
- size,
- stride,
- dtype=x.dtype,
- device=x.device,
- requires_grad=x.requires_grad,
- )
- y.zero_()
- return y
- @functools.wraps(fn)
- def wrapper(model, inputs, **kwargs):
- with _disable_current_modes():
- inputs = list(map(defake, inputs))
- return fn(model, inputs, **kwargs)
- return wrapper
- def device_from_inputs(example_inputs) -> torch.device:
- for x in example_inputs:
- if hasattr(x, "device"):
- return x.device
- def dtype_from_inputs(example_inputs) -> torch.dtype:
- for x in example_inputs:
- if hasattr(x, "dtype"):
- return x.dtype
|