compile_utils.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import torch
  2. import torch.fx as fx
  3. from torch.utils._pytree import tree_flatten
  4. aten = torch.ops.aten
  5. def get_aten_target(node):
  6. if hasattr(node.target, 'overloadpacket'):
  7. return node.target.overloadpacket
  8. return node.target
  9. rand_ops = [aten.dropout, aten._fused_dropout, aten._standard_gamma,
  10. aten.bernoulli, aten.multinomial, aten.native_dropout,
  11. aten.normal, aten.poisson, aten.binomial, aten.rrelu,
  12. aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm]
  13. # return a new copy of torch.fx.graph.Graph with CSE applied to the input graph
  14. def fx_graph_cse(fx_g: torch.fx.graph.Graph):
  15. new_graph = fx.Graph()
  16. env = {} # map from node in the old graph to node in the new graph
  17. hash_env = {} # map from hash to a node in the new graph
  18. token_map = {} # map from hash to token
  19. for n in fx_g.nodes:
  20. # The placeholder, output, and get_attr nodes are copied to the new grpah without change
  21. # do not CSE away random operations
  22. if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in rand_ops:
  23. new_node = new_graph.node_copy(n, lambda x: env[x])
  24. env[n] = new_node
  25. else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method'
  26. # substitute args and kwargs memebrs to their mapping in env if exists
  27. # specs can be used to reconstruct nested list/dictionaries
  28. def substitute(arg_list):
  29. arg_list, spec = tree_flatten(arg_list)
  30. for i in range(len(arg_list)):
  31. v = arg_list[i]
  32. if isinstance(v, torch.fx.node.Node) and v in env:
  33. arg_list[i] = env[v]
  34. return tuple(arg_list), spec
  35. args, args_spec = substitute(n.args)
  36. kwargs, kwargs_spec = substitute(n.kwargs)
  37. # each token corresponds to a unique node
  38. # nodes with the same token can be substituted
  39. token = {"target": n.target, "args": args, "args_spec": args_spec,
  40. "kwargs": kwargs, "kwargs_spec": kwargs_spec}
  41. # hash substituted args to a number, do not hash specs because specs are not hashable
  42. hash_arg = hash((args, kwargs))
  43. hash_val = (n.target, hash_arg)
  44. # check if a node has a substitute and can be eliminated
  45. hash_val_in_hash_env = hash_val in hash_env
  46. if hash_val_in_hash_env and token_map[hash_val] == token:
  47. env[n] = hash_env[hash_val]
  48. continue
  49. new_node = new_graph.node_copy(n, lambda x: env[x])
  50. env[n] = new_node
  51. if not hash_val_in_hash_env:
  52. hash_env[hash_val] = new_node
  53. token_map[hash_val] = token
  54. return new_graph
  55. def strip_overloads(gm):
  56. """
  57. Modifies the target of graph nodes in :attr:`gm` to strip overloads.
  58. Args:
  59. gm(fx.GraphModule): The input Fx graph module to be modified
  60. """
  61. for node in gm.graph.nodes:
  62. if isinstance(node.target, torch._ops.OpOverload):
  63. node.target = node.target.overloadpacket
  64. gm.recompile()
  65. def get_placeholders(graph):
  66. return list(filter(lambda x: x.op == 'placeholder', graph.nodes))
  67. def get_outputs(graph):
  68. for node in graph.nodes:
  69. if node.op == 'output':
  70. return tree_flatten(node.args[0])[0]
  71. raise AssertionError("No output node found")