online.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. import numpy as np
  4. from pandas.compat._optional import import_optional_dependency
  5. def generate_online_numba_ewma_func(
  6. nopython: bool,
  7. nogil: bool,
  8. parallel: bool,
  9. ):
  10. """
  11. Generate a numba jitted groupby ewma function specified by values
  12. from engine_kwargs.
  13. Parameters
  14. ----------
  15. nopython : bool
  16. nopython to be passed into numba.jit
  17. nogil : bool
  18. nogil to be passed into numba.jit
  19. parallel : bool
  20. parallel to be passed into numba.jit
  21. Returns
  22. -------
  23. Numba function
  24. """
  25. if TYPE_CHECKING:
  26. import numba
  27. else:
  28. numba = import_optional_dependency("numba")
  29. @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
  30. def online_ewma(
  31. values: np.ndarray,
  32. deltas: np.ndarray,
  33. minimum_periods: int,
  34. old_wt_factor: float,
  35. new_wt: float,
  36. old_wt: np.ndarray,
  37. adjust: bool,
  38. ignore_na: bool,
  39. ):
  40. """
  41. Compute online exponentially weighted mean per column over 2D values.
  42. Takes the first observation as is, then computes the subsequent
  43. exponentially weighted mean accounting minimum periods.
  44. """
  45. result = np.empty(values.shape)
  46. weighted_avg = values[0]
  47. nobs = (~np.isnan(weighted_avg)).astype(np.int64)
  48. result[0] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
  49. for i in range(1, len(values)):
  50. cur = values[i]
  51. is_observations = ~np.isnan(cur)
  52. nobs += is_observations.astype(np.int64)
  53. for j in numba.prange(len(cur)):
  54. if not np.isnan(weighted_avg[j]):
  55. if is_observations[j] or not ignore_na:
  56. # note that len(deltas) = len(vals) - 1 and deltas[i] is to be
  57. # used in conjunction with vals[i+1]
  58. old_wt[j] *= old_wt_factor ** deltas[j - 1]
  59. if is_observations[j]:
  60. # avoid numerical errors on constant series
  61. if weighted_avg[j] != cur[j]:
  62. weighted_avg[j] = (
  63. (old_wt[j] * weighted_avg[j]) + (new_wt * cur[j])
  64. ) / (old_wt[j] + new_wt)
  65. if adjust:
  66. old_wt[j] += new_wt
  67. else:
  68. old_wt[j] = 1.0
  69. elif is_observations[j]:
  70. weighted_avg[j] = cur[j]
  71. result[i] = np.where(nobs >= minimum_periods, weighted_avg, np.nan)
  72. return result, old_wt
  73. return online_ewma
  74. class EWMMeanState:
  75. def __init__(self, com, adjust, ignore_na, axis, shape) -> None:
  76. alpha = 1.0 / (1.0 + com)
  77. self.axis = axis
  78. self.shape = shape
  79. self.adjust = adjust
  80. self.ignore_na = ignore_na
  81. self.new_wt = 1.0 if adjust else alpha
  82. self.old_wt_factor = 1.0 - alpha
  83. self.old_wt = np.ones(self.shape[self.axis - 1])
  84. self.last_ewm = None
  85. def run_ewm(self, weighted_avg, deltas, min_periods, ewm_func):
  86. result, old_wt = ewm_func(
  87. weighted_avg,
  88. deltas,
  89. min_periods,
  90. self.old_wt_factor,
  91. self.new_wt,
  92. self.old_wt,
  93. self.adjust,
  94. self.ignore_na,
  95. )
  96. self.old_wt = old_wt
  97. self.last_ewm = result[-1]
  98. return result
  99. def reset(self) -> None:
  100. self.old_wt = np.ones(self.shape[self.axis - 1])
  101. self.last_ewm = None