test_comparisons.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. from datetime import (
  2. datetime,
  3. timedelta,
  4. )
  5. import operator
  6. import numpy as np
  7. import pytest
  8. from pandas import Timestamp
  9. import pandas._testing as tm
  10. class TestTimestampComparison:
  11. def test_compare_non_nano_dt64(self):
  12. # don't raise when converting dt64 to Timestamp in __richcmp__
  13. dt = np.datetime64("1066-10-14")
  14. ts = Timestamp(dt)
  15. assert dt == ts
  16. def test_comparison_dt64_ndarray(self):
  17. ts = Timestamp("2021-01-01")
  18. ts2 = Timestamp("2019-04-05")
  19. arr = np.array([[ts.asm8, ts2.asm8]], dtype="M8[ns]")
  20. result = ts == arr
  21. expected = np.array([[True, False]], dtype=bool)
  22. tm.assert_numpy_array_equal(result, expected)
  23. result = arr == ts
  24. tm.assert_numpy_array_equal(result, expected)
  25. result = ts != arr
  26. tm.assert_numpy_array_equal(result, ~expected)
  27. result = arr != ts
  28. tm.assert_numpy_array_equal(result, ~expected)
  29. result = ts2 < arr
  30. tm.assert_numpy_array_equal(result, expected)
  31. result = arr < ts2
  32. tm.assert_numpy_array_equal(result, np.array([[False, False]], dtype=bool))
  33. result = ts2 <= arr
  34. tm.assert_numpy_array_equal(result, np.array([[True, True]], dtype=bool))
  35. result = arr <= ts2
  36. tm.assert_numpy_array_equal(result, ~expected)
  37. result = ts >= arr
  38. tm.assert_numpy_array_equal(result, np.array([[True, True]], dtype=bool))
  39. result = arr >= ts
  40. tm.assert_numpy_array_equal(result, np.array([[True, False]], dtype=bool))
  41. @pytest.mark.parametrize("reverse", [True, False])
  42. def test_comparison_dt64_ndarray_tzaware(self, reverse, comparison_op):
  43. ts = Timestamp("2021-01-01 00:00:00.00000", tz="UTC")
  44. arr = np.array([ts.asm8, ts.asm8], dtype="M8[ns]")
  45. left, right = ts, arr
  46. if reverse:
  47. left, right = arr, ts
  48. if comparison_op is operator.eq:
  49. expected = np.array([False, False], dtype=bool)
  50. result = comparison_op(left, right)
  51. tm.assert_numpy_array_equal(result, expected)
  52. elif comparison_op is operator.ne:
  53. expected = np.array([True, True], dtype=bool)
  54. result = comparison_op(left, right)
  55. tm.assert_numpy_array_equal(result, expected)
  56. else:
  57. msg = "Cannot compare tz-naive and tz-aware timestamps"
  58. with pytest.raises(TypeError, match=msg):
  59. comparison_op(left, right)
  60. def test_comparison_object_array(self):
  61. # GH#15183
  62. ts = Timestamp("2011-01-03 00:00:00-0500", tz="US/Eastern")
  63. other = Timestamp("2011-01-01 00:00:00-0500", tz="US/Eastern")
  64. naive = Timestamp("2011-01-01 00:00:00")
  65. arr = np.array([other, ts], dtype=object)
  66. res = arr == ts
  67. expected = np.array([False, True], dtype=bool)
  68. assert (res == expected).all()
  69. # 2D case
  70. arr = np.array([[other, ts], [ts, other]], dtype=object)
  71. res = arr != ts
  72. expected = np.array([[True, False], [False, True]], dtype=bool)
  73. assert res.shape == expected.shape
  74. assert (res == expected).all()
  75. # tzaware mismatch
  76. arr = np.array([naive], dtype=object)
  77. msg = "Cannot compare tz-naive and tz-aware timestamps"
  78. with pytest.raises(TypeError, match=msg):
  79. arr < ts
  80. def test_comparison(self):
  81. # 5-18-2012 00:00:00.000
  82. stamp = 1337299200000000000
  83. val = Timestamp(stamp)
  84. assert val == val
  85. assert not val != val
  86. assert not val < val
  87. assert val <= val
  88. assert not val > val
  89. assert val >= val
  90. other = datetime(2012, 5, 18)
  91. assert val == other
  92. assert not val != other
  93. assert not val < other
  94. assert val <= other
  95. assert not val > other
  96. assert val >= other
  97. other = Timestamp(stamp + 100)
  98. assert val != other
  99. assert val != other
  100. assert val < other
  101. assert val <= other
  102. assert other > val
  103. assert other >= val
  104. def test_compare_invalid(self):
  105. # GH#8058
  106. val = Timestamp("20130101 12:01:02")
  107. assert not val == "foo"
  108. assert not val == 10.0
  109. assert not val == 1
  110. assert not val == []
  111. assert not val == {"foo": 1}
  112. assert not val == np.float64(1)
  113. assert not val == np.int64(1)
  114. assert val != "foo"
  115. assert val != 10.0
  116. assert val != 1
  117. assert val != []
  118. assert val != {"foo": 1}
  119. assert val != np.float64(1)
  120. assert val != np.int64(1)
  121. @pytest.mark.parametrize("tz", [None, "US/Pacific"])
  122. def test_compare_date(self, tz):
  123. # GH#36131 comparing Timestamp with date object is deprecated
  124. ts = Timestamp("2021-01-01 00:00:00.00000", tz=tz)
  125. dt = ts.to_pydatetime().date()
  126. # in 2.0 we disallow comparing pydate objects with Timestamps,
  127. # following the stdlib datetime behavior.
  128. msg = "Cannot compare Timestamp with datetime.date"
  129. for left, right in [(ts, dt), (dt, ts)]:
  130. assert not left == right
  131. assert left != right
  132. with pytest.raises(TypeError, match=msg):
  133. left < right
  134. with pytest.raises(TypeError, match=msg):
  135. left <= right
  136. with pytest.raises(TypeError, match=msg):
  137. left > right
  138. with pytest.raises(TypeError, match=msg):
  139. left >= right
  140. def test_cant_compare_tz_naive_w_aware(self, utc_fixture):
  141. # see GH#1404
  142. a = Timestamp("3/12/2012")
  143. b = Timestamp("3/12/2012", tz=utc_fixture)
  144. msg = "Cannot compare tz-naive and tz-aware timestamps"
  145. assert not a == b
  146. assert a != b
  147. with pytest.raises(TypeError, match=msg):
  148. a < b
  149. with pytest.raises(TypeError, match=msg):
  150. a <= b
  151. with pytest.raises(TypeError, match=msg):
  152. a > b
  153. with pytest.raises(TypeError, match=msg):
  154. a >= b
  155. assert not b == a
  156. assert b != a
  157. with pytest.raises(TypeError, match=msg):
  158. b < a
  159. with pytest.raises(TypeError, match=msg):
  160. b <= a
  161. with pytest.raises(TypeError, match=msg):
  162. b > a
  163. with pytest.raises(TypeError, match=msg):
  164. b >= a
  165. assert not a == b.to_pydatetime()
  166. assert not a.to_pydatetime() == b
  167. def test_timestamp_compare_scalars(self):
  168. # case where ndim == 0
  169. lhs = np.datetime64(datetime(2013, 12, 6))
  170. rhs = Timestamp("now")
  171. nat = Timestamp("nat")
  172. ops = {"gt": "lt", "lt": "gt", "ge": "le", "le": "ge", "eq": "eq", "ne": "ne"}
  173. for left, right in ops.items():
  174. left_f = getattr(operator, left)
  175. right_f = getattr(operator, right)
  176. expected = left_f(lhs, rhs)
  177. result = right_f(rhs, lhs)
  178. assert result == expected
  179. expected = left_f(rhs, nat)
  180. result = right_f(nat, rhs)
  181. assert result == expected
  182. def test_timestamp_compare_with_early_datetime(self):
  183. # e.g. datetime.min
  184. stamp = Timestamp("2012-01-01")
  185. assert not stamp == datetime.min
  186. assert not stamp == datetime(1600, 1, 1)
  187. assert not stamp == datetime(2700, 1, 1)
  188. assert stamp != datetime.min
  189. assert stamp != datetime(1600, 1, 1)
  190. assert stamp != datetime(2700, 1, 1)
  191. assert stamp > datetime(1600, 1, 1)
  192. assert stamp >= datetime(1600, 1, 1)
  193. assert stamp < datetime(2700, 1, 1)
  194. assert stamp <= datetime(2700, 1, 1)
  195. other = Timestamp.min.to_pydatetime(warn=False)
  196. assert other - timedelta(microseconds=1) < Timestamp.min
  197. def test_timestamp_compare_oob_dt64(self):
  198. us = np.timedelta64(1, "us")
  199. other = np.datetime64(Timestamp.min).astype("M8[us]")
  200. # This may change if the implementation bound is dropped to match
  201. # DatetimeArray/DatetimeIndex GH#24124
  202. assert Timestamp.min > other
  203. # Note: numpy gets the reversed comparison wrong
  204. other = np.datetime64(Timestamp.max).astype("M8[us]")
  205. assert Timestamp.max > other # not actually OOB
  206. assert other < Timestamp.max
  207. assert Timestamp.max < other + us
  208. # Note: numpy gets the reversed comparison wrong
  209. # GH-42794
  210. other = datetime(9999, 9, 9)
  211. assert Timestamp.min < other
  212. assert other > Timestamp.min
  213. assert Timestamp.max < other
  214. assert other > Timestamp.max
  215. other = datetime(1, 1, 1)
  216. assert Timestamp.max > other
  217. assert other < Timestamp.max
  218. assert Timestamp.min > other
  219. assert other < Timestamp.min
  220. def test_compare_zerodim_array(self, fixed_now_ts):
  221. # GH#26916
  222. ts = fixed_now_ts
  223. dt64 = np.datetime64("2016-01-01", "ns")
  224. arr = np.array(dt64)
  225. assert arr.ndim == 0
  226. result = arr < ts
  227. assert result is np.bool_(True)
  228. result = arr > ts
  229. assert result is np.bool_(False)
  230. def test_rich_comparison_with_unsupported_type():
  231. # Comparisons with unsupported objects should return NotImplemented
  232. # (it previously raised TypeError, see #24011)
  233. class Inf:
  234. def __lt__(self, o):
  235. return False
  236. def __le__(self, o):
  237. return isinstance(o, Inf)
  238. def __gt__(self, o):
  239. return not isinstance(o, Inf)
  240. def __ge__(self, o):
  241. return True
  242. def __eq__(self, other) -> bool:
  243. return isinstance(other, Inf)
  244. inf = Inf()
  245. timestamp = Timestamp("2018-11-30")
  246. for left, right in [(inf, timestamp), (timestamp, inf)]:
  247. assert left > right or left < right
  248. assert left >= right or left <= right
  249. assert not left == right # pylint: disable=unneeded-not
  250. assert left != right