test_where.py 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036
  1. from datetime import datetime
  2. from hypothesis import given
  3. import numpy as np
  4. import pytest
  5. from pandas.core.dtypes.common import is_scalar
  6. import pandas as pd
  7. from pandas import (
  8. DataFrame,
  9. DatetimeIndex,
  10. Index,
  11. Series,
  12. StringDtype,
  13. Timestamp,
  14. date_range,
  15. isna,
  16. )
  17. import pandas._testing as tm
  18. from pandas._testing._hypothesis import OPTIONAL_ONE_OF_ALL
  19. @pytest.fixture(params=["default", "float_string", "mixed_float", "mixed_int"])
  20. def where_frame(request, float_string_frame, mixed_float_frame, mixed_int_frame):
  21. if request.param == "default":
  22. return DataFrame(np.random.randn(5, 3), columns=["A", "B", "C"])
  23. if request.param == "float_string":
  24. return float_string_frame
  25. if request.param == "mixed_float":
  26. return mixed_float_frame
  27. if request.param == "mixed_int":
  28. return mixed_int_frame
  29. def _safe_add(df):
  30. # only add to the numeric items
  31. def is_ok(s):
  32. return (
  33. issubclass(s.dtype.type, (np.integer, np.floating)) and s.dtype != "uint8"
  34. )
  35. return DataFrame(dict((c, s + 1) if is_ok(s) else (c, s) for c, s in df.items()))
  36. class TestDataFrameIndexingWhere:
  37. def test_where_get(self, where_frame, float_string_frame):
  38. def _check_get(df, cond, check_dtypes=True):
  39. other1 = _safe_add(df)
  40. rs = df.where(cond, other1)
  41. rs2 = df.where(cond.values, other1)
  42. for k, v in rs.items():
  43. exp = Series(np.where(cond[k], df[k], other1[k]), index=v.index)
  44. tm.assert_series_equal(v, exp, check_names=False)
  45. tm.assert_frame_equal(rs, rs2)
  46. # dtypes
  47. if check_dtypes:
  48. assert (rs.dtypes == df.dtypes).all()
  49. # check getting
  50. df = where_frame
  51. if df is float_string_frame:
  52. msg = "'>' not supported between instances of 'str' and 'int'"
  53. with pytest.raises(TypeError, match=msg):
  54. df > 0
  55. return
  56. cond = df > 0
  57. _check_get(df, cond)
  58. def test_where_upcasting(self):
  59. # upcasting case (GH # 2794)
  60. df = DataFrame(
  61. {
  62. c: Series([1] * 3, dtype=c)
  63. for c in ["float32", "float64", "int32", "int64"]
  64. }
  65. )
  66. df.iloc[1, :] = 0
  67. result = df.dtypes
  68. expected = Series(
  69. [
  70. np.dtype("float32"),
  71. np.dtype("float64"),
  72. np.dtype("int32"),
  73. np.dtype("int64"),
  74. ],
  75. index=["float32", "float64", "int32", "int64"],
  76. )
  77. # when we don't preserve boolean casts
  78. #
  79. # expected = Series({ 'float32' : 1, 'float64' : 3 })
  80. tm.assert_series_equal(result, expected)
  81. def test_where_alignment(self, where_frame, float_string_frame):
  82. # aligning
  83. def _check_align(df, cond, other, check_dtypes=True):
  84. rs = df.where(cond, other)
  85. for i, k in enumerate(rs.columns):
  86. result = rs[k]
  87. d = df[k].values
  88. c = cond[k].reindex(df[k].index).fillna(False).values
  89. if is_scalar(other):
  90. o = other
  91. else:
  92. if isinstance(other, np.ndarray):
  93. o = Series(other[:, i], index=result.index).values
  94. else:
  95. o = other[k].values
  96. new_values = d if c.all() else np.where(c, d, o)
  97. expected = Series(new_values, index=result.index, name=k)
  98. # since we can't always have the correct numpy dtype
  99. # as numpy doesn't know how to downcast, don't check
  100. tm.assert_series_equal(result, expected, check_dtype=False)
  101. # dtypes
  102. # can't check dtype when other is an ndarray
  103. if check_dtypes and not isinstance(other, np.ndarray):
  104. assert (rs.dtypes == df.dtypes).all()
  105. df = where_frame
  106. if df is float_string_frame:
  107. msg = "'>' not supported between instances of 'str' and 'int'"
  108. with pytest.raises(TypeError, match=msg):
  109. df > 0
  110. return
  111. # other is a frame
  112. cond = (df > 0)[1:]
  113. _check_align(df, cond, _safe_add(df))
  114. # check other is ndarray
  115. cond = df > 0
  116. _check_align(df, cond, (_safe_add(df).values))
  117. # integers are upcast, so don't check the dtypes
  118. cond = df > 0
  119. check_dtypes = all(not issubclass(s.type, np.integer) for s in df.dtypes)
  120. _check_align(df, cond, np.nan, check_dtypes=check_dtypes)
  121. def test_where_invalid(self):
  122. # invalid conditions
  123. df = DataFrame(np.random.randn(5, 3), columns=["A", "B", "C"])
  124. cond = df > 0
  125. err1 = (df + 1).values[0:2, :]
  126. msg = "other must be the same shape as self when an ndarray"
  127. with pytest.raises(ValueError, match=msg):
  128. df.where(cond, err1)
  129. err2 = cond.iloc[:2, :].values
  130. other1 = _safe_add(df)
  131. msg = "Array conditional must be same shape as self"
  132. with pytest.raises(ValueError, match=msg):
  133. df.where(err2, other1)
  134. with pytest.raises(ValueError, match=msg):
  135. df.mask(True)
  136. with pytest.raises(ValueError, match=msg):
  137. df.mask(0)
  138. def test_where_set(self, where_frame, float_string_frame):
  139. # where inplace
  140. def _check_set(df, cond, check_dtypes=True):
  141. dfi = df.copy()
  142. econd = cond.reindex_like(df).fillna(True)
  143. expected = dfi.mask(~econd)
  144. return_value = dfi.where(cond, np.nan, inplace=True)
  145. assert return_value is None
  146. tm.assert_frame_equal(dfi, expected)
  147. # dtypes (and confirm upcasts)x
  148. if check_dtypes:
  149. for k, v in df.dtypes.items():
  150. if issubclass(v.type, np.integer) and not cond[k].all():
  151. v = np.dtype("float64")
  152. assert dfi[k].dtype == v
  153. df = where_frame
  154. if df is float_string_frame:
  155. msg = "'>' not supported between instances of 'str' and 'int'"
  156. with pytest.raises(TypeError, match=msg):
  157. df > 0
  158. return
  159. cond = df > 0
  160. _check_set(df, cond)
  161. cond = df >= 0
  162. _check_set(df, cond)
  163. # aligning
  164. cond = (df >= 0)[1:]
  165. _check_set(df, cond)
  166. def test_where_series_slicing(self):
  167. # GH 10218
  168. # test DataFrame.where with Series slicing
  169. df = DataFrame({"a": range(3), "b": range(4, 7)})
  170. result = df.where(df["a"] == 1)
  171. expected = df[df["a"] == 1].reindex(df.index)
  172. tm.assert_frame_equal(result, expected)
  173. @pytest.mark.parametrize("klass", [list, tuple, np.array])
  174. def test_where_array_like(self, klass):
  175. # see gh-15414
  176. df = DataFrame({"a": [1, 2, 3]})
  177. cond = [[False], [True], [True]]
  178. expected = DataFrame({"a": [np.nan, 2, 3]})
  179. result = df.where(klass(cond))
  180. tm.assert_frame_equal(result, expected)
  181. df["b"] = 2
  182. expected["b"] = [2, np.nan, 2]
  183. cond = [[False, True], [True, False], [True, True]]
  184. result = df.where(klass(cond))
  185. tm.assert_frame_equal(result, expected)
  186. @pytest.mark.parametrize(
  187. "cond",
  188. [
  189. [[1], [0], [1]],
  190. Series([[2], [5], [7]]),
  191. DataFrame({"a": [2, 5, 7]}),
  192. [["True"], ["False"], ["True"]],
  193. [[Timestamp("2017-01-01")], [pd.NaT], [Timestamp("2017-01-02")]],
  194. ],
  195. )
  196. def test_where_invalid_input_single(self, cond):
  197. # see gh-15414: only boolean arrays accepted
  198. df = DataFrame({"a": [1, 2, 3]})
  199. msg = "Boolean array expected for the condition"
  200. with pytest.raises(ValueError, match=msg):
  201. df.where(cond)
  202. @pytest.mark.parametrize(
  203. "cond",
  204. [
  205. [[0, 1], [1, 0], [1, 1]],
  206. Series([[0, 2], [5, 0], [4, 7]]),
  207. [["False", "True"], ["True", "False"], ["True", "True"]],
  208. DataFrame({"a": [2, 5, 7], "b": [4, 8, 9]}),
  209. [
  210. [pd.NaT, Timestamp("2017-01-01")],
  211. [Timestamp("2017-01-02"), pd.NaT],
  212. [Timestamp("2017-01-03"), Timestamp("2017-01-03")],
  213. ],
  214. ],
  215. )
  216. def test_where_invalid_input_multiple(self, cond):
  217. # see gh-15414: only boolean arrays accepted
  218. df = DataFrame({"a": [1, 2, 3], "b": [2, 2, 2]})
  219. msg = "Boolean array expected for the condition"
  220. with pytest.raises(ValueError, match=msg):
  221. df.where(cond)
  222. def test_where_dataframe_col_match(self):
  223. df = DataFrame([[1, 2, 3], [4, 5, 6]])
  224. cond = DataFrame([[True, False, True], [False, False, True]])
  225. result = df.where(cond)
  226. expected = DataFrame([[1.0, np.nan, 3], [np.nan, np.nan, 6]])
  227. tm.assert_frame_equal(result, expected)
  228. # this *does* align, though has no matching columns
  229. cond.columns = ["a", "b", "c"]
  230. result = df.where(cond)
  231. expected = DataFrame(np.nan, index=df.index, columns=df.columns)
  232. tm.assert_frame_equal(result, expected)
  233. def test_where_ndframe_align(self):
  234. msg = "Array conditional must be same shape as self"
  235. df = DataFrame([[1, 2, 3], [4, 5, 6]])
  236. cond = [True]
  237. with pytest.raises(ValueError, match=msg):
  238. df.where(cond)
  239. expected = DataFrame([[1, 2, 3], [np.nan, np.nan, np.nan]])
  240. out = df.where(Series(cond))
  241. tm.assert_frame_equal(out, expected)
  242. cond = np.array([False, True, False, True])
  243. with pytest.raises(ValueError, match=msg):
  244. df.where(cond)
  245. expected = DataFrame([[np.nan, np.nan, np.nan], [4, 5, 6]])
  246. out = df.where(Series(cond))
  247. tm.assert_frame_equal(out, expected)
  248. def test_where_bug(self):
  249. # see gh-2793
  250. df = DataFrame(
  251. {"a": [1.0, 2.0, 3.0, 4.0], "b": [4.0, 3.0, 2.0, 1.0]}, dtype="float64"
  252. )
  253. expected = DataFrame(
  254. {"a": [np.nan, np.nan, 3.0, 4.0], "b": [4.0, 3.0, np.nan, np.nan]},
  255. dtype="float64",
  256. )
  257. result = df.where(df > 2, np.nan)
  258. tm.assert_frame_equal(result, expected)
  259. result = df.copy()
  260. return_value = result.where(result > 2, np.nan, inplace=True)
  261. assert return_value is None
  262. tm.assert_frame_equal(result, expected)
  263. def test_where_bug_mixed(self, any_signed_int_numpy_dtype):
  264. # see gh-2793
  265. df = DataFrame(
  266. {
  267. "a": np.array([1, 2, 3, 4], dtype=any_signed_int_numpy_dtype),
  268. "b": np.array([4.0, 3.0, 2.0, 1.0], dtype="float64"),
  269. }
  270. )
  271. expected = DataFrame(
  272. {"a": [np.nan, np.nan, 3.0, 4.0], "b": [4.0, 3.0, np.nan, np.nan]},
  273. dtype="float64",
  274. )
  275. result = df.where(df > 2, np.nan)
  276. tm.assert_frame_equal(result, expected)
  277. result = df.copy()
  278. return_value = result.where(result > 2, np.nan, inplace=True)
  279. assert return_value is None
  280. tm.assert_frame_equal(result, expected)
  281. def test_where_bug_transposition(self):
  282. # see gh-7506
  283. a = DataFrame({0: [1, 2], 1: [3, 4], 2: [5, 6]})
  284. b = DataFrame({0: [np.nan, 8], 1: [9, np.nan], 2: [np.nan, np.nan]})
  285. do_not_replace = b.isna() | (a > b)
  286. expected = a.copy()
  287. expected[~do_not_replace] = b
  288. result = a.where(do_not_replace, b)
  289. tm.assert_frame_equal(result, expected)
  290. a = DataFrame({0: [4, 6], 1: [1, 0]})
  291. b = DataFrame({0: [np.nan, 3], 1: [3, np.nan]})
  292. do_not_replace = b.isna() | (a > b)
  293. expected = a.copy()
  294. expected[~do_not_replace] = b
  295. result = a.where(do_not_replace, b)
  296. tm.assert_frame_equal(result, expected)
  297. def test_where_datetime(self):
  298. # GH 3311
  299. df = DataFrame(
  300. {
  301. "A": date_range("20130102", periods=5),
  302. "B": date_range("20130104", periods=5),
  303. "C": np.random.randn(5),
  304. }
  305. )
  306. stamp = datetime(2013, 1, 3)
  307. msg = "'>' not supported between instances of 'float' and 'datetime.datetime'"
  308. with pytest.raises(TypeError, match=msg):
  309. df > stamp
  310. result = df[df.iloc[:, :-1] > stamp]
  311. expected = df.copy()
  312. expected.loc[[0, 1], "A"] = np.nan
  313. expected.loc[:, "C"] = np.nan
  314. tm.assert_frame_equal(result, expected)
  315. def test_where_none(self):
  316. # GH 4667
  317. # setting with None changes dtype
  318. df = DataFrame({"series": Series(range(10))}).astype(float)
  319. df[df > 7] = None
  320. expected = DataFrame(
  321. {"series": Series([0, 1, 2, 3, 4, 5, 6, 7, np.nan, np.nan])}
  322. )
  323. tm.assert_frame_equal(df, expected)
  324. # GH 7656
  325. df = DataFrame(
  326. [
  327. {"A": 1, "B": np.nan, "C": "Test"},
  328. {"A": np.nan, "B": "Test", "C": np.nan},
  329. ]
  330. )
  331. msg = "boolean setting on mixed-type"
  332. with pytest.raises(TypeError, match=msg):
  333. df.where(~isna(df), None, inplace=True)
  334. def test_where_empty_df_and_empty_cond_having_non_bool_dtypes(self):
  335. # see gh-21947
  336. df = DataFrame(columns=["a"])
  337. cond = df
  338. assert (cond.dtypes == object).all()
  339. result = df.where(cond)
  340. tm.assert_frame_equal(result, df)
  341. def test_where_align(self):
  342. def create():
  343. df = DataFrame(np.random.randn(10, 3))
  344. df.iloc[3:5, 0] = np.nan
  345. df.iloc[4:6, 1] = np.nan
  346. df.iloc[5:8, 2] = np.nan
  347. return df
  348. # series
  349. df = create()
  350. expected = df.fillna(df.mean())
  351. result = df.where(pd.notna(df), df.mean(), axis="columns")
  352. tm.assert_frame_equal(result, expected)
  353. return_value = df.where(pd.notna(df), df.mean(), inplace=True, axis="columns")
  354. assert return_value is None
  355. tm.assert_frame_equal(df, expected)
  356. df = create().fillna(0)
  357. expected = df.apply(lambda x, y: x.where(x > 0, y), y=df[0])
  358. result = df.where(df > 0, df[0], axis="index")
  359. tm.assert_frame_equal(result, expected)
  360. result = df.where(df > 0, df[0], axis="rows")
  361. tm.assert_frame_equal(result, expected)
  362. # frame
  363. df = create()
  364. expected = df.fillna(1)
  365. result = df.where(
  366. pd.notna(df), DataFrame(1, index=df.index, columns=df.columns)
  367. )
  368. tm.assert_frame_equal(result, expected)
  369. def test_where_complex(self):
  370. # GH 6345
  371. expected = DataFrame([[1 + 1j, 2], [np.nan, 4 + 1j]], columns=["a", "b"])
  372. df = DataFrame([[1 + 1j, 2], [5 + 1j, 4 + 1j]], columns=["a", "b"])
  373. df[df.abs() >= 5] = np.nan
  374. tm.assert_frame_equal(df, expected)
  375. def test_where_axis(self):
  376. # GH 9736
  377. df = DataFrame(np.random.randn(2, 2))
  378. mask = DataFrame([[False, False], [False, False]])
  379. ser = Series([0, 1])
  380. expected = DataFrame([[0, 0], [1, 1]], dtype="float64")
  381. result = df.where(mask, ser, axis="index")
  382. tm.assert_frame_equal(result, expected)
  383. result = df.copy()
  384. return_value = result.where(mask, ser, axis="index", inplace=True)
  385. assert return_value is None
  386. tm.assert_frame_equal(result, expected)
  387. expected = DataFrame([[0, 1], [0, 1]], dtype="float64")
  388. result = df.where(mask, ser, axis="columns")
  389. tm.assert_frame_equal(result, expected)
  390. result = df.copy()
  391. return_value = result.where(mask, ser, axis="columns", inplace=True)
  392. assert return_value is None
  393. tm.assert_frame_equal(result, expected)
  394. def test_where_axis_with_upcast(self):
  395. # Upcast needed
  396. df = DataFrame([[1, 2], [3, 4]], dtype="int64")
  397. mask = DataFrame([[False, False], [False, False]])
  398. ser = Series([0, np.nan])
  399. expected = DataFrame([[0, 0], [np.nan, np.nan]], dtype="float64")
  400. result = df.where(mask, ser, axis="index")
  401. tm.assert_frame_equal(result, expected)
  402. result = df.copy()
  403. return_value = result.where(mask, ser, axis="index", inplace=True)
  404. assert return_value is None
  405. tm.assert_frame_equal(result, expected)
  406. expected = DataFrame([[0, np.nan], [0, np.nan]])
  407. result = df.where(mask, ser, axis="columns")
  408. tm.assert_frame_equal(result, expected)
  409. expected = DataFrame(
  410. {
  411. 0: np.array([0, 0], dtype="int64"),
  412. 1: np.array([np.nan, np.nan], dtype="float64"),
  413. }
  414. )
  415. result = df.copy()
  416. return_value = result.where(mask, ser, axis="columns", inplace=True)
  417. assert return_value is None
  418. tm.assert_frame_equal(result, expected)
  419. def test_where_axis_multiple_dtypes(self):
  420. # Multiple dtypes (=> multiple Blocks)
  421. df = pd.concat(
  422. [
  423. DataFrame(np.random.randn(10, 2)),
  424. DataFrame(np.random.randint(0, 10, size=(10, 2)), dtype="int64"),
  425. ],
  426. ignore_index=True,
  427. axis=1,
  428. )
  429. mask = DataFrame(False, columns=df.columns, index=df.index)
  430. s1 = Series(1, index=df.columns)
  431. s2 = Series(2, index=df.index)
  432. result = df.where(mask, s1, axis="columns")
  433. expected = DataFrame(1.0, columns=df.columns, index=df.index)
  434. expected[2] = expected[2].astype("int64")
  435. expected[3] = expected[3].astype("int64")
  436. tm.assert_frame_equal(result, expected)
  437. result = df.copy()
  438. return_value = result.where(mask, s1, axis="columns", inplace=True)
  439. assert return_value is None
  440. tm.assert_frame_equal(result, expected)
  441. result = df.where(mask, s2, axis="index")
  442. expected = DataFrame(2.0, columns=df.columns, index=df.index)
  443. expected[2] = expected[2].astype("int64")
  444. expected[3] = expected[3].astype("int64")
  445. tm.assert_frame_equal(result, expected)
  446. result = df.copy()
  447. return_value = result.where(mask, s2, axis="index", inplace=True)
  448. assert return_value is None
  449. tm.assert_frame_equal(result, expected)
  450. # DataFrame vs DataFrame
  451. d1 = df.copy().drop(1, axis=0)
  452. # Explicit cast to avoid implicit cast when setting value to np.nan
  453. expected = df.copy().astype("float")
  454. expected.loc[1, :] = np.nan
  455. result = df.where(mask, d1)
  456. tm.assert_frame_equal(result, expected)
  457. result = df.where(mask, d1, axis="index")
  458. tm.assert_frame_equal(result, expected)
  459. result = df.copy()
  460. return_value = result.where(mask, d1, inplace=True)
  461. assert return_value is None
  462. tm.assert_frame_equal(result, expected)
  463. result = df.copy()
  464. return_value = result.where(mask, d1, inplace=True, axis="index")
  465. assert return_value is None
  466. tm.assert_frame_equal(result, expected)
  467. d2 = df.copy().drop(1, axis=1)
  468. expected = df.copy()
  469. expected.loc[:, 1] = np.nan
  470. result = df.where(mask, d2)
  471. tm.assert_frame_equal(result, expected)
  472. result = df.where(mask, d2, axis="columns")
  473. tm.assert_frame_equal(result, expected)
  474. result = df.copy()
  475. return_value = result.where(mask, d2, inplace=True)
  476. assert return_value is None
  477. tm.assert_frame_equal(result, expected)
  478. result = df.copy()
  479. return_value = result.where(mask, d2, inplace=True, axis="columns")
  480. assert return_value is None
  481. tm.assert_frame_equal(result, expected)
  482. def test_where_callable(self):
  483. # GH 12533
  484. df = DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  485. result = df.where(lambda x: x > 4, lambda x: x + 1)
  486. exp = DataFrame([[2, 3, 4], [5, 5, 6], [7, 8, 9]])
  487. tm.assert_frame_equal(result, exp)
  488. tm.assert_frame_equal(result, df.where(df > 4, df + 1))
  489. # return ndarray and scalar
  490. result = df.where(lambda x: (x % 2 == 0).values, lambda x: 99)
  491. exp = DataFrame([[99, 2, 99], [4, 99, 6], [99, 8, 99]])
  492. tm.assert_frame_equal(result, exp)
  493. tm.assert_frame_equal(result, df.where(df % 2 == 0, 99))
  494. # chain
  495. result = (df + 2).where(lambda x: x > 8, lambda x: x + 10)
  496. exp = DataFrame([[13, 14, 15], [16, 17, 18], [9, 10, 11]])
  497. tm.assert_frame_equal(result, exp)
  498. tm.assert_frame_equal(result, (df + 2).where((df + 2) > 8, (df + 2) + 10))
  499. def test_where_tz_values(self, tz_naive_fixture, frame_or_series):
  500. obj1 = DataFrame(
  501. DatetimeIndex(["20150101", "20150102", "20150103"], tz=tz_naive_fixture),
  502. columns=["date"],
  503. )
  504. obj2 = DataFrame(
  505. DatetimeIndex(["20150103", "20150104", "20150105"], tz=tz_naive_fixture),
  506. columns=["date"],
  507. )
  508. mask = DataFrame([True, True, False], columns=["date"])
  509. exp = DataFrame(
  510. DatetimeIndex(["20150101", "20150102", "20150105"], tz=tz_naive_fixture),
  511. columns=["date"],
  512. )
  513. if frame_or_series is Series:
  514. obj1 = obj1["date"]
  515. obj2 = obj2["date"]
  516. mask = mask["date"]
  517. exp = exp["date"]
  518. result = obj1.where(mask, obj2)
  519. tm.assert_equal(exp, result)
  520. def test_df_where_change_dtype(self):
  521. # GH#16979
  522. df = DataFrame(np.arange(2 * 3).reshape(2, 3), columns=list("ABC"))
  523. mask = np.array([[True, False, False], [False, False, True]])
  524. result = df.where(mask)
  525. expected = DataFrame(
  526. [[0, np.nan, np.nan], [np.nan, np.nan, 5]], columns=list("ABC")
  527. )
  528. tm.assert_frame_equal(result, expected)
  529. @pytest.mark.parametrize("kwargs", [{}, {"other": None}])
  530. def test_df_where_with_category(self, kwargs):
  531. # GH#16979
  532. data = np.arange(2 * 3, dtype=np.int64).reshape(2, 3)
  533. df = DataFrame(data, columns=list("ABC"))
  534. mask = np.array([[True, False, False], [False, False, True]])
  535. # change type to category
  536. df.A = df.A.astype("category")
  537. df.B = df.B.astype("category")
  538. df.C = df.C.astype("category")
  539. result = df.where(mask, **kwargs)
  540. A = pd.Categorical([0, np.nan], categories=[0, 3])
  541. B = pd.Categorical([np.nan, np.nan], categories=[1, 4])
  542. C = pd.Categorical([np.nan, 5], categories=[2, 5])
  543. expected = DataFrame({"A": A, "B": B, "C": C})
  544. tm.assert_frame_equal(result, expected)
  545. # Check Series.where while we're here
  546. result = df.A.where(mask[:, 0], **kwargs)
  547. expected = Series(A, name="A")
  548. tm.assert_series_equal(result, expected)
  549. def test_where_categorical_filtering(self):
  550. # GH#22609 Verify filtering operations on DataFrames with categorical Series
  551. df = DataFrame(data=[[0, 0], [1, 1]], columns=["a", "b"])
  552. df["b"] = df["b"].astype("category")
  553. result = df.where(df["a"] > 0)
  554. # Explicitly cast to 'float' to avoid implicit cast when setting np.nan
  555. expected = df.copy().astype({"a": "float"})
  556. expected.loc[0, :] = np.nan
  557. tm.assert_equal(result, expected)
  558. def test_where_ea_other(self):
  559. # GH#38729/GH#38742
  560. df = DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
  561. arr = pd.array([7, pd.NA, 9])
  562. ser = Series(arr)
  563. mask = np.ones(df.shape, dtype=bool)
  564. mask[1, :] = False
  565. # TODO: ideally we would get Int64 instead of object
  566. result = df.where(mask, ser, axis=0)
  567. expected = DataFrame({"A": [1, pd.NA, 3], "B": [4, pd.NA, 6]}).astype(object)
  568. tm.assert_frame_equal(result, expected)
  569. ser2 = Series(arr[:2], index=["A", "B"])
  570. expected = DataFrame({"A": [1, 7, 3], "B": [4, pd.NA, 6]})
  571. expected["B"] = expected["B"].astype(object)
  572. result = df.where(mask, ser2, axis=1)
  573. tm.assert_frame_equal(result, expected)
  574. def test_where_interval_noop(self):
  575. # GH#44181
  576. df = DataFrame([pd.Interval(0, 0)])
  577. res = df.where(df.notna())
  578. tm.assert_frame_equal(res, df)
  579. ser = df[0]
  580. res = ser.where(ser.notna())
  581. tm.assert_series_equal(res, ser)
  582. def test_where_interval_fullop_downcast(self, frame_or_series):
  583. # GH#45768
  584. obj = frame_or_series([pd.Interval(0, 0)] * 2)
  585. other = frame_or_series([1.0, 2.0])
  586. res = obj.where(~obj.notna(), other)
  587. # since all entries are being changed, we will downcast result
  588. # from object to ints (not floats)
  589. tm.assert_equal(res, other.astype(np.int64))
  590. # unlike where, Block.putmask does not downcast
  591. obj.mask(obj.notna(), other, inplace=True)
  592. tm.assert_equal(obj, other.astype(object))
  593. @pytest.mark.parametrize(
  594. "dtype",
  595. [
  596. "timedelta64[ns]",
  597. "datetime64[ns]",
  598. "datetime64[ns, Asia/Tokyo]",
  599. "Period[D]",
  600. ],
  601. )
  602. def test_where_datetimelike_noop(self, dtype):
  603. # GH#45135, analogue to GH#44181 for Period don't raise on no-op
  604. # For td64/dt64/dt64tz we already don't raise, but also are
  605. # checking that we don't unnecessarily upcast to object.
  606. ser = Series(np.arange(3) * 10**9, dtype=np.int64).view(dtype)
  607. df = ser.to_frame()
  608. mask = np.array([False, False, False])
  609. res = ser.where(~mask, "foo")
  610. tm.assert_series_equal(res, ser)
  611. mask2 = mask.reshape(-1, 1)
  612. res2 = df.where(~mask2, "foo")
  613. tm.assert_frame_equal(res2, df)
  614. res3 = ser.mask(mask, "foo")
  615. tm.assert_series_equal(res3, ser)
  616. res4 = df.mask(mask2, "foo")
  617. tm.assert_frame_equal(res4, df)
  618. # opposite case where we are replacing *all* values -> we downcast
  619. # from object dtype # GH#45768
  620. res5 = df.where(mask2, 4)
  621. expected = DataFrame(4, index=df.index, columns=df.columns)
  622. tm.assert_frame_equal(res5, expected)
  623. # unlike where, Block.putmask does not downcast
  624. df.mask(~mask2, 4, inplace=True)
  625. tm.assert_frame_equal(df, expected.astype(object))
  626. def test_where_int_downcasting_deprecated():
  627. # GH#44597
  628. arr = np.arange(6).astype(np.int16).reshape(3, 2)
  629. df = DataFrame(arr)
  630. mask = np.zeros(arr.shape, dtype=bool)
  631. mask[:, 0] = True
  632. res = df.where(mask, 2**17)
  633. expected = DataFrame({0: arr[:, 0], 1: np.array([2**17] * 3, dtype=np.int32)})
  634. tm.assert_frame_equal(res, expected)
  635. def test_where_copies_with_noop(frame_or_series):
  636. # GH-39595
  637. result = frame_or_series([1, 2, 3, 4])
  638. expected = result.copy()
  639. col = result[0] if frame_or_series is DataFrame else result
  640. where_res = result.where(col < 5)
  641. where_res *= 2
  642. tm.assert_equal(result, expected)
  643. where_res = result.where(col > 5, [1, 2, 3, 4])
  644. where_res *= 2
  645. tm.assert_equal(result, expected)
  646. def test_where_string_dtype(frame_or_series):
  647. # GH40824
  648. obj = frame_or_series(
  649. ["a", "b", "c", "d"], index=["id1", "id2", "id3", "id4"], dtype=StringDtype()
  650. )
  651. filtered_obj = frame_or_series(
  652. ["b", "c"], index=["id2", "id3"], dtype=StringDtype()
  653. )
  654. filter_ser = Series([False, True, True, False])
  655. result = obj.where(filter_ser, filtered_obj)
  656. expected = frame_or_series(
  657. [pd.NA, "b", "c", pd.NA],
  658. index=["id1", "id2", "id3", "id4"],
  659. dtype=StringDtype(),
  660. )
  661. tm.assert_equal(result, expected)
  662. result = obj.mask(~filter_ser, filtered_obj)
  663. tm.assert_equal(result, expected)
  664. obj.mask(~filter_ser, filtered_obj, inplace=True)
  665. tm.assert_equal(result, expected)
  666. def test_where_bool_comparison():
  667. # GH 10336
  668. df_mask = DataFrame(
  669. {"AAA": [True] * 4, "BBB": [False] * 4, "CCC": [True, False, True, False]}
  670. )
  671. result = df_mask.where(df_mask == False) # noqa:E712
  672. expected = DataFrame(
  673. {
  674. "AAA": np.array([np.nan] * 4, dtype=object),
  675. "BBB": [False] * 4,
  676. "CCC": [np.nan, False, np.nan, False],
  677. }
  678. )
  679. tm.assert_frame_equal(result, expected)
  680. def test_where_none_nan_coerce():
  681. # GH 15613
  682. expected = DataFrame(
  683. {
  684. "A": [Timestamp("20130101"), pd.NaT, Timestamp("20130103")],
  685. "B": [1, 2, np.nan],
  686. }
  687. )
  688. result = expected.where(expected.notnull(), None)
  689. tm.assert_frame_equal(result, expected)
  690. def test_where_duplicate_axes_mixed_dtypes():
  691. # GH 25399, verify manually masking is not affected anymore by dtype of column for
  692. # duplicate axes.
  693. result = DataFrame(data=[[0, np.nan]], columns=Index(["A", "A"]))
  694. index, columns = result.axes
  695. mask = DataFrame(data=[[True, True]], columns=columns, index=index)
  696. a = result.astype(object).where(mask)
  697. b = result.astype("f8").where(mask)
  698. c = result.T.where(mask.T).T
  699. d = result.where(mask) # used to fail with "cannot reindex from a duplicate axis"
  700. tm.assert_frame_equal(a.astype("f8"), b.astype("f8"))
  701. tm.assert_frame_equal(b.astype("f8"), c.astype("f8"))
  702. tm.assert_frame_equal(c.astype("f8"), d.astype("f8"))
  703. def test_where_columns_casting():
  704. # GH 42295
  705. df = DataFrame({"a": [1.0, 2.0], "b": [3, np.nan]})
  706. expected = df.copy()
  707. result = df.where(pd.notnull(df), None)
  708. # make sure dtypes don't change
  709. tm.assert_frame_equal(expected, result)
  710. @pytest.mark.parametrize("as_cat", [True, False])
  711. def test_where_period_invalid_na(frame_or_series, as_cat, request):
  712. # GH#44697
  713. idx = pd.period_range("2016-01-01", periods=3, freq="D")
  714. if as_cat:
  715. idx = idx.astype("category")
  716. obj = frame_or_series(idx)
  717. # NA value that we should *not* cast to Period dtype
  718. tdnat = pd.NaT.to_numpy("m8[ns]")
  719. mask = np.array([True, True, False], ndmin=obj.ndim).T
  720. if as_cat:
  721. msg = (
  722. r"Cannot setitem on a Categorical with a new category \(NaT\), "
  723. "set the categories first"
  724. )
  725. else:
  726. msg = "value should be a 'Period'"
  727. if as_cat:
  728. with pytest.raises(TypeError, match=msg):
  729. obj.where(mask, tdnat)
  730. with pytest.raises(TypeError, match=msg):
  731. obj.mask(mask, tdnat)
  732. with pytest.raises(TypeError, match=msg):
  733. obj.mask(mask, tdnat, inplace=True)
  734. else:
  735. # With PeriodDtype, ser[i] = tdnat coerces instead of raising,
  736. # so for consistency, ser[mask] = tdnat must as well
  737. expected = obj.astype(object).where(mask, tdnat)
  738. result = obj.where(mask, tdnat)
  739. tm.assert_equal(result, expected)
  740. expected = obj.astype(object).mask(mask, tdnat)
  741. result = obj.mask(mask, tdnat)
  742. tm.assert_equal(result, expected)
  743. obj.mask(mask, tdnat, inplace=True)
  744. tm.assert_equal(obj, expected)
  745. def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype):
  746. # GH#44697
  747. arr = pd.array([1, 2, 3], dtype=any_numeric_ea_dtype)
  748. obj = frame_or_series(arr)
  749. mask = np.array([True, True, False], ndmin=obj.ndim).T
  750. msg = r"Invalid value '.*' for dtype (U?Int|Float)\d{1,2}"
  751. for null in tm.NP_NAT_OBJECTS + [pd.NaT]:
  752. # NaT is an NA value that we should *not* cast to pd.NA dtype
  753. with pytest.raises(TypeError, match=msg):
  754. obj.where(mask, null)
  755. with pytest.raises(TypeError, match=msg):
  756. obj.mask(mask, null)
  757. @given(data=OPTIONAL_ONE_OF_ALL)
  758. def test_where_inplace_casting(data):
  759. # GH 22051
  760. df = DataFrame({"a": data})
  761. df_copy = df.where(pd.notnull(df), None).copy()
  762. df.where(pd.notnull(df), None, inplace=True)
  763. tm.assert_equal(df, df_copy)
  764. def test_where_downcast_to_td64():
  765. ser = Series([1, 2, 3])
  766. mask = np.array([False, False, False])
  767. td = pd.Timedelta(days=1)
  768. res = ser.where(mask, td)
  769. expected = Series([td, td, td], dtype="m8[ns]")
  770. tm.assert_series_equal(res, expected)
  771. def _check_where_equivalences(df, mask, other, expected):
  772. # similar to tests.series.indexing.test_setitem.SetitemCastingEquivalences
  773. # but with DataFrame in mind and less fleshed-out
  774. res = df.where(mask, other)
  775. tm.assert_frame_equal(res, expected)
  776. res = df.mask(~mask, other)
  777. tm.assert_frame_equal(res, expected)
  778. # Note: frame.mask(~mask, other, inplace=True) takes some more work bc
  779. # Block.putmask does *not* downcast. The change to 'expected' here
  780. # is specific to the cases in test_where_dt64_2d.
  781. df = df.copy()
  782. df.mask(~mask, other, inplace=True)
  783. if not mask.all():
  784. # with mask.all(), Block.putmask is a no-op, so does not downcast
  785. expected = expected.copy()
  786. expected["A"] = expected["A"].astype(object)
  787. tm.assert_frame_equal(df, expected)
  788. def test_where_dt64_2d():
  789. dti = date_range("2016-01-01", periods=6)
  790. dta = dti._data.reshape(3, 2)
  791. other = dta - dta[0, 0]
  792. df = DataFrame(dta, columns=["A", "B"])
  793. mask = np.asarray(df.isna()).copy()
  794. mask[:, 1] = True
  795. # setting all of one column, none of the other
  796. expected = DataFrame({"A": other[:, 0], "B": dta[:, 1]})
  797. _check_where_equivalences(df, mask, other, expected)
  798. # setting part of one column, none of the other
  799. mask[1, 0] = True
  800. expected = DataFrame(
  801. {
  802. "A": np.array([other[0, 0], dta[1, 0], other[2, 0]], dtype=object),
  803. "B": dta[:, 1],
  804. }
  805. )
  806. _check_where_equivalences(df, mask, other, expected)
  807. # setting nothing in either column
  808. mask[:] = True
  809. expected = df
  810. _check_where_equivalences(df, mask, other, expected)
  811. def test_where_producing_ea_cond_for_np_dtype():
  812. # GH#44014
  813. df = DataFrame({"a": Series([1, pd.NA, 2], dtype="Int64"), "b": [1, 2, 3]})
  814. result = df.where(lambda x: x.apply(lambda y: y > 1, axis=1))
  815. expected = DataFrame(
  816. {"a": Series([pd.NA, pd.NA, 2], dtype="Int64"), "b": [np.nan, 2, 3]}
  817. )
  818. tm.assert_frame_equal(result, expected)
  819. @pytest.mark.parametrize(
  820. "replacement", [0.001, True, "snake", None, datetime(2022, 5, 4)]
  821. )
  822. def test_where_int_overflow(replacement):
  823. # GH 31687
  824. df = DataFrame([[1.0, 2e25, "nine"], [np.nan, 0.1, None]])
  825. result = df.where(pd.notnull(df), replacement)
  826. expected = DataFrame([[1.0, 2e25, "nine"], [replacement, 0.1, replacement]])
  827. tm.assert_frame_equal(result, expected)
  828. def test_where_inplace_no_other():
  829. # GH#51685
  830. df = DataFrame({"a": [1, 2], "b": ["x", "y"]})
  831. cond = DataFrame({"a": [True, False], "b": [False, True]})
  832. df.where(cond, inplace=True)
  833. expected = DataFrame({"a": [1, np.nan], "b": [np.nan, "y"]})
  834. tm.assert_frame_equal(df, expected)