123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- """isort:skip_file"""
- from pickle import ( # type: ignore[attr-defined]
- _compat_pickle,
- _extension_registry,
- _getattribute,
- _Pickler,
- EXT1,
- EXT2,
- EXT4,
- GLOBAL,
- Pickler,
- PicklingError,
- STACK_GLOBAL,
- )
- from struct import pack
- from types import FunctionType
- from .importer import Importer, ObjMismatchError, ObjNotFoundError, sys_importer
- class PackagePickler(_Pickler):
- """Package-aware pickler.
- This behaves the same as a normal pickler, except it uses an `Importer`
- to find objects and modules to save.
- """
- def __init__(self, importer: Importer, *args, **kwargs):
- self.importer = importer
- super().__init__(*args, **kwargs)
- # Make sure the dispatch table copied from _Pickler is up-to-date.
- # Previous issues have been encountered where a library (e.g. dill)
- # mutate _Pickler.dispatch, PackagePickler makes a copy when this lib
- # is imported, then the offending library removes its dispatch entries,
- # leaving PackagePickler with a stale dispatch table that may cause
- # unwanted behavior.
- self.dispatch = _Pickler.dispatch.copy() # type: ignore[misc]
- self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment]
- def save_global(self, obj, name=None):
- # unfortunately the pickler code is factored in a way that
- # forces us to copy/paste this function. The only change is marked
- # CHANGED below.
- write = self.write # type: ignore[attr-defined]
- memo = self.memo # type: ignore[attr-defined]
- # CHANGED: import module from module environment instead of __import__
- try:
- module_name, name = self.importer.get_name(obj, name)
- except (ObjNotFoundError, ObjMismatchError) as err:
- raise PicklingError(f"Can't pickle {obj}: {str(err)}") from None
- module = self.importer.import_module(module_name)
- _, parent = _getattribute(module, name)
- # END CHANGED
- if self.proto >= 2: # type: ignore[attr-defined]
- code = _extension_registry.get((module_name, name))
- if code:
- assert code > 0
- if code <= 0xFF:
- write(EXT1 + pack("<B", code))
- elif code <= 0xFFFF:
- write(EXT2 + pack("<H", code))
- else:
- write(EXT4 + pack("<i", code))
- return
- lastname = name.rpartition(".")[2]
- if parent is module:
- name = lastname
- # Non-ASCII identifiers are supported only with protocols >= 3.
- if self.proto >= 4: # type: ignore[attr-defined]
- self.save(module_name) # type: ignore[attr-defined]
- self.save(name) # type: ignore[attr-defined]
- write(STACK_GLOBAL)
- elif parent is not module:
- self.save_reduce(getattr, (parent, lastname)) # type: ignore[attr-defined]
- elif self.proto >= 3: # type: ignore[attr-defined]
- write(
- GLOBAL
- + bytes(module_name, "utf-8")
- + b"\n"
- + bytes(name, "utf-8")
- + b"\n"
- )
- else:
- if self.fix_imports: # type: ignore[attr-defined]
- r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
- r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
- if (module_name, name) in r_name_mapping:
- module_name, name = r_name_mapping[(module_name, name)]
- elif module_name in r_import_mapping:
- module_name = r_import_mapping[module_name]
- try:
- write(
- GLOBAL
- + bytes(module_name, "ascii")
- + b"\n"
- + bytes(name, "ascii")
- + b"\n"
- )
- except UnicodeEncodeError:
- raise PicklingError(
- "can't pickle global identifier '%s.%s' using "
- "pickle protocol %i" % (module, name, self.proto) # type: ignore[attr-defined]
- ) from None
- self.memoize(obj) # type: ignore[attr-defined]
- def create_pickler(data_buf, importer, protocol=4):
- if importer is sys_importer:
- # if we are using the normal import library system, then
- # we can use the C implementation of pickle which is faster
- return Pickler(data_buf, protocol=protocol)
- else:
- return PackagePickler(importer, data_buf, protocol=protocol)
|