quantile.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. from __future__ import annotations
  2. import numpy as np
  3. from pandas._typing import (
  4. ArrayLike,
  5. Scalar,
  6. npt,
  7. )
  8. from pandas.compat.numpy import np_percentile_argname
  9. from pandas.core.dtypes.missing import (
  10. isna,
  11. na_value_for_dtype,
  12. )
  13. def quantile_compat(
  14. values: ArrayLike, qs: npt.NDArray[np.float64], interpolation: str
  15. ) -> ArrayLike:
  16. """
  17. Compute the quantiles of the given values for each quantile in `qs`.
  18. Parameters
  19. ----------
  20. values : np.ndarray or ExtensionArray
  21. qs : np.ndarray[float64]
  22. interpolation : str
  23. Returns
  24. -------
  25. np.ndarray or ExtensionArray
  26. """
  27. if isinstance(values, np.ndarray):
  28. fill_value = na_value_for_dtype(values.dtype, compat=False)
  29. mask = isna(values)
  30. return quantile_with_mask(values, mask, fill_value, qs, interpolation)
  31. else:
  32. return values._quantile(qs, interpolation)
  33. def quantile_with_mask(
  34. values: np.ndarray,
  35. mask: npt.NDArray[np.bool_],
  36. fill_value,
  37. qs: npt.NDArray[np.float64],
  38. interpolation: str,
  39. ) -> np.ndarray:
  40. """
  41. Compute the quantiles of the given values for each quantile in `qs`.
  42. Parameters
  43. ----------
  44. values : np.ndarray
  45. For ExtensionArray, this is _values_for_factorize()[0]
  46. mask : np.ndarray[bool]
  47. mask = isna(values)
  48. For ExtensionArray, this is computed before calling _value_for_factorize
  49. fill_value : Scalar
  50. The value to interpret fill NA entries with
  51. For ExtensionArray, this is _values_for_factorize()[1]
  52. qs : np.ndarray[float64]
  53. interpolation : str
  54. Type of interpolation
  55. Returns
  56. -------
  57. np.ndarray
  58. Notes
  59. -----
  60. Assumes values is already 2D. For ExtensionArray this means np.atleast_2d
  61. has been called on _values_for_factorize()[0]
  62. Quantile is computed along axis=1.
  63. """
  64. assert values.shape == mask.shape
  65. if values.ndim == 1:
  66. # unsqueeze, operate, re-squeeze
  67. values = np.atleast_2d(values)
  68. mask = np.atleast_2d(mask)
  69. res_values = quantile_with_mask(values, mask, fill_value, qs, interpolation)
  70. return res_values[0]
  71. assert values.ndim == 2
  72. is_empty = values.shape[1] == 0
  73. if is_empty:
  74. # create the array of na_values
  75. # 2d len(values) * len(qs)
  76. flat = np.array([fill_value] * len(qs))
  77. result = np.repeat(flat, len(values)).reshape(len(values), len(qs))
  78. else:
  79. result = _nanpercentile(
  80. values,
  81. qs * 100.0,
  82. na_value=fill_value,
  83. mask=mask,
  84. interpolation=interpolation,
  85. )
  86. result = np.array(result, copy=False)
  87. result = result.T
  88. return result
  89. def _nanpercentile_1d(
  90. values: np.ndarray,
  91. mask: npt.NDArray[np.bool_],
  92. qs: npt.NDArray[np.float64],
  93. na_value: Scalar,
  94. interpolation: str,
  95. ) -> Scalar | np.ndarray:
  96. """
  97. Wrapper for np.percentile that skips missing values, specialized to
  98. 1-dimensional case.
  99. Parameters
  100. ----------
  101. values : array over which to find quantiles
  102. mask : ndarray[bool]
  103. locations in values that should be considered missing
  104. qs : np.ndarray[float64] of quantile indices to find
  105. na_value : scalar
  106. value to return for empty or all-null values
  107. interpolation : str
  108. Returns
  109. -------
  110. quantiles : scalar or array
  111. """
  112. # mask is Union[ExtensionArray, ndarray]
  113. values = values[~mask]
  114. if len(values) == 0:
  115. # Can't pass dtype=values.dtype here bc we might have na_value=np.nan
  116. # with values.dtype=int64 see test_quantile_empty
  117. # equiv: 'np.array([na_value] * len(qs))' but much faster
  118. return np.full(len(qs), na_value)
  119. return np.percentile(
  120. values,
  121. qs,
  122. # error: No overload variant of "percentile" matches argument
  123. # types "ndarray[Any, Any]", "ndarray[Any, dtype[floating[_64Bit]]]"
  124. # , "Dict[str, str]" [call-overload]
  125. **{np_percentile_argname: interpolation}, # type: ignore[call-overload]
  126. )
  127. def _nanpercentile(
  128. values: np.ndarray,
  129. qs: npt.NDArray[np.float64],
  130. *,
  131. na_value,
  132. mask: npt.NDArray[np.bool_],
  133. interpolation: str,
  134. ):
  135. """
  136. Wrapper for np.percentile that skips missing values.
  137. Parameters
  138. ----------
  139. values : np.ndarray[ndim=2] over which to find quantiles
  140. qs : np.ndarray[float64] of quantile indices to find
  141. na_value : scalar
  142. value to return for empty or all-null values
  143. mask : np.ndarray[bool]
  144. locations in values that should be considered missing
  145. interpolation : str
  146. Returns
  147. -------
  148. quantiles : scalar or array
  149. """
  150. if values.dtype.kind in ["m", "M"]:
  151. # need to cast to integer to avoid rounding errors in numpy
  152. result = _nanpercentile(
  153. values.view("i8"),
  154. qs=qs,
  155. na_value=na_value.view("i8"),
  156. mask=mask,
  157. interpolation=interpolation,
  158. )
  159. # Note: we have to do `astype` and not view because in general we
  160. # have float result at this point, not i8
  161. return result.astype(values.dtype)
  162. if mask.any():
  163. # Caller is responsible for ensuring mask shape match
  164. assert mask.shape == values.shape
  165. result = [
  166. _nanpercentile_1d(val, m, qs, na_value, interpolation=interpolation)
  167. for (val, m) in zip(list(values), list(mask))
  168. ]
  169. if values.dtype.kind == "f":
  170. # preserve itemsize
  171. result = np.array(result, dtype=values.dtype, copy=False).T
  172. else:
  173. result = np.array(result, copy=False).T
  174. if (
  175. result.dtype != values.dtype
  176. and not mask.all()
  177. and (result == result.astype(values.dtype, copy=False)).all()
  178. ):
  179. # mask.all() will never get cast back to int
  180. # e.g. values id integer dtype and result is floating dtype,
  181. # only cast back to integer dtype if result values are all-integer.
  182. result = result.astype(values.dtype, copy=False)
  183. return result
  184. else:
  185. return np.percentile(
  186. values,
  187. qs,
  188. axis=1,
  189. # error: No overload variant of "percentile" matches argument types
  190. # "ndarray[Any, Any]", "ndarray[Any, dtype[floating[_64Bit]]]",
  191. # "int", "Dict[str, str]" [call-overload]
  192. **{np_percentile_argname: interpolation}, # type: ignore[call-overload]
  193. )