test_set_axis.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import numpy as np
  2. import pytest
  3. from pandas import (
  4. DataFrame,
  5. Series,
  6. )
  7. import pandas._testing as tm
  8. class SharedSetAxisTests:
  9. @pytest.fixture
  10. def obj(self):
  11. raise NotImplementedError("Implemented by subclasses")
  12. def test_set_axis(self, obj):
  13. # GH14636; this tests setting index for both Series and DataFrame
  14. new_index = list("abcd")[: len(obj)]
  15. expected = obj.copy()
  16. expected.index = new_index
  17. result = obj.set_axis(new_index, axis=0)
  18. tm.assert_equal(expected, result)
  19. def test_set_axis_copy(self, obj, using_copy_on_write):
  20. # Test copy keyword GH#47932
  21. new_index = list("abcd")[: len(obj)]
  22. orig = obj.iloc[:]
  23. expected = obj.copy()
  24. expected.index = new_index
  25. result = obj.set_axis(new_index, axis=0, copy=True)
  26. tm.assert_equal(expected, result)
  27. assert result is not obj
  28. # check we DID make a copy
  29. if not using_copy_on_write:
  30. if obj.ndim == 1:
  31. assert not tm.shares_memory(result, obj)
  32. else:
  33. assert not any(
  34. tm.shares_memory(result.iloc[:, i], obj.iloc[:, i])
  35. for i in range(obj.shape[1])
  36. )
  37. result = obj.set_axis(new_index, axis=0, copy=False)
  38. tm.assert_equal(expected, result)
  39. assert result is not obj
  40. # check we did NOT make a copy
  41. if obj.ndim == 1:
  42. assert tm.shares_memory(result, obj)
  43. else:
  44. assert all(
  45. tm.shares_memory(result.iloc[:, i], obj.iloc[:, i])
  46. for i in range(obj.shape[1])
  47. )
  48. # copy defaults to True
  49. result = obj.set_axis(new_index, axis=0)
  50. tm.assert_equal(expected, result)
  51. assert result is not obj
  52. if using_copy_on_write:
  53. # check we DID NOT make a copy
  54. if obj.ndim == 1:
  55. assert tm.shares_memory(result, obj)
  56. else:
  57. assert any(
  58. tm.shares_memory(result.iloc[:, i], obj.iloc[:, i])
  59. for i in range(obj.shape[1])
  60. )
  61. else:
  62. # check we DID make a copy
  63. if obj.ndim == 1:
  64. assert not tm.shares_memory(result, obj)
  65. else:
  66. assert not any(
  67. tm.shares_memory(result.iloc[:, i], obj.iloc[:, i])
  68. for i in range(obj.shape[1])
  69. )
  70. res = obj.set_axis(new_index, copy=False)
  71. tm.assert_equal(expected, res)
  72. # check we did NOT make a copy
  73. if res.ndim == 1:
  74. assert tm.shares_memory(res, orig)
  75. else:
  76. assert all(
  77. tm.shares_memory(res.iloc[:, i], orig.iloc[:, i])
  78. for i in range(res.shape[1])
  79. )
  80. def test_set_axis_unnamed_kwarg_warns(self, obj):
  81. # omitting the "axis" parameter
  82. new_index = list("abcd")[: len(obj)]
  83. expected = obj.copy()
  84. expected.index = new_index
  85. result = obj.set_axis(new_index)
  86. tm.assert_equal(result, expected)
  87. @pytest.mark.parametrize("axis", [3, "foo"])
  88. def test_set_axis_invalid_axis_name(self, axis, obj):
  89. # wrong values for the "axis" parameter
  90. with pytest.raises(ValueError, match="No axis named"):
  91. obj.set_axis(list("abc"), axis=axis)
  92. def test_set_axis_setattr_index_not_collection(self, obj):
  93. # wrong type
  94. msg = (
  95. r"Index\(\.\.\.\) must be called with a collection of some "
  96. r"kind, None was passed"
  97. )
  98. with pytest.raises(TypeError, match=msg):
  99. obj.index = None
  100. def test_set_axis_setattr_index_wrong_length(self, obj):
  101. # wrong length
  102. msg = (
  103. f"Length mismatch: Expected axis has {len(obj)} elements, "
  104. f"new values have {len(obj)-1} elements"
  105. )
  106. with pytest.raises(ValueError, match=msg):
  107. obj.index = np.arange(len(obj) - 1)
  108. if obj.ndim == 2:
  109. with pytest.raises(ValueError, match="Length mismatch"):
  110. obj.columns = obj.columns[::2]
  111. class TestDataFrameSetAxis(SharedSetAxisTests):
  112. @pytest.fixture
  113. def obj(self):
  114. df = DataFrame(
  115. {"A": [1.1, 2.2, 3.3], "B": [5.0, 6.1, 7.2], "C": [4.4, 5.5, 6.6]},
  116. index=[2010, 2011, 2012],
  117. )
  118. return df
  119. class TestSeriesSetAxis(SharedSetAxisTests):
  120. @pytest.fixture
  121. def obj(self):
  122. ser = Series(np.arange(4), index=[1, 3, 5, 7], dtype="int64")
  123. return ser