123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203 |
- import collections
- import importlib.machinery
- import io
- import linecache
- import pickletools
- import platform
- import types
- from collections import defaultdict, OrderedDict
- from dataclasses import dataclass
- from enum import Enum
- from importlib.machinery import SourceFileLoader
- from pathlib import Path
- from typing import (
- Any,
- BinaryIO,
- Callable,
- cast,
- DefaultDict,
- Dict,
- List,
- Optional,
- Sequence,
- Set,
- Union,
- )
- import torch
- from torch.serialization import location_tag, normalize_storage_type
- from torch.types import Storage
- from torch.utils.hooks import RemovableHandle
- from ._digraph import DiGraph
- from ._importlib import _normalize_path
- from ._mangling import demangle, is_mangled
- from ._package_pickler import create_pickler
- from ._stdlib import is_stdlib_module
- from .find_file_dependencies import find_files_source_depends_on
- from .glob_group import GlobGroup, GlobPattern
- from .importer import Importer, OrderedImporter, sys_importer
- __all__ = [
- "PackagingErrorReason",
- "EmptyMatchError",
- "PackagingError",
- "PackageExporter",
- ]
- _gate_torchscript_serialization = True
- ActionHook = Callable[["PackageExporter", str], None]
- class _ModuleProviderAction(Enum):
- """Represents one of the actions that :class:`PackageExporter` can take on a module.
- See :meth:`PackageExporter.extern` and friends for a description of what the actions do.
- """
- INTERN = 1
- EXTERN = 2
- MOCK = 3
- DENY = 4
- # Special case: when a module is mocked, PackageExporter writes out a
- # `_mock` module that implements our mocking stubs. If we re-package code,
- # we may encounter a `_mock` module from the original package. If we do,
- # just ignore it and write a `_mock` module once.
- REPACKAGED_MOCK_MODULE = 5
- # Special case: PackageImporter adds a fake module
- # (`torch_package_importer`) that allows packaged code to access it. Don't
- # re-export this.
- SKIP = 6
- class PackagingErrorReason(Enum):
- """Listing of different reasons a dependency may fail to package.
- This enum is used to provide good error messages when
- :class:`PackagingError` is raised.
- """
- def __repr__(self):
- return "<%s.%s>" % (self.__class__.__name__, self.name)
- IS_EXTENSION_MODULE = (
- "Module is a C extension module. torch.package supports Python modules only."
- )
- NO_DUNDER_FILE = "Module had no __file__ defined."
- SOURCE_FILE_NOT_FOUND = (
- "Module had a __file__, but we could not find it in your filesystem."
- )
- DEPENDENCY_RESOLUTION_FAILED = "Dependency resolution failed."
- NO_ACTION = (
- "Module did not match against any action pattern. Extern, mock, or intern it."
- )
- DENIED = "Module was denied by a pattern."
- MOCKED_BUT_STILL_USED = (
- "Module was mocked out, but is still being used in the package. "
- "Please intern or extern the mocked modules if objects are supposed to be in "
- "the package."
- )
- @dataclass
- class _PatternInfo:
- """Holds :class:`PackageExporter`-specific info about how to execute matches against"""
- # What action to take on a module that matches this pattern.
- action: _ModuleProviderAction
- # The value of `allow_empty` the user gave when specifying the pattern.
- allow_empty: bool
- # Whether this pattern has been matched during packaging.
- was_matched: bool
- def __init__(self, action, allow_empty):
- self.action = action
- self.allow_empty = allow_empty
- self.was_matched = False
- class EmptyMatchError(Exception):
- """This is an exception that is thrown when a mock or extern is marked as
- ``allow_empty=False``, and is not matched with any module during packaging.
- """
- pass
- class PackagingError(Exception):
- """This exception is raised when there is an issue with exporting a package.
- ``PackageExporter`` will attempt to gather up all the errors and present
- them to you at once.
- """
- def __init__(self, dependency_graph: DiGraph, debug=False):
- # Group errors by reason.
- broken: Dict[PackagingErrorReason, List[str]] = defaultdict(list)
- for module_name, attrs in dependency_graph.nodes.items():
- error = attrs.get("error")
- if error is None:
- continue
- if error == PackagingErrorReason.NO_ACTION:
- assert "action" not in attrs
- broken[error].append(module_name)
- message = io.StringIO()
- message.write("\n")
- for reason, module_names in broken.items():
- message.write(f"* {reason.value}\n")
- for module_name in module_names:
- message.write(f" {module_name}\n")
- # Print additional context if it's provided.
- error_context = dependency_graph.nodes[module_name].get("error_context")
- if error_context is not None:
- message.write(f" Context: {error_context}\n")
- if module_name in _DISALLOWED_MODULES:
- message.write(
- (
- " Note: While we usually use modules in the python standard library "
- f"from the local environment, `{module_name}` has a lot of system "
- "level access and therefore can pose a security risk. We heavily "
- f"recommend removing `{module_name}` from your packaged code. However, if that "
- "is not possible, add it to the extern list by calling "
- f'PackageExporter.extern("`{module_name}`")\n'
- )
- )
- if debug:
- module_path = dependency_graph.first_path(module_name)
- message.write(
- f" A path to {module_name}: {' -> '.join(module_path)}"
- )
- if not debug:
- message.write("\n")
- message.write(
- (
- "Set debug=True when invoking PackageExporter for a visualization of where "
- "broken modules are coming from!\n"
- )
- )
- # Save the dependency graph so that tooling can get at it.
- self.dependency_graph = dependency_graph
- super().__init__(message.getvalue())
- class PackageExporter:
- """Exporters allow you to write packages of code, pickled Python data, and
- arbitrary binary and text resources into a self-contained package.
- Imports can load this code in a hermetic way, such that code is loaded
- from the package rather than the normal Python import system. This allows
- for the packaging of PyTorch model code and data so that it can be run
- on a server or used in the future for transfer learning.
- The code contained in packages is copied file-by-file from the original
- source when it is created, and the file format is a specially organized
- zip file. Future users of the package can unzip the package, and edit the code
- in order to perform custom modifications to it.
- The importer for packages ensures that code in the module can only be loaded from
- within the package, except for modules explicitly listed as external using :meth:`extern`.
- The file ``extern_modules`` in the zip archive lists all the modules that a package externally depends on.
- This prevents "implicit" dependencies where the package runs locally because it is importing
- a locally-installed package, but then fails when the package is copied to another machine.
- When source code is added to the package, the exporter can optionally scan it
- for further code dependencies (``dependencies=True``). It looks for import statements,
- resolves relative references to qualified module names, and performs an action specified by the user
- (See: :meth:`extern`, :meth:`mock`, and :meth:`intern`).
- """
- """A importer that will be searched in order to find the modules referenced by other modules or by
- pickled objects. The default module environment just uses sys_importer, which searches the Python environment.
- """
- importer: Importer
- def __init__(
- self,
- f: Union[str, Path, BinaryIO],
- importer: Union[Importer, Sequence[Importer]] = sys_importer,
- debug: bool = False,
- ):
- """
- Create an exporter.
- Args:
- f: The location to export to. Can be a ``string``/``Path`` object containing a filename
- or a binary I/O object.
- importer: If a single Importer is passed, use that to search for modules.
- If a sequence of importers are passed, an ``OrderedImporter`` will be constructed out of them.
- debug: If set to True, add path of broken modules to PackagingErrors.
- """
- torch._C._log_api_usage_once("torch.package.PackageExporter")
- self.debug = debug
- if isinstance(f, (Path, str)):
- f = str(f)
- self.buffer: Optional[BinaryIO] = None
- else: # is a byte buffer
- self.buffer = f
- self.zip_file = torch._C.PyTorchFileWriter(f)
- self.zip_file.set_min_version(6)
- self._written_files: Set[str] = set()
- self.serialized_reduces: Dict[int, Any] = {}
- # A graph tracking all the modules and pickle objects added to this
- # package and the dependencies between them.
- # - Each node is a module name (or a pickle name that looks like '<foo.obj.pkl>')
- # - Each directed edge (u, v) means u depends on v.
- # - Nodes may contain metadata that describe how to write the thing to the zipfile.
- self.dependency_graph = DiGraph()
- self.script_module_serializer = torch._C.ScriptModuleSerializer(self.zip_file)
- self.storage_context = self.script_module_serializer.storage_context()
- # These are OrderedDicts for compatibility with RemovableHandle.
- # Generic OrderedDict type annotations are not present until 3.7.
- # The real type signature is OrderedDict[int, Callable[[PackageExporter, str], None]]
- self._extern_hooks: OrderedDict = OrderedDict()
- self._mock_hooks: OrderedDict = OrderedDict()
- self._intern_hooks: OrderedDict = OrderedDict()
- if isinstance(importer, Importer):
- self.importer = importer
- else:
- if not isinstance(importer, collections.abc.Sequence):
- raise TypeError(
- "importer arg should be an Importer or a sequence of Importers, "
- f"got {type(importer)} instead."
- )
- self.importer = OrderedImporter(*importer)
- self.patterns: Dict[GlobGroup, _PatternInfo] = {}
- self._unique_id = 0
- def save_source_file(
- self, module_name: str, file_or_directory: str, dependencies=True
- ):
- """Adds the local file system ``file_or_directory`` to the source package to provide the code
- for ``module_name``.
- Args:
- module_name (str): e.g. ``"my_package.my_subpackage"``, code will be saved to provide code for this package.
- file_or_directory (str): the path to a file or directory of code. When a directory, all python files in the directory
- are recursively copied using :meth:`save_source_file`. If a file is named ``"/__init__.py"`` the code is treated
- as a package.
- dependencies (bool, optional): If ``True``, we scan the source for dependencies.
- """
- path = Path(file_or_directory)
- if path.is_dir():
- to_save = [] # list of tuples with arguments to save_source_string
- module_path = module_name.replace(".", "/")
- for filename in path.glob("**/*.py"):
- relative_path = filename.relative_to(path).as_posix()
- archivename = module_path + "/" + relative_path
- submodule_name = None
- if filename.name == "__init__.py":
- submodule_name = archivename[: -len("/__init__.py")].replace(
- "/", "."
- )
- is_package = True
- else:
- submodule_name = archivename[: -len(".py")].replace("/", ".")
- is_package = False
- # we delay the call to save_source_string so that we record all the source files
- # being provided by this directory structure _before_ attempting to resolve the dependencies
- # on the source. This makes sure we don't try to copy over modules that will just get
- # overwritten by this directory blob
- to_save.append(
- (
- submodule_name,
- _read_file(str(filename)),
- is_package,
- dependencies,
- )
- )
- for item in to_save:
- self.save_source_string(*item)
- else:
- is_package = path.name == "__init__.py"
- self.save_source_string(
- module_name,
- _read_file(file_or_directory),
- is_package,
- dependencies,
- )
- def get_unique_id(self) -> str:
- """Get an id. This id is guaranteed to only be handed out once for this package."""
- ret = str(self._unique_id)
- self._unique_id += 1
- return ret
- def _get_dependencies(
- self, src: str, module_name: str, is_package: bool
- ) -> List[str]:
- """Return all modules that this source code depends on.
- Dependencies are found by scanning the source code for import-like statements.
- Arguments:
- src: The Python source code to analyze for dependencies.
- module_name: The name of the module that ``src`` corresponds to.
- is_package: Whether this module should be treated as a package.
- See :py:meth:`save_source_string` for more info.
- Returns:
- A list containing modules detected as direct dependencies in
- ``src``. The items in the list are guaranteed to be unique.
- """
- package_name = (
- module_name if is_package else module_name.rsplit(".", maxsplit=1)[0]
- )
- try:
- dep_pairs = find_files_source_depends_on(src, package_name)
- except Exception as e:
- self.dependency_graph.add_node(
- module_name,
- error=PackagingErrorReason.DEPENDENCY_RESOLUTION_FAILED,
- error_context=str(e),
- )
- return []
- # Use a dict to get uniquing but also deterministic order
- dependencies = {}
- for dep_module_name, dep_module_obj in dep_pairs:
- # handle the case where someone did something like `from pack import sub`
- # where `sub` is a submodule. In this case we don't have to save pack, just sub.
- # this ensures we don't pick up additional dependencies on pack.
- # However, in the case where `sub` is not a submodule but an object, then we do have
- # to save pack.
- if dep_module_obj is not None:
- possible_submodule = f"{dep_module_name}.{dep_module_obj}"
- if self._module_exists(possible_submodule):
- dependencies[possible_submodule] = True
- # we don't need to save `pack`
- continue
- if self._module_exists(dep_module_name):
- dependencies[dep_module_name] = True
- return list(dependencies.keys())
- def save_source_string(
- self,
- module_name: str,
- src: str,
- is_package: bool = False,
- dependencies: bool = True,
- ):
- """Adds ``src`` as the source code for ``module_name`` in the exported package.
- Args:
- module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code for this package.
- src (str): The Python source code to save for this package.
- is_package (bool, optional): If ``True``, this module is treated as a package. Packages are allowed to have submodules
- (e.g. ``my_package.my_subpackage.my_subsubpackage``), and resources can be saved inside them. Defaults to ``False``.
- dependencies (bool, optional): If ``True``, we scan the source for dependencies.
- """
- self.dependency_graph.add_node(
- module_name,
- source=src,
- is_package=is_package,
- provided=True,
- action=_ModuleProviderAction.INTERN,
- )
- if dependencies:
- deps = self._get_dependencies(src, module_name, is_package)
- for dep in deps:
- self.dependency_graph.add_edge(module_name, dep)
- self.add_dependency(dep)
- def _write_source_string(
- self,
- module_name: str,
- src: str,
- is_package: bool = False,
- ):
- """Write ``src`` as the source code for ``module_name`` in the zip archive.
- Arguments are otherwise the same as for :meth:`save_source_string`.
- """
- extension = "/__init__.py" if is_package else ".py"
- filename = module_name.replace(".", "/") + extension
- self._write(filename, src)
- def _import_module(self, module_name: str):
- try:
- return self.importer.import_module(module_name)
- except ModuleNotFoundError as e:
- if not is_mangled(module_name):
- raise
- msg = (
- f"Module not found: '{module_name}'. Make sure the PackageImporter that "
- "created this module is present in `self.importer`"
- )
- raise ModuleNotFoundError(msg) from None
- def _module_exists(self, module_name: str) -> bool:
- try:
- self._import_module(module_name)
- return True
- except Exception:
- return False
- def _get_source_of_module(self, module: types.ModuleType) -> Optional[str]:
- filename = None
- spec = getattr(module, "__spec__", None)
- if spec is not None:
- loader = getattr(spec, "loader", None)
- if loader is not None and isinstance(loader, SourceFileLoader):
- try:
- filename = loader.get_filename(module.__name__)
- except ImportError:
- pass
- if filename is None:
- filename = getattr(module, "__file__", None)
- if isinstance(filename, str) and filename.endswith(".py"):
- return "".join(linecache.getlines(filename, module.__dict__))
- return None
- def add_dependency(self, module_name: str, dependencies=True):
- """Given a module, add it to the dependency graph according to patterns
- specified by the user.
- """
- if (
- module_name in self.dependency_graph
- and self.dependency_graph.nodes[module_name].get("provided") is True
- ):
- return
- # Special case: PackageImporter provides a special module called
- # `torch_package_importer` that allows packaged modules to reference
- # their PackageImporter. We don't want to re-export this.
- if module_name == "torch_package_importer":
- self.dependency_graph.add_node(
- module_name,
- action=_ModuleProviderAction.SKIP,
- provided=True,
- )
- return
- if module_name == "_mock":
- self.dependency_graph.add_node(
- module_name,
- action=_ModuleProviderAction.REPACKAGED_MOCK_MODULE,
- provided=True,
- )
- return
- if self._can_implicitly_extern(module_name):
- self.dependency_graph.add_node(
- module_name, action=_ModuleProviderAction.EXTERN, provided=True
- )
- return
- for pattern, pattern_info in self.patterns.items():
- if pattern.matches(module_name):
- pattern_info.was_matched = True
- self.dependency_graph.add_node(
- module_name, action=pattern_info.action, provided=True
- )
- if pattern_info.action == _ModuleProviderAction.DENY:
- # Requiring a denied module just adds an error to the graph.
- self.dependency_graph.add_node(
- module_name, error=PackagingErrorReason.DENIED
- )
- # If we are interning this module, we need to retrieve its
- # dependencies and package those as well.
- if pattern_info.action == _ModuleProviderAction.INTERN:
- self._intern_module(module_name, dependencies)
- return
- # No patterns have matched. Explicitly add this as an error.
- self.dependency_graph.add_node(
- module_name, error=PackagingErrorReason.NO_ACTION
- )
- def save_module(self, module_name: str, dependencies=True):
- """Save the code for ``module`` into the package. Code for the module is resolved using the ``importers`` path to find the
- module object, and then using its ``__file__`` attribute to find the source code.
- Args:
- module_name (str): e.g. ``my_package.my_subpackage``, code will be saved to provide code
- for this package.
- dependencies (bool, optional): If ``True``, we scan the source for dependencies.
- """
- if not isinstance(module_name, str):
- raise TypeError(
- "save_module() expects a string input, did you perhaps mean to pass `__name__`?"
- )
- self._intern_module(module_name, dependencies)
- def _intern_module(
- self,
- module_name: str,
- dependencies: bool,
- ):
- """Adds the module to the dependency graph as an interned module,
- along with any metadata needed to write it out to the zipfile at serialization time.
- """
- module_obj = self._import_module(module_name)
- # Subtle: if the import above succeeded, either:
- # 1. The module name is not mangled, and this was just a regular import, or
- # 2. The module name is mangled, but one of the importers was able to
- # recognize the mangling and import it.
- # Either way, it is now safe to demangle this name so that we don't
- # serialize the mangled version to the package.
- module_name = demangle(module_name)
- # Find dependencies of this module and require them as well.
- is_package = hasattr(module_obj, "__path__")
- source = self._get_source_of_module(module_obj)
- if source is None:
- # Couldn't find a source! Add it to our dependency graph as broken
- # and continue.
- filename = getattr(module_obj, "__file__", None)
- error_context = None
- if filename is None:
- packaging_error = PackagingErrorReason.NO_DUNDER_FILE
- elif filename.endswith(tuple(importlib.machinery.EXTENSION_SUFFIXES)):
- packaging_error = PackagingErrorReason.IS_EXTENSION_MODULE
- else:
- packaging_error = PackagingErrorReason.SOURCE_FILE_NOT_FOUND
- error_context = f"filename: {filename}"
- self.dependency_graph.add_node(
- module_name,
- action=_ModuleProviderAction.INTERN,
- is_package=is_package,
- error=packaging_error,
- error_context=error_context,
- provided=True,
- )
- return
- self.dependency_graph.add_node(
- module_name,
- action=_ModuleProviderAction.INTERN,
- is_package=is_package,
- source=source,
- provided=True,
- )
- if dependencies:
- deps = self._get_dependencies(source, module_name, is_package)
- for dep in deps:
- self.dependency_graph.add_edge(module_name, dep)
- self.add_dependency(dep)
- def save_pickle(
- self,
- package: str,
- resource: str,
- obj: Any,
- dependencies: bool = True,
- pickle_protocol: int = 3,
- ):
- """Save a python object to the archive using pickle. Equivalent to :func:`torch.save` but saving into
- the archive rather than a stand-alone file. Standard pickle does not save the code, only the objects.
- If ``dependencies`` is true, this method will also scan the pickled objects for which modules are required
- to reconstruct them and save the relevant code.
- To be able to save an object where ``type(obj).__name__`` is ``my_module.MyObject``,
- ``my_module.MyObject`` must resolve to the class of the object according to the ``importer`` order. When saving objects that
- have previously been packaged, the importer's ``import_module`` method will need to be present in the ``importer`` list
- for this to work.
- Args:
- package (str): The name of module package this resource should go in (e.g. ``"my_package.my_subpackage"``).
- resource (str): A unique name for the resource, used to identify it to load.
- obj (Any): The object to save, must be picklable.
- dependencies (bool, optional): If ``True``, we scan the source for dependencies.
- """
- assert (pickle_protocol == 4) or (
- pickle_protocol == 3
- ), "torch.package only supports pickle protocols 3 and 4"
- filename = self._filename(package, resource)
- # Write the pickle data for `obj`
- data_buf = io.BytesIO()
- pickler = create_pickler(data_buf, self.importer, protocol=pickle_protocol)
- pickler.persistent_id = self._persistent_id
- pickler.dump(obj)
- data_value = data_buf.getvalue()
- mocked_modules = defaultdict(list)
- name_in_dependency_graph = f"<{package}.{resource}>"
- self.dependency_graph.add_node(
- name_in_dependency_graph,
- action=_ModuleProviderAction.INTERN,
- provided=True,
- is_pickle=True,
- )
- def _check_mocked_error(module: Optional[str], field: Optional[str]):
- """
- checks if an object (field) comes from a mocked module and then adds
- the pair to mocked_modules which contains mocked modules paired with their
- list of mocked objects present in the pickle.
- We also hold the invariant that the first user defined rule that applies
- to the module is the one we use.
- """
- assert isinstance(module, str)
- assert isinstance(field, str)
- if self._can_implicitly_extern(module):
- return
- for pattern, pattern_info in self.patterns.items():
- if pattern.matches(module):
- if pattern_info.action == _ModuleProviderAction.MOCK:
- mocked_modules[module].append(field)
- return
- if dependencies:
- all_dependencies = []
- module = None
- field = None
- memo: DefaultDict[int, str] = defaultdict(None)
- memo_count = 0
- # pickletools.dis(data_value)
- for opcode, arg, pos in pickletools.genops(data_value):
- if pickle_protocol == 4:
- if (
- opcode.name == "SHORT_BINUNICODE"
- or opcode.name == "BINUNICODE8"
- ):
- assert isinstance(arg, str)
- module = field
- field = arg
- memo[memo_count] = arg
- elif (
- opcode.name == "LONG_BINGET"
- or opcode.name == "BINGET"
- or opcode.name == "GET"
- ):
- assert isinstance(arg, int)
- module = field
- field = memo.get(arg, None)
- elif opcode.name == "MEMOIZE":
- memo_count += 1
- elif opcode.name == "STACK_GLOBAL":
- if module is None:
- # If not module was passed on in the entries preceeding this one, continue.
- continue
- assert isinstance(module, str)
- if module not in all_dependencies:
- all_dependencies.append(module)
- _check_mocked_error(module, field)
- elif (
- pickle_protocol == 3 and opcode.name == "GLOBAL"
- ): # a global reference
- assert isinstance(arg, str)
- module, field = arg.split(" ")
- if module not in all_dependencies:
- all_dependencies.append(module)
- _check_mocked_error(module, field)
- for module_name in all_dependencies:
- self.dependency_graph.add_edge(name_in_dependency_graph, module_name)
- """ If an object happens to come from a mocked module, then we collect these errors and spit them
- out with the other errors found by package exporter.
- """
- if module in mocked_modules:
- assert isinstance(module, str)
- fields = mocked_modules[module]
- self.dependency_graph.add_node(
- module_name,
- action=_ModuleProviderAction.MOCK,
- error=PackagingErrorReason.MOCKED_BUT_STILL_USED,
- error_context=f"Object(s) '{fields}' from module `{module_name}` was mocked out during packaging "
- f"but is being used in resource - `{resource}` in package `{package}`. ",
- provided=True,
- )
- else:
- self.add_dependency(module_name)
- self._write(filename, data_value)
- def save_text(self, package: str, resource: str, text: str):
- """Save text data to the package.
- Args:
- package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``).
- resource (str): A unique name for the resource, used to identify it to load.
- text (str): The contents to save.
- """
- return self.save_binary(package, resource, text.encode("utf-8"))
- def save_binary(self, package, resource, binary: bytes):
- """Save raw bytes to the package.
- Args:
- package (str): The name of module package this resource should go it (e.g. ``"my_package.my_subpackage"``).
- resource (str): A unique name for the resource, used to identify it to load.
- binary (str): The data to save.
- """
- filename = self._filename(package, resource)
- self._write(filename, binary)
- def register_extern_hook(self, hook: ActionHook) -> RemovableHandle:
- """Registers an extern hook on the exporter.
- The hook will be called each time a module matches against an :meth:`extern` pattern.
- It should have the following signature::
- hook(exporter: PackageExporter, module_name: str) -> None
- Hooks will be called in order of registration.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- A handle that can be used to remove the added hook by calling
- ``handle.remove()``.
- """
- handle = RemovableHandle(self._extern_hooks)
- self._extern_hooks[handle.id] = hook
- return handle
- def register_mock_hook(self, hook: ActionHook) -> RemovableHandle:
- """Registers a mock hook on the exporter.
- The hook will be called each time a module matches against a :meth:`mock` pattern.
- It should have the following signature::
- hook(exporter: PackageExporter, module_name: str) -> None
- Hooks will be called in order of registration.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- A handle that can be used to remove the added hook by calling
- ``handle.remove()``.
- """
- handle = RemovableHandle(self._mock_hooks)
- self._mock_hooks[handle.id] = hook
- return handle
- def register_intern_hook(self, hook: ActionHook) -> RemovableHandle:
- """Registers an intern hook on the exporter.
- The hook will be called each time a module matches against an :meth:`intern` pattern.
- It should have the following signature::
- hook(exporter: PackageExporter, module_name: str) -> None
- Hooks will be called in order of registration.
- Returns:
- :class:`torch.utils.hooks.RemovableHandle`:
- A handle that can be used to remove the added hook by calling
- ``handle.remove()``.
- """
- handle = RemovableHandle(self._intern_hooks)
- self._intern_hooks[handle.id] = hook
- return handle
- def intern(
- self,
- include: "GlobPattern",
- *,
- exclude: "GlobPattern" = (),
- allow_empty: bool = True,
- ):
- """Specify modules that should be packaged. A module must match some ``intern`` pattern in order to be
- included in the package and have its dependencies processed recursively.
- Args:
- include (Union[List[str], str]): A string e.g. "my_package.my_subpackage", or list of strings
- for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`.
- exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
- allow_empty (bool): An optional flag that specifies whether the intern modules specified by this call
- to the ``intern`` method must be matched to some module during packaging. If an ``intern`` module glob
- pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``)
- before any modules match that pattern, an exception is thrown. If ``allow_empty=True``, no such exception is thrown.
- """
- self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
- _ModuleProviderAction.INTERN, allow_empty
- )
- def mock(
- self,
- include: "GlobPattern",
- *,
- exclude: "GlobPattern" = (),
- allow_empty: bool = True,
- ):
- """Replace some required modules with a mock implementation. Mocked modules will return a fake
- object for any attribute accessed from it. Because we copy file-by-file, the dependency resolution will sometimes
- find files that are imported by model files but whose functionality is never used
- (e.g. custom serialization code or training helpers).
- Use this function to mock this functionality out without having to modify the original code.
- Args:
- include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
- for the names of the modules to be mocked out. Strings can also be a glob-style pattern
- string that may match multiple modules. Any required dependencies that match this pattern
- string will be mocked out automatically.
- Examples :
- ``'torch.**'`` -- matches ``torch`` and all submodules of torch, e.g. ``'torch.nn'``
- and ``'torch.nn.functional'``
- ``'torch.*'`` -- matches ``'torch.nn'`` or ``'torch.functional'``, but not
- ``'torch.nn.functional'``
- exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
- e.g. ``include='torch.**', exclude='torch.foo'`` will mock all torch packages except ``'torch.foo'``,
- Default: is ``[]``.
- allow_empty (bool): An optional flag that specifies whether the mock implementation(s) specified by this call
- to the :meth:`mock` method must be matched to some module during packaging. If a mock is added with
- ``allow_empty=False``, and :meth:`close` is called (either explicitly or via ``__exit__``) and the mock has
- not been matched to a module used by the package being exported, an exception is thrown.
- If ``allow_empty=True``, no such exception is thrown.
- """
- self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
- _ModuleProviderAction.MOCK, allow_empty
- )
- def extern(
- self,
- include: "GlobPattern",
- *,
- exclude: "GlobPattern" = (),
- allow_empty: bool = True,
- ):
- """Include ``module`` in the list of external modules the package can import.
- This will prevent dependency discovery from saving
- it in the package. The importer will load an external module directly from the standard import system.
- Code for extern modules must also exist in the process loading the package.
- Args:
- include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
- for the names of the modules to be externed. This can also be a glob-style pattern, as
- described in :meth:`mock`.
- exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the
- include string.
- allow_empty (bool): An optional flag that specifies whether the extern modules specified by this call
- to the ``extern`` method must be matched to some module during packaging. If an extern module glob
- pattern is added with ``allow_empty=False``, and :meth:`close` is called (either explicitly or via
- ``__exit__``) before any modules match that pattern, an exception is thrown. If ``allow_empty=True``,
- no such exception is thrown.
- """
- self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
- _ModuleProviderAction.EXTERN, allow_empty
- )
- def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()):
- """Blocklist modules who names match the given glob patterns from the list of modules the package can import.
- If a dependency on any matching packages is found, a :class:`PackagingError` is raised.
- Args:
- include (Union[List[str], str]): A string e.g. ``"my_package.my_subpackage"``, or list of strings
- for the names of the modules to be externed. This can also be a glob-style pattern, as described in :meth:`mock`.
- exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
- """
- self.patterns[GlobGroup(include, exclude=exclude)] = _PatternInfo(
- _ModuleProviderAction.DENY, allow_empty=True
- )
- def _persistent_id(self, obj):
- if torch.is_storage(obj) or isinstance(obj, torch.storage.TypedStorage):
- storage: Storage
- if isinstance(obj, torch.storage.TypedStorage):
- # TODO: Once we decide to break serialization FC, we can
- # remove this case
- untyped_storage = obj._untyped_storage
- storage_type_str = obj.pickle_storage_type()
- storage_type = getattr(torch, storage_type_str)
- storage = cast(Storage, untyped_storage)
- storage_numel = obj.size()
- elif isinstance(obj, torch.UntypedStorage):
- untyped_storage = obj
- storage = cast(Storage, untyped_storage)
- storage_type = normalize_storage_type(type(storage))
- storage_numel = storage.nbytes()
- else:
- raise RuntimeError(f"storage type not recognized: {type(obj)}")
- location = location_tag(storage)
- # serialize storage if not already written
- storage_present = self.storage_context.has_storage(storage)
- storage_id = self.storage_context.get_or_add_storage(storage)
- if not storage_present:
- if storage.device.type != "cpu":
- storage = storage.cpu()
- num_bytes = storage.nbytes()
- self.zip_file.write_record(
- f".data/{storage_id}.storage", storage.data_ptr(), num_bytes
- )
- return ("storage", storage_type, storage_id, location, storage_numel)
- if hasattr(obj, "__reduce_package__"):
- if _gate_torchscript_serialization and isinstance(
- obj, torch.jit.RecursiveScriptModule
- ):
- raise Exception(
- "Serializing ScriptModules directly into a package is a beta feature. "
- "To use, set global "
- "`torch.package.package_exporter._gate_torchscript_serialization` to `False`."
- )
- if self.serialized_reduces.get(id(obj)) is None:
- self.serialized_reduces[id(obj)] = (
- "reduce_package",
- id(obj),
- *obj.__reduce_package__(self),
- )
- return self.serialized_reduces[id(obj)]
- return None
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_value, traceback):
- # If __exit__ was called because an exception was raised, we do not
- # attempt to finalize the package. Instead, control is returned to the
- # caller to continue raising the exception.
- if exc_type is not None:
- # Do the bare minimum to leave the open buffer in a valid state.
- self._finalize_zip()
- return
- self.close()
- def _write(self, filename, str_or_bytes):
- if filename in self._written_files:
- raise AssertionError(
- f"Tried to write file '{filename}', but it already exists in this archive. "
- "Please file a bug."
- )
- self._written_files.add(filename)
- if is_mangled(filename):
- raise AssertionError(
- f"Tried to save a torch.package'd module as '{filename}'. "
- "Directly saving torch.package'd modules is not allowed."
- )
- if isinstance(str_or_bytes, str):
- str_or_bytes = str_or_bytes.encode("utf-8")
- self.zip_file.write_record(filename, str_or_bytes, len(str_or_bytes))
- def _validate_dependency_graph(self):
- # 1. Check the graph for any errors inserted during dependency analysis.
- for module_name, attrs in self.dependency_graph.nodes.items():
- if "error" in attrs:
- raise PackagingError(self.dependency_graph, debug=self.debug)
- # 2. Check that all patterns for which allow_empty=False have been matched at least once.
- for pattern, pattern_info in self.patterns.items():
- if not pattern_info.allow_empty and not pattern_info.was_matched:
- raise EmptyMatchError(
- f"Exporter did not match any modules to {pattern}, which was marked as allow_empty=False"
- )
- def _write_mock_file(self):
- if "_mock.py" not in self._written_files:
- mock_file = str(Path(__file__).parent / "_mock.py")
- self._write_source_string("_mock", _read_file(mock_file), is_package=False)
- def _execute_dependency_graph(self):
- """Takes a finalized dependency graph describing how to package all
- modules and executes it, writing to the ZIP archive.
- """
- self._validate_dependency_graph()
- extern_modules = []
- for module_name, attrs in self.dependency_graph.nodes.items():
- action = attrs["action"]
- if action == _ModuleProviderAction.EXTERN:
- for hook in self._extern_hooks.values():
- hook(self, module_name)
- extern_modules.append(module_name)
- elif action == _ModuleProviderAction.MOCK:
- for hook in self._mock_hooks.values():
- hook(self, module_name)
- self._write_mock_file()
- is_package = hasattr(self._import_module(module_name), "__path__")
- self._write_source_string(module_name, _MOCK_IMPL, is_package)
- elif action == _ModuleProviderAction.INTERN:
- for hook in self._intern_hooks.values():
- hook(self, module_name)
- # The node in the dependency graph contains metadata that tells us
- # how to intern the module.
- if "provided" not in attrs:
- raise AssertionError(
- f"Module was marked `intern` but not provided: {module_name}"
- )
- if attrs.get("is_pickle") is True:
- # This node came from save_pickle, we don't need to write any source for it.
- continue
- is_package = attrs["is_package"]
- source = attrs["source"]
- self._write_source_string(module_name, source, is_package)
- elif action == _ModuleProviderAction.REPACKAGED_MOCK_MODULE:
- self._write_mock_file()
- elif action == _ModuleProviderAction.SKIP:
- continue
- else:
- raise AssertionError(
- f"Invalid action: {module_name}, {action}. Please report a bug to PyTorch."
- )
- extern_file_contents = "\n".join(extern_modules) + "\n"
- self._write(".data/extern_modules", extern_file_contents)
- def _write_python_version(self):
- """Writes the python version that the package was created with to .data/python_version"""
- self._write(".data/python_version", platform.python_version())
- def close(self):
- """Write the package to the filesystem. Any calls after :meth:`close` are now invalid.
- It is preferable to use resource guard syntax instead::
- with PackageExporter("file.zip") as e:
- ...
- """
- self._execute_dependency_graph()
- self._write_python_version()
- self.script_module_serializer.write_files()
- self._finalize_zip()
- def _finalize_zip(self):
- """Called at the very end of packaging to leave the zipfile in a closed but valid state."""
- del self.zip_file
- if self.buffer:
- self.buffer.flush()
- def _filename(self, package, resource):
- package_path = package.replace(".", "/")
- resource = _normalize_path(resource)
- return f"{package_path}/{resource}"
- def _can_implicitly_extern(self, module_name: str):
- top_level_package_name = module_name.partition(".")[0]
- return top_level_package_name == "torch" or (
- top_level_package_name not in _DISALLOWED_MODULES
- and is_stdlib_module(top_level_package_name)
- )
- def dependency_graph_string(self) -> str:
- """Returns digraph string representation of dependencies in package.
- Returns:
- A string representation of dependencies in package.
- """
- return self.dependency_graph.to_dot()
- def _nodes_with_action_type(
- self, action: Optional[_ModuleProviderAction]
- ) -> List[str]:
- result = []
- for name, node_dict in self.dependency_graph.nodes.items():
- node_action = node_dict.get("action", None)
- if node_action == action and "is_pickle" not in node_dict:
- result.append(name)
- result.sort()
- return result
- def externed_modules(self) -> List[str]:
- """Return all modules that are currently externed.
- Returns:
- A list containing the names of modules which will be
- externed in this package.
- """
- return self._nodes_with_action_type(_ModuleProviderAction.EXTERN)
- def interned_modules(self) -> List[str]:
- """Return all modules that are currently interned.
- Returns:
- A list containing the names of modules which will be
- interned in this package.
- """
- return self._nodes_with_action_type(_ModuleProviderAction.INTERN)
- def mocked_modules(self) -> List[str]:
- """Return all modules that are currently mocked.
- Returns:
- A list containing the names of modules which will be
- mocked in this package.
- """
- return self._nodes_with_action_type(_ModuleProviderAction.MOCK)
- def denied_modules(self) -> List[str]:
- """Return all modules that are currently denied.
- Returns:
- A list containing the names of modules which will be
- denied in this package.
- """
- return self._nodes_with_action_type(_ModuleProviderAction.DENY)
- def get_rdeps(self, module_name: str) -> List[str]:
- """Return a list of all modules which depend on the module ``module_name``.
- Returns:
- A list containing the names of modules which depend on ``module_name``.
- """
- if module_name in self.dependency_graph._pred.keys():
- return list(self.dependency_graph._pred[module_name].keys())
- else:
- return []
- def all_paths(self, src: str, dst: str) -> str:
- """Return a dot representation of the subgraph
- that has all paths from src to dst.
- Returns:
- A dot representation containing all paths from src to dst.
- (https://graphviz.org/doc/info/lang.html)
- """
- return self.dependency_graph.all_paths(src, dst)
- # even though these are in the standard library, we do not allow them to be
- # automatically externed since they offer a lot of system level access
- _DISALLOWED_MODULES = ["sys", "io"]
- _MOCK_IMPL = """\
- from _mock import MockedObject
- def __getattr__(attr: str):
- return MockedObject(__name__ + '.' + attr, _suppress_err=True)
- """
- def _read_file(filename: str) -> str:
- with open(filename, "rb") as f:
- b = f.read()
- return b.decode("utf-8")
|