string_arrow.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. from __future__ import annotations
  2. import re
  3. from typing import (
  4. Callable,
  5. Union,
  6. )
  7. import numpy as np
  8. from pandas._libs import (
  9. lib,
  10. missing as libmissing,
  11. )
  12. from pandas._typing import (
  13. Dtype,
  14. Scalar,
  15. npt,
  16. )
  17. from pandas.compat import pa_version_under7p0
  18. from pandas.core.dtypes.common import (
  19. is_bool_dtype,
  20. is_dtype_equal,
  21. is_integer_dtype,
  22. is_object_dtype,
  23. is_scalar,
  24. is_string_dtype,
  25. pandas_dtype,
  26. )
  27. from pandas.core.dtypes.missing import isna
  28. from pandas.core.arrays.arrow import ArrowExtensionArray
  29. from pandas.core.arrays.boolean import BooleanDtype
  30. from pandas.core.arrays.integer import Int64Dtype
  31. from pandas.core.arrays.numeric import NumericDtype
  32. from pandas.core.arrays.string_ import (
  33. BaseStringArray,
  34. StringDtype,
  35. )
  36. from pandas.core.strings.object_array import ObjectStringArrayMixin
  37. if not pa_version_under7p0:
  38. import pyarrow as pa
  39. import pyarrow.compute as pc
  40. from pandas.core.arrays.arrow._arrow_utils import fallback_performancewarning
  41. ArrowStringScalarOrNAT = Union[str, libmissing.NAType]
  42. def _chk_pyarrow_available() -> None:
  43. if pa_version_under7p0:
  44. msg = "pyarrow>=7.0.0 is required for PyArrow backed ArrowExtensionArray."
  45. raise ImportError(msg)
  46. # TODO: Inherit directly from BaseStringArrayMethods. Currently we inherit from
  47. # ObjectStringArrayMixin because we want to have the object-dtype based methods as
  48. # fallback for the ones that pyarrow doesn't yet support
  49. class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringArray):
  50. """
  51. Extension array for string data in a ``pyarrow.ChunkedArray``.
  52. .. versionadded:: 1.2.0
  53. .. warning::
  54. ArrowStringArray is considered experimental. The implementation and
  55. parts of the API may change without warning.
  56. Parameters
  57. ----------
  58. values : pyarrow.Array or pyarrow.ChunkedArray
  59. The array of data.
  60. Attributes
  61. ----------
  62. None
  63. Methods
  64. -------
  65. None
  66. See Also
  67. --------
  68. :func:`pandas.array`
  69. The recommended function for creating a ArrowStringArray.
  70. Series.str
  71. The string methods are available on Series backed by
  72. a ArrowStringArray.
  73. Notes
  74. -----
  75. ArrowStringArray returns a BooleanArray for comparison methods.
  76. Examples
  77. --------
  78. >>> pd.array(['This is', 'some text', None, 'data.'], dtype="string[pyarrow]")
  79. <ArrowStringArray>
  80. ['This is', 'some text', <NA>, 'data.']
  81. Length: 4, dtype: string
  82. """
  83. # error: Incompatible types in assignment (expression has type "StringDtype",
  84. # base class "ArrowExtensionArray" defined the type as "ArrowDtype")
  85. _dtype: StringDtype # type: ignore[assignment]
  86. def __init__(self, values) -> None:
  87. super().__init__(values)
  88. self._dtype = StringDtype(storage="pyarrow")
  89. if not pa.types.is_string(self._data.type):
  90. raise ValueError(
  91. "ArrowStringArray requires a PyArrow (chunked) array of string type"
  92. )
  93. def __len__(self) -> int:
  94. """
  95. Length of this array.
  96. Returns
  97. -------
  98. length : int
  99. """
  100. return len(self._data)
  101. @classmethod
  102. def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False):
  103. from pandas.core.arrays.masked import BaseMaskedArray
  104. _chk_pyarrow_available()
  105. if dtype and not (isinstance(dtype, str) and dtype == "string"):
  106. dtype = pandas_dtype(dtype)
  107. assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow"
  108. if isinstance(scalars, BaseMaskedArray):
  109. # avoid costly conversion to object dtype in ensure_string_array and
  110. # numerical issues with Float32Dtype
  111. na_values = scalars._mask
  112. result = scalars._data
  113. result = lib.ensure_string_array(result, copy=copy, convert_na_value=False)
  114. return cls(pa.array(result, mask=na_values, type=pa.string()))
  115. elif isinstance(scalars, (pa.Array, pa.ChunkedArray)):
  116. return cls(pc.cast(scalars, pa.string()))
  117. # convert non-na-likes to str
  118. result = lib.ensure_string_array(scalars, copy=copy)
  119. return cls(pa.array(result, type=pa.string(), from_pandas=True))
  120. @classmethod
  121. def _from_sequence_of_strings(
  122. cls, strings, dtype: Dtype | None = None, copy: bool = False
  123. ):
  124. return cls._from_sequence(strings, dtype=dtype, copy=copy)
  125. @property
  126. def dtype(self) -> StringDtype: # type: ignore[override]
  127. """
  128. An instance of 'string[pyarrow]'.
  129. """
  130. return self._dtype
  131. def insert(self, loc: int, item) -> ArrowStringArray:
  132. if not isinstance(item, str) and item is not libmissing.NA:
  133. raise TypeError("Scalar must be NA or str")
  134. return super().insert(loc, item)
  135. def _maybe_convert_setitem_value(self, value):
  136. """Maybe convert value to be pyarrow compatible."""
  137. if is_scalar(value):
  138. if isna(value):
  139. value = None
  140. elif not isinstance(value, str):
  141. raise TypeError("Scalar must be NA or str")
  142. else:
  143. value = np.array(value, dtype=object, copy=True)
  144. value[isna(value)] = None
  145. for v in value:
  146. if not (v is None or isinstance(v, str)):
  147. raise TypeError("Scalar must be NA or str")
  148. return super()._maybe_convert_setitem_value(value)
  149. def isin(self, values) -> npt.NDArray[np.bool_]:
  150. value_set = [
  151. pa_scalar.as_py()
  152. for pa_scalar in [pa.scalar(value, from_pandas=True) for value in values]
  153. if pa_scalar.type in (pa.string(), pa.null())
  154. ]
  155. # short-circuit to return all False array.
  156. if not len(value_set):
  157. return np.zeros(len(self), dtype=bool)
  158. result = pc.is_in(self._data, value_set=pa.array(value_set))
  159. # pyarrow 2.0.0 returned nulls, so we explicily specify dtype to convert nulls
  160. # to False
  161. return np.array(result, dtype=np.bool_)
  162. def astype(self, dtype, copy: bool = True):
  163. dtype = pandas_dtype(dtype)
  164. if is_dtype_equal(dtype, self.dtype):
  165. if copy:
  166. return self.copy()
  167. return self
  168. elif isinstance(dtype, NumericDtype):
  169. data = self._data.cast(pa.from_numpy_dtype(dtype.numpy_dtype))
  170. return dtype.__from_arrow__(data)
  171. elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.floating):
  172. return self.to_numpy(dtype=dtype, na_value=np.nan)
  173. return super().astype(dtype, copy=copy)
  174. # ------------------------------------------------------------------------
  175. # String methods interface
  176. # error: Incompatible types in assignment (expression has type "NAType",
  177. # base class "ObjectStringArrayMixin" defined the type as "float")
  178. _str_na_value = libmissing.NA # type: ignore[assignment]
  179. def _str_map(
  180. self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
  181. ):
  182. # TODO: de-duplicate with StringArray method. This method is moreless copy and
  183. # paste.
  184. from pandas.arrays import (
  185. BooleanArray,
  186. IntegerArray,
  187. )
  188. if dtype is None:
  189. dtype = self.dtype
  190. if na_value is None:
  191. na_value = self.dtype.na_value
  192. mask = isna(self)
  193. arr = np.asarray(self)
  194. if is_integer_dtype(dtype) or is_bool_dtype(dtype):
  195. constructor: type[IntegerArray] | type[BooleanArray]
  196. if is_integer_dtype(dtype):
  197. constructor = IntegerArray
  198. else:
  199. constructor = BooleanArray
  200. na_value_is_na = isna(na_value)
  201. if na_value_is_na:
  202. na_value = 1
  203. result = lib.map_infer_mask(
  204. arr,
  205. f,
  206. mask.view("uint8"),
  207. convert=False,
  208. na_value=na_value,
  209. # error: Argument 1 to "dtype" has incompatible type
  210. # "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
  211. # "Type[object]"
  212. dtype=np.dtype(dtype), # type: ignore[arg-type]
  213. )
  214. if not na_value_is_na:
  215. mask[:] = False
  216. return constructor(result, mask)
  217. elif is_string_dtype(dtype) and not is_object_dtype(dtype):
  218. # i.e. StringDtype
  219. result = lib.map_infer_mask(
  220. arr, f, mask.view("uint8"), convert=False, na_value=na_value
  221. )
  222. result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True)
  223. return type(self)(result)
  224. else:
  225. # This is when the result type is object. We reach this when
  226. # -> We know the result type is truly object (e.g. .encode returns bytes
  227. # or .findall returns a list).
  228. # -> We don't know the result type. E.g. `.get` can return anything.
  229. return lib.map_infer_mask(arr, f, mask.view("uint8"))
  230. def _str_contains(
  231. self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
  232. ):
  233. if flags:
  234. fallback_performancewarning()
  235. return super()._str_contains(pat, case, flags, na, regex)
  236. if regex:
  237. if case is False:
  238. fallback_performancewarning()
  239. return super()._str_contains(pat, case, flags, na, regex)
  240. else:
  241. result = pc.match_substring_regex(self._data, pat)
  242. else:
  243. if case:
  244. result = pc.match_substring(self._data, pat)
  245. else:
  246. result = pc.match_substring(pc.utf8_upper(self._data), pat.upper())
  247. result = BooleanDtype().__from_arrow__(result)
  248. if not isna(na):
  249. result[isna(result)] = bool(na)
  250. return result
  251. def _str_startswith(self, pat: str, na=None):
  252. pat = f"^{re.escape(pat)}"
  253. return self._str_contains(pat, na=na, regex=True)
  254. def _str_endswith(self, pat: str, na=None):
  255. pat = f"{re.escape(pat)}$"
  256. return self._str_contains(pat, na=na, regex=True)
  257. def _str_replace(
  258. self,
  259. pat: str | re.Pattern,
  260. repl: str | Callable,
  261. n: int = -1,
  262. case: bool = True,
  263. flags: int = 0,
  264. regex: bool = True,
  265. ):
  266. if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
  267. fallback_performancewarning()
  268. return super()._str_replace(pat, repl, n, case, flags, regex)
  269. func = pc.replace_substring_regex if regex else pc.replace_substring
  270. result = func(self._data, pattern=pat, replacement=repl, max_replacements=n)
  271. return type(self)(result)
  272. def _str_match(
  273. self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
  274. ):
  275. if not pat.startswith("^"):
  276. pat = f"^{pat}"
  277. return self._str_contains(pat, case, flags, na, regex=True)
  278. def _str_fullmatch(
  279. self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
  280. ):
  281. if not pat.endswith("$") or pat.endswith("//$"):
  282. pat = f"{pat}$"
  283. return self._str_match(pat, case, flags, na)
  284. def _str_isalnum(self):
  285. result = pc.utf8_is_alnum(self._data)
  286. return BooleanDtype().__from_arrow__(result)
  287. def _str_isalpha(self):
  288. result = pc.utf8_is_alpha(self._data)
  289. return BooleanDtype().__from_arrow__(result)
  290. def _str_isdecimal(self):
  291. result = pc.utf8_is_decimal(self._data)
  292. return BooleanDtype().__from_arrow__(result)
  293. def _str_isdigit(self):
  294. result = pc.utf8_is_digit(self._data)
  295. return BooleanDtype().__from_arrow__(result)
  296. def _str_islower(self):
  297. result = pc.utf8_is_lower(self._data)
  298. return BooleanDtype().__from_arrow__(result)
  299. def _str_isnumeric(self):
  300. result = pc.utf8_is_numeric(self._data)
  301. return BooleanDtype().__from_arrow__(result)
  302. def _str_isspace(self):
  303. result = pc.utf8_is_space(self._data)
  304. return BooleanDtype().__from_arrow__(result)
  305. def _str_istitle(self):
  306. result = pc.utf8_is_title(self._data)
  307. return BooleanDtype().__from_arrow__(result)
  308. def _str_isupper(self):
  309. result = pc.utf8_is_upper(self._data)
  310. return BooleanDtype().__from_arrow__(result)
  311. def _str_len(self):
  312. result = pc.utf8_length(self._data)
  313. return Int64Dtype().__from_arrow__(result)
  314. def _str_lower(self):
  315. return type(self)(pc.utf8_lower(self._data))
  316. def _str_upper(self):
  317. return type(self)(pc.utf8_upper(self._data))
  318. def _str_strip(self, to_strip=None):
  319. if to_strip is None:
  320. result = pc.utf8_trim_whitespace(self._data)
  321. else:
  322. result = pc.utf8_trim(self._data, characters=to_strip)
  323. return type(self)(result)
  324. def _str_lstrip(self, to_strip=None):
  325. if to_strip is None:
  326. result = pc.utf8_ltrim_whitespace(self._data)
  327. else:
  328. result = pc.utf8_ltrim(self._data, characters=to_strip)
  329. return type(self)(result)
  330. def _str_rstrip(self, to_strip=None):
  331. if to_strip is None:
  332. result = pc.utf8_rtrim_whitespace(self._data)
  333. else:
  334. result = pc.utf8_rtrim(self._data, characters=to_strip)
  335. return type(self)(result)