123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- from collections import deque
- from typing import List, Set
- class DiGraph:
- """Really simple unweighted directed graph data structure to track dependencies.
- The API is pretty much the same as networkx so if you add something just
- copy their API.
- """
- def __init__(self):
- # Dict of node -> dict of arbitrary attributes
- self._node = {}
- # Nested dict of node -> successor node -> nothing.
- # (didn't implement edge data)
- self._succ = {}
- # Nested dict of node -> predecessor node -> nothing.
- self._pred = {}
- # Keep track of the order in which nodes are added to
- # the graph.
- self._node_order = {}
- self._insertion_idx = 0
- def add_node(self, n, **kwargs):
- """Add a node to the graph.
- Args:
- n: the node. Can we any object that is a valid dict key.
- **kwargs: any attributes you want to attach to the node.
- """
- if n not in self._node:
- self._node[n] = kwargs
- self._succ[n] = {}
- self._pred[n] = {}
- self._node_order[n] = self._insertion_idx
- self._insertion_idx += 1
- else:
- self._node[n].update(kwargs)
- def add_edge(self, u, v):
- """Add an edge to graph between nodes ``u`` and ``v``
- ``u`` and ``v`` will be created if they do not already exist.
- """
- # add nodes
- self.add_node(u)
- self.add_node(v)
- # add the edge
- self._succ[u][v] = True
- self._pred[v][u] = True
- def successors(self, n):
- """Returns an iterator over successor nodes of n."""
- try:
- return iter(self._succ[n])
- except KeyError as e:
- raise ValueError(f"The node {n} is not in the digraph.") from e
- def predecessors(self, n):
- """Returns an iterator over predecessors nodes of n."""
- try:
- return iter(self._pred[n])
- except KeyError as e:
- raise ValueError(f"The node {n} is not in the digraph.") from e
- @property
- def edges(self):
- """Returns an iterator over all edges (u, v) in the graph"""
- for n, successors in self._succ.items():
- for succ in successors:
- yield n, succ
- @property
- def nodes(self):
- """Returns a dictionary of all nodes to their attributes."""
- return self._node
- def __iter__(self):
- """Iterate over the nodes."""
- return iter(self._node)
- def __contains__(self, n):
- """Returns True if ``n`` is a node in the graph, False otherwise."""
- try:
- return n in self._node
- except TypeError:
- return False
- def forward_transitive_closure(self, src: str) -> Set[str]:
- """Returns a set of nodes that are reachable from src"""
- result = set(src)
- working_set = deque(src)
- while len(working_set) > 0:
- cur = working_set.popleft()
- for n in self.successors(cur):
- if n not in result:
- result.add(n)
- working_set.append(n)
- return result
- def backward_transitive_closure(self, src: str) -> Set[str]:
- """Returns a set of nodes that are reachable from src in reverse direction"""
- result = set(src)
- working_set = deque(src)
- while len(working_set) > 0:
- cur = working_set.popleft()
- for n in self.predecessors(cur):
- if n not in result:
- result.add(n)
- working_set.append(n)
- return result
- def all_paths(self, src: str, dst: str):
- """Returns a subgraph rooted at src that shows all the paths to dst."""
- result_graph = DiGraph()
- # First compute forward transitive closure of src (all things reachable from src).
- forward_reachable_from_src = self.forward_transitive_closure(src)
- if dst not in forward_reachable_from_src:
- return result_graph
- # Second walk the reverse dependencies of dst, adding each node to
- # the output graph iff it is also present in forward_reachable_from_src.
- # we don't use backward_transitive_closures for optimization purposes
- working_set = deque(dst)
- while len(working_set) > 0:
- cur = working_set.popleft()
- for n in self.predecessors(cur):
- if n in forward_reachable_from_src:
- result_graph.add_edge(n, cur)
- # only explore further if its reachable from src
- working_set.append(n)
- return result_graph.to_dot()
- def first_path(self, dst: str) -> List[str]:
- """Returns a list of nodes that show the first path that resulted in dst being added to the graph."""
- path = []
- while dst:
- path.append(dst)
- candidates = self._pred[dst].keys()
- dst, min_idx = "", None
- for candidate in candidates:
- idx = self._node_order.get(candidate, None)
- if idx is None:
- break
- if min_idx is None or idx < min_idx:
- min_idx = idx
- dst = candidate
- return list(reversed(path))
- def to_dot(self) -> str:
- """Returns the dot representation of the graph.
- Returns:
- A dot representation of the graph.
- """
- edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges)
- return f"""\
- digraph G {{
- rankdir = LR;
- node [shape=box];
- {edges}
- }}
- """
|