_package_pickler.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. """isort:skip_file"""
  2. from pickle import ( # type: ignore[attr-defined]
  3. _compat_pickle,
  4. _extension_registry,
  5. _getattribute,
  6. _Pickler,
  7. EXT1,
  8. EXT2,
  9. EXT4,
  10. GLOBAL,
  11. Pickler,
  12. PicklingError,
  13. STACK_GLOBAL,
  14. )
  15. from struct import pack
  16. from types import FunctionType
  17. from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
  18. class PackagePickler(_Pickler):
  19. """Package-aware pickler.
  20. This behaves the same as a normal pickler, except it uses an `Importer`
  21. to find objects and modules to save.
  22. """
  23. def __init__(self, importer: Importer, *args, **kwargs):
  24. self.importer = importer
  25. super().__init__(*args, **kwargs)
  26. # Make sure the dispatch table copied from _Pickler is up-to-date.
  27. # Previous issues have been encountered where a library (e.g. dill)
  28. # mutate _Pickler.dispatch, PackagePickler makes a copy when this lib
  29. # is imported, then the offending library removes its dispatch entries,
  30. # leaving PackagePickler with a stale dispatch table that may cause
  31. # unwanted behavior.
  32. self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc]
  33. self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
  34. def save_global(self, obj, name=None):
  35. # unfortunately the pickler code is factored in a way that
  36. # forces us to copy/paste this function. The only change is marked
  37. # CHANGED below.
  38. write = self.write # type: ignore[attr-defined]
  39. memo = self.memo # type: ignore[attr-defined]
  40. # CHANGED: import module from module environment instead of __import__
  41. try:
  42. module_name, name = self.importer.get_name(obj, name)
  43. except (ObjNotFoundError, ObjMismatchError) as err:
  44. raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None
  45. module = self.importer.import_module(module_name)
  46. _, parent = _getattribute(module, name)
  47. # END CHANGED
  48. if self.proto >= 2: # type: ignore[attr-defined]
  49. code = _extension_registry.get((module_name, name))
  50. if code:
  51. assert code > 0
  52. if code <= 0xFF:
  53. write(EXT1 + pack("<B", code))
  54. elif code <= 0xFFFF:
  55. write(EXT2 + pack("<H", code))
  56. else:
  57. write(EXT4 + pack("<i", code))
  58. return
  59. lastname = name.rpartition(".")[2]
  60. if parent is module:
  61. name = lastname
  62. # Non-ASCII identifiers are supported only with protocols >= 3.
  63. if self.proto >= 4: # type: ignore[attr-defined]
  64. self.save(module_name) # type: ignore[attr-defined]
  65. self.save(name) # type: ignore[attr-defined]
  66. write(STACK_GLOBAL)
  67. elif parent is not module:
  68. self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined]
  69. elif self.proto >= 3: # type: ignore[attr-defined]
  70. write(
  71. GLOBAL
  72. + bytes(module_name, "utf-8")
  73. + b"\n"
  74. + bytes(name, "utf-8")
  75. + b"\n"
  76. )
  77. else:
  78. if self.fix_imports: # type: ignore[attr-defined]
  79. r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
  80. r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
  81. if (module_name, name) in r_name_mapping:
  82. module_name, name = r_name_mapping[(module_name, name)]
  83. elif module_name in r_import_mapping:
  84. module_name = r_import_mapping[module_name]
  85. try:
  86. write(
  87. GLOBAL
  88. + bytes(module_name, "ascii")
  89. + b"\n"
  90. + bytes(name, "ascii")
  91. + b"\n"
  92. )
  93. except UnicodeEncodeError:
  94. raise PicklingError(
  95. "can't pickle global identifier '%s.%s' using "
  96. "pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined]
  97. ) from None
  98. self.memoize(obj) # type: ignore[attr-defined]
  99. def create_pickler(data_buf, importer, protocol=4):
  100. if importer is sys_importer:
  101. # if we are using the normal import library system, then
  102. # we can use the C implementation of pickle which is faster
  103. return Pickler(data_buf, protocol=protocol)
  104. else:
  105. return PackagePickler(importer, data_buf, protocol=protocol)