_statistical_functions.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from __future__ import annotations
  2. from ._dtypes import (
  3. _floating_dtypes,
  4. _numeric_dtypes,
  5. )
  6. from ._array_object import Array
  7. from ._creation_functions import asarray
  8. from ._dtypes import float32, float64
  9. from typing import TYPE_CHECKING, Optional, Tuple, Union
  10. if TYPE_CHECKING:
  11. from ._typing import Dtype
  12. import numpy as np
  13. def max(
  14. x: Array,
  15. /,
  16. *,
  17. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  18. keepdims: bool = False,
  19. ) -> Array:
  20. if x.dtype not in _numeric_dtypes:
  21. raise TypeError("Only numeric dtypes are allowed in max")
  22. return Array._new(np.max(x._array, axis=axis, keepdims=keepdims))
  23. def mean(
  24. x: Array,
  25. /,
  26. *,
  27. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  28. keepdims: bool = False,
  29. ) -> Array:
  30. if x.dtype not in _floating_dtypes:
  31. raise TypeError("Only floating-point dtypes are allowed in mean")
  32. return Array._new(np.mean(x._array, axis=axis, keepdims=keepdims))
  33. def min(
  34. x: Array,
  35. /,
  36. *,
  37. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  38. keepdims: bool = False,
  39. ) -> Array:
  40. if x.dtype not in _numeric_dtypes:
  41. raise TypeError("Only numeric dtypes are allowed in min")
  42. return Array._new(np.min(x._array, axis=axis, keepdims=keepdims))
  43. def prod(
  44. x: Array,
  45. /,
  46. *,
  47. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  48. dtype: Optional[Dtype] = None,
  49. keepdims: bool = False,
  50. ) -> Array:
  51. if x.dtype not in _numeric_dtypes:
  52. raise TypeError("Only numeric dtypes are allowed in prod")
  53. # Note: sum() and prod() always upcast float32 to float64 for dtype=None
  54. # We need to do so here before computing the product to avoid overflow
  55. if dtype is None and x.dtype == float32:
  56. dtype = float64
  57. return Array._new(np.prod(x._array, dtype=dtype, axis=axis, keepdims=keepdims))
  58. def std(
  59. x: Array,
  60. /,
  61. *,
  62. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  63. correction: Union[int, float] = 0.0,
  64. keepdims: bool = False,
  65. ) -> Array:
  66. # Note: the keyword argument correction is different here
  67. if x.dtype not in _floating_dtypes:
  68. raise TypeError("Only floating-point dtypes are allowed in std")
  69. return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims))
  70. def sum(
  71. x: Array,
  72. /,
  73. *,
  74. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  75. dtype: Optional[Dtype] = None,
  76. keepdims: bool = False,
  77. ) -> Array:
  78. if x.dtype not in _numeric_dtypes:
  79. raise TypeError("Only numeric dtypes are allowed in sum")
  80. # Note: sum() and prod() always upcast integers to (u)int64 and float32 to
  81. # float64 for dtype=None. `np.sum` does that too for integers, but not for
  82. # float32, so we need to special-case it here
  83. if dtype is None and x.dtype == float32:
  84. dtype = float64
  85. return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
  86. def var(
  87. x: Array,
  88. /,
  89. *,
  90. axis: Optional[Union[int, Tuple[int, ...]]] = None,
  91. correction: Union[int, float] = 0.0,
  92. keepdims: bool = False,
  93. ) -> Array:
  94. # Note: the keyword argument correction is different here
  95. if x.dtype not in _floating_dtypes:
  96. raise TypeError("Only floating-point dtypes are allowed in var")
  97. return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims))