package_importer.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752
  1. import builtins
  2. import importlib
  3. import importlib.machinery
  4. import inspect
  5. import io
  6. import linecache
  7. import os.path
  8. import types
  9. from contextlib import contextmanager
  10. from pathlib import Path
  11. from typing import Any, BinaryIO, Callable, cast, Dict, Iterable, List, Optional, Union
  12. from weakref import WeakValueDictionary
  13. import torch
  14. from torch.serialization import _get_restore_location, _maybe_decode_ascii
  15. from ._directory_reader import DirectoryReader
  16. from ._importlib import (
  17. _calc___package__,
  18. _normalize_line_endings,
  19. _normalize_path,
  20. _resolve_name,
  21. _sanity_check,
  22. )
  23. from ._mangling import demangle, PackageMangler
  24. from ._package_unpickler import PackageUnpickler
  25. from .file_structure_representation import _create_directory_from_file_list, Directory
  26. from .glob_group import GlobPattern
  27. from .importer import Importer
  28. __all__ = ["PackageImporter"]
  29. # This is a list of imports that are implicitly allowed even if they haven't
  30. # been marked as extern. This is to work around the fact that Torch implicitly
  31. # depends on numpy and package can't track it.
  32. # https://github.com/pytorch/MultiPy/issues/46
  33. IMPLICIT_IMPORT_ALLOWLIST: Iterable[str] = [
  34. "numpy",
  35. "numpy.core",
  36. "numpy.core._multiarray_umath",
  37. # FX GraphModule might depend on builtins module and users usually
  38. # don't extern builtins. Here we import it here by default.
  39. "builtins",
  40. ]
  41. class PackageImporter(Importer):
  42. """Importers allow you to load code written to packages by :class:`PackageExporter`.
  43. Code is loaded in a hermetic way, using files from the package
  44. rather than the normal python import system. This allows
  45. for the packaging of PyTorch model code and data so that it can be run
  46. on a server or used in the future for transfer learning.
  47. The importer for packages ensures that code in the module can only be loaded from
  48. within the package, except for modules explicitly listed as external during export.
  49. The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
  50. This prevents "implicit" dependencies where the package runs locally because it is importing
  51. a locally-installed package, but then fails when the package is copied to another machine.
  52. """
  53. """The dictionary of already loaded modules from this package, equivalent to ``sys.modules`` but
  54. local to this importer.
  55. """
  56. modules: Dict[str, types.ModuleType]
  57. def __init__(
  58. self,
  59. file_or_buffer: Union[str, torch._C.PyTorchFileReader, Path, BinaryIO],
  60. module_allowed: Callable[[str], bool] = lambda module_name: True,
  61. ):
  62. """Open ``file_or_buffer`` for importing. This checks that the imported package only requires modules
  63. allowed by ``module_allowed``
  64. Args:
  65. file_or_buffer: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
  66. a string, or an ``os.PathLike`` object containing a filename.
  67. module_allowed (Callable[[str], bool], optional): A method to determine if a externally provided module
  68. should be allowed. Can be used to ensure packages loaded do not depend on modules that the server
  69. does not support. Defaults to allowing anything.
  70. Raises:
  71. ImportError: If the package will use a disallowed module.
  72. """
  73. torch._C._log_api_usage_once("torch.package.PackageImporter")
  74. self.zip_reader: Any
  75. if isinstance(file_or_buffer, torch._C.PyTorchFileReader):
  76. self.filename = "<pytorch_file_reader>"
  77. self.zip_reader = file_or_buffer
  78. elif isinstance(file_or_buffer, (Path, str)):
  79. self.filename = str(file_or_buffer)
  80. if not os.path.isdir(self.filename):
  81. self.zip_reader = torch._C.PyTorchFileReader(self.filename)
  82. else:
  83. self.zip_reader = DirectoryReader(self.filename)
  84. else:
  85. self.filename = "<binary>"
  86. self.zip_reader = torch._C.PyTorchFileReader(file_or_buffer)
  87. self.root = _PackageNode(None)
  88. self.modules = {}
  89. self.extern_modules = self._read_extern()
  90. for extern_module in self.extern_modules:
  91. if not module_allowed(extern_module):
  92. raise ImportError(
  93. f"package '{file_or_buffer}' needs the external module '{extern_module}' "
  94. f"but that module has been disallowed"
  95. )
  96. self._add_extern(extern_module)
  97. for fname in self.zip_reader.get_all_records():
  98. self._add_file(fname)
  99. self.patched_builtins = builtins.__dict__.copy()
  100. self.patched_builtins["__import__"] = self.__import__
  101. # Allow packaged modules to reference their PackageImporter
  102. self.modules["torch_package_importer"] = self # type: ignore[assignment]
  103. self._mangler = PackageMangler()
  104. # used for reduce deserializaiton
  105. self.storage_context: Any = None
  106. self.last_map_location = None
  107. # used for torch.serialization._load
  108. self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs)
  109. def import_module(self, name: str, package=None):
  110. """Load a module from the package if it hasn't already been loaded, and then return
  111. the module. Modules are loaded locally
  112. to the importer and will appear in ``self.modules`` rather than ``sys.modules``.
  113. Args:
  114. name (str): Fully qualified name of the module to load.
  115. package ([type], optional): Unused, but present to match the signature of importlib.import_module. Defaults to ``None``.
  116. Returns:
  117. types.ModuleType: The (possibly already) loaded module.
  118. """
  119. # We should always be able to support importing modules from this package.
  120. # This is to support something like:
  121. # obj = importer.load_pickle(...)
  122. # importer.import_module(obj.__module__) <- this string will be mangled
  123. #
  124. # Note that _mangler.demangle will not demangle any module names
  125. # produced by a different PackageImporter instance.
  126. name = self._mangler.demangle(name)
  127. return self._gcd_import(name)
  128. def load_binary(self, package: str, resource: str) -> bytes:
  129. """Load raw bytes.
  130. Args:
  131. package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
  132. resource (str): The unique name for the resource.
  133. Returns:
  134. bytes: The loaded data.
  135. """
  136. path = self._zipfile_path(package, resource)
  137. return self.zip_reader.get_record(path)
  138. def load_text(
  139. self,
  140. package: str,
  141. resource: str,
  142. encoding: str = "utf-8",
  143. errors: str = "strict",
  144. ) -> str:
  145. """Load a string.
  146. Args:
  147. package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
  148. resource (str): The unique name for the resource.
  149. encoding (str, optional): Passed to ``decode``. Defaults to ``'utf-8'``.
  150. errors (str, optional): Passed to ``decode``. Defaults to ``'strict'``.
  151. Returns:
  152. str: The loaded text.
  153. """
  154. data = self.load_binary(package, resource)
  155. return data.decode(encoding, errors)
  156. def load_pickle(self, package: str, resource: str, map_location=None) -> Any:
  157. """Unpickles the resource from the package, loading any modules that are needed to construct the objects
  158. using :meth:`import_module`.
  159. Args:
  160. package (str): The name of module package (e.g. ``"my_package.my_subpackage"``).
  161. resource (str): The unique name for the resource.
  162. map_location: Passed to `torch.load` to determine how tensors are mapped to devices. Defaults to ``None``.
  163. Returns:
  164. Any: The unpickled object.
  165. """
  166. pickle_file = self._zipfile_path(package, resource)
  167. restore_location = _get_restore_location(map_location)
  168. loaded_storages = {}
  169. loaded_reduces = {}
  170. storage_context = torch._C.DeserializationStorageContext()
  171. def load_tensor(dtype, size, key, location, restore_location):
  172. name = f"{key}.storage"
  173. if storage_context.has_storage(name):
  174. storage = storage_context.get_storage(name, dtype)._typed_storage()
  175. else:
  176. tensor = self.zip_reader.get_storage_from_record(
  177. ".data/" + name, size, dtype
  178. )
  179. if isinstance(self.zip_reader, torch._C.PyTorchFileReader):
  180. storage_context.add_storage(name, tensor)
  181. storage = tensor._typed_storage()
  182. loaded_storages[key] = restore_location(storage, location)
  183. def persistent_load(saved_id):
  184. assert isinstance(saved_id, tuple)
  185. typename = _maybe_decode_ascii(saved_id[0])
  186. data = saved_id[1:]
  187. if typename == "storage":
  188. storage_type, key, location, size = data
  189. dtype = storage_type.dtype
  190. if key not in loaded_storages:
  191. load_tensor(
  192. dtype,
  193. size,
  194. key,
  195. _maybe_decode_ascii(location),
  196. restore_location,
  197. )
  198. storage = loaded_storages[key]
  199. # TODO: Once we decide to break serialization FC, we can
  200. # stop wrapping with TypedStorage
  201. return torch.storage.TypedStorage(
  202. wrap_storage=storage._untyped_storage, dtype=dtype, _internal=True
  203. )
  204. elif typename == "reduce_package":
  205. # to fix BC breaking change, objects on this load path
  206. # will be loaded multiple times erroneously
  207. if len(data) == 2:
  208. func, args = data
  209. return func(self, *args)
  210. reduce_id, func, args = data
  211. if reduce_id not in loaded_reduces:
  212. loaded_reduces[reduce_id] = func(self, *args)
  213. return loaded_reduces[reduce_id]
  214. else:
  215. f"Unknown typename for persistent_load, expected 'storage' or 'reduce_package' but got '{typename}'"
  216. # Load the data (which may in turn use `persistent_load` to load tensors)
  217. data_file = io.BytesIO(self.zip_reader.get_record(pickle_file))
  218. unpickler = self.Unpickler(data_file)
  219. unpickler.persistent_load = persistent_load # type: ignore[assignment]
  220. @contextmanager
  221. def set_deserialization_context():
  222. # to let reduce_package access deserializaiton context
  223. self.storage_context = storage_context
  224. self.last_map_location = map_location
  225. try:
  226. yield
  227. finally:
  228. self.storage_context = None
  229. self.last_map_location = None
  230. with set_deserialization_context():
  231. result = unpickler.load()
  232. # TODO from zdevito:
  233. # This stateful weird function will need to be removed in our efforts
  234. # to unify the format. It has a race condition if multiple python
  235. # threads try to read independent files
  236. torch._utils._validate_loaded_sparse_tensors()
  237. return result
  238. def id(self):
  239. """
  240. Returns internal identifier that torch.package uses to distinguish :class:`PackageImporter` instances.
  241. Looks like::
  242. <torch_package_0>
  243. """
  244. return self._mangler.parent_name()
  245. def file_structure(
  246. self, *, include: "GlobPattern" = "**", exclude: "GlobPattern" = ()
  247. ) -> Directory:
  248. """Returns a file structure representation of package's zipfile.
  249. Args:
  250. include (Union[List[str], str]): An optional string e.g. ``"my_package.my_subpackage"``, or optional list of strings
  251. for the names of the files to be included in the zipfile representation. This can also be
  252. a glob-style pattern, as described in :meth:`PackageExporter.mock`
  253. exclude (Union[List[str], str]): An optional pattern that excludes files whose name match the pattern.
  254. Returns:
  255. :class:`Directory`
  256. """
  257. return _create_directory_from_file_list(
  258. self.filename, self.zip_reader.get_all_records(), include, exclude
  259. )
  260. def python_version(self):
  261. """Returns the version of python that was used to create this package.
  262. Note: this function is experimental and not Forward Compatible. The plan is to move this into a lock
  263. file later on.
  264. Returns:
  265. :class:`Optional[str]` a python version e.g. 3.8.9 or None if no version was stored with this package
  266. """
  267. python_version_path = ".data/python_version"
  268. return (
  269. self.zip_reader.get_record(python_version_path).decode("utf-8").strip()
  270. if self.zip_reader.has_record(python_version_path)
  271. else None
  272. )
  273. def _read_extern(self):
  274. return (
  275. self.zip_reader.get_record(".data/extern_modules")
  276. .decode("utf-8")
  277. .splitlines(keepends=False)
  278. )
  279. def _make_module(
  280. self, name: str, filename: Optional[str], is_package: bool, parent: str
  281. ):
  282. mangled_filename = self._mangler.mangle(filename) if filename else None
  283. spec = importlib.machinery.ModuleSpec(
  284. name,
  285. self, # type: ignore[arg-type]
  286. origin="<package_importer>",
  287. is_package=is_package,
  288. )
  289. module = importlib.util.module_from_spec(spec)
  290. self.modules[name] = module
  291. module.__name__ = self._mangler.mangle(name)
  292. ns = module.__dict__
  293. ns["__spec__"] = spec
  294. ns["__loader__"] = self
  295. ns["__file__"] = mangled_filename
  296. ns["__cached__"] = None
  297. ns["__builtins__"] = self.patched_builtins
  298. ns["__torch_package__"] = True
  299. # Add this module to our private global registry. It should be unique due to mangling.
  300. assert module.__name__ not in _package_imported_modules
  301. _package_imported_modules[module.__name__] = module
  302. # pre-emptively install on the parent to prevent IMPORT_FROM from trying to
  303. # access sys.modules
  304. self._install_on_parent(parent, name, module)
  305. if filename is not None:
  306. assert mangled_filename is not None
  307. # pre-emptively install the source in `linecache` so that stack traces,
  308. # `inspect`, etc. work.
  309. assert filename not in linecache.cache # type: ignore[attr-defined]
  310. linecache.lazycache(mangled_filename, ns)
  311. code = self._compile_source(filename, mangled_filename)
  312. exec(code, ns)
  313. return module
  314. def _load_module(self, name: str, parent: str):
  315. cur: _PathNode = self.root
  316. for atom in name.split("."):
  317. if not isinstance(cur, _PackageNode) or atom not in cur.children:
  318. if name in IMPLICIT_IMPORT_ALLOWLIST:
  319. module = self.modules[name] = importlib.import_module(name)
  320. return module
  321. raise ModuleNotFoundError(
  322. f'No module named "{name}" in self-contained archive "{self.filename}"'
  323. f" and the module is also not in the list of allowed external modules: {self.extern_modules}",
  324. name=name,
  325. )
  326. cur = cur.children[atom]
  327. if isinstance(cur, _ExternNode):
  328. module = self.modules[name] = importlib.import_module(name)
  329. return module
  330. return self._make_module(name, cur.source_file, isinstance(cur, _PackageNode), parent) # type: ignore[attr-defined]
  331. def _compile_source(self, fullpath: str, mangled_filename: str):
  332. source = self.zip_reader.get_record(fullpath)
  333. source = _normalize_line_endings(source)
  334. return compile(source, mangled_filename, "exec", dont_inherit=True)
  335. # note: named `get_source` so that linecache can find the source
  336. # when this is the __loader__ of a module.
  337. def get_source(self, module_name) -> str:
  338. # linecache calls `get_source` with the `module.__name__` as the argument, so we must demangle it here.
  339. module = self.import_module(demangle(module_name))
  340. return self.zip_reader.get_record(demangle(module.__file__)).decode("utf-8")
  341. # note: named `get_resource_reader` so that importlib.resources can find it.
  342. # This is otherwise considered an internal method.
  343. def get_resource_reader(self, fullname):
  344. try:
  345. package = self._get_package(fullname)
  346. except ImportError:
  347. return None
  348. if package.__loader__ is not self:
  349. return None
  350. return _PackageResourceReader(self, fullname)
  351. def _install_on_parent(self, parent: str, name: str, module: types.ModuleType):
  352. if not parent:
  353. return
  354. # Set the module as an attribute on its parent.
  355. parent_module = self.modules[parent]
  356. if parent_module.__loader__ is self:
  357. setattr(parent_module, name.rpartition(".")[2], module)
  358. # note: copied from cpython's import code, with call to create module replaced with _make_module
  359. def _do_find_and_load(self, name):
  360. path = None
  361. parent = name.rpartition(".")[0]
  362. module_name_no_parent = name.rpartition(".")[-1]
  363. if parent:
  364. if parent not in self.modules:
  365. self._gcd_import(parent)
  366. # Crazy side-effects!
  367. if name in self.modules:
  368. return self.modules[name]
  369. parent_module = self.modules[parent]
  370. try:
  371. path = parent_module.__path__ # type: ignore[attr-defined]
  372. except AttributeError:
  373. # when we attempt to import a package only containing pybinded files,
  374. # the parent directory isn't always a package as defined by python,
  375. # so we search if the package is actually there or not before calling the error.
  376. if isinstance(
  377. parent_module.__loader__,
  378. importlib.machinery.ExtensionFileLoader,
  379. ):
  380. if name not in self.extern_modules:
  381. msg = (
  382. _ERR_MSG
  383. + "; {!r} is a c extension module which was not externed. C extension modules \
  384. need to be externed by the PackageExporter in order to be used as we do not support interning them.}."
  385. ).format(name, name)
  386. raise ModuleNotFoundError(msg, name=name) from None
  387. if not isinstance(
  388. parent_module.__dict__.get(module_name_no_parent),
  389. types.ModuleType,
  390. ):
  391. msg = (
  392. _ERR_MSG
  393. + "; {!r} is a c extension package which does not contain {!r}."
  394. ).format(name, parent, name)
  395. raise ModuleNotFoundError(msg, name=name) from None
  396. else:
  397. msg = (_ERR_MSG + "; {!r} is not a package").format(name, parent)
  398. raise ModuleNotFoundError(msg, name=name) from None
  399. module = self._load_module(name, parent)
  400. self._install_on_parent(parent, name, module)
  401. return module
  402. # note: copied from cpython's import code
  403. def _find_and_load(self, name):
  404. module = self.modules.get(name, _NEEDS_LOADING)
  405. if module is _NEEDS_LOADING:
  406. return self._do_find_and_load(name)
  407. if module is None:
  408. message = "import of {} halted; " "None in sys.modules".format(name)
  409. raise ModuleNotFoundError(message, name=name)
  410. # To handle https://github.com/pytorch/pytorch/issues/57490, where std's
  411. # creation of fake submodules via the hacking of sys.modules is not import
  412. # friendly
  413. if name == "os":
  414. self.modules["os.path"] = cast(Any, module).path
  415. elif name == "typing":
  416. self.modules["typing.io"] = cast(Any, module).io
  417. self.modules["typing.re"] = cast(Any, module).re
  418. return module
  419. def _gcd_import(self, name, package=None, level=0):
  420. """Import and return the module based on its name, the package the call is
  421. being made from, and the level adjustment.
  422. This function represents the greatest common denominator of functionality
  423. between import_module and __import__. This includes setting __package__ if
  424. the loader did not.
  425. """
  426. _sanity_check(name, package, level)
  427. if level > 0:
  428. name = _resolve_name(name, package, level)
  429. return self._find_and_load(name)
  430. # note: copied from cpython's import code
  431. def _handle_fromlist(self, module, fromlist, *, recursive=False):
  432. """Figure out what __import__ should return.
  433. The import_ parameter is a callable which takes the name of module to
  434. import. It is required to decouple the function from assuming importlib's
  435. import implementation is desired.
  436. """
  437. module_name = demangle(module.__name__)
  438. # The hell that is fromlist ...
  439. # If a package was imported, try to import stuff from fromlist.
  440. if hasattr(module, "__path__"):
  441. for x in fromlist:
  442. if not isinstance(x, str):
  443. if recursive:
  444. where = module_name + ".__all__"
  445. else:
  446. where = "``from list''"
  447. raise TypeError(
  448. f"Item in {where} must be str, " f"not {type(x).__name__}"
  449. )
  450. elif x == "*":
  451. if not recursive and hasattr(module, "__all__"):
  452. self._handle_fromlist(module, module.__all__, recursive=True)
  453. elif not hasattr(module, x):
  454. from_name = "{}.{}".format(module_name, x)
  455. try:
  456. self._gcd_import(from_name)
  457. except ModuleNotFoundError as exc:
  458. # Backwards-compatibility dictates we ignore failed
  459. # imports triggered by fromlist for modules that don't
  460. # exist.
  461. if (
  462. exc.name == from_name
  463. and self.modules.get(from_name, _NEEDS_LOADING) is not None
  464. ):
  465. continue
  466. raise
  467. return module
  468. def __import__(self, name, globals=None, locals=None, fromlist=(), level=0):
  469. if level == 0:
  470. module = self._gcd_import(name)
  471. else:
  472. globals_ = globals if globals is not None else {}
  473. package = _calc___package__(globals_)
  474. module = self._gcd_import(name, package, level)
  475. if not fromlist:
  476. # Return up to the first dot in 'name'. This is complicated by the fact
  477. # that 'name' may be relative.
  478. if level == 0:
  479. return self._gcd_import(name.partition(".")[0])
  480. elif not name:
  481. return module
  482. else:
  483. # Figure out where to slice the module's name up to the first dot
  484. # in 'name'.
  485. cut_off = len(name) - len(name.partition(".")[0])
  486. # Slice end needs to be positive to alleviate need to special-case
  487. # when ``'.' not in name``.
  488. module_name = demangle(module.__name__)
  489. return self.modules[module_name[: len(module_name) - cut_off]]
  490. else:
  491. return self._handle_fromlist(module, fromlist)
  492. def _get_package(self, package):
  493. """Take a package name or module object and return the module.
  494. If a name, the module is imported. If the passed or imported module
  495. object is not a package, raise an exception.
  496. """
  497. if hasattr(package, "__spec__"):
  498. if package.__spec__.submodule_search_locations is None:
  499. raise TypeError("{!r} is not a package".format(package.__spec__.name))
  500. else:
  501. return package
  502. else:
  503. module = self.import_module(package)
  504. if module.__spec__.submodule_search_locations is None:
  505. raise TypeError("{!r} is not a package".format(package))
  506. else:
  507. return module
  508. def _zipfile_path(self, package, resource=None):
  509. package = self._get_package(package)
  510. assert package.__loader__ is self
  511. name = demangle(package.__name__)
  512. if resource is not None:
  513. resource = _normalize_path(resource)
  514. return f"{name.replace('.', '/')}/{resource}"
  515. else:
  516. return f"{name.replace('.', '/')}"
  517. def _get_or_create_package(
  518. self, atoms: List[str]
  519. ) -> "Union[_PackageNode, _ExternNode]":
  520. cur = self.root
  521. for i, atom in enumerate(atoms):
  522. node = cur.children.get(atom, None)
  523. if node is None:
  524. node = cur.children[atom] = _PackageNode(None)
  525. if isinstance(node, _ExternNode):
  526. return node
  527. if isinstance(node, _ModuleNode):
  528. name = ".".join(atoms[:i])
  529. raise ImportError(
  530. f"inconsistent module structure. module {name} is not a package, but has submodules"
  531. )
  532. assert isinstance(node, _PackageNode)
  533. cur = node
  534. return cur
  535. def _add_file(self, filename: str):
  536. """Assembles a Python module out of the given file. Will ignore files in the .data directory.
  537. Args:
  538. filename (str): the name of the file inside of the package archive to be added
  539. """
  540. *prefix, last = filename.split("/")
  541. if len(prefix) > 1 and prefix[0] == ".data":
  542. return
  543. package = self._get_or_create_package(prefix)
  544. if isinstance(package, _ExternNode):
  545. raise ImportError(
  546. f"inconsistent module structure. package contains a module file {filename}"
  547. f" that is a subpackage of a module marked external."
  548. )
  549. if last == "__init__.py":
  550. package.source_file = filename
  551. elif last.endswith(".py"):
  552. package_name = last[: -len(".py")]
  553. package.children[package_name] = _ModuleNode(filename)
  554. def _add_extern(self, extern_name: str):
  555. *prefix, last = extern_name.split(".")
  556. package = self._get_or_create_package(prefix)
  557. if isinstance(package, _ExternNode):
  558. return # the shorter extern covers this extern case
  559. package.children[last] = _ExternNode()
  560. _NEEDS_LOADING = object()
  561. _ERR_MSG_PREFIX = "No module named "
  562. _ERR_MSG = _ERR_MSG_PREFIX + "{!r}"
  563. class _PathNode:
  564. pass
  565. class _PackageNode(_PathNode):
  566. def __init__(self, source_file: Optional[str]):
  567. self.source_file = source_file
  568. self.children: Dict[str, _PathNode] = {}
  569. class _ModuleNode(_PathNode):
  570. __slots__ = ["source_file"]
  571. def __init__(self, source_file: str):
  572. self.source_file = source_file
  573. class _ExternNode(_PathNode):
  574. pass
  575. # A private global registry of all modules that have been package-imported.
  576. _package_imported_modules: WeakValueDictionary = WeakValueDictionary()
  577. # `inspect` by default only looks in `sys.modules` to find source files for classes.
  578. # Patch it to check our private registry of package-imported modules as well.
  579. _orig_getfile = inspect.getfile
  580. def _patched_getfile(object):
  581. if inspect.isclass(object):
  582. if object.__module__ in _package_imported_modules:
  583. return _package_imported_modules[object.__module__].__file__
  584. return _orig_getfile(object)
  585. inspect.getfile = _patched_getfile
  586. class _PackageResourceReader:
  587. """Private class used to support PackageImporter.get_resource_reader().
  588. Confirms to the importlib.abc.ResourceReader interface. Allowed to access
  589. the innards of PackageImporter.
  590. """
  591. def __init__(self, importer, fullname):
  592. self.importer = importer
  593. self.fullname = fullname
  594. def open_resource(self, resource):
  595. from io import BytesIO
  596. return BytesIO(self.importer.load_binary(self.fullname, resource))
  597. def resource_path(self, resource):
  598. # The contract for resource_path is that it either returns a concrete
  599. # file system path or raises FileNotFoundError.
  600. if isinstance(
  601. self.importer.zip_reader, DirectoryReader
  602. ) and self.importer.zip_reader.has_record(
  603. os.path.join(self.fullname, resource)
  604. ):
  605. return os.path.join(
  606. self.importer.zip_reader.directory, self.fullname, resource
  607. )
  608. raise FileNotFoundError
  609. def is_resource(self, name):
  610. path = self.importer._zipfile_path(self.fullname, name)
  611. return self.importer.zip_reader.has_record(path)
  612. def contents(self):
  613. from pathlib import Path
  614. filename = self.fullname.replace(".", "/")
  615. fullname_path = Path(self.importer._zipfile_path(self.fullname))
  616. files = self.importer.zip_reader.get_all_records()
  617. subdirs_seen = set()
  618. for filename in files:
  619. try:
  620. relative = Path(filename).relative_to(fullname_path)
  621. except ValueError:
  622. continue
  623. # If the path of the file (which is relative to the top of the zip
  624. # namespace), relative to the package given when the resource
  625. # reader was created, has a parent, then it's a name in a
  626. # subdirectory and thus we skip it.
  627. parent_name = relative.parent.name
  628. if len(parent_name) == 0:
  629. yield relative.name
  630. elif parent_name not in subdirs_seen:
  631. subdirs_seen.add(parent_name)
  632. yield parent_name