node_classification.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. """ This module provides the functions for node classification problem.
  2. The functions in this module are not imported
  3. into the top level `networkx` namespace.
  4. You can access these functions by importing
  5. the `networkx.algorithms.node_classification` modules,
  6. then accessing the functions as attributes of `node_classification`.
  7. For example:
  8. >>> from networkx.algorithms import node_classification
  9. >>> G = nx.path_graph(4)
  10. >>> G.edges()
  11. EdgeView([(0, 1), (1, 2), (2, 3)])
  12. >>> G.nodes[0]["label"] = "A"
  13. >>> G.nodes[3]["label"] = "B"
  14. >>> node_classification.harmonic_function(G)
  15. ['A', 'A', 'B', 'B']
  16. References
  17. ----------
  18. Zhu, X., Ghahramani, Z., & Lafferty, J. (2003, August).
  19. Semi-supervised learning using gaussian fields and harmonic functions.
  20. In ICML (Vol. 3, pp. 912-919).
  21. """
  22. import networkx as nx
  23. __all__ = ["harmonic_function", "local_and_global_consistency"]
  24. @nx.utils.not_implemented_for("directed")
  25. def harmonic_function(G, max_iter=30, label_name="label"):
  26. """Node classification by Harmonic function
  27. Function for computing Harmonic function algorithm by Zhu et al.
  28. Parameters
  29. ----------
  30. G : NetworkX Graph
  31. max_iter : int
  32. maximum number of iterations allowed
  33. label_name : string
  34. name of target labels to predict
  35. Returns
  36. -------
  37. predicted : list
  38. List of length ``len(G)`` with the predicted labels for each node.
  39. Raises
  40. ------
  41. NetworkXError
  42. If no nodes in `G` have attribute `label_name`.
  43. Examples
  44. --------
  45. >>> from networkx.algorithms import node_classification
  46. >>> G = nx.path_graph(4)
  47. >>> G.nodes[0]["label"] = "A"
  48. >>> G.nodes[3]["label"] = "B"
  49. >>> G.nodes(data=True)
  50. NodeDataView({0: {'label': 'A'}, 1: {}, 2: {}, 3: {'label': 'B'}})
  51. >>> G.edges()
  52. EdgeView([(0, 1), (1, 2), (2, 3)])
  53. >>> predicted = node_classification.harmonic_function(G)
  54. >>> predicted
  55. ['A', 'A', 'B', 'B']
  56. References
  57. ----------
  58. Zhu, X., Ghahramani, Z., & Lafferty, J. (2003, August).
  59. Semi-supervised learning using gaussian fields and harmonic functions.
  60. In ICML (Vol. 3, pp. 912-919).
  61. """
  62. import numpy as np
  63. import scipy as sp
  64. import scipy.sparse # call as sp.sparse
  65. X = nx.to_scipy_sparse_array(G) # adjacency matrix
  66. labels, label_dict = _get_label_info(G, label_name)
  67. if labels.shape[0] == 0:
  68. raise nx.NetworkXError(
  69. f"No node on the input graph is labeled by '{label_name}'."
  70. )
  71. n_samples = X.shape[0]
  72. n_classes = label_dict.shape[0]
  73. F = np.zeros((n_samples, n_classes))
  74. # Build propagation matrix
  75. degrees = X.sum(axis=0)
  76. degrees[degrees == 0] = 1 # Avoid division by 0
  77. # TODO: csr_array
  78. D = sp.sparse.csr_array(sp.sparse.diags((1.0 / degrees), offsets=0))
  79. P = (D @ X).tolil()
  80. P[labels[:, 0]] = 0 # labels[:, 0] indicates IDs of labeled nodes
  81. # Build base matrix
  82. B = np.zeros((n_samples, n_classes))
  83. B[labels[:, 0], labels[:, 1]] = 1
  84. for _ in range(max_iter):
  85. F = (P @ F) + B
  86. return label_dict[np.argmax(F, axis=1)].tolist()
  87. @nx.utils.not_implemented_for("directed")
  88. def local_and_global_consistency(G, alpha=0.99, max_iter=30, label_name="label"):
  89. """Node classification by Local and Global Consistency
  90. Function for computing Local and global consistency algorithm by Zhou et al.
  91. Parameters
  92. ----------
  93. G : NetworkX Graph
  94. alpha : float
  95. Clamping factor
  96. max_iter : int
  97. Maximum number of iterations allowed
  98. label_name : string
  99. Name of target labels to predict
  100. Returns
  101. -------
  102. predicted : list
  103. List of length ``len(G)`` with the predicted labels for each node.
  104. Raises
  105. ------
  106. NetworkXError
  107. If no nodes in `G` have attribute `label_name`.
  108. Examples
  109. --------
  110. >>> from networkx.algorithms import node_classification
  111. >>> G = nx.path_graph(4)
  112. >>> G.nodes[0]["label"] = "A"
  113. >>> G.nodes[3]["label"] = "B"
  114. >>> G.nodes(data=True)
  115. NodeDataView({0: {'label': 'A'}, 1: {}, 2: {}, 3: {'label': 'B'}})
  116. >>> G.edges()
  117. EdgeView([(0, 1), (1, 2), (2, 3)])
  118. >>> predicted = node_classification.local_and_global_consistency(G)
  119. >>> predicted
  120. ['A', 'A', 'B', 'B']
  121. References
  122. ----------
  123. Zhou, D., Bousquet, O., Lal, T. N., Weston, J., & Schölkopf, B. (2004).
  124. Learning with local and global consistency.
  125. Advances in neural information processing systems, 16(16), 321-328.
  126. """
  127. import numpy as np
  128. import scipy as sp
  129. import scipy.sparse # call as sp.sparse
  130. X = nx.to_scipy_sparse_array(G) # adjacency matrix
  131. labels, label_dict = _get_label_info(G, label_name)
  132. if labels.shape[0] == 0:
  133. raise nx.NetworkXError(
  134. f"No node on the input graph is labeled by '{label_name}'."
  135. )
  136. n_samples = X.shape[0]
  137. n_classes = label_dict.shape[0]
  138. F = np.zeros((n_samples, n_classes))
  139. # Build propagation matrix
  140. degrees = X.sum(axis=0)
  141. degrees[degrees == 0] = 1 # Avoid division by 0
  142. # TODO: csr_array
  143. D2 = np.sqrt(sp.sparse.csr_array(sp.sparse.diags((1.0 / degrees), offsets=0)))
  144. P = alpha * ((D2 @ X) @ D2)
  145. # Build base matrix
  146. B = np.zeros((n_samples, n_classes))
  147. B[labels[:, 0], labels[:, 1]] = 1 - alpha
  148. for _ in range(max_iter):
  149. F = (P @ F) + B
  150. return label_dict[np.argmax(F, axis=1)].tolist()
  151. def _get_label_info(G, label_name):
  152. """Get and return information of labels from the input graph
  153. Parameters
  154. ----------
  155. G : Network X graph
  156. label_name : string
  157. Name of the target label
  158. Returns
  159. ----------
  160. labels : numpy array, shape = [n_labeled_samples, 2]
  161. Array of pairs of labeled node ID and label ID
  162. label_dict : numpy array, shape = [n_classes]
  163. Array of labels
  164. i-th element contains the label corresponding label ID `i`
  165. """
  166. import numpy as np
  167. labels = []
  168. label_to_id = {}
  169. lid = 0
  170. for i, n in enumerate(G.nodes(data=True)):
  171. if label_name in n[1]:
  172. label = n[1][label_name]
  173. if label not in label_to_id:
  174. label_to_id[label] = lid
  175. lid += 1
  176. labels.append([i, label_to_id[label]])
  177. labels = np.array(labels)
  178. label_dict = np.array(
  179. [label for label, _ in sorted(label_to_id.items(), key=lambda x: x[1])]
  180. )
  181. return (labels, label_dict)