|
- from __future__ import annotations
- import functools
- from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- )
- import numpy as np
- from pandas._typing import Scalar
- from pandas.compat._optional import import_optional_dependency
- from pandas.core.util.numba_ import jit_user_function
- @functools.lru_cache(maxsize=None)
- def generate_numba_apply_func(
- func: Callable[..., Scalar],
- nopython: bool,
- nogil: bool,
- parallel: bool,
- ):
- """
- Generate a numba jitted apply function specified by values from engine_kwargs.
- 1. jit the user's function
- 2. Return a rolling apply function with the jitted function inline
- Configurations specified in engine_kwargs apply to both the user's
- function _AND_ the rolling apply function.
- Parameters
- ----------
- func : function
- function to be applied to each window and will be JITed
- nopython : bool
- nopython to be passed into numba.jit
- nogil : bool
- nogil to be passed into numba.jit
- parallel : bool
- parallel to be passed into numba.jit
- Returns
- -------
- Numba function
- """
- numba_func = jit_user_function(func, nopython, nogil, parallel)
- if TYPE_CHECKING:
- import numba
- else:
- numba = import_optional_dependency("numba")
- @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
- def roll_apply(
- values: np.ndarray,
- begin: np.ndarray,
- end: np.ndarray,
- minimum_periods: int,
- *args: Any,
- ) -> np.ndarray:
- result = np.empty(len(begin))
- for i in numba.prange(len(result)):
- start = begin[i]
- stop = end[i]
- window = values[start:stop]
- count_nan = np.sum(np.isnan(window))
- if len(window) - count_nan >= minimum_periods:
- result[i] = numba_func(window, *args)
- else:
- result[i] = np.nan
- return result
- return roll_apply
- @functools.lru_cache(maxsize=None)
- def generate_numba_ewm_func(
- nopython: bool,
- nogil: bool,
- parallel: bool,
- com: float,
- adjust: bool,
- ignore_na: bool,
- deltas: tuple,
- normalize: bool,
- ):
- """
- Generate a numba jitted ewm mean or sum function specified by values
- from engine_kwargs.
- Parameters
- ----------
- nopython : bool
- nopython to be passed into numba.jit
- nogil : bool
- nogil to be passed into numba.jit
- parallel : bool
- parallel to be passed into numba.jit
- com : float
- adjust : bool
- ignore_na : bool
- deltas : tuple
- normalize : bool
- Returns
- -------
- Numba function
- """
- if TYPE_CHECKING:
- import numba
- else:
- numba = import_optional_dependency("numba")
- @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
- def ewm(
- values: np.ndarray,
- begin: np.ndarray,
- end: np.ndarray,
- minimum_periods: int,
- ) -> np.ndarray:
- result = np.empty(len(values))
- alpha = 1.0 / (1.0 + com)
- old_wt_factor = 1.0 - alpha
- new_wt = 1.0 if adjust else alpha
- for i in numba.prange(len(begin)):
- start = begin[i]
- stop = end[i]
- window = values[start:stop]
- sub_result = np.empty(len(window))
- weighted = window[0]
- nobs = int(not np.isnan(weighted))
- sub_result[0] = weighted if nobs >= minimum_periods else np.nan
- old_wt = 1.0
- for j in range(1, len(window)):
- cur = window[j]
- is_observation = not np.isnan(cur)
- nobs += is_observation
- if not np.isnan(weighted):
- if is_observation or not ignore_na:
- if normalize:
- # note that len(deltas) = len(vals) - 1 and deltas[i]
- # is to be used in conjunction with vals[i+1]
- old_wt *= old_wt_factor ** deltas[start + j - 1]
- else:
- weighted = old_wt_factor * weighted
- if is_observation:
- if normalize:
- # avoid numerical errors on constant series
- if weighted != cur:
- weighted = old_wt * weighted + new_wt * cur
- if normalize:
- weighted = weighted / (old_wt + new_wt)
- if adjust:
- old_wt += new_wt
- else:
- old_wt = 1.0
- else:
- weighted += cur
- elif is_observation:
- weighted = cur
- sub_result[j] = weighted if nobs >= minimum_periods else np.nan
- result[start:stop] = sub_result
- return result
- return ewm
- @functools.lru_cache(maxsize=None)
- def generate_numba_table_func(
- func: Callable[..., np.ndarray],
- nopython: bool,
- nogil: bool,
- parallel: bool,
- ):
- """
- Generate a numba jitted function to apply window calculations table-wise.
- Func will be passed a M window size x N number of columns array, and
- must return a 1 x N number of columns array. Func is intended to operate
- row-wise, but the result will be transposed for axis=1.
- 1. jit the user's function
- 2. Return a rolling apply function with the jitted function inline
- Parameters
- ----------
- func : function
- function to be applied to each window and will be JITed
- nopython : bool
- nopython to be passed into numba.jit
- nogil : bool
- nogil to be passed into numba.jit
- parallel : bool
- parallel to be passed into numba.jit
- Returns
- -------
- Numba function
- """
- numba_func = jit_user_function(func, nopython, nogil, parallel)
- if TYPE_CHECKING:
- import numba
- else:
- numba = import_optional_dependency("numba")
- @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
- def roll_table(
- values: np.ndarray,
- begin: np.ndarray,
- end: np.ndarray,
- minimum_periods: int,
- *args: Any,
- ):
- result = np.empty((len(begin), values.shape[1]))
- min_periods_mask = np.empty(result.shape)
- for i in numba.prange(len(result)):
- start = begin[i]
- stop = end[i]
- window = values[start:stop]
- count_nan = np.sum(np.isnan(window), axis=0)
- sub_result = numba_func(window, *args)
- nan_mask = len(window) - count_nan >= minimum_periods
- min_periods_mask[i, :] = nan_mask
- result[i, :] = sub_result
- result = np.where(min_periods_mask, result, np.nan)
- return result
- return roll_table
- # This function will no longer be needed once numba supports
- # axis for all np.nan* agg functions
- # https://github.com/numba/numba/issues/1269
- @functools.lru_cache(maxsize=None)
- def generate_manual_numpy_nan_agg_with_axis(nan_func):
- if TYPE_CHECKING:
- import numba
- else:
- numba = import_optional_dependency("numba")
- @numba.jit(nopython=True, nogil=True, parallel=True)
- def nan_agg_with_axis(table):
- result = np.empty(table.shape[1])
- for i in numba.prange(table.shape[1]):
- partition = table[:, i]
- result[i] = nan_func(partition)
- return result
- return nan_agg_with_axis
- @functools.lru_cache(maxsize=None)
- def generate_numba_ewm_table_func(
- nopython: bool,
- nogil: bool,
- parallel: bool,
- com: float,
- adjust: bool,
- ignore_na: bool,
- deltas: tuple,
- normalize: bool,
- ):
- """
- Generate a numba jitted ewm mean or sum function applied table wise specified
- by values from engine_kwargs.
- Parameters
- ----------
- nopython : bool
- nopython to be passed into numba.jit
- nogil : bool
- nogil to be passed into numba.jit
- parallel : bool
- parallel to be passed into numba.jit
- com : float
- adjust : bool
- ignore_na : bool
- deltas : tuple
- normalize: bool
- Returns
- -------
- Numba function
- """
- if TYPE_CHECKING:
- import numba
- else:
- numba = import_optional_dependency("numba")
- @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
- def ewm_table(
- values: np.ndarray,
- begin: np.ndarray,
- end: np.ndarray,
- minimum_periods: int,
- ) -> np.ndarray:
- alpha = 1.0 / (1.0 + com)
- old_wt_factor = 1.0 - alpha
- new_wt = 1.0 if adjust else alpha
- old_wt = np.ones(values.shape[1])
- result = np.empty(values.shape)
- weighted = values[0].copy()
- nobs = (~np.isnan(weighted)).astype(np.int64)
- result[0] = np.where(nobs >= minimum_periods, weighted, np.nan)
- for i in range(1, len(values)):
- cur = values[i]
- is_observations = ~np.isnan(cur)
- nobs += is_observations.astype(np.int64)
- for j in numba.prange(len(cur)):
- if not np.isnan(weighted[j]):
- if is_observations[j] or not ignore_na:
- if normalize:
- # note that len(deltas) = len(vals) - 1 and deltas[i]
- # is to be used in conjunction with vals[i+1]
- old_wt[j] *= old_wt_factor ** deltas[i - 1]
- else:
- weighted[j] = old_wt_factor * weighted[j]
- if is_observations[j]:
- if normalize:
- # avoid numerical errors on constant series
- if weighted[j] != cur[j]:
- weighted[j] = (
- old_wt[j] * weighted[j] + new_wt * cur[j]
- )
- if normalize:
- weighted[j] = weighted[j] / (old_wt[j] + new_wt)
- if adjust:
- old_wt[j] += new_wt
- else:
- old_wt[j] = 1.0
- else:
- weighted[j] += cur[j]
- elif is_observations[j]:
- weighted[j] = cur[j]
- result[i] = np.where(nobs >= minimum_periods, weighted, np.nan)
- return result
- return ewm_table
|