_digraph.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from collections import deque
  2. from typing import List, Set
  3. class DiGraph:
  4. """Really simple unweighted directed graph data structure to track dependencies.
  5. The API is pretty much the same as networkx so if you add something just
  6. copy their API.
  7. """
  8. def __init__(self):
  9. # Dict of node -> dict of arbitrary attributes
  10. self._node = {}
  11. # Nested dict of node -> successor node -> nothing.
  12. # (didn't implement edge data)
  13. self._succ = {}
  14. # Nested dict of node -> predecessor node -> nothing.
  15. self._pred = {}
  16. # Keep track of the order in which nodes are added to
  17. # the graph.
  18. self._node_order = {}
  19. self._insertion_idx = 0
  20. def add_node(self, n, **kwargs):
  21. """Add a node to the graph.
  22. Args:
  23. n: the node. Can we any object that is a valid dict key.
  24. **kwargs: any attributes you want to attach to the node.
  25. """
  26. if n not in self._node:
  27. self._node[n] = kwargs
  28. self._succ[n] = {}
  29. self._pred[n] = {}
  30. self._node_order[n] = self._insertion_idx
  31. self._insertion_idx += 1
  32. else:
  33. self._node[n].update(kwargs)
  34. def add_edge(self, u, v):
  35. """Add an edge to graph between nodes ``u`` and ``v``
  36. ``u`` and ``v`` will be created if they do not already exist.
  37. """
  38. # add nodes
  39. self.add_node(u)
  40. self.add_node(v)
  41. # add the edge
  42. self._succ[u][v] = True
  43. self._pred[v][u] = True
  44. def successors(self, n):
  45. """Returns an iterator over successor nodes of n."""
  46. try:
  47. return iter(self._succ[n])
  48. except KeyError as e:
  49. raise ValueError(f"The node {n} is not in the digraph.") from e
  50. def predecessors(self, n):
  51. """Returns an iterator over predecessors nodes of n."""
  52. try:
  53. return iter(self._pred[n])
  54. except KeyError as e:
  55. raise ValueError(f"The node {n} is not in the digraph.") from e
  56. @property
  57. def edges(self):
  58. """Returns an iterator over all edges (u, v) in the graph"""
  59. for n, successors in self._succ.items():
  60. for succ in successors:
  61. yield n, succ
  62. @property
  63. def nodes(self):
  64. """Returns a dictionary of all nodes to their attributes."""
  65. return self._node
  66. def __iter__(self):
  67. """Iterate over the nodes."""
  68. return iter(self._node)
  69. def __contains__(self, n):
  70. """Returns True if ``n`` is a node in the graph, False otherwise."""
  71. try:
  72. return n in self._node
  73. except TypeError:
  74. return False
  75. def forward_transitive_closure(self, src: str) -> Set[str]:
  76. """Returns a set of nodes that are reachable from src"""
  77. result = set(src)
  78. working_set = deque(src)
  79. while len(working_set) > 0:
  80. cur = working_set.popleft()
  81. for n in self.successors(cur):
  82. if n not in result:
  83. result.add(n)
  84. working_set.append(n)
  85. return result
  86. def backward_transitive_closure(self, src: str) -> Set[str]:
  87. """Returns a set of nodes that are reachable from src in reverse direction"""
  88. result = set(src)
  89. working_set = deque(src)
  90. while len(working_set) > 0:
  91. cur = working_set.popleft()
  92. for n in self.predecessors(cur):
  93. if n not in result:
  94. result.add(n)
  95. working_set.append(n)
  96. return result
  97. def all_paths(self, src: str, dst: str):
  98. """Returns a subgraph rooted at src that shows all the paths to dst."""
  99. result_graph = DiGraph()
  100. # First compute forward transitive closure of src (all things reachable from src).
  101. forward_reachable_from_src = self.forward_transitive_closure(src)
  102. if dst not in forward_reachable_from_src:
  103. return result_graph
  104. # Second walk the reverse dependencies of dst, adding each node to
  105. # the output graph iff it is also present in forward_reachable_from_src.
  106. # we don't use backward_transitive_closures for optimization purposes
  107. working_set = deque(dst)
  108. while len(working_set) > 0:
  109. cur = working_set.popleft()
  110. for n in self.predecessors(cur):
  111. if n in forward_reachable_from_src:
  112. result_graph.add_edge(n, cur)
  113. # only explore further if its reachable from src
  114. working_set.append(n)
  115. return result_graph.to_dot()
  116. def first_path(self, dst: str) -> List[str]:
  117. """Returns a list of nodes that show the first path that resulted in dst being added to the graph."""
  118. path = []
  119. while dst:
  120. path.append(dst)
  121. candidates = self._pred[dst].keys()
  122. dst, min_idx = "", None
  123. for candidate in candidates:
  124. idx = self._node_order.get(candidate, None)
  125. if idx is None:
  126. break
  127. if min_idx is None or idx < min_idx:
  128. min_idx = idx
  129. dst = candidate
  130. return list(reversed(path))
  131. def to_dot(self) -> str:
  132. """Returns the dot representation of the graph.
  133. Returns:
  134. A dot representation of the graph.
  135. """
  136. edges = "\n".join(f'"{f}" -> "{t}";' for f, t in self.edges)
  137. return f"""\
  138. digraph G {{
  139. rankdir = LR;
  140. node [shape=box];
  141. {edges}
  142. }}
  143. """