partitioner.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. from typing import Dict, List, Set, Iterable, Sequence, Optional, Deque
  2. from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
  3. from torch.fx.graph_module import GraphModule
  4. from torch.fx.node import Node, _get_qualified_name
  5. from torch.fx.passes.operator_support import OperatorSupportBase
  6. import logging
  7. import itertools
  8. from copy import copy
  9. from collections import deque
  10. logger = logging.getLogger(__name__)
  11. logger.setLevel(logging.WARNING)
  12. class Partition:
  13. def __init__(self, id: int = None, nodes: Iterable[Node] = None):
  14. self.id = id
  15. self.nodes: Set[Node] = set(nodes) if nodes is not None else set()
  16. def __repr__(self) -> str:
  17. return str(self.nodes)
  18. def add_node(self, node: Node):
  19. self.nodes.add(node)
  20. def remove_node(self, node: Node):
  21. self.nodes.remove(node)
  22. def size(self):
  23. return len(self.nodes)
  24. class CapabilityBasedPartitioner:
  25. def __init__(self,
  26. graph_module: GraphModule,
  27. operator_support: OperatorSupportBase,
  28. allows_single_node_partition: bool = False,
  29. non_compute_ops: Optional[Sequence[str]] = None,
  30. allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
  31. ) -> None:
  32. self.graph_module = graph_module
  33. self.operator_support = operator_support
  34. self.allows_single_node_partition = allows_single_node_partition
  35. self.non_compute_ops = non_compute_ops if non_compute_ops is not None else []
  36. self.allowed_single_node_partition_ops = (
  37. allowed_single_node_partition_ops
  38. if allowed_single_node_partition_ops is not None
  39. else []
  40. )
  41. def __is_node_supported(self, node: Node) -> bool:
  42. return (
  43. self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
  44. )
  45. def propose_partitions(self) -> List[Partition]:
  46. # assumptions: nodes in candidate list is sorted in topological order
  47. assignment: Dict[Node, int] = {} # maping from node to partition_id
  48. partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition
  49. new_partition_id = itertools.count()
  50. # try to merge partition other_id into partition self_id
  51. # merge only happens if the end graph doesn't contain cyclic dependency
  52. # returns `True` when merge happens, `False` otherwise.
  53. def maybe_merge_partition(self_id: int, other_id: int):
  54. # merged_nodes is the union of nodes in two partition to-be-merged
  55. merged_nodes = copy(partitions_by_id[self_id].nodes)
  56. merged_nodes.update(partitions_by_id[other_id].nodes)
  57. # Note it's ok to use `set` here, since we are only query if a node
  58. # has been visited. We are NEVER going to iterate on nodes inside
  59. # the set.
  60. visited: Set[Node] = set()
  61. def dfs_iter_find_cycle(root_node):
  62. stack : Deque[Node] = deque()
  63. stack.append(root_node)
  64. while stack:
  65. node = stack.pop()
  66. if node in visited:
  67. continue
  68. if node in merged_nodes:
  69. return True # found cycle, return
  70. # branching on hitting partition or not
  71. if node in assignment:
  72. # Since partition is not merged in the graph yet, when we
  73. # hit a node in a partition through DFS, we need to
  74. # traverse all nodes in the partition to properly reflect
  75. # dependencies after the fusion
  76. for p_node in partitions_by_id[assignment[node]].nodes:
  77. for user_node in p_node.users:
  78. if user_node not in partitions_by_id[assignment[node]].nodes:
  79. stack.append(user_node)
  80. else:
  81. for user_node in node.users:
  82. stack.append(user_node)
  83. visited.add(node)
  84. return False
  85. # check if merge would create cyclic dependency.
  86. for node in merged_nodes:
  87. for user_node in node.users:
  88. if user_node not in merged_nodes and dfs_iter_find_cycle(user_node):
  89. # return false indicating cyclic dependency found and
  90. # merge is aborted
  91. return False
  92. # no cyclic dependency found, move forward with the merge
  93. # updating partition nodes
  94. partitions_by_id[self_id].nodes = merged_nodes
  95. # updating assignment map
  96. for node in partitions_by_id[other_id].nodes:
  97. assignment[node] = self_id
  98. # delete other partition
  99. del partitions_by_id[other_id]
  100. return True
  101. def merge_single_node(node: Node, id: Optional[int]):
  102. if node in assignment:
  103. partitions_by_id[assignment[node]].remove_node(node)
  104. if id is None:
  105. assignment.pop(node)
  106. elif id not in partitions_by_id:
  107. assignment[node] = id
  108. partitions_by_id[id] = Partition(id=id, nodes=[node])
  109. else:
  110. assignment[node] = id
  111. partitions_by_id[id].add_node(node)
  112. logger.debug("Proposing partitions...")
  113. for node in reversed(self.graph_module.graph.nodes):
  114. # use Dict as an ordered set to ensure deterministic partitioning result, don't care value
  115. merge_candidates: Dict[int, None] = {}
  116. # Note a limited horizontal fusion is enabled:
  117. # when `node` is not supported, the code below attempts to fuse consumer of `node`.
  118. #
  119. # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
  120. # the fusion by adding an `else` block here to skip horizontal fusion.
  121. if self.__is_node_supported(node) and node not in assignment:
  122. partition_id = next(new_partition_id)
  123. merge_single_node(node, partition_id)
  124. merge_candidates[partition_id] = None
  125. for user_node in node.users:
  126. if user_node in assignment:
  127. merge_candidates[assignment[user_node]] = None
  128. merge_candidates_list = list(merge_candidates.keys())
  129. if len(merge_candidates_list) > 1:
  130. self_id = merge_candidates_list[0]
  131. for other_id in merge_candidates_list[1:]:
  132. # note: merge partition `other_id` into partition `self_id` if
  133. # it doesn't create cyclic depenency in the graph, otherwise,
  134. # this is a no-op
  135. maybe_merge_partition(self_id, other_id)
  136. # post processing to re-assign "getitem" nodes into upstream partition
  137. logger.debug("Reassigning getitem nodes to its producer node's partition...")
  138. nodes_reassignment: Dict[Node, int] = {}
  139. for node in self.graph_module.graph.nodes:
  140. is_tuple_output = True
  141. for user in node.users:
  142. if user.op != "call_function" or \
  143. _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type]
  144. is_tuple_output = False
  145. break
  146. # node has tuple outputs, re-assign all following getitem node into node's partition
  147. if is_tuple_output:
  148. id = assignment.get(node, None) # type: ignore[arg-type]
  149. for user in node.users:
  150. if assignment.get(user, None) != id: # type: ignore[arg-type]
  151. nodes_reassignment[user] = id # type: ignore[assignment]
  152. for node, id in nodes_reassignment.items():
  153. merge_single_node(node, id)
  154. # filter out single node partitions
  155. if not self.allows_single_node_partition:
  156. logger.debug("Filtering out single node partitions...")
  157. default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
  158. non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
  159. partitions_to_remove: List[int] = []
  160. for id, partition in partitions_by_id.items():
  161. compute_node_count = 0
  162. for node in partition.nodes:
  163. if node.op == "call_function" and \
  164. _get_qualified_name(node.target) not in non_compute_ops: # type: ignore[arg-type]
  165. compute_node_count += 1
  166. if node.op == "call_function" and \
  167. _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
  168. compute_node_count += 1
  169. if compute_node_count <= 1:
  170. partitions_to_remove.append(id)
  171. for id in partitions_to_remove:
  172. del partitions_by_id[id]
  173. logger.debug("Partitions proposed:")
  174. for id, partition in partitions_by_id.items():
  175. logger.debug(f"partition #{id}", [node.name for node in partition.nodes])
  176. return list(partitions_by_id.values())
  177. def fuse_partitions(self, partitions: List[Partition]) -> GraphModule:
  178. logger.debug("Fusing partitions...")
  179. # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ]
  180. return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions])
  181. # remove non-compute-ops that sits at the boundary of a partition.
  182. def remove_bookend_non_compute_ops(self, partitions: List[Partition]):
  183. non_compute_ops = set(self.non_compute_ops)
  184. def is_non_compute_node(node: Node):
  185. return node.op == "call_function" and \
  186. _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type]
  187. # cache transparent nodes
  188. transparent_input_nodes: Dict[Node, bool] = {}
  189. transparent_output_nodes: Dict[Node, bool] = {}
  190. def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
  191. if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
  192. return True
  193. if node in transparent_input_nodes:
  194. return transparent_input_nodes[node]
  195. if is_non_compute_node(node):
  196. for input_n in node.all_input_nodes:
  197. if not is_transparent_input_node(input_n, partition, removed_nodes):
  198. transparent_input_nodes[node] = False
  199. return False
  200. transparent_input_nodes[node] = True
  201. return True
  202. transparent_input_nodes[node] = False
  203. return False
  204. def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
  205. if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
  206. return True
  207. if node in transparent_output_nodes:
  208. return transparent_output_nodes[node]
  209. if is_non_compute_node(node):
  210. for output_n in node.users:
  211. if not is_transparent_output_node(output_n, partition, removed_nodes):
  212. transparent_output_nodes[node] = False
  213. return False
  214. transparent_output_nodes[node] = True
  215. return True
  216. transparent_output_nodes[node] = False
  217. return False
  218. for partition in partitions:
  219. # Note it's ok to use `set` here, since we are only query if a node
  220. # has been removed. We are NEVER going to iterate on nodes inside
  221. # the set.
  222. remove_node: Set[Node] = set()
  223. for node in partition.nodes:
  224. if is_non_compute_node(node) and \
  225. (is_transparent_input_node(node, partition.nodes, remove_node) or
  226. is_transparent_output_node(node, partition.nodes, remove_node)):
  227. remove_node.add(node)
  228. if len(remove_node) != 0:
  229. partition.nodes = partition.nodes - remove_node
  230. def partition_and_fuse(self) -> GraphModule:
  231. partitions = self.propose_partitions()
  232. fused_gm = self.fuse_partitions(partitions)
  233. return fused_gm