text.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from __future__ import annotations
  2. from collections import defaultdict
  3. from dataclasses import dataclass
  4. import numpy as np
  5. import matplotlib as mpl
  6. from matplotlib.transforms import ScaledTranslation
  7. from seaborn._marks.base import (
  8. Mark,
  9. Mappable,
  10. MappableFloat,
  11. MappableString,
  12. MappableColor,
  13. resolve_properties,
  14. resolve_color,
  15. document_properties,
  16. )
  17. @document_properties
  18. @dataclass
  19. class Text(Mark):
  20. """
  21. A textual mark to annotate or represent data values.
  22. Examples
  23. --------
  24. .. include:: ../docstrings/objects.Text.rst
  25. """
  26. text: MappableString = Mappable("")
  27. color: MappableColor = Mappable("k")
  28. alpha: MappableFloat = Mappable(1)
  29. fontsize: MappableFloat = Mappable(rc="font.size")
  30. halign: MappableString = Mappable("center")
  31. valign: MappableString = Mappable("center_baseline")
  32. offset: MappableFloat = Mappable(4)
  33. def _plot(self, split_gen, scales, orient):
  34. ax_data = defaultdict(list)
  35. for keys, data, ax in split_gen():
  36. vals = resolve_properties(self, keys, scales)
  37. color = resolve_color(self, keys, "", scales)
  38. halign = vals["halign"]
  39. valign = vals["valign"]
  40. fontsize = vals["fontsize"]
  41. offset = vals["offset"] / 72
  42. offset_trans = ScaledTranslation(
  43. {"right": -offset, "left": +offset}.get(halign, 0),
  44. {"top": -offset, "bottom": +offset, "baseline": +offset}.get(valign, 0),
  45. ax.figure.dpi_scale_trans,
  46. )
  47. for row in data.to_dict("records"):
  48. artist = mpl.text.Text(
  49. x=row["x"],
  50. y=row["y"],
  51. text=str(row.get("text", vals["text"])),
  52. color=color,
  53. fontsize=fontsize,
  54. horizontalalignment=halign,
  55. verticalalignment=valign,
  56. transform=ax.transData + offset_trans,
  57. **self.artist_kws,
  58. )
  59. ax.add_artist(artist)
  60. ax_data[ax].append([row["x"], row["y"]])
  61. for ax, ax_vals in ax_data.items():
  62. ax.update_datalim(np.array(ax_vals))