debugging.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import functools
  2. from importlib import import_module
  3. from functorch.compile import min_cut_rematerialization_partition, nop
  4. import torch
  5. from torch._functorch.compilers import ts_compile
  6. from .common import aot_autograd
  7. from .registry import register_debug_backend as register_backend
  8. """
  9. This file contains TorchDynamo backends intended for debugging uses.
  10. """
  11. @register_backend
  12. def eager(gm, fake_tensor_inputs):
  13. return gm
  14. @register_backend(name="ts")
  15. def torchscript(gm, fake_tensor_inputs):
  16. return torch.jit.script(gm)
  17. # Useful for debugging purpose
  18. # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging.
  19. aot_eager = aot_autograd(fw_compiler=nop)
  20. register_backend(name="aot_eager", compiler_fn=aot_eager)
  21. # Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs
  22. # inductor problems.
  23. # aot_eager_decomp_partition just replaces the inductor compiler with nop to help
  24. # isolate inductor vs aot_eager errors
  25. aot_eager_decomp_partition = aot_autograd(
  26. # these are taken from memory_efficient_fusion()
  27. fw_compiler=nop,
  28. bw_compiler=nop,
  29. # NB: lambda here is to delay import of inductor
  30. decompositions=lambda: import_module(
  31. "torch._inductor.compile_fx"
  32. ).select_decomp_table(),
  33. partition_fn=functools.partial(
  34. min_cut_rematerialization_partition, compiler="inductor"
  35. ),
  36. )
  37. register_backend(
  38. name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition
  39. )
  40. # AOT Autograd with torchscript backend. Default partitioner.
  41. # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser
  42. # by using the relevant fuser with torch.jit.fuser(...)
  43. aot_ts = aot_autograd(fw_compiler=ts_compile)
  44. register_backend(name="aot_ts", compiler_fn=aot_ts)