union_find.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. """
  2. Union-find data structure.
  3. """
  4. from networkx.utils import groups
  5. class UnionFind:
  6. """Union-find data structure.
  7. Each unionFind instance X maintains a family of disjoint sets of
  8. hashable objects, supporting the following two methods:
  9. - X[item] returns a name for the set containing the given item.
  10. Each set is named by an arbitrarily-chosen one of its members; as
  11. long as the set remains unchanged it will keep the same name. If
  12. the item is not yet part of a set in X, a new singleton set is
  13. created for it.
  14. - X.union(item1, item2, ...) merges the sets containing each item
  15. into a single larger set. If any item is not yet part of a set
  16. in X, it is added to X as one of the members of the merged set.
  17. Union-find data structure. Based on Josiah Carlson's code,
  18. https://code.activestate.com/recipes/215912/
  19. with significant additional changes by D. Eppstein.
  20. http://www.ics.uci.edu/~eppstein/PADS/UnionFind.py
  21. """
  22. def __init__(self, elements=None):
  23. """Create a new empty union-find structure.
  24. If *elements* is an iterable, this structure will be initialized
  25. with the discrete partition on the given set of elements.
  26. """
  27. if elements is None:
  28. elements = ()
  29. self.parents = {}
  30. self.weights = {}
  31. for x in elements:
  32. self.weights[x] = 1
  33. self.parents[x] = x
  34. def __getitem__(self, object):
  35. """Find and return the name of the set containing the object."""
  36. # check for previously unknown object
  37. if object not in self.parents:
  38. self.parents[object] = object
  39. self.weights[object] = 1
  40. return object
  41. # find path of objects leading to the root
  42. path = []
  43. root = self.parents[object]
  44. while root != object:
  45. path.append(object)
  46. object = root
  47. root = self.parents[object]
  48. # compress the path and return
  49. for ancestor in path:
  50. self.parents[ancestor] = root
  51. return root
  52. def __iter__(self):
  53. """Iterate through all items ever found or unioned by this structure."""
  54. return iter(self.parents)
  55. def to_sets(self):
  56. """Iterates over the sets stored in this structure.
  57. For example::
  58. >>> partition = UnionFind("xyz")
  59. >>> sorted(map(sorted, partition.to_sets()))
  60. [['x'], ['y'], ['z']]
  61. >>> partition.union("x", "y")
  62. >>> sorted(map(sorted, partition.to_sets()))
  63. [['x', 'y'], ['z']]
  64. """
  65. # Ensure fully pruned paths
  66. for x in self.parents:
  67. _ = self[x] # Evaluated for side-effect only
  68. yield from groups(self.parents).values()
  69. def union(self, *objects):
  70. """Find the sets containing the objects and merge them all."""
  71. # Find the heaviest root according to its weight.
  72. roots = iter(
  73. sorted(
  74. {self[x] for x in objects}, key=lambda r: self.weights[r], reverse=True
  75. )
  76. )
  77. try:
  78. root = next(roots)
  79. except StopIteration:
  80. return
  81. for r in roots:
  82. self.weights[root] += self.weights[r]
  83. self.parents[r] = root