test_equals.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from contextlib import nullcontext
  2. import copy
  3. import numpy as np
  4. import pytest
  5. from pandas._libs.missing import is_matching_na
  6. from pandas.compat.numpy import np_version_gte1p25
  7. from pandas.core.dtypes.common import is_float
  8. from pandas import (
  9. Index,
  10. MultiIndex,
  11. Series,
  12. )
  13. import pandas._testing as tm
  14. @pytest.mark.parametrize(
  15. "arr, idx",
  16. [
  17. ([1, 2, 3, 4], [0, 2, 1, 3]),
  18. ([1, np.nan, 3, np.nan], [0, 2, 1, 3]),
  19. (
  20. [1, np.nan, 3, np.nan],
  21. MultiIndex.from_tuples([(0, "a"), (1, "b"), (2, "c"), (3, "c")]),
  22. ),
  23. ],
  24. )
  25. def test_equals(arr, idx):
  26. s1 = Series(arr, index=idx)
  27. s2 = s1.copy()
  28. assert s1.equals(s2)
  29. s1[1] = 9
  30. assert not s1.equals(s2)
  31. @pytest.mark.parametrize(
  32. "val", [1, 1.1, 1 + 1j, True, "abc", [1, 2], (1, 2), {1, 2}, {"a": 1}, None]
  33. )
  34. def test_equals_list_array(val):
  35. # GH20676 Verify equals operator for list of Numpy arrays
  36. arr = np.array([1, 2])
  37. s1 = Series([arr, arr])
  38. s2 = s1.copy()
  39. assert s1.equals(s2)
  40. s1[1] = val
  41. cm = (
  42. tm.assert_produces_warning(FutureWarning, check_stacklevel=False)
  43. if isinstance(val, str) and not np_version_gte1p25
  44. else nullcontext()
  45. )
  46. with cm:
  47. assert not s1.equals(s2)
  48. def test_equals_false_negative():
  49. # GH8437 Verify false negative behavior of equals function for dtype object
  50. arr = [False, np.nan]
  51. s1 = Series(arr)
  52. s2 = s1.copy()
  53. s3 = Series(index=range(2), dtype=object)
  54. s4 = s3.copy()
  55. s5 = s3.copy()
  56. s6 = s3.copy()
  57. s3[:-1] = s4[:-1] = s5[0] = s6[0] = False
  58. assert s1.equals(s1)
  59. assert s1.equals(s2)
  60. assert s1.equals(s3)
  61. assert s1.equals(s4)
  62. assert s1.equals(s5)
  63. assert s5.equals(s6)
  64. def test_equals_matching_nas():
  65. # matching but not identical NAs
  66. left = Series([np.datetime64("NaT")], dtype=object)
  67. right = Series([np.datetime64("NaT")], dtype=object)
  68. assert left.equals(right)
  69. assert Index(left).equals(Index(right))
  70. assert left.array.equals(right.array)
  71. left = Series([np.timedelta64("NaT")], dtype=object)
  72. right = Series([np.timedelta64("NaT")], dtype=object)
  73. assert left.equals(right)
  74. assert Index(left).equals(Index(right))
  75. assert left.array.equals(right.array)
  76. left = Series([np.float64("NaN")], dtype=object)
  77. right = Series([np.float64("NaN")], dtype=object)
  78. assert left.equals(right)
  79. assert Index(left, dtype=left.dtype).equals(Index(right, dtype=right.dtype))
  80. assert left.array.equals(right.array)
  81. def test_equals_mismatched_nas(nulls_fixture, nulls_fixture2):
  82. # GH#39650
  83. left = nulls_fixture
  84. right = nulls_fixture2
  85. if hasattr(right, "copy"):
  86. right = right.copy()
  87. else:
  88. right = copy.copy(right)
  89. ser = Series([left], dtype=object)
  90. ser2 = Series([right], dtype=object)
  91. if is_matching_na(left, right):
  92. assert ser.equals(ser2)
  93. elif (left is None and is_float(right)) or (right is None and is_float(left)):
  94. assert ser.equals(ser2)
  95. else:
  96. assert not ser.equals(ser2)
  97. def test_equals_none_vs_nan():
  98. # GH#39650
  99. ser = Series([1, None], dtype=object)
  100. ser2 = Series([1, np.nan], dtype=object)
  101. assert ser.equals(ser2)
  102. assert Index(ser, dtype=ser.dtype).equals(Index(ser2, dtype=ser2.dtype))
  103. assert ser.array.equals(ser2.array)
  104. def test_equals_None_vs_float():
  105. # GH#44190
  106. left = Series([-np.inf, np.nan, -1.0, 0.0, 1.0, 10 / 3, np.inf], dtype=object)
  107. right = Series([None] * len(left))
  108. # these series were found to be equal due to a bug, check that they are correctly
  109. # found to not equal
  110. assert not left.equals(right)
  111. assert not right.equals(left)
  112. assert not left.to_frame().equals(right.to_frame())
  113. assert not right.to_frame().equals(left.to_frame())
  114. assert not Index(left, dtype="object").equals(Index(right, dtype="object"))
  115. assert not Index(right, dtype="object").equals(Index(left, dtype="object"))