test_interval.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. import numpy as np
  2. import pytest
  3. import pandas.util._test_decorators as td
  4. import pandas as pd
  5. from pandas import (
  6. Index,
  7. Interval,
  8. IntervalIndex,
  9. Timedelta,
  10. Timestamp,
  11. date_range,
  12. timedelta_range,
  13. )
  14. import pandas._testing as tm
  15. from pandas.core.arrays import IntervalArray
  16. @pytest.fixture(
  17. params=[
  18. (Index([0, 2, 4]), Index([1, 3, 5])),
  19. (Index([0.0, 1.0, 2.0]), Index([1.0, 2.0, 3.0])),
  20. (timedelta_range("0 days", periods=3), timedelta_range("1 day", periods=3)),
  21. (date_range("20170101", periods=3), date_range("20170102", periods=3)),
  22. (
  23. date_range("20170101", periods=3, tz="US/Eastern"),
  24. date_range("20170102", periods=3, tz="US/Eastern"),
  25. ),
  26. ],
  27. ids=lambda x: str(x[0].dtype),
  28. )
  29. def left_right_dtypes(request):
  30. """
  31. Fixture for building an IntervalArray from various dtypes
  32. """
  33. return request.param
  34. class TestAttributes:
  35. @pytest.mark.parametrize(
  36. "left, right",
  37. [
  38. (0, 1),
  39. (Timedelta("0 days"), Timedelta("1 day")),
  40. (Timestamp("2018-01-01"), Timestamp("2018-01-02")),
  41. (
  42. Timestamp("2018-01-01", tz="US/Eastern"),
  43. Timestamp("2018-01-02", tz="US/Eastern"),
  44. ),
  45. ],
  46. )
  47. @pytest.mark.parametrize("constructor", [IntervalArray, IntervalIndex])
  48. def test_is_empty(self, constructor, left, right, closed):
  49. # GH27219
  50. tuples = [(left, left), (left, right), np.nan]
  51. expected = np.array([closed != "both", False, False])
  52. result = constructor.from_tuples(tuples, closed=closed).is_empty
  53. tm.assert_numpy_array_equal(result, expected)
  54. class TestMethods:
  55. @pytest.mark.parametrize("new_closed", ["left", "right", "both", "neither"])
  56. def test_set_closed(self, closed, new_closed):
  57. # GH 21670
  58. array = IntervalArray.from_breaks(range(10), closed=closed)
  59. result = array.set_closed(new_closed)
  60. expected = IntervalArray.from_breaks(range(10), closed=new_closed)
  61. tm.assert_extension_array_equal(result, expected)
  62. @pytest.mark.parametrize(
  63. "other",
  64. [
  65. Interval(0, 1, closed="right"),
  66. IntervalArray.from_breaks([1, 2, 3, 4], closed="right"),
  67. ],
  68. )
  69. def test_where_raises(self, other):
  70. # GH#45768 The IntervalArray methods raises; the Series method coerces
  71. ser = pd.Series(IntervalArray.from_breaks([1, 2, 3, 4], closed="left"))
  72. mask = np.array([True, False, True])
  73. match = "'value.closed' is 'right', expected 'left'."
  74. with pytest.raises(ValueError, match=match):
  75. ser.array._where(mask, other)
  76. res = ser.where(mask, other=other)
  77. expected = ser.astype(object).where(mask, other)
  78. tm.assert_series_equal(res, expected)
  79. def test_shift(self):
  80. # https://github.com/pandas-dev/pandas/issues/31495, GH#22428, GH#31502
  81. a = IntervalArray.from_breaks([1, 2, 3])
  82. result = a.shift()
  83. # int -> float
  84. expected = IntervalArray.from_tuples([(np.nan, np.nan), (1.0, 2.0)])
  85. tm.assert_interval_array_equal(result, expected)
  86. msg = "can only insert Interval objects and NA into an IntervalArray"
  87. with pytest.raises(TypeError, match=msg):
  88. a.shift(1, fill_value=pd.NaT)
  89. def test_shift_datetime(self):
  90. # GH#31502, GH#31504
  91. a = IntervalArray.from_breaks(date_range("2000", periods=4))
  92. result = a.shift(2)
  93. expected = a.take([-1, -1, 0], allow_fill=True)
  94. tm.assert_interval_array_equal(result, expected)
  95. result = a.shift(-1)
  96. expected = a.take([1, 2, -1], allow_fill=True)
  97. tm.assert_interval_array_equal(result, expected)
  98. msg = "can only insert Interval objects and NA into an IntervalArray"
  99. with pytest.raises(TypeError, match=msg):
  100. a.shift(1, fill_value=np.timedelta64("NaT", "ns"))
  101. class TestSetitem:
  102. def test_set_na(self, left_right_dtypes):
  103. left, right = left_right_dtypes
  104. left = left.copy(deep=True)
  105. right = right.copy(deep=True)
  106. result = IntervalArray.from_arrays(left, right)
  107. if result.dtype.subtype.kind not in ["m", "M"]:
  108. msg = "'value' should be an interval type, got <.*NaTType'> instead."
  109. with pytest.raises(TypeError, match=msg):
  110. result[0] = pd.NaT
  111. if result.dtype.subtype.kind in ["i", "u"]:
  112. msg = "Cannot set float NaN to integer-backed IntervalArray"
  113. # GH#45484 TypeError, not ValueError, matches what we get with
  114. # non-NA un-holdable value.
  115. with pytest.raises(TypeError, match=msg):
  116. result[0] = np.NaN
  117. return
  118. result[0] = np.nan
  119. expected_left = Index([left._na_value] + list(left[1:]))
  120. expected_right = Index([right._na_value] + list(right[1:]))
  121. expected = IntervalArray.from_arrays(expected_left, expected_right)
  122. tm.assert_extension_array_equal(result, expected)
  123. def test_setitem_mismatched_closed(self):
  124. arr = IntervalArray.from_breaks(range(4))
  125. orig = arr.copy()
  126. other = arr.set_closed("both")
  127. msg = "'value.closed' is 'both', expected 'right'"
  128. with pytest.raises(ValueError, match=msg):
  129. arr[0] = other[0]
  130. with pytest.raises(ValueError, match=msg):
  131. arr[:1] = other[:1]
  132. with pytest.raises(ValueError, match=msg):
  133. arr[:0] = other[:0]
  134. with pytest.raises(ValueError, match=msg):
  135. arr[:] = other[::-1]
  136. with pytest.raises(ValueError, match=msg):
  137. arr[:] = list(other[::-1])
  138. with pytest.raises(ValueError, match=msg):
  139. arr[:] = other[::-1].astype(object)
  140. with pytest.raises(ValueError, match=msg):
  141. arr[:] = other[::-1].astype("category")
  142. # empty list should be no-op
  143. arr[:0] = []
  144. tm.assert_interval_array_equal(arr, orig)
  145. def test_repr():
  146. # GH 25022
  147. arr = IntervalArray.from_tuples([(0, 1), (1, 2)])
  148. result = repr(arr)
  149. expected = (
  150. "<IntervalArray>\n"
  151. "[(0, 1], (1, 2]]\n"
  152. "Length: 2, dtype: interval[int64, right]"
  153. )
  154. assert result == expected
  155. class TestReductions:
  156. def test_min_max_invalid_axis(self, left_right_dtypes):
  157. left, right = left_right_dtypes
  158. left = left.copy(deep=True)
  159. right = right.copy(deep=True)
  160. arr = IntervalArray.from_arrays(left, right)
  161. msg = "`axis` must be fewer than the number of dimensions"
  162. for axis in [-2, 1]:
  163. with pytest.raises(ValueError, match=msg):
  164. arr.min(axis=axis)
  165. with pytest.raises(ValueError, match=msg):
  166. arr.max(axis=axis)
  167. msg = "'>=' not supported between"
  168. with pytest.raises(TypeError, match=msg):
  169. arr.min(axis="foo")
  170. with pytest.raises(TypeError, match=msg):
  171. arr.max(axis="foo")
  172. def test_min_max(self, left_right_dtypes, index_or_series_or_array):
  173. # GH#44746
  174. left, right = left_right_dtypes
  175. left = left.copy(deep=True)
  176. right = right.copy(deep=True)
  177. arr = IntervalArray.from_arrays(left, right)
  178. # The expected results below are only valid if monotonic
  179. assert left.is_monotonic_increasing
  180. assert Index(arr).is_monotonic_increasing
  181. MIN = arr[0]
  182. MAX = arr[-1]
  183. indexer = np.arange(len(arr))
  184. np.random.shuffle(indexer)
  185. arr = arr.take(indexer)
  186. arr_na = arr.insert(2, np.nan)
  187. arr = index_or_series_or_array(arr)
  188. arr_na = index_or_series_or_array(arr_na)
  189. for skipna in [True, False]:
  190. res = arr.min(skipna=skipna)
  191. assert res == MIN
  192. assert type(res) == type(MIN)
  193. res = arr.max(skipna=skipna)
  194. assert res == MAX
  195. assert type(res) == type(MAX)
  196. res = arr_na.min(skipna=False)
  197. assert np.isnan(res)
  198. res = arr_na.max(skipna=False)
  199. assert np.isnan(res)
  200. res = arr_na.min(skipna=True)
  201. assert res == MIN
  202. assert type(res) == type(MIN)
  203. res = arr_na.max(skipna=True)
  204. assert res == MAX
  205. assert type(res) == type(MAX)
  206. # ----------------------------------------------------------------------------
  207. # Arrow interaction
  208. pyarrow_skip = td.skip_if_no("pyarrow")
  209. @pyarrow_skip
  210. def test_arrow_extension_type():
  211. import pyarrow as pa
  212. from pandas.core.arrays.arrow.extension_types import ArrowIntervalType
  213. p1 = ArrowIntervalType(pa.int64(), "left")
  214. p2 = ArrowIntervalType(pa.int64(), "left")
  215. p3 = ArrowIntervalType(pa.int64(), "right")
  216. assert p1.closed == "left"
  217. assert p1 == p2
  218. assert p1 != p3
  219. assert hash(p1) == hash(p2)
  220. assert hash(p1) != hash(p3)
  221. @pyarrow_skip
  222. def test_arrow_array():
  223. import pyarrow as pa
  224. from pandas.core.arrays.arrow.extension_types import ArrowIntervalType
  225. intervals = pd.interval_range(1, 5, freq=1).array
  226. result = pa.array(intervals)
  227. assert isinstance(result.type, ArrowIntervalType)
  228. assert result.type.closed == intervals.closed
  229. assert result.type.subtype == pa.int64()
  230. assert result.storage.field("left").equals(pa.array([1, 2, 3, 4], type="int64"))
  231. assert result.storage.field("right").equals(pa.array([2, 3, 4, 5], type="int64"))
  232. expected = pa.array([{"left": i, "right": i + 1} for i in range(1, 5)])
  233. assert result.storage.equals(expected)
  234. # convert to its storage type
  235. result = pa.array(intervals, type=expected.type)
  236. assert result.equals(expected)
  237. # unsupported conversions
  238. with pytest.raises(TypeError, match="Not supported to convert IntervalArray"):
  239. pa.array(intervals, type="float64")
  240. with pytest.raises(TypeError, match="Not supported to convert IntervalArray"):
  241. pa.array(intervals, type=ArrowIntervalType(pa.float64(), "left"))
  242. @pyarrow_skip
  243. def test_arrow_array_missing():
  244. import pyarrow as pa
  245. from pandas.core.arrays.arrow.extension_types import ArrowIntervalType
  246. arr = IntervalArray.from_breaks([0.0, 1.0, 2.0, 3.0])
  247. arr[1] = None
  248. result = pa.array(arr)
  249. assert isinstance(result.type, ArrowIntervalType)
  250. assert result.type.closed == arr.closed
  251. assert result.type.subtype == pa.float64()
  252. # fields have missing values (not NaN)
  253. left = pa.array([0.0, None, 2.0], type="float64")
  254. right = pa.array([1.0, None, 3.0], type="float64")
  255. assert result.storage.field("left").equals(left)
  256. assert result.storage.field("right").equals(right)
  257. # structarray itself also has missing values on the array level
  258. vals = [
  259. {"left": 0.0, "right": 1.0},
  260. {"left": None, "right": None},
  261. {"left": 2.0, "right": 3.0},
  262. ]
  263. expected = pa.StructArray.from_pandas(vals, mask=np.array([False, True, False]))
  264. assert result.storage.equals(expected)
  265. @pyarrow_skip
  266. @pytest.mark.parametrize(
  267. "breaks",
  268. [[0.0, 1.0, 2.0, 3.0], date_range("2017", periods=4, freq="D")],
  269. ids=["float", "datetime64[ns]"],
  270. )
  271. def test_arrow_table_roundtrip(breaks):
  272. import pyarrow as pa
  273. from pandas.core.arrays.arrow.extension_types import ArrowIntervalType
  274. arr = IntervalArray.from_breaks(breaks)
  275. arr[1] = None
  276. df = pd.DataFrame({"a": arr})
  277. table = pa.table(df)
  278. assert isinstance(table.field("a").type, ArrowIntervalType)
  279. result = table.to_pandas()
  280. assert isinstance(result["a"].dtype, pd.IntervalDtype)
  281. tm.assert_frame_equal(result, df)
  282. table2 = pa.concat_tables([table, table])
  283. result = table2.to_pandas()
  284. expected = pd.concat([df, df], ignore_index=True)
  285. tm.assert_frame_equal(result, expected)
  286. # GH-41040
  287. table = pa.table(
  288. [pa.chunked_array([], type=table.column(0).type)], schema=table.schema
  289. )
  290. result = table.to_pandas()
  291. tm.assert_frame_equal(result, expected[0:0])
  292. @pyarrow_skip
  293. @pytest.mark.parametrize(
  294. "breaks",
  295. [[0.0, 1.0, 2.0, 3.0], date_range("2017", periods=4, freq="D")],
  296. ids=["float", "datetime64[ns]"],
  297. )
  298. def test_arrow_table_roundtrip_without_metadata(breaks):
  299. import pyarrow as pa
  300. arr = IntervalArray.from_breaks(breaks)
  301. arr[1] = None
  302. df = pd.DataFrame({"a": arr})
  303. table = pa.table(df)
  304. # remove the metadata
  305. table = table.replace_schema_metadata()
  306. assert table.schema.metadata is None
  307. result = table.to_pandas()
  308. assert isinstance(result["a"].dtype, pd.IntervalDtype)
  309. tm.assert_frame_equal(result, df)
  310. @pyarrow_skip
  311. def test_from_arrow_from_raw_struct_array():
  312. # in case pyarrow lost the Interval extension type (eg on parquet roundtrip
  313. # with datetime64[ns] subtype, see GH-45881), still allow conversion
  314. # from arrow to IntervalArray
  315. import pyarrow as pa
  316. arr = pa.array([{"left": 0, "right": 1}, {"left": 1, "right": 2}])
  317. dtype = pd.IntervalDtype(np.dtype("int64"), closed="neither")
  318. result = dtype.__from_arrow__(arr)
  319. expected = IntervalArray.from_breaks(
  320. np.array([0, 1, 2], dtype="int64"), closed="neither"
  321. )
  322. tm.assert_extension_array_equal(result, expected)
  323. result = dtype.__from_arrow__(pa.chunked_array([arr]))
  324. tm.assert_extension_array_equal(result, expected)
  325. @pytest.mark.parametrize("timezone", ["UTC", "US/Pacific", "GMT"])
  326. def test_interval_index_subtype(timezone, inclusive_endpoints_fixture):
  327. # GH 46999
  328. dates = date_range("2022", periods=3, tz=timezone)
  329. dtype = f"interval[datetime64[ns, {timezone}], {inclusive_endpoints_fixture}]"
  330. result = IntervalIndex.from_arrays(
  331. ["2022-01-01", "2022-01-02"],
  332. ["2022-01-02", "2022-01-03"],
  333. closed=inclusive_endpoints_fixture,
  334. dtype=dtype,
  335. )
  336. expected = IntervalIndex.from_arrays(
  337. dates[:-1], dates[1:], closed=inclusive_endpoints_fixture
  338. )
  339. tm.assert_index_equal(result, expected)