lukes.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. """Lukes Algorithm for exact optimal weighted tree partitioning."""
  2. from copy import deepcopy
  3. from functools import lru_cache
  4. from random import choice
  5. import networkx as nx
  6. from networkx.utils import not_implemented_for
  7. __all__ = ["lukes_partitioning"]
  8. D_EDGE_W = "weight"
  9. D_EDGE_VALUE = 1.0
  10. D_NODE_W = "weight"
  11. D_NODE_VALUE = 1
  12. PKEY = "partitions"
  13. CLUSTER_EVAL_CACHE_SIZE = 2048
  14. def _split_n_from(n, min_size_of_first_part):
  15. # splits j in two parts of which the first is at least
  16. # the second argument
  17. assert n >= min_size_of_first_part
  18. for p1 in range(min_size_of_first_part, n + 1):
  19. yield p1, n - p1
  20. def lukes_partitioning(G, max_size, node_weight=None, edge_weight=None):
  21. """Optimal partitioning of a weighted tree using the Lukes algorithm.
  22. This algorithm partitions a connected, acyclic graph featuring integer
  23. node weights and float edge weights. The resulting clusters are such
  24. that the total weight of the nodes in each cluster does not exceed
  25. max_size and that the weight of the edges that are cut by the partition
  26. is minimum. The algorithm is based on LUKES[1].
  27. Parameters
  28. ----------
  29. G : graph
  30. max_size : int
  31. Maximum weight a partition can have in terms of sum of
  32. node_weight for all nodes in the partition
  33. edge_weight : key
  34. Edge data key to use as weight. If None, the weights are all
  35. set to one.
  36. node_weight : key
  37. Node data key to use as weight. If None, the weights are all
  38. set to one. The data must be int.
  39. Returns
  40. -------
  41. partition : list
  42. A list of sets of nodes representing the clusters of the
  43. partition.
  44. Raises
  45. ------
  46. NotATree
  47. If G is not a tree.
  48. TypeError
  49. If any of the values of node_weight is not int.
  50. References
  51. ----------
  52. .. Lukes, J. A. (1974).
  53. "Efficient Algorithm for the Partitioning of Trees."
  54. IBM Journal of Research and Development, 18(3), 217–224.
  55. """
  56. # First sanity check and tree preparation
  57. if not nx.is_tree(G):
  58. raise nx.NotATree("lukes_partitioning works only on trees")
  59. else:
  60. if nx.is_directed(G):
  61. root = [n for n, d in G.in_degree() if d == 0]
  62. assert len(root) == 1
  63. root = root[0]
  64. t_G = deepcopy(G)
  65. else:
  66. root = choice(list(G.nodes))
  67. # this has the desirable side effect of not inheriting attributes
  68. t_G = nx.dfs_tree(G, root)
  69. # Since we do not want to screw up the original graph,
  70. # if we have a blank attribute, we make a deepcopy
  71. if edge_weight is None or node_weight is None:
  72. safe_G = deepcopy(G)
  73. if edge_weight is None:
  74. nx.set_edge_attributes(safe_G, D_EDGE_VALUE, D_EDGE_W)
  75. edge_weight = D_EDGE_W
  76. if node_weight is None:
  77. nx.set_node_attributes(safe_G, D_NODE_VALUE, D_NODE_W)
  78. node_weight = D_NODE_W
  79. else:
  80. safe_G = G
  81. # Second sanity check
  82. # The values of node_weight MUST BE int.
  83. # I cannot see any room for duck typing without incurring serious
  84. # danger of subtle bugs.
  85. all_n_attr = nx.get_node_attributes(safe_G, node_weight).values()
  86. for x in all_n_attr:
  87. if not isinstance(x, int):
  88. raise TypeError(
  89. "lukes_partitioning needs integer "
  90. f"values for node_weight ({node_weight})"
  91. )
  92. # SUBROUTINES -----------------------
  93. # these functions are defined here for two reasons:
  94. # - brevity: we can leverage global "safe_G"
  95. # - caching: signatures are hashable
  96. @not_implemented_for("undirected")
  97. # this is intended to be called only on t_G
  98. def _leaves(gr):
  99. for x in gr.nodes:
  100. if not nx.descendants(gr, x):
  101. yield x
  102. @not_implemented_for("undirected")
  103. def _a_parent_of_leaves_only(gr):
  104. tleaves = set(_leaves(gr))
  105. for n in set(gr.nodes) - tleaves:
  106. if all(x in tleaves for x in nx.descendants(gr, n)):
  107. return n
  108. @lru_cache(CLUSTER_EVAL_CACHE_SIZE)
  109. def _value_of_cluster(cluster):
  110. valid_edges = [e for e in safe_G.edges if e[0] in cluster and e[1] in cluster]
  111. return sum(safe_G.edges[e][edge_weight] for e in valid_edges)
  112. def _value_of_partition(partition):
  113. return sum(_value_of_cluster(frozenset(c)) for c in partition)
  114. @lru_cache(CLUSTER_EVAL_CACHE_SIZE)
  115. def _weight_of_cluster(cluster):
  116. return sum(safe_G.nodes[n][node_weight] for n in cluster)
  117. def _pivot(partition, node):
  118. ccx = [c for c in partition if node in c]
  119. assert len(ccx) == 1
  120. return ccx[0]
  121. def _concatenate_or_merge(partition_1, partition_2, x, i, ref_weight):
  122. ccx = _pivot(partition_1, x)
  123. cci = _pivot(partition_2, i)
  124. merged_xi = ccx.union(cci)
  125. # We first check if we can do the merge.
  126. # If so, we do the actual calculations, otherwise we concatenate
  127. if _weight_of_cluster(frozenset(merged_xi)) <= ref_weight:
  128. cp1 = list(filter(lambda x: x != ccx, partition_1))
  129. cp2 = list(filter(lambda x: x != cci, partition_2))
  130. option_2 = [merged_xi] + cp1 + cp2
  131. return option_2, _value_of_partition(option_2)
  132. else:
  133. option_1 = partition_1 + partition_2
  134. return option_1, _value_of_partition(option_1)
  135. # INITIALIZATION -----------------------
  136. leaves = set(_leaves(t_G))
  137. for lv in leaves:
  138. t_G.nodes[lv][PKEY] = {}
  139. slot = safe_G.nodes[lv][node_weight]
  140. t_G.nodes[lv][PKEY][slot] = [{lv}]
  141. t_G.nodes[lv][PKEY][0] = [{lv}]
  142. for inner in [x for x in t_G.nodes if x not in leaves]:
  143. t_G.nodes[inner][PKEY] = {}
  144. slot = safe_G.nodes[inner][node_weight]
  145. t_G.nodes[inner][PKEY][slot] = [{inner}]
  146. # CORE ALGORITHM -----------------------
  147. while True:
  148. x_node = _a_parent_of_leaves_only(t_G)
  149. weight_of_x = safe_G.nodes[x_node][node_weight]
  150. best_value = 0
  151. best_partition = None
  152. bp_buffer = {}
  153. x_descendants = nx.descendants(t_G, x_node)
  154. for i_node in x_descendants:
  155. for j in range(weight_of_x, max_size + 1):
  156. for a, b in _split_n_from(j, weight_of_x):
  157. if (
  158. a not in t_G.nodes[x_node][PKEY].keys()
  159. or b not in t_G.nodes[i_node][PKEY].keys()
  160. ):
  161. # it's not possible to form this particular weight sum
  162. continue
  163. part1 = t_G.nodes[x_node][PKEY][a]
  164. part2 = t_G.nodes[i_node][PKEY][b]
  165. part, value = _concatenate_or_merge(part1, part2, x_node, i_node, j)
  166. if j not in bp_buffer.keys() or bp_buffer[j][1] < value:
  167. # we annotate in the buffer the best partition for j
  168. bp_buffer[j] = part, value
  169. # we also keep track of the overall best partition
  170. if best_value <= value:
  171. best_value = value
  172. best_partition = part
  173. # as illustrated in Lukes, once we finished a child, we can
  174. # discharge the partitions we found into the graph
  175. # (the key phrase is make all x == x')
  176. # so that they are used by the subsequent children
  177. for w, (best_part_for_vl, vl) in bp_buffer.items():
  178. t_G.nodes[x_node][PKEY][w] = best_part_for_vl
  179. bp_buffer.clear()
  180. # the absolute best partition for this node
  181. # across all weights has to be stored at 0
  182. t_G.nodes[x_node][PKEY][0] = best_partition
  183. t_G.remove_nodes_from(x_descendants)
  184. if x_node == root:
  185. # the 0-labeled partition of root
  186. # is the optimal one for the whole tree
  187. return t_G.nodes[root][PKEY][0]