test_node_classification.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import pytest
  2. pytest.importorskip("numpy")
  3. pytest.importorskip("scipy")
  4. import networkx as nx
  5. from networkx.algorithms import node_classification
  6. class TestHarmonicFunction:
  7. def test_path_graph(self):
  8. G = nx.path_graph(4)
  9. label_name = "label"
  10. G.nodes[0][label_name] = "A"
  11. G.nodes[3][label_name] = "B"
  12. predicted = node_classification.harmonic_function(G, label_name=label_name)
  13. assert predicted[0] == "A"
  14. assert predicted[1] == "A"
  15. assert predicted[2] == "B"
  16. assert predicted[3] == "B"
  17. def test_no_labels(self):
  18. with pytest.raises(nx.NetworkXError):
  19. G = nx.path_graph(4)
  20. node_classification.harmonic_function(G)
  21. def test_no_nodes(self):
  22. with pytest.raises(nx.NetworkXError):
  23. G = nx.Graph()
  24. node_classification.harmonic_function(G)
  25. def test_no_edges(self):
  26. with pytest.raises(nx.NetworkXError):
  27. G = nx.Graph()
  28. G.add_node(1)
  29. G.add_node(2)
  30. node_classification.harmonic_function(G)
  31. def test_digraph(self):
  32. with pytest.raises(nx.NetworkXNotImplemented):
  33. G = nx.DiGraph()
  34. G.add_edge(0, 1)
  35. G.add_edge(1, 2)
  36. G.add_edge(2, 3)
  37. label_name = "label"
  38. G.nodes[0][label_name] = "A"
  39. G.nodes[3][label_name] = "B"
  40. node_classification.harmonic_function(G)
  41. def test_one_labeled_node(self):
  42. G = nx.path_graph(4)
  43. label_name = "label"
  44. G.nodes[0][label_name] = "A"
  45. predicted = node_classification.harmonic_function(G, label_name=label_name)
  46. assert predicted[0] == "A"
  47. assert predicted[1] == "A"
  48. assert predicted[2] == "A"
  49. assert predicted[3] == "A"
  50. def test_nodes_all_labeled(self):
  51. G = nx.karate_club_graph()
  52. label_name = "club"
  53. predicted = node_classification.harmonic_function(G, label_name=label_name)
  54. for i in range(len(G)):
  55. assert predicted[i] == G.nodes[i][label_name]
  56. def test_labeled_nodes_are_not_changed(self):
  57. G = nx.karate_club_graph()
  58. label_name = "club"
  59. label_removed = {0, 1, 2, 3, 4, 5, 6, 7}
  60. for i in label_removed:
  61. del G.nodes[i][label_name]
  62. predicted = node_classification.harmonic_function(G, label_name=label_name)
  63. label_not_removed = set(range(len(G))) - label_removed
  64. for i in label_not_removed:
  65. assert predicted[i] == G.nodes[i][label_name]
  66. class TestLocalAndGlobalConsistency:
  67. def test_path_graph(self):
  68. G = nx.path_graph(4)
  69. label_name = "label"
  70. G.nodes[0][label_name] = "A"
  71. G.nodes[3][label_name] = "B"
  72. predicted = node_classification.local_and_global_consistency(
  73. G, label_name=label_name
  74. )
  75. assert predicted[0] == "A"
  76. assert predicted[1] == "A"
  77. assert predicted[2] == "B"
  78. assert predicted[3] == "B"
  79. def test_no_labels(self):
  80. with pytest.raises(nx.NetworkXError):
  81. G = nx.path_graph(4)
  82. node_classification.local_and_global_consistency(G)
  83. def test_no_nodes(self):
  84. with pytest.raises(nx.NetworkXError):
  85. G = nx.Graph()
  86. node_classification.local_and_global_consistency(G)
  87. def test_no_edges(self):
  88. with pytest.raises(nx.NetworkXError):
  89. G = nx.Graph()
  90. G.add_node(1)
  91. G.add_node(2)
  92. node_classification.local_and_global_consistency(G)
  93. def test_digraph(self):
  94. with pytest.raises(nx.NetworkXNotImplemented):
  95. G = nx.DiGraph()
  96. G.add_edge(0, 1)
  97. G.add_edge(1, 2)
  98. G.add_edge(2, 3)
  99. label_name = "label"
  100. G.nodes[0][label_name] = "A"
  101. G.nodes[3][label_name] = "B"
  102. node_classification.harmonic_function(G)
  103. def test_one_labeled_node(self):
  104. G = nx.path_graph(4)
  105. label_name = "label"
  106. G.nodes[0][label_name] = "A"
  107. predicted = node_classification.local_and_global_consistency(
  108. G, label_name=label_name
  109. )
  110. assert predicted[0] == "A"
  111. assert predicted[1] == "A"
  112. assert predicted[2] == "A"
  113. assert predicted[3] == "A"
  114. def test_nodes_all_labeled(self):
  115. G = nx.karate_club_graph()
  116. label_name = "club"
  117. predicted = node_classification.local_and_global_consistency(
  118. G, alpha=0, label_name=label_name
  119. )
  120. for i in range(len(G)):
  121. assert predicted[i] == G.nodes[i][label_name]