steinertree.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. from itertools import chain
  2. import networkx as nx
  3. from networkx.utils import not_implemented_for, pairwise
  4. __all__ = ["metric_closure", "steiner_tree"]
  5. @not_implemented_for("directed")
  6. def metric_closure(G, weight="weight"):
  7. """Return the metric closure of a graph.
  8. The metric closure of a graph *G* is the complete graph in which each edge
  9. is weighted by the shortest path distance between the nodes in *G* .
  10. Parameters
  11. ----------
  12. G : NetworkX graph
  13. Returns
  14. -------
  15. NetworkX graph
  16. Metric closure of the graph `G`.
  17. """
  18. M = nx.Graph()
  19. Gnodes = set(G)
  20. # check for connected graph while processing first node
  21. all_paths_iter = nx.all_pairs_dijkstra(G, weight=weight)
  22. u, (distance, path) = next(all_paths_iter)
  23. if Gnodes - set(distance):
  24. msg = "G is not a connected graph. metric_closure is not defined."
  25. raise nx.NetworkXError(msg)
  26. Gnodes.remove(u)
  27. for v in Gnodes:
  28. M.add_edge(u, v, distance=distance[v], path=path[v])
  29. # first node done -- now process the rest
  30. for u, (distance, path) in all_paths_iter:
  31. Gnodes.remove(u)
  32. for v in Gnodes:
  33. M.add_edge(u, v, distance=distance[v], path=path[v])
  34. return M
  35. def _mehlhorn_steiner_tree(G, terminal_nodes, weight):
  36. paths = nx.multi_source_dijkstra_path(G, terminal_nodes)
  37. d_1 = {}
  38. s = {}
  39. for v in G.nodes():
  40. s[v] = paths[v][0]
  41. d_1[(v, s[v])] = len(paths[v]) - 1
  42. # G1-G4 names match those from the Mehlhorn 1988 paper.
  43. G_1_prime = nx.Graph()
  44. for u, v, data in G.edges(data=True):
  45. su, sv = s[u], s[v]
  46. weight_here = d_1[(u, su)] + data.get(weight, 1) + d_1[(v, sv)]
  47. if not G_1_prime.has_edge(su, sv):
  48. G_1_prime.add_edge(su, sv, weight=weight_here)
  49. else:
  50. new_weight = min(weight_here, G_1_prime[su][sv][weight])
  51. G_1_prime.add_edge(su, sv, weight=new_weight)
  52. G_2 = nx.minimum_spanning_edges(G_1_prime, data=True)
  53. G_3 = nx.Graph()
  54. for u, v, d in G_2:
  55. path = nx.shortest_path(G, u, v, weight)
  56. for n1, n2 in pairwise(path):
  57. G_3.add_edge(n1, n2)
  58. G_3_mst = list(nx.minimum_spanning_edges(G_3, data=False))
  59. if G.is_multigraph():
  60. G_3_mst = (
  61. (u, v, min(G[u][v], key=lambda k: G[u][v][k][weight])) for u, v in G_3_mst
  62. )
  63. G_4 = G.edge_subgraph(G_3_mst).copy()
  64. _remove_nonterminal_leaves(G_4, terminal_nodes)
  65. return G_4.edges()
  66. def _kou_steiner_tree(G, terminal_nodes, weight):
  67. # H is the subgraph induced by terminal_nodes in the metric closure M of G.
  68. M = metric_closure(G, weight=weight)
  69. H = M.subgraph(terminal_nodes)
  70. # Use the 'distance' attribute of each edge provided by M.
  71. mst_edges = nx.minimum_spanning_edges(H, weight="distance", data=True)
  72. # Create an iterator over each edge in each shortest path; repeats are okay
  73. mst_all_edges = chain.from_iterable(pairwise(d["path"]) for u, v, d in mst_edges)
  74. if G.is_multigraph():
  75. mst_all_edges = (
  76. (u, v, min(G[u][v], key=lambda k: G[u][v][k][weight]))
  77. for u, v in mst_all_edges
  78. )
  79. # Find the MST again, over this new set of edges
  80. G_S = G.edge_subgraph(mst_all_edges)
  81. T_S = nx.minimum_spanning_edges(G_S, weight="weight", data=False)
  82. # Leaf nodes that are not terminal might still remain; remove them here
  83. T_H = G.edge_subgraph(T_S).copy()
  84. _remove_nonterminal_leaves(T_H, terminal_nodes)
  85. return T_H.edges()
  86. def _remove_nonterminal_leaves(G, terminals):
  87. terminals_set = set(terminals)
  88. for n in list(G.nodes):
  89. if n not in terminals_set and G.degree(n) == 1:
  90. G.remove_node(n)
  91. ALGORITHMS = {
  92. "kou": _kou_steiner_tree,
  93. "mehlhorn": _mehlhorn_steiner_tree,
  94. }
  95. @not_implemented_for("directed")
  96. def steiner_tree(G, terminal_nodes, weight="weight", method=None):
  97. r"""Return an approximation to the minimum Steiner tree of a graph.
  98. The minimum Steiner tree of `G` w.r.t a set of `terminal_nodes` (also *S*)
  99. is a tree within `G` that spans those nodes and has minimum size (sum of
  100. edge weights) among all such trees.
  101. The approximation algorithm is specified with the `method` keyword
  102. argument. All three available algorithms produce a tree whose weight is
  103. within a ``(2 - (2 / l))`` factor of the weight of the optimal Steiner tree,
  104. where ``l`` is the minimum number of leaf nodes across all possible Steiner
  105. trees.
  106. * ``"kou"`` [2]_ (runtime $O(|S| |V|^2)$) computes the minimum spanning tree of
  107. the subgraph of the metric closure of *G* induced by the terminal nodes,
  108. where the metric closure of *G* is the complete graph in which each edge is
  109. weighted by the shortest path distance between the nodes in *G*.
  110. * ``"mehlhorn"`` [3]_ (runtime $O(|E|+|V|\log|V|)$) modifies Kou et al.'s
  111. algorithm, beginning by finding the closest terminal node for each
  112. non-terminal. This data is used to create a complete graph containing only
  113. the terminal nodes, in which edge is weighted with the shortest path
  114. distance between them. The algorithm then proceeds in the same way as Kou
  115. et al..
  116. Parameters
  117. ----------
  118. G : NetworkX graph
  119. terminal_nodes : list
  120. A list of terminal nodes for which minimum steiner tree is
  121. to be found.
  122. weight : string (default = 'weight')
  123. Use the edge attribute specified by this string as the edge weight.
  124. Any edge attribute not present defaults to 1.
  125. method : string, optional (default = 'kou')
  126. The algorithm to use to approximate the Steiner tree.
  127. Supported options: 'kou', 'mehlhorn'.
  128. Other inputs produce a ValueError.
  129. Returns
  130. -------
  131. NetworkX graph
  132. Approximation to the minimum steiner tree of `G` induced by
  133. `terminal_nodes` .
  134. Notes
  135. -----
  136. For multigraphs, the edge between two nodes with minimum weight is the
  137. edge put into the Steiner tree.
  138. References
  139. ----------
  140. .. [1] Steiner_tree_problem on Wikipedia.
  141. https://en.wikipedia.org/wiki/Steiner_tree_problem
  142. .. [2] Kou, L., G. Markowsky, and L. Berman. 1981.
  143. ‘A Fast Algorithm for Steiner Trees’.
  144. Acta Informatica 15 (2): 141–45.
  145. https://doi.org/10.1007/BF00288961.
  146. .. [3] Mehlhorn, Kurt. 1988.
  147. ‘A Faster Approximation Algorithm for the Steiner Problem in Graphs’.
  148. Information Processing Letters 27 (3): 125–28.
  149. https://doi.org/10.1016/0020-0190(88)90066-X.
  150. """
  151. if method is None:
  152. import warnings
  153. msg = (
  154. "steiner_tree will change default method from 'kou' to 'mehlhorn'"
  155. "in version 3.2.\nSet the `method` kwarg to remove this warning."
  156. )
  157. warnings.warn(msg, FutureWarning, stacklevel=4)
  158. method = "kou"
  159. try:
  160. algo = ALGORITHMS[method]
  161. except KeyError as e:
  162. msg = f"{method} is not a valid choice for an algorithm."
  163. raise ValueError(msg) from e
  164. edges = algo(G, terminal_nodes, weight)
  165. # For multigraph we should add the minimal weight edge keys
  166. if G.is_multigraph():
  167. edges = (
  168. (u, v, min(G[u][v], key=lambda k: G[u][v][k][weight])) for u, v in edges
  169. )
  170. T = G.edge_subgraph(edges)
  171. return T