test_feather.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. """ test feather-format compat """
  2. import numpy as np
  3. import pytest
  4. import pandas as pd
  5. import pandas._testing as tm
  6. from pandas.core.arrays import (
  7. ArrowStringArray,
  8. StringArray,
  9. )
  10. from pandas.io.feather_format import read_feather, to_feather # isort:skip
  11. pyarrow = pytest.importorskip("pyarrow", minversion="1.0.1")
  12. @pytest.mark.single_cpu
  13. class TestFeather:
  14. def check_error_on_write(self, df, exc, err_msg):
  15. # check that we are raising the exception
  16. # on writing
  17. with pytest.raises(exc, match=err_msg):
  18. with tm.ensure_clean() as path:
  19. to_feather(df, path)
  20. def check_external_error_on_write(self, df):
  21. # check that we are raising the exception
  22. # on writing
  23. with tm.external_error_raised(Exception):
  24. with tm.ensure_clean() as path:
  25. to_feather(df, path)
  26. def check_round_trip(self, df, expected=None, write_kwargs={}, **read_kwargs):
  27. if expected is None:
  28. expected = df
  29. with tm.ensure_clean() as path:
  30. to_feather(df, path, **write_kwargs)
  31. result = read_feather(path, **read_kwargs)
  32. tm.assert_frame_equal(result, expected)
  33. def test_error(self):
  34. msg = "feather only support IO with DataFrames"
  35. for obj in [
  36. pd.Series([1, 2, 3]),
  37. 1,
  38. "foo",
  39. pd.Timestamp("20130101"),
  40. np.array([1, 2, 3]),
  41. ]:
  42. self.check_error_on_write(obj, ValueError, msg)
  43. def test_basic(self):
  44. df = pd.DataFrame(
  45. {
  46. "string": list("abc"),
  47. "int": list(range(1, 4)),
  48. "uint": np.arange(3, 6).astype("u1"),
  49. "float": np.arange(4.0, 7.0, dtype="float64"),
  50. "float_with_null": [1.0, np.nan, 3],
  51. "bool": [True, False, True],
  52. "bool_with_null": [True, np.nan, False],
  53. "cat": pd.Categorical(list("abc")),
  54. "dt": pd.DatetimeIndex(
  55. list(pd.date_range("20130101", periods=3)), freq=None
  56. ),
  57. "dttz": pd.DatetimeIndex(
  58. list(pd.date_range("20130101", periods=3, tz="US/Eastern")),
  59. freq=None,
  60. ),
  61. "dt_with_null": [
  62. pd.Timestamp("20130101"),
  63. pd.NaT,
  64. pd.Timestamp("20130103"),
  65. ],
  66. "dtns": pd.DatetimeIndex(
  67. list(pd.date_range("20130101", periods=3, freq="ns")), freq=None
  68. ),
  69. }
  70. )
  71. df["periods"] = pd.period_range("2013", freq="M", periods=3)
  72. df["timedeltas"] = pd.timedelta_range("1 day", periods=3)
  73. df["intervals"] = pd.interval_range(0, 3, 3)
  74. assert df.dttz.dtype.tz.zone == "US/Eastern"
  75. self.check_round_trip(df)
  76. def test_duplicate_columns(self):
  77. # https://github.com/wesm/feather/issues/53
  78. # not currently able to handle duplicate columns
  79. df = pd.DataFrame(np.arange(12).reshape(4, 3), columns=list("aaa")).copy()
  80. self.check_external_error_on_write(df)
  81. def test_stringify_columns(self):
  82. df = pd.DataFrame(np.arange(12).reshape(4, 3)).copy()
  83. msg = "feather must have string column names"
  84. self.check_error_on_write(df, ValueError, msg)
  85. def test_read_columns(self):
  86. # GH 24025
  87. df = pd.DataFrame(
  88. {
  89. "col1": list("abc"),
  90. "col2": list(range(1, 4)),
  91. "col3": list("xyz"),
  92. "col4": list(range(4, 7)),
  93. }
  94. )
  95. columns = ["col1", "col3"]
  96. self.check_round_trip(df, expected=df[columns], columns=columns)
  97. def test_read_columns_different_order(self):
  98. # GH 33878
  99. df = pd.DataFrame({"A": [1, 2], "B": ["x", "y"], "C": [True, False]})
  100. expected = df[["B", "A"]]
  101. self.check_round_trip(df, expected, columns=["B", "A"])
  102. def test_unsupported_other(self):
  103. # mixed python objects
  104. df = pd.DataFrame({"a": ["a", 1, 2.0]})
  105. self.check_external_error_on_write(df)
  106. def test_rw_use_threads(self):
  107. df = pd.DataFrame({"A": np.arange(100000)})
  108. self.check_round_trip(df, use_threads=True)
  109. self.check_round_trip(df, use_threads=False)
  110. def test_write_with_index(self):
  111. df = pd.DataFrame({"A": [1, 2, 3]})
  112. self.check_round_trip(df)
  113. msg = (
  114. r"feather does not support serializing .* for the index; "
  115. r"you can \.reset_index\(\) to make the index into column\(s\)"
  116. )
  117. # non-default index
  118. for index in [
  119. [2, 3, 4],
  120. pd.date_range("20130101", periods=3),
  121. list("abc"),
  122. [1, 3, 4],
  123. pd.MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1)]),
  124. ]:
  125. df.index = index
  126. self.check_error_on_write(df, ValueError, msg)
  127. # index with meta-data
  128. df.index = [0, 1, 2]
  129. df.index.name = "foo"
  130. msg = "feather does not serialize index meta-data on a default index"
  131. self.check_error_on_write(df, ValueError, msg)
  132. # column multi-index
  133. df.index = [0, 1, 2]
  134. df.columns = pd.MultiIndex.from_tuples([("a", 1)])
  135. msg = "feather must have string column names"
  136. self.check_error_on_write(df, ValueError, msg)
  137. def test_path_pathlib(self):
  138. df = tm.makeDataFrame().reset_index()
  139. result = tm.round_trip_pathlib(df.to_feather, read_feather)
  140. tm.assert_frame_equal(df, result)
  141. def test_path_localpath(self):
  142. df = tm.makeDataFrame().reset_index()
  143. result = tm.round_trip_localpath(df.to_feather, read_feather)
  144. tm.assert_frame_equal(df, result)
  145. def test_passthrough_keywords(self):
  146. df = tm.makeDataFrame().reset_index()
  147. self.check_round_trip(df, write_kwargs={"version": 1})
  148. @pytest.mark.network
  149. @tm.network(
  150. url=(
  151. "https://raw.githubusercontent.com/pandas-dev/pandas/main/"
  152. "pandas/tests/io/data/feather/feather-0_3_1.feather"
  153. ),
  154. check_before_test=True,
  155. )
  156. def test_http_path(self, feather_file):
  157. # GH 29055
  158. url = (
  159. "https://raw.githubusercontent.com/pandas-dev/pandas/main/"
  160. "pandas/tests/io/data/feather/feather-0_3_1.feather"
  161. )
  162. expected = read_feather(feather_file)
  163. res = read_feather(url)
  164. tm.assert_frame_equal(expected, res)
  165. def test_read_feather_dtype_backend(self, string_storage, dtype_backend):
  166. # GH#50765
  167. pa = pytest.importorskip("pyarrow")
  168. df = pd.DataFrame(
  169. {
  170. "a": pd.Series([1, np.nan, 3], dtype="Int64"),
  171. "b": pd.Series([1, 2, 3], dtype="Int64"),
  172. "c": pd.Series([1.5, np.nan, 2.5], dtype="Float64"),
  173. "d": pd.Series([1.5, 2.0, 2.5], dtype="Float64"),
  174. "e": [True, False, None],
  175. "f": [True, False, True],
  176. "g": ["a", "b", "c"],
  177. "h": ["a", "b", None],
  178. }
  179. )
  180. if string_storage == "python":
  181. string_array = StringArray(np.array(["a", "b", "c"], dtype=np.object_))
  182. string_array_na = StringArray(np.array(["a", "b", pd.NA], dtype=np.object_))
  183. else:
  184. string_array = ArrowStringArray(pa.array(["a", "b", "c"]))
  185. string_array_na = ArrowStringArray(pa.array(["a", "b", None]))
  186. with tm.ensure_clean() as path:
  187. to_feather(df, path)
  188. with pd.option_context("mode.string_storage", string_storage):
  189. result = read_feather(path, dtype_backend=dtype_backend)
  190. expected = pd.DataFrame(
  191. {
  192. "a": pd.Series([1, np.nan, 3], dtype="Int64"),
  193. "b": pd.Series([1, 2, 3], dtype="Int64"),
  194. "c": pd.Series([1.5, np.nan, 2.5], dtype="Float64"),
  195. "d": pd.Series([1.5, 2.0, 2.5], dtype="Float64"),
  196. "e": pd.Series([True, False, pd.NA], dtype="boolean"),
  197. "f": pd.Series([True, False, True], dtype="boolean"),
  198. "g": string_array,
  199. "h": string_array_na,
  200. }
  201. )
  202. if dtype_backend == "pyarrow":
  203. from pandas.arrays import ArrowExtensionArray
  204. expected = pd.DataFrame(
  205. {
  206. col: ArrowExtensionArray(pa.array(expected[col], from_pandas=True))
  207. for col in expected.columns
  208. }
  209. )
  210. tm.assert_frame_equal(result, expected)
  211. def test_invalid_dtype_backend(self):
  212. msg = (
  213. "dtype_backend numpy is invalid, only 'numpy_nullable' and "
  214. "'pyarrow' are allowed."
  215. )
  216. df = pd.DataFrame({"int": list(range(1, 4))})
  217. with tm.ensure_clean("tmp.feather") as path:
  218. df.to_feather(path)
  219. with pytest.raises(ValueError, match=msg):
  220. read_feather(path, dtype_backend="numpy")