test_mst.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. """Unit tests for the :mod:`networkx.algorithms.tree.mst` module."""
  2. import pytest
  3. import networkx as nx
  4. from networkx.utils import edges_equal, nodes_equal
  5. def test_unknown_algorithm():
  6. with pytest.raises(ValueError):
  7. nx.minimum_spanning_tree(nx.Graph(), algorithm="random")
  8. class MinimumSpanningTreeTestBase:
  9. """Base class for test classes for minimum spanning tree algorithms.
  10. This class contains some common tests that will be inherited by
  11. subclasses. Each subclass must have a class attribute
  12. :data:`algorithm` that is a string representing the algorithm to
  13. run, as described under the ``algorithm`` keyword argument for the
  14. :func:`networkx.minimum_spanning_edges` function. Subclasses can
  15. then implement any algorithm-specific tests.
  16. """
  17. def setup_method(self, method):
  18. """Creates an example graph and stores the expected minimum and
  19. maximum spanning tree edges.
  20. """
  21. # This stores the class attribute `algorithm` in an instance attribute.
  22. self.algo = self.algorithm
  23. # This example graph comes from Wikipedia:
  24. # https://en.wikipedia.org/wiki/Kruskal's_algorithm
  25. edges = [
  26. (0, 1, 7),
  27. (0, 3, 5),
  28. (1, 2, 8),
  29. (1, 3, 9),
  30. (1, 4, 7),
  31. (2, 4, 5),
  32. (3, 4, 15),
  33. (3, 5, 6),
  34. (4, 5, 8),
  35. (4, 6, 9),
  36. (5, 6, 11),
  37. ]
  38. self.G = nx.Graph()
  39. self.G.add_weighted_edges_from(edges)
  40. self.minimum_spanning_edgelist = [
  41. (0, 1, {"weight": 7}),
  42. (0, 3, {"weight": 5}),
  43. (1, 4, {"weight": 7}),
  44. (2, 4, {"weight": 5}),
  45. (3, 5, {"weight": 6}),
  46. (4, 6, {"weight": 9}),
  47. ]
  48. self.maximum_spanning_edgelist = [
  49. (0, 1, {"weight": 7}),
  50. (1, 2, {"weight": 8}),
  51. (1, 3, {"weight": 9}),
  52. (3, 4, {"weight": 15}),
  53. (4, 6, {"weight": 9}),
  54. (5, 6, {"weight": 11}),
  55. ]
  56. def test_minimum_edges(self):
  57. edges = nx.minimum_spanning_edges(self.G, algorithm=self.algo)
  58. # Edges from the spanning edges functions don't come in sorted
  59. # orientation, so we need to sort each edge individually.
  60. actual = sorted((min(u, v), max(u, v), d) for u, v, d in edges)
  61. assert edges_equal(actual, self.minimum_spanning_edgelist)
  62. def test_maximum_edges(self):
  63. edges = nx.maximum_spanning_edges(self.G, algorithm=self.algo)
  64. # Edges from the spanning edges functions don't come in sorted
  65. # orientation, so we need to sort each edge individually.
  66. actual = sorted((min(u, v), max(u, v), d) for u, v, d in edges)
  67. assert edges_equal(actual, self.maximum_spanning_edgelist)
  68. def test_without_data(self):
  69. edges = nx.minimum_spanning_edges(self.G, algorithm=self.algo, data=False)
  70. # Edges from the spanning edges functions don't come in sorted
  71. # orientation, so we need to sort each edge individually.
  72. actual = sorted((min(u, v), max(u, v)) for u, v in edges)
  73. expected = [(u, v) for u, v, d in self.minimum_spanning_edgelist]
  74. assert edges_equal(actual, expected)
  75. def test_nan_weights(self):
  76. # Edge weights NaN never appear in the spanning tree. see #2164
  77. G = self.G
  78. G.add_edge(0, 12, weight=float("nan"))
  79. edges = nx.minimum_spanning_edges(
  80. G, algorithm=self.algo, data=False, ignore_nan=True
  81. )
  82. actual = sorted((min(u, v), max(u, v)) for u, v in edges)
  83. expected = [(u, v) for u, v, d in self.minimum_spanning_edgelist]
  84. assert edges_equal(actual, expected)
  85. # Now test for raising exception
  86. edges = nx.minimum_spanning_edges(
  87. G, algorithm=self.algo, data=False, ignore_nan=False
  88. )
  89. with pytest.raises(ValueError):
  90. list(edges)
  91. # test default for ignore_nan as False
  92. edges = nx.minimum_spanning_edges(G, algorithm=self.algo, data=False)
  93. with pytest.raises(ValueError):
  94. list(edges)
  95. def test_nan_weights_order(self):
  96. # now try again with a nan edge at the beginning of G.nodes
  97. edges = [
  98. (0, 1, 7),
  99. (0, 3, 5),
  100. (1, 2, 8),
  101. (1, 3, 9),
  102. (1, 4, 7),
  103. (2, 4, 5),
  104. (3, 4, 15),
  105. (3, 5, 6),
  106. (4, 5, 8),
  107. (4, 6, 9),
  108. (5, 6, 11),
  109. ]
  110. G = nx.Graph()
  111. G.add_weighted_edges_from([(u + 1, v + 1, wt) for u, v, wt in edges])
  112. G.add_edge(0, 7, weight=float("nan"))
  113. edges = nx.minimum_spanning_edges(
  114. G, algorithm=self.algo, data=False, ignore_nan=True
  115. )
  116. actual = sorted((min(u, v), max(u, v)) for u, v in edges)
  117. shift = [(u + 1, v + 1) for u, v, d in self.minimum_spanning_edgelist]
  118. assert edges_equal(actual, shift)
  119. def test_isolated_node(self):
  120. # now try again with an isolated node
  121. edges = [
  122. (0, 1, 7),
  123. (0, 3, 5),
  124. (1, 2, 8),
  125. (1, 3, 9),
  126. (1, 4, 7),
  127. (2, 4, 5),
  128. (3, 4, 15),
  129. (3, 5, 6),
  130. (4, 5, 8),
  131. (4, 6, 9),
  132. (5, 6, 11),
  133. ]
  134. G = nx.Graph()
  135. G.add_weighted_edges_from([(u + 1, v + 1, wt) for u, v, wt in edges])
  136. G.add_node(0)
  137. edges = nx.minimum_spanning_edges(
  138. G, algorithm=self.algo, data=False, ignore_nan=True
  139. )
  140. actual = sorted((min(u, v), max(u, v)) for u, v in edges)
  141. shift = [(u + 1, v + 1) for u, v, d in self.minimum_spanning_edgelist]
  142. assert edges_equal(actual, shift)
  143. def test_minimum_tree(self):
  144. T = nx.minimum_spanning_tree(self.G, algorithm=self.algo)
  145. actual = sorted(T.edges(data=True))
  146. assert edges_equal(actual, self.minimum_spanning_edgelist)
  147. def test_maximum_tree(self):
  148. T = nx.maximum_spanning_tree(self.G, algorithm=self.algo)
  149. actual = sorted(T.edges(data=True))
  150. assert edges_equal(actual, self.maximum_spanning_edgelist)
  151. def test_disconnected(self):
  152. G = nx.Graph([(0, 1, {"weight": 1}), (2, 3, {"weight": 2})])
  153. T = nx.minimum_spanning_tree(G, algorithm=self.algo)
  154. assert nodes_equal(list(T), list(range(4)))
  155. assert edges_equal(list(T.edges()), [(0, 1), (2, 3)])
  156. def test_empty_graph(self):
  157. G = nx.empty_graph(3)
  158. T = nx.minimum_spanning_tree(G, algorithm=self.algo)
  159. assert nodes_equal(sorted(T), list(range(3)))
  160. assert T.number_of_edges() == 0
  161. def test_attributes(self):
  162. G = nx.Graph()
  163. G.add_edge(1, 2, weight=1, color="red", distance=7)
  164. G.add_edge(2, 3, weight=1, color="green", distance=2)
  165. G.add_edge(1, 3, weight=10, color="blue", distance=1)
  166. G.graph["foo"] = "bar"
  167. T = nx.minimum_spanning_tree(G, algorithm=self.algo)
  168. assert T.graph == G.graph
  169. assert nodes_equal(T, G)
  170. for u, v in T.edges():
  171. assert T.adj[u][v] == G.adj[u][v]
  172. def test_weight_attribute(self):
  173. G = nx.Graph()
  174. G.add_edge(0, 1, weight=1, distance=7)
  175. G.add_edge(0, 2, weight=30, distance=1)
  176. G.add_edge(1, 2, weight=1, distance=1)
  177. G.add_node(3)
  178. T = nx.minimum_spanning_tree(G, algorithm=self.algo, weight="distance")
  179. assert nodes_equal(sorted(T), list(range(4)))
  180. assert edges_equal(sorted(T.edges()), [(0, 2), (1, 2)])
  181. T = nx.maximum_spanning_tree(G, algorithm=self.algo, weight="distance")
  182. assert nodes_equal(sorted(T), list(range(4)))
  183. assert edges_equal(sorted(T.edges()), [(0, 1), (0, 2)])
  184. class TestBoruvka(MinimumSpanningTreeTestBase):
  185. """Unit tests for computing a minimum (or maximum) spanning tree
  186. using Borůvka's algorithm.
  187. """
  188. algorithm = "boruvka"
  189. def test_unicode_name(self):
  190. """Tests that using a Unicode string can correctly indicate
  191. Borůvka's algorithm.
  192. """
  193. edges = nx.minimum_spanning_edges(self.G, algorithm="borůvka")
  194. # Edges from the spanning edges functions don't come in sorted
  195. # orientation, so we need to sort each edge individually.
  196. actual = sorted((min(u, v), max(u, v), d) for u, v, d in edges)
  197. assert edges_equal(actual, self.minimum_spanning_edgelist)
  198. class MultigraphMSTTestBase(MinimumSpanningTreeTestBase):
  199. # Abstract class
  200. def test_multigraph_keys_min(self):
  201. """Tests that the minimum spanning edges of a multigraph
  202. preserves edge keys.
  203. """
  204. G = nx.MultiGraph()
  205. G.add_edge(0, 1, key="a", weight=2)
  206. G.add_edge(0, 1, key="b", weight=1)
  207. min_edges = nx.minimum_spanning_edges
  208. mst_edges = min_edges(G, algorithm=self.algo, data=False)
  209. assert edges_equal([(0, 1, "b")], list(mst_edges))
  210. def test_multigraph_keys_max(self):
  211. """Tests that the maximum spanning edges of a multigraph
  212. preserves edge keys.
  213. """
  214. G = nx.MultiGraph()
  215. G.add_edge(0, 1, key="a", weight=2)
  216. G.add_edge(0, 1, key="b", weight=1)
  217. max_edges = nx.maximum_spanning_edges
  218. mst_edges = max_edges(G, algorithm=self.algo, data=False)
  219. assert edges_equal([(0, 1, "a")], list(mst_edges))
  220. class TestKruskal(MultigraphMSTTestBase):
  221. """Unit tests for computing a minimum (or maximum) spanning tree
  222. using Kruskal's algorithm.
  223. """
  224. algorithm = "kruskal"
  225. def test_key_data_bool(self):
  226. """Tests that the keys and data values are included in
  227. MST edges based on whether keys and data parameters are
  228. true or false"""
  229. G = nx.MultiGraph()
  230. G.add_edge(1, 2, key=1, weight=2)
  231. G.add_edge(1, 2, key=2, weight=3)
  232. G.add_edge(3, 2, key=1, weight=2)
  233. G.add_edge(3, 1, key=1, weight=4)
  234. # keys are included and data is not included
  235. mst_edges = nx.minimum_spanning_edges(
  236. G, algorithm=self.algo, keys=True, data=False
  237. )
  238. assert edges_equal([(1, 2, 1), (2, 3, 1)], list(mst_edges))
  239. # keys are not included and data is included
  240. mst_edges = nx.minimum_spanning_edges(
  241. G, algorithm=self.algo, keys=False, data=True
  242. )
  243. assert edges_equal(
  244. [(1, 2, {"weight": 2}), (2, 3, {"weight": 2})], list(mst_edges)
  245. )
  246. # both keys and data are not included
  247. mst_edges = nx.minimum_spanning_edges(
  248. G, algorithm=self.algo, keys=False, data=False
  249. )
  250. assert edges_equal([(1, 2), (2, 3)], list(mst_edges))
  251. class TestPrim(MultigraphMSTTestBase):
  252. """Unit tests for computing a minimum (or maximum) spanning tree
  253. using Prim's algorithm.
  254. """
  255. algorithm = "prim"
  256. def test_ignore_nan(self):
  257. """Tests that the edges with NaN weights are ignored or
  258. raise an Error based on ignore_nan is true or false"""
  259. H = nx.MultiGraph()
  260. H.add_edge(1, 2, key=1, weight=float("nan"))
  261. H.add_edge(1, 2, key=2, weight=3)
  262. H.add_edge(3, 2, key=1, weight=2)
  263. H.add_edge(3, 1, key=1, weight=4)
  264. # NaN weight edges are ignored when ignore_nan=True
  265. mst_edges = nx.minimum_spanning_edges(H, algorithm=self.algo, ignore_nan=True)
  266. assert edges_equal(
  267. [(1, 2, 2, {"weight": 3}), (2, 3, 1, {"weight": 2})], list(mst_edges)
  268. )
  269. # NaN weight edges raise Error when ignore_nan=False
  270. with pytest.raises(ValueError):
  271. list(nx.minimum_spanning_edges(H, algorithm=self.algo, ignore_nan=False))
  272. def test_multigraph_keys_tree(self):
  273. G = nx.MultiGraph()
  274. G.add_edge(0, 1, key="a", weight=2)
  275. G.add_edge(0, 1, key="b", weight=1)
  276. T = nx.minimum_spanning_tree(G, algorithm=self.algo)
  277. assert edges_equal([(0, 1, 1)], list(T.edges(data="weight")))
  278. def test_multigraph_keys_tree_max(self):
  279. G = nx.MultiGraph()
  280. G.add_edge(0, 1, key="a", weight=2)
  281. G.add_edge(0, 1, key="b", weight=1)
  282. T = nx.maximum_spanning_tree(G, algorithm=self.algo)
  283. assert edges_equal([(0, 1, 2)], list(T.edges(data="weight")))
  284. class TestSpanningTreeIterator:
  285. """
  286. Tests the spanning tree iterator on the example graph in the 2005 Sörensen
  287. and Janssens paper An Algorithm to Generate all Spanning Trees of a Graph in
  288. Order of Increasing Cost
  289. """
  290. def setup_method(self):
  291. # Original Graph
  292. edges = [(0, 1, 5), (1, 2, 4), (1, 4, 6), (2, 3, 5), (2, 4, 7), (3, 4, 3)]
  293. self.G = nx.Graph()
  294. self.G.add_weighted_edges_from(edges)
  295. # List of lists of spanning trees in increasing order
  296. self.spanning_trees = [
  297. # 1, MST, cost = 17
  298. [
  299. (0, 1, {"weight": 5}),
  300. (1, 2, {"weight": 4}),
  301. (2, 3, {"weight": 5}),
  302. (3, 4, {"weight": 3}),
  303. ],
  304. # 2, cost = 18
  305. [
  306. (0, 1, {"weight": 5}),
  307. (1, 2, {"weight": 4}),
  308. (1, 4, {"weight": 6}),
  309. (3, 4, {"weight": 3}),
  310. ],
  311. # 3, cost = 19
  312. [
  313. (0, 1, {"weight": 5}),
  314. (1, 4, {"weight": 6}),
  315. (2, 3, {"weight": 5}),
  316. (3, 4, {"weight": 3}),
  317. ],
  318. # 4, cost = 19
  319. [
  320. (0, 1, {"weight": 5}),
  321. (1, 2, {"weight": 4}),
  322. (2, 4, {"weight": 7}),
  323. (3, 4, {"weight": 3}),
  324. ],
  325. # 5, cost = 20
  326. [
  327. (0, 1, {"weight": 5}),
  328. (1, 2, {"weight": 4}),
  329. (1, 4, {"weight": 6}),
  330. (2, 3, {"weight": 5}),
  331. ],
  332. # 6, cost = 21
  333. [
  334. (0, 1, {"weight": 5}),
  335. (1, 4, {"weight": 6}),
  336. (2, 4, {"weight": 7}),
  337. (3, 4, {"weight": 3}),
  338. ],
  339. # 7, cost = 21
  340. [
  341. (0, 1, {"weight": 5}),
  342. (1, 2, {"weight": 4}),
  343. (2, 3, {"weight": 5}),
  344. (2, 4, {"weight": 7}),
  345. ],
  346. # 8, cost = 23
  347. [
  348. (0, 1, {"weight": 5}),
  349. (1, 4, {"weight": 6}),
  350. (2, 3, {"weight": 5}),
  351. (2, 4, {"weight": 7}),
  352. ],
  353. ]
  354. def test_minimum_spanning_tree_iterator(self):
  355. """
  356. Tests that the spanning trees are correctly returned in increasing order
  357. """
  358. tree_index = 0
  359. for tree in nx.SpanningTreeIterator(self.G):
  360. actual = sorted(tree.edges(data=True))
  361. assert edges_equal(actual, self.spanning_trees[tree_index])
  362. tree_index += 1
  363. def test_maximum_spanning_tree_iterator(self):
  364. """
  365. Tests that the spanning trees are correctly returned in decreasing order
  366. """
  367. tree_index = 7
  368. for tree in nx.SpanningTreeIterator(self.G, minimum=False):
  369. actual = sorted(tree.edges(data=True))
  370. assert edges_equal(actual, self.spanning_trees[tree_index])
  371. tree_index -= 1
  372. def test_random_spanning_tree_multiplicative_small():
  373. """
  374. Using a fixed seed, sample one tree for repeatability.
  375. """
  376. from math import exp
  377. pytest.importorskip("scipy")
  378. gamma = {
  379. (0, 1): -0.6383,
  380. (0, 2): -0.6827,
  381. (0, 5): 0,
  382. (1, 2): -1.0781,
  383. (1, 4): 0,
  384. (2, 3): 0,
  385. (5, 3): -0.2820,
  386. (5, 4): -0.3327,
  387. (4, 3): -0.9927,
  388. }
  389. # The undirected support of gamma
  390. G = nx.Graph()
  391. for u, v in gamma:
  392. G.add_edge(u, v, lambda_key=exp(gamma[(u, v)]))
  393. solution_edges = [(2, 3), (3, 4), (0, 5), (5, 4), (4, 1)]
  394. solution = nx.Graph()
  395. solution.add_edges_from(solution_edges)
  396. sampled_tree = nx.random_spanning_tree(G, "lambda_key", seed=42)
  397. assert nx.utils.edges_equal(solution.edges, sampled_tree.edges)
  398. @pytest.mark.slow
  399. def test_random_spanning_tree_multiplicative_large():
  400. """
  401. Sample many trees from the distribution created in the last test
  402. """
  403. from math import exp
  404. from random import Random
  405. pytest.importorskip("numpy")
  406. stats = pytest.importorskip("scipy.stats")
  407. gamma = {
  408. (0, 1): -0.6383,
  409. (0, 2): -0.6827,
  410. (0, 5): 0,
  411. (1, 2): -1.0781,
  412. (1, 4): 0,
  413. (2, 3): 0,
  414. (5, 3): -0.2820,
  415. (5, 4): -0.3327,
  416. (4, 3): -0.9927,
  417. }
  418. # The undirected support of gamma
  419. G = nx.Graph()
  420. for u, v in gamma:
  421. G.add_edge(u, v, lambda_key=exp(gamma[(u, v)]))
  422. # Find the multiplicative weight for each tree.
  423. total_weight = 0
  424. tree_expected = {}
  425. for t in nx.SpanningTreeIterator(G):
  426. # Find the multiplicative weight of the spanning tree
  427. weight = 1
  428. for u, v, d in t.edges(data="lambda_key"):
  429. weight *= d
  430. tree_expected[t] = weight
  431. total_weight += weight
  432. # Assert that every tree has an entry in the expected distribution
  433. assert len(tree_expected) == 75
  434. # Set the sample size and then calculate the expected number of times we
  435. # expect to see each tree. This test uses a near minimum sample size where
  436. # the most unlikely tree has an expected frequency of 5.15.
  437. # (Minimum required is 5)
  438. #
  439. # Here we also initialize the tree_actual dict so that we know the keys
  440. # match between the two. We will later take advantage of the fact that since
  441. # python 3.7 dict order is guaranteed so the expected and actual data will
  442. # have the same order.
  443. sample_size = 1200
  444. tree_actual = {}
  445. for t in tree_expected:
  446. tree_expected[t] = (tree_expected[t] / total_weight) * sample_size
  447. tree_actual[t] = 0
  448. # Sample the spanning trees
  449. #
  450. # Assert that they are actually trees and record which of the 75 trees we
  451. # have sampled.
  452. #
  453. # For repeatability, we want to take advantage of the decorators in NetworkX
  454. # to randomly sample the same sample each time. However, if we pass in a
  455. # constant seed to sample_spanning_tree we will get the same tree each time.
  456. # Instead, we can create our own random number generator with a fixed seed
  457. # and pass those into sample_spanning_tree.
  458. rng = Random(37)
  459. for _ in range(sample_size):
  460. sampled_tree = nx.random_spanning_tree(G, "lambda_key", seed=rng)
  461. assert nx.is_tree(sampled_tree)
  462. for t in tree_expected:
  463. if nx.utils.edges_equal(t.edges, sampled_tree.edges):
  464. tree_actual[t] += 1
  465. break
  466. # Conduct a Chi squared test to see if the actual distribution matches the
  467. # expected one at an alpha = 0.05 significance level.
  468. #
  469. # H_0: The distribution of trees in tree_actual matches the normalized product
  470. # of the edge weights in the tree.
  471. #
  472. # H_a: The distribution of trees in tree_actual follows some other
  473. # distribution of spanning trees.
  474. _, p = stats.chisquare(list(tree_actual.values()), list(tree_expected.values()))
  475. # Assert that p is greater than the significance level so that we do not
  476. # reject the null hypothesis
  477. assert not p < 0.05
  478. def test_random_spanning_tree_additive_small():
  479. """
  480. Sample a single spanning tree from the additive method.
  481. """
  482. pytest.importorskip("scipy")
  483. edges = {
  484. (0, 1): 1,
  485. (0, 2): 1,
  486. (0, 5): 3,
  487. (1, 2): 2,
  488. (1, 4): 3,
  489. (2, 3): 3,
  490. (5, 3): 4,
  491. (5, 4): 5,
  492. (4, 3): 4,
  493. }
  494. # Build the graph
  495. G = nx.Graph()
  496. for u, v in edges:
  497. G.add_edge(u, v, weight=edges[(u, v)])
  498. solution_edges = [(0, 2), (1, 2), (2, 3), (3, 4), (3, 5)]
  499. solution = nx.Graph()
  500. solution.add_edges_from(solution_edges)
  501. sampled_tree = nx.random_spanning_tree(
  502. G, weight="weight", multiplicative=False, seed=37
  503. )
  504. assert nx.utils.edges_equal(solution.edges, sampled_tree.edges)
  505. @pytest.mark.slow
  506. def test_random_spanning_tree_additive_large():
  507. """
  508. Sample many spanning trees from the additive method.
  509. """
  510. from random import Random
  511. pytest.importorskip("numpy")
  512. stats = pytest.importorskip("scipy.stats")
  513. edges = {
  514. (0, 1): 1,
  515. (0, 2): 1,
  516. (0, 5): 3,
  517. (1, 2): 2,
  518. (1, 4): 3,
  519. (2, 3): 3,
  520. (5, 3): 4,
  521. (5, 4): 5,
  522. (4, 3): 4,
  523. }
  524. # Build the graph
  525. G = nx.Graph()
  526. for u, v in edges:
  527. G.add_edge(u, v, weight=edges[(u, v)])
  528. # Find the additive weight for each tree.
  529. total_weight = 0
  530. tree_expected = {}
  531. for t in nx.SpanningTreeIterator(G):
  532. # Find the multiplicative weight of the spanning tree
  533. weight = 0
  534. for u, v, d in t.edges(data="weight"):
  535. weight += d
  536. tree_expected[t] = weight
  537. total_weight += weight
  538. # Assert that every tree has an entry in the expected distribution
  539. assert len(tree_expected) == 75
  540. # Set the sample size and then calculate the expected number of times we
  541. # expect to see each tree. This test uses a near minimum sample size where
  542. # the most unlikely tree has an expected frequency of 5.07.
  543. # (Minimum required is 5)
  544. #
  545. # Here we also initialize the tree_actual dict so that we know the keys
  546. # match between the two. We will later take advantage of the fact that since
  547. # python 3.7 dict order is guaranteed so the expected and actual data will
  548. # have the same order.
  549. sample_size = 500
  550. tree_actual = {}
  551. for t in tree_expected:
  552. tree_expected[t] = (tree_expected[t] / total_weight) * sample_size
  553. tree_actual[t] = 0
  554. # Sample the spanning trees
  555. #
  556. # Assert that they are actually trees and record which of the 75 trees we
  557. # have sampled.
  558. #
  559. # For repeatability, we want to take advantage of the decorators in NetworkX
  560. # to randomly sample the same sample each time. However, if we pass in a
  561. # constant seed to sample_spanning_tree we will get the same tree each time.
  562. # Instead, we can create our own random number generator with a fixed seed
  563. # and pass those into sample_spanning_tree.
  564. rng = Random(37)
  565. for _ in range(sample_size):
  566. sampled_tree = nx.random_spanning_tree(
  567. G, "weight", multiplicative=False, seed=rng
  568. )
  569. assert nx.is_tree(sampled_tree)
  570. for t in tree_expected:
  571. if nx.utils.edges_equal(t.edges, sampled_tree.edges):
  572. tree_actual[t] += 1
  573. break
  574. # Conduct a Chi squared test to see if the actual distribution matches the
  575. # expected one at an alpha = 0.05 significance level.
  576. #
  577. # H_0: The distribution of trees in tree_actual matches the normalized product
  578. # of the edge weights in the tree.
  579. #
  580. # H_a: The distribution of trees in tree_actual follows some other
  581. # distribution of spanning trees.
  582. _, p = stats.chisquare(list(tree_actual.values()), list(tree_expected.values()))
  583. # Assert that p is greater than the significance level so that we do not
  584. # reject the null hypothesis
  585. assert not p < 0.05