test_shift.py 22 KB


  1. import numpy as np
  2. import pytest
  3. import pandas.util._test_decorators as td
  4. import pandas as pd
  5. from pandas import (
  6. CategoricalIndex,
  7. DataFrame,
  8. Index,
  9. NaT,
  10. Series,
  11. date_range,
  12. offsets,
  13. )
  14. import pandas._testing as tm
  15. class TestDataFrameShift:
  16. @pytest.mark.parametrize(
  17. "input_data, output_data",
  18. [(np.empty(shape=(0,)), []), (np.ones(shape=(2,)), [np.nan, 1.0])],
  19. )
  20. def test_shift_non_writable_array(self, input_data, output_data, frame_or_series):
  21. # GH21049 Verify whether non writable numpy array is shiftable
  22. input_data.setflags(write=False)
  23. result = frame_or_series(input_data).shift(1)
  24. if frame_or_series is not Series:
  25. # need to explicitly specify columns in the empty case
  26. expected = frame_or_series(
  27. output_data,
  28. index=range(len(output_data)),
  29. columns=range(1),
  30. dtype="float64",
  31. )
  32. else:
  33. expected = frame_or_series(output_data, dtype="float64")
  34. tm.assert_equal(result, expected)
  35. def test_shift_mismatched_freq(self, frame_or_series):
  36. ts = frame_or_series(
  37. np.random.randn(5), index=date_range("1/1/2000", periods=5, freq="H")
  38. )
  39. result = ts.shift(1, freq="5T")
  40. exp_index = ts.index.shift(1, freq="5T")
  41. tm.assert_index_equal(result.index, exp_index)
  42. # GH#1063, multiple of same base
  43. result = ts.shift(1, freq="4H")
  44. exp_index = ts.index + offsets.Hour(4)
  45. tm.assert_index_equal(result.index, exp_index)
  46. @pytest.mark.parametrize(
  47. "obj",
  48. [
  49. Series([np.arange(5)]),
  50. date_range("1/1/2011", periods=24, freq="H"),
  51. Series(range(5), index=date_range("2017", periods=5)),
  52. ],
  53. )
  54. @pytest.mark.parametrize("shift_size", [0, 1, 2])
  55. def test_shift_always_copy(self, obj, shift_size, frame_or_series):
  56. # GH#22397
  57. if frame_or_series is not Series:
  58. obj = obj.to_frame()
  59. assert obj.shift(shift_size) is not obj
  60. def test_shift_object_non_scalar_fill(self):
  61. # shift requires scalar fill_value except for object dtype
  62. ser = Series(range(3))
  63. with pytest.raises(ValueError, match="fill_value must be a scalar"):
  64. ser.shift(1, fill_value=[])
  65. df = ser.to_frame()
  66. with pytest.raises(ValueError, match="fill_value must be a scalar"):
  67. df.shift(1, fill_value=np.arange(3))
  68. obj_ser = ser.astype(object)
  69. result = obj_ser.shift(1, fill_value={})
  70. assert result[0] == {}
  71. obj_df = obj_ser.to_frame()
  72. result = obj_df.shift(1, fill_value={})
  73. assert result.iloc[0, 0] == {}
  74. def test_shift_int(self, datetime_frame, frame_or_series):
  75. ts = tm.get_obj(datetime_frame, frame_or_series).astype(int)
  76. shifted = ts.shift(1)
  77. expected = ts.astype(float).shift(1)
  78. tm.assert_equal(shifted, expected)
  79. @pytest.mark.parametrize("dtype", ["int32", "int64"])
  80. def test_shift_32bit_take(self, frame_or_series, dtype):
  81. # 32-bit taking
  82. # GH#8129
  83. index = date_range("2000-01-01", periods=5)
  84. arr = np.arange(5, dtype=dtype)
  85. s1 = frame_or_series(arr, index=index)
  86. p = arr[1]
  87. result = s1.shift(periods=p)
  88. expected = frame_or_series([np.nan, 0, 1, 2, 3], index=index)
  89. tm.assert_equal(result, expected)
  90. @pytest.mark.parametrize("periods", [1, 2, 3, 4])
  91. def test_shift_preserve_freqstr(self, periods, frame_or_series):
  92. # GH#21275
  93. obj = frame_or_series(
  94. range(periods),
  95. index=date_range("2016-1-1 00:00:00", periods=periods, freq="H"),
  96. )
  97. result = obj.shift(1, "2H")
  98. expected = frame_or_series(
  99. range(periods),
  100. index=date_range("2016-1-1 02:00:00", periods=periods, freq="H"),
  101. )
  102. tm.assert_equal(result, expected)
  103. def test_shift_dst(self, frame_or_series):
  104. # GH#13926
  105. dates = date_range("2016-11-06", freq="H", periods=10, tz="US/Eastern")
  106. obj = frame_or_series(dates)
  107. res = obj.shift(0)
  108. tm.assert_equal(res, obj)
  109. assert tm.get_dtype(res) == "datetime64[ns, US/Eastern]"
  110. res = obj.shift(1)
  111. exp_vals = [NaT] + dates.astype(object).values.tolist()[:9]
  112. exp = frame_or_series(exp_vals)
  113. tm.assert_equal(res, exp)
  114. assert tm.get_dtype(res) == "datetime64[ns, US/Eastern]"
  115. res = obj.shift(-2)
  116. exp_vals = dates.astype(object).values.tolist()[2:] + [NaT, NaT]
  117. exp = frame_or_series(exp_vals)
  118. tm.assert_equal(res, exp)
  119. assert tm.get_dtype(res) == "datetime64[ns, US/Eastern]"
  120. @pytest.mark.parametrize("ex", [10, -10, 20, -20])
  121. def test_shift_dst_beyond(self, frame_or_series, ex):
  122. # GH#13926
  123. dates = date_range("2016-11-06", freq="H", periods=10, tz="US/Eastern")
  124. obj = frame_or_series(dates)
  125. res = obj.shift(ex)
  126. exp = frame_or_series([NaT] * 10, dtype="datetime64[ns, US/Eastern]")
  127. tm.assert_equal(res, exp)
  128. assert tm.get_dtype(res) == "datetime64[ns, US/Eastern]"
  129. def test_shift_by_zero(self, datetime_frame, frame_or_series):
  130. # shift by 0
  131. obj = tm.get_obj(datetime_frame, frame_or_series)
  132. unshifted = obj.shift(0)
  133. tm.assert_equal(unshifted, obj)
  134. def test_shift(self, datetime_frame):
  135. # naive shift
  136. ser = datetime_frame["A"]
  137. shifted = datetime_frame.shift(5)
  138. tm.assert_index_equal(shifted.index, datetime_frame.index)
  139. shifted_ser = ser.shift(5)
  140. tm.assert_series_equal(shifted["A"], shifted_ser)
  141. shifted = datetime_frame.shift(-5)
  142. tm.assert_index_equal(shifted.index, datetime_frame.index)
  143. shifted_ser = ser.shift(-5)
  144. tm.assert_series_equal(shifted["A"], shifted_ser)
  145. unshifted = datetime_frame.shift(5).shift(-5)
  146. tm.assert_numpy_array_equal(
  147. unshifted.dropna().values, datetime_frame.values[:-5]
  148. )
  149. unshifted_ser = ser.shift(5).shift(-5)
  150. tm.assert_numpy_array_equal(unshifted_ser.dropna().values, ser.values[:-5])
  151. def test_shift_by_offset(self, datetime_frame, frame_or_series):
  152. # shift by DateOffset
  153. obj = tm.get_obj(datetime_frame, frame_or_series)
  154. offset = offsets.BDay()
  155. shifted = obj.shift(5, freq=offset)
  156. assert len(shifted) == len(obj)
  157. unshifted = shifted.shift(-5, freq=offset)
  158. tm.assert_equal(unshifted, obj)
  159. shifted2 = obj.shift(5, freq="B")
  160. tm.assert_equal(shifted, shifted2)
  161. unshifted = obj.shift(0, freq=offset)
  162. tm.assert_equal(unshifted, obj)
  163. d = obj.index[0]
  164. shifted_d = d + offset * 5
  165. if frame_or_series is DataFrame:
  166. tm.assert_series_equal(obj.xs(d), shifted.xs(shifted_d), check_names=False)
  167. else:
  168. tm.assert_almost_equal(obj.at[d], shifted.at[shifted_d])
  169. def test_shift_with_periodindex(self, frame_or_series):
  170. # Shifting with PeriodIndex
  171. ps = tm.makePeriodFrame()
  172. ps = tm.get_obj(ps, frame_or_series)
  173. shifted = ps.shift(1)
  174. unshifted = shifted.shift(-1)
  175. tm.assert_index_equal(shifted.index, ps.index)
  176. tm.assert_index_equal(unshifted.index, ps.index)
  177. if frame_or_series is DataFrame:
  178. tm.assert_numpy_array_equal(
  179. unshifted.iloc[:, 0].dropna().values, ps.iloc[:-1, 0].values
  180. )
  181. else:
  182. tm.assert_numpy_array_equal(unshifted.dropna().values, ps.values[:-1])
  183. shifted2 = ps.shift(1, "B")
  184. shifted3 = ps.shift(1, offsets.BDay())
  185. tm.assert_equal(shifted2, shifted3)
  186. tm.assert_equal(ps, shifted2.shift(-1, "B"))
  187. msg = "does not match PeriodIndex freq"
  188. with pytest.raises(ValueError, match=msg):
  189. ps.shift(freq="D")
  190. # legacy support
  191. shifted4 = ps.shift(1, freq="B")
  192. tm.assert_equal(shifted2, shifted4)
  193. shifted5 = ps.shift(1, freq=offsets.BDay())
  194. tm.assert_equal(shifted5, shifted4)
  195. def test_shift_other_axis(self):
  196. # shift other axis
  197. # GH#6371
  198. df = DataFrame(np.random.rand(10, 5))
  199. expected = pd.concat(
  200. [DataFrame(np.nan, index=df.index, columns=[0]), df.iloc[:, 0:-1]],
  201. ignore_index=True,
  202. axis=1,
  203. )
  204. result = df.shift(1, axis=1)
  205. tm.assert_frame_equal(result, expected)
  206. def test_shift_named_axis(self):
  207. # shift named axis
  208. df = DataFrame(np.random.rand(10, 5))
  209. expected = pd.concat(
  210. [DataFrame(np.nan, index=df.index, columns=[0]), df.iloc[:, 0:-1]],
  211. ignore_index=True,
  212. axis=1,
  213. )
  214. result = df.shift(1, axis="columns")
  215. tm.assert_frame_equal(result, expected)
  216. def test_shift_other_axis_with_freq(self, datetime_frame):
  217. obj = datetime_frame.T
  218. offset = offsets.BDay()
  219. # GH#47039
  220. shifted = obj.shift(5, freq=offset, axis=1)
  221. assert len(shifted) == len(obj)
  222. unshifted = shifted.shift(-5, freq=offset, axis=1)
  223. tm.assert_equal(unshifted, obj)
  224. def test_shift_bool(self):
  225. df = DataFrame({"high": [True, False], "low": [False, False]})
  226. rs = df.shift(1)
  227. xp = DataFrame(
  228. np.array([[np.nan, np.nan], [True, False]], dtype=object),
  229. columns=["high", "low"],
  230. )
  231. tm.assert_frame_equal(rs, xp)
  232. def test_shift_categorical1(self, frame_or_series):
  233. # GH#9416
  234. obj = frame_or_series(["a", "b", "c", "d"], dtype="category")
  235. rt = obj.shift(1).shift(-1)
  236. tm.assert_equal(obj.iloc[:-1], rt.dropna())
  237. def get_cat_values(ndframe):
  238. # For Series we could just do ._values; for DataFrame
  239. # we may be able to do this if we ever have 2D Categoricals
  240. return ndframe._mgr.arrays[0]
  241. cat = get_cat_values(obj)
  242. sp1 = obj.shift(1)
  243. tm.assert_index_equal(obj.index, sp1.index)
  244. assert np.all(get_cat_values(sp1).codes[:1] == -1)
  245. assert np.all(cat.codes[:-1] == get_cat_values(sp1).codes[1:])
  246. sn2 = obj.shift(-2)
  247. tm.assert_index_equal(obj.index, sn2.index)
  248. assert np.all(get_cat_values(sn2).codes[-2:] == -1)
  249. assert np.all(cat.codes[2:] == get_cat_values(sn2).codes[:-2])
  250. tm.assert_index_equal(cat.categories, get_cat_values(sp1).categories)
  251. tm.assert_index_equal(cat.categories, get_cat_values(sn2).categories)
  252. def test_shift_categorical(self):
  253. # GH#9416
  254. s1 = Series(["a", "b", "c"], dtype="category")
  255. s2 = Series(["A", "B", "C"], dtype="category")
  256. df = DataFrame({"one": s1, "two": s2})
  257. rs = df.shift(1)
  258. xp = DataFrame({"one": s1.shift(1), "two": s2.shift(1)})
  259. tm.assert_frame_equal(rs, xp)
  260. def test_shift_categorical_fill_value(self, frame_or_series):
  261. ts = frame_or_series(["a", "b", "c", "d"], dtype="category")
  262. res = ts.shift(1, fill_value="a")
  263. expected = frame_or_series(
  264. pd.Categorical(
  265. ["a", "a", "b", "c"], categories=["a", "b", "c", "d"], ordered=False
  266. )
  267. )
  268. tm.assert_equal(res, expected)
  269. # check for incorrect fill_value
  270. msg = r"Cannot setitem on a Categorical with a new category \(f\)"
  271. with pytest.raises(TypeError, match=msg):
  272. ts.shift(1, fill_value="f")
  273. def test_shift_fill_value(self, frame_or_series):
  274. # GH#24128
  275. dti = date_range("1/1/2000", periods=5, freq="H")
  276. ts = frame_or_series([1.0, 2.0, 3.0, 4.0, 5.0], index=dti)
  277. exp = frame_or_series([0.0, 1.0, 2.0, 3.0, 4.0], index=dti)
  278. # check that fill value works
  279. result = ts.shift(1, fill_value=0.0)
  280. tm.assert_equal(result, exp)
  281. exp = frame_or_series([0.0, 0.0, 1.0, 2.0, 3.0], index=dti)
  282. result = ts.shift(2, fill_value=0.0)
  283. tm.assert_equal(result, exp)
  284. ts = frame_or_series([1, 2, 3])
  285. res = ts.shift(2, fill_value=0)
  286. assert tm.get_dtype(res) == tm.get_dtype(ts)
  287. # retain integer dtype
  288. obj = frame_or_series([1, 2, 3, 4, 5], index=dti)
  289. exp = frame_or_series([0, 1, 2, 3, 4], index=dti)
  290. result = obj.shift(1, fill_value=0)
  291. tm.assert_equal(result, exp)
  292. exp = frame_or_series([0, 0, 1, 2, 3], index=dti)
  293. result = obj.shift(2, fill_value=0)
  294. tm.assert_equal(result, exp)
  295. def test_shift_empty(self):
  296. # Regression test for GH#8019
  297. df = DataFrame({"foo": []})
  298. rs = df.shift(-1)
  299. tm.assert_frame_equal(df, rs)
  300. def test_shift_duplicate_columns(self):
  301. # GH#9092; verify that position-based shifting works
  302. # in the presence of duplicate columns
  303. column_lists = [list(range(5)), [1] * 5, [1, 1, 2, 2, 1]]
  304. data = np.random.randn(20, 5)
  305. shifted = []
  306. for columns in column_lists:
  307. df = DataFrame(data.copy(), columns=columns)
  308. for s in range(5):
  309. df.iloc[:, s] = df.iloc[:, s].shift(s + 1)
  310. df.columns = range(5)
  311. shifted.append(df)
  312. # sanity check the base case
  313. nulls = shifted[0].isna().sum()
  314. tm.assert_series_equal(nulls, Series(range(1, 6), dtype="int64"))
  315. # check all answers are the same
  316. tm.assert_frame_equal(shifted[0], shifted[1])
  317. tm.assert_frame_equal(shifted[0], shifted[2])
  318. def test_shift_axis1_multiple_blocks(self, using_array_manager):
  319. # GH#35488
  320. df1 = DataFrame(np.random.randint(1000, size=(5, 3)))
  321. df2 = DataFrame(np.random.randint(1000, size=(5, 2)))
  322. df3 = pd.concat([df1, df2], axis=1)
  323. if not using_array_manager:
  324. assert len(df3._mgr.blocks) == 2
  325. result = df3.shift(2, axis=1)
  326. expected = df3.take([-1, -1, 0, 1, 2], axis=1)
  327. # Explicit cast to float to avoid implicit cast when setting nan.
  328. # Column names aren't unique, so directly calling `expected.astype` won't work.
  329. expected = expected.pipe(
  330. lambda df: df.set_axis(range(df.shape[1]), axis=1)
  331. .astype({0: "float", 1: "float"})
  332. .set_axis(df.columns, axis=1)
  333. )
  334. expected.iloc[:, :2] = np.nan
  335. expected.columns = df3.columns
  336. tm.assert_frame_equal(result, expected)
  337. # Case with periods < 0
  338. # rebuild df3 because `take` call above consolidated
  339. df3 = pd.concat([df1, df2], axis=1)
  340. if not using_array_manager:
  341. assert len(df3._mgr.blocks) == 2
  342. result = df3.shift(-2, axis=1)
  343. expected = df3.take([2, 3, 4, -1, -1], axis=1)
  344. # Explicit cast to float to avoid implicit cast when setting nan.
  345. # Column names aren't unique, so directly calling `expected.astype` won't work.
  346. expected = expected.pipe(
  347. lambda df: df.set_axis(range(df.shape[1]), axis=1)
  348. .astype({3: "float", 4: "float"})
  349. .set_axis(df.columns, axis=1)
  350. )
  351. expected.iloc[:, -2:] = np.nan
  352. expected.columns = df3.columns
  353. tm.assert_frame_equal(result, expected)
  354. @td.skip_array_manager_not_yet_implemented # TODO(ArrayManager) axis=1 support
  355. def test_shift_axis1_multiple_blocks_with_int_fill(self):
  356. # GH#42719
  357. df1 = DataFrame(np.random.randint(1000, size=(5, 3)))
  358. df2 = DataFrame(np.random.randint(1000, size=(5, 2)))
  359. df3 = pd.concat([df1.iloc[:4, 1:3], df2.iloc[:4, :]], axis=1)
  360. result = df3.shift(2, axis=1, fill_value=np.int_(0))
  361. assert len(df3._mgr.blocks) == 2
  362. expected = df3.take([-1, -1, 0, 1], axis=1)
  363. expected.iloc[:, :2] = np.int_(0)
  364. expected.columns = df3.columns
  365. tm.assert_frame_equal(result, expected)
  366. # Case with periods < 0
  367. df3 = pd.concat([df1.iloc[:4, 1:3], df2.iloc[:4, :]], axis=1)
  368. result = df3.shift(-2, axis=1, fill_value=np.int_(0))
  369. assert len(df3._mgr.blocks) == 2
  370. expected = df3.take([2, 3, -1, -1], axis=1)
  371. expected.iloc[:, -2:] = np.int_(0)
  372. expected.columns = df3.columns
  373. tm.assert_frame_equal(result, expected)
  374. def test_period_index_frame_shift_with_freq(self, frame_or_series):
  375. ps = tm.makePeriodFrame()
  376. ps = tm.get_obj(ps, frame_or_series)
  377. shifted = ps.shift(1, freq="infer")
  378. unshifted = shifted.shift(-1, freq="infer")
  379. tm.assert_equal(unshifted, ps)
  380. shifted2 = ps.shift(freq="B")
  381. tm.assert_equal(shifted, shifted2)
  382. shifted3 = ps.shift(freq=offsets.BDay())
  383. tm.assert_equal(shifted, shifted3)
  384. def test_datetime_frame_shift_with_freq(self, datetime_frame, frame_or_series):
  385. dtobj = tm.get_obj(datetime_frame, frame_or_series)
  386. shifted = dtobj.shift(1, freq="infer")
  387. unshifted = shifted.shift(-1, freq="infer")
  388. tm.assert_equal(dtobj, unshifted)
  389. shifted2 = dtobj.shift(freq=dtobj.index.freq)
  390. tm.assert_equal(shifted, shifted2)
  391. inferred_ts = DataFrame(
  392. datetime_frame.values,
  393. Index(np.asarray(datetime_frame.index)),
  394. columns=datetime_frame.columns,
  395. )
  396. inferred_ts = tm.get_obj(inferred_ts, frame_or_series)
  397. shifted = inferred_ts.shift(1, freq="infer")
  398. expected = dtobj.shift(1, freq="infer")
  399. expected.index = expected.index._with_freq(None)
  400. tm.assert_equal(shifted, expected)
  401. unshifted = shifted.shift(-1, freq="infer")
  402. tm.assert_equal(unshifted, inferred_ts)
  403. def test_period_index_frame_shift_with_freq_error(self, frame_or_series):
  404. ps = tm.makePeriodFrame()
  405. ps = tm.get_obj(ps, frame_or_series)
  406. msg = "Given freq M does not match PeriodIndex freq B"
  407. with pytest.raises(ValueError, match=msg):
  408. ps.shift(freq="M")
  409. def test_datetime_frame_shift_with_freq_error(
  410. self, datetime_frame, frame_or_series
  411. ):
  412. dtobj = tm.get_obj(datetime_frame, frame_or_series)
  413. no_freq = dtobj.iloc[[0, 5, 7]]
  414. msg = "Freq was not set in the index hence cannot be inferred"
  415. with pytest.raises(ValueError, match=msg):
  416. no_freq.shift(freq="infer")
  417. def test_shift_dt64values_int_fill_deprecated(self):
  418. # GH#31971
  419. ser = Series([pd.Timestamp("2020-01-01"), pd.Timestamp("2020-01-02")])
  420. with pytest.raises(TypeError, match="value should be a"):
  421. ser.shift(1, fill_value=0)
  422. df = ser.to_frame()
  423. with pytest.raises(TypeError, match="value should be a"):
  424. df.shift(1, fill_value=0)
  425. # axis = 1
  426. df2 = DataFrame({"A": ser, "B": ser})
  427. df2._consolidate_inplace()
  428. result = df2.shift(1, axis=1, fill_value=0)
  429. expected = DataFrame({"A": [0, 0], "B": df2["A"]})
  430. tm.assert_frame_equal(result, expected)
  431. # same thing but not consolidated; pre-2.0 we got different behavior
  432. df3 = DataFrame({"A": ser})
  433. df3["B"] = ser
  434. assert len(df3._mgr.arrays) == 2
  435. result = df3.shift(1, axis=1, fill_value=0)
  436. tm.assert_frame_equal(result, expected)
  437. @pytest.mark.parametrize(
  438. "as_cat",
  439. [
  440. pytest.param(
  441. True,
  442. marks=pytest.mark.xfail(
  443. reason="_can_hold_element incorrectly always returns True"
  444. ),
  445. ),
  446. False,
  447. ],
  448. )
  449. @pytest.mark.parametrize(
  450. "vals",
  451. [
  452. date_range("2020-01-01", periods=2),
  453. date_range("2020-01-01", periods=2, tz="US/Pacific"),
  454. pd.period_range("2020-01-01", periods=2, freq="D"),
  455. pd.timedelta_range("2020 Days", periods=2, freq="D"),
  456. pd.interval_range(0, 3, periods=2),
  457. pytest.param(
  458. pd.array([1, 2], dtype="Int64"),
  459. marks=pytest.mark.xfail(
  460. reason="_can_hold_element incorrectly always returns True"
  461. ),
  462. ),
  463. pytest.param(
  464. pd.array([1, 2], dtype="Float32"),
  465. marks=pytest.mark.xfail(
  466. reason="_can_hold_element incorrectly always returns True"
  467. ),
  468. ),
  469. ],
  470. ids=lambda x: str(x.dtype),
  471. )
  472. def test_shift_dt64values_axis1_invalid_fill(self, vals, as_cat):
  473. # GH#44564
  474. ser = Series(vals)
  475. if as_cat:
  476. ser = ser.astype("category")
  477. df = DataFrame({"A": ser})
  478. result = df.shift(-1, axis=1, fill_value="foo")
  479. expected = DataFrame({"A": ["foo", "foo"]})
  480. tm.assert_frame_equal(result, expected)
  481. # same thing but multiple blocks
  482. df2 = DataFrame({"A": ser, "B": ser})
  483. df2._consolidate_inplace()
  484. result = df2.shift(-1, axis=1, fill_value="foo")
  485. expected = DataFrame({"A": df2["B"], "B": ["foo", "foo"]})
  486. tm.assert_frame_equal(result, expected)
  487. # same thing but not consolidated
  488. df3 = DataFrame({"A": ser})
  489. df3["B"] = ser
  490. assert len(df3._mgr.arrays) == 2
  491. result = df3.shift(-1, axis=1, fill_value="foo")
  492. tm.assert_frame_equal(result, expected)
  493. def test_shift_axis1_categorical_columns(self):
  494. # GH#38434
  495. ci = CategoricalIndex(["a", "b", "c"])
  496. df = DataFrame(
  497. {"a": [1, 3], "b": [2, 4], "c": [5, 6]}, index=ci[:-1], columns=ci
  498. )
  499. result = df.shift(axis=1)
  500. expected = DataFrame(
  501. {"a": [np.nan, np.nan], "b": [1, 3], "c": [2, 4]}, index=ci[:-1], columns=ci
  502. )
  503. tm.assert_frame_equal(result, expected)
  504. # periods != 1
  505. result = df.shift(2, axis=1)
  506. expected = DataFrame(
  507. {"a": [np.nan, np.nan], "b": [np.nan, np.nan], "c": [1, 3]},
  508. index=ci[:-1],
  509. columns=ci,
  510. )
  511. tm.assert_frame_equal(result, expected)
  512. def test_shift_axis1_many_periods(self):
  513. # GH#44978 periods > len(columns)
  514. df = DataFrame(np.random.rand(5, 3))
  515. shifted = df.shift(6, axis=1, fill_value=None)
  516. expected = df * np.nan
  517. tm.assert_frame_equal(shifted, expected)
  518. shifted2 = df.shift(-6, axis=1, fill_value=None)
  519. tm.assert_frame_equal(shifted2, expected)