test_spanning_tree.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. """Test the minimum spanning tree function"""
  2. import numpy as np
  3. from numpy.testing import assert_
  4. import numpy.testing as npt
  5. from scipy.sparse import csr_matrix
  6. from scipy.sparse.csgraph import minimum_spanning_tree
  7. def test_minimum_spanning_tree():
  8. # Create a graph with two connected components.
  9. graph = [[0,1,0,0,0],
  10. [1,0,0,0,0],
  11. [0,0,0,8,5],
  12. [0,0,8,0,1],
  13. [0,0,5,1,0]]
  14. graph = np.asarray(graph)
  15. # Create the expected spanning tree.
  16. expected = [[0,1,0,0,0],
  17. [0,0,0,0,0],
  18. [0,0,0,0,5],
  19. [0,0,0,0,1],
  20. [0,0,0,0,0]]
  21. expected = np.asarray(expected)
  22. # Ensure minimum spanning tree code gives this expected output.
  23. csgraph = csr_matrix(graph)
  24. mintree = minimum_spanning_tree(csgraph)
  25. npt.assert_array_equal(mintree.toarray(), expected,
  26. 'Incorrect spanning tree found.')
  27. # Ensure that the original graph was not modified.
  28. npt.assert_array_equal(csgraph.toarray(), graph,
  29. 'Original graph was modified.')
  30. # Now let the algorithm modify the csgraph in place.
  31. mintree = minimum_spanning_tree(csgraph, overwrite=True)
  32. npt.assert_array_equal(mintree.toarray(), expected,
  33. 'Graph was not properly modified to contain MST.')
  34. np.random.seed(1234)
  35. for N in (5, 10, 15, 20):
  36. # Create a random graph.
  37. graph = 3 + np.random.random((N, N))
  38. csgraph = csr_matrix(graph)
  39. # The spanning tree has at most N - 1 edges.
  40. mintree = minimum_spanning_tree(csgraph)
  41. assert_(mintree.nnz < N)
  42. # Set the sub diagonal to 1 to create a known spanning tree.
  43. idx = np.arange(N-1)
  44. graph[idx,idx+1] = 1
  45. csgraph = csr_matrix(graph)
  46. mintree = minimum_spanning_tree(csgraph)
  47. # We expect to see this pattern in the spanning tree and otherwise
  48. # have this zero.
  49. expected = np.zeros((N, N))
  50. expected[idx, idx+1] = 1
  51. npt.assert_array_equal(mintree.toarray(), expected,
  52. 'Incorrect spanning tree found.')