graph_passes.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950
  1. import torch
  2. from torch.fx import GraphModule, map_arg
  3. from torch.fx.graph import Graph, Node
  4. from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
  5. from .utils import (
  6. get_node_first_input_and_output_type,
  7. getattr_from_fqn,
  8. NodeInputOrOutputType,
  9. return_first_non_observer_node,
  10. get_number_of_non_param_args,
  11. get_target_type_str,
  12. get_arg_indices_of_inputs_to_log,
  13. get_node_input_qparams,
  14. op_type_supports_shadowing,
  15. get_normalized_nth_input,
  16. )
  17. from .ns_types import (
  18. NSSingleResultValuesType,
  19. NSSubgraph,
  20. NSNodeTargetType,
  21. )
  22. from torch.ao.ns.fx.mappings import (
  23. get_node_type_to_io_type_map,
  24. )
  25. from torch.ao.quantization.observer import _is_activation_post_process
  26. from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set
  27. def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
  28. fqn = None
  29. if hasattr(gm, '_node_name_to_scope'):
  30. # fqn on observers is not present, because they do not
  31. # exist when the fqns are created during tracing. If this is
  32. # an observer, get the fqn of the node being observed.
  33. node_to_use_for_fqn = node
  34. if node.op == 'call_module':
  35. assert isinstance(node.target, str)
  36. module = getattr_from_fqn(gm, node.target)
  37. if _is_activation_post_process(module):
  38. node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
  39. fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
  40. return fqn # type: ignore[return-value]
  41. def _insert_logger_after_node(
  42. node: Node,
  43. gm: GraphModule,
  44. logger_cls: Callable,
  45. logger_node_name_suffix: str,
  46. ref_node_name: str,
  47. model_name: str,
  48. ref_name: str,
  49. ref_node_target_type: str,
  50. results_type: str,
  51. index_within_arg: int,
  52. index_of_arg: int,
  53. fqn: Optional[str],
  54. ) -> Node:
  55. """
  56. Given a starting graph of
  57. prev_node -> node -> next_node
  58. This function creates a new logger_cls obj and adds it
  59. after node, resulting in
  60. prev_node -> node -> logger_obj -> next_node
  61. """
  62. # create new name
  63. logger_node_name = \
  64. get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
  65. target_type = get_target_type_str(node, gm)
  66. # create the logger object
  67. logger_obj = logger_cls(
  68. ref_node_name, node.name, model_name, ref_name, target_type,
  69. ref_node_target_type,
  70. results_type, index_within_arg, index_of_arg, fqn)
  71. # attach the logger object to the parent module
  72. setattr(gm, logger_node_name, logger_obj)
  73. logger_node = node.graph.create_node(
  74. 'call_module', logger_node_name, (node,), {})
  75. return logger_node
  76. def add_loggers_to_model(
  77. gm: GraphModule,
  78. node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
  79. node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
  80. logger_cls: Callable,
  81. model_name: str,
  82. ) -> GraphModule:
  83. """
  84. Takes the graph of gm, adds loggers to the output
  85. of each node in nodes_to_instrument. Returns a GraphModule with the new
  86. graph.
  87. """
  88. new_graph = Graph()
  89. env: Dict[str, Any] = {}
  90. modules = dict(gm.named_modules())
  91. def load_arg(a):
  92. return map_arg(a, lambda node: env[node.name])
  93. for node in gm.graph.nodes:
  94. if node.op == 'output':
  95. new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
  96. continue
  97. if (
  98. (node in node_to_instrument_inputs_to_ref_node_name) or
  99. (node in node_to_instrument_outputs_to_ref_node_name)
  100. ):
  101. fqn = _maybe_get_fqn(node, gm)
  102. if node in node_to_instrument_inputs_to_ref_node_name:
  103. ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[node]
  104. # Ops such add and mul are special because either
  105. # one or two of the first two arguments can be tensors,
  106. # and if one argument is a tensor it can be first or
  107. # second (x + 1 versus 1 + x).
  108. arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
  109. for node_arg_idx in arg_indices_to_log:
  110. node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
  111. if type(node_arg) == Node:
  112. # create a single input logger
  113. prev_node = env[node_arg.name]
  114. env[node_arg.name] = _insert_logger_after_node(
  115. prev_node, gm, logger_cls, '_ns_logger_', node.name,
  116. model_name, ref_name, ref_node_type,
  117. NSSingleResultValuesType.NODE_INPUT.value,
  118. index_within_arg=0, index_of_arg=node_arg_idx,
  119. fqn=fqn)
  120. elif type(node_arg) == torch.fx.immutable_collections.immutable_list:
  121. # create N input loggers, one for each node
  122. for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type]
  123. prev_node = env[arg.name]
  124. env[prev_node.name] = _insert_logger_after_node(
  125. prev_node, gm, logger_cls, '_ns_logger_', node.name,
  126. model_name, ref_name, ref_node_type,
  127. NSSingleResultValuesType.NODE_INPUT.value,
  128. index_within_arg=arg_idx, index_of_arg=node_arg_idx,
  129. fqn=fqn)
  130. else:
  131. pass
  132. # ensure env is populated with base node
  133. # Note: runs for both inputs and outputs
  134. env[node.name] = new_graph.node_copy(node, load_arg)
  135. if node in node_to_instrument_outputs_to_ref_node_name:
  136. ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[node]
  137. # add the logger after the base node
  138. env[node.name] = _insert_logger_after_node(
  139. env[node.name], gm, logger_cls, '_ns_logger_', node.name,
  140. model_name, ref_name, ref_node_type,
  141. NSSingleResultValuesType.NODE_OUTPUT.value,
  142. index_within_arg=0, index_of_arg=0, fqn=fqn)
  143. else:
  144. env[node.name] = new_graph.node_copy(node, load_arg)
  145. new_gm = GraphModule(gm, new_graph)
  146. return new_gm
  147. def _insert_quantize_per_tensor_node(
  148. prev_node_c: Node,
  149. node_a: Node,
  150. gm_b: GraphModule,
  151. graph_c: Graph,
  152. scale: Union[torch.Tensor, float],
  153. zero_point: Union[torch.Tensor, int],
  154. dtype_cast_name: str,
  155. ) -> Node:
  156. # copy scale
  157. scale_node_name = \
  158. get_new_attr_name_with_prefix(
  159. node_a.name + '_input_scale_')(gm_b)
  160. setattr(gm_b, scale_node_name, scale)
  161. scale_node = graph_c.create_node(
  162. 'get_attr', scale_node_name, (), {}, scale_node_name)
  163. # copy zero_point
  164. zero_point_node_name = \
  165. get_new_attr_name_with_prefix(
  166. node_a.name + '_input_zero_point_')(gm_b)
  167. setattr(gm_b, zero_point_node_name, zero_point)
  168. zero_point_node = graph_c.create_node(
  169. 'get_attr', zero_point_node_name, (), {}, zero_point_node_name)
  170. # create the quantize_per_tensor call
  171. return graph_c.create_node(
  172. 'call_function', torch.quantize_per_tensor,
  173. (prev_node_c, scale_node, zero_point_node, torch.quint8), {},
  174. dtype_cast_name)
  175. def _insert_dtype_cast_after_node(
  176. node_a: Node,
  177. node_c: Node,
  178. prev_node_c: Union[Node, List[Node]],
  179. gm_a: GraphModule,
  180. gm_b: GraphModule,
  181. graph_c: Graph,
  182. node_name_prefix: str,
  183. logger_cls: Callable,
  184. node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
  185. ) -> Union[Node, List[Node]]:
  186. """
  187. Given a starting graph C (derived from graph B) of
  188. ... -> prev_node_c -> node_c -> ...
  189. And a corresponding related node_a, inserts the correct dtype
  190. cast node after prev_node_c to cast into the dtype expected
  191. by node_a, resulting in:
  192. dtype_cast
  193. /
  194. ... -> prev_node_c -> node_c -> ...
  195. For example, if node_c is an int8 op and node_a is an fp32 op, this function
  196. will insert a dequant.
  197. """
  198. dtype_cast_op = None
  199. dtype_cast_mod_cls = None
  200. dtype_cast_method = None
  201. dtype_cast_method_dtype = None
  202. dtype_cast_scale = None
  203. dtype_cast_zero_point = None
  204. node_input_type_a, _node_output_type_a = \
  205. get_node_first_input_and_output_type(
  206. node_a, gm_a, logger_cls, node_type_to_io_type_map)
  207. node_input_type_c, _node_output_type_c = \
  208. get_node_first_input_and_output_type(
  209. node_c, gm_b, logger_cls, node_type_to_io_type_map)
  210. if (
  211. (node_input_type_a == NodeInputOrOutputType.FP32 and
  212. node_input_type_c == NodeInputOrOutputType.INT8) or
  213. (node_input_type_a == NodeInputOrOutputType.FP32 and
  214. node_input_type_c == NodeInputOrOutputType.FP16) or
  215. # TODO(future PR): determine the actual dtype of node_c,
  216. # the current code only works because dequantize works with
  217. # multiple input dtypes.
  218. (node_input_type_a == NodeInputOrOutputType.FP32 and
  219. node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)
  220. ):
  221. dtype_cast_op = torch.dequantize
  222. elif (
  223. node_input_type_a == node_input_type_c and
  224. node_input_type_a != NodeInputOrOutputType.UNKNOWN
  225. ):
  226. dtype_cast_mod_cls = torch.nn.Identity
  227. elif (
  228. node_input_type_a == NodeInputOrOutputType.INT8 and
  229. node_input_type_c == NodeInputOrOutputType.FP32
  230. ):
  231. # int8 shadows fp32, the dtype cast needs to quantize to int8
  232. # with the right qparams.
  233. node_a_input_qparams = get_node_input_qparams(
  234. node_a, gm_a, node_type_to_io_type_map)
  235. if node_a_input_qparams is not None:
  236. dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment]
  237. dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
  238. elif (
  239. node_input_type_a == NodeInputOrOutputType.FP16 and
  240. node_input_type_c == NodeInputOrOutputType.FP32
  241. ):
  242. dtype_cast_method = 'to'
  243. dtype_cast_method_dtype = torch.float16
  244. else:
  245. raise AssertionError(
  246. f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
  247. f"{node_input_type_a} {node_a.format_node()} needs to be implemented")
  248. if isinstance(prev_node_c, Node):
  249. new_dtype_cast_name = \
  250. get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
  251. if dtype_cast_op:
  252. if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
  253. return _insert_quantize_per_tensor_node(
  254. prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale,
  255. dtype_cast_zero_point, new_dtype_cast_name)
  256. else:
  257. return graph_c.create_node(
  258. 'call_function', dtype_cast_op, (prev_node_c,), {},
  259. new_dtype_cast_name)
  260. elif dtype_cast_method:
  261. return graph_c.create_node(
  262. 'call_method', dtype_cast_method,
  263. (prev_node_c, dtype_cast_method_dtype), {}, new_dtype_cast_name)
  264. else:
  265. assert dtype_cast_mod_cls
  266. dtype_cast_mod = dtype_cast_mod_cls()
  267. setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
  268. return graph_c.create_node(
  269. 'call_module', new_dtype_cast_name, (prev_node_c,), {},
  270. new_dtype_cast_name)
  271. elif isinstance(prev_node_c, list):
  272. results = []
  273. for prev_node_c_inner in prev_node_c:
  274. new_dtype_cast_name = \
  275. get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
  276. if dtype_cast_op:
  277. # TODO(future PR): add handling for quantize_per_tensor
  278. new_dtype_cast_node = graph_c.create_node(
  279. 'call_function', dtype_cast_op, (prev_node_c_inner,), {},
  280. new_dtype_cast_name)
  281. results.append(new_dtype_cast_node)
  282. else:
  283. assert dtype_cast_mod_cls
  284. dtype_cast_mod = dtype_cast_mod_cls()
  285. setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
  286. new_dtype_cast_node = graph_c.create_node(
  287. 'call_module', new_dtype_cast_name, (prev_node_c_inner,), {},
  288. new_dtype_cast_name)
  289. results.append(new_dtype_cast_node)
  290. return results
  291. else:
  292. raise AssertionError(f"type f{type(prev_node_c)} is not handled")
  293. # TODO(future PR): look into using copy_node API instead
  294. def _copy_node_from_a_to_c(
  295. node_a: Node,
  296. gm_a: GraphModule,
  297. gm_b: GraphModule,
  298. graph_c: Graph,
  299. ) -> Node:
  300. """
  301. Simple copy of node_a to graph_c.
  302. """
  303. if node_a.op == 'get_attr':
  304. node_a_copy_name = \
  305. get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
  306. node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type]
  307. if torch.is_tensor(node_a_obj):
  308. node_a_obj = node_a_obj.detach()
  309. setattr(gm_b, node_a_copy_name, node_a_obj)
  310. node_a_copy = graph_c.create_node(
  311. node_a.op, node_a_copy_name, (), {}, node_a_copy_name)
  312. return node_a_copy
  313. elif node_a.op == 'call_method':
  314. assert node_a.target in ('dequantize', 'to'), \
  315. f"target {node_a.target} is not implemented"
  316. if node_a.target == 'dequantize':
  317. arg_copy = _copy_node_from_a_to_c(
  318. get_normalized_nth_input(node_a, gm_a, 0),
  319. gm_a, gm_b, graph_c) # type: ignore[arg-type]
  320. node_a_copy_name = \
  321. get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
  322. node_a_copy = graph_c.create_node(
  323. node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name)
  324. return node_a_copy
  325. else: # to
  326. arg_copy = _copy_node_from_a_to_c(
  327. get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c) # type: ignore[arg-type]
  328. node_a_copy_name = \
  329. get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
  330. node_a_copy = graph_c.create_node(
  331. node_a.op, node_a.target,
  332. (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)),
  333. {}, node_a_copy_name)
  334. return node_a_copy
  335. else:
  336. raise AssertionError(
  337. f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented")
  338. def _can_insert_copy_of_subgraph_a(
  339. subgraph_a: NSSubgraph,
  340. gm_a: GraphModule,
  341. num_non_param_args_node_a: int,
  342. ) -> bool:
  343. """
  344. This function returns `False` if the input subgraph cannot be copied by
  345. `_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
  346. that there is a corner case logic for which copy is not yet implemented.
  347. """
  348. # populate the list of nodes we need to check
  349. nodes = []
  350. cur_node = subgraph_a.end_node
  351. while cur_node != subgraph_a.start_node:
  352. nodes.append(cur_node)
  353. cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
  354. nodes.append(cur_node)
  355. nodes.reverse()
  356. def _can_insert(node_a_arg, gm_a):
  357. if isinstance(node_a_arg, Node):
  358. arg_a = return_first_non_observer_node(node_a_arg, gm_a)
  359. if arg_a.op == 'call_method':
  360. return arg_a.target in ('dequantize', 'to')
  361. elif arg_a.op == 'get_attr':
  362. return True
  363. else:
  364. return False
  365. elif isinstance(node_a_arg, (list, tuple)):
  366. for el in node_a_arg:
  367. if not isinstance(el, Node):
  368. return False
  369. return True
  370. # For each node, check if we handle the copy behavior. This follows the
  371. # logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
  372. for node_a in nodes:
  373. local_num_non_param_args_node_a = num_non_param_args_node_a \
  374. if node_a is nodes[0] else 1
  375. norm_args_kwargs = node_a.normalized_arguments(
  376. gm_a, normalize_to_only_use_kwargs=True)
  377. if norm_args_kwargs is not None:
  378. norm_args, norm_kwargs = norm_args_kwargs
  379. else:
  380. norm_args, norm_kwargs = node_a.args, node_a.kwargs
  381. cur_idx = 0
  382. while cur_idx < len(norm_args):
  383. if cur_idx == 0:
  384. pass
  385. elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
  386. pass
  387. else:
  388. if not _can_insert(norm_args[cur_idx], gm_a):
  389. return False
  390. cur_idx += 1
  391. for kwarg_name, kwarg_val in norm_kwargs.items():
  392. # stitch the inputs from base graph
  393. if cur_idx == 0:
  394. pass
  395. elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
  396. pass
  397. else:
  398. if not _can_insert(kwarg_val, gm_a):
  399. return False
  400. cur_idx += 1
  401. return True
  402. def _insert_copy_of_subgraph_a_after_input_node_c(
  403. input_node_c: Union[Node, List[Node]],
  404. input_node_c_2: Optional[Union[Node, List[Node]]],
  405. subgraph_a: NSSubgraph,
  406. gm_a: GraphModule,
  407. gm_b: GraphModule,
  408. node_name_prefix: str,
  409. ) -> Node:
  410. """
  411. TODO(before land): real docblock
  412. """
  413. if isinstance(input_node_c, Node):
  414. graph_c = input_node_c.graph
  415. else:
  416. assert isinstance(input_node_c, list)
  417. graph_c = input_node_c[0].graph
  418. # create a sequential list of the subgraphs' nodes from start to end,
  419. # because we need to add the nodes to graph C in non-reverse order
  420. nodes_of_a = [subgraph_a.end_node]
  421. cur_node = subgraph_a.end_node
  422. while cur_node != subgraph_a.start_node:
  423. cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
  424. nodes_of_a.insert(0, cur_node)
  425. # go through nodes of a in order, and insert them into the graph of c
  426. # sequentially
  427. cur_node_a = nodes_of_a[0]
  428. cur_node_c = _insert_copy_of_node_a_after_input_node_c(
  429. input_node_c,
  430. input_node_c_2,
  431. cur_node_a,
  432. gm_a,
  433. gm_b,
  434. node_name_prefix)
  435. for cur_idx_a in range(1, len(nodes_of_a)):
  436. cur_node_a = nodes_of_a[cur_idx_a]
  437. prev_node_c = cur_node_c # previous added node is the input to next node
  438. cur_node_c = _insert_copy_of_node_a_after_input_node_c(
  439. prev_node_c,
  440. # TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph
  441. None,
  442. cur_node_a,
  443. gm_a,
  444. gm_b,
  445. node_name_prefix)
  446. # return the last inserted node
  447. return cur_node_c
  448. def _insert_copy_of_node_a_after_input_node_c(
  449. input_node_c: Union[Node, List[Node]],
  450. input_node_c_2: Optional[Union[Node, List[Node]]],
  451. node_a: Node,
  452. gm_a: GraphModule,
  453. gm_b: GraphModule,
  454. node_name_prefix: str,
  455. ) -> Node:
  456. """
  457. Assume that node_a from graph_a has
  458. args (input, (input2)?, arg1, ...), and
  459. kwargs {kw0: kwarg0, ...}
  460. Note: input2 is optional. If it equals to None, we assume that the op
  461. has a single non-param input. If it is specified, we assume that the op
  462. has two non-param inputs.
  463. Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
  464. and creates the corresponding nodes in graph_c. Note: observers are ignored,
  465. so if an arg is an observer we navigate up until we find a non-observer parent.
  466. If node_a is a call_module, points the module pointed to by node_a to gm_b.
  467. Creates the copy of node_a in graph_c, with input as the first arg,
  468. and all other args and kwargs pointing to the copies of the objects
  469. in gm_b created above.
  470. An example in pictures:
  471. graph A:
  472. ========
  473. input -------------> node_a
  474. / / /
  475. (input_2)?----------/ / /
  476. / /
  477. weight -> weight_obs /
  478. /
  479. bias ----------------
  480. graph C (derived from B):
  481. =========================
  482. input_node_c --> node_a_copy
  483. / / /
  484. (input_node_c_2)? / /
  485. / /
  486. weight_copy ----/ /
  487. /
  488. bias_copy ------/
  489. """
  490. if isinstance(input_node_c, Node):
  491. graph_c = input_node_c.graph
  492. else:
  493. assert isinstance(input_node_c, list)
  494. graph_c = input_node_c[0].graph
  495. norm_args_kwargs = node_a.normalized_arguments(
  496. gm_a, normalize_to_only_use_kwargs=True)
  497. if norm_args_kwargs is not None:
  498. norm_args, norm_kwargs = norm_args_kwargs
  499. else:
  500. norm_args, norm_kwargs = node_a.args, node_a.kwargs
  501. new_args = []
  502. new_kwargs = {}
  503. def _copy_arg(arg):
  504. # copy the other inputs from the other graph
  505. if isinstance(arg, Node):
  506. arg = return_first_non_observer_node(arg, gm_a)
  507. arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c)
  508. return arg
  509. elif isinstance(arg, (int, float, torch.dtype)):
  510. return arg
  511. elif isinstance(kwarg_val, (list, tuple)):
  512. for el in kwarg_val:
  513. assert not isinstance(el, Node), \
  514. "handling of Node inside list is not implemented"
  515. return arg
  516. else:
  517. raise AssertionError(
  518. f"handling for kwarg of type {type(kwarg_val)} is not implemented")
  519. cur_idx = 0
  520. while cur_idx < len(norm_args):
  521. if cur_idx == 0:
  522. new_arg = input_node_c
  523. elif cur_idx == 1 and input_node_c_2 is not None:
  524. new_arg = input_node_c_2
  525. else:
  526. new_arg = _copy_arg(norm_args[cur_idx])
  527. new_args.append(new_arg)
  528. cur_idx += 1
  529. for kwarg_name, kwarg_val in norm_kwargs.items():
  530. # stitch the inputs from base graph
  531. if cur_idx == 0:
  532. new_kwargs[kwarg_name] = input_node_c
  533. elif cur_idx == 1 and input_node_c_2 is not None:
  534. new_kwargs[kwarg_name] = input_node_c_2
  535. else:
  536. new_kwargs[kwarg_name] = _copy_arg(kwarg_val)
  537. cur_idx += 1
  538. new_args = tuple(new_args) # type: ignore[assignment]
  539. node_a_shadows_c_name = \
  540. get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
  541. if node_a.op == 'call_module':
  542. # if target is a module, we point to the module from gm_b
  543. new_mod_copy_name = \
  544. get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
  545. # fetch the corresponding module from gm_a
  546. assert isinstance(node_a.target, str)
  547. mod_a = getattr_from_fqn(gm_a, node_a.target)
  548. setattr(gm_b, new_mod_copy_name, mod_a)
  549. node_a_shadows_c = graph_c.create_node(
  550. node_a.op, new_mod_copy_name, new_args,
  551. new_kwargs, node_a_shadows_c_name)
  552. return node_a_shadows_c
  553. else:
  554. assert node_a.op in ('call_function', 'call_method')
  555. node_a_shadows_c = graph_c.create_node(
  556. node_a.op, node_a.target, new_args,
  557. new_kwargs, node_a_shadows_c_name)
  558. return node_a_shadows_c
  559. def create_a_shadows_b(
  560. name_a: str,
  561. gm_a: GraphModule,
  562. name_b: str,
  563. gm_b: GraphModule,
  564. matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
  565. logger_cls: Callable,
  566. should_log_inputs: bool,
  567. node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  568. ) -> GraphModule:
  569. """
  570. Creates a new GraphModule consisting of the graph of C, with the meaningful
  571. nodes of A shadowing the corresponding nodes of B. For example,
  572. Graph A:
  573. a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2
  574. Graph B:
  575. b0 -> op0_int8 -> b1 -> op1_int8 -> b2
  576. matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}
  577. Graph C (A shadows B):
  578. / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1
  579. / /
  580. b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1
  581. In a nutshell, this function does the following for each node pair:
  582. * copies the necessary attributes and modules from gm_a to gm_b,
  583. keeping names unique
  584. * adds a dtype cast op (dequant, quant, etc)
  585. * adds a copy of node_a in gm_b's graph
  586. * adds loggers to the outputs of node_a and node_b
  587. """
  588. if node_type_to_io_type_map is None:
  589. node_type_to_io_type_map = get_node_type_to_io_type_map()
  590. # graph_c is the graph created from copying the nodes of graph_b and inserting
  591. # the shadows with the nodes copied from graph_a
  592. graph_c = Graph()
  593. env_c: Dict[str, Any] = {}
  594. modules = dict(gm_b.named_modules())
  595. def load_arg(a):
  596. return map_arg(a, lambda node: env_c[node.name])
  597. start_node_b_to_matched_subgraph_a_and_name = {}
  598. end_node_b_to_matched_subgraph_a_and_name = {}
  599. for match_name, match in matched_subgraph_pairs.items():
  600. subgraph_a, subgraph_b = match
  601. ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
  602. ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
  603. start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \
  604. (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
  605. end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \
  606. (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
  607. for node_b in gm_b.graph.nodes:
  608. if node_b.op == 'output':
  609. graph_c.output(map_arg(node_b.args[0], load_arg))
  610. continue
  611. # calculate the flags to determine what to do with this node
  612. node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
  613. node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name
  614. if (node_b_is_start_node or node_b_is_end_node):
  615. if node_b_is_start_node:
  616. subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
  617. start_node_b_to_matched_subgraph_a_and_name[node_b]
  618. else:
  619. assert node_b_is_end_node
  620. subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
  621. end_node_b_to_matched_subgraph_a_and_name[node_b]
  622. all_op_types_support_shadowing = (
  623. op_type_supports_shadowing(subgraph_a.start_node) and
  624. op_type_supports_shadowing(node_b)
  625. )
  626. if not all_op_types_support_shadowing:
  627. print(
  628. f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
  629. f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
  630. ', unsupported')
  631. env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
  632. continue
  633. # For both start_node and end_node verify that we know how to do
  634. # the dtype cast. If we do not, skip.
  635. node_input_type_a, node_output_type_a = \
  636. get_node_first_input_and_output_type(
  637. subgraph_a.start_node, gm_a, logger_cls,
  638. node_type_to_io_type_map)
  639. node_input_type_b, node_output_type_b = \
  640. get_node_first_input_and_output_type(
  641. node_b, gm_b, logger_cls,
  642. node_type_to_io_type_map)
  643. node_io_types_known_a_and_b = (
  644. node_input_type_a != NodeInputOrOutputType.UNKNOWN and
  645. node_output_type_a != NodeInputOrOutputType.UNKNOWN and
  646. node_input_type_b != NodeInputOrOutputType.UNKNOWN and
  647. node_output_type_b != NodeInputOrOutputType.UNKNOWN
  648. )
  649. if not node_io_types_known_a_and_b:
  650. print(
  651. f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
  652. f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
  653. ', unknown dtype cast')
  654. env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
  655. continue
  656. # If we are shadowing from fp32 to int8, we need to insert
  657. # quantize_per_tensor call with qparams from the previous node.
  658. # Only do this if we are able to infer these qparams from the graph.
  659. if (
  660. node_input_type_a == NodeInputOrOutputType.INT8 and
  661. node_input_type_b == NodeInputOrOutputType.FP32
  662. ):
  663. node_a_input_qparams = get_node_input_qparams(
  664. subgraph_a.start_node, gm_a, node_type_to_io_type_map)
  665. if not node_a_input_qparams:
  666. print(
  667. f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
  668. f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
  669. ', unknown input qparams')
  670. env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
  671. continue
  672. num_non_param_args_node_a = \
  673. get_number_of_non_param_args(subgraph_a.start_node, gm_a)
  674. if not _can_insert_copy_of_subgraph_a(subgraph_a, gm_a, num_non_param_args_node_a):
  675. print(
  676. f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
  677. f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
  678. ', unhandled logic in subgraph copy')
  679. env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
  680. continue
  681. fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
  682. fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)
  683. if node_b_is_start_node:
  684. # if necessary, log the input of node_c
  685. if should_log_inputs:
  686. prev_node_b = get_normalized_nth_input(node_b, gm_b, 0)
  687. if isinstance(prev_node_b, Node):
  688. prev_node_c = env_c[prev_node_b.name]
  689. env_c[prev_node_c.name] = _insert_logger_after_node(
  690. prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
  691. node_b.name, name_b, ref_name, ref_node_type_b,
  692. NSSingleResultValuesType.NODE_INPUT.value,
  693. index_within_arg=0, index_of_arg=0,
  694. fqn=fqn_base_b)
  695. elif isinstance(prev_node_b, list):
  696. # first, save the prev_node instances, because they
  697. # will be overwritten in the env after the first logger
  698. # is added
  699. prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
  700. for arg_idx, arg in enumerate(prev_node_b):
  701. prev_node_c = prev_node_c_list[arg_idx]
  702. env_c[prev_node_c.name] = _insert_logger_after_node(
  703. prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
  704. node_b.name, name_b, ref_name, ref_node_type_b,
  705. NSSingleResultValuesType.NODE_INPUT.value,
  706. index_within_arg=arg_idx, index_of_arg=0,
  707. fqn=fqn_base_b)
  708. else:
  709. # logging of inputs which are not lists is not supported yet
  710. raise AssertionError(f"type {type(prev_node_b)} is not handled yet")
  711. # subgraph so far:
  712. #
  713. # (prev_node_c)+ -> (logger_c_input)?
  714. # Note: this if statement is always True, spelling it out to clarify code
  715. # intent.
  716. if node_b_is_start_node or node_b_is_end_node:
  717. # ensure env_c is populated with base node
  718. env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
  719. node_c = env_c[node_b.name]
  720. # after this point,
  721. #
  722. # node_a is the original node from graph_a, with parent module gm_a
  723. # node_b is the original node from graph_b, with parent module gm_b
  724. # node_c is the copy of node_b in graph_c
  725. #
  726. # subgraph so far:
  727. #
  728. # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
  729. if node_b_is_start_node:
  730. # cast dtype from the dtype of node_c's input to the dtype of
  731. # node_a's input (dequant, etc)
  732. # prev_node_c = node_c.args[0]
  733. prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
  734. if should_log_inputs:
  735. # skip the input logger when inserting a dtype cast
  736. if isinstance(prev_node_c, Node):
  737. prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
  738. elif isinstance(prev_node_c, list):
  739. prev_node_c = [get_normalized_nth_input(arg, gm_b, 0) for arg in prev_node_c]
  740. dtype_cast_node = _insert_dtype_cast_after_node(
  741. subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c,
  742. node_b.name + '_dtype_cast_', logger_cls,
  743. node_type_to_io_type_map)
  744. # note: not inserting to env_c because all nodes which use the dtype
  745. # casts are copied from graph_a
  746. #
  747. # subgraph so far:
  748. #
  749. # (dtype_cast_node)+
  750. # /
  751. # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
  752. # if input logging is enabled, log the input to the subgraph
  753. if should_log_inputs:
  754. # TODO: explain this
  755. ref_node_name = ''
  756. if isinstance(dtype_cast_node, Node):
  757. dtype_cast_node = _insert_logger_after_node(
  758. dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_',
  759. ref_node_name, name_a, ref_name, ref_node_type_a,
  760. NSSingleResultValuesType.NODE_INPUT.value,
  761. index_within_arg=0, index_of_arg=0,
  762. fqn=fqn_base_a)
  763. input_logger: Union[Node, List[Node]] = dtype_cast_node
  764. else:
  765. assert isinstance(dtype_cast_node, list)
  766. new_loggers = []
  767. for dtype_cast_idx, dtype_cast_node_inner in enumerate(dtype_cast_node):
  768. dtype_cast_logger = _insert_logger_after_node(
  769. dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_',
  770. ref_node_name, name_a, ref_name, ref_node_type_a,
  771. NSSingleResultValuesType.NODE_INPUT.value,
  772. index_within_arg=dtype_cast_idx,
  773. index_of_arg=0,
  774. fqn=fqn_base_a)
  775. new_loggers.append(dtype_cast_logger)
  776. dtype_cast_node = new_loggers
  777. input_logger = dtype_cast_node
  778. # subgraph so far:
  779. #
  780. # (dtype_cast_node)+ -> (logger_a_input)?
  781. # /
  782. # prev_node_c -> (logger_c_input)? -> node_start_c
  783. # hook up the new mod_a copy to be in the graph, receiving the
  784. # same inputs as mod_b does, with dtype cast to match a
  785. # Some ops, such as LSTMs, have two non-param inputs. If we have
  786. # such an op, pass the second param as well. Note: dtype casting
  787. # for the second param is not implemented yet, it can be added
  788. # later if there is a use case.
  789. node_c_second_non_param_arg = None
  790. num_non_param_args_node_a = get_number_of_non_param_args(subgraph_a.start_node, gm_a)
  791. if num_non_param_args_node_a == 2:
  792. # node_c_second_non_param_arg = node_c.args[1]
  793. node_c_second_non_param_arg = get_normalized_nth_input(node_c, gm_b, 1)
  794. node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
  795. dtype_cast_node, node_c_second_non_param_arg,
  796. subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_')
  797. env_c[node_a_shadows_c.name] = node_a_shadows_c
  798. # subgraph so far:
  799. #
  800. # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
  801. # /
  802. # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
  803. if should_log_inputs:
  804. # When we created the input logger, we left the ref_node_name
  805. # as an empty string, because the subgraph copy did not exist
  806. # yet. Now that the subgraph copy exists, we modify this name
  807. # to its true value.
  808. # Note: the alternative to this is to create the input logger
  809. # after creating the subgraph, which is slightly more
  810. # complicated. This is the lesser of two evils.
  811. # input_logger = env_c[dtype_cast_node.name]
  812. # Find the first node in the subgraph
  813. cur_node = node_a_shadows_c
  814. while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger:
  815. cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
  816. if isinstance(input_logger, Node):
  817. input_logger_mod = getattr(gm_b, input_logger.name)
  818. input_logger_mod.ref_node_name = cur_node.name
  819. else:
  820. assert isinstance(input_logger, list)
  821. for input_logger_inner in input_logger:
  822. input_logger_mod = getattr(gm_b, input_logger_inner.name)
  823. input_logger_mod.ref_node_name = cur_node.name
  824. # hook up a logger to the mod_a copy
  825. env_c[node_a_shadows_c.name] = _insert_logger_after_node(
  826. env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_',
  827. node_a_shadows_c.name, name_a, ref_name, ref_node_type_a,
  828. NSSingleResultValuesType.NODE_OUTPUT.value,
  829. index_within_arg=0, index_of_arg=0,
  830. fqn=fqn_base_a)
  831. # subgraph so far:
  832. #
  833. # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
  834. # /
  835. # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
  836. if node_b_is_end_node:
  837. # hook up a logger to the mod_b copy
  838. env_c[node_b.name] = _insert_logger_after_node(
  839. env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_',
  840. node_b.name, name_b, ref_name, ref_node_type_b,
  841. NSSingleResultValuesType.NODE_OUTPUT.value,
  842. index_within_arg=0, index_of_arg=0,
  843. fqn=fqn_base_b)
  844. # subgraph so far:
  845. #
  846. # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
  847. # /
  848. # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
  849. #
  850. # Note: node_start_c may be the same node as node_end_c, or they
  851. # may have nodes inbetween.
  852. else:
  853. env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
  854. gm_c = GraphModule(gm_b, graph_c)
  855. return gm_c