test_sparsifiers.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. """Unit tests for the sparsifier computation functions."""
  2. import pytest
  3. import networkx as nx
  4. from networkx.utils import py_random_state
  5. _seed = 2
  6. def _test_spanner(G, spanner, stretch, weight=None):
  7. """Test whether a spanner is valid.
  8. This function tests whether the given spanner is a subgraph of the
  9. given graph G with the same node set. It also tests for all shortest
  10. paths whether they adhere to the given stretch.
  11. Parameters
  12. ----------
  13. G : NetworkX graph
  14. The original graph for which the spanner was constructed.
  15. spanner : NetworkX graph
  16. The spanner to be tested.
  17. stretch : float
  18. The proclaimed stretch of the spanner.
  19. weight : object
  20. The edge attribute to use as distance.
  21. """
  22. # check node set
  23. assert set(G.nodes()) == set(spanner.nodes())
  24. # check edge set and weights
  25. for u, v in spanner.edges():
  26. assert G.has_edge(u, v)
  27. if weight:
  28. assert spanner[u][v][weight] == G[u][v][weight]
  29. # check connectivity and stretch
  30. original_length = dict(nx.shortest_path_length(G, weight=weight))
  31. spanner_length = dict(nx.shortest_path_length(spanner, weight=weight))
  32. for u in G.nodes():
  33. for v in G.nodes():
  34. if u in original_length and v in original_length[u]:
  35. assert spanner_length[u][v] <= stretch * original_length[u][v]
  36. @py_random_state(1)
  37. def _assign_random_weights(G, seed=None):
  38. """Assigns random weights to the edges of a graph.
  39. Parameters
  40. ----------
  41. G : NetworkX graph
  42. The original graph for which the spanner was constructed.
  43. seed : integer, random_state, or None (default)
  44. Indicator of random number generation state.
  45. See :ref:`Randomness<randomness>`.
  46. """
  47. for u, v in G.edges():
  48. G[u][v]["weight"] = seed.random()
  49. def test_spanner_trivial():
  50. """Test a trivial spanner with stretch 1."""
  51. G = nx.complete_graph(20)
  52. spanner = nx.spanner(G, 1, seed=_seed)
  53. for u, v in G.edges:
  54. assert spanner.has_edge(u, v)
  55. def test_spanner_unweighted_complete_graph():
  56. """Test spanner construction on a complete unweighted graph."""
  57. G = nx.complete_graph(20)
  58. spanner = nx.spanner(G, 4, seed=_seed)
  59. _test_spanner(G, spanner, 4)
  60. spanner = nx.spanner(G, 10, seed=_seed)
  61. _test_spanner(G, spanner, 10)
  62. def test_spanner_weighted_complete_graph():
  63. """Test spanner construction on a complete weighted graph."""
  64. G = nx.complete_graph(20)
  65. _assign_random_weights(G, seed=_seed)
  66. spanner = nx.spanner(G, 4, weight="weight", seed=_seed)
  67. _test_spanner(G, spanner, 4, weight="weight")
  68. spanner = nx.spanner(G, 10, weight="weight", seed=_seed)
  69. _test_spanner(G, spanner, 10, weight="weight")
  70. def test_spanner_unweighted_gnp_graph():
  71. """Test spanner construction on an unweighted gnp graph."""
  72. G = nx.gnp_random_graph(20, 0.4, seed=_seed)
  73. spanner = nx.spanner(G, 4, seed=_seed)
  74. _test_spanner(G, spanner, 4)
  75. spanner = nx.spanner(G, 10, seed=_seed)
  76. _test_spanner(G, spanner, 10)
  77. def test_spanner_weighted_gnp_graph():
  78. """Test spanner construction on an weighted gnp graph."""
  79. G = nx.gnp_random_graph(20, 0.4, seed=_seed)
  80. _assign_random_weights(G, seed=_seed)
  81. spanner = nx.spanner(G, 4, weight="weight", seed=_seed)
  82. _test_spanner(G, spanner, 4, weight="weight")
  83. spanner = nx.spanner(G, 10, weight="weight", seed=_seed)
  84. _test_spanner(G, spanner, 10, weight="weight")
  85. def test_spanner_unweighted_disconnected_graph():
  86. """Test spanner construction on a disconnected graph."""
  87. G = nx.disjoint_union(nx.complete_graph(10), nx.complete_graph(10))
  88. spanner = nx.spanner(G, 4, seed=_seed)
  89. _test_spanner(G, spanner, 4)
  90. spanner = nx.spanner(G, 10, seed=_seed)
  91. _test_spanner(G, spanner, 10)
  92. def test_spanner_invalid_stretch():
  93. """Check whether an invalid stretch is caught."""
  94. with pytest.raises(ValueError):
  95. G = nx.empty_graph()
  96. nx.spanner(G, 0)