test_arrow.py 92 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749
  1. """
  2. This file contains a minimal set of tests for compliance with the extension
  3. array interface test suite, and should contain no other tests.
  4. The test suite for the full functionality of the array is located in
  5. `pandas/tests/arrays/`.
  6. The tests in this file are inherited from the BaseExtensionTests, and only
  7. minimal tweaks should be applied to get the tests passing (by overwriting a
  8. parent method).
  9. Additional tests should either be added to one of the BaseExtensionTests
  10. classes (if they are relevant for the extension interface for all dtypes), or
  11. be added to the array-specific tests in `pandas/tests/arrays/`.
  12. """
  13. from datetime import (
  14. date,
  15. datetime,
  16. time,
  17. timedelta,
  18. )
  19. from decimal import Decimal
  20. from io import (
  21. BytesIO,
  22. StringIO,
  23. )
  24. import operator
  25. import pickle
  26. import re
  27. import numpy as np
  28. import pytest
  29. from pandas._libs import lib
  30. from pandas.compat import (
  31. PY311,
  32. is_ci_environment,
  33. is_platform_windows,
  34. pa_version_under7p0,
  35. pa_version_under8p0,
  36. pa_version_under9p0,
  37. pa_version_under11p0,
  38. )
  39. from pandas.errors import PerformanceWarning
  40. from pandas.core.dtypes.common import is_any_int_dtype
  41. from pandas.core.dtypes.dtypes import CategoricalDtypeType
  42. import pandas as pd
  43. import pandas._testing as tm
  44. from pandas.api.extensions import no_default
  45. from pandas.api.types import (
  46. is_bool_dtype,
  47. is_float_dtype,
  48. is_integer_dtype,
  49. is_numeric_dtype,
  50. is_signed_integer_dtype,
  51. is_string_dtype,
  52. is_unsigned_integer_dtype,
  53. )
  54. from pandas.tests.extension import base
  55. pa = pytest.importorskip("pyarrow", minversion="7.0.0")
  56. from pandas.core.arrays.arrow.array import ArrowExtensionArray
  57. from pandas.core.arrays.arrow.dtype import ArrowDtype # isort:skip
  58. @pytest.fixture(params=tm.ALL_PYARROW_DTYPES, ids=str)
  59. def dtype(request):
  60. return ArrowDtype(pyarrow_dtype=request.param)
  61. @pytest.fixture
  62. def data(dtype):
  63. pa_dtype = dtype.pyarrow_dtype
  64. if pa.types.is_boolean(pa_dtype):
  65. data = [True, False] * 4 + [None] + [True, False] * 44 + [None] + [True, False]
  66. elif pa.types.is_floating(pa_dtype):
  67. data = [1.0, 0.0] * 4 + [None] + [-2.0, -1.0] * 44 + [None] + [0.5, 99.5]
  68. elif pa.types.is_signed_integer(pa_dtype):
  69. data = [1, 0] * 4 + [None] + [-2, -1] * 44 + [None] + [1, 99]
  70. elif pa.types.is_unsigned_integer(pa_dtype):
  71. data = [1, 0] * 4 + [None] + [2, 1] * 44 + [None] + [1, 99]
  72. elif pa.types.is_decimal(pa_dtype):
  73. data = (
  74. [Decimal("1"), Decimal("0.0")] * 4
  75. + [None]
  76. + [Decimal("-2.0"), Decimal("-1.0")] * 44
  77. + [None]
  78. + [Decimal("0.5"), Decimal("33.123")]
  79. )
  80. elif pa.types.is_date(pa_dtype):
  81. data = (
  82. [date(2022, 1, 1), date(1999, 12, 31)] * 4
  83. + [None]
  84. + [date(2022, 1, 1), date(2022, 1, 1)] * 44
  85. + [None]
  86. + [date(1999, 12, 31), date(1999, 12, 31)]
  87. )
  88. elif pa.types.is_timestamp(pa_dtype):
  89. data = (
  90. [datetime(2020, 1, 1, 1, 1, 1, 1), datetime(1999, 1, 1, 1, 1, 1, 1)] * 4
  91. + [None]
  92. + [datetime(2020, 1, 1, 1), datetime(1999, 1, 1, 1)] * 44
  93. + [None]
  94. + [datetime(2020, 1, 1), datetime(1999, 1, 1)]
  95. )
  96. elif pa.types.is_duration(pa_dtype):
  97. data = (
  98. [timedelta(1), timedelta(1, 1)] * 4
  99. + [None]
  100. + [timedelta(-1), timedelta(0)] * 44
  101. + [None]
  102. + [timedelta(-10), timedelta(10)]
  103. )
  104. elif pa.types.is_time(pa_dtype):
  105. data = (
  106. [time(12, 0), time(0, 12)] * 4
  107. + [None]
  108. + [time(0, 0), time(1, 1)] * 44
  109. + [None]
  110. + [time(0, 5), time(5, 0)]
  111. )
  112. elif pa.types.is_string(pa_dtype):
  113. data = ["a", "b"] * 4 + [None] + ["1", "2"] * 44 + [None] + ["!", ">"]
  114. elif pa.types.is_binary(pa_dtype):
  115. data = [b"a", b"b"] * 4 + [None] + [b"1", b"2"] * 44 + [None] + [b"!", b">"]
  116. else:
  117. raise NotImplementedError
  118. return pd.array(data, dtype=dtype)
  119. @pytest.fixture
  120. def data_missing(data):
  121. """Length-2 array with [NA, Valid]"""
  122. return type(data)._from_sequence([None, data[0]], dtype=data.dtype)
  123. @pytest.fixture(params=["data", "data_missing"])
  124. def all_data(request, data, data_missing):
  125. """Parametrized fixture returning 'data' or 'data_missing' integer arrays.
  126. Used to test dtype conversion with and without missing values.
  127. """
  128. if request.param == "data":
  129. return data
  130. elif request.param == "data_missing":
  131. return data_missing
  132. @pytest.fixture
  133. def data_for_grouping(dtype):
  134. """
  135. Data for factorization, grouping, and unique tests.
  136. Expected to be like [B, B, NA, NA, A, A, B, C]
  137. Where A < B < C and NA is missing
  138. """
  139. pa_dtype = dtype.pyarrow_dtype
  140. if pa.types.is_boolean(pa_dtype):
  141. A = False
  142. B = True
  143. C = True
  144. elif pa.types.is_floating(pa_dtype):
  145. A = -1.1
  146. B = 0.0
  147. C = 1.1
  148. elif pa.types.is_signed_integer(pa_dtype):
  149. A = -1
  150. B = 0
  151. C = 1
  152. elif pa.types.is_unsigned_integer(pa_dtype):
  153. A = 0
  154. B = 1
  155. C = 10
  156. elif pa.types.is_date(pa_dtype):
  157. A = date(1999, 12, 31)
  158. B = date(2010, 1, 1)
  159. C = date(2022, 1, 1)
  160. elif pa.types.is_timestamp(pa_dtype):
  161. A = datetime(1999, 1, 1, 1, 1, 1, 1)
  162. B = datetime(2020, 1, 1)
  163. C = datetime(2020, 1, 1, 1)
  164. elif pa.types.is_duration(pa_dtype):
  165. A = timedelta(-1)
  166. B = timedelta(0)
  167. C = timedelta(1, 4)
  168. elif pa.types.is_time(pa_dtype):
  169. A = time(0, 0)
  170. B = time(0, 12)
  171. C = time(12, 12)
  172. elif pa.types.is_string(pa_dtype):
  173. A = "a"
  174. B = "b"
  175. C = "c"
  176. elif pa.types.is_binary(pa_dtype):
  177. A = b"a"
  178. B = b"b"
  179. C = b"c"
  180. elif pa.types.is_decimal(pa_dtype):
  181. A = Decimal("-1.1")
  182. B = Decimal("0.0")
  183. C = Decimal("1.1")
  184. else:
  185. raise NotImplementedError
  186. return pd.array([B, B, None, None, A, A, B, C], dtype=dtype)
  187. @pytest.fixture
  188. def data_for_sorting(data_for_grouping):
  189. """
  190. Length-3 array with a known sort order.
  191. This should be three items [B, C, A] with
  192. A < B < C
  193. """
  194. return type(data_for_grouping)._from_sequence(
  195. [data_for_grouping[0], data_for_grouping[7], data_for_grouping[4]],
  196. dtype=data_for_grouping.dtype,
  197. )
  198. @pytest.fixture
  199. def data_missing_for_sorting(data_for_grouping):
  200. """
  201. Length-3 array with a known sort order.
  202. This should be three items [B, NA, A] with
  203. A < B and NA missing.
  204. """
  205. return type(data_for_grouping)._from_sequence(
  206. [data_for_grouping[0], data_for_grouping[2], data_for_grouping[4]],
  207. dtype=data_for_grouping.dtype,
  208. )
  209. @pytest.fixture
  210. def data_for_twos(data):
  211. """Length-100 array in which all the elements are two."""
  212. pa_dtype = data.dtype.pyarrow_dtype
  213. if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
  214. return pd.array([2] * 100, dtype=data.dtype)
  215. # tests will be xfailed where 2 is not a valid scalar for pa_dtype
  216. return data
  217. @pytest.fixture
  218. def na_value():
  219. """The scalar missing value for this type. Default 'None'"""
  220. return pd.NA
  221. class TestBaseCasting(base.BaseCastingTests):
  222. def test_astype_str(self, data, request):
  223. pa_dtype = data.dtype.pyarrow_dtype
  224. if pa.types.is_binary(pa_dtype):
  225. request.node.add_marker(
  226. pytest.mark.xfail(
  227. reason=f"For {pa_dtype} .astype(str) decodes.",
  228. )
  229. )
  230. super().test_astype_str(data)
  231. class TestConstructors(base.BaseConstructorsTests):
  232. def test_from_dtype(self, data, request):
  233. pa_dtype = data.dtype.pyarrow_dtype
  234. if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):
  235. if pa.types.is_string(pa_dtype):
  236. reason = "ArrowDtype(pa.string()) != StringDtype('pyarrow')"
  237. else:
  238. reason = f"pyarrow.type_for_alias cannot infer {pa_dtype}"
  239. request.node.add_marker(
  240. pytest.mark.xfail(
  241. reason=reason,
  242. )
  243. )
  244. super().test_from_dtype(data)
  245. def test_from_sequence_pa_array(self, data):
  246. # https://github.com/pandas-dev/pandas/pull/47034#discussion_r955500784
  247. # data._data = pa.ChunkedArray
  248. result = type(data)._from_sequence(data._data)
  249. tm.assert_extension_array_equal(result, data)
  250. assert isinstance(result._data, pa.ChunkedArray)
  251. result = type(data)._from_sequence(data._data.combine_chunks())
  252. tm.assert_extension_array_equal(result, data)
  253. assert isinstance(result._data, pa.ChunkedArray)
  254. def test_from_sequence_pa_array_notimplemented(self, request):
  255. with pytest.raises(NotImplementedError, match="Converting strings to"):
  256. ArrowExtensionArray._from_sequence_of_strings(
  257. ["12-1"], dtype=pa.month_day_nano_interval()
  258. )
  259. def test_from_sequence_of_strings_pa_array(self, data, request):
  260. pa_dtype = data.dtype.pyarrow_dtype
  261. if pa.types.is_time64(pa_dtype) and pa_dtype.equals("time64[ns]") and not PY311:
  262. request.node.add_marker(
  263. pytest.mark.xfail(
  264. reason="Nanosecond time parsing not supported.",
  265. )
  266. )
  267. elif pa_version_under11p0 and (
  268. pa.types.is_duration(pa_dtype) or pa.types.is_decimal(pa_dtype)
  269. ):
  270. request.node.add_marker(
  271. pytest.mark.xfail(
  272. raises=pa.ArrowNotImplementedError,
  273. reason=f"pyarrow doesn't support parsing {pa_dtype}",
  274. )
  275. )
  276. elif pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is not None:
  277. if is_platform_windows() and is_ci_environment():
  278. request.node.add_marker(
  279. pytest.mark.xfail(
  280. raises=pa.ArrowInvalid,
  281. reason=(
  282. "TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
  283. "on CI to path to the tzdata for pyarrow."
  284. ),
  285. )
  286. )
  287. pa_array = data._data.cast(pa.string())
  288. result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype)
  289. tm.assert_extension_array_equal(result, data)
  290. pa_array = pa_array.combine_chunks()
  291. result = type(data)._from_sequence_of_strings(pa_array, dtype=data.dtype)
  292. tm.assert_extension_array_equal(result, data)
  293. class TestGetitemTests(base.BaseGetitemTests):
  294. pass
  295. class TestBaseAccumulateTests(base.BaseAccumulateTests):
  296. def check_accumulate(self, ser, op_name, skipna):
  297. result = getattr(ser, op_name)(skipna=skipna)
  298. if ser.dtype.kind == "m":
  299. # Just check that we match the integer behavior.
  300. ser = ser.astype("int64[pyarrow]")
  301. result = result.astype("int64[pyarrow]")
  302. result = result.astype("Float64")
  303. expected = getattr(ser.astype("Float64"), op_name)(skipna=skipna)
  304. self.assert_series_equal(result, expected, check_dtype=False)
  305. @pytest.mark.parametrize("skipna", [True, False])
  306. def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
  307. pa_type = data.dtype.pyarrow_dtype
  308. if (
  309. (
  310. pa.types.is_integer(pa_type)
  311. or pa.types.is_floating(pa_type)
  312. or pa.types.is_duration(pa_type)
  313. )
  314. and all_numeric_accumulations == "cumsum"
  315. and not pa_version_under9p0
  316. ):
  317. pytest.skip("These work, are tested by test_accumulate_series.")
  318. op_name = all_numeric_accumulations
  319. ser = pd.Series(data)
  320. with pytest.raises(NotImplementedError):
  321. getattr(ser, op_name)(skipna=skipna)
  322. @pytest.mark.parametrize("skipna", [True, False])
  323. def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
  324. pa_type = data.dtype.pyarrow_dtype
  325. op_name = all_numeric_accumulations
  326. ser = pd.Series(data)
  327. do_skip = False
  328. if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
  329. if op_name in ["cumsum", "cumprod"]:
  330. do_skip = True
  331. elif pa.types.is_temporal(pa_type) and not pa.types.is_duration(pa_type):
  332. if op_name in ["cumsum", "cumprod"]:
  333. do_skip = True
  334. elif pa.types.is_duration(pa_type):
  335. if op_name == "cumprod":
  336. do_skip = True
  337. if do_skip:
  338. pytest.skip(
  339. "These should *not* work, we test in test_accumulate_series_raises "
  340. "that these correctly raise."
  341. )
  342. if all_numeric_accumulations != "cumsum" or pa_version_under9p0:
  343. if request.config.option.skip_slow:
  344. # equivalent to marking these cases with @pytest.mark.slow,
  345. # these xfails take a long time to run because pytest
  346. # renders the exception messages even when not showing them
  347. pytest.skip("pyarrow xfail slow")
  348. request.node.add_marker(
  349. pytest.mark.xfail(
  350. reason=f"{all_numeric_accumulations} not implemented",
  351. raises=NotImplementedError,
  352. )
  353. )
  354. elif all_numeric_accumulations == "cumsum" and (
  355. pa.types.is_boolean(pa_type) or pa.types.is_decimal(pa_type)
  356. ):
  357. request.node.add_marker(
  358. pytest.mark.xfail(
  359. reason=f"{all_numeric_accumulations} not implemented for {pa_type}",
  360. raises=NotImplementedError,
  361. )
  362. )
  363. self.check_accumulate(ser, op_name, skipna)
  364. class TestBaseNumericReduce(base.BaseNumericReduceTests):
  365. def check_reduce(self, ser, op_name, skipna):
  366. pa_dtype = ser.dtype.pyarrow_dtype
  367. if op_name == "count":
  368. result = getattr(ser, op_name)()
  369. else:
  370. result = getattr(ser, op_name)(skipna=skipna)
  371. if pa.types.is_boolean(pa_dtype):
  372. # Can't convert if ser contains NA
  373. pytest.skip(
  374. "pandas boolean data with NA does not fully support all reductions"
  375. )
  376. elif pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
  377. ser = ser.astype("Float64")
  378. if op_name == "count":
  379. expected = getattr(ser, op_name)()
  380. else:
  381. expected = getattr(ser, op_name)(skipna=skipna)
  382. tm.assert_almost_equal(result, expected)
  383. @pytest.mark.parametrize("skipna", [True, False])
  384. def test_reduce_series(self, data, all_numeric_reductions, skipna, request):
  385. pa_dtype = data.dtype.pyarrow_dtype
  386. opname = all_numeric_reductions
  387. ser = pd.Series(data)
  388. should_work = True
  389. if pa.types.is_temporal(pa_dtype) and opname in [
  390. "sum",
  391. "var",
  392. "skew",
  393. "kurt",
  394. "prod",
  395. ]:
  396. if pa.types.is_duration(pa_dtype) and opname in ["sum"]:
  397. # summing timedeltas is one case that *is* well-defined
  398. pass
  399. else:
  400. should_work = False
  401. elif (
  402. pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
  403. ) and opname in [
  404. "sum",
  405. "mean",
  406. "median",
  407. "prod",
  408. "std",
  409. "sem",
  410. "var",
  411. "skew",
  412. "kurt",
  413. ]:
  414. should_work = False
  415. if not should_work:
  416. # matching the non-pyarrow versions, these operations *should* not
  417. # work for these dtypes
  418. msg = f"does not support reduction '{opname}'"
  419. with pytest.raises(TypeError, match=msg):
  420. getattr(ser, opname)(skipna=skipna)
  421. return
  422. xfail_mark = pytest.mark.xfail(
  423. raises=TypeError,
  424. reason=(
  425. f"{all_numeric_reductions} is not implemented in "
  426. f"pyarrow={pa.__version__} for {pa_dtype}"
  427. ),
  428. )
  429. if all_numeric_reductions in {"skew", "kurt"}:
  430. request.node.add_marker(xfail_mark)
  431. elif (
  432. all_numeric_reductions in {"var", "std", "median"}
  433. and pa_version_under7p0
  434. and pa.types.is_decimal(pa_dtype)
  435. ):
  436. request.node.add_marker(xfail_mark)
  437. elif all_numeric_reductions == "sem" and pa_version_under8p0:
  438. request.node.add_marker(xfail_mark)
  439. elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in {
  440. "sem",
  441. "std",
  442. "var",
  443. "median",
  444. }:
  445. request.node.add_marker(xfail_mark)
  446. super().test_reduce_series(data, all_numeric_reductions, skipna)
  447. @pytest.mark.parametrize("typ", ["int64", "uint64", "float64"])
  448. def test_median_not_approximate(self, typ):
  449. # GH 52679
  450. result = pd.Series([1, 2], dtype=f"{typ}[pyarrow]").median()
  451. assert result == 1.5
  452. class TestBaseBooleanReduce(base.BaseBooleanReduceTests):
  453. @pytest.mark.parametrize("skipna", [True, False])
  454. def test_reduce_series(
  455. self, data, all_boolean_reductions, skipna, na_value, request
  456. ):
  457. pa_dtype = data.dtype.pyarrow_dtype
  458. xfail_mark = pytest.mark.xfail(
  459. raises=TypeError,
  460. reason=(
  461. f"{all_boolean_reductions} is not implemented in "
  462. f"pyarrow={pa.__version__} for {pa_dtype}"
  463. ),
  464. )
  465. if pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype):
  466. # We *might* want to make this behave like the non-pyarrow cases,
  467. # but have not yet decided.
  468. request.node.add_marker(xfail_mark)
  469. op_name = all_boolean_reductions
  470. ser = pd.Series(data)
  471. if pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype):
  472. # xref GH#34479 we support this in our non-pyarrow datetime64 dtypes,
  473. # but it isn't obvious we _should_. For now, we keep the pyarrow
  474. # behavior which does not support this.
  475. with pytest.raises(TypeError, match="does not support reduction"):
  476. getattr(ser, op_name)(skipna=skipna)
  477. return
  478. result = getattr(ser, op_name)(skipna=skipna)
  479. assert result is (op_name == "any")
  480. class TestBaseGroupby(base.BaseGroupbyTests):
  481. def test_groupby_extension_no_sort(self, data_for_grouping, request):
  482. pa_dtype = data_for_grouping.dtype.pyarrow_dtype
  483. if pa.types.is_boolean(pa_dtype):
  484. request.node.add_marker(
  485. pytest.mark.xfail(
  486. reason=f"{pa_dtype} only has 2 unique possible values",
  487. )
  488. )
  489. super().test_groupby_extension_no_sort(data_for_grouping)
  490. def test_groupby_extension_transform(self, data_for_grouping, request):
  491. pa_dtype = data_for_grouping.dtype.pyarrow_dtype
  492. if pa.types.is_boolean(pa_dtype):
  493. request.node.add_marker(
  494. pytest.mark.xfail(
  495. reason=f"{pa_dtype} only has 2 unique possible values",
  496. )
  497. )
  498. super().test_groupby_extension_transform(data_for_grouping)
  499. @pytest.mark.parametrize("as_index", [True, False])
  500. def test_groupby_extension_agg(self, as_index, data_for_grouping, request):
  501. pa_dtype = data_for_grouping.dtype.pyarrow_dtype
  502. if pa.types.is_boolean(pa_dtype):
  503. request.node.add_marker(
  504. pytest.mark.xfail(
  505. raises=ValueError,
  506. reason=f"{pa_dtype} only has 2 unique possible values",
  507. )
  508. )
  509. super().test_groupby_extension_agg(as_index, data_for_grouping)
  510. def test_in_numeric_groupby(self, data_for_grouping):
  511. if is_string_dtype(data_for_grouping.dtype):
  512. df = pd.DataFrame(
  513. {
  514. "A": [1, 1, 2, 2, 3, 3, 1, 4],
  515. "B": data_for_grouping,
  516. "C": [1, 1, 1, 1, 1, 1, 1, 1],
  517. }
  518. )
  519. expected = pd.Index(["C"])
  520. with pytest.raises(TypeError, match="does not support"):
  521. df.groupby("A").sum().columns
  522. result = df.groupby("A").sum(numeric_only=True).columns
  523. tm.assert_index_equal(result, expected)
  524. else:
  525. super().test_in_numeric_groupby(data_for_grouping)
  526. class TestBaseDtype(base.BaseDtypeTests):
  527. def test_check_dtype(self, data, request):
  528. pa_dtype = data.dtype.pyarrow_dtype
  529. if pa.types.is_decimal(pa_dtype) and pa_version_under8p0:
  530. request.node.add_marker(
  531. pytest.mark.xfail(
  532. raises=ValueError,
  533. reason="decimal string repr affects numpy comparison",
  534. )
  535. )
  536. super().test_check_dtype(data)
  537. def test_construct_from_string_own_name(self, dtype, request):
  538. pa_dtype = dtype.pyarrow_dtype
  539. if pa.types.is_decimal(pa_dtype):
  540. request.node.add_marker(
  541. pytest.mark.xfail(
  542. raises=NotImplementedError,
  543. reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
  544. )
  545. )
  546. if pa.types.is_string(pa_dtype):
  547. # We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
  548. msg = r"string\[pyarrow\] should be constructed by StringDtype"
  549. with pytest.raises(TypeError, match=msg):
  550. dtype.construct_from_string(dtype.name)
  551. return
  552. super().test_construct_from_string_own_name(dtype)
  553. def test_is_dtype_from_name(self, dtype, request):
  554. pa_dtype = dtype.pyarrow_dtype
  555. if pa.types.is_string(pa_dtype):
  556. # We still support StringDtype('pyarrow') over ArrowDtype(pa.string())
  557. assert not type(dtype).is_dtype(dtype.name)
  558. else:
  559. if pa.types.is_decimal(pa_dtype):
  560. request.node.add_marker(
  561. pytest.mark.xfail(
  562. raises=NotImplementedError,
  563. reason=f"pyarrow.type_for_alias cannot infer {pa_dtype}",
  564. )
  565. )
  566. super().test_is_dtype_from_name(dtype)
  567. def test_construct_from_string_another_type_raises(self, dtype):
  568. msg = r"'another_type' must end with '\[pyarrow\]'"
  569. with pytest.raises(TypeError, match=msg):
  570. type(dtype).construct_from_string("another_type")
  571. def test_get_common_dtype(self, dtype, request):
  572. pa_dtype = dtype.pyarrow_dtype
  573. if (
  574. pa.types.is_date(pa_dtype)
  575. or pa.types.is_time(pa_dtype)
  576. or (
  577. pa.types.is_timestamp(pa_dtype)
  578. and (pa_dtype.unit != "ns" or pa_dtype.tz is not None)
  579. )
  580. or (pa.types.is_duration(pa_dtype) and pa_dtype.unit != "ns")
  581. or pa.types.is_binary(pa_dtype)
  582. or pa.types.is_decimal(pa_dtype)
  583. ):
  584. request.node.add_marker(
  585. pytest.mark.xfail(
  586. reason=(
  587. f"{pa_dtype} does not have associated numpy "
  588. f"dtype findable by find_common_type"
  589. )
  590. )
  591. )
  592. super().test_get_common_dtype(dtype)
  593. def test_is_not_string_type(self, dtype):
  594. pa_dtype = dtype.pyarrow_dtype
  595. if pa.types.is_string(pa_dtype):
  596. assert is_string_dtype(dtype)
  597. else:
  598. super().test_is_not_string_type(dtype)
  599. class TestBaseIndex(base.BaseIndexTests):
  600. pass
  601. class TestBaseInterface(base.BaseInterfaceTests):
  602. @pytest.mark.xfail(
  603. reason="GH 45419: pyarrow.ChunkedArray does not support views.", run=False
  604. )
  605. def test_view(self, data):
  606. super().test_view(data)
  607. class TestBaseMissing(base.BaseMissingTests):
  608. def test_fillna_no_op_returns_copy(self, data):
  609. data = data[~data.isna()]
  610. valid = data[0]
  611. result = data.fillna(valid)
  612. assert result is not data
  613. self.assert_extension_array_equal(result, data)
  614. with tm.assert_produces_warning(PerformanceWarning):
  615. result = data.fillna(method="backfill")
  616. assert result is not data
  617. self.assert_extension_array_equal(result, data)
  618. def test_fillna_series_method(self, data_missing, fillna_method):
  619. with tm.maybe_produces_warning(
  620. PerformanceWarning, fillna_method is not None, check_stacklevel=False
  621. ):
  622. super().test_fillna_series_method(data_missing, fillna_method)
  623. class TestBasePrinting(base.BasePrintingTests):
  624. pass
  625. class TestBaseReshaping(base.BaseReshapingTests):
  626. @pytest.mark.xfail(
  627. reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False
  628. )
  629. def test_transpose(self, data):
  630. super().test_transpose(data)
  631. class TestBaseSetitem(base.BaseSetitemTests):
  632. @pytest.mark.xfail(
  633. reason="GH 45419: pyarrow.ChunkedArray does not support views", run=False
  634. )
  635. def test_setitem_preserves_views(self, data):
  636. super().test_setitem_preserves_views(data)
  637. class TestBaseParsing(base.BaseParsingTests):
  638. @pytest.mark.parametrize("dtype_backend", ["pyarrow", no_default])
  639. @pytest.mark.parametrize("engine", ["c", "python"])
  640. def test_EA_types(self, engine, data, dtype_backend, request):
  641. pa_dtype = data.dtype.pyarrow_dtype
  642. if pa.types.is_decimal(pa_dtype):
  643. request.node.add_marker(
  644. pytest.mark.xfail(
  645. raises=NotImplementedError,
  646. reason=f"Parameterized types {pa_dtype} not supported.",
  647. )
  648. )
  649. elif pa.types.is_timestamp(pa_dtype) and pa_dtype.unit in ("us", "ns"):
  650. request.node.add_marker(
  651. pytest.mark.xfail(
  652. raises=ValueError,
  653. reason="https://github.com/pandas-dev/pandas/issues/49767",
  654. )
  655. )
  656. elif pa.types.is_binary(pa_dtype):
  657. request.node.add_marker(
  658. pytest.mark.xfail(reason="CSV parsers don't correctly handle binary")
  659. )
  660. elif (
  661. pa.types.is_duration(pa_dtype)
  662. and dtype_backend == "pyarrow"
  663. and engine == "python"
  664. ):
  665. request.node.add_marker(
  666. pytest.mark.xfail(
  667. raises=TypeError,
  668. reason="Invalid type for timedelta scalar: NAType",
  669. )
  670. )
  671. df = pd.DataFrame({"with_dtype": pd.Series(data, dtype=str(data.dtype))})
  672. csv_output = df.to_csv(index=False, na_rep=np.nan)
  673. if pa.types.is_binary(pa_dtype):
  674. csv_output = BytesIO(csv_output)
  675. else:
  676. csv_output = StringIO(csv_output)
  677. result = pd.read_csv(
  678. csv_output,
  679. dtype={"with_dtype": str(data.dtype)},
  680. engine=engine,
  681. dtype_backend=dtype_backend,
  682. )
  683. expected = df
  684. self.assert_frame_equal(result, expected)
  685. class TestBaseUnaryOps(base.BaseUnaryOpsTests):
  686. def test_invert(self, data, request):
  687. pa_dtype = data.dtype.pyarrow_dtype
  688. if not pa.types.is_boolean(pa_dtype):
  689. request.node.add_marker(
  690. pytest.mark.xfail(
  691. raises=pa.ArrowNotImplementedError,
  692. reason=f"pyarrow.compute.invert does support {pa_dtype}",
  693. )
  694. )
  695. super().test_invert(data)
  696. class TestBaseMethods(base.BaseMethodsTests):
  697. @pytest.mark.parametrize("periods", [1, -2])
  698. def test_diff(self, data, periods, request):
  699. pa_dtype = data.dtype.pyarrow_dtype
  700. if pa.types.is_unsigned_integer(pa_dtype) and periods == 1:
  701. request.node.add_marker(
  702. pytest.mark.xfail(
  703. raises=pa.ArrowInvalid,
  704. reason=(
  705. f"diff with {pa_dtype} and periods={periods} will overflow"
  706. ),
  707. )
  708. )
  709. super().test_diff(data, periods)
  710. def test_value_counts_returns_pyarrow_int64(self, data):
  711. # GH 51462
  712. data = data[:10]
  713. result = data.value_counts()
  714. assert result.dtype == ArrowDtype(pa.int64())
  715. def test_value_counts_with_normalize(self, data, request):
  716. data = data[:10].unique()
  717. values = np.array(data[~data.isna()])
  718. ser = pd.Series(data, dtype=data.dtype)
  719. result = ser.value_counts(normalize=True).sort_index()
  720. expected = pd.Series(
  721. [1 / len(values)] * len(values), index=result.index, name="proportion"
  722. )
  723. expected = expected.astype("double[pyarrow]")
  724. self.assert_series_equal(result, expected)
  725. def test_argmin_argmax(
  726. self, data_for_sorting, data_missing_for_sorting, na_value, request
  727. ):
  728. pa_dtype = data_for_sorting.dtype.pyarrow_dtype
  729. if pa.types.is_boolean(pa_dtype):
  730. request.node.add_marker(
  731. pytest.mark.xfail(
  732. reason=f"{pa_dtype} only has 2 unique possible values",
  733. )
  734. )
  735. elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
  736. request.node.add_marker(
  737. pytest.mark.xfail(
  738. reason=f"No pyarrow kernel for {pa_dtype}",
  739. raises=pa.ArrowNotImplementedError,
  740. )
  741. )
  742. super().test_argmin_argmax(data_for_sorting, data_missing_for_sorting, na_value)
  743. @pytest.mark.parametrize(
  744. "op_name, skipna, expected",
  745. [
  746. ("idxmax", True, 0),
  747. ("idxmin", True, 2),
  748. ("argmax", True, 0),
  749. ("argmin", True, 2),
  750. ("idxmax", False, np.nan),
  751. ("idxmin", False, np.nan),
  752. ("argmax", False, -1),
  753. ("argmin", False, -1),
  754. ],
  755. )
  756. def test_argreduce_series(
  757. self, data_missing_for_sorting, op_name, skipna, expected, request
  758. ):
  759. pa_dtype = data_missing_for_sorting.dtype.pyarrow_dtype
  760. if pa.types.is_decimal(pa_dtype) and pa_version_under7p0 and skipna:
  761. request.node.add_marker(
  762. pytest.mark.xfail(
  763. reason=f"No pyarrow kernel for {pa_dtype}",
  764. raises=pa.ArrowNotImplementedError,
  765. )
  766. )
  767. super().test_argreduce_series(
  768. data_missing_for_sorting, op_name, skipna, expected
  769. )
  770. def test_factorize(self, data_for_grouping, request):
  771. pa_dtype = data_for_grouping.dtype.pyarrow_dtype
  772. if pa.types.is_boolean(pa_dtype):
  773. request.node.add_marker(
  774. pytest.mark.xfail(
  775. reason=f"{pa_dtype} only has 2 unique possible values",
  776. )
  777. )
  778. super().test_factorize(data_for_grouping)
  779. _combine_le_expected_dtype = "bool[pyarrow]"
  780. def test_combine_add(self, data_repeated, request):
  781. pa_dtype = next(data_repeated(1)).dtype.pyarrow_dtype
  782. if pa.types.is_duration(pa_dtype):
  783. # TODO: this fails on the scalar addition constructing 'expected'
  784. # but not in the actual 'combine' call, so may be salvage-able
  785. mark = pytest.mark.xfail(
  786. raises=TypeError,
  787. reason=f"{pa_dtype} cannot be added to {pa_dtype}",
  788. )
  789. request.node.add_marker(mark)
  790. super().test_combine_add(data_repeated)
  791. elif pa.types.is_temporal(pa_dtype):
  792. # analogous to datetime64, these cannot be added
  793. orig_data1, orig_data2 = data_repeated(2)
  794. s1 = pd.Series(orig_data1)
  795. s2 = pd.Series(orig_data2)
  796. with pytest.raises(TypeError):
  797. s1.combine(s2, lambda x1, x2: x1 + x2)
  798. else:
  799. super().test_combine_add(data_repeated)
  800. def test_searchsorted(self, data_for_sorting, as_series, request):
  801. pa_dtype = data_for_sorting.dtype.pyarrow_dtype
  802. if pa.types.is_boolean(pa_dtype):
  803. request.node.add_marker(
  804. pytest.mark.xfail(
  805. reason=f"{pa_dtype} only has 2 unique possible values",
  806. )
  807. )
  808. super().test_searchsorted(data_for_sorting, as_series)
  809. def test_basic_equals(self, data):
  810. # https://github.com/pandas-dev/pandas/issues/34660
  811. assert pd.Series(data).equals(pd.Series(data))
  812. class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
  813. divmod_exc = NotImplementedError
  814. @classmethod
  815. def assert_equal(cls, left, right, **kwargs):
  816. if isinstance(left, pd.DataFrame):
  817. left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype
  818. right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype
  819. else:
  820. left_pa_type = left.dtype.pyarrow_dtype
  821. right_pa_type = right.dtype.pyarrow_dtype
  822. if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type):
  823. # decimal precision can resize in the result type depending on data
  824. # just compare the float values
  825. left = left.astype("float[pyarrow]")
  826. right = right.astype("float[pyarrow]")
  827. tm.assert_equal(left, right, **kwargs)
  828. def get_op_from_name(self, op_name):
  829. short_opname = op_name.strip("_")
  830. if short_opname == "rtruediv":
  831. # use the numpy version that won't raise on division by zero
  832. return lambda x, y: np.divide(y, x)
  833. elif short_opname == "rfloordiv":
  834. return lambda x, y: np.floor_divide(y, x)
  835. return tm.get_op_from_name(op_name)
  836. def _patch_combine(self, obj, other, op):
  837. # BaseOpsUtil._combine can upcast expected dtype
  838. # (because it generates expected on python scalars)
  839. # while ArrowExtensionArray maintains original type
  840. expected = base.BaseArithmeticOpsTests._combine(self, obj, other, op)
  841. was_frame = False
  842. if isinstance(expected, pd.DataFrame):
  843. was_frame = True
  844. expected_data = expected.iloc[:, 0]
  845. original_dtype = obj.iloc[:, 0].dtype
  846. else:
  847. expected_data = expected
  848. original_dtype = obj.dtype
  849. pa_expected = pa.array(expected_data._values)
  850. if pa.types.is_duration(pa_expected.type):
  851. # pyarrow sees sequence of datetime/timedelta objects and defaults
  852. # to "us" but the non-pointwise op retains unit
  853. unit = original_dtype.pyarrow_dtype.unit
  854. if type(other) in [datetime, timedelta] and unit in ["s", "ms"]:
  855. # pydatetime/pytimedelta objects have microsecond reso, so we
  856. # take the higher reso of the original and microsecond. Note
  857. # this matches what we would do with DatetimeArray/TimedeltaArray
  858. unit = "us"
  859. pa_expected = pa_expected.cast(f"duration[{unit}]")
  860. else:
  861. pa_expected = pa_expected.cast(original_dtype.pyarrow_dtype)
  862. pd_expected = type(expected_data._values)(pa_expected)
  863. if was_frame:
  864. expected = pd.DataFrame(
  865. pd_expected, index=expected.index, columns=expected.columns
  866. )
  867. else:
  868. expected = pd.Series(pd_expected)
  869. return expected
  870. def _is_temporal_supported(self, opname, pa_dtype):
  871. return not pa_version_under8p0 and (
  872. opname in ("__add__", "__radd__")
  873. and pa.types.is_duration(pa_dtype)
  874. or opname in ("__sub__", "__rsub__")
  875. and pa.types.is_temporal(pa_dtype)
  876. )
  877. def _get_scalar_exception(self, opname, pa_dtype):
  878. arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype)
  879. if opname in {
  880. "__mod__",
  881. "__rmod__",
  882. }:
  883. exc = NotImplementedError
  884. elif arrow_temporal_supported:
  885. exc = None
  886. elif opname in ["__add__", "__radd__"] and (
  887. pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
  888. ):
  889. exc = None
  890. elif not (
  891. pa.types.is_floating(pa_dtype)
  892. or pa.types.is_integer(pa_dtype)
  893. or pa.types.is_decimal(pa_dtype)
  894. ):
  895. exc = pa.ArrowNotImplementedError
  896. else:
  897. exc = None
  898. return exc
  899. def _get_arith_xfail_marker(self, opname, pa_dtype):
  900. mark = None
  901. arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype)
  902. if (
  903. opname == "__rpow__"
  904. and (
  905. pa.types.is_floating(pa_dtype)
  906. or pa.types.is_integer(pa_dtype)
  907. or pa.types.is_decimal(pa_dtype)
  908. )
  909. and not pa_version_under7p0
  910. ):
  911. mark = pytest.mark.xfail(
  912. reason=(
  913. f"GH#29997: 1**pandas.NA == 1 while 1**pyarrow.NA == NULL "
  914. f"for {pa_dtype}"
  915. )
  916. )
  917. elif arrow_temporal_supported:
  918. mark = pytest.mark.xfail(
  919. raises=TypeError,
  920. reason=(
  921. f"{opname} not supported between"
  922. f"pd.NA and {pa_dtype} Python scalar"
  923. ),
  924. )
  925. elif (
  926. opname == "__rfloordiv__"
  927. and (pa.types.is_integer(pa_dtype) or pa.types.is_decimal(pa_dtype))
  928. and not pa_version_under7p0
  929. ):
  930. mark = pytest.mark.xfail(
  931. raises=pa.ArrowInvalid,
  932. reason="divide by 0",
  933. )
  934. elif (
  935. opname == "__rtruediv__"
  936. and pa.types.is_decimal(pa_dtype)
  937. and not pa_version_under7p0
  938. ):
  939. mark = pytest.mark.xfail(
  940. raises=pa.ArrowInvalid,
  941. reason="divide by 0",
  942. )
  943. elif (
  944. opname == "__pow__"
  945. and pa.types.is_decimal(pa_dtype)
  946. and pa_version_under7p0
  947. ):
  948. mark = pytest.mark.xfail(
  949. raises=pa.ArrowInvalid,
  950. reason="Invalid decimal function: power_checked",
  951. )
  952. return mark
  953. def test_arith_series_with_scalar(
  954. self, data, all_arithmetic_operators, request, monkeypatch
  955. ):
  956. pa_dtype = data.dtype.pyarrow_dtype
  957. if all_arithmetic_operators == "__rmod__" and (
  958. pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
  959. ):
  960. pytest.skip("Skip testing Python string formatting")
  961. self.series_scalar_exc = self._get_scalar_exception(
  962. all_arithmetic_operators, pa_dtype
  963. )
  964. mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
  965. if mark is not None:
  966. request.node.add_marker(mark)
  967. if (
  968. (
  969. all_arithmetic_operators == "__floordiv__"
  970. and pa.types.is_integer(pa_dtype)
  971. )
  972. or pa.types.is_duration(pa_dtype)
  973. or pa.types.is_timestamp(pa_dtype)
  974. ):
  975. # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
  976. # not upcast
  977. monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
  978. super().test_arith_series_with_scalar(data, all_arithmetic_operators)
  979. def test_arith_frame_with_scalar(
  980. self, data, all_arithmetic_operators, request, monkeypatch
  981. ):
  982. pa_dtype = data.dtype.pyarrow_dtype
  983. if all_arithmetic_operators == "__rmod__" and (
  984. pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
  985. ):
  986. pytest.skip("Skip testing Python string formatting")
  987. self.frame_scalar_exc = self._get_scalar_exception(
  988. all_arithmetic_operators, pa_dtype
  989. )
  990. mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
  991. if mark is not None:
  992. request.node.add_marker(mark)
  993. if (
  994. (
  995. all_arithmetic_operators == "__floordiv__"
  996. and pa.types.is_integer(pa_dtype)
  997. )
  998. or pa.types.is_duration(pa_dtype)
  999. or pa.types.is_timestamp(pa_dtype)
  1000. ):
  1001. # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
  1002. # not upcast
  1003. monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
  1004. super().test_arith_frame_with_scalar(data, all_arithmetic_operators)
  1005. def test_arith_series_with_array(
  1006. self, data, all_arithmetic_operators, request, monkeypatch
  1007. ):
  1008. pa_dtype = data.dtype.pyarrow_dtype
  1009. self.series_array_exc = self._get_scalar_exception(
  1010. all_arithmetic_operators, pa_dtype
  1011. )
  1012. if (
  1013. all_arithmetic_operators
  1014. in (
  1015. "__sub__",
  1016. "__rsub__",
  1017. )
  1018. and pa.types.is_unsigned_integer(pa_dtype)
  1019. and not pa_version_under7p0
  1020. ):
  1021. request.node.add_marker(
  1022. pytest.mark.xfail(
  1023. raises=pa.ArrowInvalid,
  1024. reason=(
  1025. f"Implemented pyarrow.compute.subtract_checked "
  1026. f"which raises on overflow for {pa_dtype}"
  1027. ),
  1028. )
  1029. )
  1030. mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
  1031. if mark is not None:
  1032. request.node.add_marker(mark)
  1033. op_name = all_arithmetic_operators
  1034. ser = pd.Series(data)
  1035. # pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
  1036. # since ser.iloc[0] is a python scalar
  1037. other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))
  1038. if (
  1039. pa.types.is_floating(pa_dtype)
  1040. or (
  1041. pa.types.is_integer(pa_dtype)
  1042. and all_arithmetic_operators not in ["__truediv__", "__rtruediv__"]
  1043. )
  1044. or pa.types.is_duration(pa_dtype)
  1045. or pa.types.is_timestamp(pa_dtype)
  1046. ):
  1047. monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
  1048. self.check_opname(ser, op_name, other, exc=self.series_array_exc)
  1049. def test_add_series_with_extension_array(self, data, request):
  1050. pa_dtype = data.dtype.pyarrow_dtype
  1051. if pa.types.is_temporal(pa_dtype) and not pa.types.is_duration(pa_dtype):
  1052. # i.e. timestamp, date, time, but not timedelta; these *should*
  1053. # raise when trying to add
  1054. ser = pd.Series(data)
  1055. if pa_version_under7p0:
  1056. msg = "Function add_checked has no kernel matching input types"
  1057. else:
  1058. msg = "Function 'add_checked' has no kernel matching input types"
  1059. with pytest.raises(NotImplementedError, match=msg):
  1060. # TODO: this is a pa.lib.ArrowNotImplementedError, might
  1061. # be better to reraise a TypeError; more consistent with
  1062. # non-pyarrow cases
  1063. ser + data
  1064. return
  1065. if (pa_version_under8p0 and pa.types.is_duration(pa_dtype)) or (
  1066. pa.types.is_boolean(pa_dtype)
  1067. ):
  1068. request.node.add_marker(
  1069. pytest.mark.xfail(
  1070. raises=NotImplementedError,
  1071. reason=f"add_checked not implemented for {pa_dtype}",
  1072. )
  1073. )
  1074. elif pa_dtype.equals("int8"):
  1075. request.node.add_marker(
  1076. pytest.mark.xfail(
  1077. raises=pa.ArrowInvalid,
  1078. reason=f"raises on overflow for {pa_dtype}",
  1079. )
  1080. )
  1081. super().test_add_series_with_extension_array(data)
  1082. class TestBaseComparisonOps(base.BaseComparisonOpsTests):
  1083. def test_compare_array(self, data, comparison_op, na_value):
  1084. ser = pd.Series(data)
  1085. # pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
  1086. # since ser.iloc[0] is a python scalar
  1087. other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))
  1088. if comparison_op.__name__ in ["eq", "ne"]:
  1089. # comparison should match point-wise comparisons
  1090. result = comparison_op(ser, other)
  1091. # Series.combine does not calculate the NA mask correctly
  1092. # when comparing over an array
  1093. assert result[8] is na_value
  1094. assert result[97] is na_value
  1095. expected = ser.combine(other, comparison_op)
  1096. expected[8] = na_value
  1097. expected[97] = na_value
  1098. self.assert_series_equal(result, expected)
  1099. else:
  1100. exc = None
  1101. try:
  1102. result = comparison_op(ser, other)
  1103. except Exception as err:
  1104. exc = err
  1105. if exc is None:
  1106. # Didn't error, then should match point-wise behavior
  1107. expected = ser.combine(other, comparison_op)
  1108. self.assert_series_equal(result, expected)
  1109. else:
  1110. with pytest.raises(type(exc)):
  1111. ser.combine(other, comparison_op)
  1112. def test_invalid_other_comp(self, data, comparison_op):
  1113. # GH 48833
  1114. with pytest.raises(
  1115. NotImplementedError, match=".* not implemented for <class 'object'>"
  1116. ):
  1117. comparison_op(data, object())
  1118. @pytest.mark.parametrize("masked_dtype", ["boolean", "Int64", "Float64"])
  1119. def test_comp_masked_numpy(self, masked_dtype, comparison_op):
  1120. # GH 52625
  1121. data = [1, 0, None]
  1122. ser_masked = pd.Series(data, dtype=masked_dtype)
  1123. ser_pa = pd.Series(data, dtype=f"{masked_dtype.lower()}[pyarrow]")
  1124. result = comparison_op(ser_pa, ser_masked)
  1125. if comparison_op in [operator.lt, operator.gt, operator.ne]:
  1126. exp = [False, False, None]
  1127. else:
  1128. exp = [True, True, None]
  1129. expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
  1130. tm.assert_series_equal(result, expected)
  1131. class TestLogicalOps:
  1132. """Various Series and DataFrame logical ops methods."""
  1133. def test_kleene_or(self):
  1134. a = pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]")
  1135. b = pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
  1136. result = a | b
  1137. expected = pd.Series(
  1138. [True, True, True, True, False, None, True, None, None],
  1139. dtype="boolean[pyarrow]",
  1140. )
  1141. tm.assert_series_equal(result, expected)
  1142. result = b | a
  1143. tm.assert_series_equal(result, expected)
  1144. # ensure we haven't mutated anything inplace
  1145. tm.assert_series_equal(
  1146. a,
  1147. pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]"),
  1148. )
  1149. tm.assert_series_equal(
  1150. b, pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
  1151. )
  1152. @pytest.mark.parametrize(
  1153. "other, expected",
  1154. [
  1155. (None, [True, None, None]),
  1156. (pd.NA, [True, None, None]),
  1157. (True, [True, True, True]),
  1158. (np.bool_(True), [True, True, True]),
  1159. (False, [True, False, None]),
  1160. (np.bool_(False), [True, False, None]),
  1161. ],
  1162. )
  1163. def test_kleene_or_scalar(self, other, expected):
  1164. a = pd.Series([True, False, None], dtype="boolean[pyarrow]")
  1165. result = a | other
  1166. expected = pd.Series(expected, dtype="boolean[pyarrow]")
  1167. tm.assert_series_equal(result, expected)
  1168. result = other | a
  1169. tm.assert_series_equal(result, expected)
  1170. # ensure we haven't mutated anything inplace
  1171. tm.assert_series_equal(
  1172. a, pd.Series([True, False, None], dtype="boolean[pyarrow]")
  1173. )
  1174. def test_kleene_and(self):
  1175. a = pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]")
  1176. b = pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
  1177. result = a & b
  1178. expected = pd.Series(
  1179. [True, False, None, False, False, False, None, False, None],
  1180. dtype="boolean[pyarrow]",
  1181. )
  1182. tm.assert_series_equal(result, expected)
  1183. result = b & a
  1184. tm.assert_series_equal(result, expected)
  1185. # ensure we haven't mutated anything inplace
  1186. tm.assert_series_equal(
  1187. a,
  1188. pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]"),
  1189. )
  1190. tm.assert_series_equal(
  1191. b, pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
  1192. )
  1193. @pytest.mark.parametrize(
  1194. "other, expected",
  1195. [
  1196. (None, [None, False, None]),
  1197. (pd.NA, [None, False, None]),
  1198. (True, [True, False, None]),
  1199. (False, [False, False, False]),
  1200. (np.bool_(True), [True, False, None]),
  1201. (np.bool_(False), [False, False, False]),
  1202. ],
  1203. )
  1204. def test_kleene_and_scalar(self, other, expected):
  1205. a = pd.Series([True, False, None], dtype="boolean[pyarrow]")
  1206. result = a & other
  1207. expected = pd.Series(expected, dtype="boolean[pyarrow]")
  1208. tm.assert_series_equal(result, expected)
  1209. result = other & a
  1210. tm.assert_series_equal(result, expected)
  1211. # ensure we haven't mutated anything inplace
  1212. tm.assert_series_equal(
  1213. a, pd.Series([True, False, None], dtype="boolean[pyarrow]")
  1214. )
  1215. def test_kleene_xor(self):
  1216. a = pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]")
  1217. b = pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
  1218. result = a ^ b
  1219. expected = pd.Series(
  1220. [False, True, None, True, False, None, None, None, None],
  1221. dtype="boolean[pyarrow]",
  1222. )
  1223. tm.assert_series_equal(result, expected)
  1224. result = b ^ a
  1225. tm.assert_series_equal(result, expected)
  1226. # ensure we haven't mutated anything inplace
  1227. tm.assert_series_equal(
  1228. a,
  1229. pd.Series([True] * 3 + [False] * 3 + [None] * 3, dtype="boolean[pyarrow]"),
  1230. )
  1231. tm.assert_series_equal(
  1232. b, pd.Series([True, False, None] * 3, dtype="boolean[pyarrow]")
  1233. )
  1234. @pytest.mark.parametrize(
  1235. "other, expected",
  1236. [
  1237. (None, [None, None, None]),
  1238. (pd.NA, [None, None, None]),
  1239. (True, [False, True, None]),
  1240. (np.bool_(True), [False, True, None]),
  1241. (np.bool_(False), [True, False, None]),
  1242. ],
  1243. )
  1244. def test_kleene_xor_scalar(self, other, expected):
  1245. a = pd.Series([True, False, None], dtype="boolean[pyarrow]")
  1246. result = a ^ other
  1247. expected = pd.Series(expected, dtype="boolean[pyarrow]")
  1248. tm.assert_series_equal(result, expected)
  1249. result = other ^ a
  1250. tm.assert_series_equal(result, expected)
  1251. # ensure we haven't mutated anything inplace
  1252. tm.assert_series_equal(
  1253. a, pd.Series([True, False, None], dtype="boolean[pyarrow]")
  1254. )
  1255. @pytest.mark.parametrize(
  1256. "op, exp",
  1257. [
  1258. ["__and__", True],
  1259. ["__or__", True],
  1260. ["__xor__", False],
  1261. ],
  1262. )
  1263. def test_logical_masked_numpy(self, op, exp):
  1264. # GH 52625
  1265. data = [True, False, None]
  1266. ser_masked = pd.Series(data, dtype="boolean")
  1267. ser_pa = pd.Series(data, dtype="boolean[pyarrow]")
  1268. result = getattr(ser_pa, op)(ser_masked)
  1269. expected = pd.Series([exp, False, None], dtype=ArrowDtype(pa.bool_()))
  1270. tm.assert_series_equal(result, expected)
  1271. def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
  1272. with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
  1273. ArrowDtype.construct_from_string("not_a_real_dype[s, tz=UTC][pyarrow]")
  1274. # but as of GH#50689, timestamptz is supported
  1275. dtype = ArrowDtype.construct_from_string("timestamp[s, tz=UTC][pyarrow]")
  1276. expected = ArrowDtype(pa.timestamp("s", "UTC"))
  1277. assert dtype == expected
  1278. with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
  1279. ArrowDtype.construct_from_string("decimal(7, 2)[pyarrow]")
  1280. def test_arrowdtype_construct_from_string_type_only_one_pyarrow():
  1281. # GH#51225
  1282. invalid = "int64[pyarrow]foobar[pyarrow]"
  1283. msg = (
  1284. r"Passing pyarrow type specific parameters \(\[pyarrow\]\) in the "
  1285. r"string is not supported\."
  1286. )
  1287. with pytest.raises(NotImplementedError, match=msg):
  1288. pd.Series(range(3), dtype=invalid)
  1289. @pytest.mark.parametrize(
  1290. "interpolation", ["linear", "lower", "higher", "nearest", "midpoint"]
  1291. )
  1292. @pytest.mark.parametrize("quantile", [0.5, [0.5, 0.5]])
  1293. def test_quantile(data, interpolation, quantile, request):
  1294. pa_dtype = data.dtype.pyarrow_dtype
  1295. data = data.take([0, 0, 0])
  1296. ser = pd.Series(data)
  1297. if (
  1298. pa.types.is_string(pa_dtype)
  1299. or pa.types.is_binary(pa_dtype)
  1300. or pa.types.is_boolean(pa_dtype)
  1301. ):
  1302. # For string, bytes, and bool, we don't *expect* to have quantile work
  1303. # Note this matches the non-pyarrow behavior
  1304. if pa_version_under7p0:
  1305. msg = r"Function quantile has no kernel matching input types \(.*\)"
  1306. else:
  1307. msg = r"Function 'quantile' has no kernel matching input types \(.*\)"
  1308. with pytest.raises(pa.ArrowNotImplementedError, match=msg):
  1309. ser.quantile(q=quantile, interpolation=interpolation)
  1310. return
  1311. if (
  1312. pa.types.is_integer(pa_dtype)
  1313. or pa.types.is_floating(pa_dtype)
  1314. or (pa.types.is_decimal(pa_dtype) and not pa_version_under7p0)
  1315. ):
  1316. pass
  1317. elif pa.types.is_temporal(data._data.type):
  1318. pass
  1319. else:
  1320. request.node.add_marker(
  1321. pytest.mark.xfail(
  1322. raises=pa.ArrowNotImplementedError,
  1323. reason=f"quantile not supported by pyarrow for {pa_dtype}",
  1324. )
  1325. )
  1326. data = data.take([0, 0, 0])
  1327. ser = pd.Series(data)
  1328. result = ser.quantile(q=quantile, interpolation=interpolation)
  1329. if pa.types.is_timestamp(pa_dtype) and interpolation not in ["lower", "higher"]:
  1330. # rounding error will make the check below fail
  1331. # (e.g. '2020-01-01 01:01:01.000001' vs '2020-01-01 01:01:01.000001024'),
  1332. # so we'll check for now that we match the numpy analogue
  1333. if pa_dtype.tz:
  1334. pd_dtype = f"M8[{pa_dtype.unit}, {pa_dtype.tz}]"
  1335. else:
  1336. pd_dtype = f"M8[{pa_dtype.unit}]"
  1337. ser_np = ser.astype(pd_dtype)
  1338. expected = ser_np.quantile(q=quantile, interpolation=interpolation)
  1339. if quantile == 0.5:
  1340. if pa_dtype.unit == "us":
  1341. expected = expected.to_pydatetime(warn=False)
  1342. assert result == expected
  1343. else:
  1344. if pa_dtype.unit == "us":
  1345. expected = expected.dt.floor("us")
  1346. tm.assert_series_equal(result, expected.astype(data.dtype))
  1347. return
  1348. if quantile == 0.5:
  1349. assert result == data[0]
  1350. else:
  1351. # Just check the values
  1352. expected = pd.Series(data.take([0, 0]), index=[0.5, 0.5])
  1353. if (
  1354. pa.types.is_integer(pa_dtype)
  1355. or pa.types.is_floating(pa_dtype)
  1356. or pa.types.is_decimal(pa_dtype)
  1357. ):
  1358. expected = expected.astype("float64[pyarrow]")
  1359. result = result.astype("float64[pyarrow]")
  1360. tm.assert_series_equal(result, expected)
  1361. @pytest.mark.parametrize(
  1362. "take_idx, exp_idx",
  1363. [[[0, 0, 2, 2, 4, 4], [0, 4]], [[0, 0, 0, 2, 4, 4], [0]]],
  1364. ids=["multi_mode", "single_mode"],
  1365. )
  1366. def test_mode_dropna_true(data_for_grouping, take_idx, exp_idx):
  1367. data = data_for_grouping.take(take_idx)
  1368. ser = pd.Series(data)
  1369. result = ser.mode(dropna=True)
  1370. expected = pd.Series(data_for_grouping.take(exp_idx))
  1371. tm.assert_series_equal(result, expected)
  1372. def test_mode_dropna_false_mode_na(data):
  1373. # GH 50982
  1374. more_nans = pd.Series([None, None, data[0]], dtype=data.dtype)
  1375. result = more_nans.mode(dropna=False)
  1376. expected = pd.Series([None], dtype=data.dtype)
  1377. tm.assert_series_equal(result, expected)
  1378. expected = pd.Series([None, data[0]], dtype=data.dtype)
  1379. result = expected.mode(dropna=False)
  1380. tm.assert_series_equal(result, expected)
  1381. @pytest.mark.parametrize(
  1382. "arrow_dtype, expected_type",
  1383. [
  1384. [pa.binary(), bytes],
  1385. [pa.binary(16), bytes],
  1386. [pa.large_binary(), bytes],
  1387. [pa.large_string(), str],
  1388. [pa.list_(pa.int64()), list],
  1389. [pa.large_list(pa.int64()), list],
  1390. [pa.map_(pa.string(), pa.int64()), list],
  1391. [pa.struct([("f1", pa.int8()), ("f2", pa.string())]), dict],
  1392. [pa.dictionary(pa.int64(), pa.int64()), CategoricalDtypeType],
  1393. ],
  1394. )
  1395. def test_arrow_dtype_type(arrow_dtype, expected_type):
  1396. # GH 51845
  1397. # TODO: Redundant with test_getitem_scalar once arrow_dtype exists in data fixture
  1398. assert ArrowDtype(arrow_dtype).type == expected_type
  1399. def test_is_bool_dtype():
  1400. # GH 22667
  1401. data = ArrowExtensionArray(pa.array([True, False, True]))
  1402. assert is_bool_dtype(data)
  1403. assert pd.core.common.is_bool_indexer(data)
  1404. s = pd.Series(range(len(data)))
  1405. result = s[data]
  1406. expected = s[np.asarray(data)]
  1407. tm.assert_series_equal(result, expected)
  1408. def test_is_numeric_dtype(data):
  1409. # GH 50563
  1410. pa_type = data.dtype.pyarrow_dtype
  1411. if (
  1412. pa.types.is_floating(pa_type)
  1413. or pa.types.is_integer(pa_type)
  1414. or pa.types.is_decimal(pa_type)
  1415. ):
  1416. assert is_numeric_dtype(data)
  1417. else:
  1418. assert not is_numeric_dtype(data)
  1419. def test_is_integer_dtype(data):
  1420. # GH 50667
  1421. pa_type = data.dtype.pyarrow_dtype
  1422. if pa.types.is_integer(pa_type):
  1423. assert is_integer_dtype(data)
  1424. else:
  1425. assert not is_integer_dtype(data)
  1426. def test_is_any_integer_dtype(data):
  1427. # GH 50667
  1428. pa_type = data.dtype.pyarrow_dtype
  1429. if pa.types.is_integer(pa_type):
  1430. assert is_any_int_dtype(data)
  1431. else:
  1432. assert not is_any_int_dtype(data)
  1433. def test_is_signed_integer_dtype(data):
  1434. pa_type = data.dtype.pyarrow_dtype
  1435. if pa.types.is_signed_integer(pa_type):
  1436. assert is_signed_integer_dtype(data)
  1437. else:
  1438. assert not is_signed_integer_dtype(data)
  1439. def test_is_unsigned_integer_dtype(data):
  1440. pa_type = data.dtype.pyarrow_dtype
  1441. if pa.types.is_unsigned_integer(pa_type):
  1442. assert is_unsigned_integer_dtype(data)
  1443. else:
  1444. assert not is_unsigned_integer_dtype(data)
  1445. def test_is_float_dtype(data):
  1446. pa_type = data.dtype.pyarrow_dtype
  1447. if pa.types.is_floating(pa_type):
  1448. assert is_float_dtype(data)
  1449. else:
  1450. assert not is_float_dtype(data)
  1451. def test_pickle_roundtrip(data):
  1452. # GH 42600
  1453. expected = pd.Series(data)
  1454. expected_sliced = expected.head(2)
  1455. full_pickled = pickle.dumps(expected)
  1456. sliced_pickled = pickle.dumps(expected_sliced)
  1457. assert len(full_pickled) > len(sliced_pickled)
  1458. result = pickle.loads(full_pickled)
  1459. tm.assert_series_equal(result, expected)
  1460. result_sliced = pickle.loads(sliced_pickled)
  1461. tm.assert_series_equal(result_sliced, expected_sliced)
  1462. def test_astype_from_non_pyarrow(data):
  1463. # GH49795
  1464. pd_array = data._data.to_pandas().array
  1465. result = pd_array.astype(data.dtype)
  1466. assert not isinstance(pd_array.dtype, ArrowDtype)
  1467. assert isinstance(result.dtype, ArrowDtype)
  1468. tm.assert_extension_array_equal(result, data)
  1469. def test_astype_float_from_non_pyarrow_str():
  1470. # GH50430
  1471. ser = pd.Series(["1.0"])
  1472. result = ser.astype("float64[pyarrow]")
  1473. expected = pd.Series([1.0], dtype="float64[pyarrow]")
  1474. tm.assert_series_equal(result, expected)
  1475. def test_to_numpy_with_defaults(data):
  1476. # GH49973
  1477. result = data.to_numpy()
  1478. pa_type = data._data.type
  1479. if pa.types.is_duration(pa_type) or pa.types.is_timestamp(pa_type):
  1480. expected = np.array(list(data))
  1481. else:
  1482. expected = np.array(data._data)
  1483. if data._hasna:
  1484. expected = expected.astype(object)
  1485. expected[pd.isna(data)] = pd.NA
  1486. tm.assert_numpy_array_equal(result, expected)
  1487. def test_to_numpy_int_with_na():
  1488. # GH51227: ensure to_numpy does not convert int to float
  1489. data = [1, None]
  1490. arr = pd.array(data, dtype="int64[pyarrow]")
  1491. result = arr.to_numpy()
  1492. expected = np.array([1, pd.NA], dtype=object)
  1493. assert isinstance(result[0], int)
  1494. tm.assert_numpy_array_equal(result, expected)
  1495. @pytest.mark.parametrize("na_val, exp", [(lib.no_default, np.nan), (1, 1)])
  1496. def test_to_numpy_null_array(na_val, exp):
  1497. # GH#52443
  1498. arr = pd.array([pd.NA, pd.NA], dtype="null[pyarrow]")
  1499. result = arr.to_numpy(dtype="float64", na_value=na_val)
  1500. expected = np.array([exp] * 2, dtype="float64")
  1501. tm.assert_numpy_array_equal(result, expected)
  1502. def test_to_numpy_null_array_no_dtype():
  1503. # GH#52443
  1504. arr = pd.array([pd.NA, pd.NA], dtype="null[pyarrow]")
  1505. result = arr.to_numpy(dtype=None)
  1506. expected = np.array([pd.NA] * 2, dtype="object")
  1507. tm.assert_numpy_array_equal(result, expected)
  1508. def test_setitem_null_slice(data):
  1509. # GH50248
  1510. orig = data.copy()
  1511. result = orig.copy()
  1512. result[:] = data[0]
  1513. expected = ArrowExtensionArray(
  1514. pa.array([data[0]] * len(data), type=data._data.type)
  1515. )
  1516. tm.assert_extension_array_equal(result, expected)
  1517. result = orig.copy()
  1518. result[:] = data[::-1]
  1519. expected = data[::-1]
  1520. tm.assert_extension_array_equal(result, expected)
  1521. result = orig.copy()
  1522. result[:] = data.tolist()
  1523. expected = data
  1524. tm.assert_extension_array_equal(result, expected)
  1525. def test_setitem_invalid_dtype(data):
  1526. # GH50248
  1527. pa_type = data._data.type
  1528. if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
  1529. fill_value = 123
  1530. err = TypeError
  1531. msg = "Invalid value '123' for dtype"
  1532. elif (
  1533. pa.types.is_integer(pa_type)
  1534. or pa.types.is_floating(pa_type)
  1535. or pa.types.is_boolean(pa_type)
  1536. ):
  1537. fill_value = "foo"
  1538. err = pa.ArrowInvalid
  1539. msg = "Could not convert"
  1540. else:
  1541. fill_value = "foo"
  1542. err = TypeError
  1543. msg = "Invalid value 'foo' for dtype"
  1544. with pytest.raises(err, match=msg):
  1545. data[:] = fill_value
  1546. @pytest.mark.skipif(pa_version_under8p0, reason="returns object with 7.0")
  1547. def test_from_arrow_respecting_given_dtype():
  1548. date_array = pa.array(
  1549. [pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31")], type=pa.date32()
  1550. )
  1551. result = date_array.to_pandas(
  1552. types_mapper={pa.date32(): ArrowDtype(pa.date64())}.get
  1553. )
  1554. expected = pd.Series(
  1555. [pd.Timestamp("2019-12-31"), pd.Timestamp("2019-12-31")],
  1556. dtype=ArrowDtype(pa.date64()),
  1557. )
  1558. tm.assert_series_equal(result, expected)
  1559. @pytest.mark.skipif(pa_version_under8p0, reason="doesn't raise with 7")
  1560. def test_from_arrow_respecting_given_dtype_unsafe():
  1561. array = pa.array([1.5, 2.5], type=pa.float64())
  1562. with pytest.raises(pa.ArrowInvalid, match="Float value 1.5 was truncated"):
  1563. array.to_pandas(types_mapper={pa.float64(): ArrowDtype(pa.int64())}.get)
  1564. def test_round():
  1565. dtype = "float64[pyarrow]"
  1566. ser = pd.Series([0.0, 1.23, 2.56, pd.NA], dtype=dtype)
  1567. result = ser.round(1)
  1568. expected = pd.Series([0.0, 1.2, 2.6, pd.NA], dtype=dtype)
  1569. tm.assert_series_equal(result, expected)
  1570. ser = pd.Series([123.4, pd.NA, 56.78], dtype=dtype)
  1571. result = ser.round(-1)
  1572. expected = pd.Series([120.0, pd.NA, 60.0], dtype=dtype)
  1573. tm.assert_series_equal(result, expected)
  1574. def test_searchsorted_with_na_raises(data_for_sorting, as_series):
  1575. # GH50447
  1576. b, c, a = data_for_sorting
  1577. arr = data_for_sorting.take([2, 0, 1]) # to get [a, b, c]
  1578. arr[-1] = pd.NA
  1579. if as_series:
  1580. arr = pd.Series(arr)
  1581. msg = (
  1582. "searchsorted requires array to be sorted, "
  1583. "which is impossible with NAs present."
  1584. )
  1585. with pytest.raises(ValueError, match=msg):
  1586. arr.searchsorted(b)
  1587. def test_sort_values_dictionary():
  1588. df = pd.DataFrame(
  1589. {
  1590. "a": pd.Series(
  1591. ["x", "y"], dtype=ArrowDtype(pa.dictionary(pa.int32(), pa.string()))
  1592. ),
  1593. "b": [1, 2],
  1594. },
  1595. )
  1596. expected = df.copy()
  1597. result = df.sort_values(by=["a", "b"])
  1598. tm.assert_frame_equal(result, expected)
  1599. @pytest.mark.parametrize("pat", ["abc", "a[a-z]{2}"])
  1600. def test_str_count(pat):
  1601. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1602. result = ser.str.count(pat)
  1603. expected = pd.Series([1, None], dtype=ArrowDtype(pa.int32()))
  1604. tm.assert_series_equal(result, expected)
  1605. def test_str_count_flags_unsupported():
  1606. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1607. with pytest.raises(NotImplementedError, match="count not"):
  1608. ser.str.count("abc", flags=1)
  1609. @pytest.mark.parametrize(
  1610. "side, str_func", [["left", "rjust"], ["right", "ljust"], ["both", "center"]]
  1611. )
  1612. def test_str_pad(side, str_func):
  1613. ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string()))
  1614. result = ser.str.pad(width=3, side=side, fillchar="x")
  1615. expected = pd.Series(
  1616. [getattr("a", str_func)(3, "x"), None], dtype=ArrowDtype(pa.string())
  1617. )
  1618. tm.assert_series_equal(result, expected)
  1619. def test_str_pad_invalid_side():
  1620. ser = pd.Series(["a", None], dtype=ArrowDtype(pa.string()))
  1621. with pytest.raises(ValueError, match="Invalid side: foo"):
  1622. ser.str.pad(3, "foo", "x")
  1623. @pytest.mark.parametrize(
  1624. "pat, case, na, regex, exp",
  1625. [
  1626. ["ab", False, None, False, [True, None]],
  1627. ["Ab", True, None, False, [False, None]],
  1628. ["ab", False, True, False, [True, True]],
  1629. ["a[a-z]{1}", False, None, True, [True, None]],
  1630. ["A[a-z]{1}", True, None, True, [False, None]],
  1631. ],
  1632. )
  1633. def test_str_contains(pat, case, na, regex, exp):
  1634. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1635. result = ser.str.contains(pat, case=case, na=na, regex=regex)
  1636. expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
  1637. tm.assert_series_equal(result, expected)
  1638. def test_str_contains_flags_unsupported():
  1639. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1640. with pytest.raises(NotImplementedError, match="contains not"):
  1641. ser.str.contains("a", flags=1)
  1642. @pytest.mark.parametrize(
  1643. "side, pat, na, exp",
  1644. [
  1645. ["startswith", "ab", None, [True, None]],
  1646. ["startswith", "b", False, [False, False]],
  1647. ["endswith", "b", True, [False, True]],
  1648. ["endswith", "bc", None, [True, None]],
  1649. ],
  1650. )
  1651. def test_str_start_ends_with(side, pat, na, exp):
  1652. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1653. result = getattr(ser.str, side)(pat, na=na)
  1654. expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
  1655. tm.assert_series_equal(result, expected)
  1656. @pytest.mark.parametrize(
  1657. "arg_name, arg",
  1658. [["pat", re.compile("b")], ["repl", str], ["case", False], ["flags", 1]],
  1659. )
  1660. def test_str_replace_unsupported(arg_name, arg):
  1661. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1662. kwargs = {"pat": "b", "repl": "x", "regex": True}
  1663. kwargs[arg_name] = arg
  1664. with pytest.raises(NotImplementedError, match="replace is not supported"):
  1665. ser.str.replace(**kwargs)
  1666. @pytest.mark.parametrize(
  1667. "pat, repl, n, regex, exp",
  1668. [
  1669. ["a", "x", -1, False, ["xbxc", None]],
  1670. ["a", "x", 1, False, ["xbac", None]],
  1671. ["[a-b]", "x", -1, True, ["xxxc", None]],
  1672. ],
  1673. )
  1674. def test_str_replace(pat, repl, n, regex, exp):
  1675. ser = pd.Series(["abac", None], dtype=ArrowDtype(pa.string()))
  1676. result = ser.str.replace(pat, repl, n=n, regex=regex)
  1677. expected = pd.Series(exp, dtype=ArrowDtype(pa.string()))
  1678. tm.assert_series_equal(result, expected)
  1679. def test_str_repeat_unsupported():
  1680. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1681. with pytest.raises(NotImplementedError, match="repeat is not"):
  1682. ser.str.repeat([1, 2])
  1683. @pytest.mark.xfail(
  1684. pa_version_under7p0,
  1685. reason="Unsupported for pyarrow < 7",
  1686. raises=NotImplementedError,
  1687. )
  1688. def test_str_repeat():
  1689. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1690. result = ser.str.repeat(2)
  1691. expected = pd.Series(["abcabc", None], dtype=ArrowDtype(pa.string()))
  1692. tm.assert_series_equal(result, expected)
  1693. @pytest.mark.parametrize(
  1694. "pat, case, na, exp",
  1695. [
  1696. ["ab", False, None, [True, None]],
  1697. ["Ab", True, None, [False, None]],
  1698. ["bc", True, None, [False, None]],
  1699. ["ab", False, True, [True, True]],
  1700. ["a[a-z]{1}", False, None, [True, None]],
  1701. ["A[a-z]{1}", True, None, [False, None]],
  1702. ],
  1703. )
  1704. def test_str_match(pat, case, na, exp):
  1705. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1706. result = ser.str.match(pat, case=case, na=na)
  1707. expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
  1708. tm.assert_series_equal(result, expected)
  1709. @pytest.mark.parametrize(
  1710. "pat, case, na, exp",
  1711. [
  1712. ["abc", False, None, [True, None]],
  1713. ["Abc", True, None, [False, None]],
  1714. ["bc", True, None, [False, None]],
  1715. ["ab", False, True, [True, True]],
  1716. ["a[a-z]{2}", False, None, [True, None]],
  1717. ["A[a-z]{1}", True, None, [False, None]],
  1718. ],
  1719. )
  1720. def test_str_fullmatch(pat, case, na, exp):
  1721. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1722. result = ser.str.match(pat, case=case, na=na)
  1723. expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
  1724. tm.assert_series_equal(result, expected)
  1725. @pytest.mark.parametrize(
  1726. "sub, start, end, exp, exp_typ",
  1727. [["ab", 0, None, [0, None], pa.int32()], ["bc", 1, 3, [2, None], pa.int64()]],
  1728. )
  1729. def test_str_find(sub, start, end, exp, exp_typ):
  1730. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1731. result = ser.str.find(sub, start=start, end=end)
  1732. expected = pd.Series(exp, dtype=ArrowDtype(exp_typ))
  1733. tm.assert_series_equal(result, expected)
  1734. def test_str_find_notimplemented():
  1735. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1736. with pytest.raises(NotImplementedError, match="find not implemented"):
  1737. ser.str.find("ab", start=1)
  1738. @pytest.mark.parametrize(
  1739. "i, exp",
  1740. [
  1741. [1, ["b", "e", None]],
  1742. [-1, ["c", "e", None]],
  1743. [2, ["c", None, None]],
  1744. [-3, ["a", None, None]],
  1745. [4, [None, None, None]],
  1746. ],
  1747. )
  1748. def test_str_get(i, exp):
  1749. ser = pd.Series(["abc", "de", None], dtype=ArrowDtype(pa.string()))
  1750. result = ser.str.get(i)
  1751. expected = pd.Series(exp, dtype=ArrowDtype(pa.string()))
  1752. tm.assert_series_equal(result, expected)
  1753. @pytest.mark.xfail(
  1754. reason="TODO: StringMethods._validate should support Arrow list types",
  1755. raises=AttributeError,
  1756. )
  1757. def test_str_join():
  1758. ser = pd.Series(ArrowExtensionArray(pa.array([list("abc"), list("123"), None])))
  1759. result = ser.str.join("=")
  1760. expected = pd.Series(["a=b=c", "1=2=3", None], dtype=ArrowDtype(pa.string()))
  1761. tm.assert_series_equal(result, expected)
  1762. @pytest.mark.parametrize(
  1763. "start, stop, step, exp",
  1764. [
  1765. [None, 2, None, ["ab", None]],
  1766. [None, 2, 1, ["ab", None]],
  1767. [1, 3, 1, ["bc", None]],
  1768. ],
  1769. )
  1770. def test_str_slice(start, stop, step, exp):
  1771. ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string()))
  1772. result = ser.str.slice(start, stop, step)
  1773. expected = pd.Series(exp, dtype=ArrowDtype(pa.string()))
  1774. tm.assert_series_equal(result, expected)
  1775. @pytest.mark.parametrize(
  1776. "start, stop, repl, exp",
  1777. [
  1778. [1, 2, "x", ["axcd", None]],
  1779. [None, 2, "x", ["xcd", None]],
  1780. [None, 2, None, ["cd", None]],
  1781. ],
  1782. )
  1783. def test_str_slice_replace(start, stop, repl, exp):
  1784. ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string()))
  1785. result = ser.str.slice_replace(start, stop, repl)
  1786. expected = pd.Series(exp, dtype=ArrowDtype(pa.string()))
  1787. tm.assert_series_equal(result, expected)
  1788. @pytest.mark.parametrize(
  1789. "value, method, exp",
  1790. [
  1791. ["a1c", "isalnum", True],
  1792. ["!|,", "isalnum", False],
  1793. ["aaa", "isalpha", True],
  1794. ["!!!", "isalpha", False],
  1795. ["٠", "isdecimal", True],
  1796. ["~!", "isdecimal", False],
  1797. ["2", "isdigit", True],
  1798. ["~", "isdigit", False],
  1799. ["aaa", "islower", True],
  1800. ["aaA", "islower", False],
  1801. ["123", "isnumeric", True],
  1802. ["11I", "isnumeric", False],
  1803. [" ", "isspace", True],
  1804. ["", "isspace", False],
  1805. ["The That", "istitle", True],
  1806. ["the That", "istitle", False],
  1807. ["AAA", "isupper", True],
  1808. ["AAc", "isupper", False],
  1809. ],
  1810. )
  1811. def test_str_is_functions(value, method, exp):
  1812. ser = pd.Series([value, None], dtype=ArrowDtype(pa.string()))
  1813. result = getattr(ser.str, method)()
  1814. expected = pd.Series([exp, None], dtype=ArrowDtype(pa.bool_()))
  1815. tm.assert_series_equal(result, expected)
  1816. @pytest.mark.parametrize(
  1817. "method, exp",
  1818. [
  1819. ["capitalize", "Abc def"],
  1820. ["title", "Abc Def"],
  1821. ["swapcase", "AbC Def"],
  1822. ["lower", "abc def"],
  1823. ["upper", "ABC DEF"],
  1824. ["casefold", "abc def"],
  1825. ],
  1826. )
  1827. def test_str_transform_functions(method, exp):
  1828. ser = pd.Series(["aBc dEF", None], dtype=ArrowDtype(pa.string()))
  1829. result = getattr(ser.str, method)()
  1830. expected = pd.Series([exp, None], dtype=ArrowDtype(pa.string()))
  1831. tm.assert_series_equal(result, expected)
  1832. def test_str_len():
  1833. ser = pd.Series(["abcd", None], dtype=ArrowDtype(pa.string()))
  1834. result = ser.str.len()
  1835. expected = pd.Series([4, None], dtype=ArrowDtype(pa.int32()))
  1836. tm.assert_series_equal(result, expected)
  1837. @pytest.mark.parametrize(
  1838. "method, to_strip, val",
  1839. [
  1840. ["strip", None, " abc "],
  1841. ["strip", "x", "xabcx"],
  1842. ["lstrip", None, " abc"],
  1843. ["lstrip", "x", "xabc"],
  1844. ["rstrip", None, "abc "],
  1845. ["rstrip", "x", "abcx"],
  1846. ],
  1847. )
  1848. def test_str_strip(method, to_strip, val):
  1849. ser = pd.Series([val, None], dtype=ArrowDtype(pa.string()))
  1850. result = getattr(ser.str, method)(to_strip=to_strip)
  1851. expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1852. tm.assert_series_equal(result, expected)
  1853. @pytest.mark.parametrize("val", ["abc123", "abc"])
  1854. def test_str_removesuffix(val):
  1855. ser = pd.Series([val, None], dtype=ArrowDtype(pa.string()))
  1856. result = ser.str.removesuffix("123")
  1857. expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1858. tm.assert_series_equal(result, expected)
  1859. @pytest.mark.parametrize("val", ["123abc", "abc"])
  1860. def test_str_removeprefix(val):
  1861. ser = pd.Series([val, None], dtype=ArrowDtype(pa.string()))
  1862. result = ser.str.removeprefix("123")
  1863. expected = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1864. tm.assert_series_equal(result, expected)
  1865. @pytest.mark.parametrize("errors", ["ignore", "strict"])
  1866. @pytest.mark.parametrize(
  1867. "encoding, exp",
  1868. [
  1869. ["utf8", b"abc"],
  1870. ["utf32", b"\xff\xfe\x00\x00a\x00\x00\x00b\x00\x00\x00c\x00\x00\x00"],
  1871. ],
  1872. )
  1873. def test_str_encode(errors, encoding, exp):
  1874. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1875. result = ser.str.encode(encoding, errors)
  1876. expected = pd.Series([exp, None], dtype=ArrowDtype(pa.binary()))
  1877. tm.assert_series_equal(result, expected)
  1878. @pytest.mark.parametrize("flags", [0, 1])
  1879. def test_str_findall(flags):
  1880. ser = pd.Series(["abc", "efg", None], dtype=ArrowDtype(pa.string()))
  1881. result = ser.str.findall("b", flags=flags)
  1882. expected = pd.Series([["b"], [], None], dtype=ArrowDtype(pa.list_(pa.string())))
  1883. tm.assert_series_equal(result, expected)
  1884. @pytest.mark.parametrize("method", ["index", "rindex"])
  1885. @pytest.mark.parametrize(
  1886. "start, end",
  1887. [
  1888. [0, None],
  1889. [1, 4],
  1890. ],
  1891. )
  1892. def test_str_r_index(method, start, end):
  1893. ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string()))
  1894. result = getattr(ser.str, method)("c", start, end)
  1895. expected = pd.Series([2, None], dtype=ArrowDtype(pa.int64()))
  1896. tm.assert_series_equal(result, expected)
  1897. with pytest.raises(ValueError, match="substring not found"):
  1898. getattr(ser.str, method)("foo", start, end)
  1899. @pytest.mark.parametrize("form", ["NFC", "NFKC"])
  1900. def test_str_normalize(form):
  1901. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  1902. result = ser.str.normalize(form)
  1903. expected = ser.copy()
  1904. tm.assert_series_equal(result, expected)
  1905. @pytest.mark.parametrize(
  1906. "start, end",
  1907. [
  1908. [0, None],
  1909. [1, 4],
  1910. ],
  1911. )
  1912. def test_str_rfind(start, end):
  1913. ser = pd.Series(["abcba", "foo", None], dtype=ArrowDtype(pa.string()))
  1914. result = ser.str.rfind("c", start, end)
  1915. expected = pd.Series([2, -1, None], dtype=ArrowDtype(pa.int64()))
  1916. tm.assert_series_equal(result, expected)
  1917. def test_str_translate():
  1918. ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string()))
  1919. result = ser.str.translate({97: "b"})
  1920. expected = pd.Series(["bbcbb", None], dtype=ArrowDtype(pa.string()))
  1921. tm.assert_series_equal(result, expected)
  1922. def test_str_wrap():
  1923. ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string()))
  1924. result = ser.str.wrap(3)
  1925. expected = pd.Series(["abc\nba", None], dtype=ArrowDtype(pa.string()))
  1926. tm.assert_series_equal(result, expected)
  1927. def test_get_dummies():
  1928. ser = pd.Series(["a|b", None, "a|c"], dtype=ArrowDtype(pa.string()))
  1929. result = ser.str.get_dummies()
  1930. expected = pd.DataFrame(
  1931. [[True, True, False], [False, False, False], [True, False, True]],
  1932. dtype=ArrowDtype(pa.bool_()),
  1933. columns=["a", "b", "c"],
  1934. )
  1935. tm.assert_frame_equal(result, expected)
  1936. def test_str_partition():
  1937. ser = pd.Series(["abcba", None], dtype=ArrowDtype(pa.string()))
  1938. result = ser.str.partition("b")
  1939. expected = pd.DataFrame(
  1940. [["a", "b", "cba"], [None, None, None]], dtype=ArrowDtype(pa.string())
  1941. )
  1942. tm.assert_frame_equal(result, expected)
  1943. result = ser.str.partition("b", expand=False)
  1944. expected = pd.Series(ArrowExtensionArray(pa.array([["a", "b", "cba"], None])))
  1945. tm.assert_series_equal(result, expected)
  1946. result = ser.str.rpartition("b")
  1947. expected = pd.DataFrame(
  1948. [["abc", "b", "a"], [None, None, None]], dtype=ArrowDtype(pa.string())
  1949. )
  1950. tm.assert_frame_equal(result, expected)
  1951. result = ser.str.rpartition("b", expand=False)
  1952. expected = pd.Series(ArrowExtensionArray(pa.array([["abc", "b", "a"], None])))
  1953. tm.assert_series_equal(result, expected)
  1954. def test_str_split():
  1955. # GH 52401
  1956. ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
  1957. result = ser.str.split("c")
  1958. expected = pd.Series(
  1959. ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
  1960. )
  1961. tm.assert_series_equal(result, expected)
  1962. result = ser.str.split("c", n=1)
  1963. expected = pd.Series(
  1964. ArrowExtensionArray(pa.array([["a1", "bcb"], ["a2", "bcb"], None]))
  1965. )
  1966. tm.assert_series_equal(result, expected)
  1967. result = ser.str.split("[1-2]", regex=True)
  1968. expected = pd.Series(
  1969. ArrowExtensionArray(pa.array([["a", "cbcb"], ["a", "cbcb"], None]))
  1970. )
  1971. tm.assert_series_equal(result, expected)
  1972. result = ser.str.split("[1-2]", regex=True, expand=True)
  1973. expected = pd.DataFrame(
  1974. {
  1975. 0: ArrowExtensionArray(pa.array(["a", "a", None])),
  1976. 1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])),
  1977. }
  1978. )
  1979. tm.assert_frame_equal(result, expected)
  1980. result = ser.str.split("1", expand=True)
  1981. expected = pd.DataFrame(
  1982. {
  1983. 0: ArrowExtensionArray(pa.array(["a", "a2cbcb", None])),
  1984. 1: ArrowExtensionArray(pa.array(["cbcb", None, None])),
  1985. }
  1986. )
  1987. tm.assert_frame_equal(result, expected)
  1988. def test_str_rsplit():
  1989. # GH 52401
  1990. ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
  1991. result = ser.str.rsplit("c")
  1992. expected = pd.Series(
  1993. ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
  1994. )
  1995. tm.assert_series_equal(result, expected)
  1996. result = ser.str.rsplit("c", n=1)
  1997. expected = pd.Series(
  1998. ArrowExtensionArray(pa.array([["a1cb", "b"], ["a2cb", "b"], None]))
  1999. )
  2000. tm.assert_series_equal(result, expected)
  2001. result = ser.str.rsplit("c", n=1, expand=True)
  2002. expected = pd.DataFrame(
  2003. {
  2004. 0: ArrowExtensionArray(pa.array(["a1cb", "a2cb", None])),
  2005. 1: ArrowExtensionArray(pa.array(["b", "b", None])),
  2006. }
  2007. )
  2008. tm.assert_frame_equal(result, expected)
  2009. result = ser.str.rsplit("1", expand=True)
  2010. expected = pd.DataFrame(
  2011. {
  2012. 0: ArrowExtensionArray(pa.array(["a", "a2cbcb", None])),
  2013. 1: ArrowExtensionArray(pa.array(["cbcb", None, None])),
  2014. }
  2015. )
  2016. tm.assert_frame_equal(result, expected)
  2017. def test_str_unsupported_extract():
  2018. ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
  2019. with pytest.raises(
  2020. NotImplementedError, match="str.extract not supported with pd.ArrowDtype"
  2021. ):
  2022. ser.str.extract(r"[ab](\d)")
  2023. @pytest.mark.parametrize("unit", ["ns", "us", "ms", "s"])
  2024. def test_duration_from_strings_with_nat(unit):
  2025. # GH51175
  2026. strings = ["1000", "NaT"]
  2027. pa_type = pa.duration(unit)
  2028. result = ArrowExtensionArray._from_sequence_of_strings(strings, dtype=pa_type)
  2029. expected = ArrowExtensionArray(pa.array([1000, None], type=pa_type))
  2030. tm.assert_extension_array_equal(result, expected)
  2031. def test_unsupported_dt(data):
  2032. pa_dtype = data.dtype.pyarrow_dtype
  2033. if not pa.types.is_temporal(pa_dtype):
  2034. with pytest.raises(
  2035. AttributeError, match="Can only use .dt accessor with datetimelike values"
  2036. ):
  2037. pd.Series(data).dt
  2038. @pytest.mark.parametrize(
  2039. "prop, expected",
  2040. [
  2041. ["year", 2023],
  2042. ["day", 2],
  2043. ["day_of_week", 0],
  2044. ["dayofweek", 0],
  2045. ["weekday", 0],
  2046. ["day_of_year", 2],
  2047. ["dayofyear", 2],
  2048. ["hour", 3],
  2049. ["minute", 4],
  2050. pytest.param(
  2051. "is_leap_year",
  2052. False,
  2053. marks=pytest.mark.xfail(
  2054. pa_version_under8p0,
  2055. raises=NotImplementedError,
  2056. reason="is_leap_year not implemented for pyarrow < 8.0",
  2057. ),
  2058. ),
  2059. ["microsecond", 5],
  2060. ["month", 1],
  2061. ["nanosecond", 6],
  2062. ["quarter", 1],
  2063. ["second", 7],
  2064. ["date", date(2023, 1, 2)],
  2065. ["time", time(3, 4, 7, 5)],
  2066. ],
  2067. )
  2068. def test_dt_properties(prop, expected):
  2069. ser = pd.Series(
  2070. [
  2071. pd.Timestamp(
  2072. year=2023,
  2073. month=1,
  2074. day=2,
  2075. hour=3,
  2076. minute=4,
  2077. second=7,
  2078. microsecond=5,
  2079. nanosecond=6,
  2080. ),
  2081. None,
  2082. ],
  2083. dtype=ArrowDtype(pa.timestamp("ns")),
  2084. )
  2085. result = getattr(ser.dt, prop)
  2086. exp_type = None
  2087. if isinstance(expected, date):
  2088. exp_type = pa.date32()
  2089. elif isinstance(expected, time):
  2090. exp_type = pa.time64("ns")
  2091. expected = pd.Series(ArrowExtensionArray(pa.array([expected, None], type=exp_type)))
  2092. tm.assert_series_equal(result, expected)
  2093. @pytest.mark.parametrize("unit", ["us", "ns"])
  2094. def test_dt_time_preserve_unit(unit):
  2095. ser = pd.Series(
  2096. [datetime(year=2023, month=1, day=2, hour=3), None],
  2097. dtype=ArrowDtype(pa.timestamp(unit)),
  2098. )
  2099. result = ser.dt.time
  2100. expected = pd.Series(
  2101. ArrowExtensionArray(pa.array([time(3, 0), None], type=pa.time64(unit)))
  2102. )
  2103. tm.assert_series_equal(result, expected)
  2104. @pytest.mark.parametrize("tz", [None, "UTC", "US/Pacific"])
  2105. def test_dt_tz(tz):
  2106. ser = pd.Series(
  2107. [datetime(year=2023, month=1, day=2, hour=3), None],
  2108. dtype=ArrowDtype(pa.timestamp("ns", tz=tz)),
  2109. )
  2110. result = ser.dt.tz
  2111. assert result == tz
  2112. def test_dt_isocalendar():
  2113. ser = pd.Series(
  2114. [datetime(year=2023, month=1, day=2, hour=3), None],
  2115. dtype=ArrowDtype(pa.timestamp("ns")),
  2116. )
  2117. result = ser.dt.isocalendar()
  2118. expected = pd.DataFrame(
  2119. [[2023, 1, 1], [0, 0, 0]],
  2120. columns=["year", "week", "day"],
  2121. dtype="int64[pyarrow]",
  2122. )
  2123. tm.assert_frame_equal(result, expected)
  2124. def test_dt_strftime(request):
  2125. if is_platform_windows() and is_ci_environment():
  2126. request.node.add_marker(
  2127. pytest.mark.xfail(
  2128. raises=pa.ArrowInvalid,
  2129. reason=(
  2130. "TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
  2131. "on CI to path to the tzdata for pyarrow."
  2132. ),
  2133. )
  2134. )
  2135. ser = pd.Series(
  2136. [datetime(year=2023, month=1, day=2, hour=3), None],
  2137. dtype=ArrowDtype(pa.timestamp("ns")),
  2138. )
  2139. result = ser.dt.strftime("%Y-%m-%dT%H:%M:%S")
  2140. expected = pd.Series(
  2141. ["2023-01-02T03:00:00.000000000", None], dtype=ArrowDtype(pa.string())
  2142. )
  2143. tm.assert_series_equal(result, expected)
  2144. @pytest.mark.parametrize("method", ["ceil", "floor", "round"])
  2145. def test_dt_roundlike_tz_options_not_supported(method):
  2146. ser = pd.Series(
  2147. [datetime(year=2023, month=1, day=2, hour=3), None],
  2148. dtype=ArrowDtype(pa.timestamp("ns")),
  2149. )
  2150. with pytest.raises(NotImplementedError, match="ambiguous is not supported."):
  2151. getattr(ser.dt, method)("1H", ambiguous="NaT")
  2152. with pytest.raises(NotImplementedError, match="nonexistent is not supported."):
  2153. getattr(ser.dt, method)("1H", nonexistent="NaT")
  2154. @pytest.mark.parametrize("method", ["ceil", "floor", "round"])
  2155. def test_dt_roundlike_unsupported_freq(method):
  2156. ser = pd.Series(
  2157. [datetime(year=2023, month=1, day=2, hour=3), None],
  2158. dtype=ArrowDtype(pa.timestamp("ns")),
  2159. )
  2160. with pytest.raises(ValueError, match="freq='1B' is not supported"):
  2161. getattr(ser.dt, method)("1B")
  2162. with pytest.raises(ValueError, match="Must specify a valid frequency: None"):
  2163. getattr(ser.dt, method)(None)
  2164. @pytest.mark.xfail(
  2165. pa_version_under7p0, reason="Methods not supported for pyarrow < 7.0"
  2166. )
  2167. @pytest.mark.parametrize("freq", ["D", "H", "T", "S", "L", "U", "N"])
  2168. @pytest.mark.parametrize("method", ["ceil", "floor", "round"])
  2169. def test_dt_ceil_year_floor(freq, method):
  2170. ser = pd.Series(
  2171. [datetime(year=2023, month=1, day=1), None],
  2172. )
  2173. pa_dtype = ArrowDtype(pa.timestamp("ns"))
  2174. expected = getattr(ser.dt, method)(f"1{freq}").astype(pa_dtype)
  2175. result = getattr(ser.astype(pa_dtype).dt, method)(f"1{freq}")
  2176. tm.assert_series_equal(result, expected)
  2177. def test_dt_to_pydatetime():
  2178. # GH 51859
  2179. data = [datetime(2022, 1, 1), datetime(2023, 1, 1)]
  2180. ser = pd.Series(data, dtype=ArrowDtype(pa.timestamp("ns")))
  2181. result = ser.dt.to_pydatetime()
  2182. expected = np.array(data, dtype=object)
  2183. tm.assert_numpy_array_equal(result, expected)
  2184. assert all(type(res) is datetime for res in result)
  2185. expected = ser.astype("datetime64[ns]").dt.to_pydatetime()
  2186. tm.assert_numpy_array_equal(result, expected)
  2187. @pytest.mark.parametrize("date_type", [32, 64])
  2188. def test_dt_to_pydatetime_date_error(date_type):
  2189. # GH 52812
  2190. ser = pd.Series(
  2191. [date(2022, 12, 31)],
  2192. dtype=ArrowDtype(getattr(pa, f"date{date_type}")()),
  2193. )
  2194. with pytest.raises(ValueError, match="to_pydatetime cannot be called with"):
  2195. ser.dt.to_pydatetime()
  2196. def test_dt_tz_localize_unsupported_tz_options():
  2197. ser = pd.Series(
  2198. [datetime(year=2023, month=1, day=2, hour=3), None],
  2199. dtype=ArrowDtype(pa.timestamp("ns")),
  2200. )
  2201. with pytest.raises(NotImplementedError, match="ambiguous='NaT' is not supported"):
  2202. ser.dt.tz_localize("UTC", ambiguous="NaT")
  2203. with pytest.raises(NotImplementedError, match="nonexistent='NaT' is not supported"):
  2204. ser.dt.tz_localize("UTC", nonexistent="NaT")
  2205. def test_dt_tz_localize_none():
  2206. ser = pd.Series(
  2207. [datetime(year=2023, month=1, day=2, hour=3), None],
  2208. dtype=ArrowDtype(pa.timestamp("ns", tz="US/Pacific")),
  2209. )
  2210. result = ser.dt.tz_localize(None)
  2211. expected = pd.Series(
  2212. [datetime(year=2023, month=1, day=2, hour=3), None],
  2213. dtype=ArrowDtype(pa.timestamp("ns")),
  2214. )
  2215. tm.assert_series_equal(result, expected)
  2216. @pytest.mark.parametrize("unit", ["us", "ns"])
  2217. def test_dt_tz_localize(unit, request):
  2218. if is_platform_windows() and is_ci_environment():
  2219. request.node.add_marker(
  2220. pytest.mark.xfail(
  2221. raises=pa.ArrowInvalid,
  2222. reason=(
  2223. "TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
  2224. "on CI to path to the tzdata for pyarrow."
  2225. ),
  2226. )
  2227. )
  2228. ser = pd.Series(
  2229. [datetime(year=2023, month=1, day=2, hour=3), None],
  2230. dtype=ArrowDtype(pa.timestamp(unit)),
  2231. )
  2232. result = ser.dt.tz_localize("US/Pacific")
  2233. exp_data = pa.array(
  2234. [datetime(year=2023, month=1, day=2, hour=3), None], type=pa.timestamp(unit)
  2235. )
  2236. exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific")
  2237. expected = pd.Series(ArrowExtensionArray(exp_data))
  2238. tm.assert_series_equal(result, expected)
  2239. @pytest.mark.parametrize(
  2240. "nonexistent, exp_date",
  2241. [
  2242. ["shift_forward", datetime(year=2023, month=3, day=12, hour=3)],
  2243. ["shift_backward", pd.Timestamp("2023-03-12 01:59:59.999999999")],
  2244. ],
  2245. )
  2246. def test_dt_tz_localize_nonexistent(nonexistent, exp_date, request):
  2247. if is_platform_windows() and is_ci_environment():
  2248. request.node.add_marker(
  2249. pytest.mark.xfail(
  2250. raises=pa.ArrowInvalid,
  2251. reason=(
  2252. "TODO: Set ARROW_TIMEZONE_DATABASE environment variable "
  2253. "on CI to path to the tzdata for pyarrow."
  2254. ),
  2255. )
  2256. )
  2257. ser = pd.Series(
  2258. [datetime(year=2023, month=3, day=12, hour=2, minute=30), None],
  2259. dtype=ArrowDtype(pa.timestamp("ns")),
  2260. )
  2261. result = ser.dt.tz_localize("US/Pacific", nonexistent=nonexistent)
  2262. exp_data = pa.array([exp_date, None], type=pa.timestamp("ns"))
  2263. exp_data = pa.compute.assume_timezone(exp_data, "US/Pacific")
  2264. expected = pd.Series(ArrowExtensionArray(exp_data))
  2265. tm.assert_series_equal(result, expected)
  2266. @pytest.mark.parametrize("skipna", [True, False])
  2267. def test_boolean_reduce_series_all_null(all_boolean_reductions, skipna):
  2268. # GH51624
  2269. ser = pd.Series([None], dtype="float64[pyarrow]")
  2270. result = getattr(ser, all_boolean_reductions)(skipna=skipna)
  2271. if skipna:
  2272. expected = all_boolean_reductions == "all"
  2273. else:
  2274. expected = pd.NA
  2275. assert result is expected
  2276. @pytest.mark.parametrize("dtype", ["string", "string[pyarrow]"])
  2277. def test_series_from_string_array(dtype):
  2278. arr = pa.array("the quick brown fox".split())
  2279. ser = pd.Series(arr, dtype=dtype)
  2280. expected = pd.Series(ArrowExtensionArray(arr), dtype=dtype)
  2281. tm.assert_series_equal(ser, expected)
  2282. def test_setitem_boolean_replace_with_mask_segfault():
  2283. # GH#52059
  2284. N = 145_000
  2285. arr = ArrowExtensionArray(pa.chunked_array([np.ones((N,), dtype=np.bool_)]))
  2286. expected = arr.copy()
  2287. arr[np.zeros((N,), dtype=np.bool_)] = False
  2288. assert arr._data == expected._data
  2289. @pytest.mark.parametrize(
  2290. "data, arrow_dtype",
  2291. [
  2292. ([b"a", b"b"], pa.large_binary()),
  2293. (["a", "b"], pa.large_string()),
  2294. ],
  2295. )
  2296. def test_conversion_large_dtypes_from_numpy_array(data, arrow_dtype):
  2297. dtype = ArrowDtype(arrow_dtype)
  2298. result = pd.array(np.array(data), dtype=dtype)
  2299. expected = pd.array(data, dtype=dtype)
  2300. tm.assert_extension_array_equal(result, expected)
  2301. @pytest.mark.parametrize("pa_type", tm.ALL_INT_PYARROW_DTYPES + tm.FLOAT_PYARROW_DTYPES)
  2302. def test_describe_numeric_data(pa_type):
  2303. # GH 52470
  2304. data = pd.Series([1, 2, 3], dtype=ArrowDtype(pa_type))
  2305. result = data.describe()
  2306. expected = pd.Series(
  2307. [3, 2, 1, 1, 1.5, 2.0, 2.5, 3],
  2308. dtype=ArrowDtype(pa.float64()),
  2309. index=["count", "mean", "std", "min", "25%", "50%", "75%", "max"],
  2310. )
  2311. tm.assert_series_equal(result, expected)
  2312. @pytest.mark.parametrize("pa_type", tm.TIMEDELTA_PYARROW_DTYPES)
  2313. def test_describe_timedelta_data(pa_type):
  2314. # GH53001
  2315. data = pd.Series(range(1, 10), dtype=ArrowDtype(pa_type))
  2316. result = data.describe()
  2317. expected = pd.Series(
  2318. [9] + pd.to_timedelta([5, 2, 1, 3, 5, 7, 9], unit=pa_type.unit).tolist(),
  2319. dtype=object,
  2320. index=["count", "mean", "std", "min", "25%", "50%", "75%", "max"],
  2321. )
  2322. tm.assert_series_equal(result, expected)
  2323. @pytest.mark.parametrize("pa_type", tm.DATETIME_PYARROW_DTYPES)
  2324. def test_describe_datetime_data(pa_type):
  2325. # GH53001
  2326. data = pd.Series(range(1, 10), dtype=ArrowDtype(pa_type))
  2327. result = data.describe()
  2328. expected = pd.Series(
  2329. [9]
  2330. + [
  2331. pd.Timestamp(v, tz=pa_type.tz, unit=pa_type.unit)
  2332. for v in [5, 1, 3, 5, 7, 9]
  2333. ],
  2334. dtype=object,
  2335. index=["count", "mean", "min", "25%", "50%", "75%", "max"],
  2336. )
  2337. tm.assert_series_equal(result, expected)
  2338. @pytest.mark.xfail(
  2339. pa_version_under8p0,
  2340. reason="Function 'add_checked' has no kernel matching input types",
  2341. raises=pa.ArrowNotImplementedError,
  2342. )
  2343. def test_duration_overflow_from_ndarray_containing_nat():
  2344. # GH52843
  2345. data_ts = pd.to_datetime([1, None])
  2346. data_td = pd.to_timedelta([1, None])
  2347. ser_ts = pd.Series(data_ts, dtype=ArrowDtype(pa.timestamp("ns")))
  2348. ser_td = pd.Series(data_td, dtype=ArrowDtype(pa.duration("ns")))
  2349. result = ser_ts + ser_td
  2350. expected = pd.Series([2, None], dtype=ArrowDtype(pa.timestamp("ns")))
  2351. tm.assert_series_equal(result, expected)