graph_drawer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import hashlib
  2. import torch
  3. import torch.fx
  4. from typing import Dict, Any, TYPE_CHECKING
  5. from torch.fx.node import _get_qualified_name, _format_arg
  6. from torch.fx.passes.shape_prop import TensorMetadata
  7. from torch.fx._compatibility import compatibility
  8. from itertools import chain
  9. __all__ = ['FxGraphDrawer']
  10. try:
  11. import pydot
  12. HAS_PYDOT = True
  13. except ImportError:
  14. HAS_PYDOT = False
  15. _COLOR_MAP = {
  16. "placeholder": '"AliceBlue"',
  17. "call_module": "LemonChiffon1",
  18. "get_param": "Yellow2",
  19. "get_attr": "LightGrey",
  20. "output": "PowderBlue",
  21. }
  22. _HASH_COLOR_MAP = [
  23. "CadetBlue1",
  24. "Coral",
  25. "DarkOliveGreen1",
  26. "DarkSeaGreen1",
  27. "GhostWhite",
  28. "Khaki1",
  29. "LavenderBlush1",
  30. "LightSkyBlue",
  31. "MistyRose1",
  32. "MistyRose2",
  33. "PaleTurquoise2",
  34. "PeachPuff1",
  35. "Salmon",
  36. "Thistle1",
  37. "Thistle3",
  38. "Wheat1",
  39. ]
  40. _WEIGHT_TEMPLATE = {
  41. "shape": "record",
  42. "fillcolor": "Salmon",
  43. "style": '"filled,rounded"',
  44. "fontcolor": "#000000",
  45. }
  46. if HAS_PYDOT:
  47. @compatibility(is_backward_compatible=False)
  48. class FxGraphDrawer:
  49. """
  50. Visualize a torch.fx.Graph with graphviz
  51. Basic usage:
  52. g = FxGraphDrawer(symbolic_traced, "resnet18")
  53. with open("a.svg", "w") as f:
  54. f.write(g.get_dot_graph().create_svg())
  55. """
  56. def __init__(
  57. self,
  58. graph_module: torch.fx.GraphModule,
  59. name: str,
  60. ignore_getattr: bool = False,
  61. ignore_parameters_and_buffers: bool = False,
  62. skip_node_names_in_args: bool = True,
  63. ):
  64. self._name = name
  65. self._dot_graphs = {
  66. name: self._to_dot(
  67. graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args
  68. )
  69. }
  70. for node in graph_module.graph.nodes:
  71. if node.op != "call_module":
  72. continue
  73. leaf_node = self._get_leaf_node(graph_module, node)
  74. if not isinstance(leaf_node, torch.fx.GraphModule):
  75. continue
  76. self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(
  77. leaf_node,
  78. f"{name}_{node.target}",
  79. ignore_getattr,
  80. ignore_parameters_and_buffers,
  81. skip_node_names_in_args,
  82. )
  83. def get_dot_graph(self, submod_name=None) -> pydot.Dot:
  84. if submod_name is None:
  85. return self.get_main_dot_graph()
  86. else:
  87. return self.get_submod_dot_graph(submod_name)
  88. def get_main_dot_graph(self) -> pydot.Dot:
  89. return self._dot_graphs[self._name]
  90. def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
  91. return self._dot_graphs[f"{self._name}_{submod_name}"]
  92. def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]:
  93. return self._dot_graphs
  94. def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
  95. template = {
  96. "shape": "record",
  97. "fillcolor": "#CAFFE3",
  98. "style": '"filled,rounded"',
  99. "fontcolor": "#000000",
  100. }
  101. if node.op in _COLOR_MAP:
  102. template["fillcolor"] = _COLOR_MAP[node.op]
  103. else:
  104. # Use a random color for each node; based on its name so it's stable.
  105. target_name = node._pretty_print_target(node.target)
  106. target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16)
  107. template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)]
  108. return template
  109. def _get_leaf_node(
  110. self, module: torch.nn.Module, node: torch.fx.Node
  111. ) -> torch.nn.Module:
  112. py_obj = module
  113. assert isinstance(node.target, str)
  114. atoms = node.target.split(".")
  115. for atom in atoms:
  116. if not hasattr(py_obj, atom):
  117. raise RuntimeError(
  118. str(py_obj) + " does not have attribute " + atom + "!"
  119. )
  120. py_obj = getattr(py_obj, atom)
  121. return py_obj
  122. def _typename(self, target: Any) -> str:
  123. if isinstance(target, torch.nn.Module):
  124. ret = torch.typename(target)
  125. elif isinstance(target, str):
  126. ret = target
  127. else:
  128. ret = _get_qualified_name(target)
  129. # Escape "{" and "}" to prevent dot files like:
  130. # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc
  131. # which triggers `Error: bad label format (...)` from dot
  132. return ret.replace("{", r"\{").replace("}", r"\}")
  133. def _get_node_label(
  134. self,
  135. module: torch.fx.GraphModule,
  136. node: torch.fx.Node,
  137. skip_node_names_in_args: bool,
  138. ) -> str:
  139. def _get_str_for_args_kwargs(arg):
  140. if isinstance(arg, tuple):
  141. prefix, suffix = r"|args=(\l", r",\n)\l"
  142. arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg]
  143. elif isinstance(arg, dict):
  144. prefix, suffix = r"|kwargs={\l", r",\n}\l"
  145. arg_strs_list = [
  146. f"{k}: {_format_arg(v, max_list_len=8)}"
  147. for k, v in arg.items()
  148. ]
  149. else: # Fall back to nothing in unexpected case.
  150. return ""
  151. # Strip out node names if requested.
  152. if skip_node_names_in_args:
  153. arg_strs_list = [a for a in arg_strs_list if "%" not in a]
  154. if len(arg_strs_list) == 0:
  155. return ""
  156. arg_strs = prefix + r",\n".join(arg_strs_list) + suffix
  157. return arg_strs.replace("{", r"\{").replace("}", r"\}")
  158. label = "{" + f"name=%{node.name}|op_code={node.op}\n"
  159. if node.op == "call_module":
  160. leaf_module = self._get_leaf_node(module, node)
  161. label += r"\n" + self._typename(leaf_module) + r"\n|"
  162. extra = ""
  163. if hasattr(leaf_module, "__constants__"):
  164. extra = r"\n".join(
  165. [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr]
  166. )
  167. label += extra + r"\n"
  168. else:
  169. label += f"|target={self._typename(node.target)}" + r"\n"
  170. if len(node.args) > 0:
  171. label += _get_str_for_args_kwargs(node.args)
  172. if len(node.kwargs) > 0:
  173. label += _get_str_for_args_kwargs(node.kwargs)
  174. label += f"|num_users={len(node.users)}" + r"\n"
  175. tensor_meta = node.meta.get('tensor_meta')
  176. label += self._tensor_meta_to_label(tensor_meta)
  177. return label + "}"
  178. def _tensor_meta_to_label(self, tm) -> str:
  179. if tm is None:
  180. return ""
  181. elif isinstance(tm, TensorMetadata):
  182. return self._stringify_tensor_meta(tm)
  183. elif isinstance(tm, list):
  184. result = ""
  185. for item in tm:
  186. result += self._tensor_meta_to_label(item)
  187. return result
  188. elif isinstance(tm, dict):
  189. result = ""
  190. for k, v in tm.items():
  191. result += self._tensor_meta_to_label(v)
  192. return result
  193. elif isinstance(tm, tuple):
  194. result = ""
  195. for item in tm:
  196. result += self._tensor_meta_to_label(item)
  197. return result
  198. else:
  199. raise RuntimeError(f"Unsupported tensor meta type {type(tm)}")
  200. def _stringify_tensor_meta(self, tm: TensorMetadata) -> str:
  201. result = ""
  202. if not hasattr(tm, "dtype"):
  203. print("tm", tm)
  204. result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n"
  205. result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n"
  206. result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n"
  207. result += "|" + "stride" + "=" + str(tm.stride) + r"\n"
  208. if tm.is_quantized:
  209. assert tm.qparams is not None
  210. assert "qscheme" in tm.qparams
  211. qscheme = tm.qparams["qscheme"]
  212. if qscheme in {
  213. torch.per_tensor_affine,
  214. torch.per_tensor_symmetric,
  215. }:
  216. result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
  217. result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
  218. elif qscheme in {
  219. torch.per_channel_affine,
  220. torch.per_channel_symmetric,
  221. torch.per_channel_affine_float_qparams,
  222. }:
  223. result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n"
  224. result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n"
  225. result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n"
  226. else:
  227. raise RuntimeError(f"Unsupported qscheme: {qscheme}")
  228. result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n"
  229. return result
  230. def _get_tensor_label(self, t: torch.Tensor) -> str:
  231. return str(t.dtype) + str(list(t.shape)) + r"\n"
  232. def _to_dot(
  233. self,
  234. graph_module: torch.fx.GraphModule,
  235. name: str,
  236. ignore_getattr: bool,
  237. ignore_parameters_and_buffers: bool,
  238. skip_node_names_in_args: bool,
  239. ) -> pydot.Dot:
  240. """
  241. Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
  242. If ignore_parameters_and_buffers is True, the parameters and buffers
  243. created with the module will not be added as nodes and edges.
  244. """
  245. dot_graph = pydot.Dot(name, rankdir="TB")
  246. for node in graph_module.graph.nodes:
  247. if ignore_getattr and node.op == "get_attr":
  248. continue
  249. style = self._get_node_style(node)
  250. dot_node = pydot.Node(
  251. node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args), **style
  252. )
  253. dot_graph.add_node(dot_node)
  254. def get_module_params_or_buffers():
  255. for pname, ptensor in chain(
  256. leaf_module.named_parameters(), leaf_module.named_buffers()
  257. ):
  258. pname1 = node.name + "." + pname
  259. label1 = (
  260. pname1 + "|op_code=get_" + "parameter"
  261. if isinstance(ptensor, torch.nn.Parameter)
  262. else "buffer" + r"\l"
  263. )
  264. dot_w_node = pydot.Node(
  265. pname1,
  266. label="{" + label1 + self._get_tensor_label(ptensor) + "}",
  267. **_WEIGHT_TEMPLATE,
  268. )
  269. dot_graph.add_node(dot_w_node)
  270. dot_graph.add_edge(pydot.Edge(pname1, node.name))
  271. if node.op == "call_module":
  272. leaf_module = self._get_leaf_node(graph_module, node)
  273. if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule):
  274. get_module_params_or_buffers()
  275. for node in graph_module.graph.nodes:
  276. if ignore_getattr and node.op == "get_attr":
  277. continue
  278. for user in node.users:
  279. dot_graph.add_edge(pydot.Edge(node.name, user.name))
  280. return dot_graph
  281. else:
  282. if not TYPE_CHECKING:
  283. @compatibility(is_backward_compatible=False)
  284. class FxGraphDrawer:
  285. def __init__(self, graph_module: torch.fx.GraphModule, name: str, ignore_getattr: bool = False):
  286. raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install '
  287. 'pydot through your favorite Python package manager.')