extension.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. """
  2. Shared methods for Index subclasses backed by ExtensionArray.
  3. """
  4. from __future__ import annotations
  5. from typing import (
  6. TYPE_CHECKING,
  7. Callable,
  8. TypeVar,
  9. )
  10. import numpy as np
  11. from pandas._typing import (
  12. ArrayLike,
  13. npt,
  14. )
  15. from pandas.util._decorators import (
  16. cache_readonly,
  17. doc,
  18. )
  19. from pandas.core.dtypes.generic import ABCDataFrame
  20. from pandas.core.indexes.base import Index
  21. if TYPE_CHECKING:
  22. from pandas.core.arrays import IntervalArray
  23. from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
  24. _T = TypeVar("_T", bound="NDArrayBackedExtensionIndex")
  25. _ExtensionIndexT = TypeVar("_ExtensionIndexT", bound="ExtensionIndex")
  26. def _inherit_from_data(
  27. name: str, delegate: type, cache: bool = False, wrap: bool = False
  28. ):
  29. """
  30. Make an alias for a method of the underlying ExtensionArray.
  31. Parameters
  32. ----------
  33. name : str
  34. Name of an attribute the class should inherit from its EA parent.
  35. delegate : class
  36. cache : bool, default False
  37. Whether to convert wrapped properties into cache_readonly
  38. wrap : bool, default False
  39. Whether to wrap the inherited result in an Index.
  40. Returns
  41. -------
  42. attribute, method, property, or cache_readonly
  43. """
  44. attr = getattr(delegate, name)
  45. if isinstance(attr, property) or type(attr).__name__ == "getset_descriptor":
  46. # getset_descriptor i.e. property defined in cython class
  47. if cache:
  48. def cached(self):
  49. return getattr(self._data, name)
  50. cached.__name__ = name
  51. cached.__doc__ = attr.__doc__
  52. method = cache_readonly(cached)
  53. else:
  54. def fget(self):
  55. result = getattr(self._data, name)
  56. if wrap:
  57. if isinstance(result, type(self._data)):
  58. return type(self)._simple_new(result, name=self.name)
  59. elif isinstance(result, ABCDataFrame):
  60. return result.set_index(self)
  61. return Index(result, name=self.name)
  62. return result
  63. def fset(self, value) -> None:
  64. setattr(self._data, name, value)
  65. fget.__name__ = name
  66. fget.__doc__ = attr.__doc__
  67. method = property(fget, fset)
  68. elif not callable(attr):
  69. # just a normal attribute, no wrapping
  70. method = attr
  71. else:
  72. # error: Incompatible redefinition (redefinition with type "Callable[[Any,
  73. # VarArg(Any), KwArg(Any)], Any]", original type "property")
  74. def method(self, *args, **kwargs): # type: ignore[misc]
  75. if "inplace" in kwargs:
  76. raise ValueError(f"cannot use inplace with {type(self).__name__}")
  77. result = attr(self._data, *args, **kwargs)
  78. if wrap:
  79. if isinstance(result, type(self._data)):
  80. return type(self)._simple_new(result, name=self.name)
  81. elif isinstance(result, ABCDataFrame):
  82. return result.set_index(self)
  83. return Index(result, name=self.name)
  84. return result
  85. # error: "property" has no attribute "__name__"
  86. method.__name__ = name # type: ignore[attr-defined]
  87. method.__doc__ = attr.__doc__
  88. return method
  89. def inherit_names(
  90. names: list[str], delegate: type, cache: bool = False, wrap: bool = False
  91. ) -> Callable[[type[_ExtensionIndexT]], type[_ExtensionIndexT]]:
  92. """
  93. Class decorator to pin attributes from an ExtensionArray to a Index subclass.
  94. Parameters
  95. ----------
  96. names : List[str]
  97. delegate : class
  98. cache : bool, default False
  99. wrap : bool, default False
  100. Whether to wrap the inherited result in an Index.
  101. """
  102. def wrapper(cls: type[_ExtensionIndexT]) -> type[_ExtensionIndexT]:
  103. for name in names:
  104. meth = _inherit_from_data(name, delegate, cache=cache, wrap=wrap)
  105. setattr(cls, name, meth)
  106. return cls
  107. return wrapper
  108. class ExtensionIndex(Index):
  109. """
  110. Index subclass for indexes backed by ExtensionArray.
  111. """
  112. # The base class already passes through to _data:
  113. # size, __len__, dtype
  114. _data: IntervalArray | NDArrayBackedExtensionArray
  115. # ---------------------------------------------------------------------
  116. def _validate_fill_value(self, value):
  117. """
  118. Convert value to be insertable to underlying array.
  119. """
  120. return self._data._validate_setitem_value(value)
  121. @doc(Index.map)
  122. def map(self, mapper, na_action=None):
  123. # Try to run function on index first, and then on elements of index
  124. # Especially important for group-by functionality
  125. try:
  126. result = mapper(self)
  127. # Try to use this result if we can
  128. if isinstance(result, np.ndarray):
  129. result = Index(result)
  130. if not isinstance(result, Index):
  131. raise TypeError("The map function must return an Index object")
  132. return result
  133. except Exception:
  134. return self.astype(object).map(mapper)
  135. @cache_readonly
  136. def _isnan(self) -> npt.NDArray[np.bool_]:
  137. # error: Incompatible return value type (got "ExtensionArray", expected
  138. # "ndarray")
  139. return self._data.isna() # type: ignore[return-value]
  140. class NDArrayBackedExtensionIndex(ExtensionIndex):
  141. """
  142. Index subclass for indexes backed by NDArrayBackedExtensionArray.
  143. """
  144. _data: NDArrayBackedExtensionArray
  145. def _get_engine_target(self) -> np.ndarray:
  146. return self._data._ndarray
  147. def _from_join_target(self, result: np.ndarray) -> ArrayLike:
  148. assert result.dtype == self._data._ndarray.dtype
  149. return self._data._from_backing_data(result)