n_shadows_utils.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312
  1. import torch
  2. import torch.fx
  3. from torch.fx import (
  4. Node,
  5. GraphModule,
  6. Graph,
  7. )
  8. from torch.ao.ns.fx.utils import (
  9. # TODO(future PR): make this work correctly for methods
  10. get_target_type_str,
  11. get_normalized_nth_input,
  12. )
  13. from torch.ao.ns.fx.ns_types import (
  14. NSSingleResultValuesType,
  15. NSResultsType,
  16. )
  17. from torch.ao.ns.fx.graph_passes import _maybe_get_fqn
  18. from torch.ao.quantization import QConfigMapping
  19. from torch.ao.quantization.qconfig import QConfigAny
  20. from torch.ao.quantization.utils import getattr_from_fqn
  21. from torch.ao.quantization.fx.match_utils import _MatchResult
  22. from torch.utils._pytree import tree_map
  23. import collections
  24. import copy
  25. from typing import List, Dict, Set, Tuple, Callable, Any, Optional
  26. import operator
  27. SHADOW_NODE_NAME_PREFIX = 'shadow'
  28. SHADOW_WRAPPER_NODE_NAME_PREFIX = 'shadow_wrapper'
  29. # TODO(future PR): reuse existing mapping instead of creating a new one
  30. BINARY_FUNCTIONS = {
  31. torch.add,
  32. torch.Tensor.add,
  33. operator.add,
  34. torch.mul,
  35. torch.Tensor.mul,
  36. operator.mul,
  37. }
  38. def _get_attr_name(subgraph_idx, subgraph_candidate_idx):
  39. return f"{SHADOW_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
  40. def _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx):
  41. return f"{SHADOW_WRAPPER_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}"
  42. class OutputProp:
  43. """
  44. Output propagation (modeled from shape propagation).
  45. Given a GraphModule and an example input, saves the output flowing
  46. through each node on `node.traced_result`.
  47. Code based on the example from
  48. https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern
  49. """
  50. def __init__(self, mod):
  51. self.mod = mod
  52. self.graph = mod.graph
  53. self.modules = dict(self.mod.named_modules())
  54. def propagate(self, *args):
  55. args_iter = iter(args)
  56. env : Dict[str, Node] = {}
  57. def load_arg(a):
  58. return torch.fx.graph.map_arg(a, lambda n: env[n.name])
  59. def fetch_attr(target : str):
  60. target_atoms = target.split('.')
  61. attr_itr = self.mod
  62. for i, atom in enumerate(target_atoms):
  63. if not hasattr(attr_itr, atom):
  64. raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
  65. attr_itr = getattr(attr_itr, atom)
  66. return attr_itr
  67. for node in self.graph.nodes:
  68. if node.op == 'placeholder':
  69. result = next(args_iter)
  70. elif node.op == 'get_attr':
  71. result = fetch_attr(node.target)
  72. elif node.op == 'call_function':
  73. result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
  74. elif node.op == 'call_method':
  75. self_obj, *args = load_arg(node.args)
  76. kwargs = load_arg(node.kwargs)
  77. result = getattr(self_obj, node.target)(*args, **kwargs)
  78. elif node.op == 'call_module':
  79. result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
  80. if isinstance(result, torch.Tensor):
  81. node.traced_result = result
  82. env[node.name] = result
  83. return None
  84. def _get_dedup_subgraphs(
  85. matches: Dict[str, _MatchResult]
  86. ) -> Dict[str, List[Node]]:
  87. # the original matches variable is unique by node, make it unique by subgraph
  88. # instead
  89. seen_nodes = set()
  90. subgraphs_dedup = {}
  91. # Dict items are not reversible until Python 3.8, so we hack it
  92. # to be compatible with previous Python versions
  93. # TODO(future PR): try reversed(list(matches.items()))
  94. matches_items_reversed: List[Tuple[str, _MatchResult]] = []
  95. for name, cur_match in matches.items():
  96. matches_items_reversed.insert(0, (name, cur_match))
  97. # Note: the order is important. `matches` currently provides the matches
  98. # in reverse order. We would like to process the matches in non-reverse
  99. # order, so that we can create an intuitive naming scheme, such as
  100. # naming the first op's submodules `shadow_0_0` through `shadow_0_(n-1)`
  101. for name, cur_match in matches_items_reversed: # type: ignore[call-overload]
  102. was_seen = False
  103. for node_or_tuple in cur_match[1]:
  104. # Cur_match[1] has an unusual type. It says that it's a `List[Node]`,
  105. # but it is really not. Furthermore, the contents of this field
  106. # can change from match results of multiple nodes of the same pattern
  107. #
  108. # For example, for conv -> bn -> relu, we see
  109. # match_results = {
  110. # 'conv': (relu, [(bn, conv), relu], ...),
  111. # 'bn': (relu, [(bn, conv), relu], ...),
  112. # 'relu': (relu, [(bn, conv), relu], ...),
  113. # }
  114. #
  115. # Ideally we should clean up the `find_matches` function to make
  116. # this more intuitive. For the purposes of this prototype, we hack
  117. # around it.
  118. if isinstance(node_or_tuple, Node):
  119. if node_or_tuple in seen_nodes:
  120. was_seen = True
  121. seen_nodes.add(node_or_tuple)
  122. else:
  123. assert isinstance(node_or_tuple, tuple)
  124. for node in node_or_tuple:
  125. assert isinstance(node, Node)
  126. if node in seen_nodes:
  127. was_seen = True
  128. seen_nodes.add(node)
  129. if was_seen:
  130. continue
  131. # Start with the unusual type, convert it to [op_0, ..., op_n]
  132. list_of_nodes = []
  133. if len(cur_match[1]) == 1:
  134. list_of_nodes = cur_match[1]
  135. else:
  136. assert len(cur_match[1]) == 2
  137. # either (a, b), or ((a, b), c) or (c, (a, b))
  138. # cannot make any assumptions on order, not clear what the
  139. # _find_matches function is doing to populate this
  140. # TODO(future PR): make this code less confusing, see discussion
  141. # in https://github.com/pytorch/pytorch/pull/80521/files#r975918836
  142. def _order_nodes(node_a, node_b, node_c) -> List[Node]:
  143. nodes = [node_a, node_b, node_c]
  144. first_node = None
  145. mid_node = None
  146. last_node = None
  147. for n in nodes:
  148. prev_n = n.args[0]
  149. next_n = list(n.users)[0]
  150. if prev_n not in nodes:
  151. first_node = n
  152. elif next_n not in nodes:
  153. last_node = n
  154. else:
  155. mid_node = n
  156. assert first_node is not None and mid_node is not None and \
  157. last_node is not None
  158. assert mid_node.args[0] is first_node
  159. assert last_node.args[0] is mid_node
  160. return [last_node, mid_node, first_node]
  161. if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node):
  162. # (a, b)
  163. list_of_nodes = cur_match[1]
  164. elif isinstance(cur_match[1][0], tuple):
  165. # ((a, b), c)
  166. node_a, node_b = cur_match[1][0]
  167. node_c = cur_match[1][1]
  168. list_of_nodes = _order_nodes(node_a, node_b, node_c)
  169. elif isinstance(cur_match[1][1], tuple):
  170. # (a, (b, c))
  171. node_a, node_b = cur_match[1][1]
  172. node_c = cur_match[1][0]
  173. list_of_nodes = _order_nodes(node_a, node_b, node_c)
  174. # [node_n, ..., node_0], note that the order is reversed
  175. # to make it chronological for simple subgraphs
  176. list_of_nodes.reverse()
  177. subgraphs_dedup[name] = list_of_nodes
  178. return subgraphs_dedup
  179. def _get_logger_for_subgraph(
  180. model: GraphModule,
  181. first_node: Node,
  182. last_node: Node,
  183. subgraph_idx: int,
  184. subgraph_candidate_idx: int,
  185. qconfig_str: str,
  186. logger_cls: Callable,
  187. fqn: Optional[str],
  188. ) -> torch.nn.Module:
  189. """
  190. Given a model and a linear subgraph starting from `first_node` and
  191. ending with `last_node`, creates a logger for the end of this
  192. subgraph.
  193. """
  194. if fqn is None:
  195. fqn = ''
  196. logger_mod_orig = logger_cls(
  197. first_node.name, # ref_node_name
  198. last_node.name, # prev_node_name
  199. f'subgraph_{subgraph_idx}_{subgraph_candidate_idx}', # model_name
  200. 'model', # ref_name
  201. get_target_type_str(last_node, model), # prev_node_target_type
  202. get_target_type_str(first_node, model), # ref_node_target_type
  203. NSSingleResultValuesType.NODE_OUTPUT.value, # results_type
  204. 0, # index_within_arg
  205. 0, # index_of_arg
  206. fqn, # fqn
  207. qconfig_str,
  208. )
  209. # Usually we expect the user to add loggers, then calibrate, then convert,
  210. # and then populate loggers. This is why the loggers start disabled.
  211. # TODO(future PR): reconsider the design to make this more intuitive.
  212. logger_mod_orig.enabled = False
  213. return logger_mod_orig
  214. def create_submodule_from_subgraph(
  215. model: torch.nn.Module,
  216. first_node: Node,
  217. last_node: Node,
  218. ) -> GraphModule:
  219. """
  220. Input: a model, and a linear subgraph within the model from first_node to
  221. last_node.
  222. Output: a new submodule containing a copy of the subgraph, with the inputs
  223. to the first node becoming the inputs to the submodule, and all other
  224. nodes in the subgraph being copied.
  225. Example inputs:
  226. `model`: a module with graph
  227. x0 -> op1 -> x1 -> op2 -> x2
  228. |
  229. arg1
  230. `first_node`: op1
  231. `last_node`: op2
  232. Example output: a new module with graph
  233. input1 -> op1_copy -> x1 -> op2_copy -> output1
  234. |
  235. arg1
  236. """
  237. #
  238. # create a blank GraphModule with an empty graph
  239. #
  240. class M(torch.nn.Module):
  241. def forward(self, x):
  242. pass
  243. m = M()
  244. gm = torch.fx.symbolic_trace(m)
  245. g = gm.graph
  246. for node in reversed(gm.graph.nodes):
  247. g.erase_node(node)
  248. #
  249. # modify the graph to have a copy of our subgraph
  250. #
  251. cur_node_orig = first_node
  252. cur_args_orig = cur_node_orig.args
  253. cur_kwargs_orig = cur_node_orig.kwargs
  254. cur_name_idx = 0
  255. iteration_limit = 100
  256. cur_iteration = 0
  257. while True:
  258. if cur_node_orig is first_node:
  259. # we are at the first node, we need to set up graph inputs
  260. # TODO(future): some graphs could have placeholders which are unrelated
  261. # to the first node, need to handle this
  262. cur_args_copy = []
  263. cur_kwargs_copy = {}
  264. seen_names: Set[str] = set()
  265. old_name_to_new_node: Dict[str, Node] = {}
  266. def _add_placeholder(
  267. g: Graph, node: Node, seen_names, old_name_to_new_node
  268. ):
  269. # note: for graphs starting with patterns such as `y = x + x`, we
  270. # need to ensure we do not add multiple placeholders with the
  271. # same name
  272. counter = 0
  273. while node.name + '_' + str(counter) in seen_names:
  274. counter += 1
  275. cur_name = node.name + '_' + str(counter)
  276. seen_names.add(cur_name)
  277. placeholder = g.placeholder(cur_name)
  278. old_name_to_new_node[node.name] = placeholder
  279. return placeholder
  280. for arg in cur_node_orig.args:
  281. if isinstance(arg, Node):
  282. p = _add_placeholder(
  283. g, arg, seen_names, old_name_to_new_node)
  284. cur_args_copy.append(p)
  285. elif isinstance(arg, (list, tuple)):
  286. new_arg = []
  287. for inner_arg in arg:
  288. if isinstance(inner_arg, Node):
  289. new_arg.append(_add_placeholder(
  290. g, inner_arg, seen_names, old_name_to_new_node))
  291. else:
  292. new_arg.append(inner_arg)
  293. cur_args_copy.append(new_arg)
  294. else:
  295. cur_args_copy.append(arg)
  296. # TODO(future PR): handle non-normalized kwargs
  297. for kwarg_name, kwarg in cur_node_orig.kwargs.items():
  298. if isinstance(kwarg, Node):
  299. cur_kwargs_copy[kwarg_name] = _add_placeholder(
  300. g, kwarg, seen_names, old_name_to_new_node)
  301. elif isinstance(kwarg, (list, tuple)):
  302. new_kwarg = []
  303. for inner_kwarg in kwarg:
  304. p = _add_placeholder(
  305. g, inner_kwarg, seen_names, old_name_to_new_node)
  306. new_kwarg.append(p)
  307. cur_kwargs_copy[kwarg_name] = new_kwarg
  308. else:
  309. cur_kwargs_copy[kwarg_name] = kwarg
  310. cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment]
  311. else:
  312. # we are not at first node, first arg is from the previous node,
  313. # and all other args are copied
  314. # the current implementation is simplistic and cannot handle
  315. # ops with two or more arguments which need to be passed from
  316. # the previous op, so we assert them out
  317. assert cur_node_orig.target not in BINARY_FUNCTIONS
  318. # at this point in the code, cur_node_copy is pointing to the copy
  319. # of the previous node
  320. # TODO(future PR): this is not handling complicated graphs correctly, need to
  321. # look at actual relationships instead of assuming sequential graph
  322. # TODO(future PR): this is ignoring kwargs, will need to support kwargs
  323. # for any fusion pattern which has them for a node that is not the
  324. # first node.
  325. cur_args_copy = [cur_node_copy] # type: ignore[has-type]
  326. if len(cur_node_orig.args) > 1:
  327. for arg in cur_node_orig.args[1:]:
  328. if isinstance(arg, torch.nn.Parameter):
  329. new_arg = arg.clone().detach() # type: ignore[assignment]
  330. mod_name = f"mod_{cur_name_idx}"
  331. cur_name_idx += 1
  332. setattr(gm, mod_name, new_arg)
  333. new_arg_placeholder = gm.placeholder(mod_name)
  334. cur_args_copy.append(new_arg_placeholder)
  335. elif isinstance(arg, (float, int, torch.dtype)):
  336. cur_args_copy.append(arg)
  337. else:
  338. raise AssertionError(f'arg of type {type(arg)} not handled yet')
  339. cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment]
  340. # copy the node
  341. if cur_node_orig.op == 'call_module':
  342. orig_mod = getattr_from_fqn(model, cur_node_orig.target) # type: ignore[arg-type]
  343. orig_mod_copy = copy.deepcopy(orig_mod)
  344. mod_name = f"mod_{cur_name_idx}"
  345. setattr(gm, mod_name, orig_mod_copy)
  346. cur_name_idx += 1
  347. cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy)
  348. elif cur_node_orig.op == 'call_function':
  349. cur_node_copy = g.call_function(
  350. cur_node_orig.target, cur_args_copy, cur_kwargs_copy)
  351. elif cur_node_orig.op == 'call_method':
  352. cur_node_copy = g.call_method(
  353. cur_node_orig.target, cur_args_copy, cur_kwargs_copy)
  354. else:
  355. raise AssertionError(f'{cur_node_orig.op} not supported yet')
  356. if cur_node_orig is last_node:
  357. break
  358. # go to next node
  359. assert len(cur_node_orig.users.keys()) == 1, \
  360. f'{cur_node_orig} has more than 1 users, not supported yet'
  361. cur_node_orig = list(cur_node_orig.users.keys())[0]
  362. cur_args_orig = cur_node_orig.args
  363. cur_kwargs_orig = cur_node_orig.kwargs
  364. cur_iteration += 1
  365. if cur_iteration > iteration_limit:
  366. raise AssertionError('iteration limit exceeded')
  367. # set up outputs
  368. g.output(cur_node_copy)
  369. gm.recompile()
  370. return gm
  371. def create_one_transformed_and_logged_copy_of_subgraph(
  372. mt: GraphModule,
  373. subgraph_idx: int,
  374. subgraph_candidate_idx: int,
  375. first_node: Node,
  376. last_node: Node,
  377. fqn: Optional[str],
  378. list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
  379. example_inputs: Any,
  380. last_added_shadow_node_list: List[Optional[Node]],
  381. custom_prepare_fn: Optional[Callable] = None,
  382. custom_prepare_kwargs: Dict[str, Any] = None,
  383. ) -> None:
  384. """
  385. Given a subgraph in `mt` and a subgraph candidate idx, inserts the
  386. subgraph candidate copy and instruments it with loggers.
  387. If subgraph_candidate_idx is 0, this is the baseline fp32 subgraph and we just
  388. add a logger to the end.
  389. If subgraph_candidate_idx is not 0, we create a copy of the subgraph and
  390. prepare it with `prepare_fx`.
  391. """
  392. # TODO(future PR): move logger classes to utils to remove circular dependency
  393. from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger
  394. if subgraph_candidate_idx == 0:
  395. # idx = 0 is the floating point (original) version of the subgraph
  396. # We keep the subgraph as is, and add a logger at the end
  397. qconfig_str = ''
  398. logger_mod_orig = _get_logger_for_subgraph(
  399. mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx,
  400. qconfig_str, OutputLogger, fqn)
  401. attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
  402. assert not hasattr(mt, attr_name)
  403. setattr(mt, attr_name, logger_mod_orig)
  404. with mt.graph.inserting_after(last_node):
  405. new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={})
  406. last_added_shadow_node_list[0] = new_node
  407. else:
  408. # idx > 0 means we have a candidate qconfig to try, so we need
  409. # to make a copy of the subgraph, feed it with the right inputs,
  410. # and add a logger at the end
  411. # get the qconfig
  412. # subtract one because the first candidate is the floating point
  413. # version of the subgraph
  414. node_name_to_qconfig = \
  415. list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
  416. qconfig = node_name_to_qconfig[first_node.name]
  417. # if no quantization is requested, skip
  418. # TODO(future PR): deduplicate equivalent qconfigs that come from
  419. # different qconfig mapping objects
  420. if qconfig is None:
  421. return
  422. qconfig_mapping = QConfigMapping().set_global(qconfig)
  423. # create a copy of the submodule, wrapped in a separate module
  424. orig_mod_copy_wrapped = create_submodule_from_subgraph(
  425. mt, first_node, last_node)
  426. # add a call to prepare_fx on the wrapper module
  427. if custom_prepare_fn is None:
  428. orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx(
  429. orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs)
  430. else:
  431. if custom_prepare_kwargs is None:
  432. custom_prepare_kwargs = {}
  433. for kwarg_name in ["example_inputs", "prepare_custom_config", "qconfig_mapping"]:
  434. assert kwarg_name not in custom_prepare_kwargs, f"cannot specify {kwarg_name} in custom_prepare_kwargs"
  435. prepare_kwargs: Dict[str, Any] = {
  436. "example_inputs": example_inputs,
  437. "qconfig_mapping": qconfig_mapping
  438. }
  439. prepare_kwargs.update(custom_prepare_kwargs)
  440. orig_mod_copy_wrapped = custom_prepare_fn(
  441. orig_mod_copy_wrapped,
  442. **prepare_kwargs)
  443. # attach the wrapper to the model
  444. attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx)
  445. assert not hasattr(mt, attr_name)
  446. setattr(mt, attr_name, orig_mod_copy_wrapped)
  447. # add a call to the wrapper module from the parent graph
  448. insert_after_node = last_added_shadow_node_list[0]
  449. with mt.graph.inserting_after(insert_after_node):
  450. # TODO(future PR): handle fusion patterns where non-first nodes
  451. # need inputs
  452. # pass in all node args and kwargs
  453. new_args = []
  454. for arg in first_node.args:
  455. if isinstance(arg, Node):
  456. new_args.append(arg)
  457. elif isinstance(arg, (list, tuple)) and len(arg) and isinstance(arg[0], Node):
  458. for inner_arg in arg:
  459. if isinstance(inner_arg, Node):
  460. new_args.append(inner_arg)
  461. new_kwargs = {}
  462. for name, old_kwarg in first_node.kwargs.items():
  463. if isinstance(old_kwarg, Node):
  464. new_kwargs[name] = old_kwarg
  465. elif isinstance(old_kwarg, (list, tuple)) and len(old_kwarg):
  466. for inner_old_kwarg in old_kwarg:
  467. # TODO(future PR): clarify why we are adding kwargs to args
  468. new_args.append(inner_old_kwarg)
  469. new_args = tuple(new_args) # type: ignore[assignment]
  470. new_node = mt.graph.call_module(
  471. attr_name, args=new_args, kwargs=new_kwargs)
  472. # add a logger to parent graph to observe the shadow wrapper
  473. logger_mod_orig = _get_logger_for_subgraph(
  474. mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx,
  475. str(qconfig), OutputComparisonLogger, fqn)
  476. attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx)
  477. assert not hasattr(mt, attr_name)
  478. setattr(mt, attr_name, logger_mod_orig)
  479. with mt.graph.inserting_after(new_node):
  480. logger = mt.graph.call_module(attr_name, args=(new_node, last_node), kwargs={})
  481. last_added_shadow_node_list[0] = logger
  482. mt.recompile()
  483. def create_n_transformed_and_logged_copies_of_subgraph(
  484. mt: GraphModule,
  485. subgraph_idx: int,
  486. match_name: str,
  487. nodes_in_this_subgraph: List[Any],
  488. qconfig_mappings: List[QConfigMapping],
  489. list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]],
  490. custom_prepare_fn: Optional[Callable] = None,
  491. custom_prepare_kwargs: Dict[str, Any] = None,
  492. ) -> None:
  493. """
  494. Given a model `mt` and a subgraph_idx, creates the needed copies
  495. of the subgraph for all qconfigs, and instruments them with loggers.
  496. """
  497. # for now, assume that
  498. # 1. the first node has one input
  499. # 2. the last node has one output
  500. # for now, ignore all subgraphs that contain non-nodes (tuples, etc)
  501. # TODO(future PR): implement this
  502. if any(
  503. not isinstance(node, Node)
  504. for node in nodes_in_this_subgraph
  505. ):
  506. return
  507. first_node = nodes_in_this_subgraph[0]
  508. last_node = nodes_in_this_subgraph[-1]
  509. # We used output propagation to populate example values on each
  510. # node. Use the example values from the previous node as the input
  511. # to the current node.
  512. prev_node = get_normalized_nth_input(first_node, mt, 0)
  513. if isinstance(prev_node, list):
  514. example_inputs = [x.traced_result for x in prev_node]
  515. elif isinstance(prev_node, tuple):
  516. example_inputs = (x.traced_result for x in prev_node) # type: ignore[assignment]
  517. else:
  518. # currently some customer models do not have a traced_result in
  519. # every node, so we have to guard for this case since we cannot
  520. # quantize without an example input
  521. # TODO(future PR): add a test case for this once we have an easy
  522. # repro, see https://github.com/pytorch/pytorch/pull/80521/files#r975940489
  523. # for additional context
  524. if hasattr(prev_node, 'traced_result'):
  525. example_inputs = (prev_node.traced_result,) # type: ignore[attr-defined, assignment]
  526. else:
  527. print(
  528. 'unable to get example input for node ' +
  529. f'{first_node.format_node()}, skipping')
  530. return
  531. # If there are no quantization configs for this subgraph, skip adding
  532. # loggers. This reduces memory usage for models where not all layers are
  533. # quantized.
  534. # TODO(future): consider making this configurable
  535. found_at_least_one_qconfig = False
  536. for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
  537. if subgraph_candidate_idx == 0:
  538. # fp32 baseline does not need a qconfig
  539. continue
  540. # a. we have N shadows, so len(qconfig_mappings) is N
  541. # b. we will have the fp32 layer + N shadows, so overall number of
  542. # (original_op) + (*shadows) will be N+1
  543. # c. since `subgraph_candidate_idx` represents (b), we need
  544. # to subtract 1 to query from (a)
  545. node_name_to_qconfig = \
  546. list_of_node_name_to_qconfig[subgraph_candidate_idx - 1]
  547. qconfig = node_name_to_qconfig[first_node.name]
  548. if qconfig is not None:
  549. found_at_least_one_qconfig = True
  550. break
  551. if not found_at_least_one_qconfig:
  552. print('unable to find at least one qconfig for node ' +
  553. f'{first_node.format_node()}, skipping')
  554. return
  555. fqn = _maybe_get_fqn(first_node, mt)
  556. # We want the results to contain the subgraphs in natural order,
  557. # and the graph to also contain shadow wrappers and shadow loggers
  558. # in natural order.
  559. # If we just iterate in reverse, the graph will be in natural
  560. # order but the eventual results will be in reverse order.
  561. # So, we keep track of the last shadow logger we added and
  562. # always insert after it.
  563. last_added_shadow_node_list: List[Optional[Node]] = [None]
  564. for subgraph_candidate_idx in range(len(qconfig_mappings) + 1):
  565. create_one_transformed_and_logged_copy_of_subgraph(
  566. mt, subgraph_idx, subgraph_candidate_idx, first_node,
  567. last_node, fqn, list_of_node_name_to_qconfig,
  568. example_inputs, last_added_shadow_node_list, custom_prepare_fn,
  569. custom_prepare_kwargs)
  570. def create_add_loggers_graph(
  571. model: GraphModule,
  572. subgraphs_dedup: Dict[str, List[Node]],
  573. qconfig_mapping: QConfigMapping,
  574. node_name_to_qconfig: Dict[str, QConfigAny],
  575. ) -> None:
  576. """
  577. Given a model, a model graph partition (currently a set of matched
  578. subgraphs) and instructions how to transform each subgraph
  579. (currently quantizing it according to qconfig_mapping), modifies
  580. the model graph to create an alternate path through the original graph,
  581. with each of the subgraphs quantized. This is useful to compare
  582. propagation error of a transformation such as quantization.
  583. For example, given layer op0 and op1, there are four cases when handling op1:
  584. 1. op0 and op1 quantized
  585. 2. op0 and op1 unquantized
  586. 3. op0 quantized, op1 unquantized
  587. 4. op0 unquantized, op1 quantized
  588. Example input, case 1:
  589. .. code::
  590. x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
  591. \ \ \ \ # noqa: W605
  592. ---> op0_1 -> x1_1 ----> clog op1_1 -> x2_1 ----> clog
  593. Example output, case 1:
  594. .. code::
  595. x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
  596. \ \ \ # noqa: W605
  597. ---> op0_1 -> x1_1 ----> clog -> op1_1 -> x2_1 ----> clog
  598. """
  599. # TODO(future PR): move logger classes to utils to remove circular dependency
  600. from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger
  601. def _get_subgraph_containing_node(node, subgraphs_dedup):
  602. for name, subgraph in subgraphs_dedup.items():
  603. if node in subgraph:
  604. return subgraph
  605. return None
  606. # First, we need to create shadow branches, going from
  607. #
  608. # x0 -> op0 -> x1 -> ...
  609. #
  610. #
  611. # to
  612. #
  613. # x0 -> op0_0 -> x1_0 -> log -> ...
  614. # \ \
  615. # -> op0_1 -> x1_1 -> clog
  616. #
  617. # Later, the outputs of each shadow will be rerouted to calculate
  618. # propagation error.
  619. # Note: we cannot iterate over matched subgraphs because some nodes
  620. # may not be matched. So, we iterate over nodes in the graph, and
  621. # associate them to matched subgraphs if possible.
  622. nodes_to_skip = set()
  623. # for each subgraph, save a mapping from first node of subgraph
  624. # to first and last node of the shadow of this subgraph
  625. orig_first_node_to_shadow_in_node = {}
  626. orig_first_node_to_shadow_out_node = {}
  627. # need to record original list because we will mutate the graph as we go
  628. orig_nodes = list(model.graph.nodes) # type: ignore[union-attr, arg-type]
  629. cur_subgraph_idx = 0
  630. for n in orig_nodes:
  631. if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip:
  632. continue
  633. maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
  634. insert_submodule_copy = False
  635. if maybe_subgraph is not None:
  636. first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
  637. for node_to_skip in maybe_subgraph:
  638. nodes_to_skip.add(node_to_skip)
  639. qconfig = node_name_to_qconfig[first_node.name]
  640. if qconfig is not None:
  641. insert_submodule_copy = True
  642. else:
  643. first_node, last_node = n, n
  644. if insert_submodule_copy:
  645. match_name = first_node.name
  646. create_n_transformed_and_logged_copies_of_subgraph(
  647. model, cur_subgraph_idx, match_name, maybe_subgraph,
  648. [qconfig_mapping], [node_name_to_qconfig],
  649. None, None
  650. )
  651. # find the created shadow module and record it so we
  652. # can find it easily in step 2
  653. expected_shadow_target = f"shadow_wrapper_{cur_subgraph_idx}_1"
  654. new_shadow_mod = None
  655. for maybe_shadow_mod in model.graph.nodes:
  656. if maybe_shadow_mod.op == 'call_module' and \
  657. maybe_shadow_mod.target == expected_shadow_target:
  658. new_shadow_mod = maybe_shadow_mod
  659. break
  660. assert new_shadow_mod is not None
  661. orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod
  662. orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod
  663. else:
  664. # create a copy of the subgraph by only copying FX nodes
  665. # but not copying any parameters, to minimize memory usage
  666. subgraph_to_use = maybe_subgraph if maybe_subgraph is not None \
  667. else [first_node]
  668. # add a regular logger after last_node
  669. qconfig_str = ''
  670. subgraph_candidate_idx = 0
  671. fqn = _maybe_get_fqn(first_node, model)
  672. logger_mod_orig = _get_logger_for_subgraph(
  673. model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx,
  674. qconfig_str, OutputLogger, fqn)
  675. attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
  676. assert not hasattr(model, attr_name)
  677. setattr(model, attr_name, logger_mod_orig)
  678. insertion_point = last_node
  679. with model.graph.inserting_after(insertion_point):
  680. logger = model.graph.call_module(
  681. attr_name, args=(last_node,), kwargs={})
  682. insertion_point = logger
  683. # create a copy of the subgraph
  684. cur_node_orig = first_node
  685. cur_node_copy = None
  686. first_node_copy = None
  687. while cur_node_orig in subgraph_to_use:
  688. # TODO(future PR): make this support all possible args/kwargs
  689. if cur_node_orig is first_node:
  690. new_args = cur_node_orig.args
  691. new_kwargs = cur_node_orig.kwargs
  692. else:
  693. first_arg_for_copy = cur_node_copy
  694. new_args = tuple([first_arg_for_copy, *cur_node_orig.args[1:]]) # noqa: C409
  695. new_kwargs = cur_node_orig.kwargs
  696. # make a copy of cur_node_orig
  697. with model.graph.inserting_after(insertion_point):
  698. cur_node_copy = model.graph.create_node(
  699. cur_node_orig.op,
  700. cur_node_orig.target,
  701. new_args,
  702. new_kwargs,
  703. # cur_node_orig.name, # TODO(future PR): set name explicitly
  704. )
  705. if first_node_copy is None:
  706. first_node_copy = cur_node_copy
  707. # since now only linear subgraphs are supported, all nodes
  708. # except the last one must have only one user
  709. if cur_node_orig != last_node:
  710. assert len(cur_node_orig.users.keys()) == 1
  711. cur_node_orig = list(cur_node_orig.users.keys())[0]
  712. assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX)
  713. insertion_point = cur_node_copy
  714. # add a comparison logger after last_node's copy
  715. subgraph_candidate_idx = 1
  716. logger_mod_orig = _get_logger_for_subgraph(
  717. model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx,
  718. qconfig_str, OutputComparisonLogger, fqn)
  719. attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx)
  720. assert not hasattr(model, attr_name)
  721. setattr(model, attr_name, logger_mod_orig)
  722. with model.graph.inserting_after(insertion_point):
  723. logger = model.graph.call_module(
  724. attr_name, args=(cur_node_copy, last_node), kwargs={})
  725. # save the final node so we can use it in step 2
  726. orig_first_node_to_shadow_in_node[first_node] = first_node_copy
  727. orig_first_node_to_shadow_out_node[first_node] = cur_node_copy
  728. cur_subgraph_idx += 1
  729. model.recompile()
  730. # Now, we go from
  731. #
  732. # x0 -> op0_0 -> x1_0 -> log -> x1 -> op1_0 -> ...
  733. # \ \ \
  734. # -> op0_1 -> x1_1 -> clog -> op1_1 -> ...
  735. #
  736. # to
  737. #
  738. # x0 -> op0_0 -> x1_0 -> log --> x1_0 -> op1_0 -> ...
  739. # \ \
  740. # -> op0_1 -> x1_1 -> clog -> x1_1 -> op1_1 -> ...
  741. #
  742. # sample values of key internal variables for the example above:
  743. #
  744. # orig_first_node_to_shadow_in_node = {op0_0: op0_1, op1_0: op1_1}
  745. # orig_first_node_to_shadow_out_node = {op0_0: op0_1, op1_0: op1_1}
  746. #
  747. # note: for subgraphs with more than one node, in_node will be different
  748. # compared to out_node
  749. nodes_to_skip = set()
  750. for n in orig_nodes:
  751. if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip:
  752. continue
  753. maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup)
  754. if maybe_subgraph is not None:
  755. first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1]
  756. for node_to_skip in maybe_subgraph:
  757. nodes_to_skip.add(node_to_skip)
  758. else:
  759. first_node, last_node = n, n
  760. def maybe_remap_node_to_shadow(node):
  761. """
  762. If unshadowed `node` has a shadow version, return that. If not,
  763. return `node`.
  764. """
  765. if not isinstance(node, Node):
  766. # handle scalars
  767. return node
  768. if node.op in ('placeholder', 'get_attr'):
  769. return node
  770. # Find the shadowed version of this arg from the previous
  771. # subgraph. For this, we need to:
  772. # 1. navigate to the first node of the previous subgraph
  773. # 2. get the output of the shadow wrapper which has (1) as an input
  774. # For now, assume the arg is in matched subgraphs. In the
  775. # future we may have to handle the case where this is not true.
  776. prev_subgraph = _get_subgraph_containing_node(
  777. node, subgraphs_dedup)
  778. if prev_subgraph is None:
  779. prev_subgraph = [node]
  780. prev_first_node = prev_subgraph[0]
  781. prev_shadow_output = \
  782. orig_first_node_to_shadow_out_node[prev_first_node]
  783. return prev_shadow_output
  784. cur_shadow_input = \
  785. orig_first_node_to_shadow_in_node[first_node]
  786. assert cur_shadow_input is not None
  787. cur_shadow_input.args = tree_map(
  788. maybe_remap_node_to_shadow, cur_shadow_input.args)
  789. cur_shadow_input.kwargs = tree_map(
  790. maybe_remap_node_to_shadow, cur_shadow_input.kwargs)
  791. model.recompile()
  792. def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module):
  793. # input: shadow wrapper module
  794. # output if shadow wrapper module has a weighted op:
  795. # (quantize_fn, (quantize_fn_args))
  796. # output if shadow wrapper module doesn't have a weighted op:
  797. # None
  798. # For now, assume that the weight is the second input
  799. # to the shadow module. If that changes, we can fix it later.
  800. placeholders_seen = 0
  801. for shadow_n in shadow_wrapper.graph.nodes: # type: ignore[union-attr]
  802. if shadow_n.op != 'placeholder':
  803. continue
  804. placeholders_seen += 1
  805. if placeholders_seen != 2:
  806. continue
  807. # the subgraph looks like
  808. #
  809. # _input_scale_1 = self._input_scale_1
  810. # _input_zero_point_1 = self._input_zero_point_1
  811. # quantize_per_channel = torch.quantize_per_channel(
  812. # w2_0, _input_scale_1, _input_zero_point_1,
  813. # 0, torch.qint8)
  814. #
  815. # we have `w2_0`, and are navigating this subgraph
  816. # to get `_input_scale_1` and `_input_zero_point_1`
  817. assert len(shadow_n.users) == 1
  818. quant_node = list(shadow_n.users.keys())[0]
  819. new_args: Any = None
  820. if quant_node.target == torch.quantize_per_channel:
  821. _weight, scale_node, zp_node, axis, dtype = quant_node.args
  822. scale_val = getattr_from_fqn(
  823. shadow_wrapper, scale_node.target)
  824. zp_val = getattr_from_fqn(
  825. shadow_wrapper, zp_node.target)
  826. new_args = (scale_val, zp_val, axis, dtype)
  827. else:
  828. assert quant_node.target == torch.quantize_per_tensor
  829. _weight, scale_node, zp_node, dtype = quant_node.args
  830. scale_val = getattr_from_fqn(
  831. shadow_wrapper, scale_node.target)
  832. zp_val = getattr_from_fqn(
  833. shadow_wrapper, zp_node.target)
  834. new_args = (scale_val, zp_val, dtype)
  835. return (quant_node.target, new_args)
  836. return None
  837. def extract_weight_comparison(m: GraphModule) -> NSResultsType:
  838. # example graph:
  839. #
  840. # w1 = self.w1
  841. # b1 = self.b1
  842. # linear = torch._C._nn.linear(x, w1, b1)
  843. # shadow_0_0 = self.shadow_0_0(linear)
  844. # shadow_wrapper_0_1 = self.shadow_wrapper_0_1(x, w1, b1)
  845. # shadow_0_1 = self.shadow_0_1(shadow_wrapper_0_1, linear)
  846. #
  847. # algorithm:
  848. # 1. for each call_function node matching our allowlist:
  849. # 2. if corresponding shadow wrapper exists, extract the weight pair
  850. #
  851. # Note: this is not super robust, but that's ok because this is
  852. # just for legacy customers who depend on the previous two-model version
  853. # of this API. TBD if we need to make this robust.
  854. # Note: modules are not supported, since existing customers only
  855. # use functions.
  856. # TODO(future PR): move this to config
  857. weighted_ops = {
  858. torch.nn.functional.linear,
  859. }
  860. results: NSResultsType = {
  861. 'model': {NSSingleResultValuesType.WEIGHT.value: {}}
  862. }
  863. for n in m.graph.nodes: # type: ignore[union-attr]
  864. if not (n.op == 'call_function' and n.target in weighted_ops):
  865. continue
  866. # Check if we have a corresponding shadow wrapper
  867. # TODO(future PR, if needed): support kwargs
  868. # TODO(future PR, if needed): support multiple shadow users
  869. first_arg = n.args[0]
  870. shadow_wrapper_node = None
  871. for user in first_arg.users:
  872. # TODO(before land): fix string match
  873. if user.op == 'call_module' and \
  874. user.target.startswith('shadow_wrapper'):
  875. shadow_wrapper_node = user
  876. break
  877. if shadow_wrapper_node is None:
  878. continue
  879. shadow_wrapper = getattr_from_fqn(
  880. m, shadow_wrapper_node.target) # type: ignore[arg-type]
  881. weight_info = _get_weight_info_from_shadow_wrapper(
  882. shadow_wrapper)
  883. if weight_info is None:
  884. continue
  885. # get weight
  886. w_node = n.args[1]
  887. w_obj = getattr_from_fqn(m, w_node.target).detach()
  888. # get a quantized version of weight
  889. quant_fn, quant_fn_args_except_first = weight_info
  890. new_args = (w_obj, *quant_fn_args_except_first)
  891. w_obj_q = quant_fn(*new_args)
  892. # add a comparison
  893. ref_node_name = n.name
  894. prev_node_name = n.name
  895. ref_node_type = get_target_type_str(n, m)
  896. prev_node_type = ref_node_type
  897. fqn = None
  898. if hasattr(m, '_node_name_to_scope'):
  899. fqn = m._node_name_to_scope[n.name][0] # type: ignore[index]
  900. comparison = torch.ao.ns.fx.utils.compute_sqnr(w_obj, w_obj_q)
  901. result_fp32 = {
  902. 'res_type': NSSingleResultValuesType.WEIGHT.value,
  903. 'values': [w_obj],
  904. 'prev_node_name': prev_node_name,
  905. 'prev_node_target_type': prev_node_type,
  906. 'ref_node_name': ref_node_name,
  907. 'ref_node_target_type': ref_node_type,
  908. 'index_within_arg': 0,
  909. 'index_of_arg': 0,
  910. 'fqn': fqn,
  911. 'qconfig_str': '',
  912. 'comparisons': [comparison],
  913. 'comparison_fn_name': 'sqnr',
  914. }
  915. result_q = {
  916. 'res_type': NSSingleResultValuesType.WEIGHT.value,
  917. 'values': [w_obj_q],
  918. 'prev_node_name': prev_node_name,
  919. 'prev_node_target_type': prev_node_type,
  920. 'ref_node_name': ref_node_name,
  921. 'ref_node_target_type': ref_node_type,
  922. 'index_within_arg': 0,
  923. 'index_of_arg': 0,
  924. 'fqn': fqn,
  925. 'qconfig_str': '',
  926. 'comparisons': [comparison],
  927. 'comparison_fn_name': 'sqnr',
  928. }
  929. # go from subgraph_n_1 to subgraph_n_0
  930. _1, _2, node_idx, _3 = shadow_wrapper_node.target.split('_')
  931. name_fp32 = f"subgraph_{node_idx}_0"
  932. name_q = f"subgraph_{node_idx}_1"
  933. results['model'][NSSingleResultValuesType.WEIGHT.value][name_fp32] = \
  934. [result_fp32]
  935. results['model'][NSSingleResultValuesType.WEIGHT.value][name_q] = \
  936. [result_q]
  937. return results
  938. # TODO(future PR): redesign this to make it easier to consume outputs
  939. def group_results_by_subgraph(results: NSResultsType) -> Any:
  940. """
  941. Creates a comparison of results
  942. Input:
  943. {
  944. 'model': {
  945. 'node_output': {
  946. 'subgraph_0_0': [
  947. 'values': [torch.tensor(...), ...], ...
  948. 'ref_node_name': ...,
  949. 'ref_node_target_type': ...,
  950. 'qconfig_str': ...,
  951. 'comparisons': [], ...
  952. 'comparison_fn_name': '',
  953. 'fqn': '...',
  954. ],
  955. 'subgraph_0_1': [
  956. 'values': [torch.tensor(...), ...], ...
  957. 'ref_node_name': ...,
  958. 'ref_node_target_type': ...,
  959. 'qconfig_str': ...,
  960. 'comparisons': [torch.tensor(...), ...], ...
  961. 'comparison_fn_name': '...',
  962. 'fqn': '...',
  963. ],
  964. ...
  965. },
  966. },
  967. }
  968. Output:
  969. {
  970. 'subgraph_0': {
  971. '0': {
  972. 'ref_node_name': '...',
  973. 'ref_node_target_type': ...,
  974. 'values': [torch.tensor(...), ...],
  975. 'qconfig_str': None,
  976. 'comparisons': [torch.tensor(...), ...], ...
  977. 'comparison_fn_name': '...',
  978. 'fqn': '...',
  979. },
  980. '1': {
  981. 'ref_node_name': '...',
  982. 'ref_node_target_type': ...,
  983. 'values': [torch.tensor(...), ...],
  984. 'qconfig_str': '...',
  985. 'comparisons': [torch.tensor(...), ...], ...
  986. 'comparison_fn_name': '...',
  987. 'fqn': '...',
  988. },
  989. },
  990. }
  991. """
  992. subgraph_name_to_subgraph_results: Any = collections.defaultdict(dict)
  993. # node_output or weight
  994. key_to_use = list(results['model'].keys())[0]
  995. for subgraph_name_with_idx, subgraph_candidate_results in \
  996. results['model'][key_to_use].items():
  997. # convert from `subgraph_m_n` to `subgraph_m` and `n`
  998. subgraph_str, subgraph_idx, subgraph_candidate_idx = \
  999. subgraph_name_with_idx.split('_')
  1000. subgraph_name = f'{subgraph_str}_{subgraph_idx}'
  1001. subgraph_results = {
  1002. 'ref_node_name': subgraph_candidate_results[0]['ref_node_name'],
  1003. 'ref_node_target_type': subgraph_candidate_results[0]['ref_node_target_type'],
  1004. 'fqn': subgraph_candidate_results[0]['fqn'],
  1005. 'values': subgraph_candidate_results[0]['values'],
  1006. 'qconfig_str': subgraph_candidate_results[0]['qconfig_str'],
  1007. 'comparisons': subgraph_candidate_results[0]['comparisons'],
  1008. 'comparison_fn_name': subgraph_candidate_results[0]['comparison_fn_name'],
  1009. }
  1010. subgraph_name_to_subgraph_results[subgraph_name][subgraph_candidate_idx] = \
  1011. subgraph_results
  1012. return dict(subgraph_name_to_subgraph_results)
  1013. # TODO(future PR): redesign this to make it easier to consume outputs
  1014. def create_results_comparison(
  1015. results_grouped,
  1016. ) -> Any:
  1017. """
  1018. Input:
  1019. {
  1020. 'subgraph_0': {
  1021. '0': {
  1022. 'ref_node_name': '...',
  1023. 'ref_node_target_type': ...,
  1024. 'values': [torch.tensor(...), ...],
  1025. 'qconfig_str': '',
  1026. 'comparisons': [],
  1027. 'comparison_fn_name': '',
  1028. 'fqn': '...',
  1029. },
  1030. '1': {
  1031. 'ref_node_name': '...',
  1032. 'ref_node_target_type': ...,
  1033. 'values': [torch.tensor(...), ...],
  1034. 'qconfig_str': '...',
  1035. 'comparisons': [torch.tensor(...), ...],
  1036. 'comparison_fn_name': 'sqnr',
  1037. 'fqn': '...',
  1038. },
  1039. },
  1040. }
  1041. Output:
  1042. {
  1043. 'subgraph_0': {
  1044. 'ref_node_name': '...',
  1045. 'ref_node_target_type': '...',
  1046. 'fqn': '...',
  1047. 'candidates': {
  1048. '1': {
  1049. 'qconfig_str': ...,
  1050. 'comparison_fn_name': 'sqnr',
  1051. 'cmp_raw': [..., ...],
  1052. 'cmp_mean': ...,
  1053. },
  1054. ...,
  1055. },
  1056. },
  1057. }
  1058. """
  1059. results_comparison = {}
  1060. for subgraph_name, subgraph_results in results_grouped.items():
  1061. candidates = {}
  1062. for subgraph_inner_name, subgraph_inner_result in subgraph_results.items():
  1063. # skip comparing baseline to baseline
  1064. if subgraph_inner_name == '0':
  1065. continue
  1066. # we expect the comparisons to be precalculated from
  1067. # calibration, so we just fetch them here
  1068. cmp_raw = subgraph_inner_result['comparisons']
  1069. cmp_raw_tensor = torch.stack(cmp_raw)
  1070. candidates[subgraph_inner_name] = {
  1071. 'qconfig_str': subgraph_inner_result['qconfig_str'],
  1072. 'comparison_fn_name': subgraph_inner_result['comparison_fn_name'],
  1073. 'cmp_raw': cmp_raw_tensor,
  1074. 'cmp_mean': torch.mean(cmp_raw_tensor),
  1075. }
  1076. results_comparison[subgraph_name] = {
  1077. 'ref_node_name': subgraph_results['0']['ref_node_name'],
  1078. 'ref_node_target_type': subgraph_results['0']['ref_node_target_type'],
  1079. 'fqn': subgraph_results['0']['fqn'],
  1080. 'candidates': candidates,
  1081. }
  1082. return results_comparison
  1083. # TODO(future PR): redesign this to make it easier to consume outputs
  1084. def print_n_shadows_summary(
  1085. results_comparison,
  1086. ) -> None:
  1087. """
  1088. Input:
  1089. {
  1090. 'subgraph_0': {
  1091. 'ref_node_name': 'linear1',
  1092. 'ref_node_target_type': '...',
  1093. 'fqn': '...',
  1094. 'candidates': {
  1095. '1': {
  1096. 'qconfig_str': ...,
  1097. 'comparison_fn_name': ...,
  1098. 'cmp_raw': [45.0, 55.0],
  1099. 'cmp_mean': 50.0,
  1100. },
  1101. ...,
  1102. },
  1103. },
  1104. }
  1105. Prints:
  1106. node_name | node_type | fqn | 0 | 1 | ...
  1107. linear1 | ... | ... | 45.0 | 50.0 | ...
  1108. """
  1109. try:
  1110. from tabulate import tabulate
  1111. except ImportError:
  1112. print("`print_tabular` relies on the library `tabulate`, "
  1113. "which could not be found on this machine. Run `pip "
  1114. "install tabulate` to install the library.")
  1115. return
  1116. results = []
  1117. for subgraph_name, subgraph_data in results_comparison.items():
  1118. mean_all_candidates = [
  1119. candidate['cmp_mean']
  1120. for candidate_name, candidate in subgraph_data['candidates'].items()
  1121. ]
  1122. data_row = [
  1123. subgraph_data['ref_node_name'],
  1124. subgraph_data['ref_node_target_type'],
  1125. subgraph_data['fqn'],
  1126. *mean_all_candidates,
  1127. ]
  1128. results.append(data_row)
  1129. max_candidate_idx_len = -1
  1130. for data_row in results:
  1131. max_candidate_idx_len = max(max_candidate_idx_len, len(data_row[1]))
  1132. candidate_idx_headers = [str(x) for x in range(max_candidate_idx_len)]
  1133. headers = ['node_name', 'node_type', 'fqn', *candidate_idx_headers]
  1134. print(tabulate(results, headers=headers))