_methods.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. """
  2. Array methods which are called by both the C-code for the method
  3. and the Python code for the NumPy-namespace function
  4. """
  5. import warnings
  6. from contextlib import nullcontext
  7. from numpy.core import multiarray as mu
  8. from numpy.core import umath as um
  9. from numpy.core.multiarray import asanyarray
  10. from numpy.core import numerictypes as nt
  11. from numpy.core import _exceptions
  12. from numpy.core._ufunc_config import _no_nep50_warning
  13. from numpy._globals import _NoValue
  14. from numpy.compat import pickle, os_fspath
  15. # save those O(100) nanoseconds!
  16. umr_maximum = um.maximum.reduce
  17. umr_minimum = um.minimum.reduce
  18. umr_sum = um.add.reduce
  19. umr_prod = um.multiply.reduce
  20. umr_any = um.logical_or.reduce
  21. umr_all = um.logical_and.reduce
  22. # Complex types to -> (2,)float view for fast-path computation in _var()
  23. _complex_to_float = {
  24. nt.dtype(nt.csingle) : nt.dtype(nt.single),
  25. nt.dtype(nt.cdouble) : nt.dtype(nt.double),
  26. }
  27. # Special case for windows: ensure double takes precedence
  28. if nt.dtype(nt.longdouble) != nt.dtype(nt.double):
  29. _complex_to_float.update({
  30. nt.dtype(nt.clongdouble) : nt.dtype(nt.longdouble),
  31. })
  32. # avoid keyword arguments to speed up parsing, saves about 15%-20% for very
  33. # small reductions
  34. def _amax(a, axis=None, out=None, keepdims=False,
  35. initial=_NoValue, where=True):
  36. return umr_maximum(a, axis, None, out, keepdims, initial, where)
  37. def _amin(a, axis=None, out=None, keepdims=False,
  38. initial=_NoValue, where=True):
  39. return umr_minimum(a, axis, None, out, keepdims, initial, where)
  40. def _sum(a, axis=None, dtype=None, out=None, keepdims=False,
  41. initial=_NoValue, where=True):
  42. return umr_sum(a, axis, dtype, out, keepdims, initial, where)
  43. def _prod(a, axis=None, dtype=None, out=None, keepdims=False,
  44. initial=_NoValue, where=True):
  45. return umr_prod(a, axis, dtype, out, keepdims, initial, where)
  46. def _any(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
  47. # Parsing keyword arguments is currently fairly slow, so avoid it for now
  48. if where is True:
  49. return umr_any(a, axis, dtype, out, keepdims)
  50. return umr_any(a, axis, dtype, out, keepdims, where=where)
  51. def _all(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
  52. # Parsing keyword arguments is currently fairly slow, so avoid it for now
  53. if where is True:
  54. return umr_all(a, axis, dtype, out, keepdims)
  55. return umr_all(a, axis, dtype, out, keepdims, where=where)
  56. def _count_reduce_items(arr, axis, keepdims=False, where=True):
  57. # fast-path for the default case
  58. if where is True:
  59. # no boolean mask given, calculate items according to axis
  60. if axis is None:
  61. axis = tuple(range(arr.ndim))
  62. elif not isinstance(axis, tuple):
  63. axis = (axis,)
  64. items = 1
  65. for ax in axis:
  66. items *= arr.shape[mu.normalize_axis_index(ax, arr.ndim)]
  67. items = nt.intp(items)
  68. else:
  69. # TODO: Optimize case when `where` is broadcast along a non-reduction
  70. # axis and full sum is more excessive than needed.
  71. # guarded to protect circular imports
  72. from numpy.lib.stride_tricks import broadcast_to
  73. # count True values in (potentially broadcasted) boolean mask
  74. items = umr_sum(broadcast_to(where, arr.shape), axis, nt.intp, None,
  75. keepdims)
  76. return items
  77. # Numpy 1.17.0, 2019-02-24
  78. # Various clip behavior deprecations, marked with _clip_dep as a prefix.
  79. def _clip_dep_is_scalar_nan(a):
  80. # guarded to protect circular imports
  81. from numpy.core.fromnumeric import ndim
  82. if ndim(a) != 0:
  83. return False
  84. try:
  85. return um.isnan(a)
  86. except TypeError:
  87. return False
  88. def _clip_dep_is_byte_swapped(a):
  89. if isinstance(a, mu.ndarray):
  90. return not a.dtype.isnative
  91. return False
  92. def _clip_dep_invoke_with_casting(ufunc, *args, out=None, casting=None, **kwargs):
  93. # normal path
  94. if casting is not None:
  95. return ufunc(*args, out=out, casting=casting, **kwargs)
  96. # try to deal with broken casting rules
  97. try:
  98. return ufunc(*args, out=out, **kwargs)
  99. except _exceptions._UFuncOutputCastingError as e:
  100. # Numpy 1.17.0, 2019-02-24
  101. warnings.warn(
  102. "Converting the output of clip from {!r} to {!r} is deprecated. "
  103. "Pass `casting=\"unsafe\"` explicitly to silence this warning, or "
  104. "correct the type of the variables.".format(e.from_, e.to),
  105. DeprecationWarning,
  106. stacklevel=2
  107. )
  108. return ufunc(*args, out=out, casting="unsafe", **kwargs)
  109. def _clip(a, min=None, max=None, out=None, *, casting=None, **kwargs):
  110. if min is None and max is None:
  111. raise ValueError("One of max or min must be given")
  112. # Numpy 1.17.0, 2019-02-24
  113. # This deprecation probably incurs a substantial slowdown for small arrays,
  114. # it will be good to get rid of it.
  115. if not _clip_dep_is_byte_swapped(a) and not _clip_dep_is_byte_swapped(out):
  116. using_deprecated_nan = False
  117. if _clip_dep_is_scalar_nan(min):
  118. min = -float('inf')
  119. using_deprecated_nan = True
  120. if _clip_dep_is_scalar_nan(max):
  121. max = float('inf')
  122. using_deprecated_nan = True
  123. if using_deprecated_nan:
  124. warnings.warn(
  125. "Passing `np.nan` to mean no clipping in np.clip has always "
  126. "been unreliable, and is now deprecated. "
  127. "In future, this will always return nan, like it already does "
  128. "when min or max are arrays that contain nan. "
  129. "To skip a bound, pass either None or an np.inf of an "
  130. "appropriate sign.",
  131. DeprecationWarning,
  132. stacklevel=2
  133. )
  134. if min is None:
  135. return _clip_dep_invoke_with_casting(
  136. um.minimum, a, max, out=out, casting=casting, **kwargs)
  137. elif max is None:
  138. return _clip_dep_invoke_with_casting(
  139. um.maximum, a, min, out=out, casting=casting, **kwargs)
  140. else:
  141. return _clip_dep_invoke_with_casting(
  142. um.clip, a, min, max, out=out, casting=casting, **kwargs)
  143. def _mean(a, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
  144. arr = asanyarray(a)
  145. is_float16_result = False
  146. rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
  147. if rcount == 0 if where is True else umr_any(rcount == 0, axis=None):
  148. warnings.warn("Mean of empty slice.", RuntimeWarning, stacklevel=2)
  149. # Cast bool, unsigned int, and int to float64 by default
  150. if dtype is None:
  151. if issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
  152. dtype = mu.dtype('f8')
  153. elif issubclass(arr.dtype.type, nt.float16):
  154. dtype = mu.dtype('f4')
  155. is_float16_result = True
  156. ret = umr_sum(arr, axis, dtype, out, keepdims, where=where)
  157. if isinstance(ret, mu.ndarray):
  158. with _no_nep50_warning():
  159. ret = um.true_divide(
  160. ret, rcount, out=ret, casting='unsafe', subok=False)
  161. if is_float16_result and out is None:
  162. ret = arr.dtype.type(ret)
  163. elif hasattr(ret, 'dtype'):
  164. if is_float16_result:
  165. ret = arr.dtype.type(ret / rcount)
  166. else:
  167. ret = ret.dtype.type(ret / rcount)
  168. else:
  169. ret = ret / rcount
  170. return ret
  171. def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
  172. where=True):
  173. arr = asanyarray(a)
  174. rcount = _count_reduce_items(arr, axis, keepdims=keepdims, where=where)
  175. # Make this warning show up on top.
  176. if ddof >= rcount if where is True else umr_any(ddof >= rcount, axis=None):
  177. warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning,
  178. stacklevel=2)
  179. # Cast bool, unsigned int, and int to float64 by default
  180. if dtype is None and issubclass(arr.dtype.type, (nt.integer, nt.bool_)):
  181. dtype = mu.dtype('f8')
  182. # Compute the mean.
  183. # Note that if dtype is not of inexact type then arraymean will
  184. # not be either.
  185. arrmean = umr_sum(arr, axis, dtype, keepdims=True, where=where)
  186. # The shape of rcount has to match arrmean to not change the shape of out
  187. # in broadcasting. Otherwise, it cannot be stored back to arrmean.
  188. if rcount.ndim == 0:
  189. # fast-path for default case when where is True
  190. div = rcount
  191. else:
  192. # matching rcount to arrmean when where is specified as array
  193. div = rcount.reshape(arrmean.shape)
  194. if isinstance(arrmean, mu.ndarray):
  195. with _no_nep50_warning():
  196. arrmean = um.true_divide(arrmean, div, out=arrmean,
  197. casting='unsafe', subok=False)
  198. elif hasattr(arrmean, "dtype"):
  199. arrmean = arrmean.dtype.type(arrmean / rcount)
  200. else:
  201. arrmean = arrmean / rcount
  202. # Compute sum of squared deviations from mean
  203. # Note that x may not be inexact and that we need it to be an array,
  204. # not a scalar.
  205. x = asanyarray(arr - arrmean)
  206. if issubclass(arr.dtype.type, (nt.floating, nt.integer)):
  207. x = um.multiply(x, x, out=x)
  208. # Fast-paths for built-in complex types
  209. elif x.dtype in _complex_to_float:
  210. xv = x.view(dtype=(_complex_to_float[x.dtype], (2,)))
  211. um.multiply(xv, xv, out=xv)
  212. x = um.add(xv[..., 0], xv[..., 1], out=x.real).real
  213. # Most general case; includes handling object arrays containing imaginary
  214. # numbers and complex types with non-native byteorder
  215. else:
  216. x = um.multiply(x, um.conjugate(x), out=x).real
  217. ret = umr_sum(x, axis, dtype, out, keepdims=keepdims, where=where)
  218. # Compute degrees of freedom and make sure it is not negative.
  219. rcount = um.maximum(rcount - ddof, 0)
  220. # divide by degrees of freedom
  221. if isinstance(ret, mu.ndarray):
  222. with _no_nep50_warning():
  223. ret = um.true_divide(
  224. ret, rcount, out=ret, casting='unsafe', subok=False)
  225. elif hasattr(ret, 'dtype'):
  226. ret = ret.dtype.type(ret / rcount)
  227. else:
  228. ret = ret / rcount
  229. return ret
  230. def _std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *,
  231. where=True):
  232. ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  233. keepdims=keepdims, where=where)
  234. if isinstance(ret, mu.ndarray):
  235. ret = um.sqrt(ret, out=ret)
  236. elif hasattr(ret, 'dtype'):
  237. ret = ret.dtype.type(um.sqrt(ret))
  238. else:
  239. ret = um.sqrt(ret)
  240. return ret
  241. def _ptp(a, axis=None, out=None, keepdims=False):
  242. return um.subtract(
  243. umr_maximum(a, axis, None, out, keepdims),
  244. umr_minimum(a, axis, None, None, keepdims),
  245. out
  246. )
  247. def _dump(self, file, protocol=2):
  248. if hasattr(file, 'write'):
  249. ctx = nullcontext(file)
  250. else:
  251. ctx = open(os_fspath(file), "wb")
  252. with ctx as f:
  253. pickle.dump(self, f, protocol=protocol)
  254. def _dumps(self, protocol=2):
  255. return pickle.dumps(self, protocol=protocol)