test_setops.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. from datetime import datetime
  2. import numpy as np
  3. import pytest
  4. import pandas as pd
  5. from pandas import (
  6. Index,
  7. Series,
  8. )
  9. import pandas._testing as tm
  10. from pandas.core.algorithms import safe_sort
  11. class TestIndexSetOps:
  12. @pytest.mark.parametrize(
  13. "method", ["union", "intersection", "difference", "symmetric_difference"]
  14. )
  15. def test_setops_sort_validation(self, method):
  16. idx1 = Index(["a", "b"])
  17. idx2 = Index(["b", "c"])
  18. with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
  19. getattr(idx1, method)(idx2, sort=2)
  20. # sort=True is supported as of GH#??
  21. getattr(idx1, method)(idx2, sort=True)
  22. def test_setops_preserve_object_dtype(self):
  23. idx = Index([1, 2, 3], dtype=object)
  24. result = idx.intersection(idx[1:])
  25. expected = idx[1:]
  26. tm.assert_index_equal(result, expected)
  27. # if other is not monotonic increasing, intersection goes through
  28. # a different route
  29. result = idx.intersection(idx[1:][::-1])
  30. tm.assert_index_equal(result, expected)
  31. result = idx._union(idx[1:], sort=None)
  32. expected = idx
  33. tm.assert_numpy_array_equal(result, expected.values)
  34. result = idx.union(idx[1:], sort=None)
  35. tm.assert_index_equal(result, expected)
  36. # if other is not monotonic increasing, _union goes through
  37. # a different route
  38. result = idx._union(idx[1:][::-1], sort=None)
  39. tm.assert_numpy_array_equal(result, expected.values)
  40. result = idx.union(idx[1:][::-1], sort=None)
  41. tm.assert_index_equal(result, expected)
  42. def test_union_base(self):
  43. index = Index([0, "a", 1, "b", 2, "c"])
  44. first = index[3:]
  45. second = index[:5]
  46. result = first.union(second)
  47. expected = Index([0, 1, 2, "a", "b", "c"])
  48. tm.assert_index_equal(result, expected)
  49. @pytest.mark.parametrize("klass", [np.array, Series, list])
  50. def test_union_different_type_base(self, klass):
  51. # GH 10149
  52. index = Index([0, "a", 1, "b", 2, "c"])
  53. first = index[3:]
  54. second = index[:5]
  55. result = first.union(klass(second.values))
  56. assert tm.equalContents(result, index)
  57. def test_union_sort_other_incomparable(self):
  58. # https://github.com/pandas-dev/pandas/issues/24959
  59. idx = Index([1, pd.Timestamp("2000")])
  60. # default (sort=None)
  61. with tm.assert_produces_warning(RuntimeWarning):
  62. result = idx.union(idx[:1])
  63. tm.assert_index_equal(result, idx)
  64. # sort=None
  65. with tm.assert_produces_warning(RuntimeWarning):
  66. result = idx.union(idx[:1], sort=None)
  67. tm.assert_index_equal(result, idx)
  68. # sort=False
  69. result = idx.union(idx[:1], sort=False)
  70. tm.assert_index_equal(result, idx)
  71. def test_union_sort_other_incomparable_true(self):
  72. idx = Index([1, pd.Timestamp("2000")])
  73. with pytest.raises(TypeError, match=".*"):
  74. idx.union(idx[:1], sort=True)
  75. def test_intersection_equal_sort_true(self):
  76. idx = Index(["c", "a", "b"])
  77. sorted_ = Index(["a", "b", "c"])
  78. tm.assert_index_equal(idx.intersection(idx, sort=True), sorted_)
  79. def test_intersection_base(self, sort):
  80. # (same results for py2 and py3 but sortedness not tested elsewhere)
  81. index = Index([0, "a", 1, "b", 2, "c"])
  82. first = index[:5]
  83. second = index[:3]
  84. expected = Index([0, 1, "a"]) if sort is None else Index([0, "a", 1])
  85. result = first.intersection(second, sort=sort)
  86. tm.assert_index_equal(result, expected)
  87. @pytest.mark.parametrize("klass", [np.array, Series, list])
  88. def test_intersection_different_type_base(self, klass, sort):
  89. # GH 10149
  90. index = Index([0, "a", 1, "b", 2, "c"])
  91. first = index[:5]
  92. second = index[:3]
  93. result = first.intersection(klass(second.values), sort=sort)
  94. assert tm.equalContents(result, second)
  95. def test_intersection_nosort(self):
  96. result = Index(["c", "b", "a"]).intersection(["b", "a"])
  97. expected = Index(["b", "a"])
  98. tm.assert_index_equal(result, expected)
  99. def test_intersection_equal_sort(self):
  100. idx = Index(["c", "a", "b"])
  101. tm.assert_index_equal(idx.intersection(idx, sort=False), idx)
  102. tm.assert_index_equal(idx.intersection(idx, sort=None), idx)
  103. def test_intersection_str_dates(self, sort):
  104. dt_dates = [datetime(2012, 2, 9), datetime(2012, 2, 22)]
  105. i1 = Index(dt_dates, dtype=object)
  106. i2 = Index(["aa"], dtype=object)
  107. result = i2.intersection(i1, sort=sort)
  108. assert len(result) == 0
  109. @pytest.mark.parametrize(
  110. "index2,expected_arr",
  111. [(Index(["B", "D"]), ["B"]), (Index(["B", "D", "A"]), ["A", "B"])],
  112. )
  113. def test_intersection_non_monotonic_non_unique(self, index2, expected_arr, sort):
  114. # non-monotonic non-unique
  115. index1 = Index(["A", "B", "A", "C"])
  116. expected = Index(expected_arr, dtype="object")
  117. result = index1.intersection(index2, sort=sort)
  118. if sort is None:
  119. expected = expected.sort_values()
  120. tm.assert_index_equal(result, expected)
  121. def test_difference_base(self, sort):
  122. # (same results for py2 and py3 but sortedness not tested elsewhere)
  123. index = Index([0, "a", 1, "b", 2, "c"])
  124. first = index[:4]
  125. second = index[3:]
  126. result = first.difference(second, sort)
  127. expected = Index([0, "a", 1])
  128. if sort is None:
  129. expected = Index(safe_sort(expected))
  130. tm.assert_index_equal(result, expected)
  131. def test_symmetric_difference(self):
  132. # (same results for py2 and py3 but sortedness not tested elsewhere)
  133. index = Index([0, "a", 1, "b", 2, "c"])
  134. first = index[:4]
  135. second = index[3:]
  136. result = first.symmetric_difference(second)
  137. expected = Index([0, 1, 2, "a", "c"])
  138. tm.assert_index_equal(result, expected)
  139. @pytest.mark.parametrize(
  140. "method,expected,sort",
  141. [
  142. (
  143. "intersection",
  144. np.array(
  145. [(1, "A"), (2, "A"), (1, "B"), (2, "B")],
  146. dtype=[("num", int), ("let", "a1")],
  147. ),
  148. False,
  149. ),
  150. (
  151. "intersection",
  152. np.array(
  153. [(1, "A"), (1, "B"), (2, "A"), (2, "B")],
  154. dtype=[("num", int), ("let", "a1")],
  155. ),
  156. None,
  157. ),
  158. (
  159. "union",
  160. np.array(
  161. [(1, "A"), (1, "B"), (1, "C"), (2, "A"), (2, "B"), (2, "C")],
  162. dtype=[("num", int), ("let", "a1")],
  163. ),
  164. None,
  165. ),
  166. ],
  167. )
  168. def test_tuple_union_bug(self, method, expected, sort):
  169. index1 = Index(
  170. np.array(
  171. [(1, "A"), (2, "A"), (1, "B"), (2, "B")],
  172. dtype=[("num", int), ("let", "a1")],
  173. )
  174. )
  175. index2 = Index(
  176. np.array(
  177. [(1, "A"), (2, "A"), (1, "B"), (2, "B"), (1, "C"), (2, "C")],
  178. dtype=[("num", int), ("let", "a1")],
  179. )
  180. )
  181. result = getattr(index1, method)(index2, sort=sort)
  182. assert result.ndim == 1
  183. expected = Index(expected)
  184. tm.assert_index_equal(result, expected)
  185. @pytest.mark.parametrize("first_list", [["b", "a"], []])
  186. @pytest.mark.parametrize("second_list", [["a", "b"], []])
  187. @pytest.mark.parametrize(
  188. "first_name, second_name, expected_name",
  189. [("A", "B", None), (None, "B", None), ("A", None, None)],
  190. )
  191. def test_union_name_preservation(
  192. self, first_list, second_list, first_name, second_name, expected_name, sort
  193. ):
  194. first = Index(first_list, name=first_name)
  195. second = Index(second_list, name=second_name)
  196. union = first.union(second, sort=sort)
  197. vals = set(first_list).union(second_list)
  198. if sort is None and len(first_list) > 0 and len(second_list) > 0:
  199. expected = Index(sorted(vals), name=expected_name)
  200. tm.assert_index_equal(union, expected)
  201. else:
  202. expected = Index(vals, name=expected_name)
  203. tm.equalContents(union, expected)
  204. @pytest.mark.parametrize(
  205. "diff_type, expected",
  206. [["difference", [1, "B"]], ["symmetric_difference", [1, 2, "B", "C"]]],
  207. )
  208. def test_difference_object_type(self, diff_type, expected):
  209. # GH 13432
  210. idx1 = Index([0, 1, "A", "B"])
  211. idx2 = Index([0, 2, "A", "C"])
  212. result = getattr(idx1, diff_type)(idx2)
  213. expected = Index(expected)
  214. tm.assert_index_equal(result, expected)