array.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. from __future__ import annotations
  2. import decimal
  3. import numbers
  4. import random
  5. import sys
  6. import numpy as np
  7. from pandas._typing import type_t
  8. from pandas.core.dtypes.base import ExtensionDtype
  9. from pandas.core.dtypes.common import (
  10. is_dtype_equal,
  11. is_float,
  12. pandas_dtype,
  13. )
  14. import pandas as pd
  15. from pandas.api.extensions import (
  16. no_default,
  17. register_extension_dtype,
  18. )
  19. from pandas.api.types import (
  20. is_list_like,
  21. is_scalar,
  22. )
  23. from pandas.core import arraylike
  24. from pandas.core.arraylike import OpsMixin
  25. from pandas.core.arrays import (
  26. ExtensionArray,
  27. ExtensionScalarOpsMixin,
  28. )
  29. from pandas.core.indexers import check_array_indexer
  30. @register_extension_dtype
  31. class DecimalDtype(ExtensionDtype):
  32. type = decimal.Decimal
  33. name = "decimal"
  34. na_value = decimal.Decimal("NaN")
  35. _metadata = ("context",)
  36. def __init__(self, context=None) -> None:
  37. self.context = context or decimal.getcontext()
  38. def __repr__(self) -> str:
  39. return f"DecimalDtype(context={self.context})"
  40. @classmethod
  41. def construct_array_type(cls) -> type_t[DecimalArray]:
  42. """
  43. Return the array type associated with this dtype.
  44. Returns
  45. -------
  46. type
  47. """
  48. return DecimalArray
  49. @property
  50. def _is_numeric(self) -> bool:
  51. return True
  52. class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray):
  53. __array_priority__ = 1000
  54. def __init__(self, values, dtype=None, copy=False, context=None) -> None:
  55. for i, val in enumerate(values):
  56. if is_float(val):
  57. if np.isnan(val):
  58. values[i] = DecimalDtype.na_value
  59. else:
  60. values[i] = DecimalDtype.type(val)
  61. elif not isinstance(val, decimal.Decimal):
  62. raise TypeError("All values must be of type " + str(decimal.Decimal))
  63. values = np.asarray(values, dtype=object)
  64. self._data = values
  65. # Some aliases for common attribute names to ensure pandas supports
  66. # these
  67. self._items = self.data = self._data
  68. # those aliases are currently not working due to assumptions
  69. # in internal code (GH-20735)
  70. # self._values = self.values = self.data
  71. self._dtype = DecimalDtype(context)
  72. @property
  73. def dtype(self):
  74. return self._dtype
  75. @classmethod
  76. def _from_sequence(cls, scalars, dtype=None, copy=False):
  77. return cls(scalars)
  78. @classmethod
  79. def _from_sequence_of_strings(cls, strings, dtype=None, copy=False):
  80. return cls._from_sequence([decimal.Decimal(x) for x in strings], dtype, copy)
  81. @classmethod
  82. def _from_factorized(cls, values, original):
  83. return cls(values)
  84. _HANDLED_TYPES = (decimal.Decimal, numbers.Number, np.ndarray)
  85. def to_numpy(
  86. self,
  87. dtype=None,
  88. copy: bool = False,
  89. na_value: object = no_default,
  90. decimals=None,
  91. ) -> np.ndarray:
  92. result = np.asarray(self, dtype=dtype)
  93. if decimals is not None:
  94. result = np.asarray([round(x, decimals) for x in result])
  95. return result
  96. def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
  97. #
  98. if not all(
  99. isinstance(t, self._HANDLED_TYPES + (DecimalArray,)) for t in inputs
  100. ):
  101. return NotImplemented
  102. result = arraylike.maybe_dispatch_ufunc_to_dunder_op(
  103. self, ufunc, method, *inputs, **kwargs
  104. )
  105. if result is not NotImplemented:
  106. # e.g. test_array_ufunc_series_scalar_other
  107. return result
  108. if "out" in kwargs:
  109. return arraylike.dispatch_ufunc_with_out(
  110. self, ufunc, method, *inputs, **kwargs
  111. )
  112. inputs = tuple(x._data if isinstance(x, DecimalArray) else x for x in inputs)
  113. result = getattr(ufunc, method)(*inputs, **kwargs)
  114. if method == "reduce":
  115. result = arraylike.dispatch_reduction_ufunc(
  116. self, ufunc, method, *inputs, **kwargs
  117. )
  118. if result is not NotImplemented:
  119. return result
  120. def reconstruct(x):
  121. if isinstance(x, (decimal.Decimal, numbers.Number)):
  122. return x
  123. else:
  124. return DecimalArray._from_sequence(x)
  125. if ufunc.nout > 1:
  126. return tuple(reconstruct(x) for x in result)
  127. else:
  128. return reconstruct(result)
  129. def __getitem__(self, item):
  130. if isinstance(item, numbers.Integral):
  131. return self._data[item]
  132. else:
  133. # array, slice.
  134. item = pd.api.indexers.check_array_indexer(self, item)
  135. return type(self)(self._data[item])
  136. def take(self, indexer, allow_fill=False, fill_value=None):
  137. from pandas.api.extensions import take
  138. data = self._data
  139. if allow_fill and fill_value is None:
  140. fill_value = self.dtype.na_value
  141. result = take(data, indexer, fill_value=fill_value, allow_fill=allow_fill)
  142. return self._from_sequence(result)
  143. def copy(self):
  144. return type(self)(self._data.copy(), dtype=self.dtype)
  145. def astype(self, dtype, copy=True):
  146. if is_dtype_equal(dtype, self._dtype):
  147. if not copy:
  148. return self
  149. dtype = pandas_dtype(dtype)
  150. if isinstance(dtype, type(self.dtype)):
  151. return type(self)(self._data, copy=copy, context=dtype.context)
  152. return super().astype(dtype, copy=copy)
  153. def __setitem__(self, key, value):
  154. if is_list_like(value):
  155. if is_scalar(key):
  156. raise ValueError("setting an array element with a sequence.")
  157. value = [decimal.Decimal(v) for v in value]
  158. else:
  159. value = decimal.Decimal(value)
  160. key = check_array_indexer(self, key)
  161. self._data[key] = value
  162. def __len__(self) -> int:
  163. return len(self._data)
  164. def __contains__(self, item) -> bool | np.bool_:
  165. if not isinstance(item, decimal.Decimal):
  166. return False
  167. elif item.is_nan():
  168. return self.isna().any()
  169. else:
  170. return super().__contains__(item)
  171. @property
  172. def nbytes(self) -> int:
  173. n = len(self)
  174. if n:
  175. return n * sys.getsizeof(self[0])
  176. return 0
  177. def isna(self):
  178. return np.array([x.is_nan() for x in self._data], dtype=bool)
  179. @property
  180. def _na_value(self):
  181. return decimal.Decimal("NaN")
  182. def _formatter(self, boxed=False):
  183. if boxed:
  184. return "Decimal: {}".format
  185. return repr
  186. @classmethod
  187. def _concat_same_type(cls, to_concat):
  188. return cls(np.concatenate([x._data for x in to_concat]))
  189. def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
  190. if skipna:
  191. # If we don't have any NAs, we can ignore skipna
  192. if self.isna().any():
  193. other = self[~self.isna()]
  194. return other._reduce(name, **kwargs)
  195. if name == "sum" and len(self) == 0:
  196. # GH#29630 avoid returning int 0 or np.bool_(False) on old numpy
  197. return decimal.Decimal(0)
  198. try:
  199. op = getattr(self.data, name)
  200. except AttributeError as err:
  201. raise NotImplementedError(
  202. f"decimal does not support the {name} operation"
  203. ) from err
  204. return op(axis=0)
  205. def _cmp_method(self, other, op):
  206. # For use with OpsMixin
  207. def convert_values(param):
  208. if isinstance(param, ExtensionArray) or is_list_like(param):
  209. ovalues = param
  210. else:
  211. # Assume it's an object
  212. ovalues = [param] * len(self)
  213. return ovalues
  214. lvalues = self
  215. rvalues = convert_values(other)
  216. # If the operator is not defined for the underlying objects,
  217. # a TypeError should be raised
  218. res = [op(a, b) for (a, b) in zip(lvalues, rvalues)]
  219. return np.asarray(res, dtype=bool)
  220. def value_counts(self, dropna: bool = True):
  221. from pandas.core.algorithms import value_counts
  222. return value_counts(self.to_numpy(), dropna=dropna)
  223. def to_decimal(values, context=None):
  224. return DecimalArray([decimal.Decimal(x) for x in values], context=context)
  225. def make_data():
  226. return [decimal.Decimal(random.random()) for _ in range(100)]
  227. DecimalArray._add_arithmetic_ops()