test_to_xarray.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import numpy as np
  2. import pytest
  3. import pandas.util._test_decorators as td
  4. from pandas import (
  5. Categorical,
  6. DataFrame,
  7. MultiIndex,
  8. Series,
  9. date_range,
  10. )
  11. import pandas._testing as tm
  12. @td.skip_if_no("xarray")
  13. class TestDataFrameToXArray:
  14. @pytest.fixture
  15. def df(self):
  16. return DataFrame(
  17. {
  18. "a": list("abc"),
  19. "b": list(range(1, 4)),
  20. "c": np.arange(3, 6).astype("u1"),
  21. "d": np.arange(4.0, 7.0, dtype="float64"),
  22. "e": [True, False, True],
  23. "f": Categorical(list("abc")),
  24. "g": date_range("20130101", periods=3),
  25. "h": date_range("20130101", periods=3, tz="US/Eastern"),
  26. }
  27. )
  28. def test_to_xarray_index_types(self, index_flat, df):
  29. index = index_flat
  30. # MultiIndex is tested in test_to_xarray_with_multiindex
  31. if len(index) == 0:
  32. pytest.skip("Test doesn't make sense for empty index")
  33. from xarray import Dataset
  34. df.index = index[:3]
  35. df.index.name = "foo"
  36. df.columns.name = "bar"
  37. result = df.to_xarray()
  38. assert result.dims["foo"] == 3
  39. assert len(result.coords) == 1
  40. assert len(result.data_vars) == 8
  41. tm.assert_almost_equal(list(result.coords.keys()), ["foo"])
  42. assert isinstance(result, Dataset)
  43. # idempotency
  44. # datetimes w/tz are preserved
  45. # column names are lost
  46. expected = df.copy()
  47. expected["f"] = expected["f"].astype(object)
  48. expected.columns.name = None
  49. tm.assert_frame_equal(result.to_dataframe(), expected)
  50. def test_to_xarray_empty(self, df):
  51. from xarray import Dataset
  52. df.index.name = "foo"
  53. result = df[0:0].to_xarray()
  54. assert result.dims["foo"] == 0
  55. assert isinstance(result, Dataset)
  56. def test_to_xarray_with_multiindex(self, df):
  57. from xarray import Dataset
  58. # MultiIndex
  59. df.index = MultiIndex.from_product([["a"], range(3)], names=["one", "two"])
  60. result = df.to_xarray()
  61. assert result.dims["one"] == 1
  62. assert result.dims["two"] == 3
  63. assert len(result.coords) == 2
  64. assert len(result.data_vars) == 8
  65. tm.assert_almost_equal(list(result.coords.keys()), ["one", "two"])
  66. assert isinstance(result, Dataset)
  67. result = result.to_dataframe()
  68. expected = df.copy()
  69. expected["f"] = expected["f"].astype(object)
  70. expected.columns.name = None
  71. tm.assert_frame_equal(result, expected)
  72. @td.skip_if_no("xarray")
  73. class TestSeriesToXArray:
  74. def test_to_xarray_index_types(self, index_flat):
  75. index = index_flat
  76. # MultiIndex is tested in test_to_xarray_with_multiindex
  77. from xarray import DataArray
  78. ser = Series(range(len(index)), index=index, dtype="int64")
  79. ser.index.name = "foo"
  80. result = ser.to_xarray()
  81. repr(result)
  82. assert len(result) == len(index)
  83. assert len(result.coords) == 1
  84. tm.assert_almost_equal(list(result.coords.keys()), ["foo"])
  85. assert isinstance(result, DataArray)
  86. # idempotency
  87. tm.assert_series_equal(result.to_series(), ser)
  88. def test_to_xarray_empty(self):
  89. from xarray import DataArray
  90. ser = Series([], dtype=object)
  91. ser.index.name = "foo"
  92. result = ser.to_xarray()
  93. assert len(result) == 0
  94. assert len(result.coords) == 1
  95. tm.assert_almost_equal(list(result.coords.keys()), ["foo"])
  96. assert isinstance(result, DataArray)
  97. def test_to_xarray_with_multiindex(self):
  98. from xarray import DataArray
  99. mi = MultiIndex.from_product([["a", "b"], range(3)], names=["one", "two"])
  100. ser = Series(range(6), dtype="int64", index=mi)
  101. result = ser.to_xarray()
  102. assert len(result) == 2
  103. tm.assert_almost_equal(list(result.coords.keys()), ["one", "two"])
  104. assert isinstance(result, DataArray)
  105. res = result.to_series()
  106. tm.assert_series_equal(res, ser)