test_ops.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. """Tests for Interval-Interval operations, such as overlaps, contains, etc."""
  2. import pytest
  3. from pandas import (
  4. Interval,
  5. Timedelta,
  6. Timestamp,
  7. )
  8. @pytest.fixture(
  9. params=[
  10. (Timedelta("0 days"), Timedelta("1 day")),
  11. (Timestamp("2018-01-01"), Timedelta("1 day")),
  12. (0, 1),
  13. ],
  14. ids=lambda x: type(x[0]).__name__,
  15. )
  16. def start_shift(request):
  17. """
  18. Fixture for generating intervals of types from a start value and a shift
  19. value that can be added to start to generate an endpoint
  20. """
  21. return request.param
  22. class TestOverlaps:
  23. def test_overlaps_self(self, start_shift, closed):
  24. start, shift = start_shift
  25. interval = Interval(start, start + shift, closed)
  26. assert interval.overlaps(interval)
  27. def test_overlaps_nested(self, start_shift, closed, other_closed):
  28. start, shift = start_shift
  29. interval1 = Interval(start, start + 3 * shift, other_closed)
  30. interval2 = Interval(start + shift, start + 2 * shift, closed)
  31. # nested intervals should always overlap
  32. assert interval1.overlaps(interval2)
  33. def test_overlaps_disjoint(self, start_shift, closed, other_closed):
  34. start, shift = start_shift
  35. interval1 = Interval(start, start + shift, other_closed)
  36. interval2 = Interval(start + 2 * shift, start + 3 * shift, closed)
  37. # disjoint intervals should never overlap
  38. assert not interval1.overlaps(interval2)
  39. def test_overlaps_endpoint(self, start_shift, closed, other_closed):
  40. start, shift = start_shift
  41. interval1 = Interval(start, start + shift, other_closed)
  42. interval2 = Interval(start + shift, start + 2 * shift, closed)
  43. # overlap if shared endpoint is closed for both (overlap at a point)
  44. result = interval1.overlaps(interval2)
  45. expected = interval1.closed_right and interval2.closed_left
  46. assert result == expected
  47. @pytest.mark.parametrize(
  48. "other",
  49. [10, True, "foo", Timedelta("1 day"), Timestamp("2018-01-01")],
  50. ids=lambda x: type(x).__name__,
  51. )
  52. def test_overlaps_invalid_type(self, other):
  53. interval = Interval(0, 1)
  54. msg = f"`other` must be an Interval, got {type(other).__name__}"
  55. with pytest.raises(TypeError, match=msg):
  56. interval.overlaps(other)
  57. class TestContains:
  58. def test_contains_interval(self, inclusive_endpoints_fixture):
  59. interval1 = Interval(0, 1, "both")
  60. interval2 = Interval(0, 1, inclusive_endpoints_fixture)
  61. assert interval1 in interval1
  62. assert interval2 in interval2
  63. assert interval2 in interval1
  64. assert interval1 not in interval2 or inclusive_endpoints_fixture == "both"
  65. def test_contains_infinite_length(self):
  66. interval1 = Interval(0, 1, "both")
  67. interval2 = Interval(float("-inf"), float("inf"), "neither")
  68. assert interval1 in interval2
  69. assert interval2 not in interval1
  70. def test_contains_zero_length(self):
  71. interval1 = Interval(0, 1, "both")
  72. interval2 = Interval(-1, -1, "both")
  73. interval3 = Interval(0.5, 0.5, "both")
  74. assert interval2 not in interval1
  75. assert interval3 in interval1
  76. assert interval2 not in interval3 and interval3 not in interval2
  77. assert interval1 not in interval2 and interval1 not in interval3
  78. @pytest.mark.parametrize(
  79. "type1",
  80. [
  81. (0, 1),
  82. (Timestamp(2000, 1, 1, 0), Timestamp(2000, 1, 1, 1)),
  83. (Timedelta("0h"), Timedelta("1h")),
  84. ],
  85. )
  86. @pytest.mark.parametrize(
  87. "type2",
  88. [
  89. (0, 1),
  90. (Timestamp(2000, 1, 1, 0), Timestamp(2000, 1, 1, 1)),
  91. (Timedelta("0h"), Timedelta("1h")),
  92. ],
  93. )
  94. def test_contains_mixed_types(self, type1, type2):
  95. interval1 = Interval(*type1)
  96. interval2 = Interval(*type2)
  97. if type1 == type2:
  98. assert interval1 in interval2
  99. else:
  100. msg = "^'<=' not supported between instances of"
  101. with pytest.raises(TypeError, match=msg):
  102. interval1 in interval2