test_setops.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import numpy as np
  2. import pytest
  3. from pandas import (
  4. Index,
  5. IntervalIndex,
  6. Timestamp,
  7. interval_range,
  8. )
  9. import pandas._testing as tm
  10. def monotonic_index(start, end, dtype="int64", closed="right"):
  11. return IntervalIndex.from_breaks(np.arange(start, end, dtype=dtype), closed=closed)
  12. def empty_index(dtype="int64", closed="right"):
  13. return IntervalIndex(np.array([], dtype=dtype), closed=closed)
  14. class TestIntervalIndex:
  15. def test_union(self, closed, sort):
  16. index = monotonic_index(0, 11, closed=closed)
  17. other = monotonic_index(5, 13, closed=closed)
  18. expected = monotonic_index(0, 13, closed=closed)
  19. result = index[::-1].union(other, sort=sort)
  20. if sort is None:
  21. tm.assert_index_equal(result, expected)
  22. assert tm.equalContents(result, expected)
  23. result = other[::-1].union(index, sort=sort)
  24. if sort is None:
  25. tm.assert_index_equal(result, expected)
  26. assert tm.equalContents(result, expected)
  27. tm.assert_index_equal(index.union(index, sort=sort), index)
  28. tm.assert_index_equal(index.union(index[:1], sort=sort), index)
  29. def test_union_empty_result(self, closed, sort):
  30. # GH 19101: empty result, same dtype
  31. index = empty_index(dtype="int64", closed=closed)
  32. result = index.union(index, sort=sort)
  33. tm.assert_index_equal(result, index)
  34. # GH 19101: empty result, different numeric dtypes -> common dtype is f8
  35. other = empty_index(dtype="float64", closed=closed)
  36. result = index.union(other, sort=sort)
  37. expected = other
  38. tm.assert_index_equal(result, expected)
  39. other = index.union(index, sort=sort)
  40. tm.assert_index_equal(result, expected)
  41. other = empty_index(dtype="uint64", closed=closed)
  42. result = index.union(other, sort=sort)
  43. tm.assert_index_equal(result, expected)
  44. result = other.union(index, sort=sort)
  45. tm.assert_index_equal(result, expected)
  46. def test_intersection(self, closed, sort):
  47. index = monotonic_index(0, 11, closed=closed)
  48. other = monotonic_index(5, 13, closed=closed)
  49. expected = monotonic_index(5, 11, closed=closed)
  50. result = index[::-1].intersection(other, sort=sort)
  51. if sort is None:
  52. tm.assert_index_equal(result, expected)
  53. assert tm.equalContents(result, expected)
  54. result = other[::-1].intersection(index, sort=sort)
  55. if sort is None:
  56. tm.assert_index_equal(result, expected)
  57. assert tm.equalContents(result, expected)
  58. tm.assert_index_equal(index.intersection(index, sort=sort), index)
  59. # GH 26225: nested intervals
  60. index = IntervalIndex.from_tuples([(1, 2), (1, 3), (1, 4), (0, 2)])
  61. other = IntervalIndex.from_tuples([(1, 2), (1, 3)])
  62. expected = IntervalIndex.from_tuples([(1, 2), (1, 3)])
  63. result = index.intersection(other)
  64. tm.assert_index_equal(result, expected)
  65. # GH 26225
  66. index = IntervalIndex.from_tuples([(0, 3), (0, 2)])
  67. other = IntervalIndex.from_tuples([(0, 2), (1, 3)])
  68. expected = IntervalIndex.from_tuples([(0, 2)])
  69. result = index.intersection(other)
  70. tm.assert_index_equal(result, expected)
  71. # GH 26225: duplicate nan element
  72. index = IntervalIndex([np.nan, np.nan])
  73. other = IntervalIndex([np.nan])
  74. expected = IntervalIndex([np.nan])
  75. result = index.intersection(other)
  76. tm.assert_index_equal(result, expected)
  77. def test_intersection_empty_result(self, closed, sort):
  78. index = monotonic_index(0, 11, closed=closed)
  79. # GH 19101: empty result, same dtype
  80. other = monotonic_index(300, 314, closed=closed)
  81. expected = empty_index(dtype="int64", closed=closed)
  82. result = index.intersection(other, sort=sort)
  83. tm.assert_index_equal(result, expected)
  84. # GH 19101: empty result, different numeric dtypes -> common dtype is float64
  85. other = monotonic_index(300, 314, dtype="float64", closed=closed)
  86. result = index.intersection(other, sort=sort)
  87. expected = other[:0]
  88. tm.assert_index_equal(result, expected)
  89. other = monotonic_index(300, 314, dtype="uint64", closed=closed)
  90. result = index.intersection(other, sort=sort)
  91. tm.assert_index_equal(result, expected)
  92. def test_intersection_duplicates(self):
  93. # GH#38743
  94. index = IntervalIndex.from_tuples([(1, 2), (1, 2), (2, 3), (3, 4)])
  95. other = IntervalIndex.from_tuples([(1, 2), (2, 3)])
  96. expected = IntervalIndex.from_tuples([(1, 2), (2, 3)])
  97. result = index.intersection(other)
  98. tm.assert_index_equal(result, expected)
  99. def test_difference(self, closed, sort):
  100. index = IntervalIndex.from_arrays([1, 0, 3, 2], [1, 2, 3, 4], closed=closed)
  101. result = index.difference(index[:1], sort=sort)
  102. expected = index[1:]
  103. if sort is None:
  104. expected = expected.sort_values()
  105. tm.assert_index_equal(result, expected)
  106. # GH 19101: empty result, same dtype
  107. result = index.difference(index, sort=sort)
  108. expected = empty_index(dtype="int64", closed=closed)
  109. tm.assert_index_equal(result, expected)
  110. # GH 19101: empty result, different dtypes
  111. other = IntervalIndex.from_arrays(
  112. index.left.astype("float64"), index.right, closed=closed
  113. )
  114. result = index.difference(other, sort=sort)
  115. tm.assert_index_equal(result, expected)
  116. def test_symmetric_difference(self, closed, sort):
  117. index = monotonic_index(0, 11, closed=closed)
  118. result = index[1:].symmetric_difference(index[:-1], sort=sort)
  119. expected = IntervalIndex([index[0], index[-1]])
  120. if sort is None:
  121. tm.assert_index_equal(result, expected)
  122. assert tm.equalContents(result, expected)
  123. # GH 19101: empty result, same dtype
  124. result = index.symmetric_difference(index, sort=sort)
  125. expected = empty_index(dtype="int64", closed=closed)
  126. if sort is None:
  127. tm.assert_index_equal(result, expected)
  128. assert tm.equalContents(result, expected)
  129. # GH 19101: empty result, different dtypes
  130. other = IntervalIndex.from_arrays(
  131. index.left.astype("float64"), index.right, closed=closed
  132. )
  133. result = index.symmetric_difference(other, sort=sort)
  134. expected = empty_index(dtype="float64", closed=closed)
  135. tm.assert_index_equal(result, expected)
  136. @pytest.mark.filterwarnings("ignore:'<' not supported between:RuntimeWarning")
  137. @pytest.mark.parametrize(
  138. "op_name", ["union", "intersection", "difference", "symmetric_difference"]
  139. )
  140. def test_set_incompatible_types(self, closed, op_name, sort):
  141. index = monotonic_index(0, 11, closed=closed)
  142. set_op = getattr(index, op_name)
  143. # TODO: standardize return type of non-union setops type(self vs other)
  144. # non-IntervalIndex
  145. if op_name == "difference":
  146. expected = index
  147. else:
  148. expected = getattr(index.astype("O"), op_name)(Index([1, 2, 3]))
  149. result = set_op(Index([1, 2, 3]), sort=sort)
  150. tm.assert_index_equal(result, expected)
  151. # mixed closed -> cast to object
  152. for other_closed in {"right", "left", "both", "neither"} - {closed}:
  153. other = monotonic_index(0, 11, closed=other_closed)
  154. expected = getattr(index.astype(object), op_name)(other, sort=sort)
  155. if op_name == "difference":
  156. expected = index
  157. result = set_op(other, sort=sort)
  158. tm.assert_index_equal(result, expected)
  159. # GH 19016: incompatible dtypes -> cast to object
  160. other = interval_range(Timestamp("20180101"), periods=9, closed=closed)
  161. expected = getattr(index.astype(object), op_name)(other, sort=sort)
  162. if op_name == "difference":
  163. expected = index
  164. result = set_op(other, sort=sort)
  165. tm.assert_index_equal(result, expected)