weak.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. import weakref
  2. from weakref import ref
  3. from _weakrefset import _IterationGuard # type: ignore[attr-defined]
  4. from collections.abc import MutableMapping, Mapping
  5. from typing import Dict
  6. import collections.abc as _collections_abc
  7. __all__ = ['WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
  8. # This file defines a variant of WeakKeyDictionary that overrides the hashing
  9. # behavior of the key to use object identity, rather than the builtin
  10. # __eq__/__hash__ functions. This is useful for Tensor weak keys, as their
  11. # __eq__ implementation return a Tensor (elementwise equality), which means
  12. # you can't use them directly with the WeakKeyDictionary in standard library.
  13. #
  14. # Our implementation strategy is to create a wrapper weak key object, which we
  15. # use as a key in a stock Python dictionary. This is similar to how weakref
  16. # implements WeakKeyDictionary, but instead of using weakref.ref as the
  17. # wrapper, we use a custom wrapper that has different __eq__ and __hash__
  18. # behavior. Note that we subsequently store this weak key directly in an
  19. # ORDINARY dictionary, since the newly constructed WeakIdKey's only use would
  20. # be a dictionary so it would have no strong references. Ensuring that
  21. # only live WeakIdKeys are in the map is handled by putting finalizers on the
  22. # original key object.
  23. # It is simpler to implement this with composition, but if we want to
  24. # directly reuse the callback mechanism on weakref, we need the weakref
  25. # and the key to be exactly the same object. Reusing the callback mechanism
  26. # minimizes the divergence between our implementation and Lib/weakref.py
  27. #
  28. # NB: Prefer using this when working with weakrefs of Tensors; e.g., do
  29. # WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
  30. # easy to get wrong cases transparently for you.
  31. class WeakIdRef(weakref.ref):
  32. __slots__ = ['_id']
  33. def __init__(self, key, callback=None):
  34. # Unlike stock weakref, which preserves hash semantics of the
  35. # original object but lazily defers hash calls until the first
  36. # time the user attempts to hash the weakref, we can eagerly
  37. # cache the id of the key as we know this is definitely the hash
  38. # method
  39. self._id = id(key)
  40. super().__init__(key, callback)
  41. def __call__(self):
  42. r = super().__call__()
  43. # Special logic for Tensor PyObject resurrection
  44. if hasattr(r, '_fix_weakref'):
  45. r._fix_weakref() # type: ignore[union-attr]
  46. return r
  47. def __hash__(self):
  48. return self._id
  49. def __eq__(self, other):
  50. # An attractive but wrong alternate implementation is to only test if
  51. # the stored _ids match. This can lead to an ABA problem if you have:
  52. #
  53. # a1 = A()
  54. # w1 = WeakIdRef(a)
  55. # del a1
  56. # a2 = A() # suppose it gets the same ID as a1
  57. # w2 = WeakIdRef(a2)
  58. # print(w1 == w2)
  59. #
  60. # This should be False, as a1 and a2 are unrelated (and a1 is
  61. # dead anyway)
  62. a = self()
  63. b = other()
  64. if a is not None and b is not None:
  65. return a is b
  66. return self is other
  67. # This is directly adapted from cpython/Lib/weakref.py
  68. class WeakIdKeyDictionary(MutableMapping):
  69. data: Dict[WeakIdRef, object]
  70. def __init__(self, dict=None):
  71. self.data = {}
  72. def remove(k, selfref=ref(self)):
  73. self = selfref()
  74. if self is not None:
  75. if self._iterating:
  76. self._pending_removals.append(k)
  77. else:
  78. try:
  79. del self.data[k]
  80. except KeyError:
  81. pass
  82. self._remove = remove
  83. # A list of dead weakrefs (keys to be removed)
  84. self._pending_removals = []
  85. self._iterating = set()
  86. self._dirty_len = False
  87. if dict is not None:
  88. self.update(dict)
  89. def _commit_removals(self):
  90. # NOTE: We don't need to call this method before mutating the dict,
  91. # because a dead weakref never compares equal to a live weakref,
  92. # even if they happened to refer to equal objects.
  93. # However, it means keys may already have been removed.
  94. pop = self._pending_removals.pop
  95. d = self.data
  96. while True:
  97. try:
  98. key = pop()
  99. except IndexError:
  100. return
  101. try:
  102. del d[key]
  103. except KeyError:
  104. pass
  105. def _scrub_removals(self):
  106. d = self.data
  107. self._pending_removals = [k for k in self._pending_removals if k in d]
  108. self._dirty_len = False
  109. def __delitem__(self, key):
  110. self._dirty_len = True
  111. del self.data[WeakIdRef(key)] # CHANGED
  112. def __getitem__(self, key):
  113. return self.data[WeakIdRef(key)] # CHANGED
  114. def __len__(self):
  115. if self._dirty_len and self._pending_removals:
  116. # self._pending_removals may still contain keys which were
  117. # explicitly removed, we have to scrub them (see issue #21173).
  118. self._scrub_removals()
  119. return len(self.data) - len(self._pending_removals)
  120. def __repr__(self):
  121. return "<%s at %#x>" % (self.__class__.__name__, id(self))
  122. def __setitem__(self, key, value):
  123. self.data[WeakIdRef(key, self._remove)] = value # CHANGED
  124. def copy(self):
  125. new = WeakIdKeyDictionary()
  126. with _IterationGuard(self):
  127. for key, value in self.data.items():
  128. o = key()
  129. if o is not None:
  130. new[o] = value
  131. return new
  132. __copy__ = copy
  133. def __deepcopy__(self, memo):
  134. from copy import deepcopy
  135. new = self.__class__()
  136. with _IterationGuard(self):
  137. for key, value in self.data.items():
  138. o = key()
  139. if o is not None:
  140. new[o] = deepcopy(value, memo)
  141. return new
  142. def get(self, key, default=None):
  143. return self.data.get(WeakIdRef(key), default) # CHANGED
  144. def __contains__(self, key):
  145. try:
  146. wr = WeakIdRef(key)
  147. except TypeError:
  148. return False
  149. return wr in self.data
  150. def items(self):
  151. with _IterationGuard(self):
  152. for wr, value in self.data.items():
  153. key = wr()
  154. if key is not None:
  155. yield key, value
  156. def keys(self):
  157. with _IterationGuard(self):
  158. for wr in self.data:
  159. obj = wr()
  160. if obj is not None:
  161. yield obj
  162. __iter__ = keys
  163. def values(self):
  164. with _IterationGuard(self):
  165. for wr, value in self.data.items():
  166. if wr() is not None:
  167. yield value
  168. def keyrefs(self):
  169. """Return a list of weak references to the keys.
  170. The references are not guaranteed to be 'live' at the time
  171. they are used, so the result of calling the references needs
  172. to be checked before being used. This can be used to avoid
  173. creating references that will cause the garbage collector to
  174. keep the keys around longer than needed.
  175. """
  176. return list(self.data)
  177. def popitem(self):
  178. self._dirty_len = True
  179. while True:
  180. key, value = self.data.popitem()
  181. o = key()
  182. if o is not None:
  183. return o, value
  184. def pop(self, key, *args):
  185. self._dirty_len = True
  186. return self.data.pop(WeakIdRef(key), *args) # CHANGED
  187. def setdefault(self, key, default=None):
  188. return self.data.setdefault(WeakIdRef(key, self._remove), default) # CHANGED
  189. def update(self, dict=None, **kwargs):
  190. d = self.data
  191. if dict is not None:
  192. if not hasattr(dict, "items"):
  193. dict = type({})(dict)
  194. for key, value in dict.items():
  195. d[WeakIdRef(key, self._remove)] = value # CHANGED
  196. if len(kwargs):
  197. self.update(kwargs)
  198. def __ior__(self, other):
  199. self.update(other)
  200. return self
  201. def __or__(self, other):
  202. if isinstance(other, _collections_abc.Mapping):
  203. c = self.copy()
  204. c.update(other)
  205. return c
  206. return NotImplemented
  207. def __ror__(self, other):
  208. if isinstance(other, _collections_abc.Mapping):
  209. c = self.__class__()
  210. c.update(other)
  211. c.update(self)
  212. return c
  213. return NotImplemented
  214. # Default Mapping equality will tests keys for equality, but
  215. # we want to test ids for equality
  216. def __eq__(self, other):
  217. if not isinstance(other, Mapping):
  218. return NotImplemented
  219. return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()}
  220. # Convenience alias
  221. WeakTensorKeyDictionary = WeakIdKeyDictionary