partitioner_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. from enum import Enum
  2. from typing import NamedTuple, Dict, List, Set
  3. from torch.fx.node import Node, map_arg
  4. class Partition:
  5. """Partition class contains all the information about an individual partition.
  6. It also provides necessary methods for manipulation the partition.
  7. """
  8. def __init__(self, partition_id: int) -> None:
  9. self.nodes: Set[Node] = set()
  10. self.partition_id = partition_id
  11. self.parents: Set["Partition"] = set()
  12. self.children: Set["Partition"] = set()
  13. self.bfs_level: int = -1
  14. self.used_mem_bytes: int = 0
  15. self.logical_device_ids: List[int] = []
  16. def __str__(self):
  17. return str(self.partition_id)
  18. def recalculate_mem_size(self):
  19. self.used_mem_bytes = 0
  20. for node in self.nodes:
  21. self.used_mem_bytes += get_extra_size_of(node, self.nodes)
  22. def add_node(self, node):
  23. input_nodes: Dict[Node, None] = {}
  24. map_arg(node.args, lambda n: input_nodes.setdefault(n))
  25. map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
  26. # Add current node's input nodes if they are placeholder or constants
  27. for n in input_nodes:
  28. if n.op in {"placeholder", "get_attr"}:
  29. self.nodes.add(n)
  30. self.nodes.add(node)
  31. self.recalculate_mem_size()
  32. def remove_node(self, node):
  33. # Remove a node only if the node is in the partition
  34. if node in self.nodes:
  35. self.nodes.remove(node)
  36. # Collect the node's input nodes
  37. input_nodes: Dict[Node, None] = {}
  38. map_arg(node.args, lambda n: input_nodes.setdefault(n))
  39. map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
  40. # Check if an input node is a placeholder or get_attr,
  41. # and this input node is not used by some other nodes in this partition,
  42. # the remove this input node
  43. for input_node in input_nodes:
  44. if all(
  45. [n not in self.nodes for n in input_node.users]
  46. ) and input_node.op in {"placeholder", "get_attr"}:
  47. self.nodes.remove(input_node)
  48. self.recalculate_mem_size()
  49. class Device(NamedTuple):
  50. name: str
  51. available_mem_bytes: int
  52. logical_id: int
  53. class NodeLatency(NamedTuple):
  54. # Latency due to the memory bandwidth
  55. mem_latency_sec: float
  56. # Latency due to the computation
  57. computer_latency_sec: float
  58. class PartitionLatency(NamedTuple):
  59. # Sum of all nodes' memory latency on the critical path
  60. mem_latency_sec: float
  61. # Sum of all nodes' compute latency on the critical path
  62. computer_latency_sec: float
  63. # Latency of the critical path
  64. overall_latency_sec: float
  65. class PartitionMode(Enum):
  66. size_based = 0
  67. sparse_nn = 1
  68. cost_aware = 2
  69. kl_based = 3
  70. aot_based = 4
  71. class PartitionerConfig(NamedTuple):
  72. devices: List[Device]
  73. mode: PartitionMode = PartitionMode.size_based
  74. transfer_rate_bytes_per_sec: float = 0.0
  75. node_to_latency_mapping: Dict[Node, NodeLatency] = {}
  76. node_to_partition_mapping: Dict[Node, int] = {}
  77. partition_to_logical_device_mapping: Dict[int, List[int]] = {}
  78. # Saturate host by replicating partitions to the remaining idle devices.
  79. saturate_host: bool = False
  80. def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
  81. """Given a node and a set of nodes,
  82. this function return the extra size that needed
  83. if this node is included in this set.
  84. """
  85. # Find all its input nodes
  86. input_nodes: Dict[Node, None] = {}
  87. map_arg(node.args, lambda n: input_nodes.setdefault(n))
  88. map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
  89. # Calculate total size of related nodes
  90. total_size_of_input_nodes = 0
  91. for n in input_nodes:
  92. # Make sure this node hasn't been in this set yet
  93. if n not in nodes:
  94. size_bytes = getattr(n, "size_bytes", None)
  95. if size_bytes:
  96. total_size_of_input_nodes += size_bytes.output_size
  97. else:
  98. raise RuntimeError("node has no size_bytes attr")
  99. # Don't forget the op node itself
  100. size_bytes = getattr(node, "size_bytes", None)
  101. if size_bytes:
  102. total_size_of_input_nodes += size_bytes.total_size
  103. else:
  104. raise RuntimeError("node has no size_bytes attr")
  105. return total_size_of_input_nodes
  106. def get_latency_of_one_partition(
  107. partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency]
  108. ) -> PartitionLatency:
  109. """Given a partiton and its nodes' latency, return a PartitionLatency for this partition"""
  110. def get_top_nodes(partition: Partition) -> List[Node]:
  111. """Given a partition, return a list of nodes on the top bfs level"""
  112. top_nodes: List[Node] = []
  113. for node in partition.nodes:
  114. # Skip placeholder and get_attr nodes
  115. if node.op in {"placeholder", "get_attr"}:
  116. continue
  117. input_nodes: Dict[Node, None] = {}
  118. map_arg(node.args, lambda n: input_nodes.setdefault(n))
  119. map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
  120. # If a node has no input nodes in this partition,
  121. # or its input nodes in this partition are placeholders and get_attrs
  122. # this node is on the top bfs level in this partition
  123. if not any(
  124. [
  125. n in partition.nodes and n.op not in {"placeholder", "get_attr"}
  126. for n in input_nodes
  127. ]
  128. ):
  129. top_nodes.append(node)
  130. return top_nodes
  131. def dfs_helper(node: Node, partition_latency) -> PartitionLatency:
  132. """Given a top node of a partition, this function returns
  133. the latency of the critical path in the partition
  134. """
  135. node_latency = node_to_latency_mapping[node]
  136. # Calculate the current overall latency of the partition
  137. overall_latency_sec = partition_latency.overall_latency_sec + max(
  138. node_latency.computer_latency_sec, node_latency.mem_latency_sec
  139. )
  140. # Update the mem latency of this path
  141. mem_latency_sec = (
  142. partition_latency.mem_latency_sec + node_latency.mem_latency_sec
  143. )
  144. # Update the compute latency of this path
  145. computer_latency_sec = (
  146. partition_latency.computer_latency_sec + node_latency.computer_latency_sec
  147. )
  148. # Get all users of this node that are in this partition
  149. users = set(node.users).intersection(partition.nodes)
  150. if users:
  151. max_latency = PartitionLatency(
  152. mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
  153. )
  154. for n in users:
  155. # Get new partition latency recursively
  156. new_partition_latency = dfs_helper(
  157. n,
  158. PartitionLatency(
  159. mem_latency_sec, computer_latency_sec, overall_latency_sec
  160. ),
  161. )
  162. if (
  163. new_partition_latency.overall_latency_sec
  164. > max_latency.overall_latency_sec
  165. ):
  166. max_latency = new_partition_latency
  167. return max_latency
  168. # If there is no user, the node is at bottom of the partition
  169. return PartitionLatency(
  170. mem_latency_sec, computer_latency_sec, overall_latency_sec
  171. )
  172. # Main part starts
  173. # Get all top level nodes of this partition
  174. top_nodes = get_top_nodes(partition)
  175. critical_path_latency = PartitionLatency(
  176. mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
  177. )
  178. # Go through all top nodes and find the largest latency (critical pass latency)
  179. for node in top_nodes:
  180. partition_latency = dfs_helper(
  181. node,
  182. PartitionLatency(
  183. mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
  184. ),
  185. )
  186. if (
  187. partition_latency.overall_latency_sec
  188. > critical_path_latency.overall_latency_sec
  189. ):
  190. critical_path_latency = partition_latency
  191. return critical_path_latency
  192. def get_partition_to_latency_mapping(
  193. partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency]
  194. ) -> Dict[Partition, PartitionLatency]:
  195. """Given all the partitions and node_to_latency_mapping dictionary,
  196. return a mapping dictionary of each partition to its overall latency
  197. """
  198. partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {}
  199. # Go through each partition and get its latency
  200. for partition in partitions:
  201. partition_latency = get_latency_of_one_partition(
  202. partition, node_to_latency_mapping
  203. )
  204. partition_to_latency_mapping[partition] = partition_latency
  205. return partition_to_latency_mapping
  206. def get_comm_latency_between(
  207. parent_partition: Partition,
  208. child_partition: Partition,
  209. transfer_rate_bytes_per_sec: float,
  210. ):
  211. """Given two partitions (parent and child),
  212. calculate the communication latency between the two.
  213. """
  214. # If two partitions are on the same device, the comm latency is 0.
  215. if (
  216. parent_partition.logical_device_ids != []
  217. and child_partition.logical_device_ids != []
  218. and parent_partition.logical_device_ids == child_partition.logical_device_ids
  219. ):
  220. return 0.0
  221. # Keep tracking the communication size between parent and child
  222. comm_size = 0
  223. # Keep tracking all the counted node
  224. visited_nodes = set()
  225. # Go through all nodes in the child partition
  226. # If a node has input nodes from the parent partition,
  227. # the output size of those input nodes will be counted
  228. # and added to comm_size
  229. for node in child_partition.nodes:
  230. input_nodes: Dict[Node, None] = {}
  231. map_arg(node.args, lambda n: input_nodes.setdefault(n))
  232. map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
  233. for n in input_nodes:
  234. if n in parent_partition.nodes and n not in visited_nodes:
  235. size_bytes = getattr(n, "size_bytes", None)
  236. if size_bytes is not None:
  237. comm_size += size_bytes.output_size
  238. visited_nodes.add(n)
  239. return comm_size / transfer_rate_bytes_per_sec
  240. def get_latency_of_partitioned_graph(
  241. partitions: List[Partition],
  242. partition_to_latency_mapping: Dict[Partition, PartitionLatency],
  243. transfer_rate_bytes_per_sec: float,
  244. ):
  245. """Given all paritions in a graph, find the critical path among all partitions
  246. and return its latency as the latency of the whole graph
  247. """
  248. def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
  249. """This function helps to recursively get the latency of a path of partitions"""
  250. # Update latency by adding current partition's latency
  251. latency_so_far_sec += partition_to_latency_mapping[
  252. partition
  253. ].overall_latency_sec
  254. children = partition.children
  255. if partition.children:
  256. max_latency_sec = 0.0
  257. for child in partition.children:
  258. # Calculate latency between
  259. comm_latency_sec = get_comm_latency_between(
  260. partition, child, transfer_rate_bytes_per_sec
  261. )
  262. new_latency_sec = dfs_helper(
  263. child, latency_so_far_sec + comm_latency_sec
  264. )
  265. if new_latency_sec > max_latency_sec:
  266. max_latency_sec = new_latency_sec
  267. return max_latency_sec
  268. return latency_so_far_sec
  269. def get_top_partitions(partitions: List[Partition]) -> List[Partition]:
  270. """This function is to return all the partitions without parents
  271. as the starting points of all the paths
  272. """
  273. top_partitions = []
  274. for partition in partitions:
  275. # If a partition has no parents, then it is a top partition
  276. if len(partition.parents) == 0:
  277. top_partitions.append(partition)
  278. return top_partitions
  279. top_partitions = get_top_partitions(partitions)
  280. critical_path_latency_sec = 0.0
  281. for partition in top_partitions:
  282. latency_sec = dfs_helper(partition, 0.0)
  283. if latency_sec > critical_path_latency_sec:
  284. critical_path_latency_sec = latency_sec
  285. return critical_path_latency_sec