test_take.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from datetime import datetime
  2. import re
  3. import numpy as np
  4. import pytest
  5. from pandas._libs import iNaT
  6. import pandas._testing as tm
  7. import pandas.core.algorithms as algos
  8. @pytest.fixture(
  9. params=[
  10. (np.int8, np.int16(127), np.int8),
  11. (np.int8, np.int16(128), np.int16),
  12. (np.int32, 1, np.int32),
  13. (np.int32, 2.0, np.float64),
  14. (np.int32, 3.0 + 4.0j, np.complex128),
  15. (np.int32, True, np.object_),
  16. (np.int32, "", np.object_),
  17. (np.float64, 1, np.float64),
  18. (np.float64, 2.0, np.float64),
  19. (np.float64, 3.0 + 4.0j, np.complex128),
  20. (np.float64, True, np.object_),
  21. (np.float64, "", np.object_),
  22. (np.complex128, 1, np.complex128),
  23. (np.complex128, 2.0, np.complex128),
  24. (np.complex128, 3.0 + 4.0j, np.complex128),
  25. (np.complex128, True, np.object_),
  26. (np.complex128, "", np.object_),
  27. (np.bool_, 1, np.object_),
  28. (np.bool_, 2.0, np.object_),
  29. (np.bool_, 3.0 + 4.0j, np.object_),
  30. (np.bool_, True, np.bool_),
  31. (np.bool_, "", np.object_),
  32. ]
  33. )
  34. def dtype_fill_out_dtype(request):
  35. return request.param
  36. class TestTake:
  37. # Standard incompatible fill error.
  38. fill_error = re.compile("Incompatible type for fill_value")
  39. def test_1d_fill_nonna(self, dtype_fill_out_dtype):
  40. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  41. data = np.random.randint(0, 2, 4).astype(dtype)
  42. indexer = [2, 1, 0, -1]
  43. result = algos.take_nd(data, indexer, fill_value=fill_value)
  44. assert (result[[0, 1, 2]] == data[[2, 1, 0]]).all()
  45. assert result[3] == fill_value
  46. assert result.dtype == out_dtype
  47. indexer = [2, 1, 0, 1]
  48. result = algos.take_nd(data, indexer, fill_value=fill_value)
  49. assert (result[[0, 1, 2, 3]] == data[indexer]).all()
  50. assert result.dtype == dtype
  51. def test_2d_fill_nonna(self, dtype_fill_out_dtype):
  52. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  53. data = np.random.randint(0, 2, (5, 3)).astype(dtype)
  54. indexer = [2, 1, 0, -1]
  55. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  56. assert (result[[0, 1, 2], :] == data[[2, 1, 0], :]).all()
  57. assert (result[3, :] == fill_value).all()
  58. assert result.dtype == out_dtype
  59. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  60. assert (result[:, [0, 1, 2]] == data[:, [2, 1, 0]]).all()
  61. assert (result[:, 3] == fill_value).all()
  62. assert result.dtype == out_dtype
  63. indexer = [2, 1, 0, 1]
  64. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  65. assert (result[[0, 1, 2, 3], :] == data[indexer, :]).all()
  66. assert result.dtype == dtype
  67. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  68. assert (result[:, [0, 1, 2, 3]] == data[:, indexer]).all()
  69. assert result.dtype == dtype
  70. def test_3d_fill_nonna(self, dtype_fill_out_dtype):
  71. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  72. data = np.random.randint(0, 2, (5, 4, 3)).astype(dtype)
  73. indexer = [2, 1, 0, -1]
  74. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  75. assert (result[[0, 1, 2], :, :] == data[[2, 1, 0], :, :]).all()
  76. assert (result[3, :, :] == fill_value).all()
  77. assert result.dtype == out_dtype
  78. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  79. assert (result[:, [0, 1, 2], :] == data[:, [2, 1, 0], :]).all()
  80. assert (result[:, 3, :] == fill_value).all()
  81. assert result.dtype == out_dtype
  82. result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value)
  83. assert (result[:, :, [0, 1, 2]] == data[:, :, [2, 1, 0]]).all()
  84. assert (result[:, :, 3] == fill_value).all()
  85. assert result.dtype == out_dtype
  86. indexer = [2, 1, 0, 1]
  87. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  88. assert (result[[0, 1, 2, 3], :, :] == data[indexer, :, :]).all()
  89. assert result.dtype == dtype
  90. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  91. assert (result[:, [0, 1, 2, 3], :] == data[:, indexer, :]).all()
  92. assert result.dtype == dtype
  93. result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value)
  94. assert (result[:, :, [0, 1, 2, 3]] == data[:, :, indexer]).all()
  95. assert result.dtype == dtype
  96. def test_1d_other_dtypes(self):
  97. arr = np.random.randn(10).astype(np.float32)
  98. indexer = [1, 2, 3, -1]
  99. result = algos.take_nd(arr, indexer)
  100. expected = arr.take(indexer)
  101. expected[-1] = np.nan
  102. tm.assert_almost_equal(result, expected)
  103. def test_2d_other_dtypes(self):
  104. arr = np.random.randn(10, 5).astype(np.float32)
  105. indexer = [1, 2, 3, -1]
  106. # axis=0
  107. result = algos.take_nd(arr, indexer, axis=0)
  108. expected = arr.take(indexer, axis=0)
  109. expected[-1] = np.nan
  110. tm.assert_almost_equal(result, expected)
  111. # axis=1
  112. result = algos.take_nd(arr, indexer, axis=1)
  113. expected = arr.take(indexer, axis=1)
  114. expected[:, -1] = np.nan
  115. tm.assert_almost_equal(result, expected)
  116. def test_1d_bool(self):
  117. arr = np.array([0, 1, 0], dtype=bool)
  118. result = algos.take_nd(arr, [0, 2, 2, 1])
  119. expected = arr.take([0, 2, 2, 1])
  120. tm.assert_numpy_array_equal(result, expected)
  121. result = algos.take_nd(arr, [0, 2, -1])
  122. assert result.dtype == np.object_
  123. def test_2d_bool(self):
  124. arr = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 1]], dtype=bool)
  125. result = algos.take_nd(arr, [0, 2, 2, 1])
  126. expected = arr.take([0, 2, 2, 1], axis=0)
  127. tm.assert_numpy_array_equal(result, expected)
  128. result = algos.take_nd(arr, [0, 2, 2, 1], axis=1)
  129. expected = arr.take([0, 2, 2, 1], axis=1)
  130. tm.assert_numpy_array_equal(result, expected)
  131. result = algos.take_nd(arr, [0, 2, -1])
  132. assert result.dtype == np.object_
  133. def test_2d_float32(self):
  134. arr = np.random.randn(4, 3).astype(np.float32)
  135. indexer = [0, 2, -1, 1, -1]
  136. # axis=0
  137. result = algos.take_nd(arr, indexer, axis=0)
  138. expected = arr.take(indexer, axis=0)
  139. expected[[2, 4], :] = np.nan
  140. tm.assert_almost_equal(result, expected)
  141. # axis=1
  142. result = algos.take_nd(arr, indexer, axis=1)
  143. expected = arr.take(indexer, axis=1)
  144. expected[:, [2, 4]] = np.nan
  145. tm.assert_almost_equal(result, expected)
  146. def test_2d_datetime64(self):
  147. # 2005/01/01 - 2006/01/01
  148. arr = np.random.randint(11_045_376, 11_360_736, (5, 3)) * 100_000_000_000
  149. arr = arr.view(dtype="datetime64[ns]")
  150. indexer = [0, 2, -1, 1, -1]
  151. # axis=0
  152. result = algos.take_nd(arr, indexer, axis=0)
  153. expected = arr.take(indexer, axis=0)
  154. expected.view(np.int64)[[2, 4], :] = iNaT
  155. tm.assert_almost_equal(result, expected)
  156. result = algos.take_nd(arr, indexer, axis=0, fill_value=datetime(2007, 1, 1))
  157. expected = arr.take(indexer, axis=0)
  158. expected[[2, 4], :] = datetime(2007, 1, 1)
  159. tm.assert_almost_equal(result, expected)
  160. # axis=1
  161. result = algos.take_nd(arr, indexer, axis=1)
  162. expected = arr.take(indexer, axis=1)
  163. expected.view(np.int64)[:, [2, 4]] = iNaT
  164. tm.assert_almost_equal(result, expected)
  165. result = algos.take_nd(arr, indexer, axis=1, fill_value=datetime(2007, 1, 1))
  166. expected = arr.take(indexer, axis=1)
  167. expected[:, [2, 4]] = datetime(2007, 1, 1)
  168. tm.assert_almost_equal(result, expected)
  169. def test_take_axis_0(self):
  170. arr = np.arange(12).reshape(4, 3)
  171. result = algos.take(arr, [0, -1])
  172. expected = np.array([[0, 1, 2], [9, 10, 11]])
  173. tm.assert_numpy_array_equal(result, expected)
  174. # allow_fill=True
  175. result = algos.take(arr, [0, -1], allow_fill=True, fill_value=0)
  176. expected = np.array([[0, 1, 2], [0, 0, 0]])
  177. tm.assert_numpy_array_equal(result, expected)
  178. def test_take_axis_1(self):
  179. arr = np.arange(12).reshape(4, 3)
  180. result = algos.take(arr, [0, -1], axis=1)
  181. expected = np.array([[0, 2], [3, 5], [6, 8], [9, 11]])
  182. tm.assert_numpy_array_equal(result, expected)
  183. # allow_fill=True
  184. result = algos.take(arr, [0, -1], axis=1, allow_fill=True, fill_value=0)
  185. expected = np.array([[0, 0], [3, 0], [6, 0], [9, 0]])
  186. tm.assert_numpy_array_equal(result, expected)
  187. # GH#26976 make sure we validate along the correct axis
  188. with pytest.raises(IndexError, match="indices are out-of-bounds"):
  189. algos.take(arr, [0, 3], axis=1, allow_fill=True, fill_value=0)
  190. def test_take_non_hashable_fill_value(self):
  191. arr = np.array([1, 2, 3])
  192. indexer = np.array([1, -1])
  193. with pytest.raises(ValueError, match="fill_value must be a scalar"):
  194. algos.take(arr, indexer, allow_fill=True, fill_value=[1])
  195. # with object dtype it is allowed
  196. arr = np.array([1, 2, 3], dtype=object)
  197. result = algos.take(arr, indexer, allow_fill=True, fill_value=[1])
  198. expected = np.array([2, [1]], dtype=object)
  199. tm.assert_numpy_array_equal(result, expected)
  200. class TestExtensionTake:
  201. # The take method found in pd.api.extensions
  202. def test_bounds_check_large(self):
  203. arr = np.array([1, 2])
  204. msg = "indices are out-of-bounds"
  205. with pytest.raises(IndexError, match=msg):
  206. algos.take(arr, [2, 3], allow_fill=True)
  207. msg = "index 2 is out of bounds for( axis 0 with)? size 2"
  208. with pytest.raises(IndexError, match=msg):
  209. algos.take(arr, [2, 3], allow_fill=False)
  210. def test_bounds_check_small(self):
  211. arr = np.array([1, 2, 3], dtype=np.int64)
  212. indexer = [0, -1, -2]
  213. msg = r"'indices' contains values less than allowed \(-2 < -1\)"
  214. with pytest.raises(ValueError, match=msg):
  215. algos.take(arr, indexer, allow_fill=True)
  216. result = algos.take(arr, indexer)
  217. expected = np.array([1, 3, 2], dtype=np.int64)
  218. tm.assert_numpy_array_equal(result, expected)
  219. @pytest.mark.parametrize("allow_fill", [True, False])
  220. def test_take_empty(self, allow_fill):
  221. arr = np.array([], dtype=np.int64)
  222. # empty take is ok
  223. result = algos.take(arr, [], allow_fill=allow_fill)
  224. tm.assert_numpy_array_equal(arr, result)
  225. msg = "|".join(
  226. [
  227. "cannot do a non-empty take from an empty axes.",
  228. "indices are out-of-bounds",
  229. ]
  230. )
  231. with pytest.raises(IndexError, match=msg):
  232. algos.take(arr, [0], allow_fill=allow_fill)
  233. def test_take_na_empty(self):
  234. result = algos.take(np.array([]), [-1, -1], allow_fill=True, fill_value=0.0)
  235. expected = np.array([0.0, 0.0])
  236. tm.assert_numpy_array_equal(result, expected)
  237. def test_take_coerces_list(self):
  238. arr = [1, 2, 3]
  239. result = algos.take(arr, [0, 0])
  240. expected = np.array([1, 1])
  241. tm.assert_numpy_array_equal(result, expected)