pickle_compat.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. """
  2. Support pre-0.12 series pickle compatibility.
  3. """
  4. from __future__ import annotations
  5. import contextlib
  6. import copy
  7. import io
  8. import pickle as pkl
  9. from typing import Generator
  10. import numpy as np
  11. from pandas._libs.arrays import NDArrayBacked
  12. from pandas._libs.tslibs import BaseOffset
  13. from pandas import Index
  14. from pandas.core.arrays import (
  15. DatetimeArray,
  16. PeriodArray,
  17. TimedeltaArray,
  18. )
  19. from pandas.core.internals import BlockManager
  20. def load_reduce(self):
  21. stack = self.stack
  22. args = stack.pop()
  23. func = stack[-1]
  24. try:
  25. stack[-1] = func(*args)
  26. return
  27. except TypeError as err:
  28. # If we have a deprecated function,
  29. # try to replace and try again.
  30. msg = "_reconstruct: First argument must be a sub-type of ndarray"
  31. if msg in str(err):
  32. try:
  33. cls = args[0]
  34. stack[-1] = object.__new__(cls)
  35. return
  36. except TypeError:
  37. pass
  38. elif args and isinstance(args[0], type) and issubclass(args[0], BaseOffset):
  39. # TypeError: object.__new__(Day) is not safe, use Day.__new__()
  40. cls = args[0]
  41. stack[-1] = cls.__new__(*args)
  42. return
  43. elif args and issubclass(args[0], PeriodArray):
  44. cls = args[0]
  45. stack[-1] = NDArrayBacked.__new__(*args)
  46. return
  47. raise
  48. # If classes are moved, provide compat here.
  49. _class_locations_map = {
  50. ("pandas.core.sparse.array", "SparseArray"): ("pandas.core.arrays", "SparseArray"),
  51. # 15477
  52. ("pandas.core.base", "FrozenNDArray"): ("numpy", "ndarray"),
  53. ("pandas.core.indexes.frozen", "FrozenNDArray"): ("numpy", "ndarray"),
  54. ("pandas.core.base", "FrozenList"): ("pandas.core.indexes.frozen", "FrozenList"),
  55. # 10890
  56. ("pandas.core.series", "TimeSeries"): ("pandas.core.series", "Series"),
  57. ("pandas.sparse.series", "SparseTimeSeries"): (
  58. "pandas.core.sparse.series",
  59. "SparseSeries",
  60. ),
  61. # 12588, extensions moving
  62. ("pandas._sparse", "BlockIndex"): ("pandas._libs.sparse", "BlockIndex"),
  63. ("pandas.tslib", "Timestamp"): ("pandas._libs.tslib", "Timestamp"),
  64. # 18543 moving period
  65. ("pandas._period", "Period"): ("pandas._libs.tslibs.period", "Period"),
  66. ("pandas._libs.period", "Period"): ("pandas._libs.tslibs.period", "Period"),
  67. # 18014 moved __nat_unpickle from _libs.tslib-->_libs.tslibs.nattype
  68. ("pandas.tslib", "__nat_unpickle"): (
  69. "pandas._libs.tslibs.nattype",
  70. "__nat_unpickle",
  71. ),
  72. ("pandas._libs.tslib", "__nat_unpickle"): (
  73. "pandas._libs.tslibs.nattype",
  74. "__nat_unpickle",
  75. ),
  76. # 15998 top-level dirs moving
  77. ("pandas.sparse.array", "SparseArray"): (
  78. "pandas.core.arrays.sparse",
  79. "SparseArray",
  80. ),
  81. ("pandas.indexes.base", "_new_Index"): ("pandas.core.indexes.base", "_new_Index"),
  82. ("pandas.indexes.base", "Index"): ("pandas.core.indexes.base", "Index"),
  83. ("pandas.indexes.numeric", "Int64Index"): (
  84. "pandas.core.indexes.base",
  85. "Index", # updated in 50775
  86. ),
  87. ("pandas.indexes.range", "RangeIndex"): ("pandas.core.indexes.range", "RangeIndex"),
  88. ("pandas.indexes.multi", "MultiIndex"): ("pandas.core.indexes.multi", "MultiIndex"),
  89. ("pandas.tseries.index", "_new_DatetimeIndex"): (
  90. "pandas.core.indexes.datetimes",
  91. "_new_DatetimeIndex",
  92. ),
  93. ("pandas.tseries.index", "DatetimeIndex"): (
  94. "pandas.core.indexes.datetimes",
  95. "DatetimeIndex",
  96. ),
  97. ("pandas.tseries.period", "PeriodIndex"): (
  98. "pandas.core.indexes.period",
  99. "PeriodIndex",
  100. ),
  101. # 19269, arrays moving
  102. ("pandas.core.categorical", "Categorical"): ("pandas.core.arrays", "Categorical"),
  103. # 19939, add timedeltaindex, float64index compat from 15998 move
  104. ("pandas.tseries.tdi", "TimedeltaIndex"): (
  105. "pandas.core.indexes.timedeltas",
  106. "TimedeltaIndex",
  107. ),
  108. ("pandas.indexes.numeric", "Float64Index"): (
  109. "pandas.core.indexes.base",
  110. "Index", # updated in 50775
  111. ),
  112. # 50775, remove Int64Index, UInt64Index & Float64Index from codabase
  113. ("pandas.core.indexes.numeric", "Int64Index"): (
  114. "pandas.core.indexes.base",
  115. "Index",
  116. ),
  117. ("pandas.core.indexes.numeric", "UInt64Index"): (
  118. "pandas.core.indexes.base",
  119. "Index",
  120. ),
  121. ("pandas.core.indexes.numeric", "Float64Index"): (
  122. "pandas.core.indexes.base",
  123. "Index",
  124. ),
  125. }
  126. # our Unpickler sub-class to override methods and some dispatcher
  127. # functions for compat and uses a non-public class of the pickle module.
  128. class Unpickler(pkl._Unpickler):
  129. def find_class(self, module, name):
  130. # override superclass
  131. key = (module, name)
  132. module, name = _class_locations_map.get(key, key)
  133. return super().find_class(module, name)
  134. Unpickler.dispatch = copy.copy(Unpickler.dispatch)
  135. Unpickler.dispatch[pkl.REDUCE[0]] = load_reduce
  136. def load_newobj(self) -> None:
  137. args = self.stack.pop()
  138. cls = self.stack[-1]
  139. # compat
  140. if issubclass(cls, Index):
  141. obj = object.__new__(cls)
  142. elif issubclass(cls, DatetimeArray) and not args:
  143. arr = np.array([], dtype="M8[ns]")
  144. obj = cls.__new__(cls, arr, arr.dtype)
  145. elif issubclass(cls, TimedeltaArray) and not args:
  146. arr = np.array([], dtype="m8[ns]")
  147. obj = cls.__new__(cls, arr, arr.dtype)
  148. elif cls is BlockManager and not args:
  149. obj = cls.__new__(cls, (), [], False)
  150. else:
  151. obj = cls.__new__(cls, *args)
  152. self.stack[-1] = obj
  153. Unpickler.dispatch[pkl.NEWOBJ[0]] = load_newobj
  154. def load_newobj_ex(self) -> None:
  155. kwargs = self.stack.pop()
  156. args = self.stack.pop()
  157. cls = self.stack.pop()
  158. # compat
  159. if issubclass(cls, Index):
  160. obj = object.__new__(cls)
  161. else:
  162. obj = cls.__new__(cls, *args, **kwargs)
  163. self.append(obj)
  164. try:
  165. Unpickler.dispatch[pkl.NEWOBJ_EX[0]] = load_newobj_ex
  166. except (AttributeError, KeyError):
  167. pass
  168. def load(fh, encoding: str | None = None, is_verbose: bool = False):
  169. """
  170. Load a pickle, with a provided encoding,
  171. Parameters
  172. ----------
  173. fh : a filelike object
  174. encoding : an optional encoding
  175. is_verbose : show exception output
  176. """
  177. try:
  178. fh.seek(0)
  179. if encoding is not None:
  180. up = Unpickler(fh, encoding=encoding)
  181. else:
  182. up = Unpickler(fh)
  183. # "Unpickler" has no attribute "is_verbose" [attr-defined]
  184. up.is_verbose = is_verbose # type: ignore[attr-defined]
  185. return up.load()
  186. except (ValueError, TypeError):
  187. raise
  188. def loads(
  189. bytes_object: bytes,
  190. *,
  191. fix_imports: bool = True,
  192. encoding: str = "ASCII",
  193. errors: str = "strict",
  194. ):
  195. """
  196. Analogous to pickle._loads.
  197. """
  198. fd = io.BytesIO(bytes_object)
  199. return Unpickler(
  200. fd, fix_imports=fix_imports, encoding=encoding, errors=errors
  201. ).load()
  202. @contextlib.contextmanager
  203. def patch_pickle() -> Generator[None, None, None]:
  204. """
  205. Temporarily patch pickle to use our unpickler.
  206. """
  207. orig_loads = pkl.loads
  208. try:
  209. setattr(pkl, "loads", loads)
  210. yield
  211. finally:
  212. setattr(pkl, "loads", orig_loads)