operations.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """Operations on trees."""
  2. from functools import partial
  3. from itertools import accumulate, chain
  4. import networkx as nx
  5. __all__ = ["join"]
  6. def join(rooted_trees, label_attribute=None):
  7. """Returns a new rooted tree with a root node joined with the roots
  8. of each of the given rooted trees.
  9. Parameters
  10. ----------
  11. rooted_trees : list
  12. A list of pairs in which each left element is a NetworkX graph
  13. object representing a tree and each right element is the root
  14. node of that tree. The nodes of these trees will be relabeled to
  15. integers.
  16. label_attribute : str
  17. If provided, the old node labels will be stored in the new tree
  18. under this node attribute. If not provided, the node attribute
  19. ``'_old'`` will store the original label of the node in the
  20. rooted trees given in the input.
  21. Returns
  22. -------
  23. NetworkX graph
  24. The rooted tree whose subtrees are the given rooted trees. The
  25. new root node is labeled 0. Each non-root node has an attribute,
  26. as described under the keyword argument ``label_attribute``,
  27. that indicates the label of the original node in the input tree.
  28. Notes
  29. -----
  30. Graph, edge, and node attributes are propagated from the given
  31. rooted trees to the created tree. If there are any overlapping graph
  32. attributes, those from later trees will overwrite those from earlier
  33. trees in the tuple of positional arguments.
  34. Examples
  35. --------
  36. Join two full balanced binary trees of height *h* to get a full
  37. balanced binary tree of depth *h* + 1::
  38. >>> h = 4
  39. >>> left = nx.balanced_tree(2, h)
  40. >>> right = nx.balanced_tree(2, h)
  41. >>> joined_tree = nx.join([(left, 0), (right, 0)])
  42. >>> nx.is_isomorphic(joined_tree, nx.balanced_tree(2, h + 1))
  43. True
  44. """
  45. if len(rooted_trees) == 0:
  46. return nx.empty_graph(1)
  47. # Unzip the zipped list of (tree, root) pairs.
  48. trees, roots = zip(*rooted_trees)
  49. # The join of the trees has the same type as the type of the first
  50. # tree.
  51. R = type(trees[0])()
  52. # Relabel the nodes so that their union is the integers starting at 1.
  53. if label_attribute is None:
  54. label_attribute = "_old"
  55. relabel = partial(
  56. nx.convert_node_labels_to_integers, label_attribute=label_attribute
  57. )
  58. lengths = (len(tree) for tree in trees[:-1])
  59. first_labels = chain([0], accumulate(lengths))
  60. trees = [
  61. relabel(tree, first_label=first_label + 1)
  62. for tree, first_label in zip(trees, first_labels)
  63. ]
  64. # Get the relabeled roots.
  65. roots = [
  66. next(v for v, d in tree.nodes(data=True) if d.get("_old") == root)
  67. for tree, root in zip(trees, roots)
  68. ]
  69. # Remove the old node labels.
  70. for tree in trees:
  71. for v in tree:
  72. tree.nodes[v].pop("_old")
  73. # Add all sets of nodes and edges, with data.
  74. nodes = (tree.nodes(data=True) for tree in trees)
  75. edges = (tree.edges(data=True) for tree in trees)
  76. R.add_nodes_from(chain.from_iterable(nodes))
  77. R.add_edges_from(chain.from_iterable(edges))
  78. # Add graph attributes; later attributes take precedent over earlier
  79. # attributes.
  80. for tree in trees:
  81. R.graph.update(tree.graph)
  82. # Finally, join the subtrees at the root. We know 0 is unused by the
  83. # way we relabeled the subtrees.
  84. R.add_node(0)
  85. R.add_edges_from((0, root) for root in roots)
  86. return R