_numeric_suite_fx.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025
  1. """
  2. This module contains tooling to compare weights and activations
  3. across models. Example usage::
  4. import copy
  5. import torch
  6. import torch.ao.quantization.quantize_fx as quantize_fx
  7. import torch.ao.ns._numeric_suite_fx as ns
  8. m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)).eval()
  9. mp = quantize_fx.prepare_fx(m, {'': torch.ao.quantization.default_qconfig})
  10. # We convert a copy because we need the original prepared model
  11. # to be available for comparisons, and `quantize_fx.convert_fx` is inplace.
  12. mq = quantize_fx.convert_fx(copy.deepcopy(mp))
  13. #
  14. # Comparing weights
  15. #
  16. # extract weight pairs
  17. weight_comparison = ns.extract_weights('a', mp, 'b', mq)
  18. # add SQNR for each comparison, inplace
  19. ns.extend_logger_results_with_comparison(
  20. weight_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
  21. 'sqnr')
  22. # weight_comparison contains the weights from `mp` and `mq` stored
  23. # in pairs, and can be used for further analysis.
  24. #
  25. # Comparing activations, with error propagation
  26. #
  27. # add loggers
  28. mp_ns, mq_ns = ns.add_loggers(
  29. 'a', copy.deepcopy(mp),
  30. 'b', copy.deepcopy(mq),
  31. ns.OutputLogger)
  32. # send an example datum to capture intermediate activations
  33. datum = torch.randn(1, 1, 1, 1)
  34. mp_ns(datum)
  35. mq_ns(datum)
  36. # extract intermediate activations
  37. act_comparison = ns.extract_logger_info(
  38. mp_ns, mq_ns, ns.OutputLogger, 'b')
  39. # add SQNR for each comparison, inplace
  40. ns.extend_logger_results_with_comparison(
  41. act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
  42. 'sqnr')
  43. # act_comparison contains the activations from `mp_ns` and `mq_ns` stored
  44. # in pairs, and can be used for further analysis.
  45. #
  46. # Comparing activations, without error propagation
  47. #
  48. # create shadow model
  49. mp_shadows_mq = ns.add_shadow_loggers(
  50. 'a', copy.deepcopy(mp),
  51. 'b', copy.deepcopy(mq),
  52. ns.OutputLogger)
  53. # send an example datum to capture intermediate activations
  54. datum = torch.randn(1, 1, 1, 1)
  55. mp_shadows_mq(datum)
  56. # extract intermediate activations
  57. shadow_act_comparison = ns.extract_shadow_logger_info(
  58. mp_shadows_mq, ns.OutputLogger, 'b')
  59. # add SQNR for each comparison, inplace
  60. ns.extend_logger_results_with_comparison(
  61. shadow_act_comparison, 'a', 'b', torch.ao.ns.fx.utils.compute_sqnr,
  62. 'sqnr')
  63. # shadow_act_comparison contains the activations from `mp_ns` and `mq_ns` stored
  64. # in pairs, and can be used for further analysis.
  65. """
  66. import collections
  67. import torch
  68. import torch.nn as nn
  69. import torch.ao.quantization.quantize_fx as quantize_fx
  70. from torch.fx import GraphModule
  71. from torch.fx.graph import Node
  72. from torch.ao.ns.fx.mappings import (
  73. get_base_name_to_sets_of_related_ops,
  74. )
  75. from torch.ao.ns.fx.graph_matcher import (
  76. get_matching_subgraph_pairs,
  77. get_type_a_related_to_b,
  78. )
  79. from .fx.weight_utils import (
  80. extract_weight_from_node,
  81. )
  82. from .fx.graph_passes import (
  83. add_loggers_to_model,
  84. create_a_shadows_b,
  85. )
  86. from .fx.utils import (
  87. rekey_logger_info_on_node_name_of_model,
  88. maybe_add_missing_fqns,
  89. get_target_type_str,
  90. )
  91. from .fx.ns_types import (
  92. NSSingleResultValuesType,
  93. NSResultsType,
  94. NSNodeTargetType,
  95. )
  96. from torch.ao.quantization.backend_config.utils import get_fusion_pattern_to_root_node_getter
  97. from torch.ao.quantization.backend_config import BackendConfig
  98. from torch.ao.quantization.fx.match_utils import _find_matches
  99. from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
  100. from torch.ao.quantization.fx.qconfig_mapping_utils import _generate_node_name_to_qconfig
  101. from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
  102. from torch.ao.quantization.qconfig import QConfigAny
  103. from torch.ao.quantization import QConfigMapping
  104. from torch.ao.ns.fx.n_shadows_utils import (
  105. OutputProp,
  106. _get_dedup_subgraphs,
  107. SHADOW_WRAPPER_NODE_NAME_PREFIX,
  108. group_results_by_subgraph,
  109. create_results_comparison,
  110. print_n_shadows_summary,
  111. create_n_transformed_and_logged_copies_of_subgraph,
  112. create_add_loggers_graph,
  113. extract_weight_comparison,
  114. )
  115. from torch.ao.ns.fx.qconfig_multi_mapping import QConfigMultiMapping
  116. from typing import Dict, Tuple, Callable, List, Optional, Set, Any, Type
  117. RNNReturnType = Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
  118. class OutputLogger(nn.Module):
  119. """
  120. Base class for capturing intermediate values.
  121. """
  122. stats: List[torch.Tensor]
  123. stats_rnn: List[RNNReturnType]
  124. # Mark as impure so that calls to it will not be removed during DCE.
  125. _is_impure = True
  126. def __init__(
  127. self,
  128. ref_node_name: str,
  129. prev_node_name: str,
  130. model_name: str,
  131. ref_name: str,
  132. prev_node_target_type: str,
  133. ref_node_target_type: str,
  134. results_type: str,
  135. index_within_arg: int,
  136. index_of_arg: int,
  137. fqn: Optional[str],
  138. qconfig_str: Optional[str] = '',
  139. ):
  140. super().__init__()
  141. self.stats: List[torch.Tensor] = []
  142. self.stats_rnn: List[RNNReturnType] = []
  143. # name of the node which was responsible for adding this logger
  144. # Note:
  145. # - if we are logging node outputs, this is the same as prev_node_name
  146. # - if we are logging node inputs, this is the name of the node
  147. # whose input this logger is logging.
  148. #
  149. # example, where logger1 is logging input of op1 and logger2 is logging
  150. # the output of op1:
  151. #
  152. # x1 -> logger1 -> op1 -> logger2 -> x2
  153. #
  154. # in this example,
  155. # - logger1's prev_node_name is x1 and ref_node_name is op1
  156. # - logger2's prev_node_name is op1 and ref_node_name is op1
  157. self.ref_node_name = ref_node_name
  158. # name of the node whose output this Logger is capturing
  159. self.prev_node_name = prev_node_name
  160. # name of the model from which the node originated from
  161. self.model_name = model_name
  162. # reference name, used to match loggers from separate models
  163. # to each other
  164. self.ref_name = ref_name
  165. # type of the target of the node whose output this logger is logging
  166. self.prev_node_target_type = prev_node_target_type
  167. # type of the target of the node which was responsible for adding this
  168. # logger
  169. self.ref_node_target_type = ref_node_target_type
  170. # what kind of values are inside of stats
  171. self.results_type = results_type
  172. # index of this node within the arg of the input/output node
  173. # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
  174. self.index_within_arg = index_within_arg
  175. # index of this node within the args of the input/output node
  176. # for example, in add(x1, x2), x2 would have index_of_arg == 1
  177. self.index_of_arg = index_of_arg
  178. # fully qualified name
  179. self.fqn = fqn
  180. # if loggers are added before prepare_fx, but we do not want
  181. # collect results of calibration, only results after convert_fx
  182. # so, we add a flag to control whether this logger collects data
  183. self.enabled = True
  184. # string representation of qconfig
  185. self.qconfig_str = qconfig_str
  186. # this can be turned off to reduce memory usage during calibration
  187. self.save_activations = True
  188. # Note: cannot annotate the type of x because TorchScript does not support
  189. # the Union type.
  190. def forward(self, x):
  191. """
  192. """ # blank docblock to make autodoc happy
  193. # TODO(future PR): consider designing this better, as the difference
  194. # between these two flags is subtle and not obvious.
  195. if not self.enabled:
  196. return x
  197. if not self.save_activations:
  198. return x
  199. # TODO(future PR): consider refactoring this to better reuse the parent
  200. # class
  201. if isinstance(x, torch.Tensor):
  202. self.stats.append(x.detach())
  203. elif isinstance(x, tuple) and len(x) == 2 and len(x[1]) == 2:
  204. new_res = (x[0].detach(), (x[1][0].detach(), x[1][1].detach()))
  205. self.stats_rnn.append(new_res)
  206. return x
  207. def __repr__(self):
  208. clean_dict = {
  209. k: v
  210. for k, v in self.__dict__.items()
  211. # skip nn.Module keys
  212. if (k != 'training') and not k.startswith('_')
  213. }
  214. return f"OutputLogger({clean_dict})"
  215. class OutputComparisonLogger(OutputLogger):
  216. """
  217. Same as OutputLogger, but also requires the original activation
  218. in order to calculate the comparison at calibration time
  219. """
  220. def __init__(self, *args, **kwargs):
  221. super().__init__(*args, **kwargs)
  222. # TODO(future PR): make the comparison function configurable
  223. self.comparison_fn = torch.ao.ns.fx.utils.compute_sqnr
  224. self.comparison_fn_name = 'sqnr'
  225. # precalculated comparisons of logger output versus reference
  226. self.comparisons = []
  227. # precalculated comparisons function
  228. def forward(self, x, x_ref):
  229. """
  230. """ # blank docblock to make autodoc happy
  231. if not self.enabled:
  232. return x
  233. assert isinstance(x, torch.Tensor), 'non-tensor inputs not yet supported'
  234. if self.save_activations:
  235. # save the activation, for debugging
  236. self.stats.append(x.detach())
  237. # save the comparison
  238. self.comparisons.append(self.comparison_fn(x, x_ref))
  239. return x
  240. def __repr__(self):
  241. clean_dict = {
  242. k: v
  243. for k, v in self.__dict__.items()
  244. # skip nn.Module keys
  245. if (k != 'training') and not k.startswith('_')
  246. }
  247. return f"OutputComparisonLogger({clean_dict})"
  248. class NSTracer(quantize_fx.QuantizationTracer):
  249. """
  250. Just like a regular FX quantization tracer, but treats observers and fake_quantize
  251. modules as leaf modules.
  252. """
  253. def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
  254. """
  255. """ # blank docblock to make autodoc happy
  256. if isinstance(m, torch.ao.quantization.ObserverBase):
  257. return True
  258. elif isinstance(m, torch.ao.quantization.FakeQuantizeBase):
  259. return True
  260. return super().is_leaf_module(m, module_qualified_name)
  261. def _extract_weights_one_model(
  262. model_name: str,
  263. model: GraphModule,
  264. nodes_and_names_to_instrument: List[Tuple[Node, str]],
  265. results: NSResultsType,
  266. op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
  267. ) -> None:
  268. torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_one_model")
  269. for node, ref_name in nodes_and_names_to_instrument:
  270. res_type = NSSingleResultValuesType.WEIGHT.value
  271. extracted_weight = extract_weight_from_node(
  272. node, model, op_to_type_to_weight_extraction_fn)
  273. if extracted_weight:
  274. if ref_name not in results:
  275. results[ref_name] = {res_type: {}}
  276. results[ref_name][res_type][model_name] = [extracted_weight]
  277. def _extract_weights_impl(
  278. model_name_a: str,
  279. gm_a: GraphModule,
  280. model_name_b: str,
  281. gm_b: GraphModule,
  282. base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  283. unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  284. op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
  285. ) -> NSResultsType:
  286. torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_weights_impl")
  287. matched_subgraph_pairs = get_matching_subgraph_pairs(
  288. gm_a, gm_b, base_name_to_sets_of_related_ops,
  289. unmatchable_types_map)
  290. # split the subgraph pairs into one data structure for each model
  291. nodes_and_names_to_instrument_a: List[Tuple[Node, str]] = []
  292. nodes_and_names_to_instrument_b: List[Tuple[Node, str]] = []
  293. for match_name, match in matched_subgraph_pairs.items():
  294. subgraph_a, subgraph_b = match
  295. nodes_and_names_to_instrument_a.append((subgraph_a.base_op_node, match_name))
  296. nodes_and_names_to_instrument_b.append((subgraph_b.base_op_node, match_name))
  297. # populate the results, one model at a time
  298. results: NSResultsType = {}
  299. _extract_weights_one_model(
  300. model_name_a, gm_a, nodes_and_names_to_instrument_a, results,
  301. op_to_type_to_weight_extraction_fn)
  302. _extract_weights_one_model(
  303. model_name_b, gm_b, nodes_and_names_to_instrument_b, results,
  304. op_to_type_to_weight_extraction_fn)
  305. # fill in missing fqn entries
  306. maybe_add_missing_fqns(results)
  307. # rekey on names of nodes in gm_b
  308. results = rekey_logger_info_on_node_name_of_model(results, model_name_b)
  309. return results
  310. def extract_weights(
  311. model_name_a: str,
  312. model_a: nn.Module,
  313. model_name_b: str,
  314. model_b: nn.Module,
  315. base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  316. unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  317. op_to_type_to_weight_extraction_fn: Optional[Dict[str, Dict[Callable, Callable]]] = None,
  318. ) -> NSResultsType:
  319. """
  320. Extract weights from model A and model B, and return a comparison.
  321. Args:
  322. model_name_a: string name of model A to use in results
  323. model_a: model A
  324. model_name_b: string name of model B to use in results
  325. model_b: model B
  326. base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
  327. unmatchable_types_map: optional override of unmatchable types, subject to change
  328. op_to_type_to_weight_extraction_fn: optional override of function which extracts weight
  329. from a type, subject to change
  330. Return:
  331. NSResultsType, containing the weight comparisons
  332. """
  333. torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_weights")
  334. if base_name_to_sets_of_related_ops is None:
  335. base_name_to_sets_of_related_ops = \
  336. get_base_name_to_sets_of_related_ops()
  337. type_a_related_to_b = \
  338. get_type_a_related_to_b(base_name_to_sets_of_related_ops)
  339. # TODO(future PR): expose these
  340. skipped_module_names: List[str] = []
  341. skipped_module_classes: List[Callable] = []
  342. tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
  343. tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
  344. gm_a = GraphModule(model_a, tracer_a.trace(model_a))
  345. maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
  346. if maybe_model_a_node_name_to_scope is not None:
  347. gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
  348. gm_b = GraphModule(model_b, tracer_b.trace(model_b))
  349. maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
  350. if maybe_model_b_node_name_to_scope is not None:
  351. gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
  352. return _extract_weights_impl(
  353. model_name_a, gm_a, model_name_b, gm_b, base_name_to_sets_of_related_ops,
  354. unmatchable_types_map, op_to_type_to_weight_extraction_fn)
  355. def _add_loggers_one_model(
  356. model_name: str,
  357. model: GraphModule,
  358. nodes_and_names_to_instrument_inputs: List[Tuple[Node, str, str]],
  359. nodes_and_names_to_instrument_outputs: List[Tuple[Node, str, str]],
  360. logger_cls: Callable,
  361. ) -> nn.Module:
  362. torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_one_model")
  363. # TODO(future PR): do not observe nodes we do not care
  364. # about (both fp32, denylist, etc)
  365. node_to_instrument_inputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
  366. node_to_instrument_outputs_to_ref_name: Dict[Node, Tuple[str, str]] = {}
  367. for node, ref_name, ref_node_type in nodes_and_names_to_instrument_inputs:
  368. node_to_instrument_inputs_to_ref_name[node] = (ref_name, ref_node_type)
  369. for node, ref_name, ref_node_type in nodes_and_names_to_instrument_outputs:
  370. node_to_instrument_outputs_to_ref_name[node] = (ref_name, ref_node_type)
  371. model = add_loggers_to_model(
  372. model, node_to_instrument_inputs_to_ref_name,
  373. node_to_instrument_outputs_to_ref_name, logger_cls, model_name)
  374. return model
  375. def _add_loggers_impl(
  376. name_a: str,
  377. gm_a: GraphModule,
  378. name_b: str,
  379. gm_b: GraphModule,
  380. logger_cls: Callable,
  381. should_log_inputs: bool,
  382. base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  383. unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  384. ) -> Tuple[nn.Module, nn.Module]:
  385. torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_loggers_impl")
  386. matched_subgraph_pairs = get_matching_subgraph_pairs(
  387. gm_a, gm_b,
  388. base_name_to_sets_of_related_ops, unmatchable_types_map)
  389. nodes_and_names_to_instrument_inputs_a = []
  390. nodes_and_names_to_instrument_inputs_b = []
  391. nodes_and_names_to_instrument_outputs_a = []
  392. nodes_and_names_to_instrument_outputs_b = []
  393. for match_name, (subgraph_a, subgraph_b) in matched_subgraph_pairs.items():
  394. ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
  395. ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
  396. # Note: for matching inputs we use start_node, such as observing
  397. # the input of linear in linear-relu
  398. if should_log_inputs:
  399. nodes_and_names_to_instrument_inputs_a.append(
  400. (subgraph_a.start_node, match_name, ref_node_type_a))
  401. nodes_and_names_to_instrument_inputs_b.append(
  402. (subgraph_b.start_node, match_name, ref_node_type_b))
  403. # Note: for matching activations we always use end_node,
  404. # such as observing the output of relu in linear-relu
  405. nodes_and_names_to_instrument_outputs_a.append(
  406. (subgraph_a.end_node, match_name, ref_node_type_a))
  407. nodes_and_names_to_instrument_outputs_b.append(
  408. (subgraph_b.end_node, match_name, ref_node_type_b))
  409. new_model_a = _add_loggers_one_model(
  410. name_a, gm_a, nodes_and_names_to_instrument_inputs_a,
  411. nodes_and_names_to_instrument_outputs_a, logger_cls)
  412. new_model_b = _add_loggers_one_model(
  413. name_b, gm_b, nodes_and_names_to_instrument_inputs_b,
  414. nodes_and_names_to_instrument_outputs_b, logger_cls)
  415. return (new_model_a, new_model_b)
  416. def add_loggers(
  417. name_a: str,
  418. model_a: nn.Module,
  419. name_b: str,
  420. model_b: nn.Module,
  421. logger_cls: Callable,
  422. should_log_inputs : bool = False,
  423. base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  424. unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  425. ) -> Tuple[nn.Module, nn.Module]:
  426. """
  427. Instrument model A and model B with loggers.
  428. Args:
  429. name_a: string name of model A to use in results
  430. model_a: model A
  431. name_b: string name of model B to use in results
  432. model_b: model B
  433. logger_cls: class of Logger to use
  434. base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
  435. unmatchable_types_map: optional override of unmatchable types, subject to change
  436. Return:
  437. Returns a tuple of (model_a_with_loggers, model_b_with_loggers). Modifies both models inplace.
  438. """
  439. torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_loggers")
  440. # TODO(future PR): expose these
  441. skipped_module_names: List[str] = []
  442. skipped_module_classes: List[Callable] = []
  443. tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
  444. tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
  445. gm_a = GraphModule(model_a, tracer_a.trace(model_a))
  446. maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
  447. if maybe_model_a_node_name_to_scope is not None:
  448. gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
  449. gm_b = GraphModule(model_b, tracer_b.trace(model_b))
  450. maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
  451. if maybe_model_b_node_name_to_scope is not None:
  452. gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
  453. return _add_loggers_impl(
  454. name_a, gm_a, name_b, gm_b, logger_cls,
  455. should_log_inputs=should_log_inputs,
  456. base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
  457. unmatchable_types_map=unmatchable_types_map)
  458. def _extract_logger_info_one_model(
  459. model: nn.Module,
  460. results: NSResultsType,
  461. logger_cls: Callable,
  462. ) -> None:
  463. torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._extract_logger_info_one_model")
  464. for gm_name, mod in model.named_modules():
  465. # TODO(future PR): better check when scripted
  466. is_logger = (
  467. isinstance(mod, logger_cls) # type: ignore[arg-type]
  468. or (
  469. isinstance(mod, torch.jit.RecursiveScriptModule)
  470. and mod.original_name == 'OutputLogger'
  471. )
  472. )
  473. if is_logger:
  474. key = mod.ref_name
  475. if key not in results:
  476. results[key] = {}
  477. assert mod.model_name not in results[key], \
  478. f"{mod.model_name} is already present in results"
  479. if mod.results_type not in results[key]:
  480. results[key][mod.results_type] = {}
  481. if mod.model_name not in results[key][mod.results_type]:
  482. results[key][mod.results_type][mod.model_name] = []
  483. stats_to_use = mod.stats
  484. if len(mod.stats_rnn) > 0:
  485. stats_to_use = mod.stats_rnn
  486. data = {
  487. 'type': mod.results_type,
  488. 'values': stats_to_use,
  489. 'ref_node_name': mod.ref_node_name,
  490. 'ref_node_target_type': mod.ref_node_target_type,
  491. 'prev_node_name': mod.prev_node_name,
  492. 'prev_node_target_type': mod.prev_node_target_type,
  493. 'index_within_arg': mod.index_within_arg,
  494. 'index_of_arg': mod.index_of_arg,
  495. 'fqn': mod.fqn,
  496. 'qconfig_str': mod.qconfig_str,
  497. }
  498. if hasattr(mod, 'comparisons'):
  499. data['comparisons'] = mod.comparisons
  500. data['comparison_fn_name'] = mod.comparison_fn_name
  501. else:
  502. data['comparisons'] = []
  503. data['comparison_fn_name'] = ''
  504. results[key][mod.results_type][mod.model_name].append(data)
  505. # ensure the list stays sorted
  506. results[key][mod.results_type][mod.model_name].sort(
  507. key=lambda res:
  508. f"{res['index_of_arg']}:{res['index_within_arg']}"
  509. )
  510. # TODO(future PR): align on naming
  511. # this is equivalent of just the comparison extraction part of `ns.compare_model_outputs`
  512. def extract_logger_info(
  513. model_a: nn.Module,
  514. model_b: nn.Module,
  515. logger_cls: Callable,
  516. model_name_to_use_for_layer_names: str,
  517. ) -> NSResultsType:
  518. """
  519. Traverse all loggers in `model_a` and `model_b`, and extract the logged
  520. information.
  521. Args:
  522. model_a: model A
  523. model_b: model B
  524. logger_cls: class of Logger to use
  525. model_name_to_use_for_layer_names: string name of model to use for
  526. layer names in the output
  527. Return:
  528. NSResultsType, containing the logged comparisons
  529. """
  530. torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_logger_info")
  531. results: NSResultsType = {}
  532. for model in (model_a, model_b):
  533. _extract_logger_info_one_model(model, results, logger_cls)
  534. # fill in missing fqn entries
  535. maybe_add_missing_fqns(results)
  536. # rekey on the name of model b
  537. results = rekey_logger_info_on_node_name_of_model(
  538. results, model_name_to_use_for_layer_names)
  539. return results
  540. def _add_shadow_loggers_impl(
  541. name_a: str,
  542. gm_a: GraphModule,
  543. name_b: str,
  544. gm_b: GraphModule,
  545. logger_cls: Callable,
  546. should_log_inputs: bool,
  547. base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  548. node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  549. unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  550. ) -> nn.Module:
  551. torch._C._log_api_usage_once("quantization_api._numeric_suite_fx._add_shadow_loggers_impl")
  552. matched_subgraph_pairs = get_matching_subgraph_pairs(
  553. gm_a, gm_b, base_name_to_sets_of_related_ops,
  554. unmatchable_types_map)
  555. gm_a_shadows_b = create_a_shadows_b(
  556. name_a, gm_a, name_b, gm_b, matched_subgraph_pairs, logger_cls,
  557. should_log_inputs=should_log_inputs,
  558. node_type_to_io_type_map=node_type_to_io_type_map)
  559. return gm_a_shadows_b
  560. def add_shadow_loggers(
  561. name_a: str,
  562. model_a: nn.Module,
  563. name_b: str,
  564. model_b: nn.Module,
  565. logger_cls: Callable,
  566. should_log_inputs: bool = False,
  567. base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  568. node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  569. unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  570. ) -> nn.Module:
  571. """
  572. Instrument model A and model B with shadow loggers.
  573. Args:
  574. name_a: string name of model A to use in results
  575. model_a: model A
  576. name_b: string name of model B to use in results
  577. model_b: model B
  578. logger_cls: class of Logger to use
  579. should_log_inputs: whether to log inputs
  580. base_name_to_sets_of_related_ops: optional override of subgraph base nodes, subject to change
  581. unmatchable_types_map: optional override of unmatchable types, subject to change
  582. """
  583. torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.add_shadow_loggers")
  584. # TODO(future PR): expose these
  585. skipped_module_names: List[str] = []
  586. skipped_module_classes: List[Callable] = []
  587. tracer_a = NSTracer(skipped_module_names, skipped_module_classes)
  588. tracer_b = NSTracer(skipped_module_names, skipped_module_classes)
  589. gm_a = GraphModule(model_a, tracer_a.trace(model_a))
  590. maybe_model_a_node_name_to_scope = _get_observed_graph_module_attr(model_a, 'node_name_to_scope')
  591. if maybe_model_a_node_name_to_scope is not None:
  592. gm_a._node_name_to_scope = maybe_model_a_node_name_to_scope
  593. gm_b = GraphModule(model_b, tracer_b.trace(model_b))
  594. maybe_model_b_node_name_to_scope = _get_observed_graph_module_attr(model_b, 'node_name_to_scope')
  595. if maybe_model_b_node_name_to_scope is not None:
  596. gm_b._node_name_to_scope = maybe_model_b_node_name_to_scope
  597. return _add_shadow_loggers_impl(
  598. name_a, gm_a, name_b, gm_b, logger_cls,
  599. should_log_inputs=should_log_inputs,
  600. base_name_to_sets_of_related_ops=base_name_to_sets_of_related_ops,
  601. node_type_to_io_type_map=node_type_to_io_type_map,
  602. unmatchable_types_map=unmatchable_types_map)
  603. def extract_shadow_logger_info(
  604. model_a_shadows_b: nn.Module,
  605. logger_cls: Callable,
  606. model_name_to_use_for_layer_names: str,
  607. ) -> NSResultsType:
  608. """
  609. Traverse all loggers in a shadow model, and extract the logged
  610. information.
  611. Args:
  612. model_a_shadows_b: shadow model
  613. logger_cls: class of Logger to use
  614. model_name_to_use_for_layer_names: string name of model to use for
  615. layer names in the output
  616. Return:
  617. NSResultsType, containing the logged comparisons
  618. """
  619. torch._C._log_api_usage_once("quantization_api._numeric_suite_fx.extract_shadow_logger_info")
  620. results: NSResultsType = collections.defaultdict(dict)
  621. _extract_logger_info_one_model(model_a_shadows_b, results, logger_cls)
  622. # fill in missing fqn entries
  623. maybe_add_missing_fqns(results)
  624. # rekey on the name of model b
  625. results = rekey_logger_info_on_node_name_of_model(
  626. results, model_name_to_use_for_layer_names)
  627. return dict(results)
  628. def extend_logger_results_with_comparison(
  629. results: NSResultsType,
  630. model_name_1: str,
  631. model_name_2: str,
  632. comparison_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
  633. comparison_name: str,
  634. ) -> None:
  635. """
  636. Compares the logged values from `model_name_2` against the corresponding
  637. values in `model_name_1`, using `comparison_fn`. Records the result
  638. in `model_name_2`'s results under `comparison_name`. Modifies `results` inplace.
  639. Args:
  640. results: the result data structure from `extract_logger_info` or
  641. `extract_shadow_logger_info`.
  642. model_name_1: string name of model 1
  643. model_name_2: string name of model 2
  644. comparison_fn: function to compare two Tensors
  645. comparison_name: string name of model to use for
  646. layer names in the output
  647. """
  648. for _, results_type_to_results in results.items():
  649. for _, model_name_to_results in results_type_to_results.items():
  650. assert model_name_1 in model_name_to_results, \
  651. f"{model_name_1} not found in results"
  652. assert model_name_2 in model_name_to_results, \
  653. f"{model_name_2} not found in results"
  654. results_1 = model_name_to_results[model_name_1]
  655. results_2 = model_name_to_results[model_name_2]
  656. for result_2 in results_2:
  657. index_within_arg_2 = result_2['index_within_arg']
  658. index_of_arg_2 = result_2['index_of_arg']
  659. # find corresponding result_1
  660. result_1 = None
  661. for cur_result_1 in results_1:
  662. index_within_arg_1 = cur_result_1['index_within_arg']
  663. index_of_arg_1 = cur_result_1['index_of_arg']
  664. if (
  665. (index_within_arg_1 == index_within_arg_2) and
  666. (index_of_arg_1 == index_of_arg_2)
  667. ):
  668. result_1 = cur_result_1
  669. break
  670. assert result_1 is not None
  671. values_1 = result_1['values']
  672. values_2 = result_2['values']
  673. result_2[comparison_name] = []
  674. for value_1, value_2 in zip(values_1, values_2):
  675. comparison_result = comparison_fn(value_1, value_2)
  676. result_2[comparison_name].append(comparison_result)
  677. def prepare_n_shadows_model(
  678. model: torch.nn.Module,
  679. example_inputs: Any,
  680. qconfig_multi_mapping: QConfigMultiMapping,
  681. backend_config: BackendConfig,
  682. custom_prepare_fn: Optional[Callable] = None,
  683. custom_prepare_kwargs: Optional[Dict[str, Any]] = None,
  684. custom_tracer: Any = None,
  685. ) -> GraphModule:
  686. """
  687. Given a model with a graph with M ops such as
  688. args_kwargs_m -> op_m -> output_m
  689. And a set of N qconfigs for each op, creates a new model, with
  690. each of the subgraph of `op_m` transformed into
  691. .. code::
  692. |---------> op_m_n -> log_m_n
  693. | /
  694. args_kwargs_m ---------> op_m -> log_m_0
  695. Where op_m_n is op_m wrapped in a submodule and transformed with
  696. qconfig_n, and its inner graph looks like
  697. .. code::
  698. args_m -------- op_m_prepared_with_qconfig_n -> out_m_n
  699. /
  700. kwargs_m ---
  701. This is useful for testing different quantization of multiple layers in
  702. a single pass through the model.
  703. High level TODOs for future PRs:
  704. * figure out a better way to name the output structure
  705. * return a results data structure instead of printing it out
  706. * add examples to docblocks
  707. """
  708. if custom_tracer is None:
  709. tracer = quantize_fx.QuantizationTracer([], [])
  710. else:
  711. tracer = custom_tracer
  712. mt = torch.fx.GraphModule(model, tracer.trace(model))
  713. # this is necessary to ensure logger FQNs get populated
  714. mt._node_name_to_scope = tracer.node_name_to_scope
  715. # run example input propagation, we need this to call prepare_fx on
  716. # individual subgraphs
  717. output_prop = OutputProp(mt)
  718. output_prop.propagate(*example_inputs)
  719. # Find the set of subgraphs in the original graph which we need to
  720. # consider.
  721. modules = dict(mt.named_modules(remove_duplicate=False))
  722. patterns = _get_pattern_to_quantize_handlers(backend_config)
  723. root_node_getter_mapping = \
  724. get_fusion_pattern_to_root_node_getter(backend_config)
  725. standalone_module_names: List[str] = []
  726. standalone_module_classes: List[Type] = []
  727. custom_module_classes: List[Type] = []
  728. matches = _find_matches(
  729. mt.graph, modules, patterns, root_node_getter_mapping,
  730. standalone_module_names, standalone_module_classes, custom_module_classes)
  731. subgraphs_dedup: Dict[str, List[Node]] = \
  732. _get_dedup_subgraphs(matches)
  733. # generate node to qconfig for each subgraph
  734. # TODO(future PR): deduplicate repeating entries
  735. list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]] = []
  736. for qconfig_mapping in qconfig_multi_mapping.qconfig_mappings_list:
  737. node_name_to_qconfig = _generate_node_name_to_qconfig(
  738. mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope)
  739. list_of_node_name_to_qconfig.append(node_name_to_qconfig)
  740. # For each region in the model, do the following:
  741. # For each qconfig for that region, do the following:
  742. # 1. create a copy of the region wrapped in a module
  743. # 2. pass original args, original kwargs, and expected output to module
  744. # 3. add an output comparison logger and hook it up to compare
  745. # actual output to expected output
  746. # 4. run `prepare_fx` on the module
  747. for (subgraph_idx, (match_name, nodes_in_this_subgraph)) in \
  748. enumerate(subgraphs_dedup.items()):
  749. create_n_transformed_and_logged_copies_of_subgraph(
  750. mt, subgraph_idx, match_name, nodes_in_this_subgraph,
  751. qconfig_multi_mapping.qconfig_mappings_list, list_of_node_name_to_qconfig,
  752. custom_prepare_fn, custom_prepare_kwargs
  753. )
  754. return mt
  755. # TODO(future PR): we should rethink the names of all the PNP APIs
  756. def _prepare_n_shadows_add_loggers_model(
  757. model: torch.nn.Module,
  758. example_inputs: Any,
  759. qconfig_mapping: QConfigMapping,
  760. backend_config: BackendConfig,
  761. ) -> torch.nn.Module:
  762. """
  763. Note: this API is not recommended for wide usage, it is only
  764. provided for customers who need to migrate from the `add_loggers`
  765. API.
  766. This creates a model which provides logging for the following
  767. problem: if we quantize `model` with `qconfig_mapping` and feed
  768. the same input through both models, log the comparisons of
  769. corresponding intermediate layers.
  770. The problem is solved with a single model. Specifically, we
  771. partition `model` into N subgraphs, create a copy of each relevant
  772. subgraph, wrap it in a module, apply the quantization API to that
  773. module, and hook up loggers to measure the comparisons.
  774. Example starting graph:
  775. x0 -> op0 -> x1 -> op1 -> x2
  776. Example config: quantize op0 to int8, do nothing to op1.
  777. The following graph will be created:
  778. .. code::
  779. x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log
  780. \ \ \ # noqa: W605
  781. ---> op0_1 -> x1_1 ----> clog -> op1_0 -> x2_1 ----> clog
  782. Where op0_0 is op0, op0_1 is op0 wrapped in a submodule and quantized
  783. to int8, op1_0 is op1 (appearing in the graph twice), log is a logger,
  784. and clog is a comparison logger.
  785. """
  786. tracer = quantize_fx.QuantizationTracer([], [])
  787. mt = torch.fx.GraphModule(model, tracer.trace(model))
  788. # this is necessary to ensure logger FQNs get populated
  789. mt._node_name_to_scope = tracer.node_name_to_scope
  790. # run example input propagation, we need this to call prepare_fx on
  791. # individual subgraphs
  792. output_prop = OutputProp(mt)
  793. output_prop.propagate(*example_inputs)
  794. # Find the set of subgraphs in the original graph which we need to
  795. # consider.
  796. modules = dict(mt.named_modules(remove_duplicate=False))
  797. patterns = _get_pattern_to_quantize_handlers(backend_config)
  798. root_node_getter_mapping = \
  799. get_fusion_pattern_to_root_node_getter(backend_config)
  800. standalone_module_names: List[str] = []
  801. standalone_module_classes: List[Type] = []
  802. custom_module_classes: List[Type] = []
  803. matches = _find_matches(
  804. mt.graph, modules, patterns, root_node_getter_mapping,
  805. standalone_module_names, standalone_module_classes, custom_module_classes)
  806. subgraphs_dedup: Dict[str, List[Node]] = \
  807. _get_dedup_subgraphs(matches)
  808. # generate node to qconfig for each subgraph
  809. node_name_to_qconfig = _generate_node_name_to_qconfig(
  810. mt, modules, mt.graph, qconfig_mapping, tracer.node_name_to_scope)
  811. # Now, mutate the graph to be the add_loggers graph with propagation
  812. # error.
  813. create_add_loggers_graph(
  814. mt, subgraphs_dedup, qconfig_mapping, node_name_to_qconfig)
  815. return mt
  816. # TODO(future PR): we should rethink the names of all the PNP APIs
  817. def _n_shadows_compare_weights(
  818. model: torch.nn.Module,
  819. example_inputs: Any,
  820. qconfig_mapping: QConfigMapping,
  821. backend_config: BackendConfig,
  822. ) -> NSResultsType:
  823. """
  824. Note: this API is not recommended for wide usage, it is only
  825. provided for customers who need to migrate from the `add_loggers`
  826. API.
  827. """
  828. qconfig_multi_mapping = \
  829. QConfigMultiMapping.from_list_qconfig_mapping([qconfig_mapping])
  830. mp = prepare_n_shadows_model(
  831. model, example_inputs, qconfig_multi_mapping, backend_config)
  832. # passing inputs through the model is necessary to populate
  833. # observers which observe weights with real values
  834. mp(*example_inputs)
  835. mq = convert_n_shadows_model(mp)
  836. weight_comparison = extract_weight_comparison(mq)
  837. return weight_comparison
  838. # TODO(future PR): consider aligning API signature with other similar quantization
  839. # functions (enable_fake_quant, etc)
  840. def loggers_set_enabled(model: torch.nn.Module, enabled: bool) -> None:
  841. """
  842. Sets the `enabled` setting on a `model`'s loggers
  843. """
  844. for name, child in model.named_modules():
  845. if isinstance(child, OutputLogger):
  846. child.enabled = enabled
  847. # TODO(future PR): consider aligning API signature with other similar quantization
  848. # functions (enable_fake_quant, etc)
  849. def loggers_set_save_activations(
  850. model: torch.nn.Module,
  851. save_activations: bool,
  852. ) -> None:
  853. """
  854. Sets the `save_activations` setting on a `model`'s loggers
  855. """
  856. for name, child in model.named_modules():
  857. if isinstance(child, OutputLogger):
  858. child.save_activations = save_activations
  859. def convert_n_shadows_model(
  860. model: GraphModule,
  861. custom_convert_fn: Optional[Callable] = None,
  862. custom_convert_kwargs: Optional[Dict[str, Any]] = None
  863. ) -> GraphModule:
  864. """
  865. Given a model from `prepare_n_shadows_model`, runs `convert_fx`
  866. on each shadow submodule.
  867. """
  868. for node in model.graph.nodes:
  869. # TODO(future PR): consider matching in a safer way than
  870. # node name string match
  871. if node.name.startswith(SHADOW_WRAPPER_NODE_NAME_PREFIX):
  872. orig_mod = getattr(model, node.name)
  873. if custom_convert_fn is None:
  874. converted_mod = torch.ao.quantization.quantize_fx.convert_fx(
  875. orig_mod)
  876. else:
  877. if custom_convert_kwargs is None:
  878. custom_convert_kwargs = {}
  879. converted_mod = custom_convert_fn(orig_mod, **custom_convert_kwargs)
  880. setattr(model, node.name, converted_mod)
  881. return model
  882. def extract_results_n_shadows_model(model: torch.nn.Module) -> NSResultsType:
  883. """
  884. Extracts logger results from `model`.
  885. """
  886. results: NSResultsType = {}
  887. _extract_logger_info_one_model(model, results, OutputLogger)
  888. return results
  889. def print_comparisons_n_shadows_model(results: NSResultsType) -> None:
  890. """
  891. Prints a summary of extracted `results`.
  892. """
  893. results_grouped = group_results_by_subgraph(results)
  894. results_comparison = create_results_comparison(results_grouped)
  895. print_n_shadows_summary(results_comparison)