test_ewm.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. import numpy as np
  2. import pytest
  3. from pandas import (
  4. DataFrame,
  5. DatetimeIndex,
  6. Series,
  7. date_range,
  8. )
  9. import pandas._testing as tm
  10. def test_doc_string():
  11. df = DataFrame({"B": [0, 1, 2, np.nan, 4]})
  12. df
  13. df.ewm(com=0.5).mean()
  14. def test_constructor(frame_or_series):
  15. c = frame_or_series(range(5)).ewm
  16. # valid
  17. c(com=0.5)
  18. c(span=1.5)
  19. c(alpha=0.5)
  20. c(halflife=0.75)
  21. c(com=0.5, span=None)
  22. c(alpha=0.5, com=None)
  23. c(halflife=0.75, alpha=None)
  24. # not valid: mutually exclusive
  25. msg = "comass, span, halflife, and alpha are mutually exclusive"
  26. with pytest.raises(ValueError, match=msg):
  27. c(com=0.5, alpha=0.5)
  28. with pytest.raises(ValueError, match=msg):
  29. c(span=1.5, halflife=0.75)
  30. with pytest.raises(ValueError, match=msg):
  31. c(alpha=0.5, span=1.5)
  32. # not valid: com < 0
  33. msg = "comass must satisfy: comass >= 0"
  34. with pytest.raises(ValueError, match=msg):
  35. c(com=-0.5)
  36. # not valid: span < 1
  37. msg = "span must satisfy: span >= 1"
  38. with pytest.raises(ValueError, match=msg):
  39. c(span=0.5)
  40. # not valid: halflife <= 0
  41. msg = "halflife must satisfy: halflife > 0"
  42. with pytest.raises(ValueError, match=msg):
  43. c(halflife=0)
  44. # not valid: alpha <= 0 or alpha > 1
  45. msg = "alpha must satisfy: 0 < alpha <= 1"
  46. for alpha in (-0.5, 1.5):
  47. with pytest.raises(ValueError, match=msg):
  48. c(alpha=alpha)
  49. def test_ewma_times_not_datetime_type():
  50. msg = r"times must be datetime64\[ns\] dtype."
  51. with pytest.raises(ValueError, match=msg):
  52. Series(range(5)).ewm(times=np.arange(5))
  53. def test_ewma_times_not_same_length():
  54. msg = "times must be the same length as the object."
  55. with pytest.raises(ValueError, match=msg):
  56. Series(range(5)).ewm(times=np.arange(4).astype("datetime64[ns]"))
  57. def test_ewma_halflife_not_correct_type():
  58. msg = "halflife must be a timedelta convertible object"
  59. with pytest.raises(ValueError, match=msg):
  60. Series(range(5)).ewm(halflife=1, times=np.arange(5).astype("datetime64[ns]"))
  61. def test_ewma_halflife_without_times(halflife_with_times):
  62. msg = "halflife can only be a timedelta convertible argument if times is not None."
  63. with pytest.raises(ValueError, match=msg):
  64. Series(range(5)).ewm(halflife=halflife_with_times)
  65. @pytest.mark.parametrize(
  66. "times",
  67. [
  68. np.arange(10).astype("datetime64[D]").astype("datetime64[ns]"),
  69. date_range("2000", freq="D", periods=10),
  70. date_range("2000", freq="D", periods=10).tz_localize("UTC"),
  71. ],
  72. )
  73. @pytest.mark.parametrize("min_periods", [0, 2])
  74. def test_ewma_with_times_equal_spacing(halflife_with_times, times, min_periods):
  75. halflife = halflife_with_times
  76. data = np.arange(10.0)
  77. data[::2] = np.nan
  78. df = DataFrame({"A": data})
  79. result = df.ewm(halflife=halflife, min_periods=min_periods, times=times).mean()
  80. expected = df.ewm(halflife=1.0, min_periods=min_periods).mean()
  81. tm.assert_frame_equal(result, expected)
  82. def test_ewma_with_times_variable_spacing(tz_aware_fixture):
  83. tz = tz_aware_fixture
  84. halflife = "23 days"
  85. times = DatetimeIndex(
  86. ["2020-01-01", "2020-01-10T00:04:05", "2020-02-23T05:00:23"]
  87. ).tz_localize(tz)
  88. data = np.arange(3)
  89. df = DataFrame(data)
  90. result = df.ewm(halflife=halflife, times=times).mean()
  91. expected = DataFrame([0.0, 0.5674161888241773, 1.545239952073459])
  92. tm.assert_frame_equal(result, expected)
  93. def test_ewm_with_nat_raises(halflife_with_times):
  94. # GH#38535
  95. ser = Series(range(1))
  96. times = DatetimeIndex(["NaT"])
  97. with pytest.raises(ValueError, match="Cannot convert NaT values to integer"):
  98. ser.ewm(com=0.1, halflife=halflife_with_times, times=times)
  99. def test_ewm_with_times_getitem(halflife_with_times):
  100. # GH 40164
  101. halflife = halflife_with_times
  102. data = np.arange(10.0)
  103. data[::2] = np.nan
  104. times = date_range("2000", freq="D", periods=10)
  105. df = DataFrame({"A": data, "B": data})
  106. result = df.ewm(halflife=halflife, times=times)["A"].mean()
  107. expected = df.ewm(halflife=1.0)["A"].mean()
  108. tm.assert_series_equal(result, expected)
  109. @pytest.mark.parametrize("arg", ["com", "halflife", "span", "alpha"])
  110. def test_ewm_getitem_attributes_retained(arg, adjust, ignore_na):
  111. # GH 40164
  112. kwargs = {arg: 1, "adjust": adjust, "ignore_na": ignore_na}
  113. ewm = DataFrame({"A": range(1), "B": range(1)}).ewm(**kwargs)
  114. expected = {attr: getattr(ewm, attr) for attr in ewm._attributes}
  115. ewm_slice = ewm["A"]
  116. result = {attr: getattr(ewm, attr) for attr in ewm_slice._attributes}
  117. assert result == expected
  118. def test_ewma_times_adjust_false_raises():
  119. # GH 40098
  120. with pytest.raises(
  121. NotImplementedError, match="times is not supported with adjust=False."
  122. ):
  123. Series(range(1)).ewm(
  124. 0.1, adjust=False, times=date_range("2000", freq="D", periods=1)
  125. )
  126. @pytest.mark.parametrize(
  127. "func, expected",
  128. [
  129. [
  130. "mean",
  131. DataFrame(
  132. {
  133. 0: range(5),
  134. 1: range(4, 9),
  135. 2: [7.428571, 9, 10.571429, 12.142857, 13.714286],
  136. },
  137. dtype=float,
  138. ),
  139. ],
  140. [
  141. "std",
  142. DataFrame(
  143. {
  144. 0: [np.nan] * 5,
  145. 1: [4.242641] * 5,
  146. 2: [4.6291, 5.196152, 5.781745, 6.380775, 6.989788],
  147. }
  148. ),
  149. ],
  150. [
  151. "var",
  152. DataFrame(
  153. {
  154. 0: [np.nan] * 5,
  155. 1: [18.0] * 5,
  156. 2: [21.428571, 27, 33.428571, 40.714286, 48.857143],
  157. }
  158. ),
  159. ],
  160. ],
  161. )
  162. def test_float_dtype_ewma(func, expected, float_numpy_dtype):
  163. # GH#42452
  164. df = DataFrame(
  165. {0: range(5), 1: range(6, 11), 2: range(10, 20, 2)}, dtype=float_numpy_dtype
  166. )
  167. e = df.ewm(alpha=0.5, axis=1)
  168. result = getattr(e, func)()
  169. tm.assert_frame_equal(result, expected)
  170. def test_times_string_col_raises():
  171. # GH 43265
  172. df = DataFrame(
  173. {"A": np.arange(10.0), "time_col": date_range("2000", freq="D", periods=10)}
  174. )
  175. with pytest.raises(ValueError, match="times must be datetime64"):
  176. df.ewm(halflife="1 day", min_periods=0, times="time_col")
  177. def test_ewm_sum_adjust_false_notimplemented():
  178. data = Series(range(1)).ewm(com=1, adjust=False)
  179. with pytest.raises(NotImplementedError, match="sum is not"):
  180. data.sum()
  181. @pytest.mark.parametrize(
  182. "expected_data, ignore",
  183. [[[10.0, 5.0, 2.5, 11.25], False], [[10.0, 5.0, 5.0, 12.5], True]],
  184. )
  185. def test_ewm_sum(expected_data, ignore):
  186. # xref from Numbagg tests
  187. # https://github.com/numbagg/numbagg/blob/v0.2.1/numbagg/test/test_moving.py#L50
  188. data = Series([10, 0, np.nan, 10])
  189. result = data.ewm(alpha=0.5, ignore_na=ignore).sum()
  190. expected = Series(expected_data)
  191. tm.assert_series_equal(result, expected)
  192. def test_ewma_adjust():
  193. vals = Series(np.zeros(1000))
  194. vals[5] = 1
  195. result = vals.ewm(span=100, adjust=False).mean().sum()
  196. assert np.abs(result - 1) < 1e-2
  197. def test_ewma_cases(adjust, ignore_na):
  198. # try adjust/ignore_na args matrix
  199. s = Series([1.0, 2.0, 4.0, 8.0])
  200. if adjust:
  201. expected = Series([1.0, 1.6, 2.736842, 4.923077])
  202. else:
  203. expected = Series([1.0, 1.333333, 2.222222, 4.148148])
  204. result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean()
  205. tm.assert_series_equal(result, expected)
  206. def test_ewma_nan_handling():
  207. s = Series([1.0] + [np.nan] * 5 + [1.0])
  208. result = s.ewm(com=5).mean()
  209. tm.assert_series_equal(result, Series([1.0] * len(s)))
  210. s = Series([np.nan] * 2 + [1.0] + [np.nan] * 2 + [1.0])
  211. result = s.ewm(com=5).mean()
  212. tm.assert_series_equal(result, Series([np.nan] * 2 + [1.0] * 4))
  213. @pytest.mark.parametrize(
  214. "s, adjust, ignore_na, w",
  215. [
  216. (
  217. Series([np.nan, 1.0, 101.0]),
  218. True,
  219. False,
  220. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0],
  221. ),
  222. (
  223. Series([np.nan, 1.0, 101.0]),
  224. True,
  225. True,
  226. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), 1.0],
  227. ),
  228. (
  229. Series([np.nan, 1.0, 101.0]),
  230. False,
  231. False,
  232. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))],
  233. ),
  234. (
  235. Series([np.nan, 1.0, 101.0]),
  236. False,
  237. True,
  238. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), (1.0 / (1.0 + 2.0))],
  239. ),
  240. (
  241. Series([1.0, np.nan, 101.0]),
  242. True,
  243. False,
  244. [(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, 1.0],
  245. ),
  246. (
  247. Series([1.0, np.nan, 101.0]),
  248. True,
  249. True,
  250. [(1.0 - (1.0 / (1.0 + 2.0))), np.nan, 1.0],
  251. ),
  252. (
  253. Series([1.0, np.nan, 101.0]),
  254. False,
  255. False,
  256. [(1.0 - (1.0 / (1.0 + 2.0))) ** 2, np.nan, (1.0 / (1.0 + 2.0))],
  257. ),
  258. (
  259. Series([1.0, np.nan, 101.0]),
  260. False,
  261. True,
  262. [(1.0 - (1.0 / (1.0 + 2.0))), np.nan, (1.0 / (1.0 + 2.0))],
  263. ),
  264. (
  265. Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
  266. True,
  267. False,
  268. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))) ** 3, np.nan, np.nan, 1.0, np.nan],
  269. ),
  270. (
  271. Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
  272. True,
  273. True,
  274. [np.nan, (1.0 - (1.0 / (1.0 + 2.0))), np.nan, np.nan, 1.0, np.nan],
  275. ),
  276. (
  277. Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
  278. False,
  279. False,
  280. [
  281. np.nan,
  282. (1.0 - (1.0 / (1.0 + 2.0))) ** 3,
  283. np.nan,
  284. np.nan,
  285. (1.0 / (1.0 + 2.0)),
  286. np.nan,
  287. ],
  288. ),
  289. (
  290. Series([np.nan, 1.0, np.nan, np.nan, 101.0, np.nan]),
  291. False,
  292. True,
  293. [
  294. np.nan,
  295. (1.0 - (1.0 / (1.0 + 2.0))),
  296. np.nan,
  297. np.nan,
  298. (1.0 / (1.0 + 2.0)),
  299. np.nan,
  300. ],
  301. ),
  302. (
  303. Series([1.0, np.nan, 101.0, 50.0]),
  304. True,
  305. False,
  306. [
  307. (1.0 - (1.0 / (1.0 + 2.0))) ** 3,
  308. np.nan,
  309. (1.0 - (1.0 / (1.0 + 2.0))),
  310. 1.0,
  311. ],
  312. ),
  313. (
  314. Series([1.0, np.nan, 101.0, 50.0]),
  315. True,
  316. True,
  317. [
  318. (1.0 - (1.0 / (1.0 + 2.0))) ** 2,
  319. np.nan,
  320. (1.0 - (1.0 / (1.0 + 2.0))),
  321. 1.0,
  322. ],
  323. ),
  324. (
  325. Series([1.0, np.nan, 101.0, 50.0]),
  326. False,
  327. False,
  328. [
  329. (1.0 - (1.0 / (1.0 + 2.0))) ** 3,
  330. np.nan,
  331. (1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)),
  332. (1.0 / (1.0 + 2.0))
  333. * ((1.0 - (1.0 / (1.0 + 2.0))) ** 2 + (1.0 / (1.0 + 2.0))),
  334. ],
  335. ),
  336. (
  337. Series([1.0, np.nan, 101.0, 50.0]),
  338. False,
  339. True,
  340. [
  341. (1.0 - (1.0 / (1.0 + 2.0))) ** 2,
  342. np.nan,
  343. (1.0 - (1.0 / (1.0 + 2.0))) * (1.0 / (1.0 + 2.0)),
  344. (1.0 / (1.0 + 2.0)),
  345. ],
  346. ),
  347. ],
  348. )
  349. def test_ewma_nan_handling_cases(s, adjust, ignore_na, w):
  350. # GH 7603
  351. expected = (s.multiply(w).cumsum() / Series(w).cumsum()).fillna(method="ffill")
  352. result = s.ewm(com=2.0, adjust=adjust, ignore_na=ignore_na).mean()
  353. tm.assert_series_equal(result, expected)
  354. if ignore_na is False:
  355. # check that ignore_na defaults to False
  356. result = s.ewm(com=2.0, adjust=adjust).mean()
  357. tm.assert_series_equal(result, expected)
  358. def test_ewm_alpha():
  359. # GH 10789
  360. arr = np.random.randn(100)
  361. locs = np.arange(20, 40)
  362. arr[locs] = np.NaN
  363. s = Series(arr)
  364. a = s.ewm(alpha=0.61722699889169674).mean()
  365. b = s.ewm(com=0.62014947789973052).mean()
  366. c = s.ewm(span=2.240298955799461).mean()
  367. d = s.ewm(halflife=0.721792864318).mean()
  368. tm.assert_series_equal(a, b)
  369. tm.assert_series_equal(a, c)
  370. tm.assert_series_equal(a, d)
  371. def test_ewm_domain_checks():
  372. # GH 12492
  373. arr = np.random.randn(100)
  374. locs = np.arange(20, 40)
  375. arr[locs] = np.NaN
  376. s = Series(arr)
  377. msg = "comass must satisfy: comass >= 0"
  378. with pytest.raises(ValueError, match=msg):
  379. s.ewm(com=-0.1)
  380. s.ewm(com=0.0)
  381. s.ewm(com=0.1)
  382. msg = "span must satisfy: span >= 1"
  383. with pytest.raises(ValueError, match=msg):
  384. s.ewm(span=-0.1)
  385. with pytest.raises(ValueError, match=msg):
  386. s.ewm(span=0.0)
  387. with pytest.raises(ValueError, match=msg):
  388. s.ewm(span=0.9)
  389. s.ewm(span=1.0)
  390. s.ewm(span=1.1)
  391. msg = "halflife must satisfy: halflife > 0"
  392. with pytest.raises(ValueError, match=msg):
  393. s.ewm(halflife=-0.1)
  394. with pytest.raises(ValueError, match=msg):
  395. s.ewm(halflife=0.0)
  396. s.ewm(halflife=0.1)
  397. msg = "alpha must satisfy: 0 < alpha <= 1"
  398. with pytest.raises(ValueError, match=msg):
  399. s.ewm(alpha=-0.1)
  400. with pytest.raises(ValueError, match=msg):
  401. s.ewm(alpha=0.0)
  402. s.ewm(alpha=0.1)
  403. s.ewm(alpha=1.0)
  404. with pytest.raises(ValueError, match=msg):
  405. s.ewm(alpha=1.1)
  406. @pytest.mark.parametrize("method", ["mean", "std", "var"])
  407. def test_ew_empty_series(method):
  408. vals = Series([], dtype=np.float64)
  409. ewm = vals.ewm(3)
  410. result = getattr(ewm, method)()
  411. tm.assert_almost_equal(result, vals)
  412. @pytest.mark.parametrize("min_periods", [0, 1])
  413. @pytest.mark.parametrize("name", ["mean", "var", "std"])
  414. def test_ew_min_periods(min_periods, name):
  415. # excluding NaNs correctly
  416. arr = np.random.randn(50)
  417. arr[:10] = np.NaN
  418. arr[-10:] = np.NaN
  419. s = Series(arr)
  420. # check min_periods
  421. # GH 7898
  422. result = getattr(s.ewm(com=50, min_periods=2), name)()
  423. assert result[:11].isna().all()
  424. assert not result[11:].isna().any()
  425. result = getattr(s.ewm(com=50, min_periods=min_periods), name)()
  426. if name == "mean":
  427. assert result[:10].isna().all()
  428. assert not result[10:].isna().any()
  429. else:
  430. # ewm.std, ewm.var (with bias=False) require at least
  431. # two values
  432. assert result[:11].isna().all()
  433. assert not result[11:].isna().any()
  434. # check series of length 0
  435. result = getattr(Series(dtype=object).ewm(com=50, min_periods=min_periods), name)()
  436. tm.assert_series_equal(result, Series(dtype="float64"))
  437. # check series of length 1
  438. result = getattr(Series([1.0]).ewm(50, min_periods=min_periods), name)()
  439. if name == "mean":
  440. tm.assert_series_equal(result, Series([1.0]))
  441. else:
  442. # ewm.std, ewm.var with bias=False require at least
  443. # two values
  444. tm.assert_series_equal(result, Series([np.NaN]))
  445. # pass in ints
  446. result2 = getattr(Series(np.arange(50)).ewm(span=10), name)()
  447. assert result2.dtype == np.float_
  448. @pytest.mark.parametrize("name", ["cov", "corr"])
  449. def test_ewm_corr_cov(name):
  450. A = Series(np.random.randn(50), index=range(50))
  451. B = A[2:] + np.random.randn(48)
  452. A[:10] = np.NaN
  453. B.iloc[-10:] = np.NaN
  454. result = getattr(A.ewm(com=20, min_periods=5), name)(B)
  455. assert np.isnan(result.values[:14]).all()
  456. assert not np.isnan(result.values[14:]).any()
  457. @pytest.mark.parametrize("min_periods", [0, 1, 2])
  458. @pytest.mark.parametrize("name", ["cov", "corr"])
  459. def test_ewm_corr_cov_min_periods(name, min_periods):
  460. # GH 7898
  461. A = Series(np.random.randn(50), index=range(50))
  462. B = A[2:] + np.random.randn(48)
  463. A[:10] = np.NaN
  464. B.iloc[-10:] = np.NaN
  465. result = getattr(A.ewm(com=20, min_periods=min_periods), name)(B)
  466. # binary functions (ewmcov, ewmcorr) with bias=False require at
  467. # least two values
  468. assert np.isnan(result.values[:11]).all()
  469. assert not np.isnan(result.values[11:]).any()
  470. # check series of length 0
  471. empty = Series([], dtype=np.float64)
  472. result = getattr(empty.ewm(com=50, min_periods=min_periods), name)(empty)
  473. tm.assert_series_equal(result, empty)
  474. # check series of length 1
  475. result = getattr(Series([1.0]).ewm(com=50, min_periods=min_periods), name)(
  476. Series([1.0])
  477. )
  478. tm.assert_series_equal(result, Series([np.NaN]))
  479. @pytest.mark.parametrize("name", ["cov", "corr"])
  480. def test_different_input_array_raise_exception(name):
  481. A = Series(np.random.randn(50), index=range(50))
  482. A[:10] = np.NaN
  483. msg = "other must be a DataFrame or Series"
  484. # exception raised is Exception
  485. with pytest.raises(ValueError, match=msg):
  486. getattr(A.ewm(com=20, min_periods=5), name)(np.random.randn(50))
  487. @pytest.mark.parametrize("name", ["var", "std", "mean"])
  488. def test_ewma_series(series, name):
  489. series_result = getattr(series.ewm(com=10), name)()
  490. assert isinstance(series_result, Series)
  491. @pytest.mark.parametrize("name", ["var", "std", "mean"])
  492. def test_ewma_frame(frame, name):
  493. frame_result = getattr(frame.ewm(com=10), name)()
  494. assert isinstance(frame_result, DataFrame)
  495. def test_ewma_span_com_args(series):
  496. A = series.ewm(com=9.5).mean()
  497. B = series.ewm(span=20).mean()
  498. tm.assert_almost_equal(A, B)
  499. msg = "comass, span, halflife, and alpha are mutually exclusive"
  500. with pytest.raises(ValueError, match=msg):
  501. series.ewm(com=9.5, span=20)
  502. msg = "Must pass one of comass, span, halflife, or alpha"
  503. with pytest.raises(ValueError, match=msg):
  504. series.ewm().mean()
  505. def test_ewma_halflife_arg(series):
  506. A = series.ewm(com=13.932726172912965).mean()
  507. B = series.ewm(halflife=10.0).mean()
  508. tm.assert_almost_equal(A, B)
  509. msg = "comass, span, halflife, and alpha are mutually exclusive"
  510. with pytest.raises(ValueError, match=msg):
  511. series.ewm(span=20, halflife=50)
  512. with pytest.raises(ValueError, match=msg):
  513. series.ewm(com=9.5, halflife=50)
  514. with pytest.raises(ValueError, match=msg):
  515. series.ewm(com=9.5, span=20, halflife=50)
  516. msg = "Must pass one of comass, span, halflife, or alpha"
  517. with pytest.raises(ValueError, match=msg):
  518. series.ewm()
  519. def test_ewm_alpha_arg(series):
  520. # GH 10789
  521. s = series
  522. msg = "Must pass one of comass, span, halflife, or alpha"
  523. with pytest.raises(ValueError, match=msg):
  524. s.ewm()
  525. msg = "comass, span, halflife, and alpha are mutually exclusive"
  526. with pytest.raises(ValueError, match=msg):
  527. s.ewm(com=10.0, alpha=0.5)
  528. with pytest.raises(ValueError, match=msg):
  529. s.ewm(span=10.0, alpha=0.5)
  530. with pytest.raises(ValueError, match=msg):
  531. s.ewm(halflife=10.0, alpha=0.5)
  532. @pytest.mark.parametrize("func", ["cov", "corr"])
  533. def test_ewm_pairwise_cov_corr(func, frame):
  534. result = getattr(frame.ewm(span=10, min_periods=5), func)()
  535. result = result.loc[(slice(None), 1), 5]
  536. result.index = result.index.droplevel(1)
  537. expected = getattr(frame[1].ewm(span=10, min_periods=5), func)(frame[5])
  538. tm.assert_series_equal(result, expected, check_names=False)
  539. def test_numeric_only_frame(arithmetic_win_operators, numeric_only):
  540. # GH#46560
  541. kernel = arithmetic_win_operators
  542. df = DataFrame({"a": [1], "b": 2, "c": 3})
  543. df["c"] = df["c"].astype(object)
  544. ewm = df.ewm(span=2, min_periods=1)
  545. op = getattr(ewm, kernel, None)
  546. if op is not None:
  547. result = op(numeric_only=numeric_only)
  548. columns = ["a", "b"] if numeric_only else ["a", "b", "c"]
  549. expected = df[columns].agg([kernel]).reset_index(drop=True).astype(float)
  550. assert list(expected.columns) == columns
  551. tm.assert_frame_equal(result, expected)
  552. @pytest.mark.parametrize("kernel", ["corr", "cov"])
  553. @pytest.mark.parametrize("use_arg", [True, False])
  554. def test_numeric_only_corr_cov_frame(kernel, numeric_only, use_arg):
  555. # GH#46560
  556. df = DataFrame({"a": [1, 2, 3], "b": 2, "c": 3})
  557. df["c"] = df["c"].astype(object)
  558. arg = (df,) if use_arg else ()
  559. ewm = df.ewm(span=2, min_periods=1)
  560. op = getattr(ewm, kernel)
  561. result = op(*arg, numeric_only=numeric_only)
  562. # Compare result to op using float dtypes, dropping c when numeric_only is True
  563. columns = ["a", "b"] if numeric_only else ["a", "b", "c"]
  564. df2 = df[columns].astype(float)
  565. arg2 = (df2,) if use_arg else ()
  566. ewm2 = df2.ewm(span=2, min_periods=1)
  567. op2 = getattr(ewm2, kernel)
  568. expected = op2(*arg2, numeric_only=numeric_only)
  569. tm.assert_frame_equal(result, expected)
  570. @pytest.mark.parametrize("dtype", [int, object])
  571. def test_numeric_only_series(arithmetic_win_operators, numeric_only, dtype):
  572. # GH#46560
  573. kernel = arithmetic_win_operators
  574. ser = Series([1], dtype=dtype)
  575. ewm = ser.ewm(span=2, min_periods=1)
  576. op = getattr(ewm, kernel, None)
  577. if op is None:
  578. # Nothing to test
  579. return
  580. if numeric_only and dtype is object:
  581. msg = f"ExponentialMovingWindow.{kernel} does not implement numeric_only"
  582. with pytest.raises(NotImplementedError, match=msg):
  583. op(numeric_only=numeric_only)
  584. else:
  585. result = op(numeric_only=numeric_only)
  586. expected = ser.agg([kernel]).reset_index(drop=True).astype(float)
  587. tm.assert_series_equal(result, expected)
  588. @pytest.mark.parametrize("kernel", ["corr", "cov"])
  589. @pytest.mark.parametrize("use_arg", [True, False])
  590. @pytest.mark.parametrize("dtype", [int, object])
  591. def test_numeric_only_corr_cov_series(kernel, use_arg, numeric_only, dtype):
  592. # GH#46560
  593. ser = Series([1, 2, 3], dtype=dtype)
  594. arg = (ser,) if use_arg else ()
  595. ewm = ser.ewm(span=2, min_periods=1)
  596. op = getattr(ewm, kernel)
  597. if numeric_only and dtype is object:
  598. msg = f"ExponentialMovingWindow.{kernel} does not implement numeric_only"
  599. with pytest.raises(NotImplementedError, match=msg):
  600. op(*arg, numeric_only=numeric_only)
  601. else:
  602. result = op(*arg, numeric_only=numeric_only)
  603. ser2 = ser.astype(float)
  604. arg2 = (ser2,) if use_arg else ()
  605. ewm2 = ser2.ewm(span=2, min_periods=1)
  606. op2 = getattr(ewm2, kernel)
  607. expected = op2(*arg2, numeric_only=numeric_only)
  608. tm.assert_series_equal(result, expected)