_disjoint_set.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. """
  2. Disjoint set data structure
  3. """
  4. class DisjointSet:
  5. """ Disjoint set data structure for incremental connectivity queries.
  6. .. versionadded:: 1.6.0
  7. Attributes
  8. ----------
  9. n_subsets : int
  10. The number of subsets.
  11. Methods
  12. -------
  13. add
  14. merge
  15. connected
  16. subset
  17. subsets
  18. __getitem__
  19. Notes
  20. -----
  21. This class implements the disjoint set [1]_, also known as the *union-find*
  22. or *merge-find* data structure. The *find* operation (implemented in
  23. `__getitem__`) implements the *path halving* variant. The *merge* method
  24. implements the *merge by size* variant.
  25. References
  26. ----------
  27. .. [1] https://en.wikipedia.org/wiki/Disjoint-set_data_structure
  28. Examples
  29. --------
  30. >>> from scipy.cluster.hierarchy import DisjointSet
  31. Initialize a disjoint set:
  32. >>> disjoint_set = DisjointSet([1, 2, 3, 'a', 'b'])
  33. Merge some subsets:
  34. >>> disjoint_set.merge(1, 2)
  35. True
  36. >>> disjoint_set.merge(3, 'a')
  37. True
  38. >>> disjoint_set.merge('a', 'b')
  39. True
  40. >>> disjoint_set.merge('b', 'b')
  41. False
  42. Find root elements:
  43. >>> disjoint_set[2]
  44. 1
  45. >>> disjoint_set['b']
  46. 3
  47. Test connectivity:
  48. >>> disjoint_set.connected(1, 2)
  49. True
  50. >>> disjoint_set.connected(1, 'b')
  51. False
  52. List elements in disjoint set:
  53. >>> list(disjoint_set)
  54. [1, 2, 3, 'a', 'b']
  55. Get the subset containing 'a':
  56. >>> disjoint_set.subset('a')
  57. {'a', 3, 'b'}
  58. Get all subsets in the disjoint set:
  59. >>> disjoint_set.subsets()
  60. [{1, 2}, {'a', 3, 'b'}]
  61. """
  62. def __init__(self, elements=None):
  63. self.n_subsets = 0
  64. self._sizes = {}
  65. self._parents = {}
  66. # _nbrs is a circular linked list which links connected elements.
  67. self._nbrs = {}
  68. # _indices tracks the element insertion order in `__iter__`.
  69. self._indices = {}
  70. if elements is not None:
  71. for x in elements:
  72. self.add(x)
  73. def __iter__(self):
  74. """Returns an iterator of the elements in the disjoint set.
  75. Elements are ordered by insertion order.
  76. """
  77. return iter(self._indices)
  78. def __len__(self):
  79. return len(self._indices)
  80. def __contains__(self, x):
  81. return x in self._indices
  82. def __getitem__(self, x):
  83. """Find the root element of `x`.
  84. Parameters
  85. ----------
  86. x : hashable object
  87. Input element.
  88. Returns
  89. -------
  90. root : hashable object
  91. Root element of `x`.
  92. """
  93. if x not in self._indices:
  94. raise KeyError(x)
  95. # find by "path halving"
  96. parents = self._parents
  97. while self._indices[x] != self._indices[parents[x]]:
  98. parents[x] = parents[parents[x]]
  99. x = parents[x]
  100. return x
  101. def add(self, x):
  102. """Add element `x` to disjoint set
  103. """
  104. if x in self._indices:
  105. return
  106. self._sizes[x] = 1
  107. self._parents[x] = x
  108. self._nbrs[x] = x
  109. self._indices[x] = len(self._indices)
  110. self.n_subsets += 1
  111. def merge(self, x, y):
  112. """Merge the subsets of `x` and `y`.
  113. The smaller subset (the child) is merged into the larger subset (the
  114. parent). If the subsets are of equal size, the root element which was
  115. first inserted into the disjoint set is selected as the parent.
  116. Parameters
  117. ----------
  118. x, y : hashable object
  119. Elements to merge.
  120. Returns
  121. -------
  122. merged : bool
  123. True if `x` and `y` were in disjoint sets, False otherwise.
  124. """
  125. xr = self[x]
  126. yr = self[y]
  127. if self._indices[xr] == self._indices[yr]:
  128. return False
  129. sizes = self._sizes
  130. if (sizes[xr], self._indices[yr]) < (sizes[yr], self._indices[xr]):
  131. xr, yr = yr, xr
  132. self._parents[yr] = xr
  133. self._sizes[xr] += self._sizes[yr]
  134. self._nbrs[xr], self._nbrs[yr] = self._nbrs[yr], self._nbrs[xr]
  135. self.n_subsets -= 1
  136. return True
  137. def connected(self, x, y):
  138. """Test whether `x` and `y` are in the same subset.
  139. Parameters
  140. ----------
  141. x, y : hashable object
  142. Elements to test.
  143. Returns
  144. -------
  145. result : bool
  146. True if `x` and `y` are in the same set, False otherwise.
  147. """
  148. return self._indices[self[x]] == self._indices[self[y]]
  149. def subset(self, x):
  150. """Get the subset containing `x`.
  151. Parameters
  152. ----------
  153. x : hashable object
  154. Input element.
  155. Returns
  156. -------
  157. result : set
  158. Subset containing `x`.
  159. """
  160. if x not in self._indices:
  161. raise KeyError(x)
  162. result = [x]
  163. nxt = self._nbrs[x]
  164. while self._indices[nxt] != self._indices[x]:
  165. result.append(nxt)
  166. nxt = self._nbrs[nxt]
  167. return set(result)
  168. def subsets(self):
  169. """Get all the subsets in the disjoint set.
  170. Returns
  171. -------
  172. result : list
  173. Subsets in the disjoint set.
  174. """
  175. result = []
  176. visited = set()
  177. for x in self:
  178. if x not in visited:
  179. xset = self.subset(x)
  180. visited.update(xset)
  181. result.append(xset)
  182. return result