graph.py 63 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563
  1. from collections import defaultdict
  2. from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name
  3. import torch.utils._pytree as pytree
  4. from . import _pytree as fx_pytree
  5. from ._compatibility import compatibility
  6. import contextlib
  7. from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type
  8. from dataclasses import dataclass
  9. from contextlib import contextmanager
  10. import copy
  11. import torch
  12. import keyword
  13. import re
  14. import builtins
  15. import math
  16. import warnings
  17. import inspect
  18. __all__ = ["PythonCode", "CodeGen", "Graph"]
  19. if TYPE_CHECKING:
  20. from .graph_module import GraphModule # noqa: F401
  21. from ._symbolic_trace import Tracer # noqa: F401
  22. # Mapping of builtins to their `typing` equivalent.
  23. _origin_type_map = {
  24. list: List,
  25. dict: Dict,
  26. set: Set,
  27. frozenset: FrozenSet,
  28. tuple: Tuple,
  29. }
  30. # Signature for functions thattransforms the body (`list[str]`) of the
  31. # generated code
  32. TransformCodeFunc = Callable[[List[str]], List[str]]
  33. class _CustomBuiltin(NamedTuple):
  34. """Additional objs that we add to every graph's globals.
  35. The repr() for some standard library objects is not valid Python code without
  36. an import. For common objects of this sort, we bundle them in the globals of
  37. every FX graph.
  38. """
  39. # How to import this object from the standard library.
  40. import_str: str
  41. # The actual object, produced from that import string.
  42. obj: Any
  43. _custom_builtins: Dict[str, _CustomBuiltin] = {}
  44. def _register_custom_builtin(name: str, import_str: str, obj: Any):
  45. _custom_builtins[name] = _CustomBuiltin(import_str, obj)
  46. _register_custom_builtin('inf', 'from math import inf', math.inf)
  47. _register_custom_builtin('nan', 'from math import nan', math.nan)
  48. _register_custom_builtin('NoneType', 'NoneType = type(None)', type(None))
  49. _register_custom_builtin('torch', 'import torch', torch)
  50. _register_custom_builtin('device', 'from torch import device', torch.device)
  51. _register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree)
  52. _register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree)
  53. def _is_magic(x: str) -> bool:
  54. return x.startswith('__') and x.endswith('__')
  55. def _snake_case(s: str) -> str:
  56. """
  57. Transforms the given string ``s`` to a Python-style variable name
  58. Examples:
  59. ``mod.snake_case`` -> ``mod.snake_case``
  60. ``mod.pascalCase``-> ``mod.pascal_case``
  61. ``mod.ALL_CAPS`` -> ``mod.all_caps``
  62. """
  63. chars = []
  64. prev_lower = False
  65. for c in s:
  66. if prev_lower and c.isupper():
  67. chars.append('_')
  68. chars.append(c.lower())
  69. prev_lower = c.islower()
  70. return ''.join(chars)
  71. def _is_from_torch(obj: Any) -> bool:
  72. module_name = getattr(obj, '__module__', None)
  73. if module_name is not None:
  74. base_module = module_name.partition('.')[0]
  75. return (
  76. base_module == 'torch' and
  77. not module_name.startswith("torch._dynamo.") and
  78. not module_name.startswith("torch._inductor.")
  79. )
  80. name = getattr(obj, '__name__', None)
  81. # exclude torch because torch.torch.torch.torch works. idk mang
  82. if name is not None and name != 'torch':
  83. for guess in [torch, torch.nn.functional]:
  84. if getattr(guess, name, None) is obj:
  85. return True
  86. return False
  87. class _Namespace:
  88. """A context for associating names uniquely with objects.
  89. The following invariants are enforced:
  90. - Each object gets a single name.
  91. - Each name is unique within a given namespace.
  92. - Names generated do not shadow builtins, unless the object is indeed that builtin.
  93. """
  94. def __init__(self):
  95. self._obj_to_name: Dict[Any, str] = {}
  96. self._unassociated_names = set()
  97. self._used_names: Set[str] = set()
  98. self._base_count: Dict[str, int] = defaultdict(int)
  99. self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+')
  100. self._name_suffix_regex = re.compile(r"(.*)_(\d+)$")
  101. def create_name(self, candidate: str, obj: Optional[Any]) -> str:
  102. """Create a unique name.
  103. Arguments:
  104. candidate: used as the basis for the unique name, relevant to the user.
  105. obj: If not None, an object that will be associated with the unique name.
  106. """
  107. if obj is not None and obj in self._obj_to_name:
  108. return self._obj_to_name[obj]
  109. # delete all characters that are illegal in a Python identifier
  110. candidate = self._illegal_char_regex.sub('_', candidate)
  111. if not candidate:
  112. candidate = '_unnamed'
  113. if candidate[0].isdigit():
  114. candidate = f'_{candidate}'
  115. match = self._name_suffix_regex.match(candidate)
  116. if match is None:
  117. base = candidate
  118. num = None
  119. else:
  120. base, num_str = match.group(1, 2)
  121. num = int(num_str)
  122. candidate = base if num is None else f'{base}_{num}'
  123. if not num:
  124. num = self._base_count[base]
  125. while candidate in self._used_names or self._is_illegal_name(candidate, obj):
  126. num += 1
  127. candidate = f'{base}_{num}'
  128. self._used_names.add(candidate)
  129. self._base_count[base] = num
  130. if obj is None:
  131. self._unassociated_names.add(candidate)
  132. else:
  133. self._obj_to_name[obj] = candidate
  134. return candidate
  135. def associate_name_with_obj(self, name: str, obj: Any):
  136. """Associate a unique name with an object.
  137. Neither `name` nor `obj` should be associated already.
  138. """
  139. assert obj not in self._obj_to_name
  140. assert name in self._unassociated_names
  141. self._obj_to_name[obj] = name
  142. self._unassociated_names.remove(name)
  143. def _is_illegal_name(self, name: str, obj: Any) -> bool:
  144. # 1. keywords are never allowed as names.
  145. if name in keyword.kwlist:
  146. return True
  147. # 2. Can't shadow a builtin name, unless you *are* that builtin.
  148. if name in builtins.__dict__:
  149. return obj is not builtins.__dict__[name]
  150. # 3. Can't shadow our custom builtins either
  151. if name in _custom_builtins:
  152. return obj is not _custom_builtins[name].obj
  153. return False
  154. dtype_abbrs = {
  155. torch.bfloat16: 'bf16',
  156. torch.float64: 'f64',
  157. torch.float32: 'f32',
  158. torch.float16: 'f16',
  159. torch.complex32: 'c32',
  160. torch.complex64: 'c64',
  161. torch.complex128: 'c128',
  162. torch.int8: 'i8',
  163. torch.int16: 'i16',
  164. torch.int32: 'i32',
  165. torch.int64: 'i64',
  166. torch.bool: 'b8',
  167. torch.uint8: 'u8',
  168. }
  169. @compatibility(is_backward_compatible=True)
  170. @dataclass
  171. class PythonCode:
  172. """
  173. Represents all the information necessary to exec or save a graph as Python code.
  174. """
  175. # Python source code for the forward function definition.
  176. src: str
  177. # Values in global scope during exection of `src_def`.
  178. globals: Dict[str, Any]
  179. def _format_target(base: str, target: str) -> str:
  180. elems = target.split('.')
  181. r = base
  182. for e in elems:
  183. if not e.isidentifier():
  184. r = f'getattr({r}, "{e}")'
  185. else:
  186. r = f'{r}.{e}'
  187. return r
  188. class _InsertPoint:
  189. def __init__(self, graph, new_insert):
  190. self.graph = graph
  191. self.orig_insert, graph._insert = graph._insert, new_insert
  192. def __enter__(self):
  193. pass
  194. def __exit__(self, type, value, tb):
  195. self.graph._insert = self.orig_insert
  196. class _node_list:
  197. def __init__(self, graph: 'Graph', direction: str = '_next'):
  198. assert direction in ['_next', '_prev']
  199. self.graph = graph
  200. self.direction = direction
  201. def __len__(self):
  202. return self.graph._len
  203. def __iter__(self):
  204. root, direction = self.graph._root, self.direction
  205. cur = getattr(root, direction)
  206. while cur is not root:
  207. if not cur._erased:
  208. yield cur
  209. cur = getattr(cur, direction)
  210. def __reversed__(self):
  211. return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev')
  212. class _PyTreeInfo(NamedTuple):
  213. """
  214. Contains extra info stored when we're using Pytrees
  215. """
  216. orig_args: List[str]
  217. in_spec: pytree.TreeSpec
  218. out_spec: Optional[pytree.TreeSpec]
  219. @compatibility(is_backward_compatible=False)
  220. class CodeGen:
  221. def __init__(self):
  222. self._body_transformer: Optional[TransformCodeFunc] = None
  223. def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str:
  224. """
  225. Given the free variables and a return annotation, generates the beginning of the FX function.
  226. By default, `gen_fn_def(['a', 'b'], '') == 'def forward(a, b):'`
  227. """
  228. # If the original function didn't have self as its first argument, we
  229. # would have added it.
  230. if len(free_vars) == 0 or free_vars[0] != 'self':
  231. free_vars.insert(0, 'self')
  232. return f"def forward({', '.join(free_vars)}){maybe_return_annotation}:"
  233. def generate_output(self, output_args: Argument) -> str:
  234. """
  235. Given the output arguments, generates the return statement of the FX function.
  236. Note: The returned statement should not be indented.
  237. """
  238. return f'return {repr(output_args)}'
  239. def process_inputs(self, *args: Any) -> Any:
  240. """
  241. Transforms the inputs so that the graph can take them as arguments, as
  242. non-default codegen may result in the inputs to the function being
  243. different from the inputs to the graph.
  244. If the graph was directly runnable, this invariant should hold true
  245. `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)`
  246. """
  247. return args
  248. def process_outputs(self, outputs: Any) -> Any:
  249. """
  250. Transforms the outputs of the graph to be identical to the codegen.
  251. See ``process_inputs`` for more details.
  252. """
  253. return outputs
  254. def additional_globals(self) -> List[Tuple[str, Any]]:
  255. """
  256. If your codegen uses extra global values, add tuples of (identifier,reference to the value) here.
  257. For example, return ['List', typing.List] if you need ``List`` in the global context.
  258. """
  259. return []
  260. def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode:
  261. free_vars: List[str] = []
  262. body: List[str] = []
  263. globals_: Dict[str, Any] = {}
  264. wrapped_fns: Dict[str, None] = {}
  265. # Wrap string in list to pass by reference
  266. maybe_return_annotation : List[str] = ['']
  267. def add_global(name_hint: str, obj: Any):
  268. """Add an obj to be tracked as a global.
  269. We call this for names that reference objects external to the
  270. Graph, like functions or types.
  271. Returns: the global name that should be used to reference 'obj' in generated source.
  272. """
  273. if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
  274. # HACK: workaround for how torch custom ops are registered. We
  275. # can't import them like normal modules so they must retain their
  276. # fully qualified name.
  277. return _get_qualified_name(obj)
  278. # normalize the name hint to get a proper identifier
  279. global_name = namespace.create_name(name_hint, obj)
  280. if global_name in globals_:
  281. assert globals_[global_name] is obj
  282. return global_name
  283. globals_[global_name] = obj
  284. return global_name
  285. # Pre-fill the globals table with registered builtins.
  286. for name, (_, obj) in _custom_builtins.items():
  287. add_global(name, obj)
  288. def type_repr(o : Any):
  289. if o == ():
  290. # Empty tuple is used for empty tuple type annotation Tuple[()]
  291. return '()'
  292. typename = _type_repr(o)
  293. if hasattr(o, '__origin__'):
  294. # This is a generic type, e.g. typing.List[torch.Tensor]
  295. origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
  296. origin_typename = add_global(_type_repr(origin_type), origin_type)
  297. if hasattr(o, '__args__'):
  298. # Assign global names for each of the inner type variables.
  299. args = [type_repr(arg) for arg in o.__args__]
  300. if len(args) == 0:
  301. # Bare type, such as `typing.Tuple` with no subscript
  302. # This code-path used in Python < 3.9
  303. return origin_typename
  304. return f'{origin_typename}[{",".join(args)}]'
  305. else:
  306. # Bare type, such as `typing.Tuple` with no subscript
  307. # This code-path used in Python 3.9+
  308. return origin_typename
  309. # Common case: this is a regular module name like 'foo.bar.baz'
  310. return add_global(typename, o)
  311. def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
  312. def _get_repr(arg):
  313. # Handle NamedTuples (if it has `_fields`) via add_global.
  314. if isinstance(arg, tuple) and hasattr(arg, '_fields'):
  315. qualified_name = _get_qualified_name(type(arg))
  316. global_name = add_global(qualified_name, type(arg))
  317. return f"{global_name}{repr(tuple(arg))}"
  318. return repr(arg)
  319. args_s = ', '.join(_get_repr(a) for a in args)
  320. kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
  321. if args_s and kwargs_s:
  322. return f'{args_s}, {kwargs_s}'
  323. return args_s or kwargs_s
  324. # Run through reverse nodes and record the first instance of a use
  325. # of a given node. This represents the *last* use of the node in the
  326. # execution order of the program, which we will use to free unused
  327. # values
  328. node_to_last_use : Dict[Node, Node] = {}
  329. user_to_last_uses : Dict[Node, List[Node]] = {}
  330. def register_last_uses(n : Node, user : Node):
  331. if n not in node_to_last_use:
  332. node_to_last_use[n] = user
  333. user_to_last_uses.setdefault(user, []).append(n)
  334. for node in reversed(nodes):
  335. map_arg(node.args, lambda n: register_last_uses(n, node))
  336. map_arg(node.kwargs, lambda n: register_last_uses(n, node))
  337. def delete_unused_values(user : Node):
  338. """
  339. Delete values after their last use. This ensures that values that are
  340. not used in the remainder of the code are freed and the memory usage
  341. of the code is optimal.
  342. """
  343. if user.op == 'placeholder':
  344. return
  345. if user.op == 'output':
  346. body.append('\n')
  347. return
  348. nodes_to_delete = user_to_last_uses.get(user, [])
  349. if len(nodes_to_delete):
  350. to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
  351. body.append(f'; {to_delete_str}\n')
  352. else:
  353. body.append('\n')
  354. prev_stacktrace = None
  355. def append_stacktrace_summary(node : Node):
  356. """
  357. Append a summary of the stacktrace to the generated code. This is
  358. useful for debugging.
  359. """
  360. nonlocal prev_stacktrace
  361. pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
  362. if node.op not in {'placeholder', 'output'}:
  363. if node.stack_trace:
  364. if node.stack_trace != prev_stacktrace:
  365. prev_stacktrace = node.stack_trace
  366. lines = node.stack_trace.strip().split('\n')
  367. idx = 0
  368. while idx < len(lines):
  369. line = lines[idx].strip()
  370. if line.startswith('File '):
  371. break
  372. idx += 1
  373. summary_lines = []
  374. if idx + 1 < len(lines):
  375. matches = pattern.match(lines[idx].strip())
  376. if matches:
  377. file = matches.group(1)
  378. lineno = matches.group(2)
  379. lineage = f'File: {file}:{lineno}'
  380. summary_lines.append(lineage)
  381. code = f"code: {lines[idx + 1].strip()}"
  382. summary_lines.append(code)
  383. summary_str = ', '.join(summary_lines)
  384. body.append(f'\n# {summary_str}\n')
  385. elif prev_stacktrace != "":
  386. prev_stacktrace = ""
  387. body.append('\n# No stacktrace found for following nodes\n')
  388. def stringify_shape(shape : torch.Size) -> str:
  389. return f"[{', '.join(str(x) for x in shape)}]"
  390. def emit_node(node : Node):
  391. maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
  392. if verbose:
  393. # override annotation with more detailed information
  394. from torch._subclasses.fake_tensor import FakeTensor
  395. from torch.fx.experimental.proxy_tensor import py_sym_types
  396. from torch.fx.passes.shape_prop import TensorMetadata
  397. meta_val = node.meta.get('val', node.meta.get('tensor_meta', None))
  398. if isinstance(meta_val, FakeTensor):
  399. maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
  400. elif isinstance(meta_val, py_sym_types):
  401. maybe_type_annotation = f': Sym({meta_val})'
  402. elif isinstance(meta_val, TensorMetadata):
  403. maybe_type_annotation = f': {dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}'
  404. if node.op == 'placeholder':
  405. assert isinstance(node.target, str)
  406. maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
  407. free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
  408. raw_name = node.target.replace('*', '')
  409. if raw_name != repr(node):
  410. body.append(f'{repr(node)} = {raw_name}\n')
  411. return
  412. elif node.op == 'call_method':
  413. assert isinstance(node.target, str)
  414. body.append(
  415. f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
  416. f'({_format_args(node.args[1:], node.kwargs)})')
  417. return
  418. elif node.op == 'call_function':
  419. assert callable(node.target)
  420. # pretty print operators
  421. if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods:
  422. assert isinstance(node.args, tuple)
  423. body.append(f'{repr(node)}{maybe_type_annotation} = '
  424. f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
  425. return
  426. # pretty print inplace operators; required for jit.script to work properly
  427. # not currently supported in normal FX graphs, but generated by torchdynamo
  428. if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods:
  429. body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
  430. f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
  431. return
  432. qualified_name = _get_qualified_name(node.target)
  433. global_name = add_global(qualified_name, node.target)
  434. # special case for getattr: node.args could be 2-argument or 3-argument
  435. # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
  436. if global_name == 'getattr' and \
  437. isinstance(node.args, tuple) and \
  438. isinstance(node.args[1], str) and \
  439. node.args[1].isidentifier() and \
  440. len(node.args) == 2:
  441. body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
  442. return
  443. body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
  444. if node.meta.get('is_wrapped', False):
  445. wrapped_fns.setdefault(global_name)
  446. return
  447. elif node.op == 'call_module':
  448. assert isinstance(node.target, str)
  449. body.append(f'{repr(node)}{maybe_type_annotation} = '
  450. f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
  451. return
  452. elif node.op == 'get_attr':
  453. assert isinstance(node.target, str)
  454. body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
  455. return
  456. elif node.op == 'output':
  457. if node.type is not None:
  458. maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
  459. body.append(self.generate_output(node.args[0]))
  460. return
  461. raise NotImplementedError(f'node: {node.op} {node.target}')
  462. for node in nodes:
  463. # NOTE: emit_node does not emit a string with newline. It depends
  464. # on delete_unused_values to append one
  465. if verbose:
  466. append_stacktrace_summary(node)
  467. emit_node(node)
  468. delete_unused_values(node)
  469. if len(body) == 0:
  470. # If the Graph has no non-placeholder nodes, no lines for the body
  471. # have been emitted. To continue to have valid Python code, emit a
  472. # single pass statement
  473. body.append('pass\n')
  474. if len(wrapped_fns) > 0:
  475. wrap_name = add_global('wrap', torch.fx.wrap)
  476. wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
  477. else:
  478. wrap_stmts = ''
  479. if self._body_transformer:
  480. body = self._body_transformer(body)
  481. for name, value in self.additional_globals():
  482. add_global(name, value)
  483. prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
  484. code = ''.join(body).lstrip('\n')
  485. code = '\n'.join(' ' + line for line in code.split('\n'))
  486. fn_code = f"""
  487. {wrap_stmts}
  488. {prologue}
  489. {code}"""
  490. return PythonCode(fn_code, globals_)
  491. # Ideally, we'd like to refactor all of the pytree logic into this codegen
  492. # class. Unfortunately, there are 3 areas we currently need extra logic in FX.
  493. # 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`.
  494. # 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec.
  495. # Since we can't access .graph within the FX forward, we need to copy the attribute to the module.
  496. # 3. We currently can't register the pytree imports with `add_global` - not sure why.
  497. class _PyTreeCodeGen(CodeGen):
  498. def __init__(self, pytree_info: _PyTreeInfo):
  499. super().__init__()
  500. self.pytree_info: _PyTreeInfo = pytree_info
  501. def process_inputs(self, *inputs: Any) -> Any:
  502. flat_args, _ = pytree.tree_flatten(inputs)
  503. return flat_args
  504. def process_outputs(self, out: Any) -> Any:
  505. if self.pytree_info is None:
  506. return out
  507. if not isinstance(out, list):
  508. out = [out]
  509. assert(self.pytree_info.out_spec is not None)
  510. return pytree.tree_unflatten(out, self.pytree_info.out_spec)
  511. def gen_fn_def(self, free_vars, maybe_return_annotation):
  512. # Given a user function/model:
  513. # myargs = [myargs0, myargs1]
  514. # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...}
  515. # def forward(self, mypos, *myargs, mykey=None, **mykwargs):
  516. #
  517. # The generated code flattens all keywords into positional arguments for `forward()`
  518. # e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1):
  519. #
  520. # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately
  521. # e.g. tree_flatten_spec(([mypos, myargs0, myargs1],
  522. # {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}),
  523. # self._in_spec)
  524. #
  525. # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec
  526. # e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec)
  527. if self.pytree_info is None:
  528. return super().gen_fn_def(free_vars, maybe_return_annotation)
  529. fn_args = self.pytree_info.orig_args
  530. has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False
  531. if has_orig_self:
  532. free_vars.insert(0, 'self')
  533. fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation)
  534. if len(free_vars) > 0: # pytree has placeholders in it
  535. # when kwargs is present, in_spec is tuple(args, kwargs)
  536. has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \
  537. len(self.pytree_info.in_spec.children_specs) == 2 and \
  538. self.pytree_info.in_spec.children_specs[0].type == tuple and \
  539. self.pytree_info.in_spec.children_specs[1].type == dict
  540. fn_kwargs = '{}'
  541. fn_signature = f"[{', '.join(fn_args)}], self._in_spec"
  542. if has_args_kwargs_tuple:
  543. count_args = len(self.pytree_info.in_spec.children_specs[0].children_specs)
  544. fn_args = self.pytree_info.orig_args[:count_args]
  545. fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip(
  546. self.pytree_info.in_spec.children_specs[1].context,
  547. self.pytree_info.orig_args[count_args:])) + '}'
  548. fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec"
  549. fn_definition += f"""
  550. {', '.join(free_vars)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
  551. return fn_definition
  552. def generate_output(self, output_args):
  553. if self.pytree_info:
  554. return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)'
  555. else:
  556. return super().generate_output(output_args)
  557. @compatibility(is_backward_compatible=True)
  558. class Graph:
  559. """
  560. ``Graph`` is the main data structure used in the FX Intermediate Representation.
  561. It consists of a series of ``Node`` s, each representing callsites (or other
  562. syntactic constructs). The list of ``Node`` s, taken together, constitute a
  563. valid Python function.
  564. For example, the following code
  565. .. code-block:: python
  566. import torch
  567. import torch.fx
  568. class MyModule(torch.nn.Module):
  569. def __init__(self):
  570. super().__init__()
  571. self.param = torch.nn.Parameter(torch.rand(3, 4))
  572. self.linear = torch.nn.Linear(4, 5)
  573. def forward(self, x):
  574. return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)
  575. m = MyModule()
  576. gm = torch.fx.symbolic_trace(m)
  577. Will produce the following Graph::
  578. print(gm.graph)
  579. .. code-block:: text
  580. graph(x):
  581. %linear_weight : [#users=1] = self.linear.weight
  582. %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
  583. %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
  584. %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
  585. %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
  586. %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
  587. return topk_1
  588. For the semantics of operations represented in the ``Graph``, please see :class:`Node`.
  589. """
  590. @compatibility(is_backward_compatible=True)
  591. def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None,
  592. tracer_extras: Optional[Dict[str, Any]] = None):
  593. """
  594. Construct an empty Graph.
  595. """
  596. self._root : Node = Node(self, '', 'root', '', (), {})
  597. self._used_names : Dict[str, int] = {} # base name -> number
  598. self._insert = self._root.prepend
  599. self._len = 0
  600. self._graph_namespace = _Namespace()
  601. self._owning_module = owning_module
  602. self._tracer_cls = tracer_cls
  603. self._tracer_extras = tracer_extras
  604. self._codegen = CodeGen()
  605. @property
  606. def owning_module(self):
  607. return self._owning_module
  608. @owning_module.setter
  609. def owning_module(self, mod: Optional["GraphModule"]):
  610. self._owning_module = mod
  611. @property
  612. def nodes(self) -> _node_list:
  613. """
  614. Get the list of Nodes that constitute this Graph.
  615. Note that this ``Node`` list representation is a doubly-linked list. Mutations
  616. during iteration (e.g. delete a Node, add a Node) are safe.
  617. Returns:
  618. A doubly-linked list of Nodes. Note that ``reversed`` can be called on
  619. this list to switch iteration order.
  620. """
  621. return _node_list(self)
  622. @compatibility(is_backward_compatible=True)
  623. def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]':
  624. """
  625. Copy all nodes from a given graph into ``self``.
  626. Args:
  627. g (Graph): The source graph from which to copy Nodes.
  628. val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping
  629. from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed
  630. in with values in it already to override copying of certain values.
  631. Returns:
  632. The value in ``self`` that is now equivalent to the output value in ``g``,
  633. if ``g`` had an ``output`` node. ``None`` otherwise.
  634. """
  635. for node in g.nodes:
  636. if node in val_map:
  637. continue
  638. if node.op == 'output':
  639. rv = map_arg(node.args[0], lambda n: val_map[n])
  640. return rv if not return_output_node else (rv, node)
  641. val_map[node] = self.node_copy(node, lambda n : val_map[n])
  642. return None
  643. def __deepcopy__(self, memo=None) -> 'Graph':
  644. """
  645. Explicitly implement __deepcopy__ to prevent excessive recursion depth
  646. from the default implementation. This uses graph_copy to copy the nodes
  647. in an iterative way, rather than recursive. It also populates the
  648. memoization table to prevent unnecessary copies (e.g. references to
  649. nodes or other parts of the Graph from a custom GraphModule implementation.
  650. """
  651. memo = memo if memo else {}
  652. g = Graph(tracer_cls=self._tracer_cls)
  653. output_vals = g.graph_copy(self, val_map=memo, return_output_node=True)
  654. g._codegen = copy.deepcopy(self._codegen)
  655. assert isinstance(output_vals, tuple)
  656. output_val, old_output_val = output_vals
  657. g.output(output_val, type_expr=getattr(old_output_val, 'type', None))
  658. return g
  659. @compatibility(is_backward_compatible=True)
  660. def create_node(self, op: str, target: 'Target',
  661. args: Optional[Tuple['Argument', ...]] = None,
  662. kwargs: Optional[Dict[str, 'Argument']] = None,
  663. name: Optional[str] = None,
  664. type_expr: Optional[Any] = None) -> Node:
  665. """
  666. Create a ``Node`` and add it to the ``Graph`` at the current insert-point.
  667. Note that the current insert-point can be set via :meth:`Graph.inserting_before`
  668. and :meth:`Graph.inserting_after`.
  669. Args:
  670. op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr',
  671. 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are
  672. described in the ``Graph`` docstring.
  673. args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node.
  674. kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node
  675. name (Optional[str]): an optional string name for the ``Node``.
  676. This will influence the name of the value assigned to in the
  677. Python generated code.
  678. type_expr (Optional[Any]): an optional type annotation representing the
  679. Python type the output of this node will have.
  680. Returns:
  681. The newly-created and inserted node.
  682. """
  683. assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output')
  684. args = () if args is None else args
  685. kwargs = {} if kwargs is None else kwargs
  686. assert isinstance(args, tuple), "args must be a tuple"
  687. assert isinstance(kwargs, dict), "kwargs must be a dict"
  688. candidate = name if name is not None else self._target_to_str(target)
  689. name = self._graph_namespace.create_name(candidate, None)
  690. n = Node(self, name, op, target, args, kwargs, type_expr)
  691. self._graph_namespace.associate_name_with_obj(name, n)
  692. self._insert(n)
  693. self._len += 1
  694. return n
  695. @compatibility(is_backward_compatible=False)
  696. def process_inputs(self, *args):
  697. """
  698. Processes args so that they can be passed to the FX graph.
  699. """
  700. return self._codegen.process_inputs(*args)
  701. @compatibility(is_backward_compatible=False)
  702. def process_outputs(self, out):
  703. return self._codegen.process_outputs(out)
  704. @compatibility(is_backward_compatible=True)
  705. def erase_node(self, to_erase : Node) -> None:
  706. """
  707. Erases a ``Node`` from the ``Graph``. Throws an exception if
  708. there are still users of that node in the ``Graph``.
  709. Args:
  710. to_erase (Node): The ``Node`` to erase from the ``Graph``.
  711. """
  712. if len(to_erase.users) > 0:
  713. raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} '
  714. f'users in the graph: {to_erase.users}!')
  715. to_erase._remove_from_list()
  716. to_erase._erased = True # iterators may retain handles to erased nodes
  717. self._len -= 1
  718. # Null out this Node's argument nodes so that the Nodes referred to
  719. # can update their ``users`` accordingly
  720. new_args = map_arg(to_erase.args, lambda n: None)
  721. assert isinstance(new_args, tuple)
  722. to_erase.args = new_args
  723. new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
  724. assert isinstance(new_kwargs, dict)
  725. to_erase.kwargs = new_kwargs
  726. @compatibility(is_backward_compatible=True)
  727. def inserting_before(self, n: Optional[Node] = None):
  728. """Set the point at which create_node and companion methods will insert into the graph.
  729. When used within a 'with' statement, this will temporary set the insert point and
  730. then restore it when the with statement exits::
  731. with g.inserting_before(n):
  732. ... # inserting before node n
  733. ... # insert point restored to what it was previously
  734. g.inserting_before(n) # set the insert point permanently
  735. Args:
  736. n (Optional[Node]): The node before which to insert. If None this will insert before
  737. the beginning of the entire graph.
  738. Returns:
  739. A resource manager that will restore the insert point on ``__exit__``.
  740. """
  741. if n is None:
  742. return self.inserting_after(self._root)
  743. assert n.graph == self, "Node to insert before is not in graph."
  744. return _InsertPoint(self, n.prepend)
  745. @compatibility(is_backward_compatible=True)
  746. def inserting_after(self, n: Optional[Node] = None):
  747. """Set the point at which create_node and companion methods will insert into the graph.
  748. When used within a 'with' statement, this will temporary set the insert point and
  749. then restore it when the with statement exits::
  750. with g.inserting_after(n):
  751. ... # inserting after node n
  752. ... # insert point restored to what it was previously
  753. g.inserting_after(n) # set the insert point permanently
  754. Args:
  755. n (Optional[Node]): The node before which to insert. If None this will insert after
  756. the beginning of the entire graph.
  757. Returns:
  758. A resource manager that will restore the insert point on ``__exit__``.
  759. """
  760. if n is None:
  761. return self.inserting_before(self._root)
  762. assert n.graph == self, "Node to insert after is not in graph."
  763. return _InsertPoint(self, n.append)
  764. @compatibility(is_backward_compatible=True)
  765. def placeholder(self, name: str, type_expr: Optional[Any] = None,
  766. default_value : Any = inspect.Signature.empty) -> Node:
  767. """
  768. Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents
  769. a function input.
  770. Args:
  771. name (str): A name for the input value. This corresponds to the name
  772. of the positional argument to the function this ``Graph`` represents.
  773. type_expr (Optional[Any]): an optional type annotation representing the
  774. Python type the output of this node will have. This is needed in some
  775. cases for proper code generation (e.g. when the function is used
  776. subsequently in TorchScript compilation).
  777. default_value (Any): The default value this function argument should take
  778. on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty`
  779. should be passed as this argument to specify that the parameter does _not_
  780. have a default value.
  781. .. note::
  782. The same insertion point and type expression rules apply for this method
  783. as ``Graph.create_node``.
  784. """
  785. args = () if default_value is inspect.Signature.empty else (default_value,)
  786. return self.create_node('placeholder', name, args=args, type_expr=type_expr)
  787. @compatibility(is_backward_compatible=True)
  788. def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node:
  789. """
  790. Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the
  791. fetch of an attribute from the ``Module`` hierarchy.
  792. Args:
  793. qualified_name (str): the fully-qualified name of the attribute to be retrieved.
  794. For example, if the traced Module has a submodule named ``foo``, which has a
  795. submodule named ``bar``, which has an attribute named ``baz``, the qualified
  796. name ``foo.bar.baz`` should be passed as ``qualified_name``.
  797. type_expr (Optional[Any]): an optional type annotation representing the
  798. Python type the output of this node will have.
  799. Returns:
  800. The newly-created and inserted ``get_attr`` node.
  801. .. note::
  802. The same insertion point and type expression rules apply for this method
  803. as ``Graph.create_node``.
  804. """
  805. def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool:
  806. module_path, _, name = qualified_name.rpartition(".")
  807. try:
  808. submod: torch.nn.Module = mod.get_submodule(module_path)
  809. except AttributeError:
  810. warnings.warn(f"Failed to fetch module {module_path}!")
  811. return False
  812. if not hasattr(submod, name):
  813. return False
  814. res = getattr(submod, name)
  815. if (not isinstance(res, torch.nn.Module)
  816. and not isinstance(res, torch.nn.Parameter)
  817. and name not in submod._buffers):
  818. return False
  819. return True
  820. if (self.owning_module and
  821. not _get_attr_reference_exists(self.owning_module, qualified_name)):
  822. warnings.warn("Attempted to insert a get_attr Node with no "
  823. "underlying reference in the owning "
  824. "GraphModule! Call "
  825. "GraphModule.add_submodule to add the "
  826. "necessary submodule, "
  827. "GraphModule.add_parameter to add the "
  828. "necessary Parameter, or "
  829. "nn.Module.register_buffer to add the "
  830. "necessary buffer", stacklevel=2)
  831. return self.create_node('get_attr', qualified_name, type_expr=type_expr)
  832. @compatibility(is_backward_compatible=True)
  833. def call_module(self,
  834. module_name: str,
  835. args: Optional[Tuple['Argument', ...]] = None,
  836. kwargs: Optional[Dict[str, 'Argument']] = None,
  837. type_expr: Optional[Any] = None) -> Node:
  838. """
  839. Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node
  840. represents a call to the forward() function of a ``Module`` in the ``Module``
  841. hierarchy.
  842. Args:
  843. module_name (str): The qualified name of the ``Module`` in the ``Module``
  844. hierarchy to be called. For example, if the traced ``Module`` has a
  845. submodule named ``foo``, which has a submodule named ``bar``, the
  846. qualified name ``foo.bar`` should be passed as ``module_name`` to
  847. call that module.
  848. args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
  849. to the called method. Note that this should *not* include a ``self`` argument.
  850. kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
  851. to the called method
  852. type_expr (Optional[Any]): an optional type annotation representing the
  853. Python type the output of this node will have.
  854. Returns:
  855. The newly-created and inserted ``call_module`` node.
  856. .. note::
  857. The same insertion point and type expression rules apply for this method
  858. as :meth:`Graph.create_node`.
  859. """
  860. if (self.owning_module and
  861. self.owning_module.get_submodule(module_name) is None):
  862. warnings.warn("Attempted to insert a call_module Node with "
  863. "no underlying reference in the owning "
  864. "GraphModule! Call "
  865. "GraphModule.add_submodule to add the "
  866. "necessary submodule")
  867. return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr)
  868. @compatibility(is_backward_compatible=True)
  869. def call_method(self,
  870. method_name: str,
  871. args: Optional[Tuple['Argument', ...]] = None,
  872. kwargs: Optional[Dict[str, 'Argument']] = None,
  873. type_expr: Optional[Any] = None) -> Node:
  874. """
  875. Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node
  876. represents a call to a given method on the 0th element of ``args``.
  877. Args:
  878. method_name (str): The name of the method to apply to the self argument.
  879. For example, if args[0] is a ``Node`` representing a ``Tensor``,
  880. then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``.
  881. args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
  882. to the called method. Note that this *should* include a ``self`` argument.
  883. kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
  884. to the called method
  885. type_expr (Optional[Any]): an optional type annotation representing the
  886. Python type the output of this node will have.
  887. Returns:
  888. The newly created and inserted ``call_method`` node.
  889. .. note::
  890. The same insertion point and type expression rules apply for this method
  891. as :meth:`Graph.create_node`.
  892. """
  893. return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr)
  894. @compatibility(is_backward_compatible=True)
  895. def call_function(self,
  896. the_function: Callable[..., Any],
  897. args: Optional[Tuple['Argument', ...]] = None,
  898. kwargs: Optional[Dict[str, 'Argument']] = None,
  899. type_expr: Optional[Any] = None) -> Node:
  900. """
  901. Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node
  902. represents a call to a Python callable, specified by ``the_function``.
  903. Args:
  904. the_function (Callable[..., Any]): The function to be called. Can be any PyTorch
  905. operator, Python function, or member of the ``builtins`` or ``operator``
  906. namespaces.
  907. args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed
  908. to the called function.
  909. kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed
  910. to the called function
  911. type_expr (Optional[Any]): an optional type annotation representing the
  912. Python type the output of this node will have.
  913. Returns:
  914. The newly created and inserted ``call_function`` node.
  915. .. note::
  916. The same insertion point and type expression rules apply for this method
  917. as :meth:`Graph.create_node`.
  918. """
  919. return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr)
  920. @compatibility(is_backward_compatible=True)
  921. def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node:
  922. """
  923. Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from
  924. the graph of node to the graph of self. Example::
  925. # Copying all the nodes in `g` into `new_graph`
  926. g : torch.fx.Graph = ...
  927. new_graph = torch.fx.graph()
  928. value_remap = {}
  929. for node in g.nodes:
  930. value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])
  931. Args:
  932. node (Node): The node to copy into ``self``.
  933. arg_transform (Callable[[Node], Argument]): A function that transforms
  934. ``Node`` arguments in node's ``args`` and ``kwargs`` into the
  935. equivalent argument in ``self``. In the simplest case, this should
  936. retrieve a value out of a table mapping Nodes in the original
  937. graph to ``self``.
  938. """
  939. args = map_arg(node.args, arg_transform)
  940. kwargs = map_arg(node.kwargs, arg_transform)
  941. assert isinstance(args, tuple)
  942. assert isinstance(kwargs, dict)
  943. result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type)
  944. result_node.meta = copy.copy(node.meta)
  945. return result_node
  946. @compatibility(is_backward_compatible=True)
  947. def output(self, result: 'Argument', type_expr: Optional[Any] = None):
  948. """
  949. Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents
  950. a ``return`` statement in Python code. ``result`` is the value that should
  951. be returned.
  952. Args:
  953. result (Argument): The value to be returned.
  954. type_expr (Optional[Any]): an optional type annotation representing the
  955. Python type the output of this node will have.
  956. .. note::
  957. The same insertion point and type expression rules apply for this method
  958. as ``Graph.create_node``.
  959. """
  960. return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr)
  961. def _target_to_str(self, target : Target) -> str:
  962. if callable(target):
  963. op = target.__name__
  964. else:
  965. assert isinstance(target, str)
  966. op = target
  967. if _is_magic(op):
  968. op = op[2:-2]
  969. op = _snake_case(op)
  970. return op
  971. @compatibility(is_backward_compatible=True)
  972. def python_code(self, root_module: str, *, verbose: bool = False) -> PythonCode:
  973. """
  974. Turn this ``Graph`` into valid Python code.
  975. Args:
  976. root_module (str): The name of the root module on which to look-up
  977. qualified name targets. This is usually 'self'.
  978. Returns:
  979. A PythonCode object, consisting of two fields:
  980. src: the Python source code representing the object
  981. globals: a dictionary of global names in `src` -> the objects that they reference.
  982. """
  983. # NOTE: [Graph Namespaces]
  984. #
  985. # There are two types of symbols in generated Python source code:
  986. # locals and globals.
  987. # Locals are locally defined by the output of a node in the Graph.
  988. # Globals are references to external objects, like functions or types.
  989. #
  990. # When generating Python code, we need to make sure to name things
  991. # appropriately. In particular:
  992. # - All names should be unique, to avoid weird shadowing bugs.
  993. # - These names need to be consistent, e.g. a object should always be
  994. # referenced by the same name.
  995. #
  996. # To do this, we create a new namespace just for this source. All names
  997. # that get printed must come from this namespace.
  998. #
  999. # Why can't we re-use node.name? Because it was generated within the
  1000. # namespace `self._graph_namespace`. In order to provide uniqueness
  1001. # over both locals (node.name) *and* globals, we create a completely
  1002. # new namespace to put all identifiers in.
  1003. namespace = _Namespace()
  1004. # Override Node's repr to generate a valid name within our namespace.
  1005. # Since repr() is designed to produce a valid Python expression, it
  1006. # makes sense to re-use it. This way, it's easy to print something like
  1007. # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is
  1008. # implemented cooperatively to allow this.
  1009. def node_repr(n: Node):
  1010. return namespace.create_name(n.name, n)
  1011. @contextmanager
  1012. def override_node_repr(graph: Graph):
  1013. orig_repr_fns = {}
  1014. for node in graph.nodes:
  1015. orig_repr_fns[node] = node._repr_fn
  1016. node._repr_fn = node_repr
  1017. try:
  1018. yield None
  1019. finally:
  1020. # restore the original repr functions
  1021. for node in graph.nodes:
  1022. node._repr_fn = orig_repr_fns[node]
  1023. with override_node_repr(self):
  1024. return self._python_code(root_module, namespace, verbose=verbose)
  1025. def _python_code(self, root_module: str, namespace: _Namespace, *, verbose: bool = False) -> PythonCode:
  1026. return self._codegen._gen_python_code(self.nodes, root_module, namespace, verbose=verbose)
  1027. def __str__(self) -> str:
  1028. """
  1029. Return a human-readable (not machine-readable) string representation
  1030. of this Graph
  1031. """
  1032. placeholder_names : List[str] = []
  1033. # This is a one-element array just so ``format_node`` can modify the closed
  1034. # over value
  1035. maybe_return_typename : List[str] = ['']
  1036. node_strs = [node.format_node(placeholder_names) for node in self.nodes]
  1037. param_str = ', '.join(placeholder_names)
  1038. s = f'graph({param_str}){maybe_return_typename[0]}:'
  1039. for node_str in node_strs:
  1040. if node_str:
  1041. s += '\n ' + node_str
  1042. return s
  1043. @compatibility(is_backward_compatible=True)
  1044. def print_tabular(self):
  1045. """
  1046. Prints the intermediate representation of the graph in tabular
  1047. format. Note that this API requires the ``tabulate`` module to be
  1048. installed.
  1049. """
  1050. try:
  1051. from tabulate import tabulate
  1052. except ImportError:
  1053. print("`print_tabular` relies on the library `tabulate`, "
  1054. "which could not be found on this machine. Run `pip "
  1055. "install tabulate` to install the library.")
  1056. node_specs = [[n.op, n.name, n.target, n.args, n.kwargs]
  1057. for n in self.nodes]
  1058. print(tabulate(node_specs,
  1059. headers=['opcode', 'name', 'target', 'args', 'kwargs']))
  1060. @compatibility(is_backward_compatible=True)
  1061. def lint(self):
  1062. """
  1063. Runs various checks on this Graph to make sure it is well-formed. In
  1064. particular:
  1065. - Checks Nodes have correct ownership (owned by this graph)
  1066. - Checks Nodes appear in topological order
  1067. - If this Graph has an owning GraphModule, checks that targets
  1068. exist in that GraphModule
  1069. """
  1070. # Check topo order
  1071. def check_arg(arg : Node, n : Optional[Node] = None) -> None:
  1072. context_str = f' of Node \'{n}\' ' if n else ' '
  1073. if arg.graph is not self:
  1074. raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, '
  1075. f'but was used as an argument! If you are copying nodes from another graph, make '
  1076. f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}')
  1077. if arg not in seen_values:
  1078. raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been '
  1079. f'defined! Please check that Nodes in the graph are topologically ordered\n{self}')
  1080. seen_names : Set[str] = set()
  1081. seen_values : Set[Node] = set()
  1082. for node in self.nodes:
  1083. if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']:
  1084. raise RuntimeError(f'Node {node} had unknown opcode {node.op}!')
  1085. if node.graph is not self:
  1086. raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!')
  1087. map_arg(node.args, lambda arg: check_arg(arg, node))
  1088. map_arg(node.kwargs, lambda arg: check_arg(arg, node))
  1089. seen_values.add(node)
  1090. if node.name in seen_names:
  1091. raise RuntimeError(f'Node redefined name {node.name}!')
  1092. seen_names.add(node.name)
  1093. # Check targets are legit
  1094. if self.owning_module:
  1095. for node in self.nodes:
  1096. if node.op == 'call_function':
  1097. if not callable(node.target):
  1098. raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
  1099. 'a Callable is expected')
  1100. else:
  1101. if not isinstance(node.target, str):
  1102. raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but '
  1103. 'a str is expected')
  1104. if node.op in ['get_attr', 'call_module']:
  1105. target_atoms = node.target.split('.')
  1106. m_itr = self.owning_module
  1107. for i, atom in enumerate(target_atoms):
  1108. new_m_itr = getattr(m_itr, atom, None)
  1109. seen_qualname = '.'.join(target_atoms[:i])
  1110. if new_m_itr is None:
  1111. raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute '
  1112. f'{atom} of {seen_qualname}')
  1113. if (node.op == "call_module"
  1114. and not isinstance(new_m_itr, torch.nn.Module)):
  1115. raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
  1116. 'not reference an nn.Module')
  1117. elif (node.op == "get_attr"
  1118. and not isinstance(new_m_itr, torch.nn.Module)
  1119. and not isinstance(new_m_itr, torch.nn.Parameter)
  1120. and atom not in m_itr._buffers):
  1121. warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
  1122. 'not reference an nn.Module, nn.Parameter, or buffer, which is '
  1123. 'what \'get_attr\' Nodes typically target')
  1124. else:
  1125. m_itr = new_m_itr
  1126. @compatibility(is_backward_compatible=True)
  1127. def eliminate_dead_code(self):
  1128. """
  1129. Remove all dead code from the graph, based on each node's number of
  1130. users, and whether the nodes have any side effects. The graph must be
  1131. topologically sorted before calling.
  1132. Returns:
  1133. bool: Whether the graph was changed as a result of the pass.
  1134. Example:
  1135. Before dead code is eliminated, `a` from `a = x + 1` below has no users
  1136. and thus can be eliminated from the graph without having an effect.
  1137. .. code-block:: python
  1138. def forward(self, x):
  1139. a = x + 1
  1140. return x + self.attr_1
  1141. After dead code is eliminated, `a = x + 1` has been removed, and the rest
  1142. of `forward` remains.
  1143. .. code-block:: python
  1144. def forward(self, x):
  1145. return x + self.attr_1
  1146. .. warning::
  1147. Dead code elimination has some heuristics to avoid removing
  1148. side-effectful nodes (see Node.is_impure) but in general coverage
  1149. is very bad, so you should assume that this method is not sound
  1150. to call unless you know that your FX graph consists entirely
  1151. of functional operations.
  1152. """
  1153. # Lint the graph first to make sure its topologically sorted, otherwise
  1154. # DCE below will not behave as expected.
  1155. self.lint()
  1156. # Reverse iterate so that when we remove a node, any nodes used as an
  1157. # input to that node have an updated user count that no longer reflects
  1158. # the removed node.
  1159. changed = False
  1160. for node in reversed(self.nodes):
  1161. if not node.is_impure() and len(node.users) == 0:
  1162. self.erase_node(node)
  1163. changed = True
  1164. return changed
  1165. @compatibility(is_backward_compatible=False)
  1166. def set_codegen(self, codegen: CodeGen):
  1167. self._codegen = codegen
  1168. @compatibility(is_backward_compatible=False)
  1169. def on_generate_code(
  1170. self,
  1171. make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]
  1172. ):
  1173. """Register a transformer function when python code is generated
  1174. Args:
  1175. make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]):
  1176. a function that returns a code transformer to be registered.
  1177. This function is called by `on_generate_code` to obtain the
  1178. code transformer.
  1179. This function is also given as its input the currently
  1180. registered code transformer (or None if nothing is registered),
  1181. in case it is not desirable to overwrite it. This is useful to
  1182. chain code transformers together.
  1183. Returns:
  1184. a context manager that when used in a `with` statement, to automatically
  1185. restore the previously registered code transformer.
  1186. Example:
  1187. .. code-block:: python
  1188. gm: fx.GraphModule = ...
  1189. # This is a code transformer we want to register. This code
  1190. # transformer prepends a pdb import and trace statement at the very
  1191. # beginning of the generated torch.fx code to allow for manual
  1192. # debugging with the PDB library.
  1193. def insert_pdb(body):
  1194. return ["import pdb; pdb.set_trace()\\n", *body]
  1195. # Registers `insert_pdb`, and overwrites the current registered
  1196. # code transformer (given by `_` to the lambda):
  1197. gm.graph.on_generate_code(
  1198. lambda _: insert_pdb
  1199. )
  1200. # Or alternatively, registers a code transformer which first
  1201. # runs `body` through existing registered transformer, then
  1202. # through `insert_pdb`:
  1203. gm.graph.on_generate_code(
  1204. lambda current_trans: (
  1205. lambda body: insert_pdb(
  1206. current_trans(body) if current_trans
  1207. else body
  1208. )
  1209. )
  1210. )
  1211. gm.recompile()
  1212. gm(*inputs) # drops into pdb
  1213. This function can also be used as a context manager, with the benefit to
  1214. automatically restores the previously registered code transformer:
  1215. .. code-block:: python
  1216. # ... continue from previous example
  1217. with gm.graph.on_generate_code(lambda _: insert_pdb):
  1218. # do more stuff with `gm`...
  1219. gm.recompile()
  1220. gm(*inputs) # drops into pdb
  1221. # now previous code transformer is restored (but `gm`'s code with pdb
  1222. # remains - that means you can run `gm` with pdb here too, until you
  1223. # run next `recompile()`).
  1224. """
  1225. on_gen_code_old = self._codegen._body_transformer
  1226. self._codegen._body_transformer = make_transformer(on_gen_code_old)
  1227. @contextlib.contextmanager
  1228. def on_generate_code_context_manager():
  1229. try:
  1230. yield
  1231. finally:
  1232. self._codegen._body_transformer = on_gen_code_old
  1233. return on_generate_code_context_manager()
  1234. reflectable_magic_methods = {
  1235. 'add': '{} + {}',
  1236. 'sub': '{} - {}',
  1237. 'mul': '{} * {}',
  1238. 'floordiv': '{} // {}',
  1239. 'truediv': '{} / {}',
  1240. 'div': '{} / {}',
  1241. 'mod': '{} % {}',
  1242. 'pow': '{} ** {}',
  1243. 'lshift': '{} << {}',
  1244. 'rshift': '{} >> {}',
  1245. 'and_': '{} & {}',
  1246. 'or_': '{} | {}',
  1247. 'xor': '{} ^ {}',
  1248. 'getitem': '{}[{}]',
  1249. 'matmul': '{} @ {}',
  1250. }
  1251. magic_methods = dict({
  1252. 'eq': '{} == {}',
  1253. 'ne': '{} != {}',
  1254. 'lt': '{} < {}',
  1255. 'gt': '{} > {}',
  1256. 'le': '{} <= {}',
  1257. 'ge': '{} >= {}',
  1258. 'pos': '+{}',
  1259. 'neg': '-{}',
  1260. 'invert': '~{}'}, **reflectable_magic_methods)
  1261. inplace_methods = {
  1262. 'iadd': '{} += {}',
  1263. 'iand': '{} &= {}',
  1264. 'ifloordiv': '{} //= {}',
  1265. 'ilshift': '{} <<= {}',
  1266. 'imod': '{} %= {}',
  1267. 'imul': '{} *= {}',
  1268. 'imatmul': '{} @= {}',
  1269. 'ior': '{} |= {}',
  1270. 'ipow': '{} **= {}',
  1271. 'irshift': '{} >>= {}',
  1272. 'isub': '{} -= {}',
  1273. 'itruediv': '{} /= {}',
  1274. 'ixor': '{} ^= {}',
  1275. 'setitem': '{}[{}] = {}',
  1276. }