_ir_utils.py 616 B

123456789101112131415161718
  1. import torch
  2. from typing import Union
  3. class _InsertPoint:
  4. def __init__(self, insert_point_graph: torch._C.Graph, insert_point: Union[torch._C.Node, torch._C.Block]):
  5. self.insert_point = insert_point
  6. self.g = insert_point_graph
  7. self.guard = None
  8. def __enter__(self):
  9. self.prev_insert_point = self.g.insertPoint()
  10. self.g.setInsertPoint(self.insert_point)
  11. def __exit__(self, *args):
  12. self.g.setInsertPoint(self.prev_insert_point)
  13. def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]):
  14. return _InsertPoint(self, insert_point)