_proto_graph.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from typing import Optional
  2. from tensorboard.compat.proto.node_def_pb2 import NodeDef
  3. from tensorboard.compat.proto.attr_value_pb2 import AttrValue
  4. from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
  5. def attr_value_proto(dtype, shape, s):
  6. """Creates a dict of objects matching
  7. https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
  8. specifically designed for a NodeDef. The values have been
  9. reverse engineered from standard TensorBoard logged data.
  10. """
  11. attr = {}
  12. if s is not None:
  13. attr["attr"] = AttrValue(s=s.encode(encoding="utf_8"))
  14. if shape is not None:
  15. shapeproto = tensor_shape_proto(shape)
  16. attr["_output_shapes"] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto]))
  17. return attr
  18. def tensor_shape_proto(outputsize):
  19. """Creates an object matching
  20. https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto
  21. """
  22. return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize])
  23. def node_proto(
  24. name,
  25. op="UnSpecified",
  26. input=None,
  27. dtype=None,
  28. shape: Optional[tuple] = None,
  29. outputsize=None,
  30. attributes="",
  31. ):
  32. """Creates an object matching
  33. https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto
  34. """
  35. if input is None:
  36. input = []
  37. if not isinstance(input, list):
  38. input = [input]
  39. return NodeDef(
  40. name=name.encode(encoding="utf_8"),
  41. op=op,
  42. input=input,
  43. attr=attr_value_proto(dtype, outputsize, attributes),
  44. )