branchings.py 35 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048
  1. """
  2. Algorithms for finding optimum branchings and spanning arborescences.
  3. This implementation is based on:
  4. J. Edmonds, Optimum branchings, J. Res. Natl. Bur. Standards 71B (1967),
  5. 233–240. URL: http://archive.org/details/jresv71Bn4p233
  6. """
  7. # TODO: Implement method from Gabow, Galil, Spence and Tarjan:
  8. #
  9. # @article{
  10. # year={1986},
  11. # issn={0209-9683},
  12. # journal={Combinatorica},
  13. # volume={6},
  14. # number={2},
  15. # doi={10.1007/BF02579168},
  16. # title={Efficient algorithms for finding minimum spanning trees in
  17. # undirected and directed graphs},
  18. # url={https://doi.org/10.1007/BF02579168},
  19. # publisher={Springer-Verlag},
  20. # keywords={68 B 15; 68 C 05},
  21. # author={Gabow, Harold N. and Galil, Zvi and Spencer, Thomas and Tarjan,
  22. # Robert E.},
  23. # pages={109-122},
  24. # language={English}
  25. # }
  26. import string
  27. from dataclasses import dataclass, field
  28. from enum import Enum
  29. from operator import itemgetter
  30. from queue import PriorityQueue
  31. import networkx as nx
  32. from networkx.utils import py_random_state
  33. from .recognition import is_arborescence, is_branching
  34. __all__ = [
  35. "branching_weight",
  36. "greedy_branching",
  37. "maximum_branching",
  38. "minimum_branching",
  39. "maximum_spanning_arborescence",
  40. "minimum_spanning_arborescence",
  41. "ArborescenceIterator",
  42. "Edmonds",
  43. ]
  44. KINDS = {"max", "min"}
  45. STYLES = {
  46. "branching": "branching",
  47. "arborescence": "arborescence",
  48. "spanning arborescence": "arborescence",
  49. }
  50. INF = float("inf")
  51. @py_random_state(1)
  52. def random_string(L=15, seed=None):
  53. return "".join([seed.choice(string.ascii_letters) for n in range(L)])
  54. def _min_weight(weight):
  55. return -weight
  56. def _max_weight(weight):
  57. return weight
  58. def branching_weight(G, attr="weight", default=1):
  59. """
  60. Returns the total weight of a branching.
  61. You must access this function through the networkx.algorithms.tree module.
  62. Parameters
  63. ----------
  64. G : DiGraph
  65. The directed graph.
  66. attr : str
  67. The attribute to use as weights. If None, then each edge will be
  68. treated equally with a weight of 1.
  69. default : float
  70. When `attr` is not None, then if an edge does not have that attribute,
  71. `default` specifies what value it should take.
  72. Returns
  73. -------
  74. weight: int or float
  75. The total weight of the branching.
  76. Examples
  77. --------
  78. >>> G = nx.DiGraph()
  79. >>> G.add_weighted_edges_from([(0, 1, 2), (1, 2, 4), (2, 3, 3), (3, 4, 2)])
  80. >>> nx.tree.branching_weight(G)
  81. 11
  82. """
  83. return sum(edge[2].get(attr, default) for edge in G.edges(data=True))
  84. @py_random_state(4)
  85. def greedy_branching(G, attr="weight", default=1, kind="max", seed=None):
  86. """
  87. Returns a branching obtained through a greedy algorithm.
  88. This algorithm is wrong, and cannot give a proper optimal branching.
  89. However, we include it for pedagogical reasons, as it can be helpful to
  90. see what its outputs are.
  91. The output is a branching, and possibly, a spanning arborescence. However,
  92. it is not guaranteed to be optimal in either case.
  93. Parameters
  94. ----------
  95. G : DiGraph
  96. The directed graph to scan.
  97. attr : str
  98. The attribute to use as weights. If None, then each edge will be
  99. treated equally with a weight of 1.
  100. default : float
  101. When `attr` is not None, then if an edge does not have that attribute,
  102. `default` specifies what value it should take.
  103. kind : str
  104. The type of optimum to search for: 'min' or 'max' greedy branching.
  105. seed : integer, random_state, or None (default)
  106. Indicator of random number generation state.
  107. See :ref:`Randomness<randomness>`.
  108. Returns
  109. -------
  110. B : directed graph
  111. The greedily obtained branching.
  112. """
  113. if kind not in KINDS:
  114. raise nx.NetworkXException("Unknown value for `kind`.")
  115. if kind == "min":
  116. reverse = False
  117. else:
  118. reverse = True
  119. if attr is None:
  120. # Generate a random string the graph probably won't have.
  121. attr = random_string(seed=seed)
  122. edges = [(u, v, data.get(attr, default)) for (u, v, data) in G.edges(data=True)]
  123. # We sort by weight, but also by nodes to normalize behavior across runs.
  124. try:
  125. edges.sort(key=itemgetter(2, 0, 1), reverse=reverse)
  126. except TypeError:
  127. # This will fail in Python 3.x if the nodes are of varying types.
  128. # In that case, we use the arbitrary order.
  129. edges.sort(key=itemgetter(2), reverse=reverse)
  130. # The branching begins with a forest of no edges.
  131. B = nx.DiGraph()
  132. B.add_nodes_from(G)
  133. # Now we add edges greedily so long we maintain the branching.
  134. uf = nx.utils.UnionFind()
  135. for i, (u, v, w) in enumerate(edges):
  136. if uf[u] == uf[v]:
  137. # Adding this edge would form a directed cycle.
  138. continue
  139. elif B.in_degree(v) == 1:
  140. # The edge would increase the degree to be greater than one.
  141. continue
  142. else:
  143. # If attr was None, then don't insert weights...
  144. data = {}
  145. if attr is not None:
  146. data[attr] = w
  147. B.add_edge(u, v, **data)
  148. uf.union(u, v)
  149. return B
  150. class MultiDiGraph_EdgeKey(nx.MultiDiGraph):
  151. """
  152. MultiDiGraph which assigns unique keys to every edge.
  153. Adds a dictionary edge_index which maps edge keys to (u, v, data) tuples.
  154. This is not a complete implementation. For Edmonds algorithm, we only use
  155. add_node and add_edge, so that is all that is implemented here. During
  156. additions, any specified keys are ignored---this means that you also
  157. cannot update edge attributes through add_node and add_edge.
  158. Why do we need this? Edmonds algorithm requires that we track edges, even
  159. as we change the head and tail of an edge, and even changing the weight
  160. of edges. We must reliably track edges across graph mutations.
  161. """
  162. def __init__(self, incoming_graph_data=None, **attr):
  163. cls = super()
  164. cls.__init__(incoming_graph_data=incoming_graph_data, **attr)
  165. self._cls = cls
  166. self.edge_index = {}
  167. def remove_node(self, n):
  168. keys = set()
  169. for keydict in self.pred[n].values():
  170. keys.update(keydict)
  171. for keydict in self.succ[n].values():
  172. keys.update(keydict)
  173. for key in keys:
  174. del self.edge_index[key]
  175. self._cls.remove_node(n)
  176. def remove_nodes_from(self, nbunch):
  177. for n in nbunch:
  178. self.remove_node(n)
  179. def add_edge(self, u_for_edge, v_for_edge, key_for_edge, **attr):
  180. """
  181. Key is now required.
  182. """
  183. u, v, key = u_for_edge, v_for_edge, key_for_edge
  184. if key in self.edge_index:
  185. uu, vv, _ = self.edge_index[key]
  186. if (u != uu) or (v != vv):
  187. raise Exception(f"Key {key!r} is already in use.")
  188. self._cls.add_edge(u, v, key, **attr)
  189. self.edge_index[key] = (u, v, self.succ[u][v][key])
  190. def add_edges_from(self, ebunch_to_add, **attr):
  191. for u, v, k, d in ebunch_to_add:
  192. self.add_edge(u, v, k, **d)
  193. def remove_edge_with_key(self, key):
  194. try:
  195. u, v, _ = self.edge_index[key]
  196. except KeyError as err:
  197. raise KeyError(f"Invalid edge key {key!r}") from err
  198. else:
  199. del self.edge_index[key]
  200. self._cls.remove_edge(u, v, key)
  201. def remove_edges_from(self, ebunch):
  202. raise NotImplementedError
  203. def get_path(G, u, v):
  204. """
  205. Returns the edge keys of the unique path between u and v.
  206. This is not a generic function. G must be a branching and an instance of
  207. MultiDiGraph_EdgeKey.
  208. """
  209. nodes = nx.shortest_path(G, u, v)
  210. # We are guaranteed that there is only one edge connected every node
  211. # in the shortest path.
  212. def first_key(i, vv):
  213. # Needed for 2.x/3.x compatibilitity
  214. keys = G[nodes[i]][vv].keys()
  215. # Normalize behavior
  216. keys = list(keys)
  217. return keys[0]
  218. edges = [first_key(i, vv) for i, vv in enumerate(nodes[1:])]
  219. return nodes, edges
  220. class Edmonds:
  221. """
  222. Edmonds algorithm [1]_ for finding optimal branchings and spanning
  223. arborescences.
  224. This algorithm can find both minimum and maximum spanning arborescences and
  225. branchings.
  226. Notes
  227. -----
  228. While this algorithm can find a minimum branching, since it isn't required
  229. to be spanning, the minimum branching is always from the set of negative
  230. weight edges which is most likely the empty set for most graphs.
  231. References
  232. ----------
  233. .. [1] J. Edmonds, Optimum Branchings, Journal of Research of the National
  234. Bureau of Standards, 1967, Vol. 71B, p.233-240,
  235. https://archive.org/details/jresv71Bn4p233
  236. """
  237. def __init__(self, G, seed=None):
  238. self.G_original = G
  239. # Need to fix this. We need the whole tree.
  240. self.store = True
  241. # The final answer.
  242. self.edges = []
  243. # Since we will be creating graphs with new nodes, we need to make
  244. # sure that our node names do not conflict with the real node names.
  245. self.template = random_string(seed=seed) + "_{0}"
  246. def _init(self, attr, default, kind, style, preserve_attrs, seed, partition):
  247. if kind not in KINDS:
  248. raise nx.NetworkXException("Unknown value for `kind`.")
  249. # Store inputs.
  250. self.attr = attr
  251. self.default = default
  252. self.kind = kind
  253. self.style = style
  254. # Determine how we are going to transform the weights.
  255. if kind == "min":
  256. self.trans = trans = _min_weight
  257. else:
  258. self.trans = trans = _max_weight
  259. if attr is None:
  260. # Generate a random attr the graph probably won't have.
  261. attr = random_string(seed=seed)
  262. # This is the actual attribute used by the algorithm.
  263. self._attr = attr
  264. # This attribute is used to store whether a particular edge is still
  265. # a candidate. We generate a random attr to remove clashes with
  266. # preserved edges
  267. self.candidate_attr = "candidate_" + random_string(seed=seed)
  268. # The object we manipulate at each step is a multidigraph.
  269. self.G = G = MultiDiGraph_EdgeKey()
  270. for key, (u, v, data) in enumerate(self.G_original.edges(data=True)):
  271. d = {attr: trans(data.get(attr, default))}
  272. if data.get(partition) is not None:
  273. d[partition] = data.get(partition)
  274. if preserve_attrs:
  275. for d_k, d_v in data.items():
  276. if d_k != attr:
  277. d[d_k] = d_v
  278. G.add_edge(u, v, key, **d)
  279. self.level = 0
  280. # These are the "buckets" from the paper.
  281. #
  282. # As in the paper, G^i are modified versions of the original graph.
  283. # D^i and E^i are nodes and edges of the maximal edges that are
  284. # consistent with G^i. These are dashed edges in figures A-F of the
  285. # paper. In this implementation, we store D^i and E^i together as a
  286. # graph B^i. So we will have strictly more B^i than the paper does.
  287. self.B = MultiDiGraph_EdgeKey()
  288. self.B.edge_index = {}
  289. self.graphs = [] # G^i
  290. self.branchings = [] # B^i
  291. self.uf = nx.utils.UnionFind()
  292. # A list of lists of edge indexes. Each list is a circuit for graph G^i.
  293. # Note the edge list will not, in general, be a circuit in graph G^0.
  294. self.circuits = []
  295. # Stores the index of the minimum edge in the circuit found in G^i
  296. # and B^i. The ordering of the edges seems to preserve the weight
  297. # ordering from G^0. So even if the circuit does not form a circuit
  298. # in G^0, it is still true that the minimum edge of the circuit in
  299. # G^i is still the minimum edge in circuit G^0 (despite their weights
  300. # being different).
  301. self.minedge_circuit = []
  302. def find_optimum(
  303. self,
  304. attr="weight",
  305. default=1,
  306. kind="max",
  307. style="branching",
  308. preserve_attrs=False,
  309. partition=None,
  310. seed=None,
  311. ):
  312. """
  313. Returns a branching from G.
  314. Parameters
  315. ----------
  316. attr : str
  317. The edge attribute used to in determining optimality.
  318. default : float
  319. The value of the edge attribute used if an edge does not have
  320. the attribute `attr`.
  321. kind : {'min', 'max'}
  322. The type of optimum to search for, either 'min' or 'max'.
  323. style : {'branching', 'arborescence'}
  324. If 'branching', then an optimal branching is found. If `style` is
  325. 'arborescence', then a branching is found, such that if the
  326. branching is also an arborescence, then the branching is an
  327. optimal spanning arborescences. A given graph G need not have
  328. an optimal spanning arborescence.
  329. preserve_attrs : bool
  330. If True, preserve the other edge attributes of the original
  331. graph (that are not the one passed to `attr`)
  332. partition : str
  333. The edge attribute holding edge partition data. Used in the
  334. spanning arborescence iterator.
  335. seed : integer, random_state, or None (default)
  336. Indicator of random number generation state.
  337. See :ref:`Randomness<randomness>`.
  338. Returns
  339. -------
  340. H : (multi)digraph
  341. The branching.
  342. """
  343. self._init(attr, default, kind, style, preserve_attrs, seed, partition)
  344. uf = self.uf
  345. # This enormous while loop could use some refactoring...
  346. G, B = self.G, self.B
  347. D = set()
  348. nodes = iter(list(G.nodes()))
  349. attr = self._attr
  350. G_pred = G.pred
  351. def desired_edge(v):
  352. """
  353. Find the edge directed toward v with maximal weight.
  354. If an edge partition exists in this graph, return the included edge
  355. if it exists and no not return any excluded edges. There can only
  356. be one included edge for each vertex otherwise the edge partition is
  357. empty.
  358. """
  359. edge = None
  360. weight = -INF
  361. for u, _, key, data in G.in_edges(v, data=True, keys=True):
  362. # Skip excluded edges
  363. if data.get(partition) == nx.EdgePartition.EXCLUDED:
  364. continue
  365. new_weight = data[attr]
  366. # Return the included edge
  367. if data.get(partition) == nx.EdgePartition.INCLUDED:
  368. weight = new_weight
  369. edge = (u, v, key, new_weight, data)
  370. return edge, weight
  371. # Find the best open edge
  372. if new_weight > weight:
  373. weight = new_weight
  374. edge = (u, v, key, new_weight, data)
  375. return edge, weight
  376. while True:
  377. # (I1): Choose a node v in G^i not in D^i.
  378. try:
  379. v = next(nodes)
  380. except StopIteration:
  381. # If there are no more new nodes to consider, then we *should*
  382. # meet the break condition (b) from the paper:
  383. # (b) every node of G^i is in D^i and E^i is a branching
  384. # Construction guarantees that it's a branching.
  385. assert len(G) == len(B)
  386. if len(B):
  387. assert is_branching(B)
  388. if self.store:
  389. self.graphs.append(G.copy())
  390. self.branchings.append(B.copy())
  391. # Add these to keep the lengths equal. Element i is the
  392. # circuit at level i that was merged to form branching i+1.
  393. # There is no circuit for the last level.
  394. self.circuits.append([])
  395. self.minedge_circuit.append(None)
  396. break
  397. else:
  398. if v in D:
  399. # print("v in D", v)
  400. continue
  401. # Put v into bucket D^i.
  402. # print(f"Adding node {v}")
  403. D.add(v)
  404. B.add_node(v)
  405. edge, weight = desired_edge(v)
  406. # print(f"Max edge is {edge!r}")
  407. if edge is None:
  408. # If there is no edge, continue with a new node at (I1).
  409. continue
  410. else:
  411. # Determine if adding the edge to E^i would mean its no longer
  412. # a branching. Presently, v has indegree 0 in B---it is a root.
  413. u = edge[0]
  414. if uf[u] == uf[v]:
  415. # Then adding the edge will create a circuit. Then B
  416. # contains a unique path P from v to u. So condition (a)
  417. # from the paper does hold. We need to store the circuit
  418. # for future reference.
  419. Q_nodes, Q_edges = get_path(B, v, u)
  420. Q_edges.append(edge[2]) # Edge key
  421. else:
  422. # Then B with the edge is still a branching and condition
  423. # (a) from the paper does not hold.
  424. Q_nodes, Q_edges = None, None
  425. # Conditions for adding the edge.
  426. # If weight < 0, then it cannot help in finding a maximum branching.
  427. if self.style == "branching" and weight <= 0:
  428. acceptable = False
  429. else:
  430. acceptable = True
  431. # print(f"Edge is acceptable: {acceptable}")
  432. if acceptable:
  433. dd = {attr: weight}
  434. if edge[4].get(partition) is not None:
  435. dd[partition] = edge[4].get(partition)
  436. B.add_edge(u, v, edge[2], **dd)
  437. G[u][v][edge[2]][self.candidate_attr] = True
  438. uf.union(u, v)
  439. if Q_edges is not None:
  440. # print("Edge introduced a simple cycle:")
  441. # print(Q_nodes, Q_edges)
  442. # Move to method
  443. # Previous meaning of u and v is no longer important.
  444. # Apply (I2).
  445. # Get the edge in the cycle with the minimum weight.
  446. # Also, save the incoming weights for each node.
  447. minweight = INF
  448. minedge = None
  449. Q_incoming_weight = {}
  450. for edge_key in Q_edges:
  451. u, v, data = B.edge_index[edge_key]
  452. # We cannot remove an included edges, even if it is
  453. # the minimum edge in the circuit
  454. w = data[attr]
  455. Q_incoming_weight[v] = w
  456. if data.get(partition) == nx.EdgePartition.INCLUDED:
  457. continue
  458. if w < minweight:
  459. minweight = w
  460. minedge = edge_key
  461. self.circuits.append(Q_edges)
  462. self.minedge_circuit.append(minedge)
  463. if self.store:
  464. self.graphs.append(G.copy())
  465. # Always need the branching with circuits.
  466. self.branchings.append(B.copy())
  467. # Now we mutate it.
  468. new_node = self.template.format(self.level)
  469. # print(minweight, minedge, Q_incoming_weight)
  470. G.add_node(new_node)
  471. new_edges = []
  472. for u, v, key, data in G.edges(data=True, keys=True):
  473. if u in Q_incoming_weight:
  474. if v in Q_incoming_weight:
  475. # Circuit edge, do nothing for now.
  476. # Eventually delete it.
  477. continue
  478. else:
  479. # Outgoing edge. Make it from new node
  480. dd = data.copy()
  481. new_edges.append((new_node, v, key, dd))
  482. else:
  483. if v in Q_incoming_weight:
  484. # Incoming edge. Change its weight
  485. w = data[attr]
  486. w += minweight - Q_incoming_weight[v]
  487. dd = data.copy()
  488. dd[attr] = w
  489. new_edges.append((u, new_node, key, dd))
  490. else:
  491. # Outside edge. No modification necessary.
  492. continue
  493. G.remove_nodes_from(Q_nodes)
  494. B.remove_nodes_from(Q_nodes)
  495. D.difference_update(set(Q_nodes))
  496. for u, v, key, data in new_edges:
  497. G.add_edge(u, v, key, **data)
  498. if self.candidate_attr in data:
  499. del data[self.candidate_attr]
  500. B.add_edge(u, v, key, **data)
  501. uf.union(u, v)
  502. nodes = iter(list(G.nodes()))
  503. self.level += 1
  504. # (I3) Branch construction.
  505. # print(self.level)
  506. H = self.G_original.__class__()
  507. def is_root(G, u, edgekeys):
  508. """
  509. Returns True if `u` is a root node in G.
  510. Node `u` will be a root node if its in-degree, restricted to the
  511. specified edges, is equal to 0.
  512. """
  513. if u not in G:
  514. # print(G.nodes(), u)
  515. raise Exception(f"{u!r} not in G")
  516. for v in G.pred[u]:
  517. for edgekey in G.pred[u][v]:
  518. if edgekey in edgekeys:
  519. return False, edgekey
  520. else:
  521. return True, None
  522. # Start with the branching edges in the last level.
  523. edges = set(self.branchings[self.level].edge_index)
  524. while self.level > 0:
  525. self.level -= 1
  526. # The current level is i, and we start counting from 0.
  527. # We need the node at level i+1 that results from merging a circuit
  528. # at level i. randomname_0 is the first merged node and this
  529. # happens at level 1. That is, randomname_0 is a node at level 1
  530. # that results from merging a circuit at level 0.
  531. merged_node = self.template.format(self.level)
  532. # The circuit at level i that was merged as a node the graph
  533. # at level i+1.
  534. circuit = self.circuits[self.level]
  535. # print
  536. # print(merged_node, self.level, circuit)
  537. # print("before", edges)
  538. # Note, we ask if it is a root in the full graph, not the branching.
  539. # The branching alone doesn't have all the edges.
  540. isroot, edgekey = is_root(self.graphs[self.level + 1], merged_node, edges)
  541. edges.update(circuit)
  542. if isroot:
  543. minedge = self.minedge_circuit[self.level]
  544. if minedge is None:
  545. raise Exception
  546. # Remove the edge in the cycle with minimum weight.
  547. edges.remove(minedge)
  548. else:
  549. # We have identified an edge at next higher level that
  550. # transitions into the merged node at the level. That edge
  551. # transitions to some corresponding node at the current level.
  552. # We want to remove an edge from the cycle that transitions
  553. # into the corresponding node.
  554. # print("edgekey is: ", edgekey)
  555. # print("circuit is: ", circuit)
  556. # The branching at level i
  557. G = self.graphs[self.level]
  558. # print(G.edge_index)
  559. target = G.edge_index[edgekey][1]
  560. for edgekey in circuit:
  561. u, v, data = G.edge_index[edgekey]
  562. if v == target:
  563. break
  564. else:
  565. raise Exception("Couldn't find edge incoming to merged node.")
  566. # print(f"not a root. removing {edgekey}")
  567. edges.remove(edgekey)
  568. self.edges = edges
  569. H.add_nodes_from(self.G_original)
  570. for edgekey in edges:
  571. u, v, d = self.graphs[0].edge_index[edgekey]
  572. dd = {self.attr: self.trans(d[self.attr])}
  573. # Optionally, preserve the other edge attributes of the original
  574. # graph
  575. if preserve_attrs:
  576. for key, value in d.items():
  577. if key not in [self.attr, self.candidate_attr]:
  578. dd[key] = value
  579. # TODO: make this preserve the key.
  580. H.add_edge(u, v, **dd)
  581. return H
  582. def maximum_branching(
  583. G, attr="weight", default=1, preserve_attrs=False, partition=None
  584. ):
  585. ed = Edmonds(G)
  586. B = ed.find_optimum(
  587. attr,
  588. default,
  589. kind="max",
  590. style="branching",
  591. preserve_attrs=preserve_attrs,
  592. partition=partition,
  593. )
  594. return B
  595. def minimum_branching(
  596. G, attr="weight", default=1, preserve_attrs=False, partition=None
  597. ):
  598. ed = Edmonds(G)
  599. B = ed.find_optimum(
  600. attr,
  601. default,
  602. kind="min",
  603. style="branching",
  604. preserve_attrs=preserve_attrs,
  605. partition=partition,
  606. )
  607. return B
  608. def maximum_spanning_arborescence(
  609. G, attr="weight", default=1, preserve_attrs=False, partition=None
  610. ):
  611. ed = Edmonds(G)
  612. B = ed.find_optimum(
  613. attr,
  614. default,
  615. kind="max",
  616. style="arborescence",
  617. preserve_attrs=preserve_attrs,
  618. partition=partition,
  619. )
  620. if not is_arborescence(B):
  621. msg = "No maximum spanning arborescence in G."
  622. raise nx.exception.NetworkXException(msg)
  623. return B
  624. def minimum_spanning_arborescence(
  625. G, attr="weight", default=1, preserve_attrs=False, partition=None
  626. ):
  627. ed = Edmonds(G)
  628. B = ed.find_optimum(
  629. attr,
  630. default,
  631. kind="min",
  632. style="arborescence",
  633. preserve_attrs=preserve_attrs,
  634. partition=partition,
  635. )
  636. if not is_arborescence(B):
  637. msg = "No minimum spanning arborescence in G."
  638. raise nx.exception.NetworkXException(msg)
  639. return B
  640. docstring_branching = """
  641. Returns a {kind} {style} from G.
  642. Parameters
  643. ----------
  644. G : (multi)digraph-like
  645. The graph to be searched.
  646. attr : str
  647. The edge attribute used to in determining optimality.
  648. default : float
  649. The value of the edge attribute used if an edge does not have
  650. the attribute `attr`.
  651. preserve_attrs : bool
  652. If True, preserve the other attributes of the original graph (that are not
  653. passed to `attr`)
  654. partition : str
  655. The key for the edge attribute containing the partition
  656. data on the graph. Edges can be included, excluded or open using the
  657. `EdgePartition` enum.
  658. Returns
  659. -------
  660. B : (multi)digraph-like
  661. A {kind} {style}.
  662. """
  663. docstring_arborescence = (
  664. docstring_branching
  665. + """
  666. Raises
  667. ------
  668. NetworkXException
  669. If the graph does not contain a {kind} {style}.
  670. """
  671. )
  672. maximum_branching.__doc__ = docstring_branching.format(
  673. kind="maximum", style="branching"
  674. )
  675. minimum_branching.__doc__ = docstring_branching.format(
  676. kind="minimum", style="branching"
  677. )
  678. maximum_spanning_arborescence.__doc__ = docstring_arborescence.format(
  679. kind="maximum", style="spanning arborescence"
  680. )
  681. minimum_spanning_arborescence.__doc__ = docstring_arborescence.format(
  682. kind="minimum", style="spanning arborescence"
  683. )
  684. class ArborescenceIterator:
  685. """
  686. Iterate over all spanning arborescences of a graph in either increasing or
  687. decreasing cost.
  688. Notes
  689. -----
  690. This iterator uses the partition scheme from [1]_ (included edges,
  691. excluded edges and open edges). It generates minimum spanning
  692. arborescences using a modified Edmonds' Algorithm which respects the
  693. partition of edges. For arborescences with the same weight, ties are
  694. broken arbitrarily.
  695. References
  696. ----------
  697. .. [1] G.K. Janssens, K. Sörensen, An algorithm to generate all spanning
  698. trees in order of increasing cost, Pesquisa Operacional, 2005-08,
  699. Vol. 25 (2), p. 219-229,
  700. https://www.scielo.br/j/pope/a/XHswBwRwJyrfL88dmMwYNWp/?lang=en
  701. """
  702. @dataclass(order=True)
  703. class Partition:
  704. """
  705. This dataclass represents a partition and stores a dict with the edge
  706. data and the weight of the minimum spanning arborescence of the
  707. partition dict.
  708. """
  709. mst_weight: float
  710. partition_dict: dict = field(compare=False)
  711. def __copy__(self):
  712. return ArborescenceIterator.Partition(
  713. self.mst_weight, self.partition_dict.copy()
  714. )
  715. def __init__(self, G, weight="weight", minimum=True, init_partition=None):
  716. """
  717. Initialize the iterator
  718. Parameters
  719. ----------
  720. G : nx.DiGraph
  721. The directed graph which we need to iterate trees over
  722. weight : String, default = "weight"
  723. The edge attribute used to store the weight of the edge
  724. minimum : bool, default = True
  725. Return the trees in increasing order while true and decreasing order
  726. while false.
  727. init_partition : tuple, default = None
  728. In the case that certain edges have to be included or excluded from
  729. the arborescences, `init_partition` should be in the form
  730. `(included_edges, excluded_edges)` where each edges is a
  731. `(u, v)`-tuple inside an iterable such as a list or set.
  732. """
  733. self.G = G.copy()
  734. self.weight = weight
  735. self.minimum = minimum
  736. self.method = (
  737. minimum_spanning_arborescence if minimum else maximum_spanning_arborescence
  738. )
  739. # Randomly create a key for an edge attribute to hold the partition data
  740. self.partition_key = (
  741. "ArborescenceIterators super secret partition attribute name"
  742. )
  743. if init_partition is not None:
  744. partition_dict = {}
  745. for e in init_partition[0]:
  746. partition_dict[e] = nx.EdgePartition.INCLUDED
  747. for e in init_partition[1]:
  748. partition_dict[e] = nx.EdgePartition.EXCLUDED
  749. self.init_partition = ArborescenceIterator.Partition(0, partition_dict)
  750. else:
  751. self.init_partition = None
  752. def __iter__(self):
  753. """
  754. Returns
  755. -------
  756. ArborescenceIterator
  757. The iterator object for this graph
  758. """
  759. self.partition_queue = PriorityQueue()
  760. self._clear_partition(self.G)
  761. # Write the initial partition if it exists.
  762. if self.init_partition is not None:
  763. self._write_partition(self.init_partition)
  764. mst_weight = self.method(
  765. self.G,
  766. self.weight,
  767. partition=self.partition_key,
  768. preserve_attrs=True,
  769. ).size(weight=self.weight)
  770. self.partition_queue.put(
  771. self.Partition(
  772. mst_weight if self.minimum else -mst_weight,
  773. {}
  774. if self.init_partition is None
  775. else self.init_partition.partition_dict,
  776. )
  777. )
  778. return self
  779. def __next__(self):
  780. """
  781. Returns
  782. -------
  783. (multi)Graph
  784. The spanning tree of next greatest weight, which ties broken
  785. arbitrarily.
  786. """
  787. if self.partition_queue.empty():
  788. del self.G, self.partition_queue
  789. raise StopIteration
  790. partition = self.partition_queue.get()
  791. self._write_partition(partition)
  792. next_arborescence = self.method(
  793. self.G,
  794. self.weight,
  795. partition=self.partition_key,
  796. preserve_attrs=True,
  797. )
  798. self._partition(partition, next_arborescence)
  799. self._clear_partition(next_arborescence)
  800. return next_arborescence
  801. def _partition(self, partition, partition_arborescence):
  802. """
  803. Create new partitions based of the minimum spanning tree of the
  804. current minimum partition.
  805. Parameters
  806. ----------
  807. partition : Partition
  808. The Partition instance used to generate the current minimum spanning
  809. tree.
  810. partition_arborescence : nx.Graph
  811. The minimum spanning arborescence of the input partition.
  812. """
  813. # create two new partitions with the data from the input partition dict
  814. p1 = self.Partition(0, partition.partition_dict.copy())
  815. p2 = self.Partition(0, partition.partition_dict.copy())
  816. for e in partition_arborescence.edges:
  817. # determine if the edge was open or included
  818. if e not in partition.partition_dict:
  819. # This is an open edge
  820. p1.partition_dict[e] = nx.EdgePartition.EXCLUDED
  821. p2.partition_dict[e] = nx.EdgePartition.INCLUDED
  822. self._write_partition(p1)
  823. try:
  824. p1_mst = self.method(
  825. self.G,
  826. self.weight,
  827. partition=self.partition_key,
  828. preserve_attrs=True,
  829. )
  830. p1_mst_weight = p1_mst.size(weight=self.weight)
  831. p1.mst_weight = p1_mst_weight if self.minimum else -p1_mst_weight
  832. self.partition_queue.put(p1.__copy__())
  833. except nx.NetworkXException:
  834. pass
  835. p1.partition_dict = p2.partition_dict.copy()
  836. def _write_partition(self, partition):
  837. """
  838. Writes the desired partition into the graph to calculate the minimum
  839. spanning tree. Also, if one incoming edge is included, mark all others
  840. as excluded so that if that vertex is merged during Edmonds' algorithm
  841. we cannot still pick another of that vertex's included edges.
  842. Parameters
  843. ----------
  844. partition : Partition
  845. A Partition dataclass describing a partition on the edges of the
  846. graph.
  847. """
  848. for u, v, d in self.G.edges(data=True):
  849. if (u, v) in partition.partition_dict:
  850. d[self.partition_key] = partition.partition_dict[(u, v)]
  851. else:
  852. d[self.partition_key] = nx.EdgePartition.OPEN
  853. for n in self.G:
  854. included_count = 0
  855. excluded_count = 0
  856. for u, v, d in self.G.in_edges(nbunch=n, data=True):
  857. if d.get(self.partition_key) == nx.EdgePartition.INCLUDED:
  858. included_count += 1
  859. elif d.get(self.partition_key) == nx.EdgePartition.EXCLUDED:
  860. excluded_count += 1
  861. # Check that if there is an included edges, all other incoming ones
  862. # are excluded. If not fix it!
  863. if included_count == 1 and excluded_count != self.G.in_degree(n) - 1:
  864. for u, v, d in self.G.in_edges(nbunch=n, data=True):
  865. if d.get(self.partition_key) != nx.EdgePartition.INCLUDED:
  866. d[self.partition_key] = nx.EdgePartition.EXCLUDED
  867. def _clear_partition(self, G):
  868. """
  869. Removes partition data from the graph
  870. """
  871. for u, v, d in G.edges(data=True):
  872. if self.partition_key in d:
  873. del d[self.partition_key]