123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319 |
- from enum import Enum
- from typing import NamedTuple, Dict, List, Set
- from torch.fx.node import Node, map_arg
- class Partition:
- """Partition class contains all the information about an individual partition.
- It also provides necessary methods for manipulation the partition.
- """
- def __init__(self, partition_id: int) -> None:
- self.nodes: Set[Node] = set()
- self.partition_id = partition_id
- self.parents: Set["Partition"] = set()
- self.children: Set["Partition"] = set()
- self.bfs_level: int = -1
- self.used_mem_bytes: int = 0
- self.logical_device_ids: List[int] = []
- def __str__(self):
- return str(self.partition_id)
- def recalculate_mem_size(self):
- self.used_mem_bytes = 0
- for node in self.nodes:
- self.used_mem_bytes += get_extra_size_of(node, self.nodes)
- def add_node(self, node):
- input_nodes: Dict[Node, None] = {}
- map_arg(node.args, lambda n: input_nodes.setdefault(n))
- map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
- # Add current node's input nodes if they are placeholder or constants
- for n in input_nodes:
- if n.op in {"placeholder", "get_attr"}:
- self.nodes.add(n)
- self.nodes.add(node)
- self.recalculate_mem_size()
- def remove_node(self, node):
- # Remove a node only if the node is in the partition
- if node in self.nodes:
- self.nodes.remove(node)
- # Collect the node's input nodes
- input_nodes: Dict[Node, None] = {}
- map_arg(node.args, lambda n: input_nodes.setdefault(n))
- map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
- # Check if an input node is a placeholder or get_attr,
- # and this input node is not used by some other nodes in this partition,
- # the remove this input node
- for input_node in input_nodes:
- if all(
- [n not in self.nodes for n in input_node.users]
- ) and input_node.op in {"placeholder", "get_attr"}:
- self.nodes.remove(input_node)
- self.recalculate_mem_size()
- class Device(NamedTuple):
- name: str
- available_mem_bytes: int
- logical_id: int
- class NodeLatency(NamedTuple):
- # Latency due to the memory bandwidth
- mem_latency_sec: float
- # Latency due to the computation
- computer_latency_sec: float
- class PartitionLatency(NamedTuple):
- # Sum of all nodes' memory latency on the critical path
- mem_latency_sec: float
- # Sum of all nodes' compute latency on the critical path
- computer_latency_sec: float
- # Latency of the critical path
- overall_latency_sec: float
- class PartitionMode(Enum):
- size_based = 0
- sparse_nn = 1
- cost_aware = 2
- kl_based = 3
- aot_based = 4
- class PartitionerConfig(NamedTuple):
- devices: List[Device]
- mode: PartitionMode = PartitionMode.size_based
- transfer_rate_bytes_per_sec: float = 0.0
- node_to_latency_mapping: Dict[Node, NodeLatency] = {}
- node_to_partition_mapping: Dict[Node, int] = {}
- partition_to_logical_device_mapping: Dict[int, List[int]] = {}
- # Saturate host by replicating partitions to the remaining idle devices.
- saturate_host: bool = False
- def get_extra_size_of(node: Node, nodes: Set[Node]) -> int:
- """Given a node and a set of nodes,
- this function return the extra size that needed
- if this node is included in this set.
- """
- # Find all its input nodes
- input_nodes: Dict[Node, None] = {}
- map_arg(node.args, lambda n: input_nodes.setdefault(n))
- map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
- # Calculate total size of related nodes
- total_size_of_input_nodes = 0
- for n in input_nodes:
- # Make sure this node hasn't been in this set yet
- if n not in nodes:
- size_bytes = getattr(n, "size_bytes", None)
- if size_bytes:
- total_size_of_input_nodes += size_bytes.output_size
- else:
- raise RuntimeError("node has no size_bytes attr")
- # Don't forget the op node itself
- size_bytes = getattr(node, "size_bytes", None)
- if size_bytes:
- total_size_of_input_nodes += size_bytes.total_size
- else:
- raise RuntimeError("node has no size_bytes attr")
- return total_size_of_input_nodes
- def get_latency_of_one_partition(
- partition: Partition, node_to_latency_mapping: Dict[Node, NodeLatency]
- ) -> PartitionLatency:
- """Given a partiton and its nodes' latency, return a PartitionLatency for this partition"""
- def get_top_nodes(partition: Partition) -> List[Node]:
- """Given a partition, return a list of nodes on the top bfs level"""
- top_nodes: List[Node] = []
- for node in partition.nodes:
- # Skip placeholder and get_attr nodes
- if node.op in {"placeholder", "get_attr"}:
- continue
- input_nodes: Dict[Node, None] = {}
- map_arg(node.args, lambda n: input_nodes.setdefault(n))
- map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
- # If a node has no input nodes in this partition,
- # or its input nodes in this partition are placeholders and get_attrs
- # this node is on the top bfs level in this partition
- if not any(
- [
- n in partition.nodes and n.op not in {"placeholder", "get_attr"}
- for n in input_nodes
- ]
- ):
- top_nodes.append(node)
- return top_nodes
- def dfs_helper(node: Node, partition_latency) -> PartitionLatency:
- """Given a top node of a partition, this function returns
- the latency of the critical path in the partition
- """
- node_latency = node_to_latency_mapping[node]
- # Calculate the current overall latency of the partition
- overall_latency_sec = partition_latency.overall_latency_sec + max(
- node_latency.computer_latency_sec, node_latency.mem_latency_sec
- )
- # Update the mem latency of this path
- mem_latency_sec = (
- partition_latency.mem_latency_sec + node_latency.mem_latency_sec
- )
- # Update the compute latency of this path
- computer_latency_sec = (
- partition_latency.computer_latency_sec + node_latency.computer_latency_sec
- )
- # Get all users of this node that are in this partition
- users = set(node.users).intersection(partition.nodes)
- if users:
- max_latency = PartitionLatency(
- mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
- )
- for n in users:
- # Get new partition latency recursively
- new_partition_latency = dfs_helper(
- n,
- PartitionLatency(
- mem_latency_sec, computer_latency_sec, overall_latency_sec
- ),
- )
- if (
- new_partition_latency.overall_latency_sec
- > max_latency.overall_latency_sec
- ):
- max_latency = new_partition_latency
- return max_latency
- # If there is no user, the node is at bottom of the partition
- return PartitionLatency(
- mem_latency_sec, computer_latency_sec, overall_latency_sec
- )
- # Main part starts
- # Get all top level nodes of this partition
- top_nodes = get_top_nodes(partition)
- critical_path_latency = PartitionLatency(
- mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
- )
- # Go through all top nodes and find the largest latency (critical pass latency)
- for node in top_nodes:
- partition_latency = dfs_helper(
- node,
- PartitionLatency(
- mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0
- ),
- )
- if (
- partition_latency.overall_latency_sec
- > critical_path_latency.overall_latency_sec
- ):
- critical_path_latency = partition_latency
- return critical_path_latency
- def get_partition_to_latency_mapping(
- partitions: List[Partition], node_to_latency_mapping: Dict[Node, NodeLatency]
- ) -> Dict[Partition, PartitionLatency]:
- """Given all the partitions and node_to_latency_mapping dictionary,
- return a mapping dictionary of each partition to its overall latency
- """
- partition_to_latency_mapping: Dict[Partition, PartitionLatency] = {}
- # Go through each partition and get its latency
- for partition in partitions:
- partition_latency = get_latency_of_one_partition(
- partition, node_to_latency_mapping
- )
- partition_to_latency_mapping[partition] = partition_latency
- return partition_to_latency_mapping
- def get_comm_latency_between(
- parent_partition: Partition,
- child_partition: Partition,
- transfer_rate_bytes_per_sec: float,
- ):
- """Given two partitions (parent and child),
- calculate the communication latency between the two.
- """
- # If two partitions are on the same device, the comm latency is 0.
- if (
- parent_partition.logical_device_ids != []
- and child_partition.logical_device_ids != []
- and parent_partition.logical_device_ids == child_partition.logical_device_ids
- ):
- return 0.0
- # Keep tracking the communication size between parent and child
- comm_size = 0
- # Keep tracking all the counted node
- visited_nodes = set()
- # Go through all nodes in the child partition
- # If a node has input nodes from the parent partition,
- # the output size of those input nodes will be counted
- # and added to comm_size
- for node in child_partition.nodes:
- input_nodes: Dict[Node, None] = {}
- map_arg(node.args, lambda n: input_nodes.setdefault(n))
- map_arg(node.kwargs, lambda n: input_nodes.setdefault(n))
- for n in input_nodes:
- if n in parent_partition.nodes and n not in visited_nodes:
- size_bytes = getattr(n, "size_bytes", None)
- if size_bytes is not None:
- comm_size += size_bytes.output_size
- visited_nodes.add(n)
- return comm_size / transfer_rate_bytes_per_sec
- def get_latency_of_partitioned_graph(
- partitions: List[Partition],
- partition_to_latency_mapping: Dict[Partition, PartitionLatency],
- transfer_rate_bytes_per_sec: float,
- ):
- """Given all paritions in a graph, find the critical path among all partitions
- and return its latency as the latency of the whole graph
- """
- def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float:
- """This function helps to recursively get the latency of a path of partitions"""
- # Update latency by adding current partition's latency
- latency_so_far_sec += partition_to_latency_mapping[
- partition
- ].overall_latency_sec
- children = partition.children
- if partition.children:
- max_latency_sec = 0.0
- for child in partition.children:
- # Calculate latency between
- comm_latency_sec = get_comm_latency_between(
- partition, child, transfer_rate_bytes_per_sec
- )
- new_latency_sec = dfs_helper(
- child, latency_so_far_sec + comm_latency_sec
- )
- if new_latency_sec > max_latency_sec:
- max_latency_sec = new_latency_sec
- return max_latency_sec
- return latency_so_far_sec
- def get_top_partitions(partitions: List[Partition]) -> List[Partition]:
- """This function is to return all the partitions without parents
- as the starting points of all the paths
- """
- top_partitions = []
- for partition in partitions:
- # If a partition has no parents, then it is a top partition
- if len(partition.parents) == 0:
- top_partitions.append(partition)
- return top_partitions
- top_partitions = get_top_partitions(partitions)
- critical_path_latency_sec = 0.0
- for partition in top_partitions:
- latency_sec = dfs_helper(partition, 0.0)
- if latency_sec > critical_path_latency_sec:
- critical_path_latency_sec = latency_sec
- return critical_path_latency_sec
|