fx_utils.py 1.0 KB

1234567891011121314151617181920212223242526272829
  1. import torch
  2. # Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
  3. # Works for length 2 patterns with 1 module and 1 function/method.
  4. def matches_module_function_pattern(pattern, node, modules):
  5. if len(node.args) == 0:
  6. return False
  7. if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
  8. node, torch.fx.Node
  9. ):
  10. return False
  11. # the first node is call_module
  12. if node.args[0].op != "call_module":
  13. return False
  14. if not isinstance(node.args[0].target, str):
  15. return False
  16. if node.args[0].target not in modules:
  17. return False
  18. if type(modules[node.args[0].target]) is not pattern[0]:
  19. return False
  20. # the second node is call_function or call_method
  21. if node.op != "call_function" and node.op != "call_method":
  22. return False
  23. if node.target != pattern[1]:
  24. return False
  25. # make sure node.args[0] output is only used by current node.
  26. if len(node.args[0].users) > 1:
  27. return False
  28. return True