import functools from importlib import import_module from functorch.compile import min_cut_rematerialization_partition, nop import torch from torch._functorch.compilers import ts_compile from .common import aot_autograd from .registry import register_debug_backend as register_backend """ This file contains TorchDynamo backends intended for debugging uses. """ @register_backend def eager(gm, fake_tensor_inputs): return gm @register_backend(name="ts") def torchscript(gm, fake_tensor_inputs): return torch.jit.script(gm) # Useful for debugging purpose # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. aot_eager = aot_autograd(fw_compiler=nop) register_backend(name="aot_eager", compiler_fn=aot_eager) # Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs # inductor problems. # aot_eager_decomp_partition just replaces the inductor compiler with nop to help # isolate inductor vs aot_eager errors aot_eager_decomp_partition = aot_autograd( # these are taken from memory_efficient_fusion() fw_compiler=nop, bw_compiler=nop, # NB: lambda here is to delay import of inductor decompositions=lambda: import_module( "torch._inductor.compile_fx" ).select_decomp_table(), partition_fn=functools.partial( min_cut_rematerialization_partition, compiler="inductor" ), ) register_backend( name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition ) # AOT Autograd with torchscript backend. Default partitioner. # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser # by using the relevant fuser with torch.jit.fuser(...) aot_ts = aot_autograd(fw_compiler=ts_compile) register_backend(name="aot_ts", compiler_fn=aot_ts)