test_interval_new.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import re
  2. import numpy as np
  3. import pytest
  4. from pandas.compat import IS64
  5. from pandas import (
  6. Index,
  7. Interval,
  8. IntervalIndex,
  9. Series,
  10. )
  11. import pandas._testing as tm
  12. class TestIntervalIndex:
  13. @pytest.fixture
  14. def series_with_interval_index(self):
  15. return Series(np.arange(5), IntervalIndex.from_breaks(np.arange(6)))
  16. def test_loc_with_interval(self, series_with_interval_index, indexer_sl):
  17. # loc with single label / list of labels:
  18. # - Intervals: only exact matches
  19. # - scalars: those that contain it
  20. ser = series_with_interval_index.copy()
  21. expected = 0
  22. result = indexer_sl(ser)[Interval(0, 1)]
  23. assert result == expected
  24. expected = ser.iloc[3:5]
  25. result = indexer_sl(ser)[[Interval(3, 4), Interval(4, 5)]]
  26. tm.assert_series_equal(expected, result)
  27. # missing or not exact
  28. with pytest.raises(KeyError, match=re.escape("Interval(3, 5, closed='left')")):
  29. indexer_sl(ser)[Interval(3, 5, closed="left")]
  30. with pytest.raises(KeyError, match=re.escape("Interval(3, 5, closed='right')")):
  31. indexer_sl(ser)[Interval(3, 5)]
  32. with pytest.raises(
  33. KeyError, match=re.escape("Interval(-2, 0, closed='right')")
  34. ):
  35. indexer_sl(ser)[Interval(-2, 0)]
  36. with pytest.raises(KeyError, match=re.escape("Interval(5, 6, closed='right')")):
  37. indexer_sl(ser)[Interval(5, 6)]
  38. def test_loc_with_scalar(self, series_with_interval_index, indexer_sl):
  39. # loc with single label / list of labels:
  40. # - Intervals: only exact matches
  41. # - scalars: those that contain it
  42. ser = series_with_interval_index.copy()
  43. assert indexer_sl(ser)[1] == 0
  44. assert indexer_sl(ser)[1.5] == 1
  45. assert indexer_sl(ser)[2] == 1
  46. expected = ser.iloc[1:4]
  47. tm.assert_series_equal(expected, indexer_sl(ser)[[1.5, 2.5, 3.5]])
  48. tm.assert_series_equal(expected, indexer_sl(ser)[[2, 3, 4]])
  49. tm.assert_series_equal(expected, indexer_sl(ser)[[1.5, 3, 4]])
  50. expected = ser.iloc[[1, 1, 2, 1]]
  51. tm.assert_series_equal(expected, indexer_sl(ser)[[1.5, 2, 2.5, 1.5]])
  52. expected = ser.iloc[2:5]
  53. tm.assert_series_equal(expected, indexer_sl(ser)[ser >= 2])
  54. def test_loc_with_slices(self, series_with_interval_index, indexer_sl):
  55. # loc with slices:
  56. # - Interval objects: only works with exact matches
  57. # - scalars: only works for non-overlapping, monotonic intervals,
  58. # and start/stop select location based on the interval that
  59. # contains them:
  60. # (slice_loc(start, stop) == (idx.get_loc(start), idx.get_loc(stop))
  61. ser = series_with_interval_index.copy()
  62. # slice of interval
  63. expected = ser.iloc[:3]
  64. result = indexer_sl(ser)[Interval(0, 1) : Interval(2, 3)]
  65. tm.assert_series_equal(expected, result)
  66. expected = ser.iloc[3:]
  67. result = indexer_sl(ser)[Interval(3, 4) :]
  68. tm.assert_series_equal(expected, result)
  69. msg = "Interval objects are not currently supported"
  70. with pytest.raises(NotImplementedError, match=msg):
  71. indexer_sl(ser)[Interval(3, 6) :]
  72. with pytest.raises(NotImplementedError, match=msg):
  73. indexer_sl(ser)[Interval(3, 4, closed="left") :]
  74. def test_slice_step_ne1(self, series_with_interval_index):
  75. # GH#31658 slice of scalar with step != 1
  76. ser = series_with_interval_index.copy()
  77. expected = ser.iloc[0:4:2]
  78. result = ser[0:4:2]
  79. tm.assert_series_equal(result, expected)
  80. result2 = ser[0:4][::2]
  81. tm.assert_series_equal(result2, expected)
  82. def test_slice_float_start_stop(self, series_with_interval_index):
  83. # GH#31658 slicing with integers is positional, with floats is not
  84. # supported
  85. ser = series_with_interval_index.copy()
  86. msg = "label-based slicing with step!=1 is not supported for IntervalIndex"
  87. with pytest.raises(ValueError, match=msg):
  88. ser[1.5:9.5:2]
  89. def test_slice_interval_step(self, series_with_interval_index):
  90. # GH#31658 allows for integer step!=1, not Interval step
  91. ser = series_with_interval_index.copy()
  92. msg = "label-based slicing with step!=1 is not supported for IntervalIndex"
  93. with pytest.raises(ValueError, match=msg):
  94. ser[0 : 4 : Interval(0, 1)]
  95. def test_loc_with_overlap(self, indexer_sl):
  96. idx = IntervalIndex.from_tuples([(1, 5), (3, 7)])
  97. ser = Series(range(len(idx)), index=idx)
  98. # scalar
  99. expected = ser
  100. result = indexer_sl(ser)[4]
  101. tm.assert_series_equal(expected, result)
  102. result = indexer_sl(ser)[[4]]
  103. tm.assert_series_equal(expected, result)
  104. # interval
  105. expected = 0
  106. result = indexer_sl(ser)[Interval(1, 5)]
  107. result == expected
  108. expected = ser
  109. result = indexer_sl(ser)[[Interval(1, 5), Interval(3, 7)]]
  110. tm.assert_series_equal(expected, result)
  111. with pytest.raises(KeyError, match=re.escape("Interval(3, 5, closed='right')")):
  112. indexer_sl(ser)[Interval(3, 5)]
  113. msg = r"None of \[\[Interval\(3, 5, closed='right'\)\]\]"
  114. with pytest.raises(KeyError, match=msg):
  115. indexer_sl(ser)[[Interval(3, 5)]]
  116. # slices with interval (only exact matches)
  117. expected = ser
  118. result = indexer_sl(ser)[Interval(1, 5) : Interval(3, 7)]
  119. tm.assert_series_equal(expected, result)
  120. msg = (
  121. "'can only get slices from an IntervalIndex if bounds are "
  122. "non-overlapping and all monotonic increasing or decreasing'"
  123. )
  124. with pytest.raises(KeyError, match=msg):
  125. indexer_sl(ser)[Interval(1, 6) : Interval(3, 8)]
  126. if indexer_sl is tm.loc:
  127. # slices with scalar raise for overlapping intervals
  128. # TODO KeyError is the appropriate error?
  129. with pytest.raises(KeyError, match=msg):
  130. ser.loc[1:4]
  131. def test_non_unique(self, indexer_sl):
  132. idx = IntervalIndex.from_tuples([(1, 3), (3, 7)])
  133. ser = Series(range(len(idx)), index=idx)
  134. result = indexer_sl(ser)[Interval(1, 3)]
  135. assert result == 0
  136. result = indexer_sl(ser)[[Interval(1, 3)]]
  137. expected = ser.iloc[0:1]
  138. tm.assert_series_equal(expected, result)
  139. def test_non_unique_moar(self, indexer_sl):
  140. idx = IntervalIndex.from_tuples([(1, 3), (1, 3), (3, 7)])
  141. ser = Series(range(len(idx)), index=idx)
  142. expected = ser.iloc[[0, 1]]
  143. result = indexer_sl(ser)[Interval(1, 3)]
  144. tm.assert_series_equal(expected, result)
  145. expected = ser
  146. result = indexer_sl(ser)[Interval(1, 3) :]
  147. tm.assert_series_equal(expected, result)
  148. expected = ser.iloc[[0, 1]]
  149. result = indexer_sl(ser)[[Interval(1, 3)]]
  150. tm.assert_series_equal(expected, result)
  151. def test_loc_getitem_missing_key_error_message(
  152. self, frame_or_series, series_with_interval_index
  153. ):
  154. # GH#27365
  155. ser = series_with_interval_index.copy()
  156. obj = frame_or_series(ser)
  157. with pytest.raises(KeyError, match=r"\[6\]"):
  158. obj.loc[[4, 5, 6]]
  159. @pytest.mark.xfail(not IS64, reason="GH 23440")
  160. @pytest.mark.parametrize(
  161. "intervals",
  162. [
  163. ([Interval(-np.inf, 0.0), Interval(0.0, 1.0)]),
  164. ([Interval(-np.inf, -2.0), Interval(-2.0, -1.0)]),
  165. ([Interval(-1.0, 0.0), Interval(0.0, np.inf)]),
  166. ([Interval(1.0, 2.0), Interval(2.0, np.inf)]),
  167. ],
  168. )
  169. def test_repeating_interval_index_with_infs(intervals):
  170. # GH 46658
  171. interval_index = Index(intervals * 51)
  172. expected = np.arange(1, 102, 2, dtype=np.intp)
  173. result = interval_index.get_indexer_for([intervals[1]])
  174. tm.assert_equal(result, expected)