test_steinertree.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import pytest
  2. import networkx as nx
  3. from networkx.algorithms.approximation.steinertree import metric_closure, steiner_tree
  4. from networkx.utils import edges_equal
  5. class TestSteinerTree:
  6. @classmethod
  7. def setup_class(cls):
  8. G1 = nx.Graph()
  9. G1.add_edge(1, 2, weight=10)
  10. G1.add_edge(2, 3, weight=10)
  11. G1.add_edge(3, 4, weight=10)
  12. G1.add_edge(4, 5, weight=10)
  13. G1.add_edge(5, 6, weight=10)
  14. G1.add_edge(2, 7, weight=1)
  15. G1.add_edge(7, 5, weight=1)
  16. G2 = nx.Graph()
  17. G2.add_edge(0, 5, weight=6)
  18. G2.add_edge(1, 2, weight=2)
  19. G2.add_edge(1, 5, weight=3)
  20. G2.add_edge(2, 4, weight=4)
  21. G2.add_edge(3, 5, weight=5)
  22. G2.add_edge(4, 5, weight=1)
  23. G3 = nx.Graph()
  24. G3.add_edge(1, 2, weight=8)
  25. G3.add_edge(1, 9, weight=3)
  26. G3.add_edge(1, 8, weight=6)
  27. G3.add_edge(1, 10, weight=2)
  28. G3.add_edge(1, 14, weight=3)
  29. G3.add_edge(2, 3, weight=6)
  30. G3.add_edge(3, 4, weight=3)
  31. G3.add_edge(3, 10, weight=2)
  32. G3.add_edge(3, 11, weight=1)
  33. G3.add_edge(4, 5, weight=1)
  34. G3.add_edge(4, 11, weight=1)
  35. G3.add_edge(5, 6, weight=4)
  36. G3.add_edge(5, 11, weight=2)
  37. G3.add_edge(5, 12, weight=1)
  38. G3.add_edge(5, 13, weight=3)
  39. G3.add_edge(6, 7, weight=2)
  40. G3.add_edge(6, 12, weight=3)
  41. G3.add_edge(6, 13, weight=1)
  42. G3.add_edge(7, 8, weight=3)
  43. G3.add_edge(7, 9, weight=3)
  44. G3.add_edge(7, 11, weight=5)
  45. G3.add_edge(7, 13, weight=2)
  46. G3.add_edge(7, 14, weight=4)
  47. G3.add_edge(8, 9, weight=2)
  48. G3.add_edge(9, 14, weight=1)
  49. G3.add_edge(10, 11, weight=2)
  50. G3.add_edge(10, 14, weight=1)
  51. G3.add_edge(11, 12, weight=1)
  52. G3.add_edge(11, 14, weight=7)
  53. G3.add_edge(12, 14, weight=3)
  54. G3.add_edge(12, 15, weight=1)
  55. G3.add_edge(13, 14, weight=4)
  56. G3.add_edge(13, 15, weight=1)
  57. G3.add_edge(14, 15, weight=2)
  58. cls.G1 = G1
  59. cls.G2 = G2
  60. cls.G3 = G3
  61. cls.G1_term_nodes = [1, 2, 3, 4, 5]
  62. cls.G2_term_nodes = [0, 2, 3]
  63. cls.G3_term_nodes = [1, 3, 5, 6, 8, 10, 11, 12, 13]
  64. cls.methods = ["kou", "mehlhorn"]
  65. def test_connected_metric_closure(self):
  66. G = self.G1.copy()
  67. G.add_node(100)
  68. pytest.raises(nx.NetworkXError, metric_closure, G)
  69. def test_metric_closure(self):
  70. M = metric_closure(self.G1)
  71. mc = [
  72. (1, 2, {"distance": 10, "path": [1, 2]}),
  73. (1, 3, {"distance": 20, "path": [1, 2, 3]}),
  74. (1, 4, {"distance": 22, "path": [1, 2, 7, 5, 4]}),
  75. (1, 5, {"distance": 12, "path": [1, 2, 7, 5]}),
  76. (1, 6, {"distance": 22, "path": [1, 2, 7, 5, 6]}),
  77. (1, 7, {"distance": 11, "path": [1, 2, 7]}),
  78. (2, 3, {"distance": 10, "path": [2, 3]}),
  79. (2, 4, {"distance": 12, "path": [2, 7, 5, 4]}),
  80. (2, 5, {"distance": 2, "path": [2, 7, 5]}),
  81. (2, 6, {"distance": 12, "path": [2, 7, 5, 6]}),
  82. (2, 7, {"distance": 1, "path": [2, 7]}),
  83. (3, 4, {"distance": 10, "path": [3, 4]}),
  84. (3, 5, {"distance": 12, "path": [3, 2, 7, 5]}),
  85. (3, 6, {"distance": 22, "path": [3, 2, 7, 5, 6]}),
  86. (3, 7, {"distance": 11, "path": [3, 2, 7]}),
  87. (4, 5, {"distance": 10, "path": [4, 5]}),
  88. (4, 6, {"distance": 20, "path": [4, 5, 6]}),
  89. (4, 7, {"distance": 11, "path": [4, 5, 7]}),
  90. (5, 6, {"distance": 10, "path": [5, 6]}),
  91. (5, 7, {"distance": 1, "path": [5, 7]}),
  92. (6, 7, {"distance": 11, "path": [6, 5, 7]}),
  93. ]
  94. assert edges_equal(list(M.edges(data=True)), mc)
  95. def test_steiner_tree(self):
  96. valid_steiner_trees = [
  97. [
  98. [
  99. (1, 2, {"weight": 10}),
  100. (2, 3, {"weight": 10}),
  101. (2, 7, {"weight": 1}),
  102. (3, 4, {"weight": 10}),
  103. (5, 7, {"weight": 1}),
  104. ],
  105. [
  106. (1, 2, {"weight": 10}),
  107. (2, 7, {"weight": 1}),
  108. (3, 4, {"weight": 10}),
  109. (4, 5, {"weight": 10}),
  110. (5, 7, {"weight": 1}),
  111. ],
  112. [
  113. (1, 2, {"weight": 10}),
  114. (2, 3, {"weight": 10}),
  115. (2, 7, {"weight": 1}),
  116. (4, 5, {"weight": 10}),
  117. (5, 7, {"weight": 1}),
  118. ],
  119. ],
  120. [
  121. [
  122. (0, 5, {"weight": 6}),
  123. (1, 2, {"weight": 2}),
  124. (1, 5, {"weight": 3}),
  125. (3, 5, {"weight": 5}),
  126. ],
  127. [
  128. (0, 5, {"weight": 6}),
  129. (4, 2, {"weight": 4}),
  130. (4, 5, {"weight": 1}),
  131. (3, 5, {"weight": 5}),
  132. ],
  133. ],
  134. [
  135. [
  136. (1, 10, {"weight": 2}),
  137. (3, 10, {"weight": 2}),
  138. (3, 11, {"weight": 1}),
  139. (5, 12, {"weight": 1}),
  140. (6, 13, {"weight": 1}),
  141. (8, 9, {"weight": 2}),
  142. (9, 14, {"weight": 1}),
  143. (10, 14, {"weight": 1}),
  144. (11, 12, {"weight": 1}),
  145. (12, 15, {"weight": 1}),
  146. (13, 15, {"weight": 1}),
  147. ]
  148. ],
  149. ]
  150. for method in self.methods:
  151. for G, term_nodes, valid_trees in zip(
  152. [self.G1, self.G2, self.G3],
  153. [self.G1_term_nodes, self.G2_term_nodes, self.G3_term_nodes],
  154. valid_steiner_trees,
  155. ):
  156. S = steiner_tree(G, term_nodes, method=method)
  157. assert any(
  158. edges_equal(list(S.edges(data=True)), valid_tree)
  159. for valid_tree in valid_trees
  160. )
  161. def test_multigraph_steiner_tree(self):
  162. G = nx.MultiGraph()
  163. G.add_edges_from(
  164. [
  165. (1, 2, 0, {"weight": 1}),
  166. (2, 3, 0, {"weight": 999}),
  167. (2, 3, 1, {"weight": 1}),
  168. (3, 4, 0, {"weight": 1}),
  169. (3, 5, 0, {"weight": 1}),
  170. ]
  171. )
  172. terminal_nodes = [2, 4, 5]
  173. expected_edges = [
  174. (2, 3, 1, {"weight": 1}), # edge with key 1 has lower weight
  175. (3, 4, 0, {"weight": 1}),
  176. (3, 5, 0, {"weight": 1}),
  177. ]
  178. for method in self.methods:
  179. S = steiner_tree(G, terminal_nodes, method=method)
  180. assert edges_equal(S.edges(data=True, keys=True), expected_edges)