_onnx_graph.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from tensorboard.compat.proto.graph_pb2 import GraphDef
  2. from tensorboard.compat.proto.node_def_pb2 import NodeDef
  3. from tensorboard.compat.proto.versions_pb2 import VersionDef
  4. from tensorboard.compat.proto.attr_value_pb2 import AttrValue
  5. from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
  6. def load_onnx_graph(fname):
  7. import onnx
  8. m = onnx.load(fname)
  9. g = m.graph
  10. return parse(g)
  11. def parse(graph):
  12. nodes_proto = []
  13. nodes = []
  14. import itertools
  15. for node in itertools.chain(graph.input, graph.output):
  16. nodes_proto.append(node)
  17. for node in nodes_proto:
  18. print(node.name)
  19. shapeproto = TensorShapeProto(
  20. dim=[
  21. TensorShapeProto.Dim(size=d.dim_value)
  22. for d in node.type.tensor_type.shape.dim
  23. ]
  24. )
  25. nodes.append(
  26. NodeDef(
  27. name=node.name.encode(encoding="utf_8"),
  28. op="Variable",
  29. input=[],
  30. attr={
  31. "dtype": AttrValue(type=node.type.tensor_type.elem_type),
  32. "shape": AttrValue(shape=shapeproto),
  33. },
  34. )
  35. )
  36. for node in graph.node:
  37. _attr = []
  38. for s in node.attribute:
  39. _attr.append(" = ".join([str(f[1]) for f in s.ListFields()]))
  40. attr = ", ".join(_attr).encode(encoding="utf_8")
  41. print(node.output[0])
  42. nodes.append(
  43. NodeDef(
  44. name=node.output[0].encode(encoding="utf_8"),
  45. op=node.op_type,
  46. input=node.input,
  47. attr={"parameters": AttrValue(s=attr)},
  48. )
  49. )
  50. # two pass token replacement, appends opname to object id
  51. mapping = {}
  52. for node in nodes:
  53. mapping[node.name] = node.op + "_" + node.name
  54. return GraphDef(node=nodes, versions=VersionDef(producer=22))