_tensorboard_vis.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import time
  2. from collections import defaultdict
  3. from functools import partial
  4. from typing import DefaultDict
  5. import torch
  6. # Unfortunately it doesn't seem as if there was any way to get TensorBoard to do
  7. # anything without having TF installed, and so this file has a hard dependency on it
  8. # as well. It really is a debugging tool, so it doesn't matter.
  9. try:
  10. from tensorflow.core.util import event_pb2
  11. from tensorflow.core.framework import graph_pb2
  12. from tensorflow.python.summary.writer.writer import FileWriter
  13. except ImportError:
  14. raise ImportError("TensorBoard visualization of GraphExecutors requires having "
  15. "TensorFlow installed") from None
  16. def dump_tensorboard_summary(graph_executor, logdir):
  17. with FileWriter(logdir) as w:
  18. pb_graph = visualize(graph_executor)
  19. evt = event_pb2.Event(wall_time=time.time(), graph_def=pb_graph.SerializeToString())
  20. w.add_event(evt)
  21. def visualize(graph, name_prefix='', pb_graph=None, executors_it=None):
  22. """Visualizes an independent graph, or a graph executor."""
  23. value_map = {}
  24. pb_graph = pb_graph or graph_pb2.GraphDef()
  25. if isinstance(graph, torch._C.GraphExecutorState):
  26. visualize_graph_executor(graph, name_prefix, pb_graph,
  27. partial(visualize, pb_graph=pb_graph))
  28. return pb_graph
  29. # Set up an input node
  30. input_node = pb_graph.node.add(op='input', name=name_prefix + 'input')
  31. for i, value in enumerate(graph.param_node().outputs()):
  32. value_map[value.unique()] = name_prefix + 'input:' + str(i)
  33. visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it)
  34. # Gather all outputs
  35. return_node = pb_graph.node.add(op='output', name=name_prefix + 'output')
  36. for value in graph.return_node().inputs():
  37. return_node.input.append(value_map[value.unique()])
  38. return pb_graph
  39. def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph):
  40. """Appends the state of a given GraphExecutor to the graph protobuf.
  41. Args:
  42. state (GraphExecutor or GraphExecutorState): GraphExecutor to display.
  43. name_prefix (str): Name prefix of the containing subgraph.
  44. pb_graph (GraphDef): graph to append to.
  45. inline_graph (Callable): a function that handles setting up a value_map,
  46. so that some graphs in here can be inlined. This is necessary, because
  47. this will simply be `visualize` for the top-level GraphExecutor,
  48. or `inline_graph` for all nested ones.
  49. The signature should look like (Graph, name_prefix) -> ().
  50. It will be called exactly once.
  51. The strategy is to embed all different configurations as independent subgraphs,
  52. while inlining the original graph as the one that actually produces the values.
  53. """
  54. if state.autograd_fallback_graph is not None:
  55. visualize(graph=state.autograd_fallback_graph,
  56. name_prefix=name_prefix + 'autograd_fallback/',
  57. pb_graph=pb_graph,
  58. executors_it=iter(state.autograd_fallback.executors()))
  59. for i, (arg_spec, plan) in enumerate(state.execution_plans.items()):
  60. subgraph_name = name_prefix + 'plan{}/'.format(i)
  61. # Create a disconnected node that will keep information regarding the input
  62. # types of this trace. This is unfortunately a bit too verbose to be included
  63. # in the subgraph name.
  64. input_kinds = pb_graph.node.add(op='INPUT_KIND', name=subgraph_name)
  65. input_kinds.attr['inputs'].s = repr(arg_spec).encode('ascii')
  66. visualize(plan.graph, subgraph_name, pb_graph, iter(plan.code.executors()))
  67. # Show gradient as an independent subgraph of this plan
  68. if plan.grad_executor is not None:
  69. grad_subgraph_name = subgraph_name + 'grad/'
  70. visualize(plan.grad_executor, grad_subgraph_name, pb_graph)
  71. return inline_graph(state.graph, name_prefix + 'original/')
  72. def visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it=None):
  73. """Recursive part of visualize (basically skips setting up the input and output nodes)."""
  74. def inline_graph(subgraph, name, node):
  75. rec_value_map = {inp.unique(): value_map[val.unique()]
  76. for inp, val in zip(subgraph.inputs(), node.inputs())}
  77. visualize_rec(graph=subgraph,
  78. value_map=rec_value_map,
  79. name_prefix=name,
  80. pb_graph=pb_graph)
  81. for out, val in zip(subgraph.outputs(), node.outputs()):
  82. value_map[val.unique()] = rec_value_map[out.unique()]
  83. op_id_counter: DefaultDict[str, int] = defaultdict(int)
  84. def name_for(node):
  85. kind = node.kind()[node.kind().index('::') + 2:]
  86. op_id_counter[kind] += 1
  87. return kind, name_prefix + kind + '_' + str(op_id_counter[kind])
  88. def add_fusion_group(node):
  89. op, name = name_for(node)
  90. inline_graph(node.g('Subgraph'), name + '/', node)
  91. def add_graph_executor(node):
  92. op, name = name_for(node)
  93. if executors_it is None:
  94. add_node(node)
  95. else:
  96. ge = next(executors_it)
  97. visualize_graph_executor(ge, name + '/', pb_graph,
  98. partial(inline_graph, node=node))
  99. def add_node(node):
  100. if node.kind() == 'prim::FusionGroup':
  101. return add_fusion_group(node)
  102. elif node.kind() == 'prim::GraphExecutor':
  103. return add_graph_executor(node)
  104. op, name = name_for(node)
  105. pb_node = pb_graph.node.add(op=op, name=name)
  106. for value in node.inputs():
  107. pb_node.input.append(value_map[value.unique()])
  108. # TODO: handle attrs
  109. for i, value in enumerate(node.outputs()):
  110. value_map[value.unique()] = name + ':' + str(i)
  111. for node in graph.nodes():
  112. add_node(node)