test_common.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import numpy as np
  2. import pytest
  3. from pandas.core.dtypes import dtypes
  4. from pandas.core.dtypes.common import is_extension_array_dtype
  5. import pandas as pd
  6. import pandas._testing as tm
  7. from pandas.core.arrays import ExtensionArray
  8. class DummyDtype(dtypes.ExtensionDtype):
  9. pass
  10. class DummyArray(ExtensionArray):
  11. def __init__(self, data) -> None:
  12. self.data = data
  13. def __array__(self, dtype):
  14. return self.data
  15. @property
  16. def dtype(self):
  17. return DummyDtype()
  18. def astype(self, dtype, copy=True):
  19. # we don't support anything but a single dtype
  20. if isinstance(dtype, DummyDtype):
  21. if copy:
  22. return type(self)(self.data)
  23. return self
  24. return np.array(self, dtype=dtype, copy=copy)
  25. class TestExtensionArrayDtype:
  26. @pytest.mark.parametrize(
  27. "values",
  28. [
  29. pd.Categorical([]),
  30. pd.Categorical([]).dtype,
  31. pd.Series(pd.Categorical([])),
  32. DummyDtype(),
  33. DummyArray(np.array([1, 2])),
  34. ],
  35. )
  36. def test_is_extension_array_dtype(self, values):
  37. assert is_extension_array_dtype(values)
  38. @pytest.mark.parametrize("values", [np.array([]), pd.Series(np.array([]))])
  39. def test_is_not_extension_array_dtype(self, values):
  40. assert not is_extension_array_dtype(values)
  41. def test_astype():
  42. arr = DummyArray(np.array([1, 2, 3]))
  43. expected = np.array([1, 2, 3], dtype=object)
  44. result = arr.astype(object)
  45. tm.assert_numpy_array_equal(result, expected)
  46. result = arr.astype("object")
  47. tm.assert_numpy_array_equal(result, expected)
  48. def test_astype_no_copy():
  49. arr = DummyArray(np.array([1, 2, 3], dtype=np.int64))
  50. result = arr.astype(arr.dtype, copy=False)
  51. assert arr is result
  52. result = arr.astype(arr.dtype)
  53. assert arr is not result
  54. @pytest.mark.parametrize("dtype", [dtypes.CategoricalDtype(), dtypes.IntervalDtype()])
  55. def test_is_extension_array_dtype(dtype):
  56. assert isinstance(dtype, dtypes.ExtensionDtype)
  57. assert is_extension_array_dtype(dtype)