array.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. """
  2. Test extension array that has custom attribute information (not stored on the dtype).
  3. """
  4. from __future__ import annotations
  5. import numbers
  6. import numpy as np
  7. from pandas._typing import type_t
  8. from pandas.core.dtypes.base import ExtensionDtype
  9. import pandas as pd
  10. from pandas.core.arrays import ExtensionArray
  11. class FloatAttrDtype(ExtensionDtype):
  12. type = float
  13. name = "float_attr"
  14. na_value = np.nan
  15. @classmethod
  16. def construct_array_type(cls) -> type_t[FloatAttrArray]:
  17. """
  18. Return the array type associated with this dtype.
  19. Returns
  20. -------
  21. type
  22. """
  23. return FloatAttrArray
  24. class FloatAttrArray(ExtensionArray):
  25. dtype = FloatAttrDtype()
  26. __array_priority__ = 1000
  27. def __init__(self, values, attr=None) -> None:
  28. if not isinstance(values, np.ndarray):
  29. raise TypeError("Need to pass a numpy array of float64 dtype as values")
  30. if not values.dtype == "float64":
  31. raise TypeError("Need to pass a numpy array of float64 dtype as values")
  32. self.data = values
  33. self.attr = attr
  34. @classmethod
  35. def _from_sequence(cls, scalars, dtype=None, copy=False):
  36. data = np.array(scalars, dtype="float64", copy=copy)
  37. return cls(data)
  38. def __getitem__(self, item):
  39. if isinstance(item, numbers.Integral):
  40. return self.data[item]
  41. else:
  42. # slice, list-like, mask
  43. item = pd.api.indexers.check_array_indexer(self, item)
  44. return type(self)(self.data[item], self.attr)
  45. def __len__(self) -> int:
  46. return len(self.data)
  47. def isna(self):
  48. return np.isnan(self.data)
  49. def take(self, indexer, allow_fill=False, fill_value=None):
  50. from pandas.api.extensions import take
  51. data = self.data
  52. if allow_fill and fill_value is None:
  53. fill_value = self.dtype.na_value
  54. result = take(data, indexer, fill_value=fill_value, allow_fill=allow_fill)
  55. return type(self)(result, self.attr)
  56. def copy(self):
  57. return type(self)(self.data.copy(), self.attr)
  58. @classmethod
  59. def _concat_same_type(cls, to_concat):
  60. data = np.concatenate([x.data for x in to_concat])
  61. attr = to_concat[0].attr if len(to_concat) else None
  62. return cls(data, attr)