test_equals.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. """
  2. Tests shared for DatetimeIndex/TimedeltaIndex/PeriodIndex
  3. """
  4. from datetime import (
  5. datetime,
  6. timedelta,
  7. )
  8. import numpy as np
  9. import pytest
  10. import pandas as pd
  11. from pandas import (
  12. CategoricalIndex,
  13. DatetimeIndex,
  14. Index,
  15. PeriodIndex,
  16. TimedeltaIndex,
  17. date_range,
  18. period_range,
  19. )
  20. import pandas._testing as tm
  21. class EqualsTests:
  22. def test_not_equals_numeric(self, index):
  23. assert not index.equals(Index(index.asi8))
  24. assert not index.equals(Index(index.asi8.astype("u8")))
  25. assert not index.equals(Index(index.asi8).astype("f8"))
  26. def test_equals(self, index):
  27. assert index.equals(index)
  28. assert index.equals(index.astype(object))
  29. assert index.equals(CategoricalIndex(index))
  30. assert index.equals(CategoricalIndex(index.astype(object)))
  31. def test_not_equals_non_arraylike(self, index):
  32. assert not index.equals(list(index))
  33. def test_not_equals_strings(self, index):
  34. other = Index([str(x) for x in index], dtype=object)
  35. assert not index.equals(other)
  36. assert not index.equals(CategoricalIndex(other))
  37. def test_not_equals_misc_strs(self, index):
  38. other = Index(list("abc"))
  39. assert not index.equals(other)
  40. class TestPeriodIndexEquals(EqualsTests):
  41. @pytest.fixture
  42. def index(self):
  43. return period_range("2013-01-01", periods=5, freq="D")
  44. # TODO: de-duplicate with other test_equals2 methods
  45. @pytest.mark.parametrize("freq", ["D", "M"])
  46. def test_equals2(self, freq):
  47. # GH#13107
  48. idx = PeriodIndex(["2011-01-01", "2011-01-02", "NaT"], freq=freq)
  49. assert idx.equals(idx)
  50. assert idx.equals(idx.copy())
  51. assert idx.equals(idx.astype(object))
  52. assert idx.astype(object).equals(idx)
  53. assert idx.astype(object).equals(idx.astype(object))
  54. assert not idx.equals(list(idx))
  55. assert not idx.equals(pd.Series(idx))
  56. idx2 = PeriodIndex(["2011-01-01", "2011-01-02", "NaT"], freq="H")
  57. assert not idx.equals(idx2)
  58. assert not idx.equals(idx2.copy())
  59. assert not idx.equals(idx2.astype(object))
  60. assert not idx.astype(object).equals(idx2)
  61. assert not idx.equals(list(idx2))
  62. assert not idx.equals(pd.Series(idx2))
  63. # same internal, different tz
  64. idx3 = PeriodIndex._simple_new(
  65. idx._values._simple_new(idx._values.asi8, freq="H")
  66. )
  67. tm.assert_numpy_array_equal(idx.asi8, idx3.asi8)
  68. assert not idx.equals(idx3)
  69. assert not idx.equals(idx3.copy())
  70. assert not idx.equals(idx3.astype(object))
  71. assert not idx.astype(object).equals(idx3)
  72. assert not idx.equals(list(idx3))
  73. assert not idx.equals(pd.Series(idx3))
  74. class TestDatetimeIndexEquals(EqualsTests):
  75. @pytest.fixture
  76. def index(self):
  77. return date_range("2013-01-01", periods=5)
  78. def test_equals2(self):
  79. # GH#13107
  80. idx = DatetimeIndex(["2011-01-01", "2011-01-02", "NaT"])
  81. assert idx.equals(idx)
  82. assert idx.equals(idx.copy())
  83. assert idx.equals(idx.astype(object))
  84. assert idx.astype(object).equals(idx)
  85. assert idx.astype(object).equals(idx.astype(object))
  86. assert not idx.equals(list(idx))
  87. assert not idx.equals(pd.Series(idx))
  88. idx2 = DatetimeIndex(["2011-01-01", "2011-01-02", "NaT"], tz="US/Pacific")
  89. assert not idx.equals(idx2)
  90. assert not idx.equals(idx2.copy())
  91. assert not idx.equals(idx2.astype(object))
  92. assert not idx.astype(object).equals(idx2)
  93. assert not idx.equals(list(idx2))
  94. assert not idx.equals(pd.Series(idx2))
  95. # same internal, different tz
  96. idx3 = DatetimeIndex(idx.asi8, tz="US/Pacific")
  97. tm.assert_numpy_array_equal(idx.asi8, idx3.asi8)
  98. assert not idx.equals(idx3)
  99. assert not idx.equals(idx3.copy())
  100. assert not idx.equals(idx3.astype(object))
  101. assert not idx.astype(object).equals(idx3)
  102. assert not idx.equals(list(idx3))
  103. assert not idx.equals(pd.Series(idx3))
  104. # check that we do not raise when comparing with OutOfBounds objects
  105. oob = Index([datetime(2500, 1, 1)] * 3, dtype=object)
  106. assert not idx.equals(oob)
  107. assert not idx2.equals(oob)
  108. assert not idx3.equals(oob)
  109. # check that we do not raise when comparing with OutOfBounds dt64
  110. oob2 = oob.map(np.datetime64)
  111. assert not idx.equals(oob2)
  112. assert not idx2.equals(oob2)
  113. assert not idx3.equals(oob2)
  114. @pytest.mark.parametrize("freq", ["B", "C"])
  115. def test_not_equals_bday(self, freq):
  116. rng = date_range("2009-01-01", "2010-01-01", freq=freq)
  117. assert not rng.equals(list(rng))
  118. class TestTimedeltaIndexEquals(EqualsTests):
  119. @pytest.fixture
  120. def index(self):
  121. return tm.makeTimedeltaIndex(10)
  122. def test_equals2(self):
  123. # GH#13107
  124. idx = TimedeltaIndex(["1 days", "2 days", "NaT"])
  125. assert idx.equals(idx)
  126. assert idx.equals(idx.copy())
  127. assert idx.equals(idx.astype(object))
  128. assert idx.astype(object).equals(idx)
  129. assert idx.astype(object).equals(idx.astype(object))
  130. assert not idx.equals(list(idx))
  131. assert not idx.equals(pd.Series(idx))
  132. idx2 = TimedeltaIndex(["2 days", "1 days", "NaT"])
  133. assert not idx.equals(idx2)
  134. assert not idx.equals(idx2.copy())
  135. assert not idx.equals(idx2.astype(object))
  136. assert not idx.astype(object).equals(idx2)
  137. assert not idx.astype(object).equals(idx2.astype(object))
  138. assert not idx.equals(list(idx2))
  139. assert not idx.equals(pd.Series(idx2))
  140. # Check that we dont raise OverflowError on comparisons outside the
  141. # implementation range GH#28532
  142. oob = Index([timedelta(days=10**6)] * 3, dtype=object)
  143. assert not idx.equals(oob)
  144. assert not idx2.equals(oob)
  145. oob2 = Index([np.timedelta64(x) for x in oob], dtype=object)
  146. assert (oob == oob2).all()
  147. assert not idx.equals(oob2)
  148. assert not idx2.equals(oob2)
  149. oob3 = oob.map(np.timedelta64)
  150. assert (oob3 == oob).all()
  151. assert not idx.equals(oob3)
  152. assert not idx2.equals(oob3)