test_assert_extension_array_equal.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import numpy as np
  2. import pytest
  3. from pandas import array
  4. import pandas._testing as tm
  5. from pandas.core.arrays.sparse import SparseArray
  6. @pytest.mark.parametrize(
  7. "kwargs",
  8. [
  9. {}, # Default is check_exact=False
  10. {"check_exact": False},
  11. {"check_exact": True},
  12. ],
  13. )
  14. def test_assert_extension_array_equal_not_exact(kwargs):
  15. # see gh-23709
  16. arr1 = SparseArray([-0.17387645482451206, 0.3414148016424936])
  17. arr2 = SparseArray([-0.17387645482451206, 0.3414148016424937])
  18. if kwargs.get("check_exact", False):
  19. msg = """\
  20. ExtensionArray are different
  21. ExtensionArray values are different \\(50\\.0 %\\)
  22. \\[left\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\]
  23. \\[right\\]: \\[-0\\.17387645482.*, 0\\.341414801642.*\\]"""
  24. with pytest.raises(AssertionError, match=msg):
  25. tm.assert_extension_array_equal(arr1, arr2, **kwargs)
  26. else:
  27. tm.assert_extension_array_equal(arr1, arr2, **kwargs)
  28. @pytest.mark.parametrize("decimals", range(10))
  29. def test_assert_extension_array_equal_less_precise(decimals):
  30. rtol = 0.5 * 10**-decimals
  31. arr1 = SparseArray([0.5, 0.123456])
  32. arr2 = SparseArray([0.5, 0.123457])
  33. if decimals >= 5:
  34. msg = """\
  35. ExtensionArray are different
  36. ExtensionArray values are different \\(50\\.0 %\\)
  37. \\[left\\]: \\[0\\.5, 0\\.123456\\]
  38. \\[right\\]: \\[0\\.5, 0\\.123457\\]"""
  39. with pytest.raises(AssertionError, match=msg):
  40. tm.assert_extension_array_equal(arr1, arr2, rtol=rtol)
  41. else:
  42. tm.assert_extension_array_equal(arr1, arr2, rtol=rtol)
  43. def test_assert_extension_array_equal_dtype_mismatch(check_dtype):
  44. end = 5
  45. kwargs = {"check_dtype": check_dtype}
  46. arr1 = SparseArray(np.arange(end, dtype="int64"))
  47. arr2 = SparseArray(np.arange(end, dtype="int32"))
  48. if check_dtype:
  49. msg = """\
  50. ExtensionArray are different
  51. Attribute "dtype" are different
  52. \\[left\\]: Sparse\\[int64, 0\\]
  53. \\[right\\]: Sparse\\[int32, 0\\]"""
  54. with pytest.raises(AssertionError, match=msg):
  55. tm.assert_extension_array_equal(arr1, arr2, **kwargs)
  56. else:
  57. tm.assert_extension_array_equal(arr1, arr2, **kwargs)
  58. def test_assert_extension_array_equal_missing_values():
  59. arr1 = SparseArray([np.nan, 1, 2, np.nan])
  60. arr2 = SparseArray([np.nan, 1, 2, 3])
  61. msg = """\
  62. ExtensionArray NA mask are different
  63. ExtensionArray NA mask values are different \\(25\\.0 %\\)
  64. \\[left\\]: \\[True, False, False, True\\]
  65. \\[right\\]: \\[True, False, False, False\\]"""
  66. with pytest.raises(AssertionError, match=msg):
  67. tm.assert_extension_array_equal(arr1, arr2)
  68. @pytest.mark.parametrize("side", ["left", "right"])
  69. def test_assert_extension_array_equal_non_extension_array(side):
  70. numpy_array = np.arange(5)
  71. extension_array = SparseArray(numpy_array)
  72. msg = f"{side} is not an ExtensionArray"
  73. args = (
  74. (numpy_array, extension_array)
  75. if side == "left"
  76. else (extension_array, numpy_array)
  77. )
  78. with pytest.raises(AssertionError, match=msg):
  79. tm.assert_extension_array_equal(*args)
  80. @pytest.mark.parametrize("right_dtype", ["Int32", "int64"])
  81. def test_assert_extension_array_equal_ignore_dtype_mismatch(right_dtype):
  82. # https://github.com/pandas-dev/pandas/issues/35715
  83. left = array([1, 2, 3], dtype="Int64")
  84. right = array([1, 2, 3], dtype=right_dtype)
  85. tm.assert_extension_array_equal(left, right, check_dtype=False)