test_interval.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import numpy as np
  2. import pytest
  3. import pandas as pd
  4. from pandas import (
  5. DataFrame,
  6. IntervalIndex,
  7. Series,
  8. )
  9. import pandas._testing as tm
  10. class TestIntervalIndex:
  11. @pytest.fixture
  12. def series_with_interval_index(self):
  13. return Series(np.arange(5), IntervalIndex.from_breaks(np.arange(6)))
  14. def test_getitem_with_scalar(self, series_with_interval_index, indexer_sl):
  15. ser = series_with_interval_index.copy()
  16. expected = ser.iloc[:3]
  17. tm.assert_series_equal(expected, indexer_sl(ser)[:3])
  18. tm.assert_series_equal(expected, indexer_sl(ser)[:2.5])
  19. tm.assert_series_equal(expected, indexer_sl(ser)[0.1:2.5])
  20. if indexer_sl is tm.loc:
  21. tm.assert_series_equal(expected, ser.loc[-1:3])
  22. expected = ser.iloc[1:4]
  23. tm.assert_series_equal(expected, indexer_sl(ser)[[1.5, 2.5, 3.5]])
  24. tm.assert_series_equal(expected, indexer_sl(ser)[[2, 3, 4]])
  25. tm.assert_series_equal(expected, indexer_sl(ser)[[1.5, 3, 4]])
  26. expected = ser.iloc[2:5]
  27. tm.assert_series_equal(expected, indexer_sl(ser)[ser >= 2])
  28. @pytest.mark.parametrize("direction", ["increasing", "decreasing"])
  29. def test_getitem_nonoverlapping_monotonic(self, direction, closed, indexer_sl):
  30. tpls = [(0, 1), (2, 3), (4, 5)]
  31. if direction == "decreasing":
  32. tpls = tpls[::-1]
  33. idx = IntervalIndex.from_tuples(tpls, closed=closed)
  34. ser = Series(list("abc"), idx)
  35. for key, expected in zip(idx.left, ser):
  36. if idx.closed_left:
  37. assert indexer_sl(ser)[key] == expected
  38. else:
  39. with pytest.raises(KeyError, match=str(key)):
  40. indexer_sl(ser)[key]
  41. for key, expected in zip(idx.right, ser):
  42. if idx.closed_right:
  43. assert indexer_sl(ser)[key] == expected
  44. else:
  45. with pytest.raises(KeyError, match=str(key)):
  46. indexer_sl(ser)[key]
  47. for key, expected in zip(idx.mid, ser):
  48. assert indexer_sl(ser)[key] == expected
  49. def test_getitem_non_matching(self, series_with_interval_index, indexer_sl):
  50. ser = series_with_interval_index.copy()
  51. # this is a departure from our current
  52. # indexing scheme, but simpler
  53. with pytest.raises(KeyError, match=r"\[-1\] not in index"):
  54. indexer_sl(ser)[[-1, 3, 4, 5]]
  55. with pytest.raises(KeyError, match=r"\[-1\] not in index"):
  56. indexer_sl(ser)[[-1, 3]]
  57. @pytest.mark.slow
  58. def test_loc_getitem_large_series(self):
  59. ser = Series(
  60. np.arange(1000000), index=IntervalIndex.from_breaks(np.arange(1000001))
  61. )
  62. result1 = ser.loc[:80000]
  63. result2 = ser.loc[0:80000]
  64. result3 = ser.loc[0:80000:1]
  65. tm.assert_series_equal(result1, result2)
  66. tm.assert_series_equal(result1, result3)
  67. def test_loc_getitem_frame(self):
  68. # CategoricalIndex with IntervalIndex categories
  69. df = DataFrame({"A": range(10)})
  70. ser = pd.cut(df.A, 5)
  71. df["B"] = ser
  72. df = df.set_index("B")
  73. result = df.loc[4]
  74. expected = df.iloc[4:6]
  75. tm.assert_frame_equal(result, expected)
  76. with pytest.raises(KeyError, match="10"):
  77. df.loc[10]
  78. # single list-like
  79. result = df.loc[[4]]
  80. expected = df.iloc[4:6]
  81. tm.assert_frame_equal(result, expected)
  82. # non-unique
  83. result = df.loc[[4, 5]]
  84. expected = df.take([4, 5, 4, 5])
  85. tm.assert_frame_equal(result, expected)
  86. with pytest.raises(KeyError, match=r"None of \[\[10\]\] are"):
  87. df.loc[[10]]
  88. # partial missing
  89. with pytest.raises(KeyError, match=r"\[10\] not in index"):
  90. df.loc[[10, 4]]
  91. def test_getitem_interval_with_nans(self, frame_or_series, indexer_sl):
  92. # GH#41831
  93. index = IntervalIndex([np.nan, np.nan])
  94. key = index[:-1]
  95. obj = frame_or_series(range(2), index=index)
  96. if frame_or_series is DataFrame and indexer_sl is tm.setitem:
  97. obj = obj.T
  98. result = indexer_sl(obj)[key]
  99. expected = obj
  100. tm.assert_equal(result, expected)
  101. class TestIntervalIndexInsideMultiIndex:
  102. def test_mi_intervalindex_slicing_with_scalar(self):
  103. # GH#27456
  104. ii = IntervalIndex.from_arrays(
  105. [0, 1, 10, 11, 0, 1, 10, 11], [1, 2, 11, 12, 1, 2, 11, 12], name="MP"
  106. )
  107. idx = pd.MultiIndex.from_arrays(
  108. [
  109. pd.Index(["FC", "FC", "FC", "FC", "OWNER", "OWNER", "OWNER", "OWNER"]),
  110. pd.Index(
  111. ["RID1", "RID1", "RID2", "RID2", "RID1", "RID1", "RID2", "RID2"]
  112. ),
  113. ii,
  114. ]
  115. )
  116. idx.names = ["Item", "RID", "MP"]
  117. df = DataFrame({"value": [1, 2, 3, 4, 5, 6, 7, 8]})
  118. df.index = idx
  119. query_df = DataFrame(
  120. {
  121. "Item": ["FC", "OWNER", "FC", "OWNER", "OWNER"],
  122. "RID": ["RID1", "RID1", "RID1", "RID2", "RID2"],
  123. "MP": [0.2, 1.5, 1.6, 11.1, 10.9],
  124. }
  125. )
  126. query_df = query_df.sort_index()
  127. idx = pd.MultiIndex.from_arrays([query_df.Item, query_df.RID, query_df.MP])
  128. query_df.index = idx
  129. result = df.value.loc[query_df.index]
  130. # the IntervalIndex level is indexed with floats, which map to
  131. # the intervals containing them. Matching the behavior we would get
  132. # with _only_ an IntervalIndex, we get an IntervalIndex level back.
  133. sliced_level = ii.take([0, 1, 1, 3, 2])
  134. expected_index = pd.MultiIndex.from_arrays(
  135. [idx.get_level_values(0), idx.get_level_values(1), sliced_level]
  136. )
  137. expected = Series([1, 6, 2, 8, 7], index=expected_index, name="value")
  138. tm.assert_series_equal(result, expected)