test_arrow_compat.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import pytest
  2. from pandas.compat.pyarrow import pa_version_under10p0
  3. from pandas.core.dtypes.dtypes import PeriodDtype
  4. import pandas as pd
  5. import pandas._testing as tm
  6. from pandas.core.arrays import (
  7. PeriodArray,
  8. period_array,
  9. )
  10. pa = pytest.importorskip("pyarrow", minversion="1.0.1")
  11. def test_arrow_extension_type():
  12. from pandas.core.arrays.arrow.extension_types import ArrowPeriodType
  13. p1 = ArrowPeriodType("D")
  14. p2 = ArrowPeriodType("D")
  15. p3 = ArrowPeriodType("M")
  16. assert p1.freq == "D"
  17. assert p1 == p2
  18. assert p1 != p3
  19. assert hash(p1) == hash(p2)
  20. assert hash(p1) != hash(p3)
  21. @pytest.mark.xfail(not pa_version_under10p0, reason="Wrong behavior with pyarrow 10")
  22. @pytest.mark.parametrize(
  23. "data, freq",
  24. [
  25. (pd.date_range("2017", periods=3), "D"),
  26. (pd.date_range("2017", periods=3, freq="A"), "A-DEC"),
  27. ],
  28. )
  29. def test_arrow_array(data, freq):
  30. from pandas.core.arrays.arrow.extension_types import ArrowPeriodType
  31. periods = period_array(data, freq=freq)
  32. result = pa.array(periods)
  33. assert isinstance(result.type, ArrowPeriodType)
  34. assert result.type.freq == freq
  35. expected = pa.array(periods.asi8, type="int64")
  36. assert result.storage.equals(expected)
  37. # convert to its storage type
  38. result = pa.array(periods, type=pa.int64())
  39. assert result.equals(expected)
  40. # unsupported conversions
  41. msg = "Not supported to convert PeriodArray to 'double' type"
  42. with pytest.raises(TypeError, match=msg):
  43. pa.array(periods, type="float64")
  44. with pytest.raises(TypeError, match="different 'freq'"):
  45. pa.array(periods, type=ArrowPeriodType("T"))
  46. def test_arrow_array_missing():
  47. from pandas.core.arrays.arrow.extension_types import ArrowPeriodType
  48. arr = PeriodArray([1, 2, 3], freq="D")
  49. arr[1] = pd.NaT
  50. result = pa.array(arr)
  51. assert isinstance(result.type, ArrowPeriodType)
  52. assert result.type.freq == "D"
  53. expected = pa.array([1, None, 3], type="int64")
  54. assert result.storage.equals(expected)
  55. def test_arrow_table_roundtrip():
  56. from pandas.core.arrays.arrow.extension_types import ArrowPeriodType
  57. arr = PeriodArray([1, 2, 3], freq="D")
  58. arr[1] = pd.NaT
  59. df = pd.DataFrame({"a": arr})
  60. table = pa.table(df)
  61. assert isinstance(table.field("a").type, ArrowPeriodType)
  62. result = table.to_pandas()
  63. assert isinstance(result["a"].dtype, PeriodDtype)
  64. tm.assert_frame_equal(result, df)
  65. table2 = pa.concat_tables([table, table])
  66. result = table2.to_pandas()
  67. expected = pd.concat([df, df], ignore_index=True)
  68. tm.assert_frame_equal(result, expected)
  69. def test_arrow_load_from_zero_chunks():
  70. # GH-41040
  71. from pandas.core.arrays.arrow.extension_types import ArrowPeriodType
  72. arr = PeriodArray([], freq="D")
  73. df = pd.DataFrame({"a": arr})
  74. table = pa.table(df)
  75. assert isinstance(table.field("a").type, ArrowPeriodType)
  76. table = pa.table(
  77. [pa.chunked_array([], type=table.column(0).type)], schema=table.schema
  78. )
  79. result = table.to_pandas()
  80. assert isinstance(result["a"].dtype, PeriodDtype)
  81. tm.assert_frame_equal(result, df)
  82. def test_arrow_table_roundtrip_without_metadata():
  83. arr = PeriodArray([1, 2, 3], freq="H")
  84. arr[1] = pd.NaT
  85. df = pd.DataFrame({"a": arr})
  86. table = pa.table(df)
  87. # remove the metadata
  88. table = table.replace_schema_metadata()
  89. assert table.schema.metadata is None
  90. result = table.to_pandas()
  91. assert isinstance(result["a"].dtype, PeriodDtype)
  92. tm.assert_frame_equal(result, df)