numba_.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. from __future__ import annotations
  2. import functools
  3. from typing import (
  4. TYPE_CHECKING,
  5. Any,
  6. Callable,
  7. )
  8. import numpy as np
  9. from pandas._typing import Scalar
  10. from pandas.compat._optional import import_optional_dependency
  11. from pandas.core.util.numba_ import jit_user_function
  12. @functools.lru_cache(maxsize=None)
  13. def generate_numba_apply_func(
  14. func: Callable[..., Scalar],
  15. nopython: bool,
  16. nogil: bool,
  17. parallel: bool,
  18. ):
  19. """
  20. Generate a numba jitted apply function specified by values from engine_kwargs.
  21. 1. jit the user's function
  22. 2. Return a rolling apply function with the jitted function inline
  23. Configurations specified in engine_kwargs apply to both the user's
  24. function _AND_ the rolling apply function.
  25. Parameters
  26. ----------
  27. func : function
  28. function to be applied to each window and will be JITed
  29. nopython : bool
  30. nopython to be passed into numba.jit
  31. nogil : bool
  32. nogil to be passed into numba.jit
  33. parallel : bool
  34. parallel to be passed into numba.jit
  35. Returns
  36. -------
  37. Numba function
  38. """
  39. numba_func = jit_user_function(func, nopython, nogil, parallel)
  40. if TYPE_CHECKING:
  41. import numba
  42. else:
  43. numba = import_optional_dependency("numba")
  44. @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
  45. def roll_apply(
  46. values: np.ndarray,
  47. begin: np.ndarray,
  48. end: np.ndarray,
  49. minimum_periods: int,
  50. *args: Any,
  51. ) -> np.ndarray:
  52. result = np.empty(len(begin))
  53. for i in numba.prange(len(result)):
  54. start = begin[i]
  55. stop = end[i]
  56. window = values[start:stop]
  57. count_nan = np.sum(np.isnan(window))
  58. if len(window) - count_nan >= minimum_periods:
  59. result[i] = numba_func(window, *args)
  60. else:
  61. result[i] = np.nan
  62. return result
  63. return roll_apply
  64. @functools.lru_cache(maxsize=None)
  65. def generate_numba_ewm_func(
  66. nopython: bool,
  67. nogil: bool,
  68. parallel: bool,
  69. com: float,
  70. adjust: bool,
  71. ignore_na: bool,
  72. deltas: tuple,
  73. normalize: bool,
  74. ):
  75. """
  76. Generate a numba jitted ewm mean or sum function specified by values
  77. from engine_kwargs.
  78. Parameters
  79. ----------
  80. nopython : bool
  81. nopython to be passed into numba.jit
  82. nogil : bool
  83. nogil to be passed into numba.jit
  84. parallel : bool
  85. parallel to be passed into numba.jit
  86. com : float
  87. adjust : bool
  88. ignore_na : bool
  89. deltas : tuple
  90. normalize : bool
  91. Returns
  92. -------
  93. Numba function
  94. """
  95. if TYPE_CHECKING:
  96. import numba
  97. else:
  98. numba = import_optional_dependency("numba")
  99. @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
  100. def ewm(
  101. values: np.ndarray,
  102. begin: np.ndarray,
  103. end: np.ndarray,
  104. minimum_periods: int,
  105. ) -> np.ndarray:
  106. result = np.empty(len(values))
  107. alpha = 1.0 / (1.0 + com)
  108. old_wt_factor = 1.0 - alpha
  109. new_wt = 1.0 if adjust else alpha
  110. for i in numba.prange(len(begin)):
  111. start = begin[i]
  112. stop = end[i]
  113. window = values[start:stop]
  114. sub_result = np.empty(len(window))
  115. weighted = window[0]
  116. nobs = int(not np.isnan(weighted))
  117. sub_result[0] = weighted if nobs >= minimum_periods else np.nan
  118. old_wt = 1.0
  119. for j in range(1, len(window)):
  120. cur = window[j]
  121. is_observation = not np.isnan(cur)
  122. nobs += is_observation
  123. if not np.isnan(weighted):
  124. if is_observation or not ignore_na:
  125. if normalize:
  126. # note that len(deltas) = len(vals) - 1 and deltas[i]
  127. # is to be used in conjunction with vals[i+1]
  128. old_wt *= old_wt_factor ** deltas[start + j - 1]
  129. else:
  130. weighted = old_wt_factor * weighted
  131. if is_observation:
  132. if normalize:
  133. # avoid numerical errors on constant series
  134. if weighted != cur:
  135. weighted = old_wt * weighted + new_wt * cur
  136. if normalize:
  137. weighted = weighted / (old_wt + new_wt)
  138. if adjust:
  139. old_wt += new_wt
  140. else:
  141. old_wt = 1.0
  142. else:
  143. weighted += cur
  144. elif is_observation:
  145. weighted = cur
  146. sub_result[j] = weighted if nobs >= minimum_periods else np.nan
  147. result[start:stop] = sub_result
  148. return result
  149. return ewm
  150. @functools.lru_cache(maxsize=None)
  151. def generate_numba_table_func(
  152. func: Callable[..., np.ndarray],
  153. nopython: bool,
  154. nogil: bool,
  155. parallel: bool,
  156. ):
  157. """
  158. Generate a numba jitted function to apply window calculations table-wise.
  159. Func will be passed a M window size x N number of columns array, and
  160. must return a 1 x N number of columns array. Func is intended to operate
  161. row-wise, but the result will be transposed for axis=1.
  162. 1. jit the user's function
  163. 2. Return a rolling apply function with the jitted function inline
  164. Parameters
  165. ----------
  166. func : function
  167. function to be applied to each window and will be JITed
  168. nopython : bool
  169. nopython to be passed into numba.jit
  170. nogil : bool
  171. nogil to be passed into numba.jit
  172. parallel : bool
  173. parallel to be passed into numba.jit
  174. Returns
  175. -------
  176. Numba function
  177. """
  178. numba_func = jit_user_function(func, nopython, nogil, parallel)
  179. if TYPE_CHECKING:
  180. import numba
  181. else:
  182. numba = import_optional_dependency("numba")
  183. @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
  184. def roll_table(
  185. values: np.ndarray,
  186. begin: np.ndarray,
  187. end: np.ndarray,
  188. minimum_periods: int,
  189. *args: Any,
  190. ):
  191. result = np.empty((len(begin), values.shape[1]))
  192. min_periods_mask = np.empty(result.shape)
  193. for i in numba.prange(len(result)):
  194. start = begin[i]
  195. stop = end[i]
  196. window = values[start:stop]
  197. count_nan = np.sum(np.isnan(window), axis=0)
  198. sub_result = numba_func(window, *args)
  199. nan_mask = len(window) - count_nan >= minimum_periods
  200. min_periods_mask[i, :] = nan_mask
  201. result[i, :] = sub_result
  202. result = np.where(min_periods_mask, result, np.nan)
  203. return result
  204. return roll_table
  205. # This function will no longer be needed once numba supports
  206. # axis for all np.nan* agg functions
  207. # https://github.com/numba/numba/issues/1269
  208. @functools.lru_cache(maxsize=None)
  209. def generate_manual_numpy_nan_agg_with_axis(nan_func):
  210. if TYPE_CHECKING:
  211. import numba
  212. else:
  213. numba = import_optional_dependency("numba")
  214. @numba.jit(nopython=True, nogil=True, parallel=True)
  215. def nan_agg_with_axis(table):
  216. result = np.empty(table.shape[1])
  217. for i in numba.prange(table.shape[1]):
  218. partition = table[:, i]
  219. result[i] = nan_func(partition)
  220. return result
  221. return nan_agg_with_axis
  222. @functools.lru_cache(maxsize=None)
  223. def generate_numba_ewm_table_func(
  224. nopython: bool,
  225. nogil: bool,
  226. parallel: bool,
  227. com: float,
  228. adjust: bool,
  229. ignore_na: bool,
  230. deltas: tuple,
  231. normalize: bool,
  232. ):
  233. """
  234. Generate a numba jitted ewm mean or sum function applied table wise specified
  235. by values from engine_kwargs.
  236. Parameters
  237. ----------
  238. nopython : bool
  239. nopython to be passed into numba.jit
  240. nogil : bool
  241. nogil to be passed into numba.jit
  242. parallel : bool
  243. parallel to be passed into numba.jit
  244. com : float
  245. adjust : bool
  246. ignore_na : bool
  247. deltas : tuple
  248. normalize: bool
  249. Returns
  250. -------
  251. Numba function
  252. """
  253. if TYPE_CHECKING:
  254. import numba
  255. else:
  256. numba = import_optional_dependency("numba")
  257. @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
  258. def ewm_table(
  259. values: np.ndarray,
  260. begin: np.ndarray,
  261. end: np.ndarray,
  262. minimum_periods: int,
  263. ) -> np.ndarray:
  264. alpha = 1.0 / (1.0 + com)
  265. old_wt_factor = 1.0 - alpha
  266. new_wt = 1.0 if adjust else alpha
  267. old_wt = np.ones(values.shape[1])
  268. result = np.empty(values.shape)
  269. weighted = values[0].copy()
  270. nobs = (~np.isnan(weighted)).astype(np.int64)
  271. result[0] = np.where(nobs >= minimum_periods, weighted, np.nan)
  272. for i in range(1, len(values)):
  273. cur = values[i]
  274. is_observations = ~np.isnan(cur)
  275. nobs += is_observations.astype(np.int64)
  276. for j in numba.prange(len(cur)):
  277. if not np.isnan(weighted[j]):
  278. if is_observations[j] or not ignore_na:
  279. if normalize:
  280. # note that len(deltas) = len(vals) - 1 and deltas[i]
  281. # is to be used in conjunction with vals[i+1]
  282. old_wt[j] *= old_wt_factor ** deltas[i - 1]
  283. else:
  284. weighted[j] = old_wt_factor * weighted[j]
  285. if is_observations[j]:
  286. if normalize:
  287. # avoid numerical errors on constant series
  288. if weighted[j] != cur[j]:
  289. weighted[j] = (
  290. old_wt[j] * weighted[j] + new_wt * cur[j]
  291. )
  292. if normalize:
  293. weighted[j] = weighted[j] / (old_wt[j] + new_wt)
  294. if adjust:
  295. old_wt[j] += new_wt
  296. else:
  297. old_wt[j] = 1.0
  298. else:
  299. weighted[j] += cur[j]
  300. elif is_observations[j]:
  301. weighted[j] = cur[j]
  302. result[i] = np.where(nobs >= minimum_periods, weighted, np.nan)
  303. return result
  304. return ewm_table