net_min_base.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. import logging
  2. from dataclasses import dataclass
  3. from typing import Any, Callable, Dict, List, Optional, Tuple
  4. import torch
  5. import torch.fx
  6. from torch.fx._compatibility import compatibility
  7. from torch.fx.node import map_arg
  8. from .shape_prop import ShapeProp
  9. from .split_utils import split_by_tags
  10. from .tools_common import (
  11. CALLABLE_NODE_OPS,
  12. FxNetAccFusionsFinder,
  13. Names,
  14. NodeList,
  15. NodeSet,
  16. TensorOrTensors,
  17. Tensors,
  18. )
  19. __all__ = [
  20. "FxNetMinimizerBadModuleError",
  21. "FxNetMinimizerRunFuncError",
  22. "FxNetMinimizerResultMismatchError",
  23. ]
  24. _LOGGER = logging.getLogger(__name__)
  25. @compatibility(is_backward_compatible=False)
  26. class FxNetMinimizerBadModuleError(Exception):
  27. """
  28. Raised if failed to split out a minimize module
  29. """
  30. pass
  31. @compatibility(is_backward_compatible=False)
  32. class FxNetMinimizerRunFuncError(Exception):
  33. """
  34. Raised if error occurs during run_a or run_b functions
  35. """
  36. pass
  37. @compatibility(is_backward_compatible=False)
  38. class FxNetMinimizerResultMismatchError(Exception):
  39. """
  40. Raised if comparing function thinks the results are mismatching.
  41. """
  42. pass
  43. @dataclass
  44. class _MinimizerSettingBase:
  45. """
  46. Args:
  47. `accumulate_error`: Instead of using a's input for both converted module to verify
  48. , use the previous outputs of each converted module as input to accumulate the
  49. errors.
  50. `traverse_method`: "sequential" or "binary" or "accumulate"
  51. Determine the way of traverse the nodes in FX module.
  52. `find_all`: Minimizer will go through the entire model and return all problematic nodes.
  53. `return_intermediate`: If true, when using `run_nodes()` function to run the
  54. model, intermediate results of all the ops will be returned as output.
  55. """
  56. accumulate_error: bool = False
  57. traverse_method: str = "sequential"
  58. find_all: bool = False
  59. return_intermediate: bool = False
  60. def __str__(self):
  61. settings_str = "FX Minimizer Settings:\n"
  62. for k, v in vars(self).items():
  63. settings_str += f"\t{k}: {v}\n"
  64. return settings_str
  65. class _MinimizerBase:
  66. """
  67. This class is used to automatically find problematic nodes in a model. It takes a FX
  68. graphmodule and generate some submodules while traverse the graph. Then two functions
  69. `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn`
  70. will be used to compare the results.
  71. Currently we provides two ways to traverse the graph and generate submodules.
  72. 1. Sequential traversal: this will traverse the graph node by node and generate
  73. one submodule with one sigle node.
  74. 2. Binary searching: this will do a binary search style traversal on the graph.
  75. For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP.
  76. """
  77. def __init__(
  78. self,
  79. module: torch.fx.GraphModule,
  80. sample_input: Tensors,
  81. compare_fn: Callable[
  82. [TensorOrTensors, TensorOrTensors, Names], Tuple[float, bool]
  83. ],
  84. settings: _MinimizerSettingBase,
  85. ):
  86. assert isinstance(module, torch.fx.GraphModule)
  87. self.module = module
  88. self.sample_input = sample_input
  89. self.compare_fn = compare_fn
  90. self.settings = settings
  91. # Stores outputs of run_a function
  92. self.a_outputs: Dict[str, Any] = {}
  93. # Stores outputs of run_b function
  94. self.b_outputs: Dict[str, Any] = {}
  95. # Stores the results of compare_fn
  96. self.results: Dict[Any, Any] = {}
  97. # Stores the report for the runs
  98. self.reports: List[List[str]] = []
  99. # Current iteration
  100. self.iteration: int = 0
  101. callable_nodes = {
  102. node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS
  103. }
  104. ShapeProp(self.module).propagate(*self.sample_input)
  105. self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)()
  106. # Check if number of input in sample_input matches the number of placeholders
  107. placeholders = [
  108. node.name for node in self.module.graph.nodes if node.op == "placeholder"
  109. ]
  110. assert len(placeholders) == len(self.sample_input)
  111. # Store sample_input
  112. for i, name in enumerate(placeholders):
  113. self.a_outputs[name] = sample_input[i]
  114. self.b_outputs[name] = sample_input[i]
  115. def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors:
  116. """
  117. Run `mod` with `inputs` and generate output. The output will be compared with
  118. output of run_b().
  119. """
  120. raise RuntimeError("run_a() is not implemented.")
  121. def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors) -> TensorOrTensors:
  122. """
  123. Run `mod` with `inputs` and generate output. The output will be compared with
  124. output of run_a().
  125. """
  126. raise RuntimeError("run_b() is not implemented.")
  127. def _store_outputs(
  128. self,
  129. a_result: TensorOrTensors,
  130. b_result: TensorOrTensors,
  131. submodule: torch.fx.GraphModule,
  132. ):
  133. """
  134. Store the outputs of self.run_a() and self.run_b() into self.a_outputs and
  135. self.b_outputs, so that we can use them when execute preceding nodes that
  136. use those outputs as inputs.
  137. Args:
  138. a_result: Output of self.run_a(). Could be a tensor or tensors.
  139. b_result: Output of self.run_b(). Could be a tensor or tensors.
  140. submodule: The module that generates a_result and b_result.
  141. """
  142. output_node = next(
  143. node for node in submodule.graph.nodes if node.op == "output"
  144. )
  145. # Only one output
  146. if isinstance(output_node.args[0], torch.fx.Node):
  147. self.a_outputs[output_node.args[0].name] = a_result
  148. self.b_outputs[output_node.args[0].name] = b_result
  149. # Multiple outputs
  150. else:
  151. for i, arg in enumerate(output_node.args[0]):
  152. self.a_outputs[arg.name] = a_result[i]
  153. self.b_outputs[arg.name] = b_result[i]
  154. def _get_submod_inputs(
  155. self, main_module: torch.fx.GraphModule, submod_path: str
  156. ) -> Tuple[Tensors, Tensors]:
  157. """
  158. Try get submodule inputs from stored outputs. If not found then use
  159. torch_glow.get_submod_inputs to get the inputs.
  160. If accumulate_error is False, use a_input for run_a() and run_b()
  161. otherwise use a_input for run_a and b_input for run_b.
  162. Args:
  163. main_module: Top-levlel fx module.
  164. submod_path: Path to the submodule we want to run and compare results.
  165. Returns:
  166. a_input: List of tensor(s) that will be used by run_a() as submodule inputs.
  167. b_input: List of tensor(s) that will be used by run_b() as submodule inputs.
  168. """
  169. a_input = []
  170. b_input = []
  171. submodule = getattr(main_module, submod_path)
  172. placeholders = [
  173. node.name for node in submodule.graph.nodes if node.op == "placeholder"
  174. ]
  175. # If all placeholder can be found in stored outputs, use stored
  176. # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs`
  177. # to get the inputs.
  178. if set(placeholders) <= self.a_outputs.keys():
  179. for name in placeholders:
  180. a_input.append(self.a_outputs[name])
  181. b_input.append(self.b_outputs[name])
  182. else:
  183. if self.settings.accumulate_error:
  184. print(f"Can't find previous stored outputs named {placeholders}!")
  185. def get_inputs(self: torch.nn.Module, inputs: Any):
  186. nonlocal a_input
  187. a_input = inputs
  188. # Use forward hook to get the inputs to the submodule
  189. handle = submodule.register_forward_pre_hook(get_inputs)
  190. main_module(*self.sample_input)
  191. handle.remove()
  192. b_input = a_input
  193. if not self.settings.accumulate_error:
  194. return a_input, a_input
  195. return a_input, b_input
  196. def _tag_nodes(self, selected_nodes: NodeSet):
  197. """
  198. Tag selected nodes with tag "minimize". Nodes with the same tags will
  199. be split to the same submodule afterwards.
  200. Args:
  201. selected_nodes: Nodes that we want to minimize. We will tag those nodes
  202. with "minimize", all preceding nodes with "main_0" and all following
  203. nodes with "main_1".
  204. """
  205. for node in self.module.graph.nodes:
  206. if node.op not in CALLABLE_NODE_OPS:
  207. continue
  208. if node in selected_nodes:
  209. node.tag = "minimize"
  210. elif any(
  211. n.tag in {"minimize", "main_1"}
  212. for n in node.all_input_nodes
  213. if n.op in CALLABLE_NODE_OPS
  214. ):
  215. node.tag = "main_1"
  216. else:
  217. node.tag = "main_0"
  218. def _build_submodule(self, nodes: NodeSet) -> Tuple[torch.fx.GraphModule, str]:
  219. """
  220. Split self.module so that one submodule consists of `nodes` and only `nodes`.
  221. Args:
  222. nodes: Nodes that we want to include in the minimize submodule.
  223. Returns:
  224. split_module (torch.fx.GraphModule): the module after split.
  225. submodule_name (str): the name of the submodule that consists of `nodes`.
  226. """
  227. # Color provided nodes
  228. self._tag_nodes(nodes)
  229. # Split module based on coloring
  230. split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"])
  231. # Find submodule containing colored nodes
  232. submodule_name: str = ""
  233. for child_name, _ in split_module.named_children():
  234. # Skip submodules we're not interested in at the moment
  235. if "minimize" not in child_name:
  236. continue
  237. if submodule_name == "":
  238. submodule_name = child_name
  239. else:
  240. raise FxNetMinimizerBadModuleError(
  241. f"Expected only one minimize submodule with nodes {nodes}"
  242. )
  243. if submodule_name == "":
  244. raise FxNetMinimizerBadModuleError(
  245. f"Minimize submodule was not found with nodes {nodes}"
  246. )
  247. return split_module, submodule_name
  248. def _run_and_compare(
  249. self, split_module: torch.fx.GraphModule, submod_name: str, output_names: Names
  250. ):
  251. """
  252. Run the submodule in `split_module` that has name `submod_name`
  253. using `self.run_a` and `self.run_b` and compare their results.
  254. Args:
  255. split_module: Main module that contains the minimize submodule.
  256. submod_name: Name of the minimize submodule.
  257. output_names: Names of the node we want to output. If None, we
  258. will use the original output.
  259. """
  260. submodule = getattr(split_module, submod_name)
  261. a_input, b_input = self._get_submod_inputs(split_module, submod_name)
  262. if len(self.reports) == 0:
  263. self.reports.append([])
  264. self.iteration = 1
  265. report = self.reports[self.iteration - 1]
  266. report.append("Run and compare ...")
  267. if output_names:
  268. output_nodes: NodeList = []
  269. for node in submodule.graph.nodes:
  270. if node.op == "output":
  271. submodule.graph.erase_node(node)
  272. if node.name in output_names:
  273. output_nodes.append(node)
  274. submodule.graph.output(
  275. output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes)
  276. )
  277. submodule.graph.lint()
  278. submodule.recompile()
  279. # Use name of args in output node as key to store comparison result
  280. for node in submodule.graph.nodes:
  281. if node.op == "output":
  282. result_key = map_arg(node.args, lambda x: x.name)
  283. a_result = self.run_a(submodule, a_input)
  284. b_result = self.run_b(submodule, b_input)
  285. self._store_outputs(a_result, b_result, submodule)
  286. # Compare results
  287. names: Names = output_names
  288. if output_names is None:
  289. names = [str(v) for v in result_key]
  290. numeric_result, bool_result = self.compare_fn(a_result, b_result, names)
  291. self.results[result_key] = numeric_result
  292. report.append(f"Numerical accuracy = {numeric_result}")
  293. if not bool_result:
  294. report.append(f"Result mismatch for {result_key}")
  295. raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}")
  296. def _binary_search_impl(
  297. self, all_nodes: NodeList, start_idx: int, end_idx: int
  298. ) -> NodeSet:
  299. """
  300. Recursive binary search implementation.
  301. """
  302. nodes: NodeList = all_nodes[start_idx:end_idx]
  303. report: List[str] = []
  304. self.reports.append(report)
  305. self.iteration += 1
  306. report.append(f"Binary search iteration {self.iteration}.")
  307. report.append(
  308. f"From node index {start_idx} to {end_idx-1}. "
  309. f"Size of the interested node list is {len(nodes)}"
  310. )
  311. cur_nodes: NodeSet = set(nodes)
  312. for node in nodes:
  313. if node in self.fusions:
  314. cur_nodes.update(self.fusions[node])
  315. try:
  316. split_module, submod_name = self._build_submodule(cur_nodes)
  317. self._run_and_compare(split_module, submod_name, [])
  318. except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError):
  319. if len(nodes) == 1:
  320. report.append(
  321. f"This is the last node in the sub-module. "
  322. f"Search in the current branch is successful with culprit = {cur_nodes}."
  323. )
  324. self.print_report(report)
  325. return cur_nodes
  326. report.append(
  327. "Proceed to split and lower the halves of the current "
  328. "sub-module individually."
  329. )
  330. self.print_report(report)
  331. mid = len(nodes) // 2
  332. culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid)
  333. if len(culprits) != 0 and not self.settings.find_all:
  334. return culprits
  335. culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx)
  336. if len(culprits) == 0:
  337. report.append(
  338. f"Further split and lowering found no errors. "
  339. f"Unable to minimize the submodule with list of nodes: {nodes}"
  340. )
  341. self.print_report(report)
  342. return culprits
  343. else:
  344. report.append("No discrepancy found.")
  345. self.print_report(report)
  346. return set()
  347. def _binary_traverse(self, nodes: NodeList) -> NodeSet:
  348. """
  349. Binary search on `nodes` for culprit.
  350. """
  351. return self._binary_search_impl(nodes, 0, len(nodes))
  352. def _sequential_traverse(self, nodes: NodeList) -> NodeSet:
  353. """
  354. Traverse `nodes` one by one and determine if any of them is a culprit.
  355. """
  356. culprits: NodeSet = set()
  357. for node in nodes:
  358. report: List[str] = []
  359. self.reports.append(report)
  360. self.iteration += 1
  361. report.append(f"Sequential traverse iteration {self.iteration}.")
  362. report.append(f"Visit node: {node.name}")
  363. _LOGGER.info(f"Visit node: {node.name}")
  364. cur_nodes: NodeSet = {node}
  365. if node in self.fusions:
  366. cur_nodes = self.fusions[node]
  367. try:
  368. split_module, submod_name = self._build_submodule(cur_nodes)
  369. self._run_and_compare(split_module, submod_name, [node.name])
  370. self.print_report(report)
  371. except (FxNetMinimizerResultMismatchError):
  372. culprits.add(node)
  373. report.append(f"Found culprit from numeric error: {node}")
  374. self.print_report(report)
  375. if not self.settings.find_all:
  376. return culprits
  377. except (FxNetMinimizerRunFuncError):
  378. culprits.update(cur_nodes)
  379. report.append(f"Found culprit from run error: {node}")
  380. self.print_report(report)
  381. if not self.settings.find_all:
  382. return culprits
  383. return culprits
  384. def _accumulate_traverse(self, nodes: NodeList) -> NodeSet:
  385. culprits: NodeSet = set()
  386. nodes_to_run: NodeSet = set()
  387. # find_all is not supported for accumulate traversal because all the
  388. # ops run on NNPI. So we return after the first op that raises error.
  389. if self.settings.find_all:
  390. print("'Find All' mode is not supported in accumulate traversal.")
  391. return culprits
  392. for node in nodes:
  393. report: List[str] = []
  394. self.reports.append(report)
  395. self.iteration += 1
  396. report.append(f"Accumulate traverse iteration {self.iteration}.")
  397. nodes_to_run.add(node)
  398. node_name = node.name
  399. if node_name is not None and isinstance(node_name, tuple):
  400. node_name = node_name[0]
  401. assert node_name is not None and isinstance(
  402. node_name, str
  403. ), f"minimize: node_name: {node_name}"
  404. report.append(f"Add node: {node_name}")
  405. try:
  406. split_module, submod_name = self._build_submodule(nodes_to_run)
  407. self._run_and_compare(split_module, submod_name, [node_name])
  408. self.print_report(report)
  409. except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError):
  410. culprits.add(node)
  411. report.append(f"Found culprit {node}")
  412. self.print_report(report)
  413. return culprits
  414. return culprits
  415. def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList:
  416. """
  417. Collect nodes in the model that between nodes with name of `start` and `end`.
  418. These two nodes are also included.
  419. """
  420. nodes: NodeList = []
  421. add_node = start is None
  422. for node in self.module.graph.nodes:
  423. if node.op not in CALLABLE_NODE_OPS:
  424. continue
  425. if node.name == start:
  426. add_node = True
  427. if add_node:
  428. nodes.append(node)
  429. if node.name == end:
  430. break
  431. return nodes
  432. def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None):
  433. """
  434. Run part of the model from `start` node to `end` node. If `start` is None
  435. then we start from the beginning of the model. If `end` is None then we
  436. stop at the end of the model.
  437. Args:
  438. start: The name of the node which is the first node of the submodule
  439. we want to run. If set to None, then we'll start with the first
  440. node of the model.
  441. end: The name of the node which is the last node of the submodule we
  442. want to run. If set to None, we'll end with the last node of the
  443. model.
  444. """
  445. nodes = self._collect_nodes(start, end)
  446. cur_nodes = set(nodes)
  447. for node in nodes:
  448. if node in self.fusions:
  449. cur_nodes.update(self.fusions[node])
  450. output_names = []
  451. if self.settings.return_intermediate:
  452. output_names = [node.name for node in nodes]
  453. try:
  454. split_module, submod_name = self._build_submodule(cur_nodes)
  455. self._run_and_compare(split_module, submod_name, output_names)
  456. except (
  457. FxNetMinimizerRunFuncError,
  458. FxNetMinimizerResultMismatchError,
  459. ) as e:
  460. print(e)
  461. def print_report(self, report: List[str]):
  462. for i in range(len(report)):
  463. if i > 0:
  464. print(" . " + report[i])
  465. else:
  466. print(report[i])
  467. def print_reports(self):
  468. for report in self.reports:
  469. self.print_report(report)
  470. def minimize(
  471. self, start: Optional[str] = None, end: Optional[str] = None
  472. ) -> NodeSet:
  473. """
  474. Minimizing the model from node with name `start` to node with name `end` base
  475. on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or
  476. FxNetMinimizerResultMismatchError errors.
  477. Args:
  478. start: The name of the node where we want to start minimizing. If set
  479. to None, then we'll start with the first node of the model.
  480. end: The name of the node where we want to terminate minimizing. If
  481. set to None, we'll end with the last node of the model.
  482. Returns:
  483. nodes: A list of nodes that causes FxNetMinimizerRunFuncError or
  484. FxNetMinimizerResultMismatchError errors during minimizing.
  485. """
  486. print(self.settings)
  487. print(self.module.graph)
  488. nodes = self._collect_nodes(start, end)
  489. if self.settings.traverse_method == "sequential":
  490. return self._sequential_traverse(nodes)
  491. if self.settings.traverse_method == "binary":
  492. return self._binary_traverse(nodes)
  493. if self.settings.traverse_method == "accumulate":
  494. return self._accumulate_traverse(nodes)
  495. raise RuntimeError(f"Unknow traverse method {self.settings.traverse_method}!")