_testing.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import numpy as np
  2. import matplotlib as mpl
  3. from matplotlib.colors import to_rgb, to_rgba
  4. from numpy.testing import assert_array_equal
  5. USE_PROPS = [
  6. "alpha",
  7. "edgecolor",
  8. "facecolor",
  9. "fill",
  10. "hatch",
  11. "height",
  12. "linestyle",
  13. "linewidth",
  14. "paths",
  15. "xy",
  16. "xydata",
  17. "sizes",
  18. "zorder",
  19. ]
  20. def assert_artists_equal(list1, list2):
  21. assert len(list1) == len(list2)
  22. for a1, a2 in zip(list1, list2):
  23. assert a1.__class__ == a2.__class__
  24. prop1 = a1.properties()
  25. prop2 = a2.properties()
  26. for key in USE_PROPS:
  27. if key not in prop1:
  28. continue
  29. v1 = prop1[key]
  30. v2 = prop2[key]
  31. if key == "paths":
  32. for p1, p2 in zip(v1, v2):
  33. assert_array_equal(p1.vertices, p2.vertices)
  34. assert_array_equal(p1.codes, p2.codes)
  35. elif key == "color":
  36. v1 = mpl.colors.to_rgba(v1)
  37. v2 = mpl.colors.to_rgba(v2)
  38. assert v1 == v2
  39. elif isinstance(v1, np.ndarray):
  40. assert_array_equal(v1, v2)
  41. else:
  42. assert v1 == v2
  43. def assert_legends_equal(leg1, leg2):
  44. assert leg1.get_title().get_text() == leg2.get_title().get_text()
  45. for t1, t2 in zip(leg1.get_texts(), leg2.get_texts()):
  46. assert t1.get_text() == t2.get_text()
  47. assert_artists_equal(
  48. leg1.get_patches(), leg2.get_patches(),
  49. )
  50. assert_artists_equal(
  51. leg1.get_lines(), leg2.get_lines(),
  52. )
  53. def assert_plots_equal(ax1, ax2, labels=True):
  54. assert_artists_equal(ax1.patches, ax2.patches)
  55. assert_artists_equal(ax1.lines, ax2.lines)
  56. assert_artists_equal(ax1.collections, ax2.collections)
  57. if labels:
  58. assert ax1.get_xlabel() == ax2.get_xlabel()
  59. assert ax1.get_ylabel() == ax2.get_ylabel()
  60. def assert_colors_equal(a, b, check_alpha=True):
  61. def handle_array(x):
  62. if isinstance(x, np.ndarray):
  63. if x.ndim > 1:
  64. x = np.unique(x, axis=0).squeeze()
  65. if x.ndim > 1:
  66. raise ValueError("Color arrays must be 1 dimensional")
  67. return x
  68. a = handle_array(a)
  69. b = handle_array(b)
  70. f = to_rgba if check_alpha else to_rgb
  71. assert f(a) == f(b)