kernighan_lin.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. """Functions for computing the Kernighan–Lin bipartition algorithm."""
  2. from itertools import count
  3. import networkx as nx
  4. from networkx.algorithms.community.community_utils import is_partition
  5. from networkx.utils import BinaryHeap, not_implemented_for, py_random_state
  6. __all__ = ["kernighan_lin_bisection"]
  7. def _kernighan_lin_sweep(edges, side):
  8. """
  9. This is a modified form of Kernighan-Lin, which moves single nodes at a
  10. time, alternating between sides to keep the bisection balanced. We keep
  11. two min-heaps of swap costs to make optimal-next-move selection fast.
  12. """
  13. costs0, costs1 = costs = BinaryHeap(), BinaryHeap()
  14. for u, side_u, edges_u in zip(count(), side, edges):
  15. cost_u = sum(w if side[v] else -w for v, w in edges_u)
  16. costs[side_u].insert(u, cost_u if side_u else -cost_u)
  17. def _update_costs(costs_x, x):
  18. for y, w in edges[x]:
  19. costs_y = costs[side[y]]
  20. cost_y = costs_y.get(y)
  21. if cost_y is not None:
  22. cost_y += 2 * (-w if costs_x is costs_y else w)
  23. costs_y.insert(y, cost_y, True)
  24. i = 0
  25. totcost = 0
  26. while costs0 and costs1:
  27. u, cost_u = costs0.pop()
  28. _update_costs(costs0, u)
  29. v, cost_v = costs1.pop()
  30. _update_costs(costs1, v)
  31. totcost += cost_u + cost_v
  32. i += 1
  33. yield totcost, i, (u, v)
  34. @py_random_state(4)
  35. @not_implemented_for("directed")
  36. def kernighan_lin_bisection(G, partition=None, max_iter=10, weight="weight", seed=None):
  37. """Partition a graph into two blocks using the Kernighan–Lin
  38. algorithm.
  39. This algorithm partitions a network into two sets by iteratively
  40. swapping pairs of nodes to reduce the edge cut between the two sets. The
  41. pairs are chosen according to a modified form of Kernighan-Lin, which
  42. moves node individually, alternating between sides to keep the bisection
  43. balanced.
  44. Parameters
  45. ----------
  46. G : graph
  47. partition : tuple
  48. Pair of iterables containing an initial partition. If not
  49. specified, a random balanced partition is used.
  50. max_iter : int
  51. Maximum number of times to attempt swaps to find an
  52. improvemement before giving up.
  53. weight : key
  54. Edge data key to use as weight. If None, the weights are all
  55. set to one.
  56. seed : integer, random_state, or None (default)
  57. Indicator of random number generation state.
  58. See :ref:`Randomness<randomness>`.
  59. Only used if partition is None
  60. Returns
  61. -------
  62. partition : tuple
  63. A pair of sets of nodes representing the bipartition.
  64. Raises
  65. ------
  66. NetworkXError
  67. If partition is not a valid partition of the nodes of the graph.
  68. References
  69. ----------
  70. .. [1] Kernighan, B. W.; Lin, Shen (1970).
  71. "An efficient heuristic procedure for partitioning graphs."
  72. *Bell Systems Technical Journal* 49: 291--307.
  73. Oxford University Press 2011.
  74. """
  75. n = len(G)
  76. labels = list(G)
  77. seed.shuffle(labels)
  78. index = {v: i for i, v in enumerate(labels)}
  79. if partition is None:
  80. side = [0] * (n // 2) + [1] * ((n + 1) // 2)
  81. else:
  82. try:
  83. A, B = partition
  84. except (TypeError, ValueError) as err:
  85. raise nx.NetworkXError("partition must be two sets") from err
  86. if not is_partition(G, (A, B)):
  87. raise nx.NetworkXError("partition invalid")
  88. side = [0] * n
  89. for a in A:
  90. side[index[a]] = 1
  91. if G.is_multigraph():
  92. edges = [
  93. [
  94. (index[u], sum(e.get(weight, 1) for e in d.values()))
  95. for u, d in G[v].items()
  96. ]
  97. for v in labels
  98. ]
  99. else:
  100. edges = [
  101. [(index[u], e.get(weight, 1)) for u, e in G[v].items()] for v in labels
  102. ]
  103. for i in range(max_iter):
  104. costs = list(_kernighan_lin_sweep(edges, side))
  105. min_cost, min_i, _ = min(costs)
  106. if min_cost >= 0:
  107. break
  108. for _, _, (u, v) in costs[:min_i]:
  109. side[u] = 1
  110. side[v] = 0
  111. A = {u for u, s in zip(labels, side) if s == 0}
  112. B = {u for u, s in zip(labels, side) if s == 1}
  113. return A, B