test_interval.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. """
  2. This file contains a minimal set of tests for compliance with the extension
  3. array interface test suite, and should contain no other tests.
  4. The test suite for the full functionality of the array is located in
  5. `pandas/tests/arrays/`.
  6. The tests in this file are inherited from the BaseExtensionTests, and only
  7. minimal tweaks should be applied to get the tests passing (by overwriting a
  8. parent method).
  9. Additional tests should either be added to one of the BaseExtensionTests
  10. classes (if they are relevant for the extension interface for all dtypes), or
  11. be added to the array-specific tests in `pandas/tests/arrays/`.
  12. """
  13. import numpy as np
  14. import pytest
  15. from pandas.core.dtypes.dtypes import IntervalDtype
  16. from pandas import (
  17. Interval,
  18. Series,
  19. )
  20. from pandas.core.arrays import IntervalArray
  21. from pandas.tests.extension import base
  22. def make_data():
  23. N = 100
  24. left_array = np.random.uniform(size=N).cumsum()
  25. right_array = left_array + np.random.uniform(size=N)
  26. return [Interval(left, right) for left, right in zip(left_array, right_array)]
  27. @pytest.fixture
  28. def dtype():
  29. return IntervalDtype()
  30. @pytest.fixture
  31. def data():
  32. """Length-100 PeriodArray for semantics test."""
  33. return IntervalArray(make_data())
  34. @pytest.fixture
  35. def data_missing():
  36. """Length 2 array with [NA, Valid]"""
  37. return IntervalArray.from_tuples([None, (0, 1)])
  38. @pytest.fixture
  39. def data_for_sorting():
  40. return IntervalArray.from_tuples([(1, 2), (2, 3), (0, 1)])
  41. @pytest.fixture
  42. def data_missing_for_sorting():
  43. return IntervalArray.from_tuples([(1, 2), None, (0, 1)])
  44. @pytest.fixture
  45. def na_value():
  46. return np.nan
  47. @pytest.fixture
  48. def data_for_grouping():
  49. a = (0, 1)
  50. b = (1, 2)
  51. c = (2, 3)
  52. return IntervalArray.from_tuples([b, b, None, None, a, a, b, c])
  53. class BaseInterval:
  54. pass
  55. class TestDtype(BaseInterval, base.BaseDtypeTests):
  56. pass
  57. class TestCasting(BaseInterval, base.BaseCastingTests):
  58. pass
  59. class TestConstructors(BaseInterval, base.BaseConstructorsTests):
  60. pass
  61. class TestGetitem(BaseInterval, base.BaseGetitemTests):
  62. pass
  63. class TestIndex(base.BaseIndexTests):
  64. pass
  65. class TestGrouping(BaseInterval, base.BaseGroupbyTests):
  66. pass
  67. class TestInterface(BaseInterval, base.BaseInterfaceTests):
  68. pass
  69. class TestReduce(base.BaseNoReduceTests):
  70. @pytest.mark.parametrize("skipna", [True, False])
  71. def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
  72. op_name = all_numeric_reductions
  73. ser = Series(data)
  74. if op_name in ["min", "max"]:
  75. # IntervalArray *does* implement these
  76. assert getattr(ser, op_name)(skipna=skipna) in data
  77. assert getattr(data, op_name)(skipna=skipna) in data
  78. return
  79. super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)
  80. class TestMethods(BaseInterval, base.BaseMethodsTests):
  81. @pytest.mark.xfail(reason="addition is not defined for intervals")
  82. def test_combine_add(self, data_repeated):
  83. super().test_combine_add(data_repeated)
  84. @pytest.mark.xfail(
  85. reason="Raises with incorrect message bc it disallows *all* listlikes "
  86. "instead of just wrong-length listlikes"
  87. )
  88. def test_fillna_length_mismatch(self, data_missing):
  89. super().test_fillna_length_mismatch(data_missing)
  90. class TestMissing(BaseInterval, base.BaseMissingTests):
  91. # Index.fillna only accepts scalar `value`, so we have to xfail all
  92. # non-scalar fill tests.
  93. unsupported_fill = pytest.mark.xfail(
  94. reason="Unsupported fillna option for Interval."
  95. )
  96. @unsupported_fill
  97. def test_fillna_limit_pad(self):
  98. super().test_fillna_limit_pad()
  99. @unsupported_fill
  100. def test_fillna_series_method(self):
  101. super().test_fillna_series_method()
  102. @unsupported_fill
  103. def test_fillna_limit_backfill(self):
  104. super().test_fillna_limit_backfill()
  105. @unsupported_fill
  106. def test_fillna_no_op_returns_copy(self):
  107. super().test_fillna_no_op_returns_copy()
  108. @unsupported_fill
  109. def test_fillna_series(self):
  110. super().test_fillna_series()
  111. def test_fillna_non_scalar_raises(self, data_missing):
  112. msg = "can only insert Interval objects and NA into an IntervalArray"
  113. with pytest.raises(TypeError, match=msg):
  114. data_missing.fillna([1, 1])
  115. class TestReshaping(BaseInterval, base.BaseReshapingTests):
  116. pass
  117. class TestSetitem(BaseInterval, base.BaseSetitemTests):
  118. pass
  119. class TestPrinting(BaseInterval, base.BasePrintingTests):
  120. @pytest.mark.xfail(reason="Interval has custom repr")
  121. def test_array_repr(self, data, size):
  122. super().test_array_repr()
  123. class TestParsing(BaseInterval, base.BaseParsingTests):
  124. @pytest.mark.parametrize("engine", ["c", "python"])
  125. def test_EA_types(self, engine, data):
  126. expected_msg = r".*must implement _from_sequence_of_strings.*"
  127. with pytest.raises(NotImplementedError, match=expected_msg):
  128. super().test_EA_types(engine, data)