123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783 |
- import torch
- import torch.nn as nn
- import torch.overrides
- from torch.nn.modules.module import _addindent
- from torch.package import PackageImporter, PackageExporter
- import linecache
- from typing import Type, Dict, List, Any, Union, Optional, Set
- from .graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
- from ._compatibility import compatibility
- from torch.package import Importer, sys_importer
- import copy
- import itertools
- import sys
- import traceback
- from pathlib import Path
- import os
- import warnings
- __all__ = ["reduce_graph_module", "reduce_package_graph_module", "reduce_deploy_graph_module", "GraphModule"]
- _USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
- # Normal exec loses the source code, however we can work with
- # the linecache module to recover it.
- # Using _exec_with_source will add it to our local cache
- # and then tools like TorchScript will be able to get source info.
- class _EvalCacheLoader:
- def __init__(self):
- self.eval_cache = {}
- self.next_id = 0
- def cache(self, src: str, globals: Dict[str, Any]):
- """Store the source in a private cache, and add a lazy entry in linecache
- that allows the source to be retrieved by 'filename'.
- Args:
- src (str): The module source to cache
- globals (dict): The module globals
- Returns:
- str: The cache key (and dummy filename) generated for src.
- """
- key = self._get_key()
- self.eval_cache[key] = src
- # Don't mutate globals so that this loader is only used
- # to populate linecache, and doesn't interact with other modules
- # that might check `__loader__`
- globals_copy = globals.copy()
- globals_copy['__file__'] = key
- globals_copy['__name__'] = key
- globals_copy['__loader__'] = self
- linecache.lazycache(key, globals_copy)
- return key
- # Part of the loader protocol (PEP 302)
- # linecache will use this method when trying to find source code
- def get_source(self, module_name) -> Optional[str]:
- if module_name in self.eval_cache:
- return self.eval_cache[module_name]
- return None
- def _get_key(self):
- key = f'<eval_with_key>.{self.next_id}'
- self.next_id += 1
- return key
- _loader = _EvalCacheLoader()
- def _exec_with_source(src: str, globals: Dict[str, Any]):
- key = _loader.cache(src, globals)
- exec(compile(src, key, 'exec'), globals)
- def _forward_from_src(src: str, globals: Dict[str, Any]):
- # avoid mutating the passed in dict
- globals_copy = globals.copy()
- _exec_with_source(src, globals_copy)
- forward_fn = globals_copy['forward']
- del globals_copy['forward']
- return forward_fn
- def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
- if name in _custom_builtins:
- return _custom_builtins[name].import_str
- if _is_from_torch(name):
- return 'import torch'
- module_name, attr_name = importer.get_name(obj)
- return f'from {module_name} import {attr_name} as {name}'
- def _format_import_block(globals: Dict[str, Any], importer: Importer):
- import_strs: Set[str] = set()
- for name, obj in globals.items():
- import_strs.add(_format_import_statement(name, obj, importer))
- return '\n'.join(import_strs)
- @compatibility(is_backward_compatible=True)
- def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
- # BC: attribute name was changed from `code` to `_code` to facilitate
- # making `code` into a property and adding a docstring to it
- fn_src = body.get('_code') or body['code']
- forward = _forward_from_src(import_block + fn_src, {})
- return _deserialize_graph_module(forward, body)
- @compatibility(is_backward_compatible=True)
- def reduce_package_graph_module(
- importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
- ) -> torch.nn.Module:
- forward = importer.import_module(generated_module_name).forward
- return _deserialize_graph_module(forward, body)
- @compatibility(is_backward_compatible=True)
- def reduce_deploy_graph_module(
- importer: PackageImporter, body: Dict[Any, Any], import_block: str
- ) -> torch.nn.Module:
- ns = {}
- ns["__builtins__"] = importer.patched_builtins
- fn_src = body.get('_code')
- assert fn_src is not None
- forward = _forward_from_src(import_block + fn_src, ns)
- return _deserialize_graph_module(forward, body)
- def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module:
- """
- Deserialize a GraphModule given the dictionary of the original module,
- using the code to reconstruct the graph. We delete the actual graph before
- saving the dictionary so that changes to the in-memory graph format do not
- get serialized.
- """
- # We create a dummy class here because symbolic_trace pulls the forward()
- # function off of the class, rather than the instance
- class CodeOnlyModule(torch.nn.Module):
- def __init__(self, body):
- super().__init__()
- self.__dict__ = body
- # Try to retrieve the forward source in a backward-compatible way
- CodeOnlyModule.forward = forward
- tracer_cls = body.get('_tracer_cls')
- if tracer_cls is None:
- from ._symbolic_trace import Tracer
- tracer_cls = Tracer
- graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule')
- # This is a workaround for a mypy linter issue related to
- # passing base class as an argument - https://github.com/python/mypy/issues/5865.
- cls_tracer : Any = tracer_cls
- class KeepModules(cls_tracer):
- # we shouldn't trace into any of the submodules,
- # because they were not traced in the original GraphModule
- def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
- return True
- com = CodeOnlyModule(body)
- tracer_extras = body.get('_tracer_extras', {})
- graph = KeepModules().trace(com, **tracer_extras)
- # Manually set Tracer class on the reconstructed Graph, to avoid
- # referencing the private local subclass KeepModules.
- graph._tracer_cls = tracer_cls
- gm = GraphModule(com, graph, class_name=graphmodule_cls_name)
- # The GraphModule constructor only retains attributes referenced by the graph.
- # In this case, our goal is return a GraphModule as close to identical as the one
- # put into the package. If any additional attributes were present in body,
- # we should keep them.
- for k, v in body.items():
- if not hasattr(gm, k):
- setattr(gm, k, v)
- return gm
- # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
- # This installs empty Modules where none exist yet if they are subpaths of target
- def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
- *prefix, field = target.split('.')
- for item in prefix:
- f = getattr(from_module, item)
- t = getattr(to_module, item, None)
- if f is t:
- # we have already installed one of its parents
- # (e.g. target = root.linear.weight, but we have already installed root.linear)
- # once we install a parent, we no longer need to copy the children
- # since all the needed properties will already be present
- return
- if t is None:
- t = torch.nn.Module()
- setattr(to_module, item, t)
- from_module, to_module = f, t
- orig = getattr(from_module, field)
- # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
- # So, we register it as a named buffer in the target module.
- if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
- to_module.register_buffer(field, orig)
- else:
- setattr(to_module, field, orig)
- # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
- # This installs empty Modules where none exist yet if they are subpaths of target
- def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
- *prefix, field = target.split('.')
- for item in prefix:
- t = getattr(to_module, item, None)
- if t is None:
- t = torch.nn.Module()
- setattr(to_module, item, t)
- to_module = t
- # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
- # So, we register it as a named buffer in the target module.
- if isinstance(from_obj, torch.Tensor) and not isinstance(from_obj, torch.nn.Parameter):
- to_module.register_buffer(field, from_obj)
- else:
- setattr(to_module, field, from_obj)
- class _WrappedCall:
- def __init__(self, cls, cls_call):
- self.cls = cls
- self.cls_call = cls_call
- # Previously, if an error occurred when valid
- # symbolically-traced code was run with an invalid input, the
- # user would see the source of the error as coming from
- # `File "<eval_with_key_N">`, where N is some number. We use
- # this function to generate a more informative error message. We
- # return the traceback itself, a message explaining that the
- # error occurred in a traced Module's generated forward
- # function, and five lines of context surrounding the faulty
- # line
- @staticmethod
- def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
- # auxiliary variables (for readability)
- err_lineno = frame_summary.lineno
- assert err_lineno is not None
- line = frame_summary.line
- assert line is not None
- err_line_len = len(line)
- all_src_lines = linecache.getlines(frame_summary.filename)
- # constituent substrings of the error message
- tb_repr = traceback.format_exc()
- custom_msg = ("Call using an FX-traced Module, "
- f"line {err_lineno} of the traced Module's "
- "generated forward function:")
- before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
- marker = "~" * err_line_len + "~~~ <--- HERE"
- err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
- # joined message
- return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
- def __call__(self, obj, *args, **kwargs):
- try:
- if self.cls_call is not None:
- return self.cls_call(obj, *args, **kwargs)
- else:
- return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
- except Exception as e:
- assert e.__traceback__
- topmost_framesummary: traceback.FrameSummary = \
- traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
- if "eval_with_key" in topmost_framesummary.filename:
- print(_WrappedCall._generate_error_message(topmost_framesummary),
- file=sys.stderr)
- raise e.with_traceback(None)
- else:
- raise e
- @compatibility(is_backward_compatible=True)
- class GraphModule(torch.nn.Module):
- """
- GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
- ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
- from that ``graph``.
- .. warning::
- When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
- regenerated. However, if you edit the contents of the ``graph`` without reassigning
- the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
- code.
- """
- def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
- # each instance of a graph module needs its own forward method
- # so create a new singleton class for each instance.
- # it is a subclass of the user-defined class, the only difference
- # is an extra layer to install the forward method
- # address issue described at https://github.com/pytorch/pytorch/issues/63883
- # in other words, traverse class hierarchy to fix the redundant class definition problem
- for t in cls.__mro__:
- c = t.__qualname__.split('.')[-1]
- if c != 'GraphModuleImpl':
- cls = t
- break
- class GraphModuleImpl(cls): # type: ignore[misc, valid-type]
- pass
- return super().__new__(GraphModuleImpl)
- @compatibility(is_backward_compatible=True)
- def __init__(self,
- root: Union[torch.nn.Module, Dict[str, Any]],
- graph: Graph,
- class_name: str = 'GraphModule'):
- """
- Construct a GraphModule.
- Args:
- root (Union[torch.nn.Module, Dict[str, Any]):
- ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
- In the case that ``root`` is a Module, any references to Module-based objects (via qualified
- name) in the Graph's Nodes' ``target`` field will be copied over from the respective place
- within ``root``'s Module hierarchy into the GraphModule's module hierarchy.
- In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be
- looked up directly in the dict's keys. The object mapped to by the Dict will be copied
- over into the appropriate place within the GraphModule's module hierarchy.
- graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation
- class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
- error messages will report as originating from ``GraphModule``. It may be helpful to set this
- to ``root``'s original name or a name that makes sense within the context of your transform.
- """
- super().__init__()
- self.__class__.__name__ = class_name
- if isinstance(root, torch.nn.Module):
- if hasattr(root, 'training'):
- self.training = root.training
- for node in graph.nodes:
- if node.op in ['get_attr', 'call_module']:
- assert isinstance(node.target, str)
- _copy_attr(root, self, node.target)
- elif isinstance(root, dict):
- targets_to_copy = []
- for node in graph.nodes:
- if node.op in ['get_attr', 'call_module']:
- assert isinstance(node.target, str)
- if node.target not in root:
- raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target +
- ' but that target was not provided in ``root``!')
- targets_to_copy.append(node.target)
- # Sort targets in ascending order of the # of atoms.
- # This will ensure that less deeply nested attributes are assigned
- # before more deeply nested attributes. For example, foo.bar
- # will be assigned before foo.bar.baz. Otherwise, we might assign
- # the user-provided ``foo.bar`` and wipe out the previously-assigned
- # ``foo.bar.baz``
- targets_to_copy.sort(key=lambda t: t.count('.'))
- for target_to_copy in targets_to_copy:
- _assign_attr(root[target_to_copy], self, target_to_copy)
- else:
- raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!')
- self.graph = graph
- # Store the Tracer class responsible for creating a Graph separately as part of the
- # GraphModule state, except when the Tracer is defined in a local namespace.
- # Locally defined Tracers are not pickleable. This is needed because torch.package will
- # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
- # to re-create the Graph during deserialization.
- self._tracer_cls = None
- if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
- self._tracer_cls = self.graph._tracer_cls
- self._tracer_extras = {}
- if self.graph._tracer_extras:
- self._tracer_extras = self.graph._tracer_extras
- # Dictionary to store metadata
- self.meta : Dict[str, Any] = {}
- # TorchScript breaks trying to compile the graph setter because of the
- # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
- #
- # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
- __jit_unused_properties__ = ['graph']
- @property
- def graph(self) -> Graph:
- """
- Return the ``Graph`` underlying this ``GraphModule``
- """
- return self._graph
- @graph.setter
- def graph(self, g : Graph) -> None:
- """
- Set the underlying ``Graph`` for this ``GraphModule``. This will internally
- recompile the ``GraphModule`` so that the generated ``forward()`` function
- corresponds to ``g``
- """
- assert isinstance(g, Graph), f'Expected a Graph instance, but got {type(g)}'
- self._graph = g
- g.owning_module = self
- self.recompile()
- @compatibility(is_backward_compatible=False)
- def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModule"):
- """Dumps out module to ``folder`` with ``module_name`` so that it can be
- imported with ``from <folder> import <module_name>``
- Args:
- folder (Union[str, os.PathLike]): The folder to write the code out to
- module_name (str): Top-level name to use for the ``Module`` while
- writing out the code
- """
- folder = Path(folder)
- Path(folder).mkdir(exist_ok=True)
- torch.save(self.state_dict(), folder / 'state_dict.pt')
- tab = " " * 4
- custom_builtins = '\n'.join([v.import_str for v in _custom_builtins.values()])
- model_str = f"""
- import torch
- {custom_builtins}
- from torch.nn import *
- class {module_name}(torch.nn.Module):
- def __init__(self):
- super().__init__()
- """
- def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
- safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
- if type(module) in safe_reprs:
- return f"{module.__repr__()}"
- else:
- return None
- blobified_modules = []
- for module_name, module in self.named_children():
- module_str = _gen_model_repr(module_name, module)
- if module_str is None:
- module_file = folder / f'{module_name}.pt'
- torch.save(module, module_file)
- blobified_modules.append(module_name)
- module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
- module_str = f"torch.load(r'{module_file}') # {module_repr}"
- model_str += f"{tab*2}self.{module_name} = {module_str}\n"
- for buffer_name, buffer in self._buffers.items():
- if buffer is None:
- continue
- model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
- for param_name, param in self._parameters.items():
- if param is None:
- continue
- model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
- model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
- model_str += f"{_addindent(self.code, 4)}\n"
- module_file = folder / 'module.py'
- module_file.write_text(model_str)
- init_file = folder / '__init__.py'
- init_file.write_text('from .module import *')
- if len(blobified_modules) > 0:
- warnings.warn("Was not able to save the following children modules as reprs -"
- f"saved as pickled files instead: {blobified_modules}")
- @compatibility(is_backward_compatible=True)
- def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
- """
- Adds the given submodule to ``self``.
- This installs empty Modules where none exist yet if they are
- subpaths of ``target``.
- Args:
- target: The fully-qualified string name of the new submodule
- (See example in ``nn.Module.get_submodule`` for how to
- specify a fully-qualified string.)
- m: The submodule itself; the actual object we want to
- install in the current Module
- Return:
- bool: Whether or not the submodule could be inserted. For
- this method to return True, each object in the chain
- denoted by ``target`` must either a) not exist yet,
- or b) reference an ``nn.Module`` (not a parameter or
- other attribute)
- """
- *prefix, field = target.split('.')
- mod: torch.nn.Module = self
- for item in prefix:
- submod = getattr(mod, item, None)
- if submod is None:
- submod = torch.nn.Module()
- setattr(mod, item, submod)
- if not isinstance(submod, torch.nn.Module):
- return False
- mod = submod
- mod.add_module(field, m)
- return True
- @compatibility(is_backward_compatible=True)
- def delete_submodule(self, target: str) -> bool:
- """
- Deletes the given submodule from ``self``.
- The module will not be deleted if ``target`` is not a valid
- target.
- Args:
- target: The fully-qualified string name of the new submodule
- (See example in ``nn.Module.get_submodule`` for how to
- specify a fully-qualified string.)
- Returns:
- bool: Whether or not the target string referenced a
- submodule we want to delete. A return value of ``False``
- means that the ``target`` was not a valid reference to
- a submodule.
- """
- atoms = target.split(".")
- path, target_submod = atoms[:-1], atoms[-1]
- mod: torch.nn.Module = self
- # Get the parent module
- for item in path:
- if not hasattr(mod, item):
- return False
- mod = getattr(mod, item)
- if not isinstance(mod, torch.nn.Module):
- return False
- if not hasattr(mod, target_submod):
- return False
- if not isinstance(getattr(mod, target_submod), torch.nn.Module):
- return False
- delattr(mod, target_submod)
- return True
- @compatibility(is_backward_compatible=True)
- def delete_all_unused_submodules(self) -> None:
- """
- Deletes all unused submodules from ``self``.
- A Module is considered "used" if any one of the following is
- true:
- 1. It has children that are used
- 2. Its forward is called directly via a ``call_module`` node
- 3. It has a non-Module attribute that is used from a
- ``get_attr`` node
- This method can be called to clean up an ``nn.Module`` without
- manually calling ``delete_submodule`` on each unused submodule.
- """
- used: List[str] = []
- for node in self.graph.nodes:
- if node.op == "call_module" or node.op == "get_attr":
- # A list of strings representing the different parts
- # of the path. For exmaple, `foo.bar.baz` gives us
- # ["foo", "bar", "baz"]
- fullpath = node.target.split(".")
- # If we're looking at multiple parts of a path, join
- # join them with a dot. Otherwise, return that single
- # element without doing anything to it.
- def join_fn(x: str, y: str) -> str:
- return '.'.join([x, y] if y else [x])
- # Progressively collect all the names of intermediate
- # modules. For example, if we have the target
- # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
- # `foo.bar.baz` to the list.
- for path in itertools.accumulate(fullpath, join_fn):
- used.append(path)
- # For a `call_module` node, also register all recursive submodules
- # as used
- if node.op == "call_module":
- try:
- submod = self.get_submodule(node.target)
- for submod_name, _ in submod.named_modules():
- if submod_name != '':
- used.append('.'.join([node.target, submod_name]))
- except AttributeError:
- # Node referenced nonexistent submodule, don't need to
- # worry about GCing anything
- pass
- to_delete = [name for name, _ in self.named_modules()
- if name not in used]
- for name in to_delete:
- self.delete_submodule(name)
- @property
- def code(self) -> str:
- """
- Return the Python code generated from the ``Graph`` underlying this
- ``GraphModule``.
- """
- if not hasattr(self, '_code'):
- raise RuntimeError('Code has not been generated! Please report a bug to PyTorch')
- return self._code
- @compatibility(is_backward_compatible=True)
- def recompile(self) -> PythonCode:
- """
- Recompile this GraphModule from its ``graph`` attribute. This should be
- called after editing the contained ``graph``, otherwise the generated
- code of this ``GraphModule`` will be out of date.
- """
- if isinstance(self._graph._codegen, _PyTreeCodeGen):
- self._in_spec = self._graph._codegen.pytree_info.in_spec
- self._out_spec = self._graph._codegen.pytree_info.out_spec
- python_code = self._graph.python_code(root_module='self')
- self._code = python_code.src
- cls = type(self)
- cls.forward = _forward_from_src(self._code, python_code.globals)
- # Determine whether this class explicitly defines a __call__ implementation
- # to wrap. If it does, save it in order to have wrapped_call invoke it.
- # If it does not, wrapped_call can use a dynamic call to super() instead.
- # In most cases, super().__call__ should be torch.nn.Module.__call__.
- # We do not want to hold a reference to Module.__call__ here; doing so will
- # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
- cls_call = cls.__call__ if "__call__" in vars(cls) else None
- if '_wrapped_call' not in vars(cls):
- cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
- def call_wrapped(self, *args, **kwargs):
- return self._wrapped_call(self, *args, **kwargs)
- cls.__call__ = call_wrapped
- return python_code
- # Passing Tracer as argument allows subclasses extending fx.GraphModule
- # define their own Tracer (extending fx.Tracer).
- def __reduce_deploy__(self, importer: Importer):
- dict_without_graph = self.__dict__.copy()
- dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
- del dict_without_graph['_graph']
- python_code = self.recompile()
- import_block = _format_import_block(python_code.globals, importer)
- return (reduce_deploy_graph_module, (dict_without_graph, import_block))
- def __reduce_package__(self, exporter: PackageExporter):
- dict_without_graph = self.__dict__.copy()
- dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
- del dict_without_graph['_graph']
- generated_module_name = f'fx-generated._{exporter.get_unique_id()}'
- python_code = self.recompile()
- import_block = _format_import_block(python_code.globals, exporter.importer)
- module_code = import_block + self.code
- exporter.save_source_string(generated_module_name, module_code)
- return (reduce_package_graph_module, (dict_without_graph, generated_module_name))
- def __reduce__(self):
- """
- Serialization of GraphModule. We serialize only the generated code, not
- the underlying ``Graph``. This is because ``Graph`` does not have on-disk
- backward-compatibility guarantees, whereas Python source code does.
- On the deserialization side, we symbolically trace through the generated
- code to regenerate the underlying ``Graph``
- """
- dict_without_graph = self.__dict__.copy()
- python_code = self.recompile()
- import_block = _format_import_block(python_code.globals, sys_importer)
- del dict_without_graph['_graph']
- return (reduce_graph_module, (dict_without_graph, import_block))
- # because __reduce__ is defined for serialization,
- # we need to define deepcopy otherwise it will call __reduce__
- # and cause symbolic tracing to occur every time we try to copy the object
- def __deepcopy__(self, memo):
- res = type(self).__new__(type(self))
- memo[id(self)] = res
- fake_mod = torch.nn.Module()
- fake_mod.__dict__ = copy.deepcopy(self.__dict__, memo)
- GraphModule.__init__(res, fake_mod, fake_mod.__dict__['_graph'])
- # hooks are lost during `GraphModule.__init__`, so we need to copy over
- # them explicitly, note right now we are only copying state_dict related
- # hooks, to reduce bc-related issues, we can copy forward/backward related
- # hooks in the future as well if needed
- extra_preserved_attrs = [
- "_state_dict_hooks",
- "_load_state_dict_pre_hooks",
- "_load_state_dict_post_hooks"
- ]
- for attr in extra_preserved_attrs:
- if attr in self.__dict__:
- setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
- res.meta = copy.deepcopy(getattr(self, 'meta', {}), memo)
- if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
- for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
- setattr(res, attr_name, attr)
- return res
- def __copy__(self):
- res = GraphModule(self, self.graph)
- res.meta = getattr(self, 'meta', {})
- return res
- @compatibility(is_backward_compatible=False)
- def print_readable(self, print_output=True):
- """
- Return the Python code generated for current GraphModule and its children GraphModules
- """
- verbose_python_code = self._graph.python_code(root_module='self', verbose=True)
- module_code = verbose_python_code.src
- module_code = module_code.lstrip('\n')
- module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
- module_code = _addindent(module_code, 4)
- submodule_code_list = [""]
- for submodule in self.children():
- if isinstance(submodule, GraphModule):
- submodule_code_list.append(submodule.print_readable(print_output=False))
- submodule_code = "\n".join(submodule_code_list)
- submodule_code = _addindent(submodule_code, 4)
- output = module_code + submodule_code
- if print_output:
- print(module_code + submodule_code)
- return output
- def __str__(self) -> str:
- orig_str = super().__str__()
- print_readable_reminder = "# To see more debug info, please use `graph_module.print_readable()`"
- return '\n'.join([orig_str, self._code, print_readable_reminder])
- def _replicate_for_data_parallel(self):
- new_gm = self.__copy__()
- new_gm._is_replica = True
- return new_gm
- # workarounds for issues in __torch_function__
- # WAR for __torch_function__ not handling tensor lists,
- # fix is in https://github.com/pytorch/pytorch/pull/34725
- # orig_cat = torch.cat
- # def patched_cat(*args, **kwargs):
- # tensors = args[0]
- # for t in tensors:
- # if isinstance(t, Proxy):
- # return t.__torch_function__(patched_cat, (), args, kwargs)
- # return orig_cat(*args, **kwargs)
- # patched_cat.__module__ = 'torch'
- # patched_cat.__name__ = 'cat'
- # torch.cat = patched_cat
|