ewm.py 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012
  1. from __future__ import annotations
  2. import datetime
  3. from functools import partial
  4. from textwrap import dedent
  5. from typing import TYPE_CHECKING
  6. import numpy as np
  7. from pandas._libs.tslibs import Timedelta
  8. import pandas._libs.window.aggregations as window_aggregations
  9. from pandas._typing import (
  10. Axis,
  11. TimedeltaConvertibleTypes,
  12. )
  13. if TYPE_CHECKING:
  14. from pandas import DataFrame, Series
  15. from pandas.core.generic import NDFrame
  16. from pandas.util._decorators import doc
  17. from pandas.core.dtypes.common import (
  18. is_datetime64_ns_dtype,
  19. is_numeric_dtype,
  20. )
  21. from pandas.core.dtypes.missing import isna
  22. from pandas.core import common
  23. from pandas.core.indexers.objects import (
  24. BaseIndexer,
  25. ExponentialMovingWindowIndexer,
  26. GroupbyIndexer,
  27. )
  28. from pandas.core.util.numba_ import (
  29. get_jit_arguments,
  30. maybe_use_numba,
  31. )
  32. from pandas.core.window.common import zsqrt
  33. from pandas.core.window.doc import (
  34. _shared_docs,
  35. create_section_header,
  36. kwargs_numeric_only,
  37. numba_notes,
  38. template_header,
  39. template_returns,
  40. template_see_also,
  41. window_agg_numba_parameters,
  42. )
  43. from pandas.core.window.numba_ import (
  44. generate_numba_ewm_func,
  45. generate_numba_ewm_table_func,
  46. )
  47. from pandas.core.window.online import (
  48. EWMMeanState,
  49. generate_online_numba_ewma_func,
  50. )
  51. from pandas.core.window.rolling import (
  52. BaseWindow,
  53. BaseWindowGroupby,
  54. )
  55. def get_center_of_mass(
  56. comass: float | None,
  57. span: float | None,
  58. halflife: float | None,
  59. alpha: float | None,
  60. ) -> float:
  61. valid_count = common.count_not_none(comass, span, halflife, alpha)
  62. if valid_count > 1:
  63. raise ValueError("comass, span, halflife, and alpha are mutually exclusive")
  64. # Convert to center of mass; domain checks ensure 0 < alpha <= 1
  65. if comass is not None:
  66. if comass < 0:
  67. raise ValueError("comass must satisfy: comass >= 0")
  68. elif span is not None:
  69. if span < 1:
  70. raise ValueError("span must satisfy: span >= 1")
  71. comass = (span - 1) / 2
  72. elif halflife is not None:
  73. if halflife <= 0:
  74. raise ValueError("halflife must satisfy: halflife > 0")
  75. decay = 1 - np.exp(np.log(0.5) / halflife)
  76. comass = 1 / decay - 1
  77. elif alpha is not None:
  78. if alpha <= 0 or alpha > 1:
  79. raise ValueError("alpha must satisfy: 0 < alpha <= 1")
  80. comass = (1 - alpha) / alpha
  81. else:
  82. raise ValueError("Must pass one of comass, span, halflife, or alpha")
  83. return float(comass)
  84. def _calculate_deltas(
  85. times: np.ndarray | NDFrame,
  86. halflife: float | TimedeltaConvertibleTypes | None,
  87. ) -> np.ndarray:
  88. """
  89. Return the diff of the times divided by the half-life. These values are used in
  90. the calculation of the ewm mean.
  91. Parameters
  92. ----------
  93. times : np.ndarray, Series
  94. Times corresponding to the observations. Must be monotonically increasing
  95. and ``datetime64[ns]`` dtype.
  96. halflife : float, str, timedelta, optional
  97. Half-life specifying the decay
  98. Returns
  99. -------
  100. np.ndarray
  101. Diff of the times divided by the half-life
  102. """
  103. _times = np.asarray(times.view(np.int64), dtype=np.float64)
  104. # TODO: generalize to non-nano?
  105. _halflife = float(Timedelta(halflife).as_unit("ns")._value)
  106. return np.diff(_times) / _halflife
  107. class ExponentialMovingWindow(BaseWindow):
  108. r"""
  109. Provide exponentially weighted (EW) calculations.
  110. Exactly one of ``com``, ``span``, ``halflife``, or ``alpha`` must be
  111. provided if ``times`` is not provided. If ``times`` is provided,
  112. ``halflife`` and one of ``com``, ``span`` or ``alpha`` may be provided.
  113. Parameters
  114. ----------
  115. com : float, optional
  116. Specify decay in terms of center of mass
  117. :math:`\alpha = 1 / (1 + com)`, for :math:`com \geq 0`.
  118. span : float, optional
  119. Specify decay in terms of span
  120. :math:`\alpha = 2 / (span + 1)`, for :math:`span \geq 1`.
  121. halflife : float, str, timedelta, optional
  122. Specify decay in terms of half-life
  123. :math:`\alpha = 1 - \exp\left(-\ln(2) / halflife\right)`, for
  124. :math:`halflife > 0`.
  125. If ``times`` is specified, a timedelta convertible unit over which an
  126. observation decays to half its value. Only applicable to ``mean()``,
  127. and halflife value will not apply to the other functions.
  128. .. versionadded:: 1.1.0
  129. alpha : float, optional
  130. Specify smoothing factor :math:`\alpha` directly
  131. :math:`0 < \alpha \leq 1`.
  132. min_periods : int, default 0
  133. Minimum number of observations in window required to have a value;
  134. otherwise, result is ``np.nan``.
  135. adjust : bool, default True
  136. Divide by decaying adjustment factor in beginning periods to account
  137. for imbalance in relative weightings (viewing EWMA as a moving average).
  138. - When ``adjust=True`` (default), the EW function is calculated using weights
  139. :math:`w_i = (1 - \alpha)^i`. For example, the EW moving average of the series
  140. [:math:`x_0, x_1, ..., x_t`] would be:
  141. .. math::
  142. y_t = \frac{x_t + (1 - \alpha)x_{t-1} + (1 - \alpha)^2 x_{t-2} + ... + (1 -
  143. \alpha)^t x_0}{1 + (1 - \alpha) + (1 - \alpha)^2 + ... + (1 - \alpha)^t}
  144. - When ``adjust=False``, the exponentially weighted function is calculated
  145. recursively:
  146. .. math::
  147. \begin{split}
  148. y_0 &= x_0\\
  149. y_t &= (1 - \alpha) y_{t-1} + \alpha x_t,
  150. \end{split}
  151. ignore_na : bool, default False
  152. Ignore missing values when calculating weights.
  153. - When ``ignore_na=False`` (default), weights are based on absolute positions.
  154. For example, the weights of :math:`x_0` and :math:`x_2` used in calculating
  155. the final weighted average of [:math:`x_0`, None, :math:`x_2`] are
  156. :math:`(1-\alpha)^2` and :math:`1` if ``adjust=True``, and
  157. :math:`(1-\alpha)^2` and :math:`\alpha` if ``adjust=False``.
  158. - When ``ignore_na=True``, weights are based
  159. on relative positions. For example, the weights of :math:`x_0` and :math:`x_2`
  160. used in calculating the final weighted average of
  161. [:math:`x_0`, None, :math:`x_2`] are :math:`1-\alpha` and :math:`1` if
  162. ``adjust=True``, and :math:`1-\alpha` and :math:`\alpha` if ``adjust=False``.
  163. axis : {0, 1}, default 0
  164. If ``0`` or ``'index'``, calculate across the rows.
  165. If ``1`` or ``'columns'``, calculate across the columns.
  166. For `Series` this parameter is unused and defaults to 0.
  167. times : np.ndarray, Series, default None
  168. .. versionadded:: 1.1.0
  169. Only applicable to ``mean()``.
  170. Times corresponding to the observations. Must be monotonically increasing and
  171. ``datetime64[ns]`` dtype.
  172. If 1-D array like, a sequence with the same shape as the observations.
  173. method : str {'single', 'table'}, default 'single'
  174. .. versionadded:: 1.4.0
  175. Execute the rolling operation per single column or row (``'single'``)
  176. or over the entire object (``'table'``).
  177. This argument is only implemented when specifying ``engine='numba'``
  178. in the method call.
  179. Only applicable to ``mean()``
  180. Returns
  181. -------
  182. ``ExponentialMovingWindow`` subclass
  183. See Also
  184. --------
  185. rolling : Provides rolling window calculations.
  186. expanding : Provides expanding transformations.
  187. Notes
  188. -----
  189. See :ref:`Windowing Operations <window.exponentially_weighted>`
  190. for further usage details and examples.
  191. Examples
  192. --------
  193. >>> df = pd.DataFrame({'B': [0, 1, 2, np.nan, 4]})
  194. >>> df
  195. B
  196. 0 0.0
  197. 1 1.0
  198. 2 2.0
  199. 3 NaN
  200. 4 4.0
  201. >>> df.ewm(com=0.5).mean()
  202. B
  203. 0 0.000000
  204. 1 0.750000
  205. 2 1.615385
  206. 3 1.615385
  207. 4 3.670213
  208. >>> df.ewm(alpha=2 / 3).mean()
  209. B
  210. 0 0.000000
  211. 1 0.750000
  212. 2 1.615385
  213. 3 1.615385
  214. 4 3.670213
  215. **adjust**
  216. >>> df.ewm(com=0.5, adjust=True).mean()
  217. B
  218. 0 0.000000
  219. 1 0.750000
  220. 2 1.615385
  221. 3 1.615385
  222. 4 3.670213
  223. >>> df.ewm(com=0.5, adjust=False).mean()
  224. B
  225. 0 0.000000
  226. 1 0.666667
  227. 2 1.555556
  228. 3 1.555556
  229. 4 3.650794
  230. **ignore_na**
  231. >>> df.ewm(com=0.5, ignore_na=True).mean()
  232. B
  233. 0 0.000000
  234. 1 0.750000
  235. 2 1.615385
  236. 3 1.615385
  237. 4 3.225000
  238. >>> df.ewm(com=0.5, ignore_na=False).mean()
  239. B
  240. 0 0.000000
  241. 1 0.750000
  242. 2 1.615385
  243. 3 1.615385
  244. 4 3.670213
  245. **times**
  246. Exponentially weighted mean with weights calculated with a timedelta ``halflife``
  247. relative to ``times``.
  248. >>> times = ['2020-01-01', '2020-01-03', '2020-01-10', '2020-01-15', '2020-01-17']
  249. >>> df.ewm(halflife='4 days', times=pd.DatetimeIndex(times)).mean()
  250. B
  251. 0 0.000000
  252. 1 0.585786
  253. 2 1.523889
  254. 3 1.523889
  255. 4 3.233686
  256. """
  257. _attributes = [
  258. "com",
  259. "span",
  260. "halflife",
  261. "alpha",
  262. "min_periods",
  263. "adjust",
  264. "ignore_na",
  265. "axis",
  266. "times",
  267. "method",
  268. ]
  269. def __init__(
  270. self,
  271. obj: NDFrame,
  272. com: float | None = None,
  273. span: float | None = None,
  274. halflife: float | TimedeltaConvertibleTypes | None = None,
  275. alpha: float | None = None,
  276. min_periods: int | None = 0,
  277. adjust: bool = True,
  278. ignore_na: bool = False,
  279. axis: Axis = 0,
  280. times: np.ndarray | NDFrame | None = None,
  281. method: str = "single",
  282. *,
  283. selection=None,
  284. ) -> None:
  285. super().__init__(
  286. obj=obj,
  287. min_periods=1 if min_periods is None else max(int(min_periods), 1),
  288. on=None,
  289. center=False,
  290. closed=None,
  291. method=method,
  292. axis=axis,
  293. selection=selection,
  294. )
  295. self.com = com
  296. self.span = span
  297. self.halflife = halflife
  298. self.alpha = alpha
  299. self.adjust = adjust
  300. self.ignore_na = ignore_na
  301. self.times = times
  302. if self.times is not None:
  303. if not self.adjust:
  304. raise NotImplementedError("times is not supported with adjust=False.")
  305. if not is_datetime64_ns_dtype(self.times):
  306. raise ValueError("times must be datetime64[ns] dtype.")
  307. if len(self.times) != len(obj):
  308. raise ValueError("times must be the same length as the object.")
  309. if not isinstance(self.halflife, (str, datetime.timedelta, np.timedelta64)):
  310. raise ValueError("halflife must be a timedelta convertible object")
  311. if isna(self.times).any():
  312. raise ValueError("Cannot convert NaT values to integer")
  313. self._deltas = _calculate_deltas(self.times, self.halflife)
  314. # Halflife is no longer applicable when calculating COM
  315. # But allow COM to still be calculated if the user passes other decay args
  316. if common.count_not_none(self.com, self.span, self.alpha) > 0:
  317. self._com = get_center_of_mass(self.com, self.span, None, self.alpha)
  318. else:
  319. self._com = 1.0
  320. else:
  321. if self.halflife is not None and isinstance(
  322. self.halflife, (str, datetime.timedelta, np.timedelta64)
  323. ):
  324. raise ValueError(
  325. "halflife can only be a timedelta convertible argument if "
  326. "times is not None."
  327. )
  328. # Without times, points are equally spaced
  329. self._deltas = np.ones(
  330. max(self.obj.shape[self.axis] - 1, 0), dtype=np.float64
  331. )
  332. self._com = get_center_of_mass(
  333. # error: Argument 3 to "get_center_of_mass" has incompatible type
  334. # "Union[float, Any, None, timedelta64, signedinteger[_64Bit]]";
  335. # expected "Optional[float]"
  336. self.com,
  337. self.span,
  338. self.halflife, # type: ignore[arg-type]
  339. self.alpha,
  340. )
  341. def _check_window_bounds(
  342. self, start: np.ndarray, end: np.ndarray, num_vals: int
  343. ) -> None:
  344. # emw algorithms are iterative with each point
  345. # ExponentialMovingWindowIndexer "bounds" are the entire window
  346. pass
  347. def _get_window_indexer(self) -> BaseIndexer:
  348. """
  349. Return an indexer class that will compute the window start and end bounds
  350. """
  351. return ExponentialMovingWindowIndexer()
  352. def online(
  353. self, engine: str = "numba", engine_kwargs=None
  354. ) -> OnlineExponentialMovingWindow:
  355. """
  356. Return an ``OnlineExponentialMovingWindow`` object to calculate
  357. exponentially moving window aggregations in an online method.
  358. .. versionadded:: 1.3.0
  359. Parameters
  360. ----------
  361. engine: str, default ``'numba'``
  362. Execution engine to calculate online aggregations.
  363. Applies to all supported aggregation methods.
  364. engine_kwargs : dict, default None
  365. Applies to all supported aggregation methods.
  366. * For ``'numba'`` engine, the engine can accept ``nopython``, ``nogil``
  367. and ``parallel`` dictionary keys. The values must either be ``True`` or
  368. ``False``. The default ``engine_kwargs`` for the ``'numba'`` engine is
  369. ``{{'nopython': True, 'nogil': False, 'parallel': False}}`` and will be
  370. applied to the function
  371. Returns
  372. -------
  373. OnlineExponentialMovingWindow
  374. """
  375. return OnlineExponentialMovingWindow(
  376. obj=self.obj,
  377. com=self.com,
  378. span=self.span,
  379. halflife=self.halflife,
  380. alpha=self.alpha,
  381. min_periods=self.min_periods,
  382. adjust=self.adjust,
  383. ignore_na=self.ignore_na,
  384. axis=self.axis,
  385. times=self.times,
  386. engine=engine,
  387. engine_kwargs=engine_kwargs,
  388. selection=self._selection,
  389. )
  390. @doc(
  391. _shared_docs["aggregate"],
  392. see_also=dedent(
  393. """
  394. See Also
  395. --------
  396. pandas.DataFrame.rolling.aggregate
  397. """
  398. ),
  399. examples=dedent(
  400. """
  401. Examples
  402. --------
  403. >>> df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6], "C": [7, 8, 9]})
  404. >>> df
  405. A B C
  406. 0 1 4 7
  407. 1 2 5 8
  408. 2 3 6 9
  409. >>> df.ewm(alpha=0.5).mean()
  410. A B C
  411. 0 1.000000 4.000000 7.000000
  412. 1 1.666667 4.666667 7.666667
  413. 2 2.428571 5.428571 8.428571
  414. """
  415. ),
  416. klass="Series/Dataframe",
  417. axis="",
  418. )
  419. def aggregate(self, func, *args, **kwargs):
  420. return super().aggregate(func, *args, **kwargs)
  421. agg = aggregate
  422. @doc(
  423. template_header,
  424. create_section_header("Parameters"),
  425. kwargs_numeric_only,
  426. window_agg_numba_parameters(),
  427. create_section_header("Returns"),
  428. template_returns,
  429. create_section_header("See Also"),
  430. template_see_also,
  431. create_section_header("Notes"),
  432. numba_notes.replace("\n", "", 1),
  433. window_method="ewm",
  434. aggregation_description="(exponential weighted moment) mean",
  435. agg_method="mean",
  436. )
  437. def mean(
  438. self,
  439. numeric_only: bool = False,
  440. engine=None,
  441. engine_kwargs=None,
  442. ):
  443. if maybe_use_numba(engine):
  444. if self.method == "single":
  445. func = generate_numba_ewm_func
  446. else:
  447. func = generate_numba_ewm_table_func
  448. ewm_func = func(
  449. **get_jit_arguments(engine_kwargs),
  450. com=self._com,
  451. adjust=self.adjust,
  452. ignore_na=self.ignore_na,
  453. deltas=tuple(self._deltas),
  454. normalize=True,
  455. )
  456. return self._apply(ewm_func, name="mean")
  457. elif engine in ("cython", None):
  458. if engine_kwargs is not None:
  459. raise ValueError("cython engine does not accept engine_kwargs")
  460. deltas = None if self.times is None else self._deltas
  461. window_func = partial(
  462. window_aggregations.ewm,
  463. com=self._com,
  464. adjust=self.adjust,
  465. ignore_na=self.ignore_na,
  466. deltas=deltas,
  467. normalize=True,
  468. )
  469. return self._apply(window_func, name="mean", numeric_only=numeric_only)
  470. else:
  471. raise ValueError("engine must be either 'numba' or 'cython'")
  472. @doc(
  473. template_header,
  474. create_section_header("Parameters"),
  475. kwargs_numeric_only,
  476. window_agg_numba_parameters(),
  477. create_section_header("Returns"),
  478. template_returns,
  479. create_section_header("See Also"),
  480. template_see_also,
  481. create_section_header("Notes"),
  482. numba_notes.replace("\n", "", 1),
  483. window_method="ewm",
  484. aggregation_description="(exponential weighted moment) sum",
  485. agg_method="sum",
  486. )
  487. def sum(
  488. self,
  489. numeric_only: bool = False,
  490. engine=None,
  491. engine_kwargs=None,
  492. ):
  493. if not self.adjust:
  494. raise NotImplementedError("sum is not implemented with adjust=False")
  495. if maybe_use_numba(engine):
  496. if self.method == "single":
  497. func = generate_numba_ewm_func
  498. else:
  499. func = generate_numba_ewm_table_func
  500. ewm_func = func(
  501. **get_jit_arguments(engine_kwargs),
  502. com=self._com,
  503. adjust=self.adjust,
  504. ignore_na=self.ignore_na,
  505. deltas=tuple(self._deltas),
  506. normalize=False,
  507. )
  508. return self._apply(ewm_func, name="sum")
  509. elif engine in ("cython", None):
  510. if engine_kwargs is not None:
  511. raise ValueError("cython engine does not accept engine_kwargs")
  512. deltas = None if self.times is None else self._deltas
  513. window_func = partial(
  514. window_aggregations.ewm,
  515. com=self._com,
  516. adjust=self.adjust,
  517. ignore_na=self.ignore_na,
  518. deltas=deltas,
  519. normalize=False,
  520. )
  521. return self._apply(window_func, name="sum", numeric_only=numeric_only)
  522. else:
  523. raise ValueError("engine must be either 'numba' or 'cython'")
  524. @doc(
  525. template_header,
  526. create_section_header("Parameters"),
  527. dedent(
  528. """
  529. bias : bool, default False
  530. Use a standard estimation bias correction.
  531. """
  532. ).replace("\n", "", 1),
  533. kwargs_numeric_only,
  534. create_section_header("Returns"),
  535. template_returns,
  536. create_section_header("See Also"),
  537. template_see_also[:-1],
  538. window_method="ewm",
  539. aggregation_description="(exponential weighted moment) standard deviation",
  540. agg_method="std",
  541. )
  542. def std(self, bias: bool = False, numeric_only: bool = False):
  543. if (
  544. numeric_only
  545. and self._selected_obj.ndim == 1
  546. and not is_numeric_dtype(self._selected_obj.dtype)
  547. ):
  548. # Raise directly so error message says std instead of var
  549. raise NotImplementedError(
  550. f"{type(self).__name__}.std does not implement numeric_only"
  551. )
  552. return zsqrt(self.var(bias=bias, numeric_only=numeric_only))
  553. @doc(
  554. template_header,
  555. create_section_header("Parameters"),
  556. dedent(
  557. """
  558. bias : bool, default False
  559. Use a standard estimation bias correction.
  560. """
  561. ).replace("\n", "", 1),
  562. kwargs_numeric_only,
  563. create_section_header("Returns"),
  564. template_returns,
  565. create_section_header("See Also"),
  566. template_see_also[:-1],
  567. window_method="ewm",
  568. aggregation_description="(exponential weighted moment) variance",
  569. agg_method="var",
  570. )
  571. def var(self, bias: bool = False, numeric_only: bool = False):
  572. window_func = window_aggregations.ewmcov
  573. wfunc = partial(
  574. window_func,
  575. com=self._com,
  576. adjust=self.adjust,
  577. ignore_na=self.ignore_na,
  578. bias=bias,
  579. )
  580. def var_func(values, begin, end, min_periods):
  581. return wfunc(values, begin, end, min_periods, values)
  582. return self._apply(var_func, name="var", numeric_only=numeric_only)
  583. @doc(
  584. template_header,
  585. create_section_header("Parameters"),
  586. dedent(
  587. """
  588. other : Series or DataFrame , optional
  589. If not supplied then will default to self and produce pairwise
  590. output.
  591. pairwise : bool, default None
  592. If False then only matching columns between self and other will be
  593. used and the output will be a DataFrame.
  594. If True then all pairwise combinations will be calculated and the
  595. output will be a MultiIndex DataFrame in the case of DataFrame
  596. inputs. In the case of missing elements, only complete pairwise
  597. observations will be used.
  598. bias : bool, default False
  599. Use a standard estimation bias correction.
  600. """
  601. ).replace("\n", "", 1),
  602. kwargs_numeric_only,
  603. create_section_header("Returns"),
  604. template_returns,
  605. create_section_header("See Also"),
  606. template_see_also[:-1],
  607. window_method="ewm",
  608. aggregation_description="(exponential weighted moment) sample covariance",
  609. agg_method="cov",
  610. )
  611. def cov(
  612. self,
  613. other: DataFrame | Series | None = None,
  614. pairwise: bool | None = None,
  615. bias: bool = False,
  616. numeric_only: bool = False,
  617. ):
  618. from pandas import Series
  619. self._validate_numeric_only("cov", numeric_only)
  620. def cov_func(x, y):
  621. x_array = self._prep_values(x)
  622. y_array = self._prep_values(y)
  623. window_indexer = self._get_window_indexer()
  624. min_periods = (
  625. self.min_periods
  626. if self.min_periods is not None
  627. else window_indexer.window_size
  628. )
  629. start, end = window_indexer.get_window_bounds(
  630. num_values=len(x_array),
  631. min_periods=min_periods,
  632. center=self.center,
  633. closed=self.closed,
  634. step=self.step,
  635. )
  636. result = window_aggregations.ewmcov(
  637. x_array,
  638. start,
  639. end,
  640. # error: Argument 4 to "ewmcov" has incompatible type
  641. # "Optional[int]"; expected "int"
  642. self.min_periods, # type: ignore[arg-type]
  643. y_array,
  644. self._com,
  645. self.adjust,
  646. self.ignore_na,
  647. bias,
  648. )
  649. return Series(result, index=x.index, name=x.name, copy=False)
  650. return self._apply_pairwise(
  651. self._selected_obj, other, pairwise, cov_func, numeric_only
  652. )
  653. @doc(
  654. template_header,
  655. create_section_header("Parameters"),
  656. dedent(
  657. """
  658. other : Series or DataFrame, optional
  659. If not supplied then will default to self and produce pairwise
  660. output.
  661. pairwise : bool, default None
  662. If False then only matching columns between self and other will be
  663. used and the output will be a DataFrame.
  664. If True then all pairwise combinations will be calculated and the
  665. output will be a MultiIndex DataFrame in the case of DataFrame
  666. inputs. In the case of missing elements, only complete pairwise
  667. observations will be used.
  668. """
  669. ).replace("\n", "", 1),
  670. kwargs_numeric_only,
  671. create_section_header("Returns"),
  672. template_returns,
  673. create_section_header("See Also"),
  674. template_see_also[:-1],
  675. window_method="ewm",
  676. aggregation_description="(exponential weighted moment) sample correlation",
  677. agg_method="corr",
  678. )
  679. def corr(
  680. self,
  681. other: DataFrame | Series | None = None,
  682. pairwise: bool | None = None,
  683. numeric_only: bool = False,
  684. ):
  685. from pandas import Series
  686. self._validate_numeric_only("corr", numeric_only)
  687. def cov_func(x, y):
  688. x_array = self._prep_values(x)
  689. y_array = self._prep_values(y)
  690. window_indexer = self._get_window_indexer()
  691. min_periods = (
  692. self.min_periods
  693. if self.min_periods is not None
  694. else window_indexer.window_size
  695. )
  696. start, end = window_indexer.get_window_bounds(
  697. num_values=len(x_array),
  698. min_periods=min_periods,
  699. center=self.center,
  700. closed=self.closed,
  701. step=self.step,
  702. )
  703. def _cov(X, Y):
  704. return window_aggregations.ewmcov(
  705. X,
  706. start,
  707. end,
  708. min_periods,
  709. Y,
  710. self._com,
  711. self.adjust,
  712. self.ignore_na,
  713. True,
  714. )
  715. with np.errstate(all="ignore"):
  716. cov = _cov(x_array, y_array)
  717. x_var = _cov(x_array, x_array)
  718. y_var = _cov(y_array, y_array)
  719. result = cov / zsqrt(x_var * y_var)
  720. return Series(result, index=x.index, name=x.name, copy=False)
  721. return self._apply_pairwise(
  722. self._selected_obj, other, pairwise, cov_func, numeric_only
  723. )
  724. class ExponentialMovingWindowGroupby(BaseWindowGroupby, ExponentialMovingWindow):
  725. """
  726. Provide an exponential moving window groupby implementation.
  727. """
  728. _attributes = ExponentialMovingWindow._attributes + BaseWindowGroupby._attributes
  729. def __init__(self, obj, *args, _grouper=None, **kwargs) -> None:
  730. super().__init__(obj, *args, _grouper=_grouper, **kwargs)
  731. if not obj.empty and self.times is not None:
  732. # sort the times and recalculate the deltas according to the groups
  733. groupby_order = np.concatenate(list(self._grouper.indices.values()))
  734. self._deltas = _calculate_deltas(
  735. self.times.take(groupby_order),
  736. self.halflife,
  737. )
  738. def _get_window_indexer(self) -> GroupbyIndexer:
  739. """
  740. Return an indexer class that will compute the window start and end bounds
  741. Returns
  742. -------
  743. GroupbyIndexer
  744. """
  745. window_indexer = GroupbyIndexer(
  746. groupby_indices=self._grouper.indices,
  747. window_indexer=ExponentialMovingWindowIndexer,
  748. )
  749. return window_indexer
  750. class OnlineExponentialMovingWindow(ExponentialMovingWindow):
  751. def __init__(
  752. self,
  753. obj: NDFrame,
  754. com: float | None = None,
  755. span: float | None = None,
  756. halflife: float | TimedeltaConvertibleTypes | None = None,
  757. alpha: float | None = None,
  758. min_periods: int | None = 0,
  759. adjust: bool = True,
  760. ignore_na: bool = False,
  761. axis: Axis = 0,
  762. times: np.ndarray | NDFrame | None = None,
  763. engine: str = "numba",
  764. engine_kwargs: dict[str, bool] | None = None,
  765. *,
  766. selection=None,
  767. ) -> None:
  768. if times is not None:
  769. raise NotImplementedError(
  770. "times is not implemented with online operations."
  771. )
  772. super().__init__(
  773. obj=obj,
  774. com=com,
  775. span=span,
  776. halflife=halflife,
  777. alpha=alpha,
  778. min_periods=min_periods,
  779. adjust=adjust,
  780. ignore_na=ignore_na,
  781. axis=axis,
  782. times=times,
  783. selection=selection,
  784. )
  785. self._mean = EWMMeanState(
  786. self._com, self.adjust, self.ignore_na, self.axis, obj.shape
  787. )
  788. if maybe_use_numba(engine):
  789. self.engine = engine
  790. self.engine_kwargs = engine_kwargs
  791. else:
  792. raise ValueError("'numba' is the only supported engine")
  793. def reset(self) -> None:
  794. """
  795. Reset the state captured by `update` calls.
  796. """
  797. self._mean.reset()
  798. def aggregate(self, func, *args, **kwargs):
  799. raise NotImplementedError("aggregate is not implemented.")
  800. def std(self, bias: bool = False, *args, **kwargs):
  801. raise NotImplementedError("std is not implemented.")
  802. def corr(
  803. self,
  804. other: DataFrame | Series | None = None,
  805. pairwise: bool | None = None,
  806. numeric_only: bool = False,
  807. ):
  808. raise NotImplementedError("corr is not implemented.")
  809. def cov(
  810. self,
  811. other: DataFrame | Series | None = None,
  812. pairwise: bool | None = None,
  813. bias: bool = False,
  814. numeric_only: bool = False,
  815. ):
  816. raise NotImplementedError("cov is not implemented.")
  817. def var(self, bias: bool = False, numeric_only: bool = False):
  818. raise NotImplementedError("var is not implemented.")
  819. def mean(self, *args, update=None, update_times=None, **kwargs):
  820. """
  821. Calculate an online exponentially weighted mean.
  822. Parameters
  823. ----------
  824. update: DataFrame or Series, default None
  825. New values to continue calculating the
  826. exponentially weighted mean from the last values and weights.
  827. Values should be float64 dtype.
  828. ``update`` needs to be ``None`` the first time the
  829. exponentially weighted mean is calculated.
  830. update_times: Series or 1-D np.ndarray, default None
  831. New times to continue calculating the
  832. exponentially weighted mean from the last values and weights.
  833. If ``None``, values are assumed to be evenly spaced
  834. in time.
  835. This feature is currently unsupported.
  836. Returns
  837. -------
  838. DataFrame or Series
  839. Examples
  840. --------
  841. >>> df = pd.DataFrame({"a": range(5), "b": range(5, 10)})
  842. >>> online_ewm = df.head(2).ewm(0.5).online()
  843. >>> online_ewm.mean()
  844. a b
  845. 0 0.00 5.00
  846. 1 0.75 5.75
  847. >>> online_ewm.mean(update=df.tail(3))
  848. a b
  849. 2 1.615385 6.615385
  850. 3 2.550000 7.550000
  851. 4 3.520661 8.520661
  852. >>> online_ewm.reset()
  853. >>> online_ewm.mean()
  854. a b
  855. 0 0.00 5.00
  856. 1 0.75 5.75
  857. """
  858. result_kwargs = {}
  859. is_frame = self._selected_obj.ndim == 2
  860. if update_times is not None:
  861. raise NotImplementedError("update_times is not implemented.")
  862. update_deltas = np.ones(
  863. max(self._selected_obj.shape[self.axis - 1] - 1, 0), dtype=np.float64
  864. )
  865. if update is not None:
  866. if self._mean.last_ewm is None:
  867. raise ValueError(
  868. "Must call mean with update=None first before passing update"
  869. )
  870. result_from = 1
  871. result_kwargs["index"] = update.index
  872. if is_frame:
  873. last_value = self._mean.last_ewm[np.newaxis, :]
  874. result_kwargs["columns"] = update.columns
  875. else:
  876. last_value = self._mean.last_ewm
  877. result_kwargs["name"] = update.name
  878. np_array = np.concatenate((last_value, update.to_numpy()))
  879. else:
  880. result_from = 0
  881. result_kwargs["index"] = self._selected_obj.index
  882. if is_frame:
  883. result_kwargs["columns"] = self._selected_obj.columns
  884. else:
  885. result_kwargs["name"] = self._selected_obj.name
  886. np_array = self._selected_obj.astype(np.float64).to_numpy()
  887. ewma_func = generate_online_numba_ewma_func(
  888. **get_jit_arguments(self.engine_kwargs)
  889. )
  890. result = self._mean.run_ewm(
  891. np_array if is_frame else np_array[:, np.newaxis],
  892. update_deltas,
  893. self.min_periods,
  894. ewma_func,
  895. )
  896. if not is_frame:
  897. result = result.squeeze()
  898. result = result[result_from:]
  899. result = self._selected_obj._constructor(result, **result_kwargs)
  900. return result