test_string_arrow.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import pickle
  2. import re
  3. import numpy as np
  4. import pytest
  5. from pandas.compat import pa_version_under7p0
  6. import pandas as pd
  7. import pandas._testing as tm
  8. from pandas.core.arrays.string_ import (
  9. StringArray,
  10. StringDtype,
  11. )
  12. from pandas.core.arrays.string_arrow import ArrowStringArray
  13. skip_if_no_pyarrow = pytest.mark.skipif(
  14. pa_version_under7p0,
  15. reason="pyarrow>=7.0.0 is required for PyArrow backed StringArray",
  16. )
  17. @skip_if_no_pyarrow
  18. def test_eq_all_na():
  19. a = pd.array([pd.NA, pd.NA], dtype=StringDtype("pyarrow"))
  20. result = a == a
  21. expected = pd.array([pd.NA, pd.NA], dtype="boolean[pyarrow]")
  22. tm.assert_extension_array_equal(result, expected)
  23. def test_config(string_storage):
  24. with pd.option_context("string_storage", string_storage):
  25. assert StringDtype().storage == string_storage
  26. result = pd.array(["a", "b"])
  27. assert result.dtype.storage == string_storage
  28. expected = (
  29. StringDtype(string_storage).construct_array_type()._from_sequence(["a", "b"])
  30. )
  31. tm.assert_equal(result, expected)
  32. def test_config_bad_storage_raises():
  33. msg = re.escape("Value must be one of python|pyarrow")
  34. with pytest.raises(ValueError, match=msg):
  35. pd.options.mode.string_storage = "foo"
  36. @skip_if_no_pyarrow
  37. @pytest.mark.parametrize("chunked", [True, False])
  38. @pytest.mark.parametrize("array", ["numpy", "pyarrow"])
  39. def test_constructor_not_string_type_raises(array, chunked):
  40. import pyarrow as pa
  41. array = pa if array == "pyarrow" else np
  42. arr = array.array([1, 2, 3])
  43. if chunked:
  44. if array is np:
  45. pytest.skip("chunked not applicable to numpy array")
  46. arr = pa.chunked_array(arr)
  47. if array is np:
  48. msg = "Unsupported type '<class 'numpy.ndarray'>' for ArrowExtensionArray"
  49. else:
  50. msg = re.escape(
  51. "ArrowStringArray requires a PyArrow (chunked) array of string type"
  52. )
  53. with pytest.raises(ValueError, match=msg):
  54. ArrowStringArray(arr)
  55. @skip_if_no_pyarrow
  56. def test_from_sequence_wrong_dtype_raises():
  57. with pd.option_context("string_storage", "python"):
  58. ArrowStringArray._from_sequence(["a", None, "c"], dtype="string")
  59. with pd.option_context("string_storage", "pyarrow"):
  60. ArrowStringArray._from_sequence(["a", None, "c"], dtype="string")
  61. with pytest.raises(AssertionError, match=None):
  62. ArrowStringArray._from_sequence(["a", None, "c"], dtype="string[python]")
  63. ArrowStringArray._from_sequence(["a", None, "c"], dtype="string[pyarrow]")
  64. with pytest.raises(AssertionError, match=None):
  65. with pd.option_context("string_storage", "python"):
  66. ArrowStringArray._from_sequence(["a", None, "c"], dtype=StringDtype())
  67. with pd.option_context("string_storage", "pyarrow"):
  68. ArrowStringArray._from_sequence(["a", None, "c"], dtype=StringDtype())
  69. with pytest.raises(AssertionError, match=None):
  70. ArrowStringArray._from_sequence(["a", None, "c"], dtype=StringDtype("python"))
  71. ArrowStringArray._from_sequence(["a", None, "c"], dtype=StringDtype("pyarrow"))
  72. with pd.option_context("string_storage", "python"):
  73. StringArray._from_sequence(["a", None, "c"], dtype="string")
  74. with pd.option_context("string_storage", "pyarrow"):
  75. StringArray._from_sequence(["a", None, "c"], dtype="string")
  76. StringArray._from_sequence(["a", None, "c"], dtype="string[python]")
  77. with pytest.raises(AssertionError, match=None):
  78. StringArray._from_sequence(["a", None, "c"], dtype="string[pyarrow]")
  79. with pd.option_context("string_storage", "python"):
  80. StringArray._from_sequence(["a", None, "c"], dtype=StringDtype())
  81. with pytest.raises(AssertionError, match=None):
  82. with pd.option_context("string_storage", "pyarrow"):
  83. StringArray._from_sequence(["a", None, "c"], dtype=StringDtype())
  84. StringArray._from_sequence(["a", None, "c"], dtype=StringDtype("python"))
  85. with pytest.raises(AssertionError, match=None):
  86. StringArray._from_sequence(["a", None, "c"], dtype=StringDtype("pyarrow"))
  87. @pytest.mark.skipif(
  88. not pa_version_under7p0,
  89. reason="pyarrow is installed",
  90. )
  91. def test_pyarrow_not_installed_raises():
  92. msg = re.escape("pyarrow>=7.0.0 is required for PyArrow backed")
  93. with pytest.raises(ImportError, match=msg):
  94. StringDtype(storage="pyarrow")
  95. with pytest.raises(ImportError, match=msg):
  96. ArrowStringArray([])
  97. with pytest.raises(ImportError, match=msg):
  98. ArrowStringArray._from_sequence(["a", None, "b"])
  99. @skip_if_no_pyarrow
  100. @pytest.mark.parametrize("multiple_chunks", [False, True])
  101. @pytest.mark.parametrize(
  102. "key, value, expected",
  103. [
  104. (-1, "XX", ["a", "b", "c", "d", "XX"]),
  105. (1, "XX", ["a", "XX", "c", "d", "e"]),
  106. (1, None, ["a", None, "c", "d", "e"]),
  107. (1, pd.NA, ["a", None, "c", "d", "e"]),
  108. ([1, 3], "XX", ["a", "XX", "c", "XX", "e"]),
  109. ([1, 3], ["XX", "YY"], ["a", "XX", "c", "YY", "e"]),
  110. ([1, 3], ["XX", None], ["a", "XX", "c", None, "e"]),
  111. ([1, 3], ["XX", pd.NA], ["a", "XX", "c", None, "e"]),
  112. ([0, -1], ["XX", "YY"], ["XX", "b", "c", "d", "YY"]),
  113. ([-1, 0], ["XX", "YY"], ["YY", "b", "c", "d", "XX"]),
  114. (slice(3, None), "XX", ["a", "b", "c", "XX", "XX"]),
  115. (slice(2, 4), ["XX", "YY"], ["a", "b", "XX", "YY", "e"]),
  116. (slice(3, 1, -1), ["XX", "YY"], ["a", "b", "YY", "XX", "e"]),
  117. (slice(None), "XX", ["XX", "XX", "XX", "XX", "XX"]),
  118. ([False, True, False, True, False], ["XX", "YY"], ["a", "XX", "c", "YY", "e"]),
  119. ],
  120. )
  121. def test_setitem(multiple_chunks, key, value, expected):
  122. import pyarrow as pa
  123. result = pa.array(list("abcde"))
  124. expected = pa.array(expected)
  125. if multiple_chunks:
  126. result = pa.chunked_array([result[:3], result[3:]])
  127. expected = pa.chunked_array([expected[:3], expected[3:]])
  128. result = ArrowStringArray(result)
  129. expected = ArrowStringArray(expected)
  130. result[key] = value
  131. tm.assert_equal(result, expected)
  132. @skip_if_no_pyarrow
  133. def test_setitem_invalid_indexer_raises():
  134. import pyarrow as pa
  135. arr = ArrowStringArray(pa.array(list("abcde")))
  136. with pytest.raises(IndexError, match=None):
  137. arr[5] = "foo"
  138. with pytest.raises(IndexError, match=None):
  139. arr[-6] = "foo"
  140. with pytest.raises(IndexError, match=None):
  141. arr[[0, 5]] = "foo"
  142. with pytest.raises(IndexError, match=None):
  143. arr[[0, -6]] = "foo"
  144. with pytest.raises(IndexError, match=None):
  145. arr[[True, True, False]] = "foo"
  146. with pytest.raises(ValueError, match=None):
  147. arr[[0, 1]] = ["foo", "bar", "baz"]
  148. @skip_if_no_pyarrow
  149. def test_pickle_roundtrip():
  150. # GH 42600
  151. expected = pd.Series(range(10), dtype="string[pyarrow]")
  152. expected_sliced = expected.head(2)
  153. full_pickled = pickle.dumps(expected)
  154. sliced_pickled = pickle.dumps(expected_sliced)
  155. assert len(full_pickled) > len(sliced_pickled)
  156. result = pickle.loads(full_pickled)
  157. tm.assert_series_equal(result, expected)
  158. result_sliced = pickle.loads(sliced_pickled)
  159. tm.assert_series_equal(result_sliced, expected_sliced)