cse_pass.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from typing import Dict, Tuple, Any
  2. import torch
  3. from torch.fx.passes.infra.pass_base import PassBase, PassResult
  4. from torch.utils._pytree import tree_flatten
  5. from torch.fx import GraphModule, Graph
  6. from torch.fx import Node
  7. aten = torch.ops.aten
  8. # stateful ops are banned from CSE
  9. rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501
  10. inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501
  11. @torch.fx._compatibility.compatibility(is_backward_compatible=False)
  12. def get_CSE_banned_ops():
  13. return rand_ops.union(inplace_ops)
  14. @torch.fx._compatibility.compatibility(is_backward_compatible=False)
  15. class CSEPass(PassBase):
  16. def __init__(self, banned_ops=None):
  17. """
  18. This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node.
  19. For functional dialects, user would only need to specify the random ops in ban list.
  20. Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects.
  21. If your dialect contains stateful operators, please customized the banned_ops.
  22. """
  23. if banned_ops is None:
  24. banned_ops = set()
  25. self.banned_ops = banned_ops
  26. super().__init__()
  27. def call(self, graph_module: GraphModule) -> PassResult:
  28. """
  29. Return a new copy of torch.fx.GraphModule with CSE applied to the input graph
  30. Example usage:
  31. from torch.fx.experimental.proxy_tensor import make_fx
  32. def f(a):
  33. b = a * a
  34. c = a * a
  35. return b+c
  36. p = CSEPass()
  37. traced_graph = make_fx(f)(torch.tensor(1))
  38. print(traced_graph)
  39. result = p(traced_graph)
  40. print(result.graph_module)
  41. """
  42. def get_aten_target(node):
  43. if hasattr(node.target, 'overloadpacket'):
  44. return node.target.overloadpacket
  45. return node.target
  46. modified = False
  47. new_graph = Graph()
  48. env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph
  49. hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph
  50. token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token
  51. for n in graph_module.graph.nodes:
  52. # The placeholder, output, and get_attr nodes are copied to the new grpah without change
  53. # do not CSE away random operations
  54. if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops:
  55. new_node = new_graph.node_copy(n, lambda x: env[x])
  56. env[n] = new_node
  57. else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
  58. # substitute args and kwargs memebrs to their mapping in env if exists
  59. # specs can be used to reconstruct nested list/dictionaries
  60. def substitute(arg_list):
  61. arg_list, spec = tree_flatten(arg_list)
  62. for i in range(len(arg_list)):
  63. v = arg_list[i]
  64. if isinstance(v, Node) and v in env:
  65. arg_list[i] = env[v]
  66. return tuple(arg_list), spec
  67. args, args_spec = substitute(n.args)
  68. kwargs, kwargs_spec = substitute(n.kwargs)
  69. # each token corresponds to a unique node
  70. # nodes with the same token can be substituted
  71. token = {"target": n.target, "args": args, "args_spec": args_spec,
  72. "kwargs": kwargs, "kwargs_spec": kwargs_spec}
  73. # hash substituted args to a number, do not hash specs because specs are not hashable
  74. hash_arg = hash((args, kwargs))
  75. hash_val = (n.target, hash_arg)
  76. # check if a node has a substitute and can be eliminated
  77. hash_val_in_hash_env = hash_val in hash_env
  78. if hash_val_in_hash_env and token_map[hash_val] == token:
  79. modified = True # substition happens and the graph is modified
  80. env[n] = hash_env[hash_val]
  81. continue
  82. new_node = new_graph.node_copy(n, lambda x: env[x])
  83. env[n] = new_node
  84. if not hash_val_in_hash_env:
  85. hash_env[hash_val] = new_node
  86. token_map[hash_val] = token
  87. csed_gm = GraphModule(graph_module, new_graph)
  88. return PassResult(csed_gm, modified)