test_where.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. import numpy as np
  2. import pytest
  3. from pandas.core.dtypes.common import is_integer
  4. import pandas as pd
  5. from pandas import (
  6. Series,
  7. Timestamp,
  8. date_range,
  9. isna,
  10. )
  11. import pandas._testing as tm
  12. def test_where_unsafe_int(any_signed_int_numpy_dtype):
  13. s = Series(np.arange(10), dtype=any_signed_int_numpy_dtype)
  14. mask = s < 5
  15. s[mask] = range(2, 7)
  16. expected = Series(
  17. list(range(2, 7)) + list(range(5, 10)),
  18. dtype=any_signed_int_numpy_dtype,
  19. )
  20. tm.assert_series_equal(s, expected)
  21. def test_where_unsafe_float(float_numpy_dtype):
  22. s = Series(np.arange(10), dtype=float_numpy_dtype)
  23. mask = s < 5
  24. s[mask] = range(2, 7)
  25. data = list(range(2, 7)) + list(range(5, 10))
  26. expected = Series(data, dtype=float_numpy_dtype)
  27. tm.assert_series_equal(s, expected)
  28. @pytest.mark.parametrize(
  29. "dtype,expected_dtype",
  30. [
  31. (np.int8, np.float64),
  32. (np.int16, np.float64),
  33. (np.int32, np.float64),
  34. (np.int64, np.float64),
  35. (np.float32, np.float32),
  36. (np.float64, np.float64),
  37. ],
  38. )
  39. def test_where_unsafe_upcast(dtype, expected_dtype):
  40. # see gh-9743
  41. s = Series(np.arange(10), dtype=dtype)
  42. values = [2.5, 3.5, 4.5, 5.5, 6.5]
  43. mask = s < 5
  44. expected = Series(values + list(range(5, 10)), dtype=expected_dtype)
  45. s[mask] = values
  46. tm.assert_series_equal(s, expected)
  47. def test_where_unsafe():
  48. # see gh-9731
  49. s = Series(np.arange(10), dtype="int64")
  50. values = [2.5, 3.5, 4.5, 5.5]
  51. mask = s > 5
  52. expected = Series(list(range(6)) + values, dtype="float64")
  53. s[mask] = values
  54. tm.assert_series_equal(s, expected)
  55. # see gh-3235
  56. s = Series(np.arange(10), dtype="int64")
  57. mask = s < 5
  58. s[mask] = range(2, 7)
  59. expected = Series(list(range(2, 7)) + list(range(5, 10)), dtype="int64")
  60. tm.assert_series_equal(s, expected)
  61. assert s.dtype == expected.dtype
  62. s = Series(np.arange(10), dtype="int64")
  63. mask = s > 5
  64. s[mask] = [0] * 4
  65. expected = Series([0, 1, 2, 3, 4, 5] + [0] * 4, dtype="int64")
  66. tm.assert_series_equal(s, expected)
  67. s = Series(np.arange(10))
  68. mask = s > 5
  69. msg = "cannot set using a list-like indexer with a different length than the value"
  70. with pytest.raises(ValueError, match=msg):
  71. s[mask] = [5, 4, 3, 2, 1]
  72. with pytest.raises(ValueError, match=msg):
  73. s[mask] = [0] * 5
  74. # dtype changes
  75. s = Series([1, 2, 3, 4])
  76. result = s.where(s > 2, np.nan)
  77. expected = Series([np.nan, np.nan, 3, 4])
  78. tm.assert_series_equal(result, expected)
  79. # GH 4667
  80. # setting with None changes dtype
  81. s = Series(range(10)).astype(float)
  82. s[8] = None
  83. result = s[8]
  84. assert isna(result)
  85. s = Series(range(10)).astype(float)
  86. s[s > 8] = None
  87. result = s[isna(s)]
  88. expected = Series(np.nan, index=[9])
  89. tm.assert_series_equal(result, expected)
  90. def test_where():
  91. s = Series(np.random.randn(5))
  92. cond = s > 0
  93. rs = s.where(cond).dropna()
  94. rs2 = s[cond]
  95. tm.assert_series_equal(rs, rs2)
  96. rs = s.where(cond, -s)
  97. tm.assert_series_equal(rs, s.abs())
  98. rs = s.where(cond)
  99. assert s.shape == rs.shape
  100. assert rs is not s
  101. # test alignment
  102. cond = Series([True, False, False, True, False], index=s.index)
  103. s2 = -(s.abs())
  104. expected = s2[cond].reindex(s2.index[:3]).reindex(s2.index)
  105. rs = s2.where(cond[:3])
  106. tm.assert_series_equal(rs, expected)
  107. expected = s2.abs()
  108. expected.iloc[0] = s2[0]
  109. rs = s2.where(cond[:3], -s2)
  110. tm.assert_series_equal(rs, expected)
  111. def test_where_error():
  112. s = Series(np.random.randn(5))
  113. cond = s > 0
  114. msg = "Array conditional must be same shape as self"
  115. with pytest.raises(ValueError, match=msg):
  116. s.where(1)
  117. with pytest.raises(ValueError, match=msg):
  118. s.where(cond[:3].values, -s)
  119. # GH 2745
  120. s = Series([1, 2])
  121. s[[True, False]] = [0, 1]
  122. expected = Series([0, 2])
  123. tm.assert_series_equal(s, expected)
  124. # failures
  125. msg = "cannot set using a list-like indexer with a different length than the value"
  126. with pytest.raises(ValueError, match=msg):
  127. s[[True, False]] = [0, 2, 3]
  128. with pytest.raises(ValueError, match=msg):
  129. s[[True, False]] = []
  130. @pytest.mark.parametrize("klass", [list, tuple, np.array, Series])
  131. def test_where_array_like(klass):
  132. # see gh-15414
  133. s = Series([1, 2, 3])
  134. cond = [False, True, True]
  135. expected = Series([np.nan, 2, 3])
  136. result = s.where(klass(cond))
  137. tm.assert_series_equal(result, expected)
  138. @pytest.mark.parametrize(
  139. "cond",
  140. [
  141. [1, 0, 1],
  142. Series([2, 5, 7]),
  143. ["True", "False", "True"],
  144. [Timestamp("2017-01-01"), pd.NaT, Timestamp("2017-01-02")],
  145. ],
  146. )
  147. def test_where_invalid_input(cond):
  148. # see gh-15414: only boolean arrays accepted
  149. s = Series([1, 2, 3])
  150. msg = "Boolean array expected for the condition"
  151. with pytest.raises(ValueError, match=msg):
  152. s.where(cond)
  153. msg = "Array conditional must be same shape as self"
  154. with pytest.raises(ValueError, match=msg):
  155. s.where([True])
  156. def test_where_ndframe_align():
  157. msg = "Array conditional must be same shape as self"
  158. s = Series([1, 2, 3])
  159. cond = [True]
  160. with pytest.raises(ValueError, match=msg):
  161. s.where(cond)
  162. expected = Series([1, np.nan, np.nan])
  163. out = s.where(Series(cond))
  164. tm.assert_series_equal(out, expected)
  165. cond = np.array([False, True, False, True])
  166. with pytest.raises(ValueError, match=msg):
  167. s.where(cond)
  168. expected = Series([np.nan, 2, np.nan])
  169. out = s.where(Series(cond))
  170. tm.assert_series_equal(out, expected)
  171. def test_where_setitem_invalid():
  172. # GH 2702
  173. # make sure correct exceptions are raised on invalid list assignment
  174. msg = (
  175. lambda x: f"cannot set using a {x} indexer with a "
  176. "different length than the value"
  177. )
  178. # slice
  179. s = Series(list("abc"))
  180. with pytest.raises(ValueError, match=msg("slice")):
  181. s[0:3] = list(range(27))
  182. s[0:3] = list(range(3))
  183. expected = Series([0, 1, 2])
  184. tm.assert_series_equal(s.astype(np.int64), expected)
  185. # slice with step
  186. s = Series(list("abcdef"))
  187. with pytest.raises(ValueError, match=msg("slice")):
  188. s[0:4:2] = list(range(27))
  189. s = Series(list("abcdef"))
  190. s[0:4:2] = list(range(2))
  191. expected = Series([0, "b", 1, "d", "e", "f"])
  192. tm.assert_series_equal(s, expected)
  193. # neg slices
  194. s = Series(list("abcdef"))
  195. with pytest.raises(ValueError, match=msg("slice")):
  196. s[:-1] = list(range(27))
  197. s[-3:-1] = list(range(2))
  198. expected = Series(["a", "b", "c", 0, 1, "f"])
  199. tm.assert_series_equal(s, expected)
  200. # list
  201. s = Series(list("abc"))
  202. with pytest.raises(ValueError, match=msg("list-like")):
  203. s[[0, 1, 2]] = list(range(27))
  204. s = Series(list("abc"))
  205. with pytest.raises(ValueError, match=msg("list-like")):
  206. s[[0, 1, 2]] = list(range(2))
  207. # scalar
  208. s = Series(list("abc"))
  209. s[0] = list(range(10))
  210. expected = Series([list(range(10)), "b", "c"])
  211. tm.assert_series_equal(s, expected)
  212. @pytest.mark.parametrize("size", range(2, 6))
  213. @pytest.mark.parametrize(
  214. "mask", [[True, False, False, False, False], [True, False], [False]]
  215. )
  216. @pytest.mark.parametrize(
  217. "item", [2.0, np.nan, np.finfo(float).max, np.finfo(float).min]
  218. )
  219. # Test numpy arrays, lists and tuples as the input to be
  220. # broadcast
  221. @pytest.mark.parametrize(
  222. "box", [lambda x: np.array([x]), lambda x: [x], lambda x: (x,)]
  223. )
  224. def test_broadcast(size, mask, item, box):
  225. # GH#8801, GH#4195
  226. selection = np.resize(mask, size)
  227. data = np.arange(size, dtype=float)
  228. # Construct the expected series by taking the source
  229. # data or item based on the selection
  230. expected = Series(
  231. [item if use_item else data[i] for i, use_item in enumerate(selection)]
  232. )
  233. s = Series(data)
  234. s[selection] = item
  235. tm.assert_series_equal(s, expected)
  236. s = Series(data)
  237. result = s.where(~selection, box(item))
  238. tm.assert_series_equal(result, expected)
  239. s = Series(data)
  240. result = s.mask(selection, box(item))
  241. tm.assert_series_equal(result, expected)
  242. def test_where_inplace():
  243. s = Series(np.random.randn(5))
  244. cond = s > 0
  245. rs = s.copy()
  246. rs.where(cond, inplace=True)
  247. tm.assert_series_equal(rs.dropna(), s[cond])
  248. tm.assert_series_equal(rs, s.where(cond))
  249. rs = s.copy()
  250. rs.where(cond, -s, inplace=True)
  251. tm.assert_series_equal(rs, s.where(cond, -s))
  252. def test_where_dups():
  253. # GH 4550
  254. # where crashes with dups in index
  255. s1 = Series(list(range(3)))
  256. s2 = Series(list(range(3)))
  257. comb = pd.concat([s1, s2])
  258. result = comb.where(comb < 2)
  259. expected = Series([0, 1, np.nan, 0, 1, np.nan], index=[0, 1, 2, 0, 1, 2])
  260. tm.assert_series_equal(result, expected)
  261. # GH 4548
  262. # inplace updating not working with dups
  263. comb[comb < 1] = 5
  264. expected = Series([5, 1, 2, 5, 1, 2], index=[0, 1, 2, 0, 1, 2])
  265. tm.assert_series_equal(comb, expected)
  266. comb[comb < 2] += 10
  267. expected = Series([5, 11, 2, 5, 11, 2], index=[0, 1, 2, 0, 1, 2])
  268. tm.assert_series_equal(comb, expected)
  269. def test_where_numeric_with_string():
  270. # GH 9280
  271. s = Series([1, 2, 3])
  272. w = s.where(s > 1, "X")
  273. assert not is_integer(w[0])
  274. assert is_integer(w[1])
  275. assert is_integer(w[2])
  276. assert isinstance(w[0], str)
  277. assert w.dtype == "object"
  278. w = s.where(s > 1, ["X", "Y", "Z"])
  279. assert not is_integer(w[0])
  280. assert is_integer(w[1])
  281. assert is_integer(w[2])
  282. assert isinstance(w[0], str)
  283. assert w.dtype == "object"
  284. w = s.where(s > 1, np.array(["X", "Y", "Z"]))
  285. assert not is_integer(w[0])
  286. assert is_integer(w[1])
  287. assert is_integer(w[2])
  288. assert isinstance(w[0], str)
  289. assert w.dtype == "object"
  290. @pytest.mark.parametrize("dtype", ["timedelta64[ns]", "datetime64[ns]"])
  291. def test_where_datetimelike_coerce(dtype):
  292. ser = Series([1, 2], dtype=dtype)
  293. expected = Series([10, 10])
  294. mask = np.array([False, False])
  295. rs = ser.where(mask, [10, 10])
  296. tm.assert_series_equal(rs, expected)
  297. rs = ser.where(mask, 10)
  298. tm.assert_series_equal(rs, expected)
  299. rs = ser.where(mask, 10.0)
  300. tm.assert_series_equal(rs, expected)
  301. rs = ser.where(mask, [10.0, 10.0])
  302. tm.assert_series_equal(rs, expected)
  303. rs = ser.where(mask, [10.0, np.nan])
  304. expected = Series([10, None], dtype="object")
  305. tm.assert_series_equal(rs, expected)
  306. def test_where_datetimetz():
  307. # GH 15701
  308. timestamps = ["2016-12-31 12:00:04+00:00", "2016-12-31 12:00:04.010000+00:00"]
  309. ser = Series([Timestamp(t) for t in timestamps], dtype="datetime64[ns, UTC]")
  310. rs = ser.where(Series([False, True]))
  311. expected = Series([pd.NaT, ser[1]], dtype="datetime64[ns, UTC]")
  312. tm.assert_series_equal(rs, expected)
  313. def test_where_sparse():
  314. # GH#17198 make sure we dont get an AttributeError for sp_index
  315. ser = Series(pd.arrays.SparseArray([1, 2]))
  316. result = ser.where(ser >= 2, 0)
  317. expected = Series(pd.arrays.SparseArray([0, 2]))
  318. tm.assert_series_equal(result, expected)
  319. def test_where_empty_series_and_empty_cond_having_non_bool_dtypes():
  320. # https://github.com/pandas-dev/pandas/issues/34592
  321. ser = Series([], dtype=float)
  322. result = ser.where([])
  323. tm.assert_series_equal(result, ser)
  324. def test_where_categorical(frame_or_series):
  325. # https://github.com/pandas-dev/pandas/issues/18888
  326. exp = frame_or_series(
  327. pd.Categorical(["A", "A", "B", "B", np.nan], categories=["A", "B", "C"]),
  328. dtype="category",
  329. )
  330. df = frame_or_series(["A", "A", "B", "B", "C"], dtype="category")
  331. res = df.where(df != "C")
  332. tm.assert_equal(exp, res)
  333. def test_where_datetimelike_categorical(tz_naive_fixture):
  334. # GH#37682
  335. tz = tz_naive_fixture
  336. dr = date_range("2001-01-01", periods=3, tz=tz)._with_freq(None)
  337. lvals = pd.DatetimeIndex([dr[0], dr[1], pd.NaT])
  338. rvals = pd.Categorical([dr[0], pd.NaT, dr[2]])
  339. mask = np.array([True, True, False])
  340. # DatetimeIndex.where
  341. res = lvals.where(mask, rvals)
  342. tm.assert_index_equal(res, dr)
  343. # DatetimeArray.where
  344. res = lvals._data._where(mask, rvals)
  345. tm.assert_datetime_array_equal(res, dr._data)
  346. # Series.where
  347. res = Series(lvals).where(mask, rvals)
  348. tm.assert_series_equal(res, Series(dr))
  349. # DataFrame.where
  350. res = pd.DataFrame(lvals).where(mask[:, None], pd.DataFrame(rvals))
  351. tm.assert_frame_equal(res, pd.DataFrame(dr))