12345678910111213141516171819202122232425262728293031 |
- import torch.fx as fx
- def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
- """
- Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
- `gm` gets run.
- Args:
- gm: graph module to insert breakpoint. It is then recompiled for it to
- take effect.
- Returns:
- the `gm` with breakpoint inserted.
- """
- def insert_pdb(body):
- return ["import pdb; pdb.set_trace()\n", *body]
- with gm.graph.on_generate_code(
- make_transformer=lambda cur_transform: (
- # new code transformer to register
- lambda body: (
- insert_pdb(
- cur_transform(body) if cur_transform
- else body
- )
- )
- )
- ):
- gm.recompile()
- return gm
|