123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- import importlib
- from abc import ABC, abstractmethod
- from pickle import ( # type: ignore[attr-defined] # type: ignore[attr-defined]
- _getattribute,
- _Pickler,
- whichmodule as _pickle_whichmodule,
- )
- from types import ModuleType
- from typing import Any, Dict, List, Optional, Tuple
- from ._mangling import demangle, get_mangle_prefix, is_mangled
- __all__ = ["ObjNotFoundError", "ObjMismatchError", "Importer", "OrderedImporter"]
- class ObjNotFoundError(Exception):
- """Raised when an importer cannot find an object by searching for its name."""
- pass
- class ObjMismatchError(Exception):
- """Raised when an importer found a different object with the same name as the user-provided one."""
- pass
- class Importer(ABC):
- """Represents an environment to import modules from.
- By default, you can figure out what module an object belongs by checking
- __module__ and importing the result using __import__ or importlib.import_module.
- torch.package introduces module importers other than the default one.
- Each PackageImporter introduces a new namespace. Potentially a single
- name (e.g. 'foo.bar') is present in multiple namespaces.
- It supports two main operations:
- import_module: module_name -> module object
- get_name: object -> (parent module name, name of obj within module)
- The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError.
- module_name, obj_name = env.get_name(obj)
- module = env.import_module(module_name)
- obj2 = getattr(module, obj_name)
- assert obj1 is obj2
- """
- modules: Dict[str, ModuleType]
- @abstractmethod
- def import_module(self, module_name: str) -> ModuleType:
- """Import `module_name` from this environment.
- The contract is the same as for importlib.import_module.
- """
- pass
- def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]:
- """Given an object, return a name that can be used to retrieve the
- object from this environment.
- Args:
- obj: An object to get the the module-environment-relative name for.
- name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`.
- This is only here to match how Pickler handles __reduce__ functions that return a string,
- don't use otherwise.
- Returns:
- A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment.
- Use it like:
- mod = importer.import_module(parent_module_name)
- obj = getattr(mod, attr_name)
- Raises:
- ObjNotFoundError: we couldn't retrieve `obj by name.
- ObjMisMatchError: we found a different object with the same name as `obj`.
- """
- if name is None and obj and _Pickler.dispatch.get(type(obj)) is None:
- # Honor the string return variant of __reduce__, which will give us
- # a global name to search for in this environment.
- # TODO: I guess we should do copyreg too?
- reduce = getattr(obj, "__reduce__", None)
- if reduce is not None:
- try:
- rv = reduce()
- if isinstance(rv, str):
- name = rv
- except Exception:
- pass
- if name is None:
- name = getattr(obj, "__qualname__", None)
- if name is None:
- name = obj.__name__
- orig_module_name = self.whichmodule(obj, name)
- # Demangle the module name before importing. If this obj came out of a
- # PackageImporter, `__module__` will be mangled. See mangling.md for
- # details.
- module_name = demangle(orig_module_name)
- # Check that this name will indeed return the correct object
- try:
- module = self.import_module(module_name)
- obj2, _ = _getattribute(module, name)
- except (ImportError, KeyError, AttributeError):
- raise ObjNotFoundError(
- f"{obj} was not found as {module_name}.{name}"
- ) from None
- if obj is obj2:
- return module_name, name
- def get_obj_info(obj):
- assert name is not None
- module_name = self.whichmodule(obj, name)
- is_mangled_ = is_mangled(module_name)
- location = (
- get_mangle_prefix(module_name)
- if is_mangled_
- else "the current Python environment"
- )
- importer_name = (
- f"the importer for {get_mangle_prefix(module_name)}"
- if is_mangled_
- else "'sys_importer'"
- )
- return module_name, location, importer_name
- obj_module_name, obj_location, obj_importer_name = get_obj_info(obj)
- obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2)
- msg = (
- f"\n\nThe object provided is from '{obj_module_name}', "
- f"which is coming from {obj_location}."
- f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}."
- "\nTo fix this, make sure this 'PackageExporter's importer lists "
- f"{obj_importer_name} before {obj2_importer_name}."
- )
- raise ObjMismatchError(msg)
- def whichmodule(self, obj: Any, name: str) -> str:
- """Find the module name an object belongs to.
- This should be considered internal for end-users, but developers of
- an importer can override it to customize the behavior.
- Taken from pickle.py, but modified to exclude the search into sys.modules
- """
- module_name = getattr(obj, "__module__", None)
- if module_name is not None:
- return module_name
- # Protect the iteration by using a list copy of self.modules against dynamic
- # modules that trigger imports of other modules upon calls to getattr.
- for module_name, module in self.modules.copy().items():
- if (
- module_name == "__main__"
- or module_name == "__mp_main__" # bpo-42406
- or module is None
- ):
- continue
- try:
- if _getattribute(module, name)[0] is obj:
- return module_name
- except AttributeError:
- pass
- return "__main__"
- class _SysImporter(Importer):
- """An importer that implements the default behavior of Python."""
- def import_module(self, module_name: str):
- return importlib.import_module(module_name)
- def whichmodule(self, obj: Any, name: str) -> str:
- return _pickle_whichmodule(obj, name)
- sys_importer = _SysImporter()
- class OrderedImporter(Importer):
- """A compound importer that takes a list of importers and tries them one at a time.
- The first importer in the list that returns a result "wins".
- """
- def __init__(self, *args):
- self._importers: List[Importer] = list(args)
- def _is_torchpackage_dummy(self, module):
- """Returns true iff this module is an empty PackageNode in a torch.package.
- If you intern `a.b` but never use `a` in your code, then `a` will be an
- empty module with no source. This can break cases where we are trying to
- re-package an object after adding a real dependency on `a`, since
- OrderedImportere will resolve `a` to the dummy package and stop there.
- See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769
- """
- if not getattr(module, "__torch_package__", False):
- return False
- if not hasattr(module, "__path__"):
- return False
- if not hasattr(module, "__file__"):
- return True
- return module.__file__ is None
- def import_module(self, module_name: str) -> ModuleType:
- last_err = None
- for importer in self._importers:
- if not isinstance(importer, Importer):
- raise TypeError(
- f"{importer} is not a Importer. "
- "All importers in OrderedImporter must inherit from Importer."
- )
- try:
- module = importer.import_module(module_name)
- if self._is_torchpackage_dummy(module):
- continue
- return module
- except ModuleNotFoundError as err:
- last_err = err
- if last_err is not None:
- raise last_err
- else:
- raise ModuleNotFoundError(module_name)
- def whichmodule(self, obj: Any, name: str) -> str:
- for importer in self._importers:
- module_name = importer.whichmodule(obj, name)
- if module_name != "__main__":
- return module_name
- return "__main__"
|