summary.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849
  1. import json
  2. import logging
  3. import os
  4. from typing import Optional
  5. import numpy as np
  6. from google.protobuf import struct_pb2
  7. from tensorboard.compat.proto.summary_pb2 import HistogramProto
  8. from tensorboard.compat.proto.summary_pb2 import Summary
  9. from tensorboard.compat.proto.summary_pb2 import SummaryMetadata
  10. from tensorboard.compat.proto.tensor_pb2 import TensorProto
  11. from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
  12. from tensorboard.plugins.custom_scalar import layout_pb2
  13. from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData
  14. from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
  15. from ._convert_np import make_np
  16. from ._utils import _prepare_video, convert_to_HWC
  17. __all__ = ['hparams', 'scalar', 'histogram_raw', 'histogram', 'make_histogram', 'image', 'image_boxes', 'draw_boxes',
  18. 'make_image', 'video', 'make_video', 'audio', 'custom_scalars', 'text', 'pr_curve_raw', 'pr_curve', 'compute_curve',
  19. 'mesh']
  20. logger = logging.getLogger(__name__)
  21. def _calc_scale_factor(tensor):
  22. converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
  23. return 1 if converted.dtype == np.uint8 else 255
  24. def _draw_single_box(
  25. image,
  26. xmin,
  27. ymin,
  28. xmax,
  29. ymax,
  30. display_str,
  31. color="black",
  32. color_text="black",
  33. thickness=2,
  34. ):
  35. from PIL import ImageDraw, ImageFont
  36. font = ImageFont.load_default()
  37. draw = ImageDraw.Draw(image)
  38. (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
  39. draw.line(
  40. [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
  41. width=thickness,
  42. fill=color,
  43. )
  44. if display_str:
  45. text_bottom = bottom
  46. # Reverse list and print from bottom to top.
  47. text_width, text_height = font.getsize(display_str)
  48. margin = np.ceil(0.05 * text_height)
  49. draw.rectangle(
  50. [
  51. (left, text_bottom - text_height - 2 * margin),
  52. (left + text_width, text_bottom),
  53. ],
  54. fill=color,
  55. )
  56. draw.text(
  57. (left + margin, text_bottom - text_height - margin),
  58. display_str,
  59. fill=color_text,
  60. font=font,
  61. )
  62. return image
  63. def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
  64. """Outputs three `Summary` protocol buffers needed by hparams plugin.
  65. `Experiment` keeps the metadata of an experiment, such as the name of the
  66. hyperparameters and the name of the metrics.
  67. `SessionStartInfo` keeps key-value pairs of the hyperparameters
  68. `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS
  69. Args:
  70. hparam_dict: A dictionary that contains names of the hyperparameters
  71. and their values.
  72. metric_dict: A dictionary that contains names of the metrics
  73. and their values.
  74. hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
  75. contains names of the hyperparameters and all discrete values they can hold
  76. Returns:
  77. The `Summary` protobufs for Experiment, SessionStartInfo and
  78. SessionEndInfo
  79. """
  80. import torch
  81. from tensorboard.plugins.hparams.api_pb2 import (
  82. Experiment,
  83. HParamInfo,
  84. MetricInfo,
  85. MetricName,
  86. Status,
  87. DataType,
  88. )
  89. from tensorboard.plugins.hparams.metadata import (
  90. PLUGIN_NAME,
  91. PLUGIN_DATA_VERSION,
  92. EXPERIMENT_TAG,
  93. SESSION_START_INFO_TAG,
  94. SESSION_END_INFO_TAG,
  95. )
  96. from tensorboard.plugins.hparams.plugin_data_pb2 import (
  97. HParamsPluginData,
  98. SessionEndInfo,
  99. SessionStartInfo,
  100. )
  101. # TODO: expose other parameters in the future.
  102. # hp = HParamInfo(name='lr',display_name='learning rate',
  103. # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
  104. # max_value=100))
  105. # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
  106. # description='', dataset_type=DatasetType.DATASET_VALIDATION)
  107. # exp = Experiment(name='123', description='456', time_created_secs=100.0,
  108. # hparam_infos=[hp], metric_infos=[mt], user='tw')
  109. if not isinstance(hparam_dict, dict):
  110. logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.")
  111. raise TypeError(
  112. "parameter: hparam_dict should be a dictionary, nothing logged."
  113. )
  114. if not isinstance(metric_dict, dict):
  115. logger.warning("parameter: metric_dict should be a dictionary, nothing logged.")
  116. raise TypeError(
  117. "parameter: metric_dict should be a dictionary, nothing logged."
  118. )
  119. hparam_domain_discrete = hparam_domain_discrete or {}
  120. if not isinstance(hparam_domain_discrete, dict):
  121. raise TypeError(
  122. "parameter: hparam_domain_discrete should be a dictionary, nothing logged."
  123. )
  124. for k, v in hparam_domain_discrete.items():
  125. if (
  126. k not in hparam_dict
  127. or not isinstance(v, list)
  128. or not all(isinstance(d, type(hparam_dict[k])) for d in v)
  129. ):
  130. raise TypeError(
  131. "parameter: hparam_domain_discrete[{}] should be a list of same type as "
  132. "hparam_dict[{}].".format(k, k)
  133. )
  134. hps = []
  135. ssi = SessionStartInfo()
  136. for k, v in hparam_dict.items():
  137. if v is None:
  138. continue
  139. if isinstance(v, (int, float)):
  140. ssi.hparams[k].number_value = v
  141. if k in hparam_domain_discrete:
  142. domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue(
  143. values=[
  144. struct_pb2.Value(number_value=d)
  145. for d in hparam_domain_discrete[k]
  146. ]
  147. )
  148. else:
  149. domain_discrete = None
  150. hps.append(
  151. HParamInfo(
  152. name=k,
  153. type=DataType.Value("DATA_TYPE_FLOAT64"),
  154. domain_discrete=domain_discrete,
  155. )
  156. )
  157. continue
  158. if isinstance(v, str):
  159. ssi.hparams[k].string_value = v
  160. if k in hparam_domain_discrete:
  161. domain_discrete = struct_pb2.ListValue(
  162. values=[
  163. struct_pb2.Value(string_value=d)
  164. for d in hparam_domain_discrete[k]
  165. ]
  166. )
  167. else:
  168. domain_discrete = None
  169. hps.append(
  170. HParamInfo(
  171. name=k,
  172. type=DataType.Value("DATA_TYPE_STRING"),
  173. domain_discrete=domain_discrete,
  174. )
  175. )
  176. continue
  177. if isinstance(v, bool):
  178. ssi.hparams[k].bool_value = v
  179. if k in hparam_domain_discrete:
  180. domain_discrete = struct_pb2.ListValue(
  181. values=[
  182. struct_pb2.Value(bool_value=d)
  183. for d in hparam_domain_discrete[k]
  184. ]
  185. )
  186. else:
  187. domain_discrete = None
  188. hps.append(
  189. HParamInfo(
  190. name=k,
  191. type=DataType.Value("DATA_TYPE_BOOL"),
  192. domain_discrete=domain_discrete,
  193. )
  194. )
  195. continue
  196. if isinstance(v, torch.Tensor):
  197. v = make_np(v)[0]
  198. ssi.hparams[k].number_value = v
  199. hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
  200. continue
  201. raise ValueError(
  202. "value should be one of int, float, str, bool, or torch.Tensor"
  203. )
  204. content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
  205. smd = SummaryMetadata(
  206. plugin_data=SummaryMetadata.PluginData(
  207. plugin_name=PLUGIN_NAME, content=content.SerializeToString()
  208. )
  209. )
  210. ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])
  211. mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]
  212. exp = Experiment(hparam_infos=hps, metric_infos=mts)
  213. content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
  214. smd = SummaryMetadata(
  215. plugin_data=SummaryMetadata.PluginData(
  216. plugin_name=PLUGIN_NAME, content=content.SerializeToString()
  217. )
  218. )
  219. exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])
  220. sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS"))
  221. content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
  222. smd = SummaryMetadata(
  223. plugin_data=SummaryMetadata.PluginData(
  224. plugin_name=PLUGIN_NAME, content=content.SerializeToString()
  225. )
  226. )
  227. sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
  228. return exp, ssi, sei
  229. def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
  230. """Outputs a `Summary` protocol buffer containing a single scalar value.
  231. The generated Summary has a Tensor.proto containing the input Tensor.
  232. Args:
  233. name: A name for the generated node. Will also serve as the series name in
  234. TensorBoard.
  235. tensor: A real numeric Tensor containing a single value.
  236. collections: Optional list of graph collections keys. The new summary op is
  237. added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
  238. new_style: Whether to use new style (tensor field) or old style (simple_value
  239. field). New style could lead to faster data loading.
  240. Returns:
  241. A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
  242. Raises:
  243. ValueError: If tensor has the wrong shape or type.
  244. """
  245. tensor = make_np(tensor).squeeze()
  246. assert (
  247. tensor.ndim == 0
  248. ), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions."
  249. # python float is double precision in numpy
  250. scalar = float(tensor)
  251. if new_style:
  252. tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
  253. if double_precision:
  254. tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")
  255. plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
  256. smd = SummaryMetadata(plugin_data=plugin_data)
  257. return Summary(
  258. value=[
  259. Summary.Value(
  260. tag=name,
  261. tensor=tensor_proto,
  262. metadata=smd,
  263. )
  264. ]
  265. )
  266. else:
  267. return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
  268. def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts):
  269. # pylint: disable=line-too-long
  270. """Outputs a `Summary` protocol buffer with a histogram.
  271. The generated
  272. [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
  273. has one summary value containing a histogram for `values`.
  274. Args:
  275. name: A name for the generated node. Will also serve as a series name in
  276. TensorBoard.
  277. min: A float or int min value
  278. max: A float or int max value
  279. num: Int number of values
  280. sum: Float or int sum of all values
  281. sum_squares: Float or int sum of squares for all values
  282. bucket_limits: A numeric `Tensor` with upper value per bucket
  283. bucket_counts: A numeric `Tensor` with number of values per bucket
  284. Returns:
  285. A scalar `Tensor` of type `string`. The serialized `Summary` protocol
  286. buffer.
  287. """
  288. hist = HistogramProto(
  289. min=min,
  290. max=max,
  291. num=num,
  292. sum=sum,
  293. sum_squares=sum_squares,
  294. bucket_limit=bucket_limits,
  295. bucket=bucket_counts,
  296. )
  297. return Summary(value=[Summary.Value(tag=name, histo=hist)])
  298. def histogram(name, values, bins, max_bins=None):
  299. # pylint: disable=line-too-long
  300. """Outputs a `Summary` protocol buffer with a histogram.
  301. The generated
  302. [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
  303. has one summary value containing a histogram for `values`.
  304. This op reports an `InvalidArgument` error if any value is not finite.
  305. Args:
  306. name: A name for the generated node. Will also serve as a series name in
  307. TensorBoard.
  308. values: A real numeric `Tensor`. Any shape. Values to use to
  309. build the histogram.
  310. Returns:
  311. A scalar `Tensor` of type `string`. The serialized `Summary` protocol
  312. buffer.
  313. """
  314. values = make_np(values)
  315. hist = make_histogram(values.astype(float), bins, max_bins)
  316. return Summary(value=[Summary.Value(tag=name, histo=hist)])
  317. def make_histogram(values, bins, max_bins=None):
  318. """Convert values into a histogram proto using logic from histogram.cc."""
  319. if values.size == 0:
  320. raise ValueError("The input has no element.")
  321. values = values.reshape(-1)
  322. counts, limits = np.histogram(values, bins=bins)
  323. num_bins = len(counts)
  324. if max_bins is not None and num_bins > max_bins:
  325. subsampling = num_bins // max_bins
  326. subsampling_remainder = num_bins % subsampling
  327. if subsampling_remainder != 0:
  328. counts = np.pad(
  329. counts,
  330. pad_width=[[0, subsampling - subsampling_remainder]],
  331. mode="constant",
  332. constant_values=0,
  333. )
  334. counts = counts.reshape(-1, subsampling).sum(axis=-1)
  335. new_limits = np.empty((counts.size + 1,), limits.dtype)
  336. new_limits[:-1] = limits[:-1:subsampling]
  337. new_limits[-1] = limits[-1]
  338. limits = new_limits
  339. # Find the first and the last bin defining the support of the histogram:
  340. cum_counts = np.cumsum(np.greater(counts, 0))
  341. start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
  342. start = int(start)
  343. end = int(end) + 1
  344. del cum_counts
  345. # TensorBoard only includes the right bin limits. To still have the leftmost limit
  346. # included, we include an empty bin left.
  347. # If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the
  348. # first nonzero-count bin:
  349. counts = (
  350. counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
  351. )
  352. limits = limits[start : end + 1]
  353. if counts.size == 0 or limits.size == 0:
  354. raise ValueError("The histogram is empty, please file a bug report.")
  355. sum_sq = values.dot(values)
  356. return HistogramProto(
  357. min=values.min(),
  358. max=values.max(),
  359. num=len(values),
  360. sum=values.sum(),
  361. sum_squares=sum_sq,
  362. bucket_limit=limits.tolist(),
  363. bucket=counts.tolist(),
  364. )
  365. def image(tag, tensor, rescale=1, dataformats="NCHW"):
  366. """Outputs a `Summary` protocol buffer with images.
  367. The summary has up to `max_images` summary values containing images. The
  368. images are built from `tensor` which must be 3-D with shape `[height, width,
  369. channels]` and where `channels` can be:
  370. * 1: `tensor` is interpreted as Grayscale.
  371. * 3: `tensor` is interpreted as RGB.
  372. * 4: `tensor` is interpreted as RGBA.
  373. The `name` in the outputted Summary.Value protobufs is generated based on the
  374. name, with a suffix depending on the max_outputs setting:
  375. * If `max_outputs` is 1, the summary value tag is '*name*/image'.
  376. * If `max_outputs` is greater than 1, the summary value tags are
  377. generated sequentially as '*name*/image/0', '*name*/image/1', etc.
  378. Args:
  379. tag: A name for the generated node. Will also serve as a series name in
  380. TensorBoard.
  381. tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
  382. channels]` where `channels` is 1, 3, or 4.
  383. 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
  384. The image() function will scale the image values to [0, 255] by applying
  385. a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
  386. will be clipped.
  387. Returns:
  388. A scalar `Tensor` of type `string`. The serialized `Summary` protocol
  389. buffer.
  390. """
  391. tensor = make_np(tensor)
  392. tensor = convert_to_HWC(tensor, dataformats)
  393. # Do not assume that user passes in values in [0, 255], use data type to detect
  394. scale_factor = _calc_scale_factor(tensor)
  395. tensor = tensor.astype(np.float32)
  396. tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
  397. image = make_image(tensor, rescale=rescale)
  398. return Summary(value=[Summary.Value(tag=tag, image=image)])
  399. def image_boxes(
  400. tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None
  401. ):
  402. """Outputs a `Summary` protocol buffer with images."""
  403. tensor_image = make_np(tensor_image)
  404. tensor_image = convert_to_HWC(tensor_image, dataformats)
  405. tensor_boxes = make_np(tensor_boxes)
  406. tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image)
  407. image = make_image(
  408. tensor_image.clip(0, 255).astype(np.uint8), rescale=rescale, rois=tensor_boxes, labels=labels
  409. )
  410. return Summary(value=[Summary.Value(tag=tag, image=image)])
  411. def draw_boxes(disp_image, boxes, labels=None):
  412. # xyxy format
  413. num_boxes = boxes.shape[0]
  414. list_gt = range(num_boxes)
  415. for i in list_gt:
  416. disp_image = _draw_single_box(
  417. disp_image,
  418. boxes[i, 0],
  419. boxes[i, 1],
  420. boxes[i, 2],
  421. boxes[i, 3],
  422. display_str=None if labels is None else labels[i],
  423. color="Red",
  424. )
  425. return disp_image
  426. def make_image(tensor, rescale=1, rois=None, labels=None):
  427. """Convert a numpy representation of an image to Image protobuf"""
  428. from PIL import Image
  429. height, width, channel = tensor.shape
  430. scaled_height = int(height * rescale)
  431. scaled_width = int(width * rescale)
  432. image = Image.fromarray(tensor)
  433. if rois is not None:
  434. image = draw_boxes(image, rois, labels=labels)
  435. try:
  436. ANTIALIAS = Image.Resampling.LANCZOS
  437. except AttributeError:
  438. ANTIALIAS = Image.ANTIALIAS
  439. image = image.resize((scaled_width, scaled_height), ANTIALIAS)
  440. import io
  441. output = io.BytesIO()
  442. image.save(output, format="PNG")
  443. image_string = output.getvalue()
  444. output.close()
  445. return Summary.Image(
  446. height=height,
  447. width=width,
  448. colorspace=channel,
  449. encoded_image_string=image_string,
  450. )
  451. def video(tag, tensor, fps=4):
  452. tensor = make_np(tensor)
  453. tensor = _prepare_video(tensor)
  454. # If user passes in uint8, then we don't need to rescale by 255
  455. scale_factor = _calc_scale_factor(tensor)
  456. tensor = tensor.astype(np.float32)
  457. tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
  458. video = make_video(tensor, fps)
  459. return Summary(value=[Summary.Value(tag=tag, image=video)])
  460. def make_video(tensor, fps):
  461. try:
  462. import moviepy # noqa: F401
  463. except ImportError:
  464. print("add_video needs package moviepy")
  465. return
  466. try:
  467. from moviepy import editor as mpy
  468. except ImportError:
  469. print(
  470. "moviepy is installed, but can't import moviepy.editor.",
  471. "Some packages could be missing [imageio, requests]",
  472. )
  473. return
  474. import tempfile
  475. t, h, w, c = tensor.shape
  476. # encode sequence of images into gif string
  477. clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
  478. filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name
  479. try: # newer version of moviepy use logger instead of progress_bar argument.
  480. clip.write_gif(filename, verbose=False, logger=None)
  481. except TypeError:
  482. try: # older version of moviepy does not support progress_bar argument.
  483. clip.write_gif(filename, verbose=False, progress_bar=False)
  484. except TypeError:
  485. clip.write_gif(filename, verbose=False)
  486. with open(filename, "rb") as f:
  487. tensor_string = f.read()
  488. try:
  489. os.remove(filename)
  490. except OSError:
  491. logger.warning("The temporary file used by moviepy cannot be deleted.")
  492. return Summary.Image(
  493. height=h, width=w, colorspace=c, encoded_image_string=tensor_string
  494. )
  495. def audio(tag, tensor, sample_rate=44100):
  496. array = make_np(tensor)
  497. array = array.squeeze()
  498. if abs(array).max() > 1:
  499. print("warning: audio amplitude out of range, auto clipped.")
  500. array = array.clip(-1, 1)
  501. assert array.ndim == 1, "input tensor should be 1 dimensional."
  502. array = (array * np.iinfo(np.int16).max).astype("<i2")
  503. import io
  504. import wave
  505. fio = io.BytesIO()
  506. with wave.open(fio, "wb") as wave_write:
  507. wave_write.setnchannels(1)
  508. wave_write.setsampwidth(2)
  509. wave_write.setframerate(sample_rate)
  510. wave_write.writeframes(array.data)
  511. audio_string = fio.getvalue()
  512. fio.close()
  513. audio = Summary.Audio(
  514. sample_rate=sample_rate,
  515. num_channels=1,
  516. length_frames=array.shape[-1],
  517. encoded_audio_string=audio_string,
  518. content_type="audio/wav",
  519. )
  520. return Summary(value=[Summary.Value(tag=tag, audio=audio)])
  521. def custom_scalars(layout):
  522. categories = []
  523. for k, v in layout.items():
  524. charts = []
  525. for chart_name, chart_meatadata in v.items():
  526. tags = chart_meatadata[1]
  527. if chart_meatadata[0] == "Margin":
  528. assert len(tags) == 3
  529. mgcc = layout_pb2.MarginChartContent(
  530. series=[
  531. layout_pb2.MarginChartContent.Series(
  532. value=tags[0], lower=tags[1], upper=tags[2]
  533. )
  534. ]
  535. )
  536. chart = layout_pb2.Chart(title=chart_name, margin=mgcc)
  537. else:
  538. mlcc = layout_pb2.MultilineChartContent(tag=tags)
  539. chart = layout_pb2.Chart(title=chart_name, multiline=mlcc)
  540. charts.append(chart)
  541. categories.append(layout_pb2.Category(title=k, chart=charts))
  542. layout = layout_pb2.Layout(category=categories)
  543. plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars")
  544. smd = SummaryMetadata(plugin_data=plugin_data)
  545. tensor = TensorProto(
  546. dtype="DT_STRING",
  547. string_val=[layout.SerializeToString()],
  548. tensor_shape=TensorShapeProto(),
  549. )
  550. return Summary(
  551. value=[
  552. Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd)
  553. ]
  554. )
  555. def text(tag, text):
  556. plugin_data = SummaryMetadata.PluginData(
  557. plugin_name="text", content=TextPluginData(version=0).SerializeToString()
  558. )
  559. smd = SummaryMetadata(plugin_data=plugin_data)
  560. tensor = TensorProto(
  561. dtype="DT_STRING",
  562. string_val=[text.encode(encoding="utf_8")],
  563. tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
  564. )
  565. return Summary(
  566. value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)]
  567. )
  568. def pr_curve_raw(
  569. tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None
  570. ):
  571. if num_thresholds > 127: # weird, value > 127 breaks protobuf
  572. num_thresholds = 127
  573. data = np.stack((tp, fp, tn, fn, precision, recall))
  574. pr_curve_plugin_data = PrCurvePluginData(
  575. version=0, num_thresholds=num_thresholds
  576. ).SerializeToString()
  577. plugin_data = SummaryMetadata.PluginData(
  578. plugin_name="pr_curves", content=pr_curve_plugin_data
  579. )
  580. smd = SummaryMetadata(plugin_data=plugin_data)
  581. tensor = TensorProto(
  582. dtype="DT_FLOAT",
  583. float_val=data.reshape(-1).tolist(),
  584. tensor_shape=TensorShapeProto(
  585. dim=[
  586. TensorShapeProto.Dim(size=data.shape[0]),
  587. TensorShapeProto.Dim(size=data.shape[1]),
  588. ]
  589. ),
  590. )
  591. return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
  592. def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
  593. # weird, value > 127 breaks protobuf
  594. num_thresholds = min(num_thresholds, 127)
  595. data = compute_curve(
  596. labels, predictions, num_thresholds=num_thresholds, weights=weights
  597. )
  598. pr_curve_plugin_data = PrCurvePluginData(
  599. version=0, num_thresholds=num_thresholds
  600. ).SerializeToString()
  601. plugin_data = SummaryMetadata.PluginData(
  602. plugin_name="pr_curves", content=pr_curve_plugin_data
  603. )
  604. smd = SummaryMetadata(plugin_data=plugin_data)
  605. tensor = TensorProto(
  606. dtype="DT_FLOAT",
  607. float_val=data.reshape(-1).tolist(),
  608. tensor_shape=TensorShapeProto(
  609. dim=[
  610. TensorShapeProto.Dim(size=data.shape[0]),
  611. TensorShapeProto.Dim(size=data.shape[1]),
  612. ]
  613. ),
  614. )
  615. return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
  616. # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
  617. def compute_curve(labels, predictions, num_thresholds=None, weights=None):
  618. _MINIMUM_COUNT = 1e-7
  619. if weights is None:
  620. weights = 1.0
  621. # Compute bins of true positives and false positives.
  622. bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
  623. float_labels = labels.astype(np.float64)
  624. histogram_range = (0, num_thresholds - 1)
  625. tp_buckets, _ = np.histogram(
  626. bucket_indices,
  627. bins=num_thresholds,
  628. range=histogram_range,
  629. weights=float_labels * weights,
  630. )
  631. fp_buckets, _ = np.histogram(
  632. bucket_indices,
  633. bins=num_thresholds,
  634. range=histogram_range,
  635. weights=(1.0 - float_labels) * weights,
  636. )
  637. # Obtain the reverse cumulative sum.
  638. tp = np.cumsum(tp_buckets[::-1])[::-1]
  639. fp = np.cumsum(fp_buckets[::-1])[::-1]
  640. tn = fp[0] - fp
  641. fn = tp[0] - tp
  642. precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
  643. recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
  644. return np.stack((tp, fp, tn, fn, precision, recall))
  645. def _get_tensor_summary(
  646. name, display_name, description, tensor, content_type, components, json_config
  647. ):
  648. """Creates a tensor summary with summary metadata.
  649. Args:
  650. name: Uniquely identifiable name of the summary op. Could be replaced by
  651. combination of name and type to make it unique even outside of this
  652. summary.
  653. display_name: Will be used as the display name in TensorBoard.
  654. Defaults to `name`.
  655. description: A longform readable description of the summary data. Markdown
  656. is supported.
  657. tensor: Tensor to display in summary.
  658. content_type: Type of content inside the Tensor.
  659. components: Bitmask representing present parts (vertices, colors, etc.) that
  660. belong to the summary.
  661. json_config: A string, JSON-serialized dictionary of ThreeJS classes
  662. configuration.
  663. Returns:
  664. Tensor summary with metadata.
  665. """
  666. import torch
  667. from tensorboard.plugins.mesh import metadata
  668. tensor = torch.as_tensor(tensor)
  669. tensor_metadata = metadata.create_summary_metadata(
  670. name,
  671. display_name,
  672. content_type,
  673. components,
  674. tensor.shape,
  675. description,
  676. json_config=json_config,
  677. )
  678. tensor = TensorProto(
  679. dtype="DT_FLOAT",
  680. float_val=tensor.reshape(-1).tolist(),
  681. tensor_shape=TensorShapeProto(
  682. dim=[
  683. TensorShapeProto.Dim(size=tensor.shape[0]),
  684. TensorShapeProto.Dim(size=tensor.shape[1]),
  685. TensorShapeProto.Dim(size=tensor.shape[2]),
  686. ]
  687. ),
  688. )
  689. tensor_summary = Summary.Value(
  690. tag=metadata.get_instance_name(name, content_type),
  691. tensor=tensor,
  692. metadata=tensor_metadata,
  693. )
  694. return tensor_summary
  695. def _get_json_config(config_dict):
  696. """Parses and returns JSON string from python dictionary."""
  697. json_config = "{}"
  698. if config_dict is not None:
  699. json_config = json.dumps(config_dict, sort_keys=True)
  700. return json_config
  701. # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py
  702. def mesh(
  703. tag, vertices, colors, faces, config_dict, display_name=None, description=None
  704. ):
  705. """Outputs a merged `Summary` protocol buffer with a mesh/point cloud.
  706. Args:
  707. tag: A name for this summary operation.
  708. vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
  709. coordinates of vertices.
  710. faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
  711. vertices within each triangle.
  712. colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
  713. vertex.
  714. display_name: If set, will be used as the display name in TensorBoard.
  715. Defaults to `name`.
  716. description: A longform readable description of the summary data. Markdown
  717. is supported.
  718. config_dict: Dictionary with ThreeJS classes names and configuration.
  719. Returns:
  720. Merged summary for mesh/point cloud representation.
  721. """
  722. from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
  723. from tensorboard.plugins.mesh import metadata
  724. json_config = _get_json_config(config_dict)
  725. summaries = []
  726. tensors = [
  727. (vertices, MeshPluginData.VERTEX),
  728. (faces, MeshPluginData.FACE),
  729. (colors, MeshPluginData.COLOR),
  730. ]
  731. tensors = [tensor for tensor in tensors if tensor[0] is not None]
  732. components = metadata.get_components_bitmask(
  733. [content_type for (tensor, content_type) in tensors]
  734. )
  735. for tensor, content_type in tensors:
  736. summaries.append(
  737. _get_tensor_summary(
  738. tag,
  739. display_name,
  740. description,
  741. tensor,
  742. content_type,
  743. components,
  744. json_config,
  745. )
  746. )
  747. return Summary(value=summaries)