12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- from tensorboard.compat.proto.graph_pb2 import GraphDef
- from tensorboard.compat.proto.node_def_pb2 import NodeDef
- from tensorboard.compat.proto.versions_pb2 import VersionDef
- from tensorboard.compat.proto.attr_value_pb2 import AttrValue
- from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
- def load_onnx_graph(fname):
- import onnx
- m = onnx.load(fname)
- g = m.graph
- return parse(g)
- def parse(graph):
- nodes_proto = []
- nodes = []
- import itertools
- for node in itertools.chain(graph.input, graph.output):
- nodes_proto.append(node)
- for node in nodes_proto:
- print(node.name)
- shapeproto = TensorShapeProto(
- dim=[
- TensorShapeProto.Dim(size=d.dim_value)
- for d in node.type.tensor_type.shape.dim
- ]
- )
- nodes.append(
- NodeDef(
- name=node.name.encode(encoding="utf_8"),
- op="Variable",
- input=[],
- attr={
- "dtype": AttrValue(type=node.type.tensor_type.elem_type),
- "shape": AttrValue(shape=shapeproto),
- },
- )
- )
- for node in graph.node:
- _attr = []
- for s in node.attribute:
- _attr.append(" = ".join([str(f[1]) for f in s.ListFields()]))
- attr = ", ".join(_attr).encode(encoding="utf_8")
- print(node.output[0])
- nodes.append(
- NodeDef(
- name=node.output[0].encode(encoding="utf_8"),
- op=node.op_type,
- input=node.input,
- attr={"parameters": AttrValue(s=attr)},
- )
- )
- # two pass token replacement, appends opname to object id
- mapping = {}
- for node in nodes:
- mapping[node.name] = node.op + "_" + node.name
- return GraphDef(node=nodes, versions=VersionDef(producer=22))
|