test_array.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. import datetime
  2. import decimal
  3. import numpy as np
  4. import pytest
  5. import pytz
  6. from pandas.core.dtypes.base import _registry as registry
  7. import pandas as pd
  8. import pandas._testing as tm
  9. from pandas.api.extensions import register_extension_dtype
  10. from pandas.arrays import (
  11. BooleanArray,
  12. DatetimeArray,
  13. FloatingArray,
  14. IntegerArray,
  15. IntervalArray,
  16. SparseArray,
  17. TimedeltaArray,
  18. )
  19. from pandas.core.arrays import (
  20. PandasArray,
  21. period_array,
  22. )
  23. from pandas.tests.extension.decimal import (
  24. DecimalArray,
  25. DecimalDtype,
  26. to_decimal,
  27. )
  28. @pytest.mark.parametrize(
  29. "data, dtype, expected",
  30. [
  31. # Basic NumPy defaults.
  32. ([1, 2], None, IntegerArray._from_sequence([1, 2])),
  33. ([1, 2], object, PandasArray(np.array([1, 2], dtype=object))),
  34. (
  35. [1, 2],
  36. np.dtype("float32"),
  37. PandasArray(np.array([1.0, 2.0], dtype=np.dtype("float32"))),
  38. ),
  39. (np.array([1, 2], dtype="int64"), None, IntegerArray._from_sequence([1, 2])),
  40. (
  41. np.array([1.0, 2.0], dtype="float64"),
  42. None,
  43. FloatingArray._from_sequence([1.0, 2.0]),
  44. ),
  45. # String alias passes through to NumPy
  46. ([1, 2], "float32", PandasArray(np.array([1, 2], dtype="float32"))),
  47. ([1, 2], "int64", PandasArray(np.array([1, 2], dtype=np.int64))),
  48. # GH#44715 FloatingArray does not support float16, so fall back to PandasArray
  49. (
  50. np.array([1, 2], dtype=np.float16),
  51. None,
  52. PandasArray(np.array([1, 2], dtype=np.float16)),
  53. ),
  54. # idempotency with e.g. pd.array(pd.array([1, 2], dtype="int64"))
  55. (
  56. PandasArray(np.array([1, 2], dtype=np.int32)),
  57. None,
  58. PandasArray(np.array([1, 2], dtype=np.int32)),
  59. ),
  60. # Period alias
  61. (
  62. [pd.Period("2000", "D"), pd.Period("2001", "D")],
  63. "Period[D]",
  64. period_array(["2000", "2001"], freq="D"),
  65. ),
  66. # Period dtype
  67. (
  68. [pd.Period("2000", "D")],
  69. pd.PeriodDtype("D"),
  70. period_array(["2000"], freq="D"),
  71. ),
  72. # Datetime (naive)
  73. (
  74. [1, 2],
  75. np.dtype("datetime64[ns]"),
  76. DatetimeArray._from_sequence(np.array([1, 2], dtype="datetime64[ns]")),
  77. ),
  78. (
  79. np.array([1, 2], dtype="datetime64[ns]"),
  80. None,
  81. DatetimeArray._from_sequence(np.array([1, 2], dtype="datetime64[ns]")),
  82. ),
  83. (
  84. pd.DatetimeIndex(["2000", "2001"]),
  85. np.dtype("datetime64[ns]"),
  86. DatetimeArray._from_sequence(["2000", "2001"]),
  87. ),
  88. (
  89. pd.DatetimeIndex(["2000", "2001"]),
  90. None,
  91. DatetimeArray._from_sequence(["2000", "2001"]),
  92. ),
  93. (
  94. ["2000", "2001"],
  95. np.dtype("datetime64[ns]"),
  96. DatetimeArray._from_sequence(["2000", "2001"]),
  97. ),
  98. # Datetime (tz-aware)
  99. (
  100. ["2000", "2001"],
  101. pd.DatetimeTZDtype(tz="CET"),
  102. DatetimeArray._from_sequence(
  103. ["2000", "2001"], dtype=pd.DatetimeTZDtype(tz="CET")
  104. ),
  105. ),
  106. # Timedelta
  107. (
  108. ["1H", "2H"],
  109. np.dtype("timedelta64[ns]"),
  110. TimedeltaArray._from_sequence(["1H", "2H"]),
  111. ),
  112. (
  113. pd.TimedeltaIndex(["1H", "2H"]),
  114. np.dtype("timedelta64[ns]"),
  115. TimedeltaArray._from_sequence(["1H", "2H"]),
  116. ),
  117. (
  118. pd.TimedeltaIndex(["1H", "2H"]),
  119. None,
  120. TimedeltaArray._from_sequence(["1H", "2H"]),
  121. ),
  122. (
  123. # preserve non-nano, i.e. don't cast to PandasArray
  124. TimedeltaArray._simple_new(
  125. np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]")
  126. ),
  127. None,
  128. TimedeltaArray._simple_new(
  129. np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]")
  130. ),
  131. ),
  132. (
  133. # preserve non-nano, i.e. don't cast to PandasArray
  134. TimedeltaArray._simple_new(
  135. np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]")
  136. ),
  137. np.dtype("m8[s]"),
  138. TimedeltaArray._simple_new(
  139. np.arange(5, dtype=np.int64).view("m8[s]"), dtype=np.dtype("m8[s]")
  140. ),
  141. ),
  142. # Category
  143. (["a", "b"], "category", pd.Categorical(["a", "b"])),
  144. (
  145. ["a", "b"],
  146. pd.CategoricalDtype(None, ordered=True),
  147. pd.Categorical(["a", "b"], ordered=True),
  148. ),
  149. # Interval
  150. (
  151. [pd.Interval(1, 2), pd.Interval(3, 4)],
  152. "interval",
  153. IntervalArray.from_tuples([(1, 2), (3, 4)]),
  154. ),
  155. # Sparse
  156. ([0, 1], "Sparse[int64]", SparseArray([0, 1], dtype="int64")),
  157. # IntegerNA
  158. ([1, None], "Int16", pd.array([1, None], dtype="Int16")),
  159. (pd.Series([1, 2]), None, PandasArray(np.array([1, 2], dtype=np.int64))),
  160. # String
  161. (
  162. ["a", None],
  163. "string",
  164. pd.StringDtype().construct_array_type()._from_sequence(["a", None]),
  165. ),
  166. (
  167. ["a", None],
  168. pd.StringDtype(),
  169. pd.StringDtype().construct_array_type()._from_sequence(["a", None]),
  170. ),
  171. # Boolean
  172. ([True, None], "boolean", BooleanArray._from_sequence([True, None])),
  173. ([True, None], pd.BooleanDtype(), BooleanArray._from_sequence([True, None])),
  174. # Index
  175. (pd.Index([1, 2]), None, PandasArray(np.array([1, 2], dtype=np.int64))),
  176. # Series[EA] returns the EA
  177. (
  178. pd.Series(pd.Categorical(["a", "b"], categories=["a", "b", "c"])),
  179. None,
  180. pd.Categorical(["a", "b"], categories=["a", "b", "c"]),
  181. ),
  182. # "3rd party" EAs work
  183. ([decimal.Decimal(0), decimal.Decimal(1)], "decimal", to_decimal([0, 1])),
  184. # pass an ExtensionArray, but a different dtype
  185. (
  186. period_array(["2000", "2001"], freq="D"),
  187. "category",
  188. pd.Categorical([pd.Period("2000", "D"), pd.Period("2001", "D")]),
  189. ),
  190. ],
  191. )
  192. def test_array(data, dtype, expected):
  193. result = pd.array(data, dtype=dtype)
  194. tm.assert_equal(result, expected)
  195. def test_array_copy():
  196. a = np.array([1, 2])
  197. # default is to copy
  198. b = pd.array(a, dtype=a.dtype)
  199. assert not tm.shares_memory(a, b)
  200. # copy=True
  201. b = pd.array(a, dtype=a.dtype, copy=True)
  202. assert not tm.shares_memory(a, b)
  203. # copy=False
  204. b = pd.array(a, dtype=a.dtype, copy=False)
  205. assert tm.shares_memory(a, b)
  206. cet = pytz.timezone("CET")
  207. @pytest.mark.parametrize(
  208. "data, expected",
  209. [
  210. # period
  211. (
  212. [pd.Period("2000", "D"), pd.Period("2001", "D")],
  213. period_array(["2000", "2001"], freq="D"),
  214. ),
  215. # interval
  216. ([pd.Interval(0, 1), pd.Interval(1, 2)], IntervalArray.from_breaks([0, 1, 2])),
  217. # datetime
  218. (
  219. [pd.Timestamp("2000"), pd.Timestamp("2001")],
  220. DatetimeArray._from_sequence(["2000", "2001"]),
  221. ),
  222. (
  223. [datetime.datetime(2000, 1, 1), datetime.datetime(2001, 1, 1)],
  224. DatetimeArray._from_sequence(["2000", "2001"]),
  225. ),
  226. (
  227. np.array([1, 2], dtype="M8[ns]"),
  228. DatetimeArray(np.array([1, 2], dtype="M8[ns]")),
  229. ),
  230. (
  231. np.array([1, 2], dtype="M8[us]"),
  232. DatetimeArray._simple_new(
  233. np.array([1, 2], dtype="M8[us]"), dtype=np.dtype("M8[us]")
  234. ),
  235. ),
  236. # datetimetz
  237. (
  238. [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2001", tz="CET")],
  239. DatetimeArray._from_sequence(
  240. ["2000", "2001"], dtype=pd.DatetimeTZDtype(tz="CET")
  241. ),
  242. ),
  243. (
  244. [
  245. datetime.datetime(2000, 1, 1, tzinfo=cet),
  246. datetime.datetime(2001, 1, 1, tzinfo=cet),
  247. ],
  248. DatetimeArray._from_sequence(
  249. ["2000", "2001"], dtype=pd.DatetimeTZDtype(tz=cet)
  250. ),
  251. ),
  252. # timedelta
  253. (
  254. [pd.Timedelta("1H"), pd.Timedelta("2H")],
  255. TimedeltaArray._from_sequence(["1H", "2H"]),
  256. ),
  257. (
  258. np.array([1, 2], dtype="m8[ns]"),
  259. TimedeltaArray(np.array([1, 2], dtype="m8[ns]")),
  260. ),
  261. (
  262. np.array([1, 2], dtype="m8[us]"),
  263. TimedeltaArray(np.array([1, 2], dtype="m8[us]")),
  264. ),
  265. # integer
  266. ([1, 2], IntegerArray._from_sequence([1, 2])),
  267. ([1, None], IntegerArray._from_sequence([1, None])),
  268. ([1, pd.NA], IntegerArray._from_sequence([1, pd.NA])),
  269. ([1, np.nan], IntegerArray._from_sequence([1, np.nan])),
  270. # float
  271. ([0.1, 0.2], FloatingArray._from_sequence([0.1, 0.2])),
  272. ([0.1, None], FloatingArray._from_sequence([0.1, pd.NA])),
  273. ([0.1, np.nan], FloatingArray._from_sequence([0.1, pd.NA])),
  274. ([0.1, pd.NA], FloatingArray._from_sequence([0.1, pd.NA])),
  275. # integer-like float
  276. ([1.0, 2.0], FloatingArray._from_sequence([1.0, 2.0])),
  277. ([1.0, None], FloatingArray._from_sequence([1.0, pd.NA])),
  278. ([1.0, np.nan], FloatingArray._from_sequence([1.0, pd.NA])),
  279. ([1.0, pd.NA], FloatingArray._from_sequence([1.0, pd.NA])),
  280. # mixed-integer-float
  281. ([1, 2.0], FloatingArray._from_sequence([1.0, 2.0])),
  282. ([1, np.nan, 2.0], FloatingArray._from_sequence([1.0, None, 2.0])),
  283. # string
  284. (
  285. ["a", "b"],
  286. pd.StringDtype().construct_array_type()._from_sequence(["a", "b"]),
  287. ),
  288. (
  289. ["a", None],
  290. pd.StringDtype().construct_array_type()._from_sequence(["a", None]),
  291. ),
  292. # Boolean
  293. ([True, False], BooleanArray._from_sequence([True, False])),
  294. ([True, None], BooleanArray._from_sequence([True, None])),
  295. ],
  296. )
  297. def test_array_inference(data, expected):
  298. result = pd.array(data)
  299. tm.assert_equal(result, expected)
  300. @pytest.mark.parametrize(
  301. "data",
  302. [
  303. # mix of frequencies
  304. [pd.Period("2000", "D"), pd.Period("2001", "A")],
  305. # mix of closed
  306. [pd.Interval(0, 1, closed="left"), pd.Interval(1, 2, closed="right")],
  307. # Mix of timezones
  308. [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2000", tz="UTC")],
  309. # Mix of tz-aware and tz-naive
  310. [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2000")],
  311. np.array([pd.Timestamp("2000"), pd.Timestamp("2000", tz="CET")]),
  312. ],
  313. )
  314. def test_array_inference_fails(data):
  315. result = pd.array(data)
  316. expected = PandasArray(np.array(data, dtype=object))
  317. tm.assert_extension_array_equal(result, expected)
  318. @pytest.mark.parametrize("data", [np.array(0)])
  319. def test_nd_raises(data):
  320. with pytest.raises(ValueError, match="PandasArray must be 1-dimensional"):
  321. pd.array(data, dtype="int64")
  322. def test_scalar_raises():
  323. with pytest.raises(ValueError, match="Cannot pass scalar '1'"):
  324. pd.array(1)
  325. def test_dataframe_raises():
  326. # GH#51167 don't accidentally cast to StringArray by doing inference on columns
  327. df = pd.DataFrame([[1, 2], [3, 4]], columns=["A", "B"])
  328. msg = "Cannot pass DataFrame to 'pandas.array'"
  329. with pytest.raises(TypeError, match=msg):
  330. pd.array(df)
  331. def test_bounds_check():
  332. # GH21796
  333. with pytest.raises(
  334. TypeError, match=r"cannot safely cast non-equivalent int(32|64) to uint16"
  335. ):
  336. pd.array([-1, 2, 3], dtype="UInt16")
  337. # ---------------------------------------------------------------------------
  338. # A couple dummy classes to ensure that Series and Indexes are unboxed before
  339. # getting to the EA classes.
  340. @register_extension_dtype
  341. class DecimalDtype2(DecimalDtype):
  342. name = "decimal2"
  343. @classmethod
  344. def construct_array_type(cls):
  345. """
  346. Return the array type associated with this dtype.
  347. Returns
  348. -------
  349. type
  350. """
  351. return DecimalArray2
  352. class DecimalArray2(DecimalArray):
  353. @classmethod
  354. def _from_sequence(cls, scalars, dtype=None, copy=False):
  355. if isinstance(scalars, (pd.Series, pd.Index)):
  356. raise TypeError("scalars should not be of type pd.Series or pd.Index")
  357. return super()._from_sequence(scalars, dtype=dtype, copy=copy)
  358. def test_array_unboxes(index_or_series):
  359. box = index_or_series
  360. data = box([decimal.Decimal("1"), decimal.Decimal("2")])
  361. # make sure it works
  362. with pytest.raises(
  363. TypeError, match="scalars should not be of type pd.Series or pd.Index"
  364. ):
  365. DecimalArray2._from_sequence(data)
  366. result = pd.array(data, dtype="decimal2")
  367. expected = DecimalArray2._from_sequence(data.values)
  368. tm.assert_equal(result, expected)
  369. @pytest.fixture
  370. def registry_without_decimal():
  371. """Fixture yielding 'registry' with no DecimalDtype entries"""
  372. idx = registry.dtypes.index(DecimalDtype)
  373. registry.dtypes.pop(idx)
  374. yield
  375. registry.dtypes.append(DecimalDtype)
  376. def test_array_not_registered(registry_without_decimal):
  377. # check we aren't on it
  378. assert registry.find("decimal") is None
  379. data = [decimal.Decimal("1"), decimal.Decimal("2")]
  380. result = pd.array(data, dtype=DecimalDtype)
  381. expected = DecimalArray._from_sequence(data)
  382. tm.assert_equal(result, expected)
  383. def test_array_to_numpy_na():
  384. # GH#40638
  385. arr = pd.array([pd.NA, 1], dtype="string")
  386. result = arr.to_numpy(na_value=True, dtype=bool)
  387. expected = np.array([True, True])
  388. tm.assert_numpy_array_equal(result, expected)