const_fold.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import re
  2. from typing import Callable, Dict, Optional, Set, Union
  3. import torch.fx
  4. from torch.fx.node import map_arg
  5. from torch.fx.passes.split_module import split_module
  6. __all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs']
  7. class FoldedGraphModule(torch.fx.GraphModule):
  8. """
  9. FoldedGraphModule is a GraphModule which also contains another
  10. `const_subgraph_module` representing a subgraph which has all const attr
  11. inputs and which can be run once before running the main standard
  12. `graph`. The `const_output_names` are the ordered list names of attrs which
  13. represent what each respective output from the const_subgraph should be set
  14. on which attrs.
  15. """
  16. def __init__(
  17. self,
  18. root: torch.nn.Module,
  19. graph: torch.fx.Graph,
  20. const_subgraph: Optional[torch.fx.Graph] = None,
  21. fx_const_folded_attrs_name: str = None,
  22. device_for_folded_attrs: str = "cuda",
  23. ):
  24. super().__init__(root, graph)
  25. self.const_subgraph_module = (
  26. None
  27. if const_subgraph is None
  28. else torch.fx.GraphModule(root, const_subgraph)
  29. )
  30. self.has_folding_been_run = False
  31. self.fx_const_folded_attrs_name = fx_const_folded_attrs_name
  32. self.device_for_folded_attrs = device_for_folded_attrs
  33. def __call__(self, *args, **kwargs):
  34. if not self.has_folding_been_run:
  35. self.run_folding()
  36. return super().__call__(*args)
  37. def run_folding(self):
  38. # If there's no const subgraph module or attr output names to use, return
  39. # early as there is no const folding to perform.
  40. if (
  41. self.const_subgraph_module is None
  42. or self.fx_const_folded_attrs_name is None
  43. ):
  44. return
  45. assert not self.has_folding_been_run
  46. self.has_folding_been_run = True
  47. # Actually run const folding subgraph. Note that single attr const fold
  48. # subgraphs output a single Tensor while multiple outputs are returned as
  49. # Tuple[Tensor,].
  50. folded_attrs = self.const_subgraph_module()
  51. def _create_param(i):
  52. return torch.nn.Parameter(
  53. i
  54. if not isinstance(i, int)
  55. else torch.Tensor([i]).to(device=self.device_for_folded_attrs),
  56. requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False,
  57. )
  58. params = (
  59. torch.nn.ParameterList([_create_param(i) for i in folded_attrs])
  60. if isinstance(folded_attrs, tuple)
  61. else _create_param(folded_attrs)
  62. )
  63. setattr(self, self.fx_const_folded_attrs_name, params)
  64. def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str):
  65. """
  66. Given `gm` and some graph module which is called with target name `inline_mod_name`,
  67. this helper will inline all of the nodes from that called graph module into `gm`.
  68. """
  69. # Fetch the inner graph module that we want to inline inside `gm`.
  70. inline_mod = dict(gm.named_modules())[inline_mod_name]
  71. assert isinstance(inline_mod, torch.fx.GraphModule)
  72. call_mod_node_to_replace = None
  73. for node in gm.graph.nodes:
  74. if node.op == "call_module" and node.target == inline_mod_name:
  75. call_mod_node_to_replace = node
  76. break
  77. assert call_mod_node_to_replace is not None
  78. # Now actually do the swap. Note that we have to keep track of new nodes that are
  79. # copied into `gm` -- we do this via replacement_mapping.
  80. call_mod_args = call_mod_node_to_replace.args
  81. replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {}
  82. ph_count = 0
  83. def replacement_fn(node):
  84. new_node = replacement_mapping[node]
  85. new_node.meta = node.meta.copy()
  86. return new_node
  87. for inline_node in inline_mod.graph.nodes:
  88. if inline_node.op == "placeholder":
  89. replacement_mapping[inline_node] = call_mod_args[ph_count]
  90. ph_count += 1
  91. continue
  92. if inline_node.op == "output":
  93. outputs = inline_node.args[0]
  94. output_replacements = map_arg(outputs, replacement_fn)
  95. call_mod_node_to_replace.replace_all_uses_with(output_replacements)
  96. continue
  97. with gm.graph.inserting_before(call_mod_node_to_replace):
  98. new_node = gm.graph.node_copy(inline_node, replacement_fn)
  99. replacement_mapping[inline_node] = new_node
  100. gm.graph.eliminate_dead_code()
  101. def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str:
  102. """
  103. Make sure the name is unique (in a module) and can represents an attr.
  104. """
  105. # Delete all characters that are illegal in a Python identifier.
  106. name = re.sub("[^0-9a-zA-Z_]+", "_", name)
  107. if name[0].isdigit():
  108. name = f"_{name}"
  109. # Now make sure it is in fact unique to the module by incrementing suffix value.
  110. while hasattr(mod_traced, name):
  111. match = re.match(r"(.*)_(\d+)$", name)
  112. if match is None:
  113. name = name + "_1"
  114. else:
  115. base, num = match.group(1, 2)
  116. name = f"{base}_{int(num) + 1}"
  117. return name
  118. def split_const_subgraphs(
  119. module: Union[torch.nn.Module, torch.fx.GraphModule],
  120. skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
  121. device_for_folded_attrs: str = "cpu",
  122. ) -> FoldedGraphModule:
  123. """
  124. Looks through `module` for any nodes that have all constant attribute inputs
  125. and separates them out into their own constant subgraph, and returns a
  126. FoldedGraphModule which runs that constant subgraph on the first run to set
  127. attributes on the module prior to running the non-constant portion of the
  128. graph.
  129. """
  130. if not isinstance(module, torch.fx.GraphModule):
  131. mod_traced = torch.fx.symbolic_trace(module)
  132. else:
  133. mod_traced = module
  134. # Build up a list of const_nodes, defined as nodes that are themselves
  135. # get_attrs, or have all get_attr or other constant node inputs.
  136. const_nodes: Set[torch.fx.Node] = set()
  137. found_const_folding = False
  138. for node in mod_traced.graph.nodes:
  139. # Skip over placeholders/outputs because they can't be const folded and
  140. # we don't want to add tags to them.
  141. if node.op in {"placeholder", "output"}:
  142. continue
  143. # If the node itself is constant, or all of its inputs are constant,
  144. # then tag it as constant.
  145. if node.op != "get_attr" and not set(node.all_input_nodes).issubset(
  146. const_nodes
  147. ):
  148. continue
  149. # If provided skip folding function says to skip, then skip.
  150. if skip_folding_node_fn and skip_folding_node_fn(node):
  151. continue
  152. # Skip folding side-effectful functions
  153. if node.is_impure():
  154. continue
  155. # Must be a constant foldable node at this point.
  156. const_nodes.add(node)
  157. if node.op != "get_attr":
  158. found_const_folding = True
  159. # If we did not find any const folding then return early without a const fold subgraph.
  160. if not found_const_folding:
  161. return FoldedGraphModule(mod_traced, mod_traced.graph)
  162. # Partition the module into two: submod_0 for constant folding subgraph, and
  163. # submod_1 for the rest.
  164. def mod_partition(node: torch.fx.Node):
  165. return 0 if node in const_nodes else 1
  166. split = split_module(mod_traced, module, mod_partition)
  167. const_gm, non_const_gm = split.submod_0, split.submod_1
  168. const_mod_name, non_const_mod_name = "submod_0", "submod_1"
  169. # The module that a call_module node refers to gets copied to submodules during split.
  170. # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to
  171. # attach inlined modules to `split` as it's the owning module now.
  172. for node in non_const_gm.graph.nodes:
  173. if node.op == "call_module":
  174. setattr(split, node.target, getattr(non_const_gm, node.target))
  175. for node in const_gm.graph.nodes:
  176. if node.op == "call_module":
  177. setattr(split, node.target, getattr(const_gm, node.target))
  178. # split_module currently does not use get_attrs for attrs. Instead it passes
  179. # them in as args from the parent module, which used get_attrs. Here we set
  180. # them as get_attrs inside const_gm, allowing for running folding without
  181. # somehow a priori knowing the attrs that should be passed as args. We can
  182. # unconditionally do this for all placeholders because we know all
  183. # placeholders to const_gm must be constants accessible via get_attr.
  184. call_const_gm_args = None
  185. for node in split.graph.nodes:
  186. if node.op == "call_module":
  187. if node.target == const_mod_name:
  188. call_const_gm_args = node.args
  189. break
  190. assert call_const_gm_args is not None
  191. # Here we do the actual replacement of placeholders to get_attrs. Note that here we
  192. # set the const_gm.graph into a new root_const_gm with split as the root module,
  193. # because we are fetching attributes directly from the root module, instead of
  194. # fetching them from const_gm. Example: The const_gm must have some format like:
  195. # graph():
  196. # %inp : [#users=1] = placeholder[target=const_inp]
  197. # %add : [#users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {})
  198. # return add
  199. # We replace that with the following, which does not have any placeholders:
  200. # graph():
  201. # %inp_1 : [#users=1] = get_attr[target=const_inp]
  202. # %add : [#users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {})
  203. # return add
  204. root_const_gm = torch.fx.GraphModule(split, const_gm.graph)
  205. for node in root_const_gm.graph.nodes:
  206. if node.op == "output":
  207. multiple_outputs = isinstance(node.args[0], tuple)
  208. continue
  209. if node.op != "placeholder":
  210. continue
  211. in_node = next(n for n in call_const_gm_args if n.name == node.target)
  212. assert in_node.op == "get_attr"
  213. with root_const_gm.graph.inserting_before(node):
  214. new_node = root_const_gm.graph.get_attr(in_node.target)
  215. new_node.meta = node.meta.copy()
  216. node.replace_all_uses_with(new_node)
  217. root_const_gm.graph.erase_node(node)
  218. assert "multiple_outputs" in locals()
  219. # Now find the call to const_gm inside split, and replace it with a getattr to the
  220. # folded tensor(s) that result from constant folding. Note that we don't need to
  221. # worry about whether this is one or more tensors because the original graph
  222. # correctly uses getitem to extract individual tensors if there are multiple folded.
  223. fx_const_folded_attrs_name = get_unique_attr_name_in_module(
  224. split, "_FX_CONST_FOLDED_ATTRS"
  225. )
  226. setattr(
  227. split,
  228. fx_const_folded_attrs_name,
  229. torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(),
  230. )
  231. for node in split.graph.nodes:
  232. if node.op == "call_module" and node.target == const_mod_name:
  233. with node.graph.inserting_before(node):
  234. folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name)
  235. folded_attrs.meta = node.meta.copy()
  236. node.replace_all_uses_with(folded_attrs)
  237. break
  238. split.graph.eliminate_dead_code()
  239. # Finally, inline the non-constant submod into the split submod. This is so that the
  240. # original caller who may have passed in a graph module will get back out a graph
  241. # module whose graph is traced to the same granularity.
  242. _inline_module(split, non_const_mod_name)
  243. return FoldedGraphModule(
  244. split,
  245. split.graph,
  246. root_const_gm.graph,
  247. fx_const_folded_attrs_name,
  248. device_for_folded_attrs,
  249. )