1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import functools
- from importlib import import_module
- from functorch.compile import min_cut_rematerialization_partition, nop
- import torch
- from torch._functorch.compilers import ts_compile
- from .common import aot_autograd
- from .registry import register_debug_backend as register_backend
- """
- This file contains TorchDynamo backends intended for debugging uses.
- """
- @register_backend
- def eager(gm, fake_tensor_inputs):
- return gm
- @register_backend(name="ts")
- def torchscript(gm, fake_tensor_inputs):
- return torch.jit.script(gm)
- # Useful for debugging purpose
- # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
- aot_eager = aot_autograd(fw_compiler=nop)
- register_backend(name="aot_eager", compiler_fn=aot_eager)
- # Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
- # inductor problems.
- # aot_eager_decomp_partition just replaces the inductor compiler with nop to help
- # isolate inductor vs aot_eager errors
- aot_eager_decomp_partition = aot_autograd(
- # these are taken from memory_efficient_fusion()
- fw_compiler=nop,
- bw_compiler=nop,
- # NB: lambda here is to delay import of inductor
- decompositions=lambda: import_module(
- "torch._inductor.compile_fx"
- ).select_decomp_table(),
- partition_fn=functools.partial(
- min_cut_rematerialization_partition, compiler="inductor"
- ),
- )
- register_backend(
- name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
- )
- # AOT Autograd with torchscript backend. Default partitioner.
- # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
- # by using the relevant fuser with torch.jit.fuser(...)
- aot_ts = aot_autograd(fw_compiler=ts_compile)
- register_backend(name="aot_ts", compiler_fn=aot_ts)
|