123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- import torch
- from torch.fx.node import Node
- from torch.fx._symbolic_trace import symbolic_trace
- from torch.fx.passes.tools_common import legalize_graph
- import itertools
- import operator
- from typing import Dict, List
- def split_result_tensors(result: torch.Tensor, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
- """
- A free function for use in the merge_matmul graph transformation below that
- splits the output from a merged matmul into the individual results for each
- input tensor.
- Arguments:
- result: The merged matmul result tensor.
- inputs: The list of inputs that were merged into one for the matmul.
- Returns:
- List of matmul results for each input tensor.
- """
- # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we
- # need an int even when tracing
- if isinstance(result, torch.fx.Proxy):
- splits = [0] * len(inputs)
- else:
- splits = [x.shape[0] for x in inputs]
- return torch.split(result, splits)
- def may_depend_on(a: Node, b: Node, search_depth: int = 6):
- """
- Determine if one node depends on another in a torch.fx.Graph.
- Arguments:
- a: The node that may have a dependency on b.
- b: The node that a may have a dependency on.
- search_depth: In the case of an indirect dependency, this function
- searches upto this many nodes away in search of a
- data dependency. If none is found, the function
- makes the conservative assumption that there is a
- dependency.
- Returns:
- True if a may depend on b, False if it definitely does not.
- """
- # Equivalence is defined as dependence.
- if a == b:
- return True
- # If a has no inputs, it cannot depend on b.
- if len(a.all_input_nodes) == 0:
- return False
- # If the search depth has been exhausted and no conclusion has been
- # reached, assume that there is a data dependency.
- if search_depth == 0:
- return True
- # Recursively check all inputs of a.
- for inp in a.all_input_nodes:
- if may_depend_on(inp, b, search_depth - 1):
- return True
- return False
- def are_nodes_independent(nodes: List[Node]):
- """
- Check if all of the given nodes are pairwise-data independent.
- Arguments:
- nodes: The nodes to check for data dependencies.
- Returns:
- True if any pair in nodes has a data dependency.
- """
- # For each pair in nodes:
- for i, j in itertools.combinations(nodes, 2):
- if may_depend_on(i, j) or may_depend_on(j, i):
- return False
- return True
- def merge_matmul(in_mod: torch.nn.Module):
- """
- A graph transformation that merges matrix multiplication operations that share the same right-hand
- side operand into one large matrix multiplication.
- ____ _________ _________
- ---- | | | | M| A * C |
- M| A | T| B | * K| C | = |---------|
- ---- , | | | | T| B * C |
- K ---- --------- ---------
- K R R
- """
- gm = symbolic_trace(in_mod)
- rhs_users: Dict[Node, List[Node]] = {}
- lhs_users: Dict[Node, List[Node]] = {}
- # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to
- # the matmul of which they are the LHS/RHS.
- for node in gm.graph.nodes:
- if node.op != "call_function" or node.target is not torch.matmul:
- continue
- lhs, rhs = node.args
- # TODO: Properly handle aliasing caused by get_attr. For now,
- # use the attribute name as the operand if the node is a
- # get_attr.
- lhs = lhs.target if lhs.op == "get_attr" else lhs
- rhs = rhs.target if rhs.op == "get_attr" else rhs
- lhs_users.setdefault(lhs, []).append(node)
- rhs_users.setdefault(rhs, []).append(node)
- for rhs, mms in rhs_users.items():
- # There must be at least matmuls for a merge to make sense.
- if len(mms) < 2:
- continue
- # All matmuls must not depend on each other directly or indirectly
- # in order for the merge to be possible.
- if not are_nodes_independent(mms):
- continue
- lhs_vals = [mm.args[0] for mm in mms]
- # Merge the matmul.
- # Collect a list of LHS operands and the single RHS operand.
- lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals]
- rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs
- # Concatenate all the LHS operands.
- merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {})
- # Multiply the concatenated LHS operands with the one RHS. This will produce
- # the same results as all the individual matmuls involving rhs in the original graph,
- # but they will all be concatenated together.
- merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {})
- # Split the result of the merged matmul using the shapes of the LHS operands
- # to ascertain how large each chunk should be.
- merge_mm_split = gm.graph.call_function(
- split_result_tensors, (merge_mm, lhs), {}
- )
- merge_mm_res = [
- gm.graph.call_function(operator.getitem, (merge_mm_split, out), {})
- for out in range(len(lhs))
- ]
- # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul.
- for old, new in zip(mms, merge_mm_res):
- old.replace_all_uses_with(new)
- gm.graph.erase_node(old)
- # All of the new nodes created above were inserted at the end, so we need to sort
- # the nodes topologically to make sure all definitions precede uses.
- legalize_graph(gm)
- gm.recompile()
- gm.graph.lint()
- return gm
|