debug.py 805 B

12345678910111213141516171819202122232425262728293031
  1. import torch.fx as fx
  2. def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
  3. """
  4. Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
  5. `gm` gets run.
  6. Args:
  7. gm: graph module to insert breakpoint. It is then recompiled for it to
  8. take effect.
  9. Returns:
  10. the `gm` with breakpoint inserted.
  11. """
  12. def insert_pdb(body):
  13. return ["import pdb; pdb.set_trace()\n", *body]
  14. with gm.graph.on_generate_code(
  15. make_transformer=lambda cur_transform: (
  16. # new code transformer to register
  17. lambda body: (
  18. insert_pdb(
  19. cur_transform(body) if cur_transform
  20. else body
  21. )
  22. )
  23. )
  24. ):
  25. gm.recompile()
  26. return gm