test_assert_series_equal.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import numpy as np
  2. import pytest
  3. import pandas as pd
  4. from pandas import (
  5. Categorical,
  6. DataFrame,
  7. Series,
  8. )
  9. import pandas._testing as tm
  10. def _assert_series_equal_both(a, b, **kwargs):
  11. """
  12. Check that two Series equal.
  13. This check is performed commutatively.
  14. Parameters
  15. ----------
  16. a : Series
  17. The first Series to compare.
  18. b : Series
  19. The second Series to compare.
  20. kwargs : dict
  21. The arguments passed to `tm.assert_series_equal`.
  22. """
  23. tm.assert_series_equal(a, b, **kwargs)
  24. tm.assert_series_equal(b, a, **kwargs)
  25. def _assert_not_series_equal(a, b, **kwargs):
  26. """
  27. Check that two Series are not equal.
  28. Parameters
  29. ----------
  30. a : Series
  31. The first Series to compare.
  32. b : Series
  33. The second Series to compare.
  34. kwargs : dict
  35. The arguments passed to `tm.assert_series_equal`.
  36. """
  37. try:
  38. tm.assert_series_equal(a, b, **kwargs)
  39. msg = "The two Series were equal when they shouldn't have been"
  40. pytest.fail(msg=msg)
  41. except AssertionError:
  42. pass
  43. def _assert_not_series_equal_both(a, b, **kwargs):
  44. """
  45. Check that two Series are not equal.
  46. This check is performed commutatively.
  47. Parameters
  48. ----------
  49. a : Series
  50. The first Series to compare.
  51. b : Series
  52. The second Series to compare.
  53. kwargs : dict
  54. The arguments passed to `tm.assert_series_equal`.
  55. """
  56. _assert_not_series_equal(a, b, **kwargs)
  57. _assert_not_series_equal(b, a, **kwargs)
  58. @pytest.mark.parametrize("data", [range(3), list("abc"), list("áàä")])
  59. def test_series_equal(data):
  60. _assert_series_equal_both(Series(data), Series(data))
  61. @pytest.mark.parametrize(
  62. "data1,data2",
  63. [
  64. (range(3), range(1, 4)),
  65. (list("abc"), list("xyz")),
  66. (list("áàä"), list("éèë")),
  67. (list("áàä"), list(b"aaa")),
  68. (range(3), range(4)),
  69. ],
  70. )
  71. def test_series_not_equal_value_mismatch(data1, data2):
  72. _assert_not_series_equal_both(Series(data1), Series(data2))
  73. @pytest.mark.parametrize(
  74. "kwargs",
  75. [
  76. {"dtype": "float64"}, # dtype mismatch
  77. {"index": [1, 2, 4]}, # index mismatch
  78. {"name": "foo"}, # name mismatch
  79. ],
  80. )
  81. def test_series_not_equal_metadata_mismatch(kwargs):
  82. data = range(3)
  83. s1 = Series(data)
  84. s2 = Series(data, **kwargs)
  85. _assert_not_series_equal_both(s1, s2)
  86. @pytest.mark.parametrize("data1,data2", [(0.12345, 0.12346), (0.1235, 0.1236)])
  87. @pytest.mark.parametrize("dtype", ["float32", "float64", "Float32"])
  88. @pytest.mark.parametrize("decimals", [0, 1, 2, 3, 5, 10])
  89. def test_less_precise(data1, data2, dtype, decimals):
  90. rtol = 10**-decimals
  91. s1 = Series([data1], dtype=dtype)
  92. s2 = Series([data2], dtype=dtype)
  93. if decimals in (5, 10) or (decimals >= 3 and abs(data1 - data2) >= 0.0005):
  94. msg = "Series values are different"
  95. with pytest.raises(AssertionError, match=msg):
  96. tm.assert_series_equal(s1, s2, rtol=rtol)
  97. else:
  98. _assert_series_equal_both(s1, s2, rtol=rtol)
  99. @pytest.mark.parametrize(
  100. "s1,s2,msg",
  101. [
  102. # Index
  103. (
  104. Series(["l1", "l2"], index=[1, 2]),
  105. Series(["l1", "l2"], index=[1.0, 2.0]),
  106. "Series\\.index are different",
  107. ),
  108. # MultiIndex
  109. (
  110. DataFrame.from_records(
  111. {"a": [1, 2], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"]
  112. ).c,
  113. DataFrame.from_records(
  114. {"a": [1.0, 2.0], "b": [2.1, 1.5], "c": ["l1", "l2"]}, index=["a", "b"]
  115. ).c,
  116. "MultiIndex level \\[0\\] are different",
  117. ),
  118. ],
  119. )
  120. def test_series_equal_index_dtype(s1, s2, msg, check_index_type):
  121. kwargs = {"check_index_type": check_index_type}
  122. if check_index_type:
  123. with pytest.raises(AssertionError, match=msg):
  124. tm.assert_series_equal(s1, s2, **kwargs)
  125. else:
  126. tm.assert_series_equal(s1, s2, **kwargs)
  127. @pytest.mark.parametrize("check_like", [True, False])
  128. def test_series_equal_order_mismatch(check_like):
  129. s1 = Series([1, 2, 3], index=["a", "b", "c"])
  130. s2 = Series([3, 2, 1], index=["c", "b", "a"])
  131. if not check_like: # Do not ignore index ordering.
  132. with pytest.raises(AssertionError, match="Series.index are different"):
  133. tm.assert_series_equal(s1, s2, check_like=check_like)
  134. else:
  135. _assert_series_equal_both(s1, s2, check_like=check_like)
  136. @pytest.mark.parametrize("check_index", [True, False])
  137. def test_series_equal_index_mismatch(check_index):
  138. s1 = Series([1, 2, 3], index=["a", "b", "c"])
  139. s2 = Series([1, 2, 3], index=["c", "b", "a"])
  140. if check_index: # Do not ignore index.
  141. with pytest.raises(AssertionError, match="Series.index are different"):
  142. tm.assert_series_equal(s1, s2, check_index=check_index)
  143. else:
  144. _assert_series_equal_both(s1, s2, check_index=check_index)
  145. def test_series_invalid_param_combination():
  146. left = Series(dtype=object)
  147. right = Series(dtype=object)
  148. with pytest.raises(
  149. ValueError, match="check_like must be False if check_index is False"
  150. ):
  151. tm.assert_series_equal(left, right, check_index=False, check_like=True)
  152. def test_series_equal_length_mismatch(rtol):
  153. msg = """Series are different
  154. Series length are different
  155. \\[left\\]: 3, RangeIndex\\(start=0, stop=3, step=1\\)
  156. \\[right\\]: 4, RangeIndex\\(start=0, stop=4, step=1\\)"""
  157. s1 = Series([1, 2, 3])
  158. s2 = Series([1, 2, 3, 4])
  159. with pytest.raises(AssertionError, match=msg):
  160. tm.assert_series_equal(s1, s2, rtol=rtol)
  161. def test_series_equal_numeric_values_mismatch(rtol):
  162. msg = """Series are different
  163. Series values are different \\(33\\.33333 %\\)
  164. \\[index\\]: \\[0, 1, 2\\]
  165. \\[left\\]: \\[1, 2, 3\\]
  166. \\[right\\]: \\[1, 2, 4\\]"""
  167. s1 = Series([1, 2, 3])
  168. s2 = Series([1, 2, 4])
  169. with pytest.raises(AssertionError, match=msg):
  170. tm.assert_series_equal(s1, s2, rtol=rtol)
  171. def test_series_equal_categorical_values_mismatch(rtol):
  172. msg = """Series are different
  173. Series values are different \\(66\\.66667 %\\)
  174. \\[index\\]: \\[0, 1, 2\\]
  175. \\[left\\]: \\['a', 'b', 'c'\\]
  176. Categories \\(3, object\\): \\['a', 'b', 'c'\\]
  177. \\[right\\]: \\['a', 'c', 'b'\\]
  178. Categories \\(3, object\\): \\['a', 'b', 'c'\\]"""
  179. s1 = Series(Categorical(["a", "b", "c"]))
  180. s2 = Series(Categorical(["a", "c", "b"]))
  181. with pytest.raises(AssertionError, match=msg):
  182. tm.assert_series_equal(s1, s2, rtol=rtol)
  183. def test_series_equal_datetime_values_mismatch(rtol):
  184. msg = """Series are different
  185. Series values are different \\(100.0 %\\)
  186. \\[index\\]: \\[0, 1, 2\\]
  187. \\[left\\]: \\[1514764800000000000, 1514851200000000000, 1514937600000000000\\]
  188. \\[right\\]: \\[1549065600000000000, 1549152000000000000, 1549238400000000000\\]"""
  189. s1 = Series(pd.date_range("2018-01-01", periods=3, freq="D"))
  190. s2 = Series(pd.date_range("2019-02-02", periods=3, freq="D"))
  191. with pytest.raises(AssertionError, match=msg):
  192. tm.assert_series_equal(s1, s2, rtol=rtol)
  193. def test_series_equal_categorical_mismatch(check_categorical):
  194. msg = """Attributes of Series are different
  195. Attribute "dtype" are different
  196. \\[left\\]: CategoricalDtype\\(categories=\\['a', 'b'\\], ordered=False\\)
  197. \\[right\\]: CategoricalDtype\\(categories=\\['a', 'b', 'c'\\], \
  198. ordered=False\\)"""
  199. s1 = Series(Categorical(["a", "b"]))
  200. s2 = Series(Categorical(["a", "b"], categories=list("abc")))
  201. if check_categorical:
  202. with pytest.raises(AssertionError, match=msg):
  203. tm.assert_series_equal(s1, s2, check_categorical=check_categorical)
  204. else:
  205. _assert_series_equal_both(s1, s2, check_categorical=check_categorical)
  206. def test_assert_series_equal_extension_dtype_mismatch():
  207. # https://github.com/pandas-dev/pandas/issues/32747
  208. left = Series(pd.array([1, 2, 3], dtype="Int64"))
  209. right = left.astype(int)
  210. msg = """Attributes of Series are different
  211. Attribute "dtype" are different
  212. \\[left\\]: Int64
  213. \\[right\\]: int[32|64]"""
  214. tm.assert_series_equal(left, right, check_dtype=False)
  215. with pytest.raises(AssertionError, match=msg):
  216. tm.assert_series_equal(left, right, check_dtype=True)
  217. def test_assert_series_equal_interval_dtype_mismatch():
  218. # https://github.com/pandas-dev/pandas/issues/32747
  219. left = Series([pd.Interval(0, 1)], dtype="interval")
  220. right = left.astype(object)
  221. msg = """Attributes of Series are different
  222. Attribute "dtype" are different
  223. \\[left\\]: interval\\[int64, right\\]
  224. \\[right\\]: object"""
  225. tm.assert_series_equal(left, right, check_dtype=False)
  226. with pytest.raises(AssertionError, match=msg):
  227. tm.assert_series_equal(left, right, check_dtype=True)
  228. def test_series_equal_series_type():
  229. class MySeries(Series):
  230. pass
  231. s1 = Series([1, 2])
  232. s2 = Series([1, 2])
  233. s3 = MySeries([1, 2])
  234. tm.assert_series_equal(s1, s2, check_series_type=False)
  235. tm.assert_series_equal(s1, s2, check_series_type=True)
  236. tm.assert_series_equal(s1, s3, check_series_type=False)
  237. tm.assert_series_equal(s3, s1, check_series_type=False)
  238. with pytest.raises(AssertionError, match="Series classes are different"):
  239. tm.assert_series_equal(s1, s3, check_series_type=True)
  240. with pytest.raises(AssertionError, match="Series classes are different"):
  241. tm.assert_series_equal(s3, s1, check_series_type=True)
  242. def test_series_equal_exact_for_nonnumeric():
  243. # https://github.com/pandas-dev/pandas/issues/35446
  244. s1 = Series(["a", "b"])
  245. s2 = Series(["a", "b"])
  246. s3 = Series(["b", "a"])
  247. tm.assert_series_equal(s1, s2, check_exact=True)
  248. tm.assert_series_equal(s2, s1, check_exact=True)
  249. msg = """Series are different
  250. Series values are different \\(100\\.0 %\\)
  251. \\[index\\]: \\[0, 1\\]
  252. \\[left\\]: \\[a, b\\]
  253. \\[right\\]: \\[b, a\\]"""
  254. with pytest.raises(AssertionError, match=msg):
  255. tm.assert_series_equal(s1, s3, check_exact=True)
  256. msg = """Series are different
  257. Series values are different \\(100\\.0 %\\)
  258. \\[index\\]: \\[0, 1\\]
  259. \\[left\\]: \\[b, a\\]
  260. \\[right\\]: \\[a, b\\]"""
  261. with pytest.raises(AssertionError, match=msg):
  262. tm.assert_series_equal(s3, s1, check_exact=True)
  263. @pytest.mark.parametrize("right_dtype", ["Int32", "int64"])
  264. def test_assert_series_equal_ignore_extension_dtype_mismatch(right_dtype):
  265. # https://github.com/pandas-dev/pandas/issues/35715
  266. left = Series([1, 2, 3], dtype="Int64")
  267. right = Series([1, 2, 3], dtype=right_dtype)
  268. tm.assert_series_equal(left, right, check_dtype=False)
  269. def test_allows_duplicate_labels():
  270. left = Series([1])
  271. right = Series([1]).set_flags(allows_duplicate_labels=False)
  272. tm.assert_series_equal(left, left)
  273. tm.assert_series_equal(right, right)
  274. tm.assert_series_equal(left, right, check_flags=False)
  275. tm.assert_series_equal(right, left, check_flags=False)
  276. with pytest.raises(AssertionError, match="<Flags"):
  277. tm.assert_series_equal(left, right)
  278. with pytest.raises(AssertionError, match="<Flags"):
  279. tm.assert_series_equal(left, right)
  280. def test_assert_series_equal_identical_na(nulls_fixture):
  281. ser = Series([nulls_fixture])
  282. tm.assert_series_equal(ser, ser.copy())
  283. # while we're here do Index too
  284. idx = pd.Index(ser)
  285. tm.assert_index_equal(idx, idx.copy(deep=True))
  286. def test_identical_nested_series_is_equal():
  287. # GH#22400
  288. x = Series(
  289. [
  290. 0,
  291. 0.0131142231938,
  292. 1.77774652865e-05,
  293. np.array([0.4722720840328748, 0.4216929783681722]),
  294. ]
  295. )
  296. y = Series(
  297. [
  298. 0,
  299. 0.0131142231938,
  300. 1.77774652865e-05,
  301. np.array([0.4722720840328748, 0.4216929783681722]),
  302. ]
  303. )
  304. # These two arrays should be equal, nesting could cause issue
  305. tm.assert_series_equal(x, x)
  306. tm.assert_series_equal(x, x, check_exact=True)
  307. tm.assert_series_equal(x, y)
  308. tm.assert_series_equal(x, y, check_exact=True)
  309. @pytest.mark.parametrize("dtype", ["datetime64", "timedelta64"])
  310. def test_check_dtype_false_different_reso(dtype):
  311. # GH 52449
  312. ser_s = Series([1000213, 2131232, 21312331]).astype(f"{dtype}[s]")
  313. ser_ms = ser_s.astype(f"{dtype}[ms]")
  314. with pytest.raises(AssertionError, match="Attributes of Series are different"):
  315. tm.assert_series_equal(ser_s, ser_ms)
  316. tm.assert_series_equal(ser_ms, ser_s, check_dtype=False)
  317. ser_ms -= Series([1, 1, 1]).astype(f"{dtype}[ms]")
  318. with pytest.raises(AssertionError, match="Series are different"):
  319. tm.assert_series_equal(ser_s, ser_ms)
  320. with pytest.raises(AssertionError, match="Series are different"):
  321. tm.assert_series_equal(ser_s, ser_ms, check_dtype=False)