123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- import time
- from collections import defaultdict
- from functools import partial
- from typing import DefaultDict
- import torch
- # Unfortunately it doesn't seem as if there was any way to get TensorBoard to do
- # anything without having TF installed, and so this file has a hard dependency on it
- # as well. It really is a debugging tool, so it doesn't matter.
- try:
- from tensorflow.core.util import event_pb2
- from tensorflow.core.framework import graph_pb2
- from tensorflow.python.summary.writer.writer import FileWriter
- except ImportError:
- raise ImportError("TensorBoard visualization of GraphExecutors requires having "
- "TensorFlow installed") from None
- def dump_tensorboard_summary(graph_executor, logdir):
- with FileWriter(logdir) as w:
- pb_graph = visualize(graph_executor)
- evt = event_pb2.Event(wall_time=time.time(), graph_def=pb_graph.SerializeToString())
- w.add_event(evt)
- def visualize(graph, name_prefix='', pb_graph=None, executors_it=None):
- """Visualizes an independent graph, or a graph executor."""
- value_map = {}
- pb_graph = pb_graph or graph_pb2.GraphDef()
- if isinstance(graph, torch._C.GraphExecutorState):
- visualize_graph_executor(graph, name_prefix, pb_graph,
- partial(visualize, pb_graph=pb_graph))
- return pb_graph
- # Set up an input node
- input_node = pb_graph.node.add(op='input', name=name_prefix + 'input')
- for i, value in enumerate(graph.param_node().outputs()):
- value_map[value.unique()] = name_prefix + 'input:' + str(i)
- visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it)
- # Gather all outputs
- return_node = pb_graph.node.add(op='output', name=name_prefix + 'output')
- for value in graph.return_node().inputs():
- return_node.input.append(value_map[value.unique()])
- return pb_graph
- def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph):
- """Appends the state of a given GraphExecutor to the graph protobuf.
- Args:
- state (GraphExecutor or GraphExecutorState): GraphExecutor to display.
- name_prefix (str): Name prefix of the containing subgraph.
- pb_graph (GraphDef): graph to append to.
- inline_graph (Callable): a function that handles setting up a value_map,
- so that some graphs in here can be inlined. This is necessary, because
- this will simply be `visualize` for the top-level GraphExecutor,
- or `inline_graph` for all nested ones.
- The signature should look like (Graph, name_prefix) -> ().
- It will be called exactly once.
- The strategy is to embed all different configurations as independent subgraphs,
- while inlining the original graph as the one that actually produces the values.
- """
- if state.autograd_fallback_graph is not None:
- visualize(graph=state.autograd_fallback_graph,
- name_prefix=name_prefix + 'autograd_fallback/',
- pb_graph=pb_graph,
- executors_it=iter(state.autograd_fallback.executors()))
- for i, (arg_spec, plan) in enumerate(state.execution_plans.items()):
- subgraph_name = name_prefix + 'plan{}/'.format(i)
- # Create a disconnected node that will keep information regarding the input
- # types of this trace. This is unfortunately a bit too verbose to be included
- # in the subgraph name.
- input_kinds = pb_graph.node.add(op='INPUT_KIND', name=subgraph_name)
- input_kinds.attr['inputs'].s = repr(arg_spec).encode('ascii')
- visualize(plan.graph, subgraph_name, pb_graph, iter(plan.code.executors()))
- # Show gradient as an independent subgraph of this plan
- if plan.grad_executor is not None:
- grad_subgraph_name = subgraph_name + 'grad/'
- visualize(plan.grad_executor, grad_subgraph_name, pb_graph)
- return inline_graph(state.graph, name_prefix + 'original/')
- def visualize_rec(graph, value_map, name_prefix, pb_graph, executors_it=None):
- """Recursive part of visualize (basically skips setting up the input and output nodes)."""
- def inline_graph(subgraph, name, node):
- rec_value_map = {inp.unique(): value_map[val.unique()]
- for inp, val in zip(subgraph.inputs(), node.inputs())}
- visualize_rec(graph=subgraph,
- value_map=rec_value_map,
- name_prefix=name,
- pb_graph=pb_graph)
- for out, val in zip(subgraph.outputs(), node.outputs()):
- value_map[val.unique()] = rec_value_map[out.unique()]
- op_id_counter: DefaultDict[str, int] = defaultdict(int)
- def name_for(node):
- kind = node.kind()[node.kind().index('::') + 2:]
- op_id_counter[kind] += 1
- return kind, name_prefix + kind + '_' + str(op_id_counter[kind])
- def add_fusion_group(node):
- op, name = name_for(node)
- inline_graph(node.g('Subgraph'), name + '/', node)
- def add_graph_executor(node):
- op, name = name_for(node)
- if executors_it is None:
- add_node(node)
- else:
- ge = next(executors_it)
- visualize_graph_executor(ge, name + '/', pb_graph,
- partial(inline_graph, node=node))
- def add_node(node):
- if node.kind() == 'prim::FusionGroup':
- return add_fusion_group(node)
- elif node.kind() == 'prim::GraphExecutor':
- return add_graph_executor(node)
- op, name = name_for(node)
- pb_node = pb_graph.node.add(op=op, name=name)
- for value in node.inputs():
- pb_node.input.append(value_map[value.unique()])
- # TODO: handle attrs
- for i, value in enumerate(node.outputs()):
- value_map[value.unique()] = name + ':' + str(i)
- for node in graph.nodes():
- add_node(node)
|