test_setops.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. from datetime import (
  2. datetime,
  3. timedelta,
  4. )
  5. from hypothesis import (
  6. assume,
  7. given,
  8. strategies as st,
  9. )
  10. import numpy as np
  11. import pytest
  12. from pandas import (
  13. Index,
  14. RangeIndex,
  15. )
  16. import pandas._testing as tm
  17. class TestRangeIndexSetOps:
  18. @pytest.mark.parametrize("dtype", [None, "int64", "uint64"])
  19. def test_intersection_mismatched_dtype(self, dtype):
  20. # check that we cast to float, not object
  21. index = RangeIndex(start=0, stop=20, step=2, name="foo")
  22. index = Index(index, dtype=dtype)
  23. flt = index.astype(np.float64)
  24. # bc index.equals(flt), we go through fastpath and get RangeIndex back
  25. result = index.intersection(flt)
  26. tm.assert_index_equal(result, index, exact=True)
  27. result = flt.intersection(index)
  28. tm.assert_index_equal(result, flt, exact=True)
  29. # neither empty, not-equals
  30. result = index.intersection(flt[1:])
  31. tm.assert_index_equal(result, flt[1:], exact=True)
  32. result = flt[1:].intersection(index)
  33. tm.assert_index_equal(result, flt[1:], exact=True)
  34. # empty other
  35. result = index.intersection(flt[:0])
  36. tm.assert_index_equal(result, flt[:0], exact=True)
  37. result = flt[:0].intersection(index)
  38. tm.assert_index_equal(result, flt[:0], exact=True)
  39. def test_intersection_empty(self, sort, names):
  40. # name retention on empty intersections
  41. index = RangeIndex(start=0, stop=20, step=2, name=names[0])
  42. # empty other
  43. result = index.intersection(index[:0].rename(names[1]), sort=sort)
  44. tm.assert_index_equal(result, index[:0].rename(names[2]), exact=True)
  45. # empty self
  46. result = index[:0].intersection(index.rename(names[1]), sort=sort)
  47. tm.assert_index_equal(result, index[:0].rename(names[2]), exact=True)
  48. def test_intersection(self, sort):
  49. # intersect with Index with dtype int64
  50. index = RangeIndex(start=0, stop=20, step=2)
  51. other = Index(np.arange(1, 6))
  52. result = index.intersection(other, sort=sort)
  53. expected = Index(np.sort(np.intersect1d(index.values, other.values)))
  54. tm.assert_index_equal(result, expected)
  55. result = other.intersection(index, sort=sort)
  56. expected = Index(
  57. np.sort(np.asarray(np.intersect1d(index.values, other.values)))
  58. )
  59. tm.assert_index_equal(result, expected)
  60. # intersect with increasing RangeIndex
  61. other = RangeIndex(1, 6)
  62. result = index.intersection(other, sort=sort)
  63. expected = Index(np.sort(np.intersect1d(index.values, other.values)))
  64. tm.assert_index_equal(result, expected, exact="equiv")
  65. # intersect with decreasing RangeIndex
  66. other = RangeIndex(5, 0, -1)
  67. result = index.intersection(other, sort=sort)
  68. expected = Index(np.sort(np.intersect1d(index.values, other.values)))
  69. tm.assert_index_equal(result, expected, exact="equiv")
  70. # reversed (GH 17296)
  71. result = other.intersection(index, sort=sort)
  72. tm.assert_index_equal(result, expected, exact="equiv")
  73. # GH 17296: intersect two decreasing RangeIndexes
  74. first = RangeIndex(10, -2, -2)
  75. other = RangeIndex(5, -4, -1)
  76. expected = first.astype(int).intersection(other.astype(int), sort=sort)
  77. result = first.intersection(other, sort=sort).astype(int)
  78. tm.assert_index_equal(result, expected)
  79. # reversed
  80. result = other.intersection(first, sort=sort).astype(int)
  81. tm.assert_index_equal(result, expected)
  82. index = RangeIndex(5, name="foo")
  83. # intersect of non-overlapping indices
  84. other = RangeIndex(5, 10, 1, name="foo")
  85. result = index.intersection(other, sort=sort)
  86. expected = RangeIndex(0, 0, 1, name="foo")
  87. tm.assert_index_equal(result, expected)
  88. other = RangeIndex(-1, -5, -1)
  89. result = index.intersection(other, sort=sort)
  90. expected = RangeIndex(0, 0, 1)
  91. tm.assert_index_equal(result, expected)
  92. # intersection of empty indices
  93. other = RangeIndex(0, 0, 1)
  94. result = index.intersection(other, sort=sort)
  95. expected = RangeIndex(0, 0, 1)
  96. tm.assert_index_equal(result, expected)
  97. result = other.intersection(index, sort=sort)
  98. tm.assert_index_equal(result, expected)
  99. def test_intersection_non_overlapping_gcd(self, sort, names):
  100. # intersection of non-overlapping values based on start value and gcd
  101. index = RangeIndex(1, 10, 2, name=names[0])
  102. other = RangeIndex(0, 10, 4, name=names[1])
  103. result = index.intersection(other, sort=sort)
  104. expected = RangeIndex(0, 0, 1, name=names[2])
  105. tm.assert_index_equal(result, expected)
  106. def test_union_noncomparable(self, sort):
  107. # corner case, Index with non-int64 dtype
  108. index = RangeIndex(start=0, stop=20, step=2)
  109. other = Index([datetime.now() + timedelta(i) for i in range(4)], dtype=object)
  110. result = index.union(other, sort=sort)
  111. expected = Index(np.concatenate((index, other)))
  112. tm.assert_index_equal(result, expected)
  113. result = other.union(index, sort=sort)
  114. expected = Index(np.concatenate((other, index)))
  115. tm.assert_index_equal(result, expected)
  116. @pytest.mark.parametrize(
  117. "idx1, idx2, expected_sorted, expected_notsorted",
  118. [
  119. (
  120. RangeIndex(0, 10, 1),
  121. RangeIndex(0, 10, 1),
  122. RangeIndex(0, 10, 1),
  123. RangeIndex(0, 10, 1),
  124. ),
  125. (
  126. RangeIndex(0, 10, 1),
  127. RangeIndex(5, 20, 1),
  128. RangeIndex(0, 20, 1),
  129. RangeIndex(0, 20, 1),
  130. ),
  131. (
  132. RangeIndex(0, 10, 1),
  133. RangeIndex(10, 20, 1),
  134. RangeIndex(0, 20, 1),
  135. RangeIndex(0, 20, 1),
  136. ),
  137. (
  138. RangeIndex(0, -10, -1),
  139. RangeIndex(0, -10, -1),
  140. RangeIndex(0, -10, -1),
  141. RangeIndex(0, -10, -1),
  142. ),
  143. (
  144. RangeIndex(0, -10, -1),
  145. RangeIndex(-10, -20, -1),
  146. RangeIndex(-19, 1, 1),
  147. RangeIndex(0, -20, -1),
  148. ),
  149. (
  150. RangeIndex(0, 10, 2),
  151. RangeIndex(1, 10, 2),
  152. RangeIndex(0, 10, 1),
  153. Index(list(range(0, 10, 2)) + list(range(1, 10, 2))),
  154. ),
  155. (
  156. RangeIndex(0, 11, 2),
  157. RangeIndex(1, 12, 2),
  158. RangeIndex(0, 12, 1),
  159. Index(list(range(0, 11, 2)) + list(range(1, 12, 2))),
  160. ),
  161. (
  162. RangeIndex(0, 21, 4),
  163. RangeIndex(-2, 24, 4),
  164. RangeIndex(-2, 24, 2),
  165. Index(list(range(0, 21, 4)) + list(range(-2, 24, 4))),
  166. ),
  167. (
  168. RangeIndex(0, -20, -2),
  169. RangeIndex(-1, -21, -2),
  170. RangeIndex(-19, 1, 1),
  171. Index(list(range(0, -20, -2)) + list(range(-1, -21, -2))),
  172. ),
  173. (
  174. RangeIndex(0, 100, 5),
  175. RangeIndex(0, 100, 20),
  176. RangeIndex(0, 100, 5),
  177. RangeIndex(0, 100, 5),
  178. ),
  179. (
  180. RangeIndex(0, -100, -5),
  181. RangeIndex(5, -100, -20),
  182. RangeIndex(-95, 10, 5),
  183. Index(list(range(0, -100, -5)) + [5]),
  184. ),
  185. (
  186. RangeIndex(0, -11, -1),
  187. RangeIndex(1, -12, -4),
  188. RangeIndex(-11, 2, 1),
  189. Index(list(range(0, -11, -1)) + [1, -11]),
  190. ),
  191. (RangeIndex(0), RangeIndex(0), RangeIndex(0), RangeIndex(0)),
  192. (
  193. RangeIndex(0, -10, -2),
  194. RangeIndex(0),
  195. RangeIndex(0, -10, -2),
  196. RangeIndex(0, -10, -2),
  197. ),
  198. (
  199. RangeIndex(0, 100, 2),
  200. RangeIndex(100, 150, 200),
  201. RangeIndex(0, 102, 2),
  202. RangeIndex(0, 102, 2),
  203. ),
  204. (
  205. RangeIndex(0, -100, -2),
  206. RangeIndex(-100, 50, 102),
  207. RangeIndex(-100, 4, 2),
  208. Index(list(range(0, -100, -2)) + [-100, 2]),
  209. ),
  210. (
  211. RangeIndex(0, -100, -1),
  212. RangeIndex(0, -50, -3),
  213. RangeIndex(-99, 1, 1),
  214. RangeIndex(0, -100, -1),
  215. ),
  216. (
  217. RangeIndex(0, 1, 1),
  218. RangeIndex(5, 6, 10),
  219. RangeIndex(0, 6, 5),
  220. RangeIndex(0, 10, 5),
  221. ),
  222. (
  223. RangeIndex(0, 10, 5),
  224. RangeIndex(-5, -6, -20),
  225. RangeIndex(-5, 10, 5),
  226. Index([0, 5, -5]),
  227. ),
  228. (
  229. RangeIndex(0, 3, 1),
  230. RangeIndex(4, 5, 1),
  231. Index([0, 1, 2, 4]),
  232. Index([0, 1, 2, 4]),
  233. ),
  234. (
  235. RangeIndex(0, 10, 1),
  236. Index([], dtype=np.int64),
  237. RangeIndex(0, 10, 1),
  238. RangeIndex(0, 10, 1),
  239. ),
  240. (
  241. RangeIndex(0),
  242. Index([1, 5, 6]),
  243. Index([1, 5, 6]),
  244. Index([1, 5, 6]),
  245. ),
  246. # GH 43885
  247. (
  248. RangeIndex(0, 10),
  249. RangeIndex(0, 5),
  250. RangeIndex(0, 10),
  251. RangeIndex(0, 10),
  252. ),
  253. ],
  254. ids=lambda x: repr(x) if isinstance(x, RangeIndex) else x,
  255. )
  256. def test_union_sorted(self, idx1, idx2, expected_sorted, expected_notsorted):
  257. res1 = idx1.union(idx2, sort=None)
  258. tm.assert_index_equal(res1, expected_sorted, exact=True)
  259. res1 = idx1.union(idx2, sort=False)
  260. tm.assert_index_equal(res1, expected_notsorted, exact=True)
  261. res2 = idx2.union(idx1, sort=None)
  262. res3 = Index(idx1._values, name=idx1.name).union(idx2, sort=None)
  263. tm.assert_index_equal(res2, expected_sorted, exact=True)
  264. tm.assert_index_equal(res3, expected_sorted, exact="equiv")
  265. def test_union_same_step_misaligned(self):
  266. # GH#44019
  267. left = RangeIndex(range(0, 20, 4))
  268. right = RangeIndex(range(1, 21, 4))
  269. result = left.union(right)
  270. expected = Index([0, 1, 4, 5, 8, 9, 12, 13, 16, 17])
  271. tm.assert_index_equal(result, expected, exact=True)
  272. def test_difference(self):
  273. # GH#12034 Cases where we operate against another RangeIndex and may
  274. # get back another RangeIndex
  275. obj = RangeIndex.from_range(range(1, 10), name="foo")
  276. result = obj.difference(obj)
  277. expected = RangeIndex.from_range(range(0), name="foo")
  278. tm.assert_index_equal(result, expected, exact=True)
  279. result = obj.difference(expected.rename("bar"))
  280. tm.assert_index_equal(result, obj.rename(None), exact=True)
  281. result = obj.difference(obj[:3])
  282. tm.assert_index_equal(result, obj[3:], exact=True)
  283. result = obj.difference(obj[-3:])
  284. tm.assert_index_equal(result, obj[:-3], exact=True)
  285. # Flipping the step of 'other' doesn't affect the result, but
  286. # flipping the stepof 'self' does when sort=None
  287. result = obj[::-1].difference(obj[-3:])
  288. tm.assert_index_equal(result, obj[:-3], exact=True)
  289. result = obj[::-1].difference(obj[-3:], sort=False)
  290. tm.assert_index_equal(result, obj[:-3][::-1], exact=True)
  291. result = obj[::-1].difference(obj[-3:][::-1])
  292. tm.assert_index_equal(result, obj[:-3], exact=True)
  293. result = obj[::-1].difference(obj[-3:][::-1], sort=False)
  294. tm.assert_index_equal(result, obj[:-3][::-1], exact=True)
  295. result = obj.difference(obj[2:6])
  296. expected = Index([1, 2, 7, 8, 9], name="foo")
  297. tm.assert_index_equal(result, expected, exact=True)
  298. def test_difference_sort(self):
  299. # GH#44085 ensure we respect the sort keyword
  300. idx = Index(range(4))[::-1]
  301. other = Index(range(3, 4))
  302. result = idx.difference(other)
  303. expected = Index(range(3))
  304. tm.assert_index_equal(result, expected, exact=True)
  305. result = idx.difference(other, sort=False)
  306. expected = expected[::-1]
  307. tm.assert_index_equal(result, expected, exact=True)
  308. # case where the intersection is empty
  309. other = range(10, 12)
  310. result = idx.difference(other, sort=None)
  311. expected = idx[::-1]
  312. tm.assert_index_equal(result, expected, exact=True)
  313. def test_difference_mismatched_step(self):
  314. obj = RangeIndex.from_range(range(1, 10), name="foo")
  315. result = obj.difference(obj[::2])
  316. expected = obj[1::2]
  317. tm.assert_index_equal(result, expected, exact=True)
  318. result = obj[::-1].difference(obj[::2], sort=False)
  319. tm.assert_index_equal(result, expected[::-1], exact=True)
  320. result = obj.difference(obj[1::2])
  321. expected = obj[::2]
  322. tm.assert_index_equal(result, expected, exact=True)
  323. result = obj[::-1].difference(obj[1::2], sort=False)
  324. tm.assert_index_equal(result, expected[::-1], exact=True)
  325. def test_difference_interior_overlap_endpoints_preserved(self):
  326. left = RangeIndex(range(4))
  327. right = RangeIndex(range(1, 3))
  328. result = left.difference(right)
  329. expected = RangeIndex(0, 4, 3)
  330. assert expected.tolist() == [0, 3]
  331. tm.assert_index_equal(result, expected, exact=True)
  332. def test_difference_endpoints_overlap_interior_preserved(self):
  333. left = RangeIndex(-8, 20, 7)
  334. right = RangeIndex(13, -9, -3)
  335. result = left.difference(right)
  336. expected = RangeIndex(-1, 13, 7)
  337. assert expected.tolist() == [-1, 6]
  338. tm.assert_index_equal(result, expected, exact=True)
  339. def test_difference_interior_non_preserving(self):
  340. # case with intersection of length 1 but RangeIndex is not preserved
  341. idx = Index(range(10))
  342. other = idx[3:4]
  343. result = idx.difference(other)
  344. expected = Index([0, 1, 2, 4, 5, 6, 7, 8, 9])
  345. tm.assert_index_equal(result, expected, exact=True)
  346. # case with other.step / self.step > 2
  347. other = idx[::3]
  348. result = idx.difference(other)
  349. expected = Index([1, 2, 4, 5, 7, 8])
  350. tm.assert_index_equal(result, expected, exact=True)
  351. # cases with only reaching one end of left
  352. obj = Index(range(20))
  353. other = obj[:10:2]
  354. result = obj.difference(other)
  355. expected = Index([1, 3, 5, 7, 9] + list(range(10, 20)))
  356. tm.assert_index_equal(result, expected, exact=True)
  357. other = obj[1:11:2]
  358. result = obj.difference(other)
  359. expected = Index([0, 2, 4, 6, 8, 10] + list(range(11, 20)))
  360. tm.assert_index_equal(result, expected, exact=True)
  361. def test_symmetric_difference(self):
  362. # GH#12034 Cases where we operate against another RangeIndex and may
  363. # get back another RangeIndex
  364. left = RangeIndex.from_range(range(1, 10), name="foo")
  365. result = left.symmetric_difference(left)
  366. expected = RangeIndex.from_range(range(0), name="foo")
  367. tm.assert_index_equal(result, expected)
  368. result = left.symmetric_difference(expected.rename("bar"))
  369. tm.assert_index_equal(result, left.rename(None))
  370. result = left[:-2].symmetric_difference(left[2:])
  371. expected = Index([1, 2, 8, 9], name="foo")
  372. tm.assert_index_equal(result, expected, exact=True)
  373. right = RangeIndex.from_range(range(10, 15))
  374. result = left.symmetric_difference(right)
  375. expected = RangeIndex.from_range(range(1, 15))
  376. tm.assert_index_equal(result, expected)
  377. result = left.symmetric_difference(right[1:])
  378. expected = Index([1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14])
  379. tm.assert_index_equal(result, expected, exact=True)
  380. def assert_range_or_not_is_rangelike(index):
  381. """
  382. Check that we either have a RangeIndex or that this index *cannot*
  383. be represented as a RangeIndex.
  384. """
  385. if not isinstance(index, RangeIndex) and len(index) > 0:
  386. diff = index[:-1] - index[1:]
  387. assert not (diff == diff[0]).all()
  388. @given(
  389. st.integers(-20, 20),
  390. st.integers(-20, 20),
  391. st.integers(-20, 20),
  392. st.integers(-20, 20),
  393. st.integers(-20, 20),
  394. st.integers(-20, 20),
  395. )
  396. def test_range_difference(start1, stop1, step1, start2, stop2, step2):
  397. # test that
  398. # a) we match Index[int64].difference and
  399. # b) we return RangeIndex whenever it is possible to do so.
  400. assume(step1 != 0)
  401. assume(step2 != 0)
  402. left = RangeIndex(start1, stop1, step1)
  403. right = RangeIndex(start2, stop2, step2)
  404. result = left.difference(right, sort=None)
  405. assert_range_or_not_is_rangelike(result)
  406. left_int64 = Index(left.to_numpy())
  407. right_int64 = Index(right.to_numpy())
  408. alt = left_int64.difference(right_int64, sort=None)
  409. tm.assert_index_equal(result, alt, exact="equiv")
  410. result = left.difference(right, sort=False)
  411. assert_range_or_not_is_rangelike(result)
  412. alt = left_int64.difference(right_int64, sort=False)
  413. tm.assert_index_equal(result, alt, exact="equiv")