123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849 |
- import json
- import logging
- import os
- from typing import Optional
- import numpy as np
- from google.protobuf import struct_pb2
- from tensorboard.compat.proto.summary_pb2 import HistogramProto
- from tensorboard.compat.proto.summary_pb2 import Summary
- from tensorboard.compat.proto.summary_pb2 import SummaryMetadata
- from tensorboard.compat.proto.tensor_pb2 import TensorProto
- from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
- from tensorboard.plugins.custom_scalar import layout_pb2
- from tensorboard.plugins.pr_curve.plugin_data_pb2 import PrCurvePluginData
- from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
- from ._convert_np import make_np
- from ._utils import _prepare_video, convert_to_HWC
- __all__ = ['hparams', 'scalar', 'histogram_raw', 'histogram', 'make_histogram', 'image', 'image_boxes', 'draw_boxes',
- 'make_image', 'video', 'make_video', 'audio', 'custom_scalars', 'text', 'pr_curve_raw', 'pr_curve', 'compute_curve',
- 'mesh']
- logger = logging.getLogger(__name__)
- def _calc_scale_factor(tensor):
- converted = tensor.numpy() if not isinstance(tensor, np.ndarray) else tensor
- return 1 if converted.dtype == np.uint8 else 255
- def _draw_single_box(
- image,
- xmin,
- ymin,
- xmax,
- ymax,
- display_str,
- color="black",
- color_text="black",
- thickness=2,
- ):
- from PIL import ImageDraw, ImageFont
- font = ImageFont.load_default()
- draw = ImageDraw.Draw(image)
- (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
- draw.line(
- [(left, top), (left, bottom), (right, bottom), (right, top), (left, top)],
- width=thickness,
- fill=color,
- )
- if display_str:
- text_bottom = bottom
- # Reverse list and print from bottom to top.
- text_width, text_height = font.getsize(display_str)
- margin = np.ceil(0.05 * text_height)
- draw.rectangle(
- [
- (left, text_bottom - text_height - 2 * margin),
- (left + text_width, text_bottom),
- ],
- fill=color,
- )
- draw.text(
- (left + margin, text_bottom - text_height - margin),
- display_str,
- fill=color_text,
- font=font,
- )
- return image
- def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None):
- """Outputs three `Summary` protocol buffers needed by hparams plugin.
- `Experiment` keeps the metadata of an experiment, such as the name of the
- hyperparameters and the name of the metrics.
- `SessionStartInfo` keeps key-value pairs of the hyperparameters
- `SessionEndInfo` describes status of the experiment e.g. STATUS_SUCCESS
- Args:
- hparam_dict: A dictionary that contains names of the hyperparameters
- and their values.
- metric_dict: A dictionary that contains names of the metrics
- and their values.
- hparam_domain_discrete: (Optional[Dict[str, List[Any]]]) A dictionary that
- contains names of the hyperparameters and all discrete values they can hold
- Returns:
- The `Summary` protobufs for Experiment, SessionStartInfo and
- SessionEndInfo
- """
- import torch
- from tensorboard.plugins.hparams.api_pb2 import (
- Experiment,
- HParamInfo,
- MetricInfo,
- MetricName,
- Status,
- DataType,
- )
- from tensorboard.plugins.hparams.metadata import (
- PLUGIN_NAME,
- PLUGIN_DATA_VERSION,
- EXPERIMENT_TAG,
- SESSION_START_INFO_TAG,
- SESSION_END_INFO_TAG,
- )
- from tensorboard.plugins.hparams.plugin_data_pb2 import (
- HParamsPluginData,
- SessionEndInfo,
- SessionStartInfo,
- )
- # TODO: expose other parameters in the future.
- # hp = HParamInfo(name='lr',display_name='learning rate',
- # type=DataType.DATA_TYPE_FLOAT64, domain_interval=Interval(min_value=10,
- # max_value=100))
- # mt = MetricInfo(name=MetricName(tag='accuracy'), display_name='accuracy',
- # description='', dataset_type=DatasetType.DATASET_VALIDATION)
- # exp = Experiment(name='123', description='456', time_created_secs=100.0,
- # hparam_infos=[hp], metric_infos=[mt], user='tw')
- if not isinstance(hparam_dict, dict):
- logger.warning("parameter: hparam_dict should be a dictionary, nothing logged.")
- raise TypeError(
- "parameter: hparam_dict should be a dictionary, nothing logged."
- )
- if not isinstance(metric_dict, dict):
- logger.warning("parameter: metric_dict should be a dictionary, nothing logged.")
- raise TypeError(
- "parameter: metric_dict should be a dictionary, nothing logged."
- )
- hparam_domain_discrete = hparam_domain_discrete or {}
- if not isinstance(hparam_domain_discrete, dict):
- raise TypeError(
- "parameter: hparam_domain_discrete should be a dictionary, nothing logged."
- )
- for k, v in hparam_domain_discrete.items():
- if (
- k not in hparam_dict
- or not isinstance(v, list)
- or not all(isinstance(d, type(hparam_dict[k])) for d in v)
- ):
- raise TypeError(
- "parameter: hparam_domain_discrete[{}] should be a list of same type as "
- "hparam_dict[{}].".format(k, k)
- )
- hps = []
- ssi = SessionStartInfo()
- for k, v in hparam_dict.items():
- if v is None:
- continue
- if isinstance(v, (int, float)):
- ssi.hparams[k].number_value = v
- if k in hparam_domain_discrete:
- domain_discrete: Optional[struct_pb2.ListValue] = struct_pb2.ListValue(
- values=[
- struct_pb2.Value(number_value=d)
- for d in hparam_domain_discrete[k]
- ]
- )
- else:
- domain_discrete = None
- hps.append(
- HParamInfo(
- name=k,
- type=DataType.Value("DATA_TYPE_FLOAT64"),
- domain_discrete=domain_discrete,
- )
- )
- continue
- if isinstance(v, str):
- ssi.hparams[k].string_value = v
- if k in hparam_domain_discrete:
- domain_discrete = struct_pb2.ListValue(
- values=[
- struct_pb2.Value(string_value=d)
- for d in hparam_domain_discrete[k]
- ]
- )
- else:
- domain_discrete = None
- hps.append(
- HParamInfo(
- name=k,
- type=DataType.Value("DATA_TYPE_STRING"),
- domain_discrete=domain_discrete,
- )
- )
- continue
- if isinstance(v, bool):
- ssi.hparams[k].bool_value = v
- if k in hparam_domain_discrete:
- domain_discrete = struct_pb2.ListValue(
- values=[
- struct_pb2.Value(bool_value=d)
- for d in hparam_domain_discrete[k]
- ]
- )
- else:
- domain_discrete = None
- hps.append(
- HParamInfo(
- name=k,
- type=DataType.Value("DATA_TYPE_BOOL"),
- domain_discrete=domain_discrete,
- )
- )
- continue
- if isinstance(v, torch.Tensor):
- v = make_np(v)[0]
- ssi.hparams[k].number_value = v
- hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64")))
- continue
- raise ValueError(
- "value should be one of int, float, str, bool, or torch.Tensor"
- )
- content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION)
- smd = SummaryMetadata(
- plugin_data=SummaryMetadata.PluginData(
- plugin_name=PLUGIN_NAME, content=content.SerializeToString()
- )
- )
- ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)])
- mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()]
- exp = Experiment(hparam_infos=hps, metric_infos=mts)
- content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION)
- smd = SummaryMetadata(
- plugin_data=SummaryMetadata.PluginData(
- plugin_name=PLUGIN_NAME, content=content.SerializeToString()
- )
- )
- exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)])
- sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS"))
- content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION)
- smd = SummaryMetadata(
- plugin_data=SummaryMetadata.PluginData(
- plugin_name=PLUGIN_NAME, content=content.SerializeToString()
- )
- )
- sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)])
- return exp, ssi, sei
- def scalar(name, tensor, collections=None, new_style=False, double_precision=False):
- """Outputs a `Summary` protocol buffer containing a single scalar value.
- The generated Summary has a Tensor.proto containing the input Tensor.
- Args:
- name: A name for the generated node. Will also serve as the series name in
- TensorBoard.
- tensor: A real numeric Tensor containing a single value.
- collections: Optional list of graph collections keys. The new summary op is
- added to these collections. Defaults to `[GraphKeys.SUMMARIES]`.
- new_style: Whether to use new style (tensor field) or old style (simple_value
- field). New style could lead to faster data loading.
- Returns:
- A scalar `Tensor` of type `string`. Which contains a `Summary` protobuf.
- Raises:
- ValueError: If tensor has the wrong shape or type.
- """
- tensor = make_np(tensor).squeeze()
- assert (
- tensor.ndim == 0
- ), f"Tensor should contain one element (0 dimensions). Was given size: {tensor.size} and {tensor.ndim} dimensions."
- # python float is double precision in numpy
- scalar = float(tensor)
- if new_style:
- tensor_proto = TensorProto(float_val=[scalar], dtype="DT_FLOAT")
- if double_precision:
- tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE")
- plugin_data = SummaryMetadata.PluginData(plugin_name="scalars")
- smd = SummaryMetadata(plugin_data=plugin_data)
- return Summary(
- value=[
- Summary.Value(
- tag=name,
- tensor=tensor_proto,
- metadata=smd,
- )
- ]
- )
- else:
- return Summary(value=[Summary.Value(tag=name, simple_value=scalar)])
- def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_counts):
- # pylint: disable=line-too-long
- """Outputs a `Summary` protocol buffer with a histogram.
- The generated
- [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
- has one summary value containing a histogram for `values`.
- Args:
- name: A name for the generated node. Will also serve as a series name in
- TensorBoard.
- min: A float or int min value
- max: A float or int max value
- num: Int number of values
- sum: Float or int sum of all values
- sum_squares: Float or int sum of squares for all values
- bucket_limits: A numeric `Tensor` with upper value per bucket
- bucket_counts: A numeric `Tensor` with number of values per bucket
- Returns:
- A scalar `Tensor` of type `string`. The serialized `Summary` protocol
- buffer.
- """
- hist = HistogramProto(
- min=min,
- max=max,
- num=num,
- sum=sum,
- sum_squares=sum_squares,
- bucket_limit=bucket_limits,
- bucket=bucket_counts,
- )
- return Summary(value=[Summary.Value(tag=name, histo=hist)])
- def histogram(name, values, bins, max_bins=None):
- # pylint: disable=line-too-long
- """Outputs a `Summary` protocol buffer with a histogram.
- The generated
- [`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto)
- has one summary value containing a histogram for `values`.
- This op reports an `InvalidArgument` error if any value is not finite.
- Args:
- name: A name for the generated node. Will also serve as a series name in
- TensorBoard.
- values: A real numeric `Tensor`. Any shape. Values to use to
- build the histogram.
- Returns:
- A scalar `Tensor` of type `string`. The serialized `Summary` protocol
- buffer.
- """
- values = make_np(values)
- hist = make_histogram(values.astype(float), bins, max_bins)
- return Summary(value=[Summary.Value(tag=name, histo=hist)])
- def make_histogram(values, bins, max_bins=None):
- """Convert values into a histogram proto using logic from histogram.cc."""
- if values.size == 0:
- raise ValueError("The input has no element.")
- values = values.reshape(-1)
- counts, limits = np.histogram(values, bins=bins)
- num_bins = len(counts)
- if max_bins is not None and num_bins > max_bins:
- subsampling = num_bins // max_bins
- subsampling_remainder = num_bins % subsampling
- if subsampling_remainder != 0:
- counts = np.pad(
- counts,
- pad_width=[[0, subsampling - subsampling_remainder]],
- mode="constant",
- constant_values=0,
- )
- counts = counts.reshape(-1, subsampling).sum(axis=-1)
- new_limits = np.empty((counts.size + 1,), limits.dtype)
- new_limits[:-1] = limits[:-1:subsampling]
- new_limits[-1] = limits[-1]
- limits = new_limits
- # Find the first and the last bin defining the support of the histogram:
- cum_counts = np.cumsum(np.greater(counts, 0))
- start, end = np.searchsorted(cum_counts, [0, cum_counts[-1] - 1], side="right")
- start = int(start)
- end = int(end) + 1
- del cum_counts
- # TensorBoard only includes the right bin limits. To still have the leftmost limit
- # included, we include an empty bin left.
- # If start == 0, we need to add an empty one left, otherwise we can just include the bin left to the
- # first nonzero-count bin:
- counts = (
- counts[start - 1 : end] if start > 0 else np.concatenate([[0], counts[:end]])
- )
- limits = limits[start : end + 1]
- if counts.size == 0 or limits.size == 0:
- raise ValueError("The histogram is empty, please file a bug report.")
- sum_sq = values.dot(values)
- return HistogramProto(
- min=values.min(),
- max=values.max(),
- num=len(values),
- sum=values.sum(),
- sum_squares=sum_sq,
- bucket_limit=limits.tolist(),
- bucket=counts.tolist(),
- )
- def image(tag, tensor, rescale=1, dataformats="NCHW"):
- """Outputs a `Summary` protocol buffer with images.
- The summary has up to `max_images` summary values containing images. The
- images are built from `tensor` which must be 3-D with shape `[height, width,
- channels]` and where `channels` can be:
- * 1: `tensor` is interpreted as Grayscale.
- * 3: `tensor` is interpreted as RGB.
- * 4: `tensor` is interpreted as RGBA.
- The `name` in the outputted Summary.Value protobufs is generated based on the
- name, with a suffix depending on the max_outputs setting:
- * If `max_outputs` is 1, the summary value tag is '*name*/image'.
- * If `max_outputs` is greater than 1, the summary value tags are
- generated sequentially as '*name*/image/0', '*name*/image/1', etc.
- Args:
- tag: A name for the generated node. Will also serve as a series name in
- TensorBoard.
- tensor: A 3-D `uint8` or `float32` `Tensor` of shape `[height, width,
- channels]` where `channels` is 1, 3, or 4.
- 'tensor' can either have values in [0, 1] (float32) or [0, 255] (uint8).
- The image() function will scale the image values to [0, 255] by applying
- a scale factor of either 1 (uint8) or 255 (float32). Out-of-range values
- will be clipped.
- Returns:
- A scalar `Tensor` of type `string`. The serialized `Summary` protocol
- buffer.
- """
- tensor = make_np(tensor)
- tensor = convert_to_HWC(tensor, dataformats)
- # Do not assume that user passes in values in [0, 255], use data type to detect
- scale_factor = _calc_scale_factor(tensor)
- tensor = tensor.astype(np.float32)
- tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
- image = make_image(tensor, rescale=rescale)
- return Summary(value=[Summary.Value(tag=tag, image=image)])
- def image_boxes(
- tag, tensor_image, tensor_boxes, rescale=1, dataformats="CHW", labels=None
- ):
- """Outputs a `Summary` protocol buffer with images."""
- tensor_image = make_np(tensor_image)
- tensor_image = convert_to_HWC(tensor_image, dataformats)
- tensor_boxes = make_np(tensor_boxes)
- tensor_image = tensor_image.astype(np.float32) * _calc_scale_factor(tensor_image)
- image = make_image(
- tensor_image.clip(0, 255).astype(np.uint8), rescale=rescale, rois=tensor_boxes, labels=labels
- )
- return Summary(value=[Summary.Value(tag=tag, image=image)])
- def draw_boxes(disp_image, boxes, labels=None):
- # xyxy format
- num_boxes = boxes.shape[0]
- list_gt = range(num_boxes)
- for i in list_gt:
- disp_image = _draw_single_box(
- disp_image,
- boxes[i, 0],
- boxes[i, 1],
- boxes[i, 2],
- boxes[i, 3],
- display_str=None if labels is None else labels[i],
- color="Red",
- )
- return disp_image
- def make_image(tensor, rescale=1, rois=None, labels=None):
- """Convert a numpy representation of an image to Image protobuf"""
- from PIL import Image
- height, width, channel = tensor.shape
- scaled_height = int(height * rescale)
- scaled_width = int(width * rescale)
- image = Image.fromarray(tensor)
- if rois is not None:
- image = draw_boxes(image, rois, labels=labels)
- try:
- ANTIALIAS = Image.Resampling.LANCZOS
- except AttributeError:
- ANTIALIAS = Image.ANTIALIAS
- image = image.resize((scaled_width, scaled_height), ANTIALIAS)
- import io
- output = io.BytesIO()
- image.save(output, format="PNG")
- image_string = output.getvalue()
- output.close()
- return Summary.Image(
- height=height,
- width=width,
- colorspace=channel,
- encoded_image_string=image_string,
- )
- def video(tag, tensor, fps=4):
- tensor = make_np(tensor)
- tensor = _prepare_video(tensor)
- # If user passes in uint8, then we don't need to rescale by 255
- scale_factor = _calc_scale_factor(tensor)
- tensor = tensor.astype(np.float32)
- tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8)
- video = make_video(tensor, fps)
- return Summary(value=[Summary.Value(tag=tag, image=video)])
- def make_video(tensor, fps):
- try:
- import moviepy # noqa: F401
- except ImportError:
- print("add_video needs package moviepy")
- return
- try:
- from moviepy import editor as mpy
- except ImportError:
- print(
- "moviepy is installed, but can't import moviepy.editor.",
- "Some packages could be missing [imageio, requests]",
- )
- return
- import tempfile
- t, h, w, c = tensor.shape
- # encode sequence of images into gif string
- clip = mpy.ImageSequenceClip(list(tensor), fps=fps)
- filename = tempfile.NamedTemporaryFile(suffix=".gif", delete=False).name
- try: # newer version of moviepy use logger instead of progress_bar argument.
- clip.write_gif(filename, verbose=False, logger=None)
- except TypeError:
- try: # older version of moviepy does not support progress_bar argument.
- clip.write_gif(filename, verbose=False, progress_bar=False)
- except TypeError:
- clip.write_gif(filename, verbose=False)
- with open(filename, "rb") as f:
- tensor_string = f.read()
- try:
- os.remove(filename)
- except OSError:
- logger.warning("The temporary file used by moviepy cannot be deleted.")
- return Summary.Image(
- height=h, width=w, colorspace=c, encoded_image_string=tensor_string
- )
- def audio(tag, tensor, sample_rate=44100):
- array = make_np(tensor)
- array = array.squeeze()
- if abs(array).max() > 1:
- print("warning: audio amplitude out of range, auto clipped.")
- array = array.clip(-1, 1)
- assert array.ndim == 1, "input tensor should be 1 dimensional."
- array = (array * np.iinfo(np.int16).max).astype("<i2")
- import io
- import wave
- fio = io.BytesIO()
- with wave.open(fio, "wb") as wave_write:
- wave_write.setnchannels(1)
- wave_write.setsampwidth(2)
- wave_write.setframerate(sample_rate)
- wave_write.writeframes(array.data)
- audio_string = fio.getvalue()
- fio.close()
- audio = Summary.Audio(
- sample_rate=sample_rate,
- num_channels=1,
- length_frames=array.shape[-1],
- encoded_audio_string=audio_string,
- content_type="audio/wav",
- )
- return Summary(value=[Summary.Value(tag=tag, audio=audio)])
- def custom_scalars(layout):
- categories = []
- for k, v in layout.items():
- charts = []
- for chart_name, chart_meatadata in v.items():
- tags = chart_meatadata[1]
- if chart_meatadata[0] == "Margin":
- assert len(tags) == 3
- mgcc = layout_pb2.MarginChartContent(
- series=[
- layout_pb2.MarginChartContent.Series(
- value=tags[0], lower=tags[1], upper=tags[2]
- )
- ]
- )
- chart = layout_pb2.Chart(title=chart_name, margin=mgcc)
- else:
- mlcc = layout_pb2.MultilineChartContent(tag=tags)
- chart = layout_pb2.Chart(title=chart_name, multiline=mlcc)
- charts.append(chart)
- categories.append(layout_pb2.Category(title=k, chart=charts))
- layout = layout_pb2.Layout(category=categories)
- plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars")
- smd = SummaryMetadata(plugin_data=plugin_data)
- tensor = TensorProto(
- dtype="DT_STRING",
- string_val=[layout.SerializeToString()],
- tensor_shape=TensorShapeProto(),
- )
- return Summary(
- value=[
- Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd)
- ]
- )
- def text(tag, text):
- plugin_data = SummaryMetadata.PluginData(
- plugin_name="text", content=TextPluginData(version=0).SerializeToString()
- )
- smd = SummaryMetadata(plugin_data=plugin_data)
- tensor = TensorProto(
- dtype="DT_STRING",
- string_val=[text.encode(encoding="utf_8")],
- tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]),
- )
- return Summary(
- value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)]
- )
- def pr_curve_raw(
- tag, tp, fp, tn, fn, precision, recall, num_thresholds=127, weights=None
- ):
- if num_thresholds > 127: # weird, value > 127 breaks protobuf
- num_thresholds = 127
- data = np.stack((tp, fp, tn, fn, precision, recall))
- pr_curve_plugin_data = PrCurvePluginData(
- version=0, num_thresholds=num_thresholds
- ).SerializeToString()
- plugin_data = SummaryMetadata.PluginData(
- plugin_name="pr_curves", content=pr_curve_plugin_data
- )
- smd = SummaryMetadata(plugin_data=plugin_data)
- tensor = TensorProto(
- dtype="DT_FLOAT",
- float_val=data.reshape(-1).tolist(),
- tensor_shape=TensorShapeProto(
- dim=[
- TensorShapeProto.Dim(size=data.shape[0]),
- TensorShapeProto.Dim(size=data.shape[1]),
- ]
- ),
- )
- return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
- def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None):
- # weird, value > 127 breaks protobuf
- num_thresholds = min(num_thresholds, 127)
- data = compute_curve(
- labels, predictions, num_thresholds=num_thresholds, weights=weights
- )
- pr_curve_plugin_data = PrCurvePluginData(
- version=0, num_thresholds=num_thresholds
- ).SerializeToString()
- plugin_data = SummaryMetadata.PluginData(
- plugin_name="pr_curves", content=pr_curve_plugin_data
- )
- smd = SummaryMetadata(plugin_data=plugin_data)
- tensor = TensorProto(
- dtype="DT_FLOAT",
- float_val=data.reshape(-1).tolist(),
- tensor_shape=TensorShapeProto(
- dim=[
- TensorShapeProto.Dim(size=data.shape[0]),
- TensorShapeProto.Dim(size=data.shape[1]),
- ]
- ),
- )
- return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])
- # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/pr_curve/summary.py
- def compute_curve(labels, predictions, num_thresholds=None, weights=None):
- _MINIMUM_COUNT = 1e-7
- if weights is None:
- weights = 1.0
- # Compute bins of true positives and false positives.
- bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
- float_labels = labels.astype(np.float64)
- histogram_range = (0, num_thresholds - 1)
- tp_buckets, _ = np.histogram(
- bucket_indices,
- bins=num_thresholds,
- range=histogram_range,
- weights=float_labels * weights,
- )
- fp_buckets, _ = np.histogram(
- bucket_indices,
- bins=num_thresholds,
- range=histogram_range,
- weights=(1.0 - float_labels) * weights,
- )
- # Obtain the reverse cumulative sum.
- tp = np.cumsum(tp_buckets[::-1])[::-1]
- fp = np.cumsum(fp_buckets[::-1])[::-1]
- tn = fp[0] - fp
- fn = tp[0] - tp
- precision = tp / np.maximum(_MINIMUM_COUNT, tp + fp)
- recall = tp / np.maximum(_MINIMUM_COUNT, tp + fn)
- return np.stack((tp, fp, tn, fn, precision, recall))
- def _get_tensor_summary(
- name, display_name, description, tensor, content_type, components, json_config
- ):
- """Creates a tensor summary with summary metadata.
- Args:
- name: Uniquely identifiable name of the summary op. Could be replaced by
- combination of name and type to make it unique even outside of this
- summary.
- display_name: Will be used as the display name in TensorBoard.
- Defaults to `name`.
- description: A longform readable description of the summary data. Markdown
- is supported.
- tensor: Tensor to display in summary.
- content_type: Type of content inside the Tensor.
- components: Bitmask representing present parts (vertices, colors, etc.) that
- belong to the summary.
- json_config: A string, JSON-serialized dictionary of ThreeJS classes
- configuration.
- Returns:
- Tensor summary with metadata.
- """
- import torch
- from tensorboard.plugins.mesh import metadata
- tensor = torch.as_tensor(tensor)
- tensor_metadata = metadata.create_summary_metadata(
- name,
- display_name,
- content_type,
- components,
- tensor.shape,
- description,
- json_config=json_config,
- )
- tensor = TensorProto(
- dtype="DT_FLOAT",
- float_val=tensor.reshape(-1).tolist(),
- tensor_shape=TensorShapeProto(
- dim=[
- TensorShapeProto.Dim(size=tensor.shape[0]),
- TensorShapeProto.Dim(size=tensor.shape[1]),
- TensorShapeProto.Dim(size=tensor.shape[2]),
- ]
- ),
- )
- tensor_summary = Summary.Value(
- tag=metadata.get_instance_name(name, content_type),
- tensor=tensor,
- metadata=tensor_metadata,
- )
- return tensor_summary
- def _get_json_config(config_dict):
- """Parses and returns JSON string from python dictionary."""
- json_config = "{}"
- if config_dict is not None:
- json_config = json.dumps(config_dict, sort_keys=True)
- return json_config
- # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/mesh/summary.py
- def mesh(
- tag, vertices, colors, faces, config_dict, display_name=None, description=None
- ):
- """Outputs a merged `Summary` protocol buffer with a mesh/point cloud.
- Args:
- tag: A name for this summary operation.
- vertices: Tensor of shape `[dim_1, ..., dim_n, 3]` representing the 3D
- coordinates of vertices.
- faces: Tensor of shape `[dim_1, ..., dim_n, 3]` containing indices of
- vertices within each triangle.
- colors: Tensor of shape `[dim_1, ..., dim_n, 3]` containing colors for each
- vertex.
- display_name: If set, will be used as the display name in TensorBoard.
- Defaults to `name`.
- description: A longform readable description of the summary data. Markdown
- is supported.
- config_dict: Dictionary with ThreeJS classes names and configuration.
- Returns:
- Merged summary for mesh/point cloud representation.
- """
- from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
- from tensorboard.plugins.mesh import metadata
- json_config = _get_json_config(config_dict)
- summaries = []
- tensors = [
- (vertices, MeshPluginData.VERTEX),
- (faces, MeshPluginData.FACE),
- (colors, MeshPluginData.COLOR),
- ]
- tensors = [tensor for tensor in tensors if tensor[0] is not None]
- components = metadata.get_components_bitmask(
- [content_type for (tensor, content_type) in tensors]
- )
- for tensor, content_type in tensors:
- summaries.append(
- _get_tensor_summary(
- tag,
- display_name,
- description,
- tensor,
- content_type,
- components,
- json_config,
- )
- )
- return Summary(value=summaries)
|