test_apply.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import warnings
  2. import numpy as np
  3. import pytest
  4. from pandas import (
  5. DataFrame,
  6. Index,
  7. MultiIndex,
  8. Series,
  9. Timestamp,
  10. concat,
  11. date_range,
  12. isna,
  13. notna,
  14. )
  15. import pandas._testing as tm
  16. from pandas.tseries import offsets
  17. def f(x):
  18. # suppress warnings about empty slices, as we are deliberately testing
  19. # with a 0-length Series
  20. with warnings.catch_warnings():
  21. warnings.filterwarnings(
  22. "ignore",
  23. message=".*(empty slice|0 for slice).*",
  24. category=RuntimeWarning,
  25. )
  26. return x[np.isfinite(x)].mean()
  27. @pytest.mark.parametrize("bad_raw", [None, 1, 0])
  28. def test_rolling_apply_invalid_raw(bad_raw):
  29. with pytest.raises(ValueError, match="raw parameter must be `True` or `False`"):
  30. Series(range(3)).rolling(1).apply(len, raw=bad_raw)
  31. def test_rolling_apply_out_of_bounds(engine_and_raw):
  32. # gh-1850
  33. engine, raw = engine_and_raw
  34. vals = Series([1, 2, 3, 4])
  35. result = vals.rolling(10).apply(np.sum, engine=engine, raw=raw)
  36. assert result.isna().all()
  37. result = vals.rolling(10, min_periods=1).apply(np.sum, engine=engine, raw=raw)
  38. expected = Series([1, 3, 6, 10], dtype=float)
  39. tm.assert_almost_equal(result, expected)
  40. @pytest.mark.parametrize("window", [2, "2s"])
  41. def test_rolling_apply_with_pandas_objects(window):
  42. # 5071
  43. df = DataFrame(
  44. {"A": np.random.randn(5), "B": np.random.randint(0, 10, size=5)},
  45. index=date_range("20130101", periods=5, freq="s"),
  46. )
  47. # we have an equal spaced timeseries index
  48. # so simulate removing the first period
  49. def f(x):
  50. if x.index[0] == df.index[0]:
  51. return np.nan
  52. return x.iloc[-1]
  53. result = df.rolling(window).apply(f, raw=False)
  54. expected = df.iloc[2:].reindex_like(df)
  55. tm.assert_frame_equal(result, expected)
  56. with tm.external_error_raised(AttributeError):
  57. df.rolling(window).apply(f, raw=True)
  58. def test_rolling_apply(engine_and_raw, step):
  59. engine, raw = engine_and_raw
  60. expected = Series([], dtype="float64")
  61. result = expected.rolling(10, step=step).apply(
  62. lambda x: x.mean(), engine=engine, raw=raw
  63. )
  64. tm.assert_series_equal(result, expected)
  65. # gh-8080
  66. s = Series([None, None, None])
  67. result = s.rolling(2, min_periods=0, step=step).apply(
  68. lambda x: len(x), engine=engine, raw=raw
  69. )
  70. expected = Series([1.0, 2.0, 2.0])[::step]
  71. tm.assert_series_equal(result, expected)
  72. result = s.rolling(2, min_periods=0, step=step).apply(len, engine=engine, raw=raw)
  73. tm.assert_series_equal(result, expected)
  74. def test_all_apply(engine_and_raw):
  75. engine, raw = engine_and_raw
  76. df = (
  77. DataFrame(
  78. {"A": date_range("20130101", periods=5, freq="s"), "B": range(5)}
  79. ).set_index("A")
  80. * 2
  81. )
  82. er = df.rolling(window=1)
  83. r = df.rolling(window="1s")
  84. result = r.apply(lambda x: 1, engine=engine, raw=raw)
  85. expected = er.apply(lambda x: 1, engine=engine, raw=raw)
  86. tm.assert_frame_equal(result, expected)
  87. def test_ragged_apply(engine_and_raw):
  88. engine, raw = engine_and_raw
  89. df = DataFrame({"B": range(5)})
  90. df.index = [
  91. Timestamp("20130101 09:00:00"),
  92. Timestamp("20130101 09:00:02"),
  93. Timestamp("20130101 09:00:03"),
  94. Timestamp("20130101 09:00:05"),
  95. Timestamp("20130101 09:00:06"),
  96. ]
  97. f = lambda x: 1
  98. result = df.rolling(window="1s", min_periods=1).apply(f, engine=engine, raw=raw)
  99. expected = df.copy()
  100. expected["B"] = 1.0
  101. tm.assert_frame_equal(result, expected)
  102. result = df.rolling(window="2s", min_periods=1).apply(f, engine=engine, raw=raw)
  103. expected = df.copy()
  104. expected["B"] = 1.0
  105. tm.assert_frame_equal(result, expected)
  106. result = df.rolling(window="5s", min_periods=1).apply(f, engine=engine, raw=raw)
  107. expected = df.copy()
  108. expected["B"] = 1.0
  109. tm.assert_frame_equal(result, expected)
  110. def test_invalid_engine():
  111. with pytest.raises(ValueError, match="engine must be either 'numba' or 'cython'"):
  112. Series(range(1)).rolling(1).apply(lambda x: x, engine="foo")
  113. def test_invalid_engine_kwargs_cython():
  114. with pytest.raises(ValueError, match="cython engine does not accept engine_kwargs"):
  115. Series(range(1)).rolling(1).apply(
  116. lambda x: x, engine="cython", engine_kwargs={"nopython": False}
  117. )
  118. def test_invalid_raw_numba():
  119. with pytest.raises(
  120. ValueError, match="raw must be `True` when using the numba engine"
  121. ):
  122. Series(range(1)).rolling(1).apply(lambda x: x, raw=False, engine="numba")
  123. @pytest.mark.parametrize("args_kwargs", [[None, {"par": 10}], [(10,), None]])
  124. def test_rolling_apply_args_kwargs(args_kwargs):
  125. # GH 33433
  126. def numpysum(x, par):
  127. return np.sum(x + par)
  128. df = DataFrame({"gr": [1, 1], "a": [1, 2]})
  129. idx = Index(["gr", "a"])
  130. expected = DataFrame([[11.0, 11.0], [11.0, 12.0]], columns=idx)
  131. result = df.rolling(1).apply(numpysum, args=args_kwargs[0], kwargs=args_kwargs[1])
  132. tm.assert_frame_equal(result, expected)
  133. midx = MultiIndex.from_tuples([(1, 0), (1, 1)], names=["gr", None])
  134. expected = Series([11.0, 12.0], index=midx, name="a")
  135. gb_rolling = df.groupby("gr")["a"].rolling(1)
  136. result = gb_rolling.apply(numpysum, args=args_kwargs[0], kwargs=args_kwargs[1])
  137. tm.assert_series_equal(result, expected)
  138. def test_nans(raw):
  139. obj = Series(np.random.randn(50))
  140. obj[:10] = np.NaN
  141. obj[-10:] = np.NaN
  142. result = obj.rolling(50, min_periods=30).apply(f, raw=raw)
  143. tm.assert_almost_equal(result.iloc[-1], np.mean(obj[10:-10]))
  144. # min_periods is working correctly
  145. result = obj.rolling(20, min_periods=15).apply(f, raw=raw)
  146. assert isna(result.iloc[23])
  147. assert not isna(result.iloc[24])
  148. assert not isna(result.iloc[-6])
  149. assert isna(result.iloc[-5])
  150. obj2 = Series(np.random.randn(20))
  151. result = obj2.rolling(10, min_periods=5).apply(f, raw=raw)
  152. assert isna(result.iloc[3])
  153. assert notna(result.iloc[4])
  154. result0 = obj.rolling(20, min_periods=0).apply(f, raw=raw)
  155. result1 = obj.rolling(20, min_periods=1).apply(f, raw=raw)
  156. tm.assert_almost_equal(result0, result1)
  157. def test_center(raw):
  158. obj = Series(np.random.randn(50))
  159. obj[:10] = np.NaN
  160. obj[-10:] = np.NaN
  161. result = obj.rolling(20, min_periods=15, center=True).apply(f, raw=raw)
  162. expected = (
  163. concat([obj, Series([np.NaN] * 9)])
  164. .rolling(20, min_periods=15)
  165. .apply(f, raw=raw)
  166. .iloc[9:]
  167. .reset_index(drop=True)
  168. )
  169. tm.assert_series_equal(result, expected)
  170. def test_series(raw, series):
  171. result = series.rolling(50).apply(f, raw=raw)
  172. assert isinstance(result, Series)
  173. tm.assert_almost_equal(result.iloc[-1], np.mean(series[-50:]))
  174. def test_frame(raw, frame):
  175. result = frame.rolling(50).apply(f, raw=raw)
  176. assert isinstance(result, DataFrame)
  177. tm.assert_series_equal(
  178. result.iloc[-1, :],
  179. frame.iloc[-50:, :].apply(np.mean, axis=0, raw=raw),
  180. check_names=False,
  181. )
  182. def test_time_rule_series(raw, series):
  183. win = 25
  184. minp = 10
  185. ser = series[::2].resample("B").mean()
  186. series_result = ser.rolling(window=win, min_periods=minp).apply(f, raw=raw)
  187. last_date = series_result.index[-1]
  188. prev_date = last_date - 24 * offsets.BDay()
  189. trunc_series = series[::2].truncate(prev_date, last_date)
  190. tm.assert_almost_equal(series_result[-1], np.mean(trunc_series))
  191. def test_time_rule_frame(raw, frame):
  192. win = 25
  193. minp = 10
  194. frm = frame[::2].resample("B").mean()
  195. frame_result = frm.rolling(window=win, min_periods=minp).apply(f, raw=raw)
  196. last_date = frame_result.index[-1]
  197. prev_date = last_date - 24 * offsets.BDay()
  198. trunc_frame = frame[::2].truncate(prev_date, last_date)
  199. tm.assert_series_equal(
  200. frame_result.xs(last_date),
  201. trunc_frame.apply(np.mean, raw=raw),
  202. check_names=False,
  203. )
  204. @pytest.mark.parametrize("minp", [0, 99, 100])
  205. def test_min_periods(raw, series, minp, step):
  206. result = series.rolling(len(series) + 1, min_periods=minp, step=step).apply(
  207. f, raw=raw
  208. )
  209. expected = series.rolling(len(series), min_periods=minp, step=step).apply(
  210. f, raw=raw
  211. )
  212. nan_mask = isna(result)
  213. tm.assert_series_equal(nan_mask, isna(expected))
  214. nan_mask = ~nan_mask
  215. tm.assert_almost_equal(result[nan_mask], expected[nan_mask])
  216. def test_center_reindex_series(raw, series):
  217. # shifter index
  218. s = [f"x{x:d}" for x in range(12)]
  219. minp = 10
  220. series_xp = (
  221. series.reindex(list(series.index) + s)
  222. .rolling(window=25, min_periods=minp)
  223. .apply(f, raw=raw)
  224. .shift(-12)
  225. .reindex(series.index)
  226. )
  227. series_rs = series.rolling(window=25, min_periods=minp, center=True).apply(
  228. f, raw=raw
  229. )
  230. tm.assert_series_equal(series_xp, series_rs)
  231. def test_center_reindex_frame(raw, frame):
  232. # shifter index
  233. s = [f"x{x:d}" for x in range(12)]
  234. minp = 10
  235. frame_xp = (
  236. frame.reindex(list(frame.index) + s)
  237. .rolling(window=25, min_periods=minp)
  238. .apply(f, raw=raw)
  239. .shift(-12)
  240. .reindex(frame.index)
  241. )
  242. frame_rs = frame.rolling(window=25, min_periods=minp, center=True).apply(f, raw=raw)
  243. tm.assert_frame_equal(frame_xp, frame_rs)
  244. def test_axis1(raw):
  245. # GH 45912
  246. df = DataFrame([1, 2])
  247. result = df.rolling(window=1, axis=1).apply(np.sum, raw=raw)
  248. expected = DataFrame([1.0, 2.0])
  249. tm.assert_frame_equal(result, expected)