test_interval_tree.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from itertools import permutations
  2. import numpy as np
  3. import pytest
  4. from pandas._libs.interval import IntervalTree
  5. from pandas.compat import IS64
  6. import pandas._testing as tm
  7. def skipif_32bit(param):
  8. """
  9. Skip parameters in a parametrize on 32bit systems. Specifically used
  10. here to skip leaf_size parameters related to GH 23440.
  11. """
  12. marks = pytest.mark.skipif(not IS64, reason="GH 23440: int type mismatch on 32bit")
  13. return pytest.param(param, marks=marks)
  14. @pytest.fixture(params=["int64", "float64", "uint64"])
  15. def dtype(request):
  16. return request.param
  17. @pytest.fixture(params=[skipif_32bit(1), skipif_32bit(2), 10])
  18. def leaf_size(request):
  19. """
  20. Fixture to specify IntervalTree leaf_size parameter; to be used with the
  21. tree fixture.
  22. """
  23. return request.param
  24. @pytest.fixture(
  25. params=[
  26. np.arange(5, dtype="int64"),
  27. np.arange(5, dtype="uint64"),
  28. np.arange(5, dtype="float64"),
  29. np.array([0, 1, 2, 3, 4, np.nan], dtype="float64"),
  30. ]
  31. )
  32. def tree(request, leaf_size):
  33. left = request.param
  34. return IntervalTree(left, left + 2, leaf_size=leaf_size)
  35. class TestIntervalTree:
  36. def test_get_indexer(self, tree):
  37. result = tree.get_indexer(np.array([1.0, 5.5, 6.5]))
  38. expected = np.array([0, 4, -1], dtype="intp")
  39. tm.assert_numpy_array_equal(result, expected)
  40. with pytest.raises(
  41. KeyError, match="'indexer does not intersect a unique set of intervals'"
  42. ):
  43. tree.get_indexer(np.array([3.0]))
  44. @pytest.mark.parametrize(
  45. "dtype, target_value, target_dtype",
  46. [("int64", 2**63 + 1, "uint64"), ("uint64", -1, "int64")],
  47. )
  48. def test_get_indexer_overflow(self, dtype, target_value, target_dtype):
  49. left, right = np.array([0, 1], dtype=dtype), np.array([1, 2], dtype=dtype)
  50. tree = IntervalTree(left, right)
  51. result = tree.get_indexer(np.array([target_value], dtype=target_dtype))
  52. expected = np.array([-1], dtype="intp")
  53. tm.assert_numpy_array_equal(result, expected)
  54. def test_get_indexer_non_unique(self, tree):
  55. indexer, missing = tree.get_indexer_non_unique(np.array([1.0, 2.0, 6.5]))
  56. result = indexer[:1]
  57. expected = np.array([0], dtype="intp")
  58. tm.assert_numpy_array_equal(result, expected)
  59. result = np.sort(indexer[1:3])
  60. expected = np.array([0, 1], dtype="intp")
  61. tm.assert_numpy_array_equal(result, expected)
  62. result = np.sort(indexer[3:])
  63. expected = np.array([-1], dtype="intp")
  64. tm.assert_numpy_array_equal(result, expected)
  65. result = missing
  66. expected = np.array([2], dtype="intp")
  67. tm.assert_numpy_array_equal(result, expected)
  68. @pytest.mark.parametrize(
  69. "dtype, target_value, target_dtype",
  70. [("int64", 2**63 + 1, "uint64"), ("uint64", -1, "int64")],
  71. )
  72. def test_get_indexer_non_unique_overflow(self, dtype, target_value, target_dtype):
  73. left, right = np.array([0, 2], dtype=dtype), np.array([1, 3], dtype=dtype)
  74. tree = IntervalTree(left, right)
  75. target = np.array([target_value], dtype=target_dtype)
  76. result_indexer, result_missing = tree.get_indexer_non_unique(target)
  77. expected_indexer = np.array([-1], dtype="intp")
  78. tm.assert_numpy_array_equal(result_indexer, expected_indexer)
  79. expected_missing = np.array([0], dtype="intp")
  80. tm.assert_numpy_array_equal(result_missing, expected_missing)
  81. def test_duplicates(self, dtype):
  82. left = np.array([0, 0, 0], dtype=dtype)
  83. tree = IntervalTree(left, left + 1)
  84. with pytest.raises(
  85. KeyError, match="'indexer does not intersect a unique set of intervals'"
  86. ):
  87. tree.get_indexer(np.array([0.5]))
  88. indexer, missing = tree.get_indexer_non_unique(np.array([0.5]))
  89. result = np.sort(indexer)
  90. expected = np.array([0, 1, 2], dtype="intp")
  91. tm.assert_numpy_array_equal(result, expected)
  92. result = missing
  93. expected = np.array([], dtype="intp")
  94. tm.assert_numpy_array_equal(result, expected)
  95. @pytest.mark.parametrize(
  96. "leaf_size", [skipif_32bit(1), skipif_32bit(10), skipif_32bit(100), 10000]
  97. )
  98. def test_get_indexer_closed(self, closed, leaf_size):
  99. x = np.arange(1000, dtype="float64")
  100. found = x.astype("intp")
  101. not_found = (-1 * np.ones(1000)).astype("intp")
  102. tree = IntervalTree(x, x + 0.5, closed=closed, leaf_size=leaf_size)
  103. tm.assert_numpy_array_equal(found, tree.get_indexer(x + 0.25))
  104. expected = found if tree.closed_left else not_found
  105. tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.0))
  106. expected = found if tree.closed_right else not_found
  107. tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.5))
  108. @pytest.mark.parametrize(
  109. "left, right, expected",
  110. [
  111. (np.array([0, 1, 4], dtype="int64"), np.array([2, 3, 5]), True),
  112. (np.array([0, 1, 2], dtype="int64"), np.array([5, 4, 3]), True),
  113. (np.array([0, 1, np.nan]), np.array([5, 4, np.nan]), True),
  114. (np.array([0, 2, 4], dtype="int64"), np.array([1, 3, 5]), False),
  115. (np.array([0, 2, np.nan]), np.array([1, 3, np.nan]), False),
  116. ],
  117. )
  118. @pytest.mark.parametrize("order", (list(x) for x in permutations(range(3))))
  119. def test_is_overlapping(self, closed, order, left, right, expected):
  120. # GH 23309
  121. tree = IntervalTree(left[order], right[order], closed=closed)
  122. result = tree.is_overlapping
  123. assert result is expected
  124. @pytest.mark.parametrize("order", (list(x) for x in permutations(range(3))))
  125. def test_is_overlapping_endpoints(self, closed, order):
  126. """shared endpoints are marked as overlapping"""
  127. # GH 23309
  128. left, right = np.arange(3, dtype="int64"), np.arange(1, 4)
  129. tree = IntervalTree(left[order], right[order], closed=closed)
  130. result = tree.is_overlapping
  131. expected = closed == "both"
  132. assert result is expected
  133. @pytest.mark.parametrize(
  134. "left, right",
  135. [
  136. (np.array([], dtype="int64"), np.array([], dtype="int64")),
  137. (np.array([0], dtype="int64"), np.array([1], dtype="int64")),
  138. (np.array([np.nan]), np.array([np.nan])),
  139. (np.array([np.nan] * 3), np.array([np.nan] * 3)),
  140. ],
  141. )
  142. def test_is_overlapping_trivial(self, closed, left, right):
  143. # GH 23309
  144. tree = IntervalTree(left, right, closed=closed)
  145. assert tree.is_overlapping is False
  146. @pytest.mark.skipif(not IS64, reason="GH 23440")
  147. def test_construction_overflow(self):
  148. # GH 25485
  149. left, right = np.arange(101, dtype="int64"), [np.iinfo(np.int64).max] * 101
  150. tree = IntervalTree(left, right)
  151. # pivot should be average of left/right medians
  152. result = tree.root.pivot
  153. expected = (50 + np.iinfo(np.int64).max) / 2
  154. assert result == expected
  155. @pytest.mark.xfail(not IS64, reason="GH 23440")
  156. @pytest.mark.parametrize(
  157. "left, right, expected",
  158. [
  159. ([-np.inf, 1.0], [1.0, 2.0], 0.0),
  160. ([-np.inf, -2.0], [-2.0, -1.0], -2.0),
  161. ([-2.0, -1.0], [-1.0, np.inf], 0.0),
  162. ([1.0, 2.0], [2.0, np.inf], 2.0),
  163. ],
  164. )
  165. def test_inf_bound_infinite_recursion(self, left, right, expected):
  166. # GH 46658
  167. tree = IntervalTree(left * 101, right * 101)
  168. result = tree.root.pivot
  169. assert result == expected