123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- """
- Disjoint set data structure
- """
- class DisjointSet:
- """ Disjoint set data structure for incremental connectivity queries.
- .. versionadded:: 1.6.0
- Attributes
- ----------
- n_subsets : int
- The number of subsets.
- Methods
- -------
- add
- merge
- connected
- subset
- subsets
- __getitem__
- Notes
- -----
- This class implements the disjoint set [1]_, also known as the *union-find*
- or *merge-find* data structure. The *find* operation (implemented in
- `__getitem__`) implements the *path halving* variant. The *merge* method
- implements the *merge by size* variant.
- References
- ----------
- .. [1] https://en.wikipedia.org/wiki/Disjoint-set_data_structure
- Examples
- --------
- >>> from scipy.cluster.hierarchy import DisjointSet
- Initialize a disjoint set:
- >>> disjoint_set = DisjointSet([1, 2, 3, 'a', 'b'])
- Merge some subsets:
- >>> disjoint_set.merge(1, 2)
- True
- >>> disjoint_set.merge(3, 'a')
- True
- >>> disjoint_set.merge('a', 'b')
- True
- >>> disjoint_set.merge('b', 'b')
- False
- Find root elements:
- >>> disjoint_set[2]
- 1
- >>> disjoint_set['b']
- 3
- Test connectivity:
- >>> disjoint_set.connected(1, 2)
- True
- >>> disjoint_set.connected(1, 'b')
- False
- List elements in disjoint set:
- >>> list(disjoint_set)
- [1, 2, 3, 'a', 'b']
- Get the subset containing 'a':
- >>> disjoint_set.subset('a')
- {'a', 3, 'b'}
- Get all subsets in the disjoint set:
- >>> disjoint_set.subsets()
- [{1, 2}, {'a', 3, 'b'}]
- """
- def __init__(self, elements=None):
- self.n_subsets = 0
- self._sizes = {}
- self._parents = {}
- # _nbrs is a circular linked list which links connected elements.
- self._nbrs = {}
- # _indices tracks the element insertion order in `__iter__`.
- self._indices = {}
- if elements is not None:
- for x in elements:
- self.add(x)
- def __iter__(self):
- """Returns an iterator of the elements in the disjoint set.
- Elements are ordered by insertion order.
- """
- return iter(self._indices)
- def __len__(self):
- return len(self._indices)
- def __contains__(self, x):
- return x in self._indices
- def __getitem__(self, x):
- """Find the root element of `x`.
- Parameters
- ----------
- x : hashable object
- Input element.
- Returns
- -------
- root : hashable object
- Root element of `x`.
- """
- if x not in self._indices:
- raise KeyError(x)
- # find by "path halving"
- parents = self._parents
- while self._indices[x] != self._indices[parents[x]]:
- parents[x] = parents[parents[x]]
- x = parents[x]
- return x
- def add(self, x):
- """Add element `x` to disjoint set
- """
- if x in self._indices:
- return
- self._sizes[x] = 1
- self._parents[x] = x
- self._nbrs[x] = x
- self._indices[x] = len(self._indices)
- self.n_subsets += 1
- def merge(self, x, y):
- """Merge the subsets of `x` and `y`.
- The smaller subset (the child) is merged into the larger subset (the
- parent). If the subsets are of equal size, the root element which was
- first inserted into the disjoint set is selected as the parent.
- Parameters
- ----------
- x, y : hashable object
- Elements to merge.
- Returns
- -------
- merged : bool
- True if `x` and `y` were in disjoint sets, False otherwise.
- """
- xr = self[x]
- yr = self[y]
- if self._indices[xr] == self._indices[yr]:
- return False
- sizes = self._sizes
- if (sizes[xr], self._indices[yr]) < (sizes[yr], self._indices[xr]):
- xr, yr = yr, xr
- self._parents[yr] = xr
- self._sizes[xr] += self._sizes[yr]
- self._nbrs[xr], self._nbrs[yr] = self._nbrs[yr], self._nbrs[xr]
- self.n_subsets -= 1
- return True
- def connected(self, x, y):
- """Test whether `x` and `y` are in the same subset.
- Parameters
- ----------
- x, y : hashable object
- Elements to test.
- Returns
- -------
- result : bool
- True if `x` and `y` are in the same set, False otherwise.
- """
- return self._indices[self[x]] == self._indices[self[y]]
- def subset(self, x):
- """Get the subset containing `x`.
- Parameters
- ----------
- x : hashable object
- Input element.
- Returns
- -------
- result : set
- Subset containing `x`.
- """
- if x not in self._indices:
- raise KeyError(x)
- result = [x]
- nxt = self._nbrs[x]
- while self._indices[nxt] != self._indices[x]:
- result.append(nxt)
- nxt = self._nbrs[nxt]
- return set(result)
- def subsets(self):
- """Get all the subsets in the disjoint set.
- Returns
- -------
- result : list
- Subsets in the disjoint set.
- """
- result = []
- visited = set()
- for x in self:
- if x not in visited:
- xset = self.subset(x)
- visited.update(xset)
- result.append(xset)
- return result
|