common.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from torch.nn import Module
  2. from torch.fx.graph_module import GraphModule
  3. from torch.fx.graph import Graph
  4. from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
  5. from torch.fx._compatibility import compatibility
  6. __all__ = ['HolderModule', 'lift_subgraph_as_module', 'compare_graphs']
  7. @compatibility(is_backward_compatible=False)
  8. class HolderModule(Module):
  9. """
  10. HolderModule is used to copy all the attributes from original module to submodules
  11. that uses the attributes
  12. """
  13. def __init__(self, d):
  14. super().__init__()
  15. for k, v in d.items():
  16. self.add_module(k, v)
  17. @compatibility(is_backward_compatible=False)
  18. def lift_subgraph_as_module(gm: GraphModule, subgraph: Graph, class_name: str = 'GraphModule') -> GraphModule:
  19. """
  20. Create a GraphModule for subgraph, which copies the necessory attributes from the original parent graph_module.
  21. Args:
  22. gm (GraphModule): parent graph module
  23. subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph
  24. class_name (str): name for the submodule
  25. """
  26. # Loop through all module calls (call_module) and param fetches (get_attr)
  27. # in this component, creating HolderModules as necessary to match the path.
  28. # e.g. if in the original module there's a get_attr node fetches "conv.weight".
  29. # We create a HolderModule as root -> add a HolderModule named "conv" ->
  30. # make "weight" a attribute of "conv" HolderModule and point to conv.weight in
  31. # the original module.
  32. submodule = HolderModule({})
  33. for n in subgraph.nodes:
  34. if n.op not in ("call_module", "get_attr"):
  35. continue
  36. target = n.target
  37. assert isinstance(target, str)
  38. target_name_parts = target.split(".")
  39. curr = submodule
  40. orig_gm = gm
  41. for name in target_name_parts[:-1]:
  42. if not hasattr(curr, name):
  43. curr.add_module(name, HolderModule({}))
  44. curr = getattr(curr, name)
  45. orig_gm = getattr(orig_gm, name)
  46. leaf_node_name = target_name_parts[-1]
  47. leaf_node = getattr(orig_gm, leaf_node_name)
  48. # Relies on custom __setattr__ magic.
  49. setattr(curr, leaf_node_name, leaf_node)
  50. return GraphModule(submodule, subgraph, class_name)
  51. @compatibility(is_backward_compatible=False)
  52. def compare_graphs(left: Graph, right: Graph) -> bool:
  53. """
  54. Return True if two graphs are identical, i.e they
  55. - have the same number of outputs in the same order
  56. - have the same number of inputs in the same order
  57. - have the same set of nodes, and identical connectivity
  58. """
  59. matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True)
  60. matches = matcher.match(right)
  61. return len(matches) > 0