1234567891011121314151617181920212223242526272829 |
- import torch
- # Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
- # Works for length 2 patterns with 1 module and 1 function/method.
- def matches_module_function_pattern(pattern, node, modules):
- if len(node.args) == 0:
- return False
- if not isinstance(node.args[0], torch.fx.Node) or not isinstance(
- node, torch.fx.Node
- ):
- return False
- # the first node is call_module
- if node.args[0].op != "call_module":
- return False
- if not isinstance(node.args[0].target, str):
- return False
- if node.args[0].target not in modules:
- return False
- if type(modules[node.args[0].target]) is not pattern[0]:
- return False
- # the second node is call_function or call_method
- if node.op != "call_function" and node.op != "call_method":
- return False
- if node.target != pattern[1]:
- return False
- # make sure node.args[0] output is only used by current node.
- if len(node.args[0].users) > 1:
- return False
- return True
|