function.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. """
  2. For compatibility with numpy libraries, pandas functions or methods have to
  3. accept '*args' and '**kwargs' parameters to accommodate numpy arguments that
  4. are not actually used or respected in the pandas implementation.
  5. To ensure that users do not abuse these parameters, validation is performed in
  6. 'validators.py' to make sure that any extra parameters passed correspond ONLY
  7. to those in the numpy signature. Part of that validation includes whether or
  8. not the user attempted to pass in non-default values for these extraneous
  9. parameters. As we want to discourage users from relying on these parameters
  10. when calling the pandas implementation, we want them only to pass in the
  11. default values for these parameters.
  12. This module provides a set of commonly used default arguments for functions and
  13. methods that are spread throughout the codebase. This module will make it
  14. easier to adjust to future upstream changes in the analogous numpy signatures.
  15. """
  16. from __future__ import annotations
  17. from typing import (
  18. Any,
  19. TypeVar,
  20. cast,
  21. overload,
  22. )
  23. from numpy import ndarray
  24. from pandas._libs.lib import (
  25. is_bool,
  26. is_integer,
  27. )
  28. from pandas._typing import (
  29. Axis,
  30. AxisInt,
  31. )
  32. from pandas.errors import UnsupportedFunctionCall
  33. from pandas.util._validators import (
  34. validate_args,
  35. validate_args_and_kwargs,
  36. validate_kwargs,
  37. )
  38. AxisNoneT = TypeVar("AxisNoneT", Axis, None)
  39. class CompatValidator:
  40. def __init__(
  41. self,
  42. defaults,
  43. fname=None,
  44. method: str | None = None,
  45. max_fname_arg_count=None,
  46. ) -> None:
  47. self.fname = fname
  48. self.method = method
  49. self.defaults = defaults
  50. self.max_fname_arg_count = max_fname_arg_count
  51. def __call__(
  52. self,
  53. args,
  54. kwargs,
  55. fname=None,
  56. max_fname_arg_count=None,
  57. method: str | None = None,
  58. ) -> None:
  59. if args or kwargs:
  60. fname = self.fname if fname is None else fname
  61. max_fname_arg_count = (
  62. self.max_fname_arg_count
  63. if max_fname_arg_count is None
  64. else max_fname_arg_count
  65. )
  66. method = self.method if method is None else method
  67. if method == "args":
  68. validate_args(fname, args, max_fname_arg_count, self.defaults)
  69. elif method == "kwargs":
  70. validate_kwargs(fname, kwargs, self.defaults)
  71. elif method == "both":
  72. validate_args_and_kwargs(
  73. fname, args, kwargs, max_fname_arg_count, self.defaults
  74. )
  75. else:
  76. raise ValueError(f"invalid validation method '{method}'")
  77. ARGMINMAX_DEFAULTS = {"out": None}
  78. validate_argmin = CompatValidator(
  79. ARGMINMAX_DEFAULTS, fname="argmin", method="both", max_fname_arg_count=1
  80. )
  81. validate_argmax = CompatValidator(
  82. ARGMINMAX_DEFAULTS, fname="argmax", method="both", max_fname_arg_count=1
  83. )
  84. def process_skipna(skipna: bool | ndarray | None, args) -> tuple[bool, Any]:
  85. if isinstance(skipna, ndarray) or skipna is None:
  86. args = (skipna,) + args
  87. skipna = True
  88. return skipna, args
  89. def validate_argmin_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
  90. """
  91. If 'Series.argmin' is called via the 'numpy' library, the third parameter
  92. in its signature is 'out', which takes either an ndarray or 'None', so
  93. check if the 'skipna' parameter is either an instance of ndarray or is
  94. None, since 'skipna' itself should be a boolean
  95. """
  96. skipna, args = process_skipna(skipna, args)
  97. validate_argmin(args, kwargs)
  98. return skipna
  99. def validate_argmax_with_skipna(skipna: bool | ndarray | None, args, kwargs) -> bool:
  100. """
  101. If 'Series.argmax' is called via the 'numpy' library, the third parameter
  102. in its signature is 'out', which takes either an ndarray or 'None', so
  103. check if the 'skipna' parameter is either an instance of ndarray or is
  104. None, since 'skipna' itself should be a boolean
  105. """
  106. skipna, args = process_skipna(skipna, args)
  107. validate_argmax(args, kwargs)
  108. return skipna
  109. ARGSORT_DEFAULTS: dict[str, int | str | None] = {}
  110. ARGSORT_DEFAULTS["axis"] = -1
  111. ARGSORT_DEFAULTS["kind"] = "quicksort"
  112. ARGSORT_DEFAULTS["order"] = None
  113. ARGSORT_DEFAULTS["kind"] = None
  114. validate_argsort = CompatValidator(
  115. ARGSORT_DEFAULTS, fname="argsort", max_fname_arg_count=0, method="both"
  116. )
  117. # two different signatures of argsort, this second validation for when the
  118. # `kind` param is supported
  119. ARGSORT_DEFAULTS_KIND: dict[str, int | None] = {}
  120. ARGSORT_DEFAULTS_KIND["axis"] = -1
  121. ARGSORT_DEFAULTS_KIND["order"] = None
  122. validate_argsort_kind = CompatValidator(
  123. ARGSORT_DEFAULTS_KIND, fname="argsort", max_fname_arg_count=0, method="both"
  124. )
  125. def validate_argsort_with_ascending(ascending: bool | int | None, args, kwargs) -> bool:
  126. """
  127. If 'Categorical.argsort' is called via the 'numpy' library, the first
  128. parameter in its signature is 'axis', which takes either an integer or
  129. 'None', so check if the 'ascending' parameter has either integer type or is
  130. None, since 'ascending' itself should be a boolean
  131. """
  132. if is_integer(ascending) or ascending is None:
  133. args = (ascending,) + args
  134. ascending = True
  135. validate_argsort_kind(args, kwargs, max_fname_arg_count=3)
  136. ascending = cast(bool, ascending)
  137. return ascending
  138. CLIP_DEFAULTS: dict[str, Any] = {"out": None}
  139. validate_clip = CompatValidator(
  140. CLIP_DEFAULTS, fname="clip", method="both", max_fname_arg_count=3
  141. )
  142. @overload
  143. def validate_clip_with_axis(axis: ndarray, args, kwargs) -> None:
  144. ...
  145. @overload
  146. def validate_clip_with_axis(axis: AxisNoneT, args, kwargs) -> AxisNoneT:
  147. ...
  148. def validate_clip_with_axis(
  149. axis: ndarray | AxisNoneT, args, kwargs
  150. ) -> AxisNoneT | None:
  151. """
  152. If 'NDFrame.clip' is called via the numpy library, the third parameter in
  153. its signature is 'out', which can takes an ndarray, so check if the 'axis'
  154. parameter is an instance of ndarray, since 'axis' itself should either be
  155. an integer or None
  156. """
  157. if isinstance(axis, ndarray):
  158. args = (axis,) + args
  159. # error: Incompatible types in assignment (expression has type "None",
  160. # variable has type "Union[ndarray[Any, Any], str, int]")
  161. axis = None # type: ignore[assignment]
  162. validate_clip(args, kwargs)
  163. # error: Incompatible return value type (got "Union[ndarray[Any, Any],
  164. # str, int]", expected "Union[str, int, None]")
  165. return axis # type: ignore[return-value]
  166. CUM_FUNC_DEFAULTS: dict[str, Any] = {}
  167. CUM_FUNC_DEFAULTS["dtype"] = None
  168. CUM_FUNC_DEFAULTS["out"] = None
  169. validate_cum_func = CompatValidator(
  170. CUM_FUNC_DEFAULTS, method="both", max_fname_arg_count=1
  171. )
  172. validate_cumsum = CompatValidator(
  173. CUM_FUNC_DEFAULTS, fname="cumsum", method="both", max_fname_arg_count=1
  174. )
  175. def validate_cum_func_with_skipna(skipna, args, kwargs, name) -> bool:
  176. """
  177. If this function is called via the 'numpy' library, the third parameter in
  178. its signature is 'dtype', which takes either a 'numpy' dtype or 'None', so
  179. check if the 'skipna' parameter is a boolean or not
  180. """
  181. if not is_bool(skipna):
  182. args = (skipna,) + args
  183. skipna = True
  184. validate_cum_func(args, kwargs, fname=name)
  185. return skipna
  186. ALLANY_DEFAULTS: dict[str, bool | None] = {}
  187. ALLANY_DEFAULTS["dtype"] = None
  188. ALLANY_DEFAULTS["out"] = None
  189. ALLANY_DEFAULTS["keepdims"] = False
  190. ALLANY_DEFAULTS["axis"] = None
  191. validate_all = CompatValidator(
  192. ALLANY_DEFAULTS, fname="all", method="both", max_fname_arg_count=1
  193. )
  194. validate_any = CompatValidator(
  195. ALLANY_DEFAULTS, fname="any", method="both", max_fname_arg_count=1
  196. )
  197. LOGICAL_FUNC_DEFAULTS = {"out": None, "keepdims": False}
  198. validate_logical_func = CompatValidator(LOGICAL_FUNC_DEFAULTS, method="kwargs")
  199. MINMAX_DEFAULTS = {"axis": None, "out": None, "keepdims": False}
  200. validate_min = CompatValidator(
  201. MINMAX_DEFAULTS, fname="min", method="both", max_fname_arg_count=1
  202. )
  203. validate_max = CompatValidator(
  204. MINMAX_DEFAULTS, fname="max", method="both", max_fname_arg_count=1
  205. )
  206. RESHAPE_DEFAULTS: dict[str, str] = {"order": "C"}
  207. validate_reshape = CompatValidator(
  208. RESHAPE_DEFAULTS, fname="reshape", method="both", max_fname_arg_count=1
  209. )
  210. REPEAT_DEFAULTS: dict[str, Any] = {"axis": None}
  211. validate_repeat = CompatValidator(
  212. REPEAT_DEFAULTS, fname="repeat", method="both", max_fname_arg_count=1
  213. )
  214. ROUND_DEFAULTS: dict[str, Any] = {"out": None}
  215. validate_round = CompatValidator(
  216. ROUND_DEFAULTS, fname="round", method="both", max_fname_arg_count=1
  217. )
  218. SORT_DEFAULTS: dict[str, int | str | None] = {}
  219. SORT_DEFAULTS["axis"] = -1
  220. SORT_DEFAULTS["kind"] = "quicksort"
  221. SORT_DEFAULTS["order"] = None
  222. validate_sort = CompatValidator(SORT_DEFAULTS, fname="sort", method="kwargs")
  223. STAT_FUNC_DEFAULTS: dict[str, Any | None] = {}
  224. STAT_FUNC_DEFAULTS["dtype"] = None
  225. STAT_FUNC_DEFAULTS["out"] = None
  226. SUM_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  227. SUM_DEFAULTS["axis"] = None
  228. SUM_DEFAULTS["keepdims"] = False
  229. SUM_DEFAULTS["initial"] = None
  230. PROD_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  231. PROD_DEFAULTS["axis"] = None
  232. PROD_DEFAULTS["keepdims"] = False
  233. PROD_DEFAULTS["initial"] = None
  234. MEDIAN_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  235. MEDIAN_DEFAULTS["overwrite_input"] = False
  236. MEDIAN_DEFAULTS["keepdims"] = False
  237. STAT_FUNC_DEFAULTS["keepdims"] = False
  238. validate_stat_func = CompatValidator(STAT_FUNC_DEFAULTS, method="kwargs")
  239. validate_sum = CompatValidator(
  240. SUM_DEFAULTS, fname="sum", method="both", max_fname_arg_count=1
  241. )
  242. validate_prod = CompatValidator(
  243. PROD_DEFAULTS, fname="prod", method="both", max_fname_arg_count=1
  244. )
  245. validate_mean = CompatValidator(
  246. STAT_FUNC_DEFAULTS, fname="mean", method="both", max_fname_arg_count=1
  247. )
  248. validate_median = CompatValidator(
  249. MEDIAN_DEFAULTS, fname="median", method="both", max_fname_arg_count=1
  250. )
  251. STAT_DDOF_FUNC_DEFAULTS: dict[str, bool | None] = {}
  252. STAT_DDOF_FUNC_DEFAULTS["dtype"] = None
  253. STAT_DDOF_FUNC_DEFAULTS["out"] = None
  254. STAT_DDOF_FUNC_DEFAULTS["keepdims"] = False
  255. validate_stat_ddof_func = CompatValidator(STAT_DDOF_FUNC_DEFAULTS, method="kwargs")
  256. TAKE_DEFAULTS: dict[str, str | None] = {}
  257. TAKE_DEFAULTS["out"] = None
  258. TAKE_DEFAULTS["mode"] = "raise"
  259. validate_take = CompatValidator(TAKE_DEFAULTS, fname="take", method="kwargs")
  260. def validate_take_with_convert(convert: ndarray | bool | None, args, kwargs) -> bool:
  261. """
  262. If this function is called via the 'numpy' library, the third parameter in
  263. its signature is 'axis', which takes either an ndarray or 'None', so check
  264. if the 'convert' parameter is either an instance of ndarray or is None
  265. """
  266. if isinstance(convert, ndarray) or convert is None:
  267. args = (convert,) + args
  268. convert = True
  269. validate_take(args, kwargs, max_fname_arg_count=3, method="both")
  270. return convert
  271. TRANSPOSE_DEFAULTS = {"axes": None}
  272. validate_transpose = CompatValidator(
  273. TRANSPOSE_DEFAULTS, fname="transpose", method="both", max_fname_arg_count=0
  274. )
  275. def validate_groupby_func(name, args, kwargs, allowed=None) -> None:
  276. """
  277. 'args' and 'kwargs' should be empty, except for allowed kwargs because all
  278. of their necessary parameters are explicitly listed in the function
  279. signature
  280. """
  281. if allowed is None:
  282. allowed = []
  283. kwargs = set(kwargs) - set(allowed)
  284. if len(args) + len(kwargs) > 0:
  285. raise UnsupportedFunctionCall(
  286. "numpy operations are not valid with groupby. "
  287. f"Use .groupby(...).{name}() instead"
  288. )
  289. RESAMPLER_NUMPY_OPS = ("min", "max", "sum", "prod", "mean", "std", "var")
  290. def validate_resampler_func(method: str, args, kwargs) -> None:
  291. """
  292. 'args' and 'kwargs' should be empty because all of their necessary
  293. parameters are explicitly listed in the function signature
  294. """
  295. if len(args) + len(kwargs) > 0:
  296. if method in RESAMPLER_NUMPY_OPS:
  297. raise UnsupportedFunctionCall(
  298. "numpy operations are not valid with resample. "
  299. f"Use .resample(...).{method}() instead"
  300. )
  301. raise TypeError("too many arguments passed in")
  302. def validate_minmax_axis(axis: AxisInt | None, ndim: int = 1) -> None:
  303. """
  304. Ensure that the axis argument passed to min, max, argmin, or argmax is zero
  305. or None, as otherwise it will be incorrectly ignored.
  306. Parameters
  307. ----------
  308. axis : int or None
  309. ndim : int, default 1
  310. Raises
  311. ------
  312. ValueError
  313. """
  314. if axis is None:
  315. return
  316. if axis >= ndim or (axis < 0 and ndim + axis < 0):
  317. raise ValueError(f"`axis` must be fewer than the number of dimensions ({ndim})")