merge_matmul.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import torch
  2. from torch.fx.node import Node
  3. from torch.fx._symbolic_trace import symbolic_trace
  4. from torch.fx.passes.tools_common import legalize_graph
  5. import itertools
  6. import operator
  7. from typing import Dict, List
  8. def split_result_tensors(result: torch.Tensor, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
  9. """
  10. A free function for use in the merge_matmul graph transformation below that
  11. splits the output from a merged matmul into the individual results for each
  12. input tensor.
  13. Arguments:
  14. result: The merged matmul result tensor.
  15. inputs: The list of inputs that were merged into one for the matmul.
  16. Returns:
  17. List of matmul results for each input tensor.
  18. """
  19. # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we
  20. # need an int even when tracing
  21. if isinstance(result, torch.fx.Proxy):
  22. splits = [0] * len(inputs)
  23. else:
  24. splits = [x.shape[0] for x in inputs]
  25. return torch.split(result, splits)
  26. def may_depend_on(a: Node, b: Node, search_depth: int = 6):
  27. """
  28. Determine if one node depends on another in a torch.fx.Graph.
  29. Arguments:
  30. a: The node that may have a dependency on b.
  31. b: The node that a may have a dependency on.
  32. search_depth: In the case of an indirect dependency, this function
  33. searches upto this many nodes away in search of a
  34. data dependency. If none is found, the function
  35. makes the conservative assumption that there is a
  36. dependency.
  37. Returns:
  38. True if a may depend on b, False if it definitely does not.
  39. """
  40. # Equivalence is defined as dependence.
  41. if a == b:
  42. return True
  43. # If a has no inputs, it cannot depend on b.
  44. if len(a.all_input_nodes) == 0:
  45. return False
  46. # If the search depth has been exhausted and no conclusion has been
  47. # reached, assume that there is a data dependency.
  48. if search_depth == 0:
  49. return True
  50. # Recursively check all inputs of a.
  51. for inp in a.all_input_nodes:
  52. if may_depend_on(inp, b, search_depth - 1):
  53. return True
  54. return False
  55. def are_nodes_independent(nodes: List[Node]):
  56. """
  57. Check if all of the given nodes are pairwise-data independent.
  58. Arguments:
  59. nodes: The nodes to check for data dependencies.
  60. Returns:
  61. True if any pair in nodes has a data dependency.
  62. """
  63. # For each pair in nodes:
  64. for i, j in itertools.combinations(nodes, 2):
  65. if may_depend_on(i, j) or may_depend_on(j, i):
  66. return False
  67. return True
  68. def merge_matmul(in_mod: torch.nn.Module):
  69. """
  70. A graph transformation that merges matrix multiplication operations that share the same right-hand
  71. side operand into one large matrix multiplication.
  72. ____ _________ _________
  73. ---- | | | | M| A * C |
  74. M| A | T| B | * K| C | = |---------|
  75. ---- , | | | | T| B * C |
  76. K ---- --------- ---------
  77. K R R
  78. """
  79. gm = symbolic_trace(in_mod)
  80. rhs_users: Dict[Node, List[Node]] = {}
  81. lhs_users: Dict[Node, List[Node]] = {}
  82. # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
  83. # the matmul of which they are the LHS/RHS.
  84. for node in gm.graph.nodes:
  85. if node.op != "call_function" or node.target is not torch.matmul:
  86. continue
  87. lhs, rhs = node.args
  88. # TODO: Properly handle aliasing caused by get_attr. For now,
  89. # use the attribute name as the operand if the node is a
  90. # get_attr.
  91. lhs = lhs.target if lhs.op == "get_attr" else lhs
  92. rhs = rhs.target if rhs.op == "get_attr" else rhs
  93. lhs_users.setdefault(lhs, []).append(node)
  94. rhs_users.setdefault(rhs, []).append(node)
  95. for rhs, mms in rhs_users.items():
  96. # There must be at least matmuls for a merge to make sense.
  97. if len(mms) < 2:
  98. continue
  99. # All matmuls must not depend on each other directly or indirectly
  100. # in order for the merge to be possible.
  101. if not are_nodes_independent(mms):
  102. continue
  103. lhs_vals = [mm.args[0] for mm in mms]
  104. # Merge the matmul.
  105. # Collect a list of LHS operands and the single RHS operand.
  106. lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
  107. rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs
  108. # Concatenate all the LHS operands.
  109. merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})
  110. # Multiply the concatenated LHS operands with the one RHS. This will produce
  111. # the same results as all the individual matmuls involving rhs in the original graph,
  112. # but they will all be concatenated together.
  113. merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})
  114. # Split the result of the merged matmul using the shapes of the LHS operands
  115. # to ascertain how large each chunk should be.
  116. merge_mm_split = gm.graph.call_function(
  117. split_result_tensors, (merge_mm, lhs), {}
  118. )
  119. merge_mm_res = [
  120. gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
  121. for out in range(len(lhs))
  122. ]
  123. # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
  124. for old, new in zip(mms, merge_mm_res):
  125. old.replace_all_uses_with(new)
  126. gm.graph.erase_node(old)
  127. # All of the new nodes created above were inserted at the end, so we need to sort
  128. # the nodes topologically to make sure all definitions precede uses.
  129. legalize_graph(gm)
  130. gm.recompile()
  131. gm.graph.lint()
  132. return gm