graph_module.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. import torch
  2. import torch.nn as nn
  3. import torch.overrides
  4. from torch.nn.modules.module import _addindent
  5. from torch.package import PackageImporter, PackageExporter
  6. import linecache
  7. from typing import Type, Dict, List, Any, Union, Optional, Set
  8. from .graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
  9. from ._compatibility import compatibility
  10. from torch.package import Importer, sys_importer
  11. import copy
  12. import itertools
  13. import sys
  14. import traceback
  15. from pathlib import Path
  16. import os
  17. import warnings
  18. __all__ = ["reduce_graph_module", "reduce_package_graph_module", "reduce_deploy_graph_module", "GraphModule"]
  19. _USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes"
  20. # Normal exec loses the source code, however we can work with
  21. # the linecache module to recover it.
  22. # Using _exec_with_source will add it to our local cache
  23. # and then tools like TorchScript will be able to get source info.
  24. class _EvalCacheLoader:
  25. def __init__(self):
  26. self.eval_cache = {}
  27. self.next_id = 0
  28. def cache(self, src: str, globals: Dict[str, Any]):
  29. """Store the source in a private cache, and add a lazy entry in linecache
  30. that allows the source to be retrieved by 'filename'.
  31. Args:
  32. src (str): The module source to cache
  33. globals (dict): The module globals
  34. Returns:
  35. str: The cache key (and dummy filename) generated for src.
  36. """
  37. key = self._get_key()
  38. self.eval_cache[key] = src
  39. # Don't mutate globals so that this loader is only used
  40. # to populate linecache, and doesn't interact with other modules
  41. # that might check `__loader__`
  42. globals_copy = globals.copy()
  43. globals_copy['__file__'] = key
  44. globals_copy['__name__'] = key
  45. globals_copy['__loader__'] = self
  46. linecache.lazycache(key, globals_copy)
  47. return key
  48. # Part of the loader protocol (PEP 302)
  49. # linecache will use this method when trying to find source code
  50. def get_source(self, module_name) -> Optional[str]:
  51. if module_name in self.eval_cache:
  52. return self.eval_cache[module_name]
  53. return None
  54. def _get_key(self):
  55. key = f'<eval_with_key>.{self.next_id}'
  56. self.next_id += 1
  57. return key
  58. _loader = _EvalCacheLoader()
  59. def _exec_with_source(src: str, globals: Dict[str, Any]):
  60. key = _loader.cache(src, globals)
  61. exec(compile(src, key, 'exec'), globals)
  62. def _forward_from_src(src: str, globals: Dict[str, Any]):
  63. # avoid mutating the passed in dict
  64. globals_copy = globals.copy()
  65. _exec_with_source(src, globals_copy)
  66. forward_fn = globals_copy['forward']
  67. del globals_copy['forward']
  68. return forward_fn
  69. def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
  70. if name in _custom_builtins:
  71. return _custom_builtins[name].import_str
  72. if _is_from_torch(name):
  73. return 'import torch'
  74. module_name, attr_name = importer.get_name(obj)
  75. return f'from {module_name} import {attr_name} as {name}'
  76. def _format_import_block(globals: Dict[str, Any], importer: Importer):
  77. import_strs: Set[str] = set()
  78. for name, obj in globals.items():
  79. import_strs.add(_format_import_statement(name, obj, importer))
  80. return '\n'.join(import_strs)
  81. @compatibility(is_backward_compatible=True)
  82. def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
  83. # BC: attribute name was changed from `code` to `_code` to facilitate
  84. # making `code` into a property and adding a docstring to it
  85. fn_src = body.get('_code') or body['code']
  86. forward = _forward_from_src(import_block + fn_src, {})
  87. return _deserialize_graph_module(forward, body)
  88. @compatibility(is_backward_compatible=True)
  89. def reduce_package_graph_module(
  90. importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
  91. ) -> torch.nn.Module:
  92. forward = importer.import_module(generated_module_name).forward
  93. return _deserialize_graph_module(forward, body)
  94. @compatibility(is_backward_compatible=True)
  95. def reduce_deploy_graph_module(
  96. importer: PackageImporter, body: Dict[Any, Any], import_block: str
  97. ) -> torch.nn.Module:
  98. ns = {}
  99. ns["__builtins__"] = importer.patched_builtins
  100. fn_src = body.get('_code')
  101. assert fn_src is not None
  102. forward = _forward_from_src(import_block + fn_src, ns)
  103. return _deserialize_graph_module(forward, body)
  104. def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module:
  105. """
  106. Deserialize a GraphModule given the dictionary of the original module,
  107. using the code to reconstruct the graph. We delete the actual graph before
  108. saving the dictionary so that changes to the in-memory graph format do not
  109. get serialized.
  110. """
  111. # We create a dummy class here because symbolic_trace pulls the forward()
  112. # function off of the class, rather than the instance
  113. class CodeOnlyModule(torch.nn.Module):
  114. def __init__(self, body):
  115. super().__init__()
  116. self.__dict__ = body
  117. # Try to retrieve the forward source in a backward-compatible way
  118. CodeOnlyModule.forward = forward
  119. tracer_cls = body.get('_tracer_cls')
  120. if tracer_cls is None:
  121. from ._symbolic_trace import Tracer
  122. tracer_cls = Tracer
  123. graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule')
  124. # This is a workaround for a mypy linter issue related to
  125. # passing base class as an argument - https://github.com/python/mypy/issues/5865.
  126. cls_tracer : Any = tracer_cls
  127. class KeepModules(cls_tracer):
  128. # we shouldn't trace into any of the submodules,
  129. # because they were not traced in the original GraphModule
  130. def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
  131. return True
  132. com = CodeOnlyModule(body)
  133. tracer_extras = body.get('_tracer_extras', {})
  134. graph = KeepModules().trace(com, **tracer_extras)
  135. # Manually set Tracer class on the reconstructed Graph, to avoid
  136. # referencing the private local subclass KeepModules.
  137. graph._tracer_cls = tracer_cls
  138. gm = GraphModule(com, graph, class_name=graphmodule_cls_name)
  139. # The GraphModule constructor only retains attributes referenced by the graph.
  140. # In this case, our goal is return a GraphModule as close to identical as the one
  141. # put into the package. If any additional attributes were present in body,
  142. # we should keep them.
  143. for k, v in body.items():
  144. if not hasattr(gm, k):
  145. setattr(gm, k, v)
  146. return gm
  147. # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
  148. # This installs empty Modules where none exist yet if they are subpaths of target
  149. def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
  150. *prefix, field = target.split('.')
  151. for item in prefix:
  152. f = getattr(from_module, item)
  153. t = getattr(to_module, item, None)
  154. if f is t:
  155. # we have already installed one of its parents
  156. # (e.g. target = root.linear.weight, but we have already installed root.linear)
  157. # once we install a parent, we no longer need to copy the children
  158. # since all the needed properties will already be present
  159. return
  160. if t is None:
  161. t = torch.nn.Module()
  162. setattr(to_module, item, t)
  163. from_module, to_module = f, t
  164. orig = getattr(from_module, field)
  165. # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
  166. # So, we register it as a named buffer in the target module.
  167. if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
  168. to_module.register_buffer(field, orig)
  169. else:
  170. setattr(to_module, field, orig)
  171. # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
  172. # This installs empty Modules where none exist yet if they are subpaths of target
  173. def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
  174. *prefix, field = target.split('.')
  175. for item in prefix:
  176. t = getattr(to_module, item, None)
  177. if t is None:
  178. t = torch.nn.Module()
  179. setattr(to_module, item, t)
  180. to_module = t
  181. # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
  182. # So, we register it as a named buffer in the target module.
  183. if isinstance(from_obj, torch.Tensor) and not isinstance(from_obj, torch.nn.Parameter):
  184. to_module.register_buffer(field, from_obj)
  185. else:
  186. setattr(to_module, field, from_obj)
  187. class _WrappedCall:
  188. def __init__(self, cls, cls_call):
  189. self.cls = cls
  190. self.cls_call = cls_call
  191. # Previously, if an error occurred when valid
  192. # symbolically-traced code was run with an invalid input, the
  193. # user would see the source of the error as coming from
  194. # `File "<eval_with_key_N">`, where N is some number. We use
  195. # this function to generate a more informative error message. We
  196. # return the traceback itself, a message explaining that the
  197. # error occurred in a traced Module's generated forward
  198. # function, and five lines of context surrounding the faulty
  199. # line
  200. @staticmethod
  201. def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
  202. # auxiliary variables (for readability)
  203. err_lineno = frame_summary.lineno
  204. assert err_lineno is not None
  205. line = frame_summary.line
  206. assert line is not None
  207. err_line_len = len(line)
  208. all_src_lines = linecache.getlines(frame_summary.filename)
  209. # constituent substrings of the error message
  210. tb_repr = traceback.format_exc()
  211. custom_msg = ("Call using an FX-traced Module, "
  212. f"line {err_lineno} of the traced Module's "
  213. "generated forward function:")
  214. before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
  215. marker = "~" * err_line_len + "~~~ <--- HERE"
  216. err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
  217. # joined message
  218. return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
  219. def __call__(self, obj, *args, **kwargs):
  220. try:
  221. if self.cls_call is not None:
  222. return self.cls_call(obj, *args, **kwargs)
  223. else:
  224. return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
  225. except Exception as e:
  226. assert e.__traceback__
  227. topmost_framesummary: traceback.FrameSummary = \
  228. traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
  229. if "eval_with_key" in topmost_framesummary.filename:
  230. print(_WrappedCall._generate_error_message(topmost_framesummary),
  231. file=sys.stderr)
  232. raise e.with_traceback(None)
  233. else:
  234. raise e
  235. @compatibility(is_backward_compatible=True)
  236. class GraphModule(torch.nn.Module):
  237. """
  238. GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
  239. ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
  240. from that ``graph``.
  241. .. warning::
  242. When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
  243. regenerated. However, if you edit the contents of the ``graph`` without reassigning
  244. the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
  245. code.
  246. """
  247. def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
  248. # each instance of a graph module needs its own forward method
  249. # so create a new singleton class for each instance.
  250. # it is a subclass of the user-defined class, the only difference
  251. # is an extra layer to install the forward method
  252. # address issue described at https://github.com/pytorch/pytorch/issues/63883
  253. # in other words, traverse class hierarchy to fix the redundant class definition problem
  254. for t in cls.__mro__:
  255. c = t.__qualname__.split('.')[-1]
  256. if c != 'GraphModuleImpl':
  257. cls = t
  258. break
  259. class GraphModuleImpl(cls): # type: ignore[misc, valid-type]
  260. pass
  261. return super().__new__(GraphModuleImpl)
  262. @compatibility(is_backward_compatible=True)
  263. def __init__(self,
  264. root: Union[torch.nn.Module, Dict[str, Any]],
  265. graph: Graph,
  266. class_name: str = 'GraphModule'):
  267. """
  268. Construct a GraphModule.
  269. Args:
  270. root (Union[torch.nn.Module, Dict[str, Any]):
  271. ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
  272. In the case that ``root`` is a Module, any references to Module-based objects (via qualified
  273. name) in the Graph's Nodes' ``target`` field will be copied over from the respective place
  274. within ``root``'s Module hierarchy into the GraphModule's module hierarchy.
  275. In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be
  276. looked up directly in the dict's keys. The object mapped to by the Dict will be copied
  277. over into the appropriate place within the GraphModule's module hierarchy.
  278. graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation
  279. class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
  280. error messages will report as originating from ``GraphModule``. It may be helpful to set this
  281. to ``root``'s original name or a name that makes sense within the context of your transform.
  282. """
  283. super().__init__()
  284. self.__class__.__name__ = class_name
  285. if isinstance(root, torch.nn.Module):
  286. if hasattr(root, 'training'):
  287. self.training = root.training
  288. for node in graph.nodes:
  289. if node.op in ['get_attr', 'call_module']:
  290. assert isinstance(node.target, str)
  291. _copy_attr(root, self, node.target)
  292. elif isinstance(root, dict):
  293. targets_to_copy = []
  294. for node in graph.nodes:
  295. if node.op in ['get_attr', 'call_module']:
  296. assert isinstance(node.target, str)
  297. if node.target not in root:
  298. raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target +
  299. ' but that target was not provided in ``root``!')
  300. targets_to_copy.append(node.target)
  301. # Sort targets in ascending order of the # of atoms.
  302. # This will ensure that less deeply nested attributes are assigned
  303. # before more deeply nested attributes. For example, foo.bar
  304. # will be assigned before foo.bar.baz. Otherwise, we might assign
  305. # the user-provided ``foo.bar`` and wipe out the previously-assigned
  306. # ``foo.bar.baz``
  307. targets_to_copy.sort(key=lambda t: t.count('.'))
  308. for target_to_copy in targets_to_copy:
  309. _assign_attr(root[target_to_copy], self, target_to_copy)
  310. else:
  311. raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!')
  312. self.graph = graph
  313. # Store the Tracer class responsible for creating a Graph separately as part of the
  314. # GraphModule state, except when the Tracer is defined in a local namespace.
  315. # Locally defined Tracers are not pickleable. This is needed because torch.package will
  316. # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
  317. # to re-create the Graph during deserialization.
  318. self._tracer_cls = None
  319. if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
  320. self._tracer_cls = self.graph._tracer_cls
  321. self._tracer_extras = {}
  322. if self.graph._tracer_extras:
  323. self._tracer_extras = self.graph._tracer_extras
  324. # Dictionary to store metadata
  325. self.meta : Dict[str, Any] = {}
  326. # TorchScript breaks trying to compile the graph setter because of the
  327. # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
  328. #
  329. # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
  330. __jit_unused_properties__ = ['graph']
  331. @property
  332. def graph(self) -> Graph:
  333. """
  334. Return the ``Graph`` underlying this ``GraphModule``
  335. """
  336. return self._graph
  337. @graph.setter
  338. def graph(self, g : Graph) -> None:
  339. """
  340. Set the underlying ``Graph`` for this ``GraphModule``. This will internally
  341. recompile the ``GraphModule`` so that the generated ``forward()`` function
  342. corresponds to ``g``
  343. """
  344. assert isinstance(g, Graph), f'Expected a Graph instance, but got {type(g)}'
  345. self._graph = g
  346. g.owning_module = self
  347. self.recompile()
  348. @compatibility(is_backward_compatible=False)
  349. def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModule"):
  350. """Dumps out module to ``folder`` with ``module_name`` so that it can be
  351. imported with ``from <folder> import <module_name>``
  352. Args:
  353. folder (Union[str, os.PathLike]): The folder to write the code out to
  354. module_name (str): Top-level name to use for the ``Module`` while
  355. writing out the code
  356. """
  357. folder = Path(folder)
  358. Path(folder).mkdir(exist_ok=True)
  359. torch.save(self.state_dict(), folder / 'state_dict.pt')
  360. tab = " " * 4
  361. custom_builtins = '\n'.join([v.import_str for v in _custom_builtins.values()])
  362. model_str = f"""
  363. import torch
  364. {custom_builtins}
  365. from torch.nn import *
  366. class {module_name}(torch.nn.Module):
  367. def __init__(self):
  368. super().__init__()
  369. """
  370. def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
  371. safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
  372. if type(module) in safe_reprs:
  373. return f"{module.__repr__()}"
  374. else:
  375. return None
  376. blobified_modules = []
  377. for module_name, module in self.named_children():
  378. module_str = _gen_model_repr(module_name, module)
  379. if module_str is None:
  380. module_file = folder / f'{module_name}.pt'
  381. torch.save(module, module_file)
  382. blobified_modules.append(module_name)
  383. module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
  384. module_str = f"torch.load(r'{module_file}') # {module_repr}"
  385. model_str += f"{tab*2}self.{module_name} = {module_str}\n"
  386. for buffer_name, buffer in self._buffers.items():
  387. if buffer is None:
  388. continue
  389. model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
  390. for param_name, param in self._parameters.items():
  391. if param is None:
  392. continue
  393. model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
  394. model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
  395. model_str += f"{_addindent(self.code, 4)}\n"
  396. module_file = folder / 'module.py'
  397. module_file.write_text(model_str)
  398. init_file = folder / '__init__.py'
  399. init_file.write_text('from .module import *')
  400. if len(blobified_modules) > 0:
  401. warnings.warn("Was not able to save the following children modules as reprs -"
  402. f"saved as pickled files instead: {blobified_modules}")
  403. @compatibility(is_backward_compatible=True)
  404. def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
  405. """
  406. Adds the given submodule to ``self``.
  407. This installs empty Modules where none exist yet if they are
  408. subpaths of ``target``.
  409. Args:
  410. target: The fully-qualified string name of the new submodule
  411. (See example in ``nn.Module.get_submodule`` for how to
  412. specify a fully-qualified string.)
  413. m: The submodule itself; the actual object we want to
  414. install in the current Module
  415. Return:
  416. bool: Whether or not the submodule could be inserted. For
  417. this method to return True, each object in the chain
  418. denoted by ``target`` must either a) not exist yet,
  419. or b) reference an ``nn.Module`` (not a parameter or
  420. other attribute)
  421. """
  422. *prefix, field = target.split('.')
  423. mod: torch.nn.Module = self
  424. for item in prefix:
  425. submod = getattr(mod, item, None)
  426. if submod is None:
  427. submod = torch.nn.Module()
  428. setattr(mod, item, submod)
  429. if not isinstance(submod, torch.nn.Module):
  430. return False
  431. mod = submod
  432. mod.add_module(field, m)
  433. return True
  434. @compatibility(is_backward_compatible=True)
  435. def delete_submodule(self, target: str) -> bool:
  436. """
  437. Deletes the given submodule from ``self``.
  438. The module will not be deleted if ``target`` is not a valid
  439. target.
  440. Args:
  441. target: The fully-qualified string name of the new submodule
  442. (See example in ``nn.Module.get_submodule`` for how to
  443. specify a fully-qualified string.)
  444. Returns:
  445. bool: Whether or not the target string referenced a
  446. submodule we want to delete. A return value of ``False``
  447. means that the ``target`` was not a valid reference to
  448. a submodule.
  449. """
  450. atoms = target.split(".")
  451. path, target_submod = atoms[:-1], atoms[-1]
  452. mod: torch.nn.Module = self
  453. # Get the parent module
  454. for item in path:
  455. if not hasattr(mod, item):
  456. return False
  457. mod = getattr(mod, item)
  458. if not isinstance(mod, torch.nn.Module):
  459. return False
  460. if not hasattr(mod, target_submod):
  461. return False
  462. if not isinstance(getattr(mod, target_submod), torch.nn.Module):
  463. return False
  464. delattr(mod, target_submod)
  465. return True
  466. @compatibility(is_backward_compatible=True)
  467. def delete_all_unused_submodules(self) -> None:
  468. """
  469. Deletes all unused submodules from ``self``.
  470. A Module is considered "used" if any one of the following is
  471. true:
  472. 1. It has children that are used
  473. 2. Its forward is called directly via a ``call_module`` node
  474. 3. It has a non-Module attribute that is used from a
  475. ``get_attr`` node
  476. This method can be called to clean up an ``nn.Module`` without
  477. manually calling ``delete_submodule`` on each unused submodule.
  478. """
  479. used: List[str] = []
  480. for node in self.graph.nodes:
  481. if node.op == "call_module" or node.op == "get_attr":
  482. # A list of strings representing the different parts
  483. # of the path. For exmaple, `foo.bar.baz` gives us
  484. # ["foo", "bar", "baz"]
  485. fullpath = node.target.split(".")
  486. # If we're looking at multiple parts of a path, join
  487. # join them with a dot. Otherwise, return that single
  488. # element without doing anything to it.
  489. def join_fn(x: str, y: str) -> str:
  490. return '.'.join([x, y] if y else [x])
  491. # Progressively collect all the names of intermediate
  492. # modules. For example, if we have the target
  493. # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
  494. # `foo.bar.baz` to the list.
  495. for path in itertools.accumulate(fullpath, join_fn):
  496. used.append(path)
  497. # For a `call_module` node, also register all recursive submodules
  498. # as used
  499. if node.op == "call_module":
  500. try:
  501. submod = self.get_submodule(node.target)
  502. for submod_name, _ in submod.named_modules():
  503. if submod_name != '':
  504. used.append('.'.join([node.target, submod_name]))
  505. except AttributeError:
  506. # Node referenced nonexistent submodule, don't need to
  507. # worry about GCing anything
  508. pass
  509. to_delete = [name for name, _ in self.named_modules()
  510. if name not in used]
  511. for name in to_delete:
  512. self.delete_submodule(name)
  513. @property
  514. def code(self) -> str:
  515. """
  516. Return the Python code generated from the ``Graph`` underlying this
  517. ``GraphModule``.
  518. """
  519. if not hasattr(self, '_code'):
  520. raise RuntimeError('Code has not been generated! Please report a bug to PyTorch')
  521. return self._code
  522. @compatibility(is_backward_compatible=True)
  523. def recompile(self) -> PythonCode:
  524. """
  525. Recompile this GraphModule from its ``graph`` attribute. This should be
  526. called after editing the contained ``graph``, otherwise the generated
  527. code of this ``GraphModule`` will be out of date.
  528. """
  529. if isinstance(self._graph._codegen, _PyTreeCodeGen):
  530. self._in_spec = self._graph._codegen.pytree_info.in_spec
  531. self._out_spec = self._graph._codegen.pytree_info.out_spec
  532. python_code = self._graph.python_code(root_module='self')
  533. self._code = python_code.src
  534. cls = type(self)
  535. cls.forward = _forward_from_src(self._code, python_code.globals)
  536. # Determine whether this class explicitly defines a __call__ implementation
  537. # to wrap. If it does, save it in order to have wrapped_call invoke it.
  538. # If it does not, wrapped_call can use a dynamic call to super() instead.
  539. # In most cases, super().__call__ should be torch.nn.Module.__call__.
  540. # We do not want to hold a reference to Module.__call__ here; doing so will
  541. # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
  542. cls_call = cls.__call__ if "__call__" in vars(cls) else None
  543. if '_wrapped_call' not in vars(cls):
  544. cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
  545. def call_wrapped(self, *args, **kwargs):
  546. return self._wrapped_call(self, *args, **kwargs)
  547. cls.__call__ = call_wrapped
  548. return python_code
  549. # Passing Tracer as argument allows subclasses extending fx.GraphModule
  550. # define their own Tracer (extending fx.Tracer).
  551. def __reduce_deploy__(self, importer: Importer):
  552. dict_without_graph = self.__dict__.copy()
  553. dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
  554. del dict_without_graph['_graph']
  555. python_code = self.recompile()
  556. import_block = _format_import_block(python_code.globals, importer)
  557. return (reduce_deploy_graph_module, (dict_without_graph, import_block))
  558. def __reduce_package__(self, exporter: PackageExporter):
  559. dict_without_graph = self.__dict__.copy()
  560. dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
  561. del dict_without_graph['_graph']
  562. generated_module_name = f'fx-generated._{exporter.get_unique_id()}'
  563. python_code = self.recompile()
  564. import_block = _format_import_block(python_code.globals, exporter.importer)
  565. module_code = import_block + self.code
  566. exporter.save_source_string(generated_module_name, module_code)
  567. return (reduce_package_graph_module, (dict_without_graph, generated_module_name))
  568. def __reduce__(self):
  569. """
  570. Serialization of GraphModule. We serialize only the generated code, not
  571. the underlying ``Graph``. This is because ``Graph`` does not have on-disk
  572. backward-compatibility guarantees, whereas Python source code does.
  573. On the deserialization side, we symbolically trace through the generated
  574. code to regenerate the underlying ``Graph``
  575. """
  576. dict_without_graph = self.__dict__.copy()
  577. python_code = self.recompile()
  578. import_block = _format_import_block(python_code.globals, sys_importer)
  579. del dict_without_graph['_graph']
  580. return (reduce_graph_module, (dict_without_graph, import_block))
  581. # because __reduce__ is defined for serialization,
  582. # we need to define deepcopy otherwise it will call __reduce__
  583. # and cause symbolic tracing to occur every time we try to copy the object
  584. def __deepcopy__(self, memo):
  585. res = type(self).__new__(type(self))
  586. memo[id(self)] = res
  587. fake_mod = torch.nn.Module()
  588. fake_mod.__dict__ = copy.deepcopy(self.__dict__, memo)
  589. GraphModule.__init__(res, fake_mod, fake_mod.__dict__['_graph'])
  590. # hooks are lost during `GraphModule.__init__`, so we need to copy over
  591. # them explicitly, note right now we are only copying state_dict related
  592. # hooks, to reduce bc-related issues, we can copy forward/backward related
  593. # hooks in the future as well if needed
  594. extra_preserved_attrs = [
  595. "_state_dict_hooks",
  596. "_load_state_dict_pre_hooks",
  597. "_load_state_dict_post_hooks"
  598. ]
  599. for attr in extra_preserved_attrs:
  600. if attr in self.__dict__:
  601. setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo))
  602. res.meta = copy.deepcopy(getattr(self, 'meta', {}), memo)
  603. if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta:
  604. for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():
  605. setattr(res, attr_name, attr)
  606. return res
  607. def __copy__(self):
  608. res = GraphModule(self, self.graph)
  609. res.meta = getattr(self, 'meta', {})
  610. return res
  611. @compatibility(is_backward_compatible=False)
  612. def print_readable(self, print_output=True):
  613. """
  614. Return the Python code generated for current GraphModule and its children GraphModules
  615. """
  616. verbose_python_code = self._graph.python_code(root_module='self', verbose=True)
  617. module_code = verbose_python_code.src
  618. module_code = module_code.lstrip('\n')
  619. module_code = f"class {self._get_name()}(torch.nn.Module):\n" + module_code
  620. module_code = _addindent(module_code, 4)
  621. submodule_code_list = [""]
  622. for submodule in self.children():
  623. if isinstance(submodule, GraphModule):
  624. submodule_code_list.append(submodule.print_readable(print_output=False))
  625. submodule_code = "\n".join(submodule_code_list)
  626. submodule_code = _addindent(submodule_code, 4)
  627. output = module_code + submodule_code
  628. if print_output:
  629. print(module_code + submodule_code)
  630. return output
  631. def __str__(self) -> str:
  632. orig_str = super().__str__()
  633. print_readable_reminder = "# To see more debug info, please use `graph_module.print_readable()`"
  634. return '\n'.join([orig_str, self._code, print_readable_reminder])
  635. def _replicate_for_data_parallel(self):
  636. new_gm = self.__copy__()
  637. new_gm._is_replica = True
  638. return new_gm
  639. # workarounds for issues in __torch_function__
  640. # WAR for __torch_function__ not handling tensor lists,
  641. # fix is in https://github.com/pytorch/pytorch/pull/34725
  642. # orig_cat = torch.cat
  643. # def patched_cat(*args, **kwargs):
  644. # tensors = args[0]
  645. # for t in tensors:
  646. # if isinstance(t, Proxy):
  647. # return t.__torch_function__(patched_cat, (), args, kwargs)
  648. # return orig_cat(*args, **kwargs)
  649. # patched_cat.__module__ = 'torch'
  650. # patched_cat.__name__ = 'cat'
  651. # torch.cat = patched_cat