test_transform.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482
  1. """ test with the .transform """
  2. from io import StringIO
  3. import numpy as np
  4. import pytest
  5. from pandas.core.dtypes.common import (
  6. ensure_platform_int,
  7. is_timedelta64_dtype,
  8. )
  9. import pandas as pd
  10. from pandas import (
  11. Categorical,
  12. DataFrame,
  13. MultiIndex,
  14. Series,
  15. Timestamp,
  16. concat,
  17. date_range,
  18. )
  19. import pandas._testing as tm
  20. from pandas.tests.groupby import get_groupby_method_args
  21. def assert_fp_equal(a, b):
  22. assert (np.abs(a - b) < 1e-12).all()
  23. def test_transform():
  24. data = Series(np.arange(9) // 3, index=np.arange(9))
  25. index = np.arange(9)
  26. np.random.shuffle(index)
  27. data = data.reindex(index)
  28. grouped = data.groupby(lambda x: x // 3)
  29. transformed = grouped.transform(lambda x: x * x.sum())
  30. assert transformed[7] == 12
  31. # GH 8046
  32. # make sure that we preserve the input order
  33. df = DataFrame(
  34. np.arange(6, dtype="int64").reshape(3, 2), columns=["a", "b"], index=[0, 2, 1]
  35. )
  36. key = [0, 0, 1]
  37. expected = (
  38. df.sort_index()
  39. .groupby(key)
  40. .transform(lambda x: x - x.mean())
  41. .groupby(key)
  42. .mean()
  43. )
  44. result = df.groupby(key).transform(lambda x: x - x.mean()).groupby(key).mean()
  45. tm.assert_frame_equal(result, expected)
  46. def demean(arr):
  47. return arr - arr.mean(axis=0)
  48. people = DataFrame(
  49. np.random.randn(5, 5),
  50. columns=["a", "b", "c", "d", "e"],
  51. index=["Joe", "Steve", "Wes", "Jim", "Travis"],
  52. )
  53. key = ["one", "two", "one", "two", "one"]
  54. result = people.groupby(key).transform(demean).groupby(key).mean()
  55. expected = people.groupby(key, group_keys=False).apply(demean).groupby(key).mean()
  56. tm.assert_frame_equal(result, expected)
  57. # GH 8430
  58. df = tm.makeTimeDataFrame()
  59. g = df.groupby(pd.Grouper(freq="M"))
  60. g.transform(lambda x: x - 1)
  61. # GH 9700
  62. df = DataFrame({"a": range(5, 10), "b": range(5)})
  63. result = df.groupby("a").transform(max)
  64. expected = DataFrame({"b": range(5)})
  65. tm.assert_frame_equal(result, expected)
  66. def test_transform_fast():
  67. df = DataFrame({"id": np.arange(100000) / 3, "val": np.random.randn(100000)})
  68. grp = df.groupby("id")["val"]
  69. values = np.repeat(grp.mean().values, ensure_platform_int(grp.count().values))
  70. expected = Series(values, index=df.index, name="val")
  71. result = grp.transform(np.mean)
  72. tm.assert_series_equal(result, expected)
  73. result = grp.transform("mean")
  74. tm.assert_series_equal(result, expected)
  75. # GH 12737
  76. df = DataFrame(
  77. {
  78. "grouping": [0, 1, 1, 3],
  79. "f": [1.1, 2.1, 3.1, 4.5],
  80. "d": date_range("2014-1-1", "2014-1-4"),
  81. "i": [1, 2, 3, 4],
  82. },
  83. columns=["grouping", "f", "i", "d"],
  84. )
  85. result = df.groupby("grouping").transform("first")
  86. dates = [
  87. Timestamp("2014-1-1"),
  88. Timestamp("2014-1-2"),
  89. Timestamp("2014-1-2"),
  90. Timestamp("2014-1-4"),
  91. ]
  92. expected = DataFrame(
  93. {"f": [1.1, 2.1, 2.1, 4.5], "d": dates, "i": [1, 2, 2, 4]},
  94. columns=["f", "i", "d"],
  95. )
  96. tm.assert_frame_equal(result, expected)
  97. # selection
  98. result = df.groupby("grouping")[["f", "i"]].transform("first")
  99. expected = expected[["f", "i"]]
  100. tm.assert_frame_equal(result, expected)
  101. # dup columns
  102. df = DataFrame([[1, 2, 3], [4, 5, 6]], columns=["g", "a", "a"])
  103. result = df.groupby("g").transform("first")
  104. expected = df.drop("g", axis=1)
  105. tm.assert_frame_equal(result, expected)
  106. def test_transform_broadcast(tsframe, ts):
  107. grouped = ts.groupby(lambda x: x.month)
  108. result = grouped.transform(np.mean)
  109. tm.assert_index_equal(result.index, ts.index)
  110. for _, gp in grouped:
  111. assert_fp_equal(result.reindex(gp.index), gp.mean())
  112. grouped = tsframe.groupby(lambda x: x.month)
  113. result = grouped.transform(np.mean)
  114. tm.assert_index_equal(result.index, tsframe.index)
  115. for _, gp in grouped:
  116. agged = gp.mean(axis=0)
  117. res = result.reindex(gp.index)
  118. for col in tsframe:
  119. assert_fp_equal(res[col], agged[col])
  120. # group columns
  121. grouped = tsframe.groupby({"A": 0, "B": 0, "C": 1, "D": 1}, axis=1)
  122. result = grouped.transform(np.mean)
  123. tm.assert_index_equal(result.index, tsframe.index)
  124. tm.assert_index_equal(result.columns, tsframe.columns)
  125. for _, gp in grouped:
  126. agged = gp.mean(1)
  127. res = result.reindex(columns=gp.columns)
  128. for idx in gp.index:
  129. assert_fp_equal(res.xs(idx), agged[idx])
  130. def test_transform_axis_1(request, transformation_func):
  131. # GH 36308
  132. df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"])
  133. args = get_groupby_method_args(transformation_func, df)
  134. result = df.groupby([0, 0, 1], axis=1).transform(transformation_func, *args)
  135. expected = df.T.groupby([0, 0, 1]).transform(transformation_func, *args).T
  136. if transformation_func in ["diff", "shift"]:
  137. # Result contains nans, so transpose coerces to float
  138. expected["b"] = expected["b"].astype("int64")
  139. # cumcount returns Series; the rest are DataFrame
  140. tm.assert_equal(result, expected)
  141. def test_transform_axis_1_reducer(request, reduction_func):
  142. # GH#45715
  143. if reduction_func in (
  144. "corrwith",
  145. "ngroup",
  146. "nth",
  147. ):
  148. marker = pytest.mark.xfail(reason="transform incorrectly fails - GH#45986")
  149. request.node.add_marker(marker)
  150. df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"])
  151. result = df.groupby([0, 0, 1], axis=1).transform(reduction_func)
  152. expected = df.T.groupby([0, 0, 1]).transform(reduction_func).T
  153. tm.assert_equal(result, expected)
  154. def test_transform_axis_ts(tsframe):
  155. # make sure that we are setting the axes
  156. # correctly when on axis=0 or 1
  157. # in the presence of a non-monotonic indexer
  158. # GH12713
  159. base = tsframe.iloc[0:5]
  160. r = len(base.index)
  161. c = len(base.columns)
  162. tso = DataFrame(
  163. np.random.randn(r, c), index=base.index, columns=base.columns, dtype="float64"
  164. )
  165. # monotonic
  166. ts = tso
  167. grouped = ts.groupby(lambda x: x.weekday(), group_keys=False)
  168. result = ts - grouped.transform("mean")
  169. expected = grouped.apply(lambda x: x - x.mean(axis=0))
  170. tm.assert_frame_equal(result, expected)
  171. ts = ts.T
  172. grouped = ts.groupby(lambda x: x.weekday(), axis=1, group_keys=False)
  173. result = ts - grouped.transform("mean")
  174. expected = grouped.apply(lambda x: (x.T - x.mean(1)).T)
  175. tm.assert_frame_equal(result, expected)
  176. # non-monotonic
  177. ts = tso.iloc[[1, 0] + list(range(2, len(base)))]
  178. grouped = ts.groupby(lambda x: x.weekday(), group_keys=False)
  179. result = ts - grouped.transform("mean")
  180. expected = grouped.apply(lambda x: x - x.mean(axis=0))
  181. tm.assert_frame_equal(result, expected)
  182. ts = ts.T
  183. grouped = ts.groupby(lambda x: x.weekday(), axis=1, group_keys=False)
  184. result = ts - grouped.transform("mean")
  185. expected = grouped.apply(lambda x: (x.T - x.mean(1)).T)
  186. tm.assert_frame_equal(result, expected)
  187. def test_transform_dtype():
  188. # GH 9807
  189. # Check transform dtype output is preserved
  190. df = DataFrame([[1, 3], [2, 3]])
  191. result = df.groupby(1).transform("mean")
  192. expected = DataFrame([[1.5], [1.5]])
  193. tm.assert_frame_equal(result, expected)
  194. def test_transform_bug():
  195. # GH 5712
  196. # transforming on a datetime column
  197. df = DataFrame({"A": Timestamp("20130101"), "B": np.arange(5)})
  198. result = df.groupby("A")["B"].transform(lambda x: x.rank(ascending=False))
  199. expected = Series(np.arange(5, 0, step=-1), name="B", dtype="float64")
  200. tm.assert_series_equal(result, expected)
  201. def test_transform_numeric_to_boolean():
  202. # GH 16875
  203. # inconsistency in transforming boolean values
  204. expected = Series([True, True], name="A")
  205. df = DataFrame({"A": [1.1, 2.2], "B": [1, 2]})
  206. result = df.groupby("B").A.transform(lambda x: True)
  207. tm.assert_series_equal(result, expected)
  208. df = DataFrame({"A": [1, 2], "B": [1, 2]})
  209. result = df.groupby("B").A.transform(lambda x: True)
  210. tm.assert_series_equal(result, expected)
  211. def test_transform_datetime_to_timedelta():
  212. # GH 15429
  213. # transforming a datetime to timedelta
  214. df = DataFrame({"A": Timestamp("20130101"), "B": np.arange(5)})
  215. expected = Series([Timestamp("20130101") - Timestamp("20130101")] * 5, name="A")
  216. # this does date math without changing result type in transform
  217. base_time = df["A"][0]
  218. result = (
  219. df.groupby("A")["A"].transform(lambda x: x.max() - x.min() + base_time)
  220. - base_time
  221. )
  222. tm.assert_series_equal(result, expected)
  223. # this does date math and causes the transform to return timedelta
  224. result = df.groupby("A")["A"].transform(lambda x: x.max() - x.min())
  225. tm.assert_series_equal(result, expected)
  226. def test_transform_datetime_to_numeric():
  227. # GH 10972
  228. # convert dt to float
  229. df = DataFrame({"a": 1, "b": date_range("2015-01-01", periods=2, freq="D")})
  230. result = df.groupby("a").b.transform(
  231. lambda x: x.dt.dayofweek - x.dt.dayofweek.mean()
  232. )
  233. expected = Series([-0.5, 0.5], name="b")
  234. tm.assert_series_equal(result, expected)
  235. # convert dt to int
  236. df = DataFrame({"a": 1, "b": date_range("2015-01-01", periods=2, freq="D")})
  237. result = df.groupby("a").b.transform(
  238. lambda x: x.dt.dayofweek - x.dt.dayofweek.min()
  239. )
  240. expected = Series([0, 1], dtype=np.int32, name="b")
  241. tm.assert_series_equal(result, expected)
  242. def test_transform_casting():
  243. # 13046
  244. data = """
  245. idx A ID3 DATETIME
  246. 0 B-028 b76cd912ff "2014-10-08 13:43:27"
  247. 1 B-054 4a57ed0b02 "2014-10-08 14:26:19"
  248. 2 B-076 1a682034f8 "2014-10-08 14:29:01"
  249. 3 B-023 b76cd912ff "2014-10-08 18:39:34"
  250. 4 B-023 f88g8d7sds "2014-10-08 18:40:18"
  251. 5 B-033 b76cd912ff "2014-10-08 18:44:30"
  252. 6 B-032 b76cd912ff "2014-10-08 18:46:00"
  253. 7 B-037 b76cd912ff "2014-10-08 18:52:15"
  254. 8 B-046 db959faf02 "2014-10-08 18:59:59"
  255. 9 B-053 b76cd912ff "2014-10-08 19:17:48"
  256. 10 B-065 b76cd912ff "2014-10-08 19:21:38"
  257. """
  258. df = pd.read_csv(
  259. StringIO(data), sep=r"\s+", index_col=[0], parse_dates=["DATETIME"]
  260. )
  261. result = df.groupby("ID3")["DATETIME"].transform(lambda x: x.diff())
  262. assert is_timedelta64_dtype(result.dtype)
  263. result = df[["ID3", "DATETIME"]].groupby("ID3").transform(lambda x: x.diff())
  264. assert is_timedelta64_dtype(result.DATETIME.dtype)
  265. def test_transform_multiple(ts):
  266. grouped = ts.groupby([lambda x: x.year, lambda x: x.month])
  267. grouped.transform(lambda x: x * 2)
  268. grouped.transform(np.mean)
  269. def test_dispatch_transform(tsframe):
  270. df = tsframe[::5].reindex(tsframe.index)
  271. grouped = df.groupby(lambda x: x.month)
  272. filled = grouped.fillna(method="pad")
  273. fillit = lambda x: x.fillna(method="pad")
  274. expected = df.groupby(lambda x: x.month).transform(fillit)
  275. tm.assert_frame_equal(filled, expected)
  276. def test_transform_transformation_func(transformation_func):
  277. # GH 30918
  278. df = DataFrame(
  279. {
  280. "A": ["foo", "foo", "foo", "foo", "bar", "bar", "baz"],
  281. "B": [1, 2, np.nan, 3, 3, np.nan, 4],
  282. },
  283. index=date_range("2020-01-01", "2020-01-07"),
  284. )
  285. if transformation_func == "cumcount":
  286. test_op = lambda x: x.transform("cumcount")
  287. mock_op = lambda x: Series(range(len(x)), x.index)
  288. elif transformation_func == "fillna":
  289. test_op = lambda x: x.transform("fillna", value=0)
  290. mock_op = lambda x: x.fillna(value=0)
  291. elif transformation_func == "ngroup":
  292. test_op = lambda x: x.transform("ngroup")
  293. counter = -1
  294. def mock_op(x):
  295. nonlocal counter
  296. counter += 1
  297. return Series(counter, index=x.index)
  298. else:
  299. test_op = lambda x: x.transform(transformation_func)
  300. mock_op = lambda x: getattr(x, transformation_func)()
  301. result = test_op(df.groupby("A"))
  302. # pass the group in same order as iterating `for ... in df.groupby(...)`
  303. # but reorder to match df's index since this is a transform
  304. groups = [df[["B"]].iloc[4:6], df[["B"]].iloc[6:], df[["B"]].iloc[:4]]
  305. expected = concat([mock_op(g) for g in groups]).sort_index()
  306. # sort_index does not preserve the freq
  307. expected = expected.set_axis(df.index)
  308. if transformation_func in ("cumcount", "ngroup"):
  309. tm.assert_series_equal(result, expected)
  310. else:
  311. tm.assert_frame_equal(result, expected)
  312. def test_transform_select_columns(df):
  313. f = lambda x: x.mean()
  314. result = df.groupby("A")[["C", "D"]].transform(f)
  315. selection = df[["C", "D"]]
  316. expected = selection.groupby(df["A"]).transform(f)
  317. tm.assert_frame_equal(result, expected)
  318. def test_transform_nuisance_raises(df):
  319. # case that goes through _transform_item_by_item
  320. df.columns = ["A", "B", "B", "D"]
  321. # this also tests orderings in transform between
  322. # series/frame to make sure it's consistent
  323. grouped = df.groupby("A")
  324. gbc = grouped["B"]
  325. with pytest.raises(TypeError, match="Could not convert"):
  326. gbc.transform(lambda x: np.mean(x))
  327. with pytest.raises(TypeError, match="Could not convert"):
  328. df.groupby("A").transform(lambda x: np.mean(x))
  329. def test_transform_function_aliases(df):
  330. result = df.groupby("A").transform("mean", numeric_only=True)
  331. expected = df.groupby("A")[["C", "D"]].transform(np.mean)
  332. tm.assert_frame_equal(result, expected)
  333. result = df.groupby("A")["C"].transform("mean")
  334. expected = df.groupby("A")["C"].transform(np.mean)
  335. tm.assert_series_equal(result, expected)
  336. def test_series_fast_transform_date():
  337. # GH 13191
  338. df = DataFrame(
  339. {"grouping": [np.nan, 1, 1, 3], "d": date_range("2014-1-1", "2014-1-4")}
  340. )
  341. result = df.groupby("grouping")["d"].transform("first")
  342. dates = [
  343. pd.NaT,
  344. Timestamp("2014-1-2"),
  345. Timestamp("2014-1-2"),
  346. Timestamp("2014-1-4"),
  347. ]
  348. expected = Series(dates, name="d")
  349. tm.assert_series_equal(result, expected)
  350. def test_transform_length():
  351. # GH 9697
  352. df = DataFrame({"col1": [1, 1, 2, 2], "col2": [1, 2, 3, np.nan]})
  353. expected = Series([3.0] * 4)
  354. def nsum(x):
  355. return np.nansum(x)
  356. results = [
  357. df.groupby("col1").transform(sum)["col2"],
  358. df.groupby("col1")["col2"].transform(sum),
  359. df.groupby("col1").transform(nsum)["col2"],
  360. df.groupby("col1")["col2"].transform(nsum),
  361. ]
  362. for result in results:
  363. tm.assert_series_equal(result, expected, check_names=False)
  364. def test_transform_coercion():
  365. # 14457
  366. # when we are transforming be sure to not coerce
  367. # via assignment
  368. df = DataFrame({"A": ["a", "a", "b", "b"], "B": [0, 1, 3, 4]})
  369. g = df.groupby("A")
  370. expected = g.transform(np.mean)
  371. result = g.transform(lambda x: np.mean(x, axis=0))
  372. tm.assert_frame_equal(result, expected)
  373. def test_groupby_transform_with_int():
  374. # GH 3740, make sure that we might upcast on item-by-item transform
  375. # floats
  376. df = DataFrame(
  377. {
  378. "A": [1, 1, 1, 2, 2, 2],
  379. "B": Series(1, dtype="float64"),
  380. "C": Series([1, 2, 3, 1, 2, 3], dtype="float64"),
  381. "D": "foo",
  382. }
  383. )
  384. with np.errstate(all="ignore"):
  385. result = df.groupby("A")[["B", "C"]].transform(
  386. lambda x: (x - x.mean()) / x.std()
  387. )
  388. expected = DataFrame(
  389. {"B": np.nan, "C": Series([-1, 0, 1, -1, 0, 1], dtype="float64")}
  390. )
  391. tm.assert_frame_equal(result, expected)
  392. # int case
  393. df = DataFrame(
  394. {
  395. "A": [1, 1, 1, 2, 2, 2],
  396. "B": 1,
  397. "C": [1, 2, 3, 1, 2, 3],
  398. "D": "foo",
  399. }
  400. )
  401. with np.errstate(all="ignore"):
  402. with pytest.raises(TypeError, match="Could not convert"):
  403. df.groupby("A").transform(lambda x: (x - x.mean()) / x.std())
  404. result = df.groupby("A")[["B", "C"]].transform(
  405. lambda x: (x - x.mean()) / x.std()
  406. )
  407. expected = DataFrame({"B": np.nan, "C": [-1.0, 0.0, 1.0, -1.0, 0.0, 1.0]})
  408. tm.assert_frame_equal(result, expected)
  409. # int that needs float conversion
  410. s = Series([2, 3, 4, 10, 5, -1])
  411. df = DataFrame({"A": [1, 1, 1, 2, 2, 2], "B": 1, "C": s, "D": "foo"})
  412. with np.errstate(all="ignore"):
  413. with pytest.raises(TypeError, match="Could not convert"):
  414. df.groupby("A").transform(lambda x: (x - x.mean()) / x.std())
  415. result = df.groupby("A")[["B", "C"]].transform(
  416. lambda x: (x - x.mean()) / x.std()
  417. )
  418. s1 = s.iloc[0:3]
  419. s1 = (s1 - s1.mean()) / s1.std()
  420. s2 = s.iloc[3:6]
  421. s2 = (s2 - s2.mean()) / s2.std()
  422. expected = DataFrame({"B": np.nan, "C": concat([s1, s2])})
  423. tm.assert_frame_equal(result, expected)
  424. # int doesn't get downcasted
  425. result = df.groupby("A")[["B", "C"]].transform(lambda x: x * 2 / 2)
  426. expected = DataFrame({"B": 1.0, "C": [2.0, 3.0, 4.0, 10.0, 5.0, -1.0]})
  427. tm.assert_frame_equal(result, expected)
  428. def test_groupby_transform_with_nan_group():
  429. # GH 9941
  430. df = DataFrame({"a": range(10), "b": [1, 1, 2, 3, np.nan, 4, 4, 5, 5, 5]})
  431. result = df.groupby(df.b)["a"].transform(max)
  432. expected = Series([1.0, 1.0, 2.0, 3.0, np.nan, 6.0, 6.0, 9.0, 9.0, 9.0], name="a")
  433. tm.assert_series_equal(result, expected)
  434. def test_transform_mixed_type():
  435. index = MultiIndex.from_arrays([[0, 0, 0, 1, 1, 1], [1, 2, 3, 1, 2, 3]])
  436. df = DataFrame(
  437. {
  438. "d": [1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
  439. "c": np.tile(["a", "b", "c"], 2),
  440. "v": np.arange(1.0, 7.0),
  441. },
  442. index=index,
  443. )
  444. def f(group):
  445. group["g"] = group["d"] * 2
  446. return group[:1]
  447. grouped = df.groupby("c")
  448. result = grouped.apply(f)
  449. assert result["d"].dtype == np.float64
  450. # this is by definition a mutating operation!
  451. with pd.option_context("mode.chained_assignment", None):
  452. for key, group in grouped:
  453. res = f(group)
  454. tm.assert_frame_equal(res, result.loc[key])
  455. @pytest.mark.parametrize(
  456. "op, args, targop",
  457. [
  458. ("cumprod", (), lambda x: x.cumprod()),
  459. ("cumsum", (), lambda x: x.cumsum()),
  460. ("shift", (-1,), lambda x: x.shift(-1)),
  461. ("shift", (1,), lambda x: x.shift()),
  462. ],
  463. )
  464. def test_cython_transform_series(op, args, targop):
  465. # GH 4095
  466. s = Series(np.random.randn(1000))
  467. s_missing = s.copy()
  468. s_missing.iloc[2:10] = np.nan
  469. labels = np.random.randint(0, 50, size=1000).astype(float)
  470. # series
  471. for data in [s, s_missing]:
  472. # print(data.head())
  473. expected = data.groupby(labels).transform(targop)
  474. tm.assert_series_equal(expected, data.groupby(labels).transform(op, *args))
  475. tm.assert_series_equal(expected, getattr(data.groupby(labels), op)(*args))
  476. @pytest.mark.parametrize("op", ["cumprod", "cumsum"])
  477. @pytest.mark.parametrize("skipna", [False, True])
  478. @pytest.mark.parametrize(
  479. "input, exp",
  480. [
  481. # When everything is NaN
  482. ({"key": ["b"] * 10, "value": np.nan}, Series([np.nan] * 10, name="value")),
  483. # When there is a single NaN
  484. (
  485. {"key": ["b"] * 10 + ["a"] * 2, "value": [3] * 3 + [np.nan] + [3] * 8},
  486. {
  487. ("cumprod", False): [3.0, 9.0, 27.0] + [np.nan] * 7 + [3.0, 9.0],
  488. ("cumprod", True): [
  489. 3.0,
  490. 9.0,
  491. 27.0,
  492. np.nan,
  493. 81.0,
  494. 243.0,
  495. 729.0,
  496. 2187.0,
  497. 6561.0,
  498. 19683.0,
  499. 3.0,
  500. 9.0,
  501. ],
  502. ("cumsum", False): [3.0, 6.0, 9.0] + [np.nan] * 7 + [3.0, 6.0],
  503. ("cumsum", True): [
  504. 3.0,
  505. 6.0,
  506. 9.0,
  507. np.nan,
  508. 12.0,
  509. 15.0,
  510. 18.0,
  511. 21.0,
  512. 24.0,
  513. 27.0,
  514. 3.0,
  515. 6.0,
  516. ],
  517. },
  518. ),
  519. ],
  520. )
  521. def test_groupby_cum_skipna(op, skipna, input, exp):
  522. df = DataFrame(input)
  523. result = df.groupby("key")["value"].transform(op, skipna=skipna)
  524. if isinstance(exp, dict):
  525. expected = exp[(op, skipna)]
  526. else:
  527. expected = exp
  528. expected = Series(expected, name="value")
  529. tm.assert_series_equal(expected, result)
  530. @pytest.mark.slow
  531. @pytest.mark.parametrize(
  532. "op, args, targop",
  533. [
  534. ("cumprod", (), lambda x: x.cumprod()),
  535. ("cumsum", (), lambda x: x.cumsum()),
  536. ("shift", (-1,), lambda x: x.shift(-1)),
  537. ("shift", (1,), lambda x: x.shift()),
  538. ],
  539. )
  540. def test_cython_transform_frame(op, args, targop):
  541. s = Series(np.random.randn(1000))
  542. s_missing = s.copy()
  543. s_missing.iloc[2:10] = np.nan
  544. labels = np.random.randint(0, 50, size=1000).astype(float)
  545. strings = list("qwertyuiopasdfghjklz")
  546. strings_missing = strings[:]
  547. strings_missing[5] = np.nan
  548. df = DataFrame(
  549. {
  550. "float": s,
  551. "float_missing": s_missing,
  552. "int": [1, 1, 1, 1, 2] * 200,
  553. "datetime": date_range("1990-1-1", periods=1000),
  554. "timedelta": pd.timedelta_range(1, freq="s", periods=1000),
  555. "string": strings * 50,
  556. "string_missing": strings_missing * 50,
  557. },
  558. columns=[
  559. "float",
  560. "float_missing",
  561. "int",
  562. "datetime",
  563. "timedelta",
  564. "string",
  565. "string_missing",
  566. ],
  567. )
  568. df["cat"] = df["string"].astype("category")
  569. df2 = df.copy()
  570. df2.index = MultiIndex.from_product([range(100), range(10)])
  571. # DataFrame - Single and MultiIndex,
  572. # group by values, index level, columns
  573. for df in [df, df2]:
  574. for gb_target in [
  575. {"by": labels},
  576. {"level": 0},
  577. {"by": "string"},
  578. ]: # {"by": 'string_missing'}]:
  579. # {"by": ['int','string']}]:
  580. # TODO: remove or enable commented-out code
  581. gb = df.groupby(group_keys=False, **gb_target)
  582. if op != "shift" and "int" not in gb_target:
  583. # numeric apply fastpath promotes dtype so have
  584. # to apply separately and concat
  585. i = gb[["int"]].apply(targop)
  586. f = gb[["float", "float_missing"]].apply(targop)
  587. expected = concat([f, i], axis=1)
  588. else:
  589. expected = gb.apply(targop)
  590. expected = expected.sort_index(axis=1)
  591. result = gb[expected.columns].transform(op, *args).sort_index(axis=1)
  592. tm.assert_frame_equal(result, expected)
  593. result = getattr(gb[expected.columns], op)(*args).sort_index(axis=1)
  594. tm.assert_frame_equal(result, expected)
  595. # individual columns
  596. for c in df:
  597. if (
  598. c not in ["float", "int", "float_missing"]
  599. and op != "shift"
  600. and not (c == "timedelta" and op == "cumsum")
  601. ):
  602. msg = "|".join(
  603. [
  604. "does not support .* operations",
  605. ".* is not supported for object dtype",
  606. "is not implemented for this dtype",
  607. ]
  608. )
  609. with pytest.raises(TypeError, match=msg):
  610. gb[c].transform(op)
  611. with pytest.raises(TypeError, match=msg):
  612. getattr(gb[c], op)()
  613. else:
  614. expected = gb[c].apply(targop)
  615. expected.name = c
  616. tm.assert_series_equal(expected, gb[c].transform(op, *args))
  617. tm.assert_series_equal(expected, getattr(gb[c], op)(*args))
  618. def test_transform_with_non_scalar_group():
  619. # GH 10165
  620. cols = MultiIndex.from_tuples(
  621. [
  622. ("syn", "A"),
  623. ("mis", "A"),
  624. ("non", "A"),
  625. ("syn", "C"),
  626. ("mis", "C"),
  627. ("non", "C"),
  628. ("syn", "T"),
  629. ("mis", "T"),
  630. ("non", "T"),
  631. ("syn", "G"),
  632. ("mis", "G"),
  633. ("non", "G"),
  634. ]
  635. )
  636. df = DataFrame(
  637. np.random.randint(1, 10, (4, 12)), columns=cols, index=["A", "C", "G", "T"]
  638. )
  639. msg = "transform must return a scalar value for each group.*"
  640. with pytest.raises(ValueError, match=msg):
  641. df.groupby(axis=1, level=1).transform(lambda z: z.div(z.sum(axis=1), axis=0))
  642. @pytest.mark.parametrize(
  643. "cols,expected",
  644. [
  645. ("a", Series([1, 1, 1], name="a")),
  646. (
  647. ["a", "c"],
  648. DataFrame({"a": [1, 1, 1], "c": [1, 1, 1]}),
  649. ),
  650. ],
  651. )
  652. @pytest.mark.parametrize("agg_func", ["count", "rank", "size"])
  653. def test_transform_numeric_ret(cols, expected, agg_func):
  654. # GH#19200 and GH#27469
  655. df = DataFrame(
  656. {"a": date_range("2018-01-01", periods=3), "b": range(3), "c": range(7, 10)}
  657. )
  658. result = df.groupby("b")[cols].transform(agg_func)
  659. if agg_func == "rank":
  660. expected = expected.astype("float")
  661. elif agg_func == "size" and cols == ["a", "c"]:
  662. # transform("size") returns a Series
  663. expected = expected["a"].rename(None)
  664. tm.assert_equal(result, expected)
  665. def test_transform_ffill():
  666. # GH 24211
  667. data = [["a", 0.0], ["a", float("nan")], ["b", 1.0], ["b", float("nan")]]
  668. df = DataFrame(data, columns=["key", "values"])
  669. result = df.groupby("key").transform("ffill")
  670. expected = DataFrame({"values": [0.0, 0.0, 1.0, 1.0]})
  671. tm.assert_frame_equal(result, expected)
  672. result = df.groupby("key")["values"].transform("ffill")
  673. expected = Series([0.0, 0.0, 1.0, 1.0], name="values")
  674. tm.assert_series_equal(result, expected)
  675. @pytest.mark.parametrize("mix_groupings", [True, False])
  676. @pytest.mark.parametrize("as_series", [True, False])
  677. @pytest.mark.parametrize("val1,val2", [("foo", "bar"), (1, 2), (1.0, 2.0)])
  678. @pytest.mark.parametrize(
  679. "fill_method,limit,exp_vals",
  680. [
  681. (
  682. "ffill",
  683. None,
  684. [np.nan, np.nan, "val1", "val1", "val1", "val2", "val2", "val2"],
  685. ),
  686. ("ffill", 1, [np.nan, np.nan, "val1", "val1", np.nan, "val2", "val2", np.nan]),
  687. (
  688. "bfill",
  689. None,
  690. ["val1", "val1", "val1", "val2", "val2", "val2", np.nan, np.nan],
  691. ),
  692. ("bfill", 1, [np.nan, "val1", "val1", np.nan, "val2", "val2", np.nan, np.nan]),
  693. ],
  694. )
  695. def test_group_fill_methods(
  696. mix_groupings, as_series, val1, val2, fill_method, limit, exp_vals
  697. ):
  698. vals = [np.nan, np.nan, val1, np.nan, np.nan, val2, np.nan, np.nan]
  699. _exp_vals = list(exp_vals)
  700. # Overwrite placeholder values
  701. for index, exp_val in enumerate(_exp_vals):
  702. if exp_val == "val1":
  703. _exp_vals[index] = val1
  704. elif exp_val == "val2":
  705. _exp_vals[index] = val2
  706. # Need to modify values and expectations depending on the
  707. # Series / DataFrame that we ultimately want to generate
  708. if mix_groupings: # ['a', 'b', 'a, 'b', ...]
  709. keys = ["a", "b"] * len(vals)
  710. def interweave(list_obj):
  711. temp = []
  712. for x in list_obj:
  713. temp.extend([x, x])
  714. return temp
  715. _exp_vals = interweave(_exp_vals)
  716. vals = interweave(vals)
  717. else: # ['a', 'a', 'a', ... 'b', 'b', 'b']
  718. keys = ["a"] * len(vals) + ["b"] * len(vals)
  719. _exp_vals = _exp_vals * 2
  720. vals = vals * 2
  721. df = DataFrame({"key": keys, "val": vals})
  722. if as_series:
  723. result = getattr(df.groupby("key")["val"], fill_method)(limit=limit)
  724. exp = Series(_exp_vals, name="val")
  725. tm.assert_series_equal(result, exp)
  726. else:
  727. result = getattr(df.groupby("key"), fill_method)(limit=limit)
  728. exp = DataFrame({"val": _exp_vals})
  729. tm.assert_frame_equal(result, exp)
  730. @pytest.mark.parametrize("fill_method", ["ffill", "bfill"])
  731. def test_pad_stable_sorting(fill_method):
  732. # GH 21207
  733. x = [0] * 20
  734. y = [np.nan] * 10 + [1] * 10
  735. if fill_method == "bfill":
  736. y = y[::-1]
  737. df = DataFrame({"x": x, "y": y})
  738. expected = df.drop("x", axis=1)
  739. result = getattr(df.groupby("x"), fill_method)()
  740. tm.assert_frame_equal(result, expected)
  741. @pytest.mark.parametrize(
  742. "freq",
  743. [
  744. None,
  745. pytest.param(
  746. "D",
  747. marks=pytest.mark.xfail(
  748. reason="GH#23918 before method uses freq in vectorized approach"
  749. ),
  750. ),
  751. ],
  752. )
  753. @pytest.mark.parametrize("periods", [1, -1])
  754. @pytest.mark.parametrize("fill_method", ["ffill", "bfill", None])
  755. @pytest.mark.parametrize("limit", [None, 1])
  756. def test_pct_change(frame_or_series, freq, periods, fill_method, limit):
  757. # GH 21200, 21621, 30463
  758. vals = [3, np.nan, np.nan, np.nan, 1, 2, 4, 10, np.nan, 4]
  759. keys = ["a", "b"]
  760. key_v = np.repeat(keys, len(vals))
  761. df = DataFrame({"key": key_v, "vals": vals * 2})
  762. df_g = df
  763. if fill_method is not None:
  764. df_g = getattr(df.groupby("key"), fill_method)(limit=limit)
  765. grp = df_g.groupby(df.key)
  766. expected = grp["vals"].obj / grp["vals"].shift(periods) - 1
  767. gb = df.groupby("key")
  768. if frame_or_series is Series:
  769. gb = gb["vals"]
  770. else:
  771. expected = expected.to_frame("vals")
  772. result = gb.pct_change(
  773. periods=periods, fill_method=fill_method, limit=limit, freq=freq
  774. )
  775. tm.assert_equal(result, expected)
  776. @pytest.mark.parametrize(
  777. "func, expected_status",
  778. [
  779. ("ffill", ["shrt", "shrt", "lng", np.nan, "shrt", "ntrl", "ntrl"]),
  780. ("bfill", ["shrt", "lng", "lng", "shrt", "shrt", "ntrl", np.nan]),
  781. ],
  782. )
  783. def test_ffill_bfill_non_unique_multilevel(func, expected_status):
  784. # GH 19437
  785. date = pd.to_datetime(
  786. [
  787. "2018-01-01",
  788. "2018-01-01",
  789. "2018-01-01",
  790. "2018-01-01",
  791. "2018-01-02",
  792. "2018-01-01",
  793. "2018-01-02",
  794. ]
  795. )
  796. symbol = ["MSFT", "MSFT", "MSFT", "AAPL", "AAPL", "TSLA", "TSLA"]
  797. status = ["shrt", np.nan, "lng", np.nan, "shrt", "ntrl", np.nan]
  798. df = DataFrame({"date": date, "symbol": symbol, "status": status})
  799. df = df.set_index(["date", "symbol"])
  800. result = getattr(df.groupby("symbol")["status"], func)()
  801. index = MultiIndex.from_tuples(
  802. tuples=list(zip(*[date, symbol])), names=["date", "symbol"]
  803. )
  804. expected = Series(expected_status, index=index, name="status")
  805. tm.assert_series_equal(result, expected)
  806. @pytest.mark.parametrize("func", [np.any, np.all])
  807. def test_any_all_np_func(func):
  808. # GH 20653
  809. df = DataFrame(
  810. [["foo", True], [np.nan, True], ["foo", True]], columns=["key", "val"]
  811. )
  812. exp = Series([True, np.nan, True], name="val")
  813. res = df.groupby("key")["val"].transform(func)
  814. tm.assert_series_equal(res, exp)
  815. def test_groupby_transform_rename():
  816. # https://github.com/pandas-dev/pandas/issues/23461
  817. def demean_rename(x):
  818. result = x - x.mean()
  819. if isinstance(x, Series):
  820. return result
  821. result = result.rename(columns={c: f"{c}_demeaned" for c in result.columns})
  822. return result
  823. df = DataFrame({"group": list("ababa"), "value": [1, 1, 1, 2, 2]})
  824. expected = DataFrame({"value": [-1.0 / 3, -0.5, -1.0 / 3, 0.5, 2.0 / 3]})
  825. result = df.groupby("group").transform(demean_rename)
  826. tm.assert_frame_equal(result, expected)
  827. result_single = df.groupby("group").value.transform(demean_rename)
  828. tm.assert_series_equal(result_single, expected["value"])
  829. @pytest.mark.parametrize("func", [min, max, np.min, np.max, "first", "last"])
  830. def test_groupby_transform_timezone_column(func):
  831. # GH 24198
  832. ts = pd.to_datetime("now", utc=True).tz_convert("Asia/Singapore")
  833. result = DataFrame({"end_time": [ts], "id": [1]})
  834. result["max_end_time"] = result.groupby("id").end_time.transform(func)
  835. expected = DataFrame([[ts, 1, ts]], columns=["end_time", "id", "max_end_time"])
  836. tm.assert_frame_equal(result, expected)
  837. @pytest.mark.parametrize(
  838. "func, values",
  839. [
  840. ("idxmin", ["1/1/2011"] * 2 + ["1/3/2011"] * 7 + ["1/10/2011"]),
  841. ("idxmax", ["1/2/2011"] * 2 + ["1/9/2011"] * 7 + ["1/10/2011"]),
  842. ],
  843. )
  844. def test_groupby_transform_with_datetimes(func, values):
  845. # GH 15306
  846. dates = date_range("1/1/2011", periods=10, freq="D")
  847. stocks = DataFrame({"price": np.arange(10.0)}, index=dates)
  848. stocks["week_id"] = dates.isocalendar().week
  849. result = stocks.groupby(stocks["week_id"])["price"].transform(func)
  850. expected = Series(data=pd.to_datetime(values), index=dates, name="price")
  851. tm.assert_series_equal(result, expected)
  852. def test_groupby_transform_dtype():
  853. # GH 22243
  854. df = DataFrame({"a": [1], "val": [1.35]})
  855. result = df["val"].transform(lambda x: x.map(lambda y: f"+{y}"))
  856. expected1 = Series(["+1.35"], name="val", dtype="object")
  857. tm.assert_series_equal(result, expected1)
  858. result = df.groupby("a")["val"].transform(lambda x: x.map(lambda y: f"+{y}"))
  859. tm.assert_series_equal(result, expected1)
  860. result = df.groupby("a")["val"].transform(lambda x: x.map(lambda y: f"+({y})"))
  861. expected2 = Series(["+(1.35)"], name="val", dtype="object")
  862. tm.assert_series_equal(result, expected2)
  863. df["val"] = df["val"].astype(object)
  864. result = df.groupby("a")["val"].transform(lambda x: x.map(lambda y: f"+{y}"))
  865. tm.assert_series_equal(result, expected1)
  866. @pytest.mark.parametrize("func", ["cumsum", "cumprod", "cummin", "cummax"])
  867. def test_transform_absent_categories(func):
  868. # GH 16771
  869. # cython transforms with more groups than rows
  870. x_vals = [1]
  871. x_cats = range(2)
  872. y = [1]
  873. df = DataFrame({"x": Categorical(x_vals, x_cats), "y": y})
  874. result = getattr(df.y.groupby(df.x), func)()
  875. expected = df.y
  876. tm.assert_series_equal(result, expected)
  877. @pytest.mark.parametrize("func", ["ffill", "bfill", "shift"])
  878. @pytest.mark.parametrize("key, val", [("level", 0), ("by", Series([0]))])
  879. def test_ffill_not_in_axis(func, key, val):
  880. # GH 21521
  881. df = DataFrame([[np.nan]])
  882. result = getattr(df.groupby(**{key: val}), func)()
  883. expected = df
  884. tm.assert_frame_equal(result, expected)
  885. def test_transform_invalid_name_raises():
  886. # GH#27486
  887. df = DataFrame({"a": [0, 1, 1, 2]})
  888. g = df.groupby(["a", "b", "b", "c"])
  889. with pytest.raises(ValueError, match="not a valid function name"):
  890. g.transform("some_arbitrary_name")
  891. # method exists on the object, but is not a valid transformation/agg
  892. assert hasattr(g, "aggregate") # make sure the method exists
  893. with pytest.raises(ValueError, match="not a valid function name"):
  894. g.transform("aggregate")
  895. # Test SeriesGroupBy
  896. g = df["a"].groupby(["a", "b", "b", "c"])
  897. with pytest.raises(ValueError, match="not a valid function name"):
  898. g.transform("some_arbitrary_name")
  899. def test_transform_agg_by_name(request, reduction_func, frame_or_series):
  900. func = reduction_func
  901. obj = DataFrame(
  902. {"a": [0, 0, 0, 1, 1, 1], "b": range(6)},
  903. index=["A", "B", "C", "D", "E", "F"],
  904. )
  905. if frame_or_series is Series:
  906. obj = obj["a"]
  907. g = obj.groupby(np.repeat([0, 1], 3))
  908. if func == "corrwith" and isinstance(obj, Series): # GH#32293
  909. request.node.add_marker(
  910. pytest.mark.xfail(reason="TODO: implement SeriesGroupBy.corrwith")
  911. )
  912. args = get_groupby_method_args(reduction_func, obj)
  913. result = g.transform(func, *args)
  914. # this is the *definition* of a transformation
  915. tm.assert_index_equal(result.index, obj.index)
  916. if func not in ("ngroup", "size") and obj.ndim == 2:
  917. # size/ngroup return a Series, unlike other transforms
  918. tm.assert_index_equal(result.columns, obj.columns)
  919. # verify that values were broadcasted across each group
  920. assert len(set(DataFrame(result).iloc[-3:, -1])) == 1
  921. def test_transform_lambda_with_datetimetz():
  922. # GH 27496
  923. df = DataFrame(
  924. {
  925. "time": [
  926. Timestamp("2010-07-15 03:14:45"),
  927. Timestamp("2010-11-19 18:47:06"),
  928. ],
  929. "timezone": ["Etc/GMT+4", "US/Eastern"],
  930. }
  931. )
  932. result = df.groupby(["timezone"])["time"].transform(
  933. lambda x: x.dt.tz_localize(x.name)
  934. )
  935. expected = Series(
  936. [
  937. Timestamp("2010-07-15 03:14:45", tz="Etc/GMT+4"),
  938. Timestamp("2010-11-19 18:47:06", tz="US/Eastern"),
  939. ],
  940. name="time",
  941. )
  942. tm.assert_series_equal(result, expected)
  943. def test_transform_fastpath_raises():
  944. # GH#29631 case where fastpath defined in groupby.generic _choose_path
  945. # raises, but slow_path does not
  946. df = DataFrame({"A": [1, 1, 2, 2], "B": [1, -1, 1, 2]})
  947. gb = df.groupby("A")
  948. def func(grp):
  949. # we want a function such that func(frame) fails but func.apply(frame)
  950. # works
  951. if grp.ndim == 2:
  952. # Ensure that fast_path fails
  953. raise NotImplementedError("Don't cross the streams")
  954. return grp * 2
  955. # Check that the fastpath raises, see _transform_general
  956. obj = gb._obj_with_exclusions
  957. gen = gb.grouper.get_iterator(obj, axis=gb.axis)
  958. fast_path, slow_path = gb._define_paths(func)
  959. _, group = next(gen)
  960. with pytest.raises(NotImplementedError, match="Don't cross the streams"):
  961. fast_path(group)
  962. result = gb.transform(func)
  963. expected = DataFrame([2, -2, 2, 4], columns=["B"])
  964. tm.assert_frame_equal(result, expected)
  965. def test_transform_lambda_indexing():
  966. # GH 7883
  967. df = DataFrame(
  968. {
  969. "A": ["foo", "bar", "foo", "bar", "foo", "flux", "foo", "flux"],
  970. "B": ["one", "one", "two", "three", "two", "six", "five", "three"],
  971. "C": range(8),
  972. "D": range(8),
  973. "E": range(8),
  974. }
  975. )
  976. df = df.set_index(["A", "B"])
  977. df = df.sort_index()
  978. result = df.groupby(level="A").transform(lambda x: x.iloc[-1])
  979. expected = DataFrame(
  980. {
  981. "C": [3, 3, 7, 7, 4, 4, 4, 4],
  982. "D": [3, 3, 7, 7, 4, 4, 4, 4],
  983. "E": [3, 3, 7, 7, 4, 4, 4, 4],
  984. },
  985. index=MultiIndex.from_tuples(
  986. [
  987. ("bar", "one"),
  988. ("bar", "three"),
  989. ("flux", "six"),
  990. ("flux", "three"),
  991. ("foo", "five"),
  992. ("foo", "one"),
  993. ("foo", "two"),
  994. ("foo", "two"),
  995. ],
  996. names=["A", "B"],
  997. ),
  998. )
  999. tm.assert_frame_equal(result, expected)
  1000. def test_categorical_and_not_categorical_key(observed):
  1001. # Checks that groupby-transform, when grouping by both a categorical
  1002. # and a non-categorical key, doesn't try to expand the output to include
  1003. # non-observed categories but instead matches the input shape.
  1004. # GH 32494
  1005. df_with_categorical = DataFrame(
  1006. {
  1007. "A": Categorical(["a", "b", "a"], categories=["a", "b", "c"]),
  1008. "B": [1, 2, 3],
  1009. "C": ["a", "b", "a"],
  1010. }
  1011. )
  1012. df_without_categorical = DataFrame(
  1013. {"A": ["a", "b", "a"], "B": [1, 2, 3], "C": ["a", "b", "a"]}
  1014. )
  1015. # DataFrame case
  1016. result = df_with_categorical.groupby(["A", "C"], observed=observed).transform("sum")
  1017. expected = df_without_categorical.groupby(["A", "C"]).transform("sum")
  1018. tm.assert_frame_equal(result, expected)
  1019. expected_explicit = DataFrame({"B": [4, 2, 4]})
  1020. tm.assert_frame_equal(result, expected_explicit)
  1021. # Series case
  1022. result = df_with_categorical.groupby(["A", "C"], observed=observed)["B"].transform(
  1023. "sum"
  1024. )
  1025. expected = df_without_categorical.groupby(["A", "C"])["B"].transform("sum")
  1026. tm.assert_series_equal(result, expected)
  1027. expected_explicit = Series([4, 2, 4], name="B")
  1028. tm.assert_series_equal(result, expected_explicit)
  1029. def test_string_rank_grouping():
  1030. # GH 19354
  1031. df = DataFrame({"A": [1, 1, 2], "B": [1, 2, 3]})
  1032. result = df.groupby("A").transform("rank")
  1033. expected = DataFrame({"B": [1.0, 2.0, 1.0]})
  1034. tm.assert_frame_equal(result, expected)
  1035. def test_transform_cumcount():
  1036. # GH 27472
  1037. df = DataFrame({"a": [0, 0, 0, 1, 1, 1], "b": range(6)})
  1038. grp = df.groupby(np.repeat([0, 1], 3))
  1039. result = grp.cumcount()
  1040. expected = Series([0, 1, 2, 0, 1, 2])
  1041. tm.assert_series_equal(result, expected)
  1042. result = grp.transform("cumcount")
  1043. tm.assert_series_equal(result, expected)
  1044. @pytest.mark.parametrize("keys", [["A1"], ["A1", "A2"]])
  1045. def test_null_group_lambda_self(sort, dropna, keys):
  1046. # GH 17093
  1047. size = 50
  1048. nulls1 = np.random.choice([False, True], size)
  1049. nulls2 = np.random.choice([False, True], size)
  1050. # Whether a group contains a null value or not
  1051. nulls_grouper = nulls1 if len(keys) == 1 else nulls1 | nulls2
  1052. a1 = np.random.randint(0, 5, size=size).astype(float)
  1053. a1[nulls1] = np.nan
  1054. a2 = np.random.randint(0, 5, size=size).astype(float)
  1055. a2[nulls2] = np.nan
  1056. values = np.random.randint(0, 5, size=a1.shape)
  1057. df = DataFrame({"A1": a1, "A2": a2, "B": values})
  1058. expected_values = values
  1059. if dropna and nulls_grouper.any():
  1060. expected_values = expected_values.astype(float)
  1061. expected_values[nulls_grouper] = np.nan
  1062. expected = DataFrame(expected_values, columns=["B"])
  1063. gb = df.groupby(keys, dropna=dropna, sort=sort)
  1064. result = gb[["B"]].transform(lambda x: x)
  1065. tm.assert_frame_equal(result, expected)
  1066. def test_null_group_str_reducer(request, dropna, reduction_func):
  1067. # GH 17093
  1068. if reduction_func == "corrwith":
  1069. msg = "incorrectly raises"
  1070. request.node.add_marker(pytest.mark.xfail(reason=msg))
  1071. index = [1, 2, 3, 4] # test transform preserves non-standard index
  1072. df = DataFrame({"A": [1, 1, np.nan, np.nan], "B": [1, 2, 2, 3]}, index=index)
  1073. gb = df.groupby("A", dropna=dropna)
  1074. args = get_groupby_method_args(reduction_func, df)
  1075. # Manually handle reducers that don't fit the generic pattern
  1076. # Set expected with dropna=False, then replace if necessary
  1077. if reduction_func == "first":
  1078. expected = DataFrame({"B": [1, 1, 2, 2]}, index=index)
  1079. elif reduction_func == "last":
  1080. expected = DataFrame({"B": [2, 2, 3, 3]}, index=index)
  1081. elif reduction_func == "nth":
  1082. expected = DataFrame({"B": [1, 1, 2, 2]}, index=index)
  1083. elif reduction_func == "size":
  1084. expected = Series([2, 2, 2, 2], index=index)
  1085. elif reduction_func == "corrwith":
  1086. expected = DataFrame({"B": [1.0, 1.0, 1.0, 1.0]}, index=index)
  1087. else:
  1088. expected_gb = df.groupby("A", dropna=False)
  1089. buffer = []
  1090. for idx, group in expected_gb:
  1091. res = getattr(group["B"], reduction_func)()
  1092. buffer.append(Series(res, index=group.index))
  1093. expected = concat(buffer).to_frame("B")
  1094. if dropna:
  1095. dtype = object if reduction_func in ("any", "all") else float
  1096. expected = expected.astype(dtype)
  1097. if expected.ndim == 2:
  1098. expected.iloc[[2, 3], 0] = np.nan
  1099. else:
  1100. expected.iloc[[2, 3]] = np.nan
  1101. result = gb.transform(reduction_func, *args)
  1102. tm.assert_equal(result, expected)
  1103. def test_null_group_str_transformer(request, dropna, transformation_func):
  1104. # GH 17093
  1105. df = DataFrame({"A": [1, 1, np.nan], "B": [1, 2, 2]}, index=[1, 2, 3])
  1106. args = get_groupby_method_args(transformation_func, df)
  1107. gb = df.groupby("A", dropna=dropna)
  1108. buffer = []
  1109. for k, (idx, group) in enumerate(gb):
  1110. if transformation_func == "cumcount":
  1111. # DataFrame has no cumcount method
  1112. res = DataFrame({"B": range(len(group))}, index=group.index)
  1113. elif transformation_func == "ngroup":
  1114. res = DataFrame(len(group) * [k], index=group.index, columns=["B"])
  1115. else:
  1116. res = getattr(group[["B"]], transformation_func)(*args)
  1117. buffer.append(res)
  1118. if dropna:
  1119. dtype = object if transformation_func in ("any", "all") else None
  1120. buffer.append(DataFrame([[np.nan]], index=[3], dtype=dtype, columns=["B"]))
  1121. expected = concat(buffer)
  1122. if transformation_func in ("cumcount", "ngroup"):
  1123. # ngroup/cumcount always returns a Series as it counts the groups, not values
  1124. expected = expected["B"].rename(None)
  1125. result = gb.transform(transformation_func, *args)
  1126. tm.assert_equal(result, expected)
  1127. def test_null_group_str_reducer_series(request, dropna, reduction_func):
  1128. # GH 17093
  1129. if reduction_func == "corrwith":
  1130. msg = "corrwith not implemented for SeriesGroupBy"
  1131. request.node.add_marker(pytest.mark.xfail(reason=msg))
  1132. # GH 17093
  1133. index = [1, 2, 3, 4] # test transform preserves non-standard index
  1134. ser = Series([1, 2, 2, 3], index=index)
  1135. gb = ser.groupby([1, 1, np.nan, np.nan], dropna=dropna)
  1136. args = get_groupby_method_args(reduction_func, ser)
  1137. # Manually handle reducers that don't fit the generic pattern
  1138. # Set expected with dropna=False, then replace if necessary
  1139. if reduction_func == "first":
  1140. expected = Series([1, 1, 2, 2], index=index)
  1141. elif reduction_func == "last":
  1142. expected = Series([2, 2, 3, 3], index=index)
  1143. elif reduction_func == "nth":
  1144. expected = Series([1, 1, 2, 2], index=index)
  1145. elif reduction_func == "size":
  1146. expected = Series([2, 2, 2, 2], index=index)
  1147. elif reduction_func == "corrwith":
  1148. expected = Series([1, 1, 2, 2], index=index)
  1149. else:
  1150. expected_gb = ser.groupby([1, 1, np.nan, np.nan], dropna=False)
  1151. buffer = []
  1152. for idx, group in expected_gb:
  1153. res = getattr(group, reduction_func)()
  1154. buffer.append(Series(res, index=group.index))
  1155. expected = concat(buffer)
  1156. if dropna:
  1157. dtype = object if reduction_func in ("any", "all") else float
  1158. expected = expected.astype(dtype)
  1159. expected.iloc[[2, 3]] = np.nan
  1160. result = gb.transform(reduction_func, *args)
  1161. tm.assert_series_equal(result, expected)
  1162. def test_null_group_str_transformer_series(dropna, transformation_func):
  1163. # GH 17093
  1164. ser = Series([1, 2, 2], index=[1, 2, 3])
  1165. args = get_groupby_method_args(transformation_func, ser)
  1166. gb = ser.groupby([1, 1, np.nan], dropna=dropna)
  1167. buffer = []
  1168. for k, (idx, group) in enumerate(gb):
  1169. if transformation_func == "cumcount":
  1170. # Series has no cumcount method
  1171. res = Series(range(len(group)), index=group.index)
  1172. elif transformation_func == "ngroup":
  1173. res = Series(k, index=group.index)
  1174. else:
  1175. res = getattr(group, transformation_func)(*args)
  1176. buffer.append(res)
  1177. if dropna:
  1178. dtype = object if transformation_func in ("any", "all") else None
  1179. buffer.append(Series([np.nan], index=[3], dtype=dtype))
  1180. expected = concat(buffer)
  1181. with tm.assert_produces_warning(None):
  1182. result = gb.transform(transformation_func, *args)
  1183. tm.assert_equal(result, expected)
  1184. @pytest.mark.parametrize(
  1185. "func, expected_values",
  1186. [
  1187. (Series.sort_values, [5, 4, 3, 2, 1]),
  1188. (lambda x: x.head(1), [5.0, np.nan, 3, 2, np.nan]),
  1189. ],
  1190. )
  1191. @pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]])
  1192. @pytest.mark.parametrize("keys_in_index", [True, False])
  1193. def test_transform_aligns(func, frame_or_series, expected_values, keys, keys_in_index):
  1194. # GH#45648 - transform should align with the input's index
  1195. df = DataFrame({"a1": [1, 1, 3, 2, 2], "b": [5, 4, 3, 2, 1]})
  1196. if "a2" in keys:
  1197. df["a2"] = df["a1"]
  1198. if keys_in_index:
  1199. df = df.set_index(keys, append=True)
  1200. gb = df.groupby(keys)
  1201. if frame_or_series is Series:
  1202. gb = gb["b"]
  1203. result = gb.transform(func)
  1204. expected = DataFrame({"b": expected_values}, index=df.index)
  1205. if frame_or_series is Series:
  1206. expected = expected["b"]
  1207. tm.assert_equal(result, expected)
  1208. @pytest.mark.parametrize("keys", ["A", ["A", "B"]])
  1209. def test_as_index_no_change(keys, df, groupby_func):
  1210. # GH#49834 - as_index should have no impact on DataFrameGroupBy.transform
  1211. if keys == "A":
  1212. # Column B is string dtype; will fail on some ops
  1213. df = df.drop(columns="B")
  1214. args = get_groupby_method_args(groupby_func, df)
  1215. gb_as_index_true = df.groupby(keys, as_index=True)
  1216. gb_as_index_false = df.groupby(keys, as_index=False)
  1217. result = gb_as_index_true.transform(groupby_func, *args)
  1218. expected = gb_as_index_false.transform(groupby_func, *args)
  1219. tm.assert_equal(result, expected)