numba_.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. """Common utilities for Numba operations with groupby ops"""
  2. from __future__ import annotations
  3. import functools
  4. import inspect
  5. from typing import (
  6. TYPE_CHECKING,
  7. Any,
  8. Callable,
  9. )
  10. import numpy as np
  11. from pandas._typing import Scalar
  12. from pandas.compat._optional import import_optional_dependency
  13. from pandas.core.util.numba_ import (
  14. NumbaUtilError,
  15. jit_user_function,
  16. )
  17. def validate_udf(func: Callable) -> None:
  18. """
  19. Validate user defined function for ops when using Numba with groupby ops.
  20. The first signature arguments should include:
  21. def f(values, index, ...):
  22. ...
  23. Parameters
  24. ----------
  25. func : function, default False
  26. user defined function
  27. Returns
  28. -------
  29. None
  30. Raises
  31. ------
  32. NumbaUtilError
  33. """
  34. if not callable(func):
  35. raise NotImplementedError(
  36. "Numba engine can only be used with a single function."
  37. )
  38. udf_signature = list(inspect.signature(func).parameters.keys())
  39. expected_args = ["values", "index"]
  40. min_number_args = len(expected_args)
  41. if (
  42. len(udf_signature) < min_number_args
  43. or udf_signature[:min_number_args] != expected_args
  44. ):
  45. raise NumbaUtilError(
  46. f"The first {min_number_args} arguments to {func.__name__} must be "
  47. f"{expected_args}"
  48. )
  49. @functools.lru_cache(maxsize=None)
  50. def generate_numba_agg_func(
  51. func: Callable[..., Scalar],
  52. nopython: bool,
  53. nogil: bool,
  54. parallel: bool,
  55. ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
  56. """
  57. Generate a numba jitted agg function specified by values from engine_kwargs.
  58. 1. jit the user's function
  59. 2. Return a groupby agg function with the jitted function inline
  60. Configurations specified in engine_kwargs apply to both the user's
  61. function _AND_ the groupby evaluation loop.
  62. Parameters
  63. ----------
  64. func : function
  65. function to be applied to each group and will be JITed
  66. nopython : bool
  67. nopython to be passed into numba.jit
  68. nogil : bool
  69. nogil to be passed into numba.jit
  70. parallel : bool
  71. parallel to be passed into numba.jit
  72. Returns
  73. -------
  74. Numba function
  75. """
  76. numba_func = jit_user_function(func, nopython, nogil, parallel)
  77. if TYPE_CHECKING:
  78. import numba
  79. else:
  80. numba = import_optional_dependency("numba")
  81. @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
  82. def group_agg(
  83. values: np.ndarray,
  84. index: np.ndarray,
  85. begin: np.ndarray,
  86. end: np.ndarray,
  87. num_columns: int,
  88. *args: Any,
  89. ) -> np.ndarray:
  90. assert len(begin) == len(end)
  91. num_groups = len(begin)
  92. result = np.empty((num_groups, num_columns))
  93. for i in numba.prange(num_groups):
  94. group_index = index[begin[i] : end[i]]
  95. for j in numba.prange(num_columns):
  96. group = values[begin[i] : end[i], j]
  97. result[i, j] = numba_func(group, group_index, *args)
  98. return result
  99. return group_agg
  100. @functools.lru_cache(maxsize=None)
  101. def generate_numba_transform_func(
  102. func: Callable[..., np.ndarray],
  103. nopython: bool,
  104. nogil: bool,
  105. parallel: bool,
  106. ) -> Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, Any], np.ndarray]:
  107. """
  108. Generate a numba jitted transform function specified by values from engine_kwargs.
  109. 1. jit the user's function
  110. 2. Return a groupby transform function with the jitted function inline
  111. Configurations specified in engine_kwargs apply to both the user's
  112. function _AND_ the groupby evaluation loop.
  113. Parameters
  114. ----------
  115. func : function
  116. function to be applied to each window and will be JITed
  117. nopython : bool
  118. nopython to be passed into numba.jit
  119. nogil : bool
  120. nogil to be passed into numba.jit
  121. parallel : bool
  122. parallel to be passed into numba.jit
  123. Returns
  124. -------
  125. Numba function
  126. """
  127. numba_func = jit_user_function(func, nopython, nogil, parallel)
  128. if TYPE_CHECKING:
  129. import numba
  130. else:
  131. numba = import_optional_dependency("numba")
  132. @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel)
  133. def group_transform(
  134. values: np.ndarray,
  135. index: np.ndarray,
  136. begin: np.ndarray,
  137. end: np.ndarray,
  138. num_columns: int,
  139. *args: Any,
  140. ) -> np.ndarray:
  141. assert len(begin) == len(end)
  142. num_groups = len(begin)
  143. result = np.empty((len(values), num_columns))
  144. for i in numba.prange(num_groups):
  145. group_index = index[begin[i] : end[i]]
  146. for j in numba.prange(num_columns):
  147. group = values[begin[i] : end[i], j]
  148. result[begin[i] : end[i], j] = numba_func(group, group_index, *args)
  149. return result
  150. return group_transform