test_truncate.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import numpy as np
  2. import pytest
  3. import pandas as pd
  4. from pandas import (
  5. DataFrame,
  6. DatetimeIndex,
  7. Index,
  8. Series,
  9. date_range,
  10. )
  11. import pandas._testing as tm
  12. class TestDataFrameTruncate:
  13. def test_truncate(self, datetime_frame, frame_or_series):
  14. ts = datetime_frame[::3]
  15. ts = tm.get_obj(ts, frame_or_series)
  16. start, end = datetime_frame.index[3], datetime_frame.index[6]
  17. start_missing = datetime_frame.index[2]
  18. end_missing = datetime_frame.index[7]
  19. # neither specified
  20. truncated = ts.truncate()
  21. tm.assert_equal(truncated, ts)
  22. # both specified
  23. expected = ts[1:3]
  24. truncated = ts.truncate(start, end)
  25. tm.assert_equal(truncated, expected)
  26. truncated = ts.truncate(start_missing, end_missing)
  27. tm.assert_equal(truncated, expected)
  28. # start specified
  29. expected = ts[1:]
  30. truncated = ts.truncate(before=start)
  31. tm.assert_equal(truncated, expected)
  32. truncated = ts.truncate(before=start_missing)
  33. tm.assert_equal(truncated, expected)
  34. # end specified
  35. expected = ts[:3]
  36. truncated = ts.truncate(after=end)
  37. tm.assert_equal(truncated, expected)
  38. truncated = ts.truncate(after=end_missing)
  39. tm.assert_equal(truncated, expected)
  40. # corner case, empty series/frame returned
  41. truncated = ts.truncate(after=ts.index[0] - ts.index.freq)
  42. assert len(truncated) == 0
  43. truncated = ts.truncate(before=ts.index[-1] + ts.index.freq)
  44. assert len(truncated) == 0
  45. msg = "Truncate: 2000-01-06 00:00:00 must be after 2000-02-04 00:00:00"
  46. with pytest.raises(ValueError, match=msg):
  47. ts.truncate(
  48. before=ts.index[-1] - ts.index.freq, after=ts.index[0] + ts.index.freq
  49. )
  50. def test_truncate_nonsortedindex(self, frame_or_series):
  51. # GH#17935
  52. obj = DataFrame({"A": ["a", "b", "c", "d", "e"]}, index=[5, 3, 2, 9, 0])
  53. obj = tm.get_obj(obj, frame_or_series)
  54. msg = "truncate requires a sorted index"
  55. with pytest.raises(ValueError, match=msg):
  56. obj.truncate(before=3, after=9)
  57. def test_sort_values_nonsortedindex(self):
  58. rng = date_range("2011-01-01", "2012-01-01", freq="W")
  59. ts = DataFrame(
  60. {"A": np.random.randn(len(rng)), "B": np.random.randn(len(rng))}, index=rng
  61. )
  62. decreasing = ts.sort_values("A", ascending=False)
  63. msg = "truncate requires a sorted index"
  64. with pytest.raises(ValueError, match=msg):
  65. decreasing.truncate(before="2011-11", after="2011-12")
  66. def test_truncate_nonsortedindex_axis1(self):
  67. # GH#17935
  68. df = DataFrame(
  69. {
  70. 3: np.random.randn(5),
  71. 20: np.random.randn(5),
  72. 2: np.random.randn(5),
  73. 0: np.random.randn(5),
  74. },
  75. columns=[3, 20, 2, 0],
  76. )
  77. msg = "truncate requires a sorted index"
  78. with pytest.raises(ValueError, match=msg):
  79. df.truncate(before=2, after=20, axis=1)
  80. @pytest.mark.parametrize(
  81. "before, after, indices",
  82. [(1, 2, [2, 1]), (None, 2, [2, 1, 0]), (1, None, [3, 2, 1])],
  83. )
  84. @pytest.mark.parametrize("dtyp", [*tm.ALL_REAL_NUMPY_DTYPES, "datetime64[ns]"])
  85. def test_truncate_decreasing_index(
  86. self, before, after, indices, dtyp, frame_or_series
  87. ):
  88. # https://github.com/pandas-dev/pandas/issues/33756
  89. idx = Index([3, 2, 1, 0], dtype=dtyp)
  90. if isinstance(idx, DatetimeIndex):
  91. before = pd.Timestamp(before) if before is not None else None
  92. after = pd.Timestamp(after) if after is not None else None
  93. indices = [pd.Timestamp(i) for i in indices]
  94. values = frame_or_series(range(len(idx)), index=idx)
  95. result = values.truncate(before=before, after=after)
  96. expected = values.loc[indices]
  97. tm.assert_equal(result, expected)
  98. def test_truncate_multiindex(self, frame_or_series):
  99. # GH 34564
  100. mi = pd.MultiIndex.from_product([[1, 2, 3, 4], ["A", "B"]], names=["L1", "L2"])
  101. s1 = DataFrame(range(mi.shape[0]), index=mi, columns=["col"])
  102. s1 = tm.get_obj(s1, frame_or_series)
  103. result = s1.truncate(before=2, after=3)
  104. df = DataFrame.from_dict(
  105. {"L1": [2, 2, 3, 3], "L2": ["A", "B", "A", "B"], "col": [2, 3, 4, 5]}
  106. )
  107. expected = df.set_index(["L1", "L2"])
  108. expected = tm.get_obj(expected, frame_or_series)
  109. tm.assert_equal(result, expected)
  110. def test_truncate_index_only_one_unique_value(self, frame_or_series):
  111. # GH 42365
  112. obj = Series(0, index=date_range("2021-06-30", "2021-06-30")).repeat(5)
  113. if frame_or_series is DataFrame:
  114. obj = obj.to_frame(name="a")
  115. truncated = obj.truncate("2021-06-28", "2021-07-01")
  116. tm.assert_equal(truncated, obj)