numba_.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. """Common utilities for Numba operations"""
  2. from __future__ import annotations
  3. import types
  4. from typing import (
  5. TYPE_CHECKING,
  6. Callable,
  7. )
  8. import numpy as np
  9. from pandas.compat._optional import import_optional_dependency
  10. from pandas.errors import NumbaUtilError
  11. GLOBAL_USE_NUMBA: bool = False
  12. def maybe_use_numba(engine: str | None) -> bool:
  13. """Signal whether to use numba routines."""
  14. return engine == "numba" or (engine is None and GLOBAL_USE_NUMBA)
  15. def set_use_numba(enable: bool = False) -> None:
  16. global GLOBAL_USE_NUMBA
  17. if enable:
  18. import_optional_dependency("numba")
  19. GLOBAL_USE_NUMBA = enable
  20. def get_jit_arguments(
  21. engine_kwargs: dict[str, bool] | None = None, kwargs: dict | None = None
  22. ) -> dict[str, bool]:
  23. """
  24. Return arguments to pass to numba.JIT, falling back on pandas default JIT settings.
  25. Parameters
  26. ----------
  27. engine_kwargs : dict, default None
  28. user passed keyword arguments for numba.JIT
  29. kwargs : dict, default None
  30. user passed keyword arguments to pass into the JITed function
  31. Returns
  32. -------
  33. dict[str, bool]
  34. nopython, nogil, parallel
  35. Raises
  36. ------
  37. NumbaUtilError
  38. """
  39. if engine_kwargs is None:
  40. engine_kwargs = {}
  41. nopython = engine_kwargs.get("nopython", True)
  42. if kwargs and nopython:
  43. raise NumbaUtilError(
  44. "numba does not support kwargs with nopython=True: "
  45. "https://github.com/numba/numba/issues/2916"
  46. )
  47. nogil = engine_kwargs.get("nogil", False)
  48. parallel = engine_kwargs.get("parallel", False)
  49. return {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  50. def jit_user_function(
  51. func: Callable, nopython: bool, nogil: bool, parallel: bool
  52. ) -> Callable:
  53. """
  54. JIT the user's function given the configurable arguments.
  55. Parameters
  56. ----------
  57. func : function
  58. user defined function
  59. nopython : bool
  60. nopython parameter for numba.JIT
  61. nogil : bool
  62. nogil parameter for numba.JIT
  63. parallel : bool
  64. parallel parameter for numba.JIT
  65. Returns
  66. -------
  67. function
  68. Numba JITed function
  69. """
  70. if TYPE_CHECKING:
  71. import numba
  72. else:
  73. numba = import_optional_dependency("numba")
  74. if numba.extending.is_jitted(func):
  75. # Don't jit a user passed jitted function
  76. numba_func = func
  77. else:
  78. @numba.generated_jit(nopython=nopython, nogil=nogil, parallel=parallel)
  79. def numba_func(data, *_args):
  80. if getattr(np, func.__name__, False) is func or isinstance(
  81. func, types.BuiltinFunctionType
  82. ):
  83. jf = func
  84. else:
  85. jf = numba.jit(func, nopython=nopython, nogil=nogil)
  86. def impl(data, *_args):
  87. return jf(data, *_args)
  88. return impl
  89. return numba_func