heaps.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. """
  2. Min-heaps.
  3. """
  4. from heapq import heappop, heappush
  5. from itertools import count
  6. import networkx as nx
  7. __all__ = ["MinHeap", "PairingHeap", "BinaryHeap"]
  8. class MinHeap:
  9. """Base class for min-heaps.
  10. A MinHeap stores a collection of key-value pairs ordered by their values.
  11. It supports querying the minimum pair, inserting a new pair, decreasing the
  12. value in an existing pair and deleting the minimum pair.
  13. """
  14. class _Item:
  15. """Used by subclassess to represent a key-value pair."""
  16. __slots__ = ("key", "value")
  17. def __init__(self, key, value):
  18. self.key = key
  19. self.value = value
  20. def __repr__(self):
  21. return repr((self.key, self.value))
  22. def __init__(self):
  23. """Initialize a new min-heap."""
  24. self._dict = {}
  25. def min(self):
  26. """Query the minimum key-value pair.
  27. Returns
  28. -------
  29. key, value : tuple
  30. The key-value pair with the minimum value in the heap.
  31. Raises
  32. ------
  33. NetworkXError
  34. If the heap is empty.
  35. """
  36. raise NotImplementedError
  37. def pop(self):
  38. """Delete the minimum pair in the heap.
  39. Returns
  40. -------
  41. key, value : tuple
  42. The key-value pair with the minimum value in the heap.
  43. Raises
  44. ------
  45. NetworkXError
  46. If the heap is empty.
  47. """
  48. raise NotImplementedError
  49. def get(self, key, default=None):
  50. """Returns the value associated with a key.
  51. Parameters
  52. ----------
  53. key : hashable object
  54. The key to be looked up.
  55. default : object
  56. Default value to return if the key is not present in the heap.
  57. Default value: None.
  58. Returns
  59. -------
  60. value : object.
  61. The value associated with the key.
  62. """
  63. raise NotImplementedError
  64. def insert(self, key, value, allow_increase=False):
  65. """Insert a new key-value pair or modify the value in an existing
  66. pair.
  67. Parameters
  68. ----------
  69. key : hashable object
  70. The key.
  71. value : object comparable with existing values.
  72. The value.
  73. allow_increase : bool
  74. Whether the value is allowed to increase. If False, attempts to
  75. increase an existing value have no effect. Default value: False.
  76. Returns
  77. -------
  78. decreased : bool
  79. True if a pair is inserted or the existing value is decreased.
  80. """
  81. raise NotImplementedError
  82. def __nonzero__(self):
  83. """Returns whether the heap if empty."""
  84. return bool(self._dict)
  85. def __bool__(self):
  86. """Returns whether the heap if empty."""
  87. return bool(self._dict)
  88. def __len__(self):
  89. """Returns the number of key-value pairs in the heap."""
  90. return len(self._dict)
  91. def __contains__(self, key):
  92. """Returns whether a key exists in the heap.
  93. Parameters
  94. ----------
  95. key : any hashable object.
  96. The key to be looked up.
  97. """
  98. return key in self._dict
  99. class PairingHeap(MinHeap):
  100. """A pairing heap."""
  101. class _Node(MinHeap._Item):
  102. """A node in a pairing heap.
  103. A tree in a pairing heap is stored using the left-child, right-sibling
  104. representation.
  105. """
  106. __slots__ = ("left", "next", "prev", "parent")
  107. def __init__(self, key, value):
  108. super().__init__(key, value)
  109. # The leftmost child.
  110. self.left = None
  111. # The next sibling.
  112. self.next = None
  113. # The previous sibling.
  114. self.prev = None
  115. # The parent.
  116. self.parent = None
  117. def __init__(self):
  118. """Initialize a pairing heap."""
  119. super().__init__()
  120. self._root = None
  121. def min(self):
  122. if self._root is None:
  123. raise nx.NetworkXError("heap is empty.")
  124. return (self._root.key, self._root.value)
  125. def pop(self):
  126. if self._root is None:
  127. raise nx.NetworkXError("heap is empty.")
  128. min_node = self._root
  129. self._root = self._merge_children(self._root)
  130. del self._dict[min_node.key]
  131. return (min_node.key, min_node.value)
  132. def get(self, key, default=None):
  133. node = self._dict.get(key)
  134. return node.value if node is not None else default
  135. def insert(self, key, value, allow_increase=False):
  136. node = self._dict.get(key)
  137. root = self._root
  138. if node is not None:
  139. if value < node.value:
  140. node.value = value
  141. if node is not root and value < node.parent.value:
  142. self._cut(node)
  143. self._root = self._link(root, node)
  144. return True
  145. elif allow_increase and value > node.value:
  146. node.value = value
  147. child = self._merge_children(node)
  148. # Nonstandard step: Link the merged subtree with the root. See
  149. # below for the standard step.
  150. if child is not None:
  151. self._root = self._link(self._root, child)
  152. # Standard step: Perform a decrease followed by a pop as if the
  153. # value were the smallest in the heap. Then insert the new
  154. # value into the heap.
  155. # if node is not root:
  156. # self._cut(node)
  157. # if child is not None:
  158. # root = self._link(root, child)
  159. # self._root = self._link(root, node)
  160. # else:
  161. # self._root = (self._link(node, child)
  162. # if child is not None else node)
  163. return False
  164. else:
  165. # Insert a new key.
  166. node = self._Node(key, value)
  167. self._dict[key] = node
  168. self._root = self._link(root, node) if root is not None else node
  169. return True
  170. def _link(self, root, other):
  171. """Link two nodes, making the one with the smaller value the parent of
  172. the other.
  173. """
  174. if other.value < root.value:
  175. root, other = other, root
  176. next = root.left
  177. other.next = next
  178. if next is not None:
  179. next.prev = other
  180. other.prev = None
  181. root.left = other
  182. other.parent = root
  183. return root
  184. def _merge_children(self, root):
  185. """Merge the subtrees of the root using the standard two-pass method.
  186. The resulting subtree is detached from the root.
  187. """
  188. node = root.left
  189. root.left = None
  190. if node is not None:
  191. link = self._link
  192. # Pass 1: Merge pairs of consecutive subtrees from left to right.
  193. # At the end of the pass, only the prev pointers of the resulting
  194. # subtrees have meaningful values. The other pointers will be fixed
  195. # in pass 2.
  196. prev = None
  197. while True:
  198. next = node.next
  199. if next is None:
  200. node.prev = prev
  201. break
  202. next_next = next.next
  203. node = link(node, next)
  204. node.prev = prev
  205. prev = node
  206. if next_next is None:
  207. break
  208. node = next_next
  209. # Pass 2: Successively merge the subtrees produced by pass 1 from
  210. # right to left with the rightmost one.
  211. prev = node.prev
  212. while prev is not None:
  213. prev_prev = prev.prev
  214. node = link(prev, node)
  215. prev = prev_prev
  216. # Now node can become the new root. Its has no parent nor siblings.
  217. node.prev = None
  218. node.next = None
  219. node.parent = None
  220. return node
  221. def _cut(self, node):
  222. """Cut a node from its parent."""
  223. prev = node.prev
  224. next = node.next
  225. if prev is not None:
  226. prev.next = next
  227. else:
  228. node.parent.left = next
  229. node.prev = None
  230. if next is not None:
  231. next.prev = prev
  232. node.next = None
  233. node.parent = None
  234. class BinaryHeap(MinHeap):
  235. """A binary heap."""
  236. def __init__(self):
  237. """Initialize a binary heap."""
  238. super().__init__()
  239. self._heap = []
  240. self._count = count()
  241. def min(self):
  242. dict = self._dict
  243. if not dict:
  244. raise nx.NetworkXError("heap is empty")
  245. heap = self._heap
  246. pop = heappop
  247. # Repeatedly remove stale key-value pairs until a up-to-date one is
  248. # met.
  249. while True:
  250. value, _, key = heap[0]
  251. if key in dict and value == dict[key]:
  252. break
  253. pop(heap)
  254. return (key, value)
  255. def pop(self):
  256. dict = self._dict
  257. if not dict:
  258. raise nx.NetworkXError("heap is empty")
  259. heap = self._heap
  260. pop = heappop
  261. # Repeatedly remove stale key-value pairs until a up-to-date one is
  262. # met.
  263. while True:
  264. value, _, key = heap[0]
  265. pop(heap)
  266. if key in dict and value == dict[key]:
  267. break
  268. del dict[key]
  269. return (key, value)
  270. def get(self, key, default=None):
  271. return self._dict.get(key, default)
  272. def insert(self, key, value, allow_increase=False):
  273. dict = self._dict
  274. if key in dict:
  275. old_value = dict[key]
  276. if value < old_value or (allow_increase and value > old_value):
  277. # Since there is no way to efficiently obtain the location of a
  278. # key-value pair in the heap, insert a new pair even if ones
  279. # with the same key may already be present. Deem the old ones
  280. # as stale and skip them when the minimum pair is queried.
  281. dict[key] = value
  282. heappush(self._heap, (value, next(self._count), key))
  283. return value < old_value
  284. return False
  285. else:
  286. dict[key] = value
  287. heappush(self._heap, (value, next(self._count), key))
  288. return True