test_other.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664
  1. """
  2. test all other .agg behavior
  3. """
  4. import datetime as dt
  5. from functools import partial
  6. import numpy as np
  7. import pytest
  8. from pandas.errors import SpecificationError
  9. import pandas as pd
  10. from pandas import (
  11. DataFrame,
  12. Index,
  13. MultiIndex,
  14. PeriodIndex,
  15. Series,
  16. date_range,
  17. period_range,
  18. )
  19. import pandas._testing as tm
  20. from pandas.io.formats.printing import pprint_thing
  21. def test_agg_partial_failure_raises():
  22. # GH#43741
  23. df = DataFrame(
  24. {
  25. "data1": np.random.randn(5),
  26. "data2": np.random.randn(5),
  27. "key1": ["a", "a", "b", "b", "a"],
  28. "key2": ["one", "two", "one", "two", "one"],
  29. }
  30. )
  31. grouped = df.groupby("key1")
  32. def peak_to_peak(arr):
  33. return arr.max() - arr.min()
  34. with pytest.raises(TypeError, match="unsupported operand type"):
  35. grouped.agg([peak_to_peak])
  36. with pytest.raises(TypeError, match="unsupported operand type"):
  37. grouped.agg(peak_to_peak)
  38. def test_agg_datetimes_mixed():
  39. data = [[1, "2012-01-01", 1.0], [2, "2012-01-02", 2.0], [3, None, 3.0]]
  40. df1 = DataFrame(
  41. {
  42. "key": [x[0] for x in data],
  43. "date": [x[1] for x in data],
  44. "value": [x[2] for x in data],
  45. }
  46. )
  47. data = [
  48. [
  49. row[0],
  50. (dt.datetime.strptime(row[1], "%Y-%m-%d").date() if row[1] else None),
  51. row[2],
  52. ]
  53. for row in data
  54. ]
  55. df2 = DataFrame(
  56. {
  57. "key": [x[0] for x in data],
  58. "date": [x[1] for x in data],
  59. "value": [x[2] for x in data],
  60. }
  61. )
  62. df1["weights"] = df1["value"] / df1["value"].sum()
  63. gb1 = df1.groupby("date").aggregate(np.sum)
  64. df2["weights"] = df1["value"] / df1["value"].sum()
  65. gb2 = df2.groupby("date").aggregate(np.sum)
  66. assert len(gb1) == len(gb2)
  67. def test_agg_period_index():
  68. prng = period_range("2012-1-1", freq="M", periods=3)
  69. df = DataFrame(np.random.randn(3, 2), index=prng)
  70. rs = df.groupby(level=0).sum()
  71. assert isinstance(rs.index, PeriodIndex)
  72. # GH 3579
  73. index = period_range(start="1999-01", periods=5, freq="M")
  74. s1 = Series(np.random.rand(len(index)), index=index)
  75. s2 = Series(np.random.rand(len(index)), index=index)
  76. df = DataFrame.from_dict({"s1": s1, "s2": s2})
  77. grouped = df.groupby(df.index.month)
  78. list(grouped)
  79. def test_agg_dict_parameter_cast_result_dtypes():
  80. # GH 12821
  81. df = DataFrame(
  82. {
  83. "class": ["A", "A", "B", "B", "C", "C", "D", "D"],
  84. "time": date_range("1/1/2011", periods=8, freq="H"),
  85. }
  86. )
  87. df.loc[[0, 1, 2, 5], "time"] = None
  88. # test for `first` function
  89. exp = df.loc[[0, 3, 4, 6]].set_index("class")
  90. grouped = df.groupby("class")
  91. tm.assert_frame_equal(grouped.first(), exp)
  92. tm.assert_frame_equal(grouped.agg("first"), exp)
  93. tm.assert_frame_equal(grouped.agg({"time": "first"}), exp)
  94. tm.assert_series_equal(grouped.time.first(), exp["time"])
  95. tm.assert_series_equal(grouped.time.agg("first"), exp["time"])
  96. # test for `last` function
  97. exp = df.loc[[0, 3, 4, 7]].set_index("class")
  98. grouped = df.groupby("class")
  99. tm.assert_frame_equal(grouped.last(), exp)
  100. tm.assert_frame_equal(grouped.agg("last"), exp)
  101. tm.assert_frame_equal(grouped.agg({"time": "last"}), exp)
  102. tm.assert_series_equal(grouped.time.last(), exp["time"])
  103. tm.assert_series_equal(grouped.time.agg("last"), exp["time"])
  104. # count
  105. exp = Series([2, 2, 2, 2], index=Index(list("ABCD"), name="class"), name="time")
  106. tm.assert_series_equal(grouped.time.agg(len), exp)
  107. tm.assert_series_equal(grouped.time.size(), exp)
  108. exp = Series([0, 1, 1, 2], index=Index(list("ABCD"), name="class"), name="time")
  109. tm.assert_series_equal(grouped.time.count(), exp)
  110. def test_agg_cast_results_dtypes():
  111. # similar to GH12821
  112. # xref #11444
  113. u = [dt.datetime(2015, x + 1, 1) for x in range(12)]
  114. v = list("aaabbbbbbccd")
  115. df = DataFrame({"X": v, "Y": u})
  116. result = df.groupby("X")["Y"].agg(len)
  117. expected = df.groupby("X")["Y"].count()
  118. tm.assert_series_equal(result, expected)
  119. def test_aggregate_float64_no_int64():
  120. # see gh-11199
  121. df = DataFrame({"a": [1, 2, 3, 4, 5], "b": [1, 2, 2, 4, 5], "c": [1, 2, 3, 4, 5]})
  122. expected = DataFrame({"a": [1, 2.5, 4, 5]}, index=[1, 2, 4, 5])
  123. expected.index.name = "b"
  124. result = df.groupby("b")[["a"]].mean()
  125. tm.assert_frame_equal(result, expected)
  126. expected = DataFrame({"a": [1, 2.5, 4, 5], "c": [1, 2.5, 4, 5]}, index=[1, 2, 4, 5])
  127. expected.index.name = "b"
  128. result = df.groupby("b")[["a", "c"]].mean()
  129. tm.assert_frame_equal(result, expected)
  130. def test_aggregate_api_consistency():
  131. # GH 9052
  132. # make sure that the aggregates via dict
  133. # are consistent
  134. df = DataFrame(
  135. {
  136. "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"],
  137. "B": ["one", "one", "two", "two", "two", "two", "one", "two"],
  138. "C": np.random.randn(8) + 1.0,
  139. "D": np.arange(8),
  140. }
  141. )
  142. grouped = df.groupby(["A", "B"])
  143. c_mean = grouped["C"].mean()
  144. c_sum = grouped["C"].sum()
  145. d_mean = grouped["D"].mean()
  146. d_sum = grouped["D"].sum()
  147. result = grouped["D"].agg(["sum", "mean"])
  148. expected = pd.concat([d_sum, d_mean], axis=1)
  149. expected.columns = ["sum", "mean"]
  150. tm.assert_frame_equal(result, expected, check_like=True)
  151. result = grouped.agg([np.sum, np.mean])
  152. expected = pd.concat([c_sum, c_mean, d_sum, d_mean], axis=1)
  153. expected.columns = MultiIndex.from_product([["C", "D"], ["sum", "mean"]])
  154. tm.assert_frame_equal(result, expected, check_like=True)
  155. result = grouped[["D", "C"]].agg([np.sum, np.mean])
  156. expected = pd.concat([d_sum, d_mean, c_sum, c_mean], axis=1)
  157. expected.columns = MultiIndex.from_product([["D", "C"], ["sum", "mean"]])
  158. tm.assert_frame_equal(result, expected, check_like=True)
  159. result = grouped.agg({"C": "mean", "D": "sum"})
  160. expected = pd.concat([d_sum, c_mean], axis=1)
  161. tm.assert_frame_equal(result, expected, check_like=True)
  162. result = grouped.agg({"C": ["mean", "sum"], "D": ["mean", "sum"]})
  163. expected = pd.concat([c_mean, c_sum, d_mean, d_sum], axis=1)
  164. expected.columns = MultiIndex.from_product([["C", "D"], ["mean", "sum"]])
  165. msg = r"Column\(s\) \['r', 'r2'\] do not exist"
  166. with pytest.raises(KeyError, match=msg):
  167. grouped[["D", "C"]].agg({"r": np.sum, "r2": np.mean})
  168. def test_agg_dict_renaming_deprecation():
  169. # 15931
  170. df = DataFrame({"A": [1, 1, 1, 2, 2], "B": range(5), "C": range(5)})
  171. msg = r"nested renamer is not supported"
  172. with pytest.raises(SpecificationError, match=msg):
  173. df.groupby("A").agg(
  174. {"B": {"foo": ["sum", "max"]}, "C": {"bar": ["count", "min"]}}
  175. )
  176. msg = r"Column\(s\) \['ma'\] do not exist"
  177. with pytest.raises(KeyError, match=msg):
  178. df.groupby("A")[["B", "C"]].agg({"ma": "max"})
  179. msg = r"nested renamer is not supported"
  180. with pytest.raises(SpecificationError, match=msg):
  181. df.groupby("A").B.agg({"foo": "count"})
  182. def test_agg_compat():
  183. # GH 12334
  184. df = DataFrame(
  185. {
  186. "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"],
  187. "B": ["one", "one", "two", "two", "two", "two", "one", "two"],
  188. "C": np.random.randn(8) + 1.0,
  189. "D": np.arange(8),
  190. }
  191. )
  192. g = df.groupby(["A", "B"])
  193. msg = r"nested renamer is not supported"
  194. with pytest.raises(SpecificationError, match=msg):
  195. g["D"].agg({"C": ["sum", "std"]})
  196. with pytest.raises(SpecificationError, match=msg):
  197. g["D"].agg({"C": "sum", "D": "std"})
  198. def test_agg_nested_dicts():
  199. # API change for disallowing these types of nested dicts
  200. df = DataFrame(
  201. {
  202. "A": ["foo", "bar", "foo", "bar", "foo", "bar", "foo", "foo"],
  203. "B": ["one", "one", "two", "two", "two", "two", "one", "two"],
  204. "C": np.random.randn(8) + 1.0,
  205. "D": np.arange(8),
  206. }
  207. )
  208. g = df.groupby(["A", "B"])
  209. msg = r"nested renamer is not supported"
  210. with pytest.raises(SpecificationError, match=msg):
  211. g.aggregate({"r1": {"C": ["mean", "sum"]}, "r2": {"D": ["mean", "sum"]}})
  212. with pytest.raises(SpecificationError, match=msg):
  213. g.agg({"C": {"ra": ["mean", "std"]}, "D": {"rb": ["mean", "std"]}})
  214. # same name as the original column
  215. # GH9052
  216. with pytest.raises(SpecificationError, match=msg):
  217. g["D"].agg({"result1": np.sum, "result2": np.mean})
  218. with pytest.raises(SpecificationError, match=msg):
  219. g["D"].agg({"D": np.sum, "result2": np.mean})
  220. def test_agg_item_by_item_raise_typeerror():
  221. df = DataFrame(np.random.randint(10, size=(20, 10)))
  222. def raiseException(df):
  223. pprint_thing("----------------------------------------")
  224. pprint_thing(df.to_string())
  225. raise TypeError("test")
  226. with pytest.raises(TypeError, match="test"):
  227. df.groupby(0).agg(raiseException)
  228. def test_series_agg_multikey():
  229. ts = tm.makeTimeSeries()
  230. grouped = ts.groupby([lambda x: x.year, lambda x: x.month])
  231. result = grouped.agg(np.sum)
  232. expected = grouped.sum()
  233. tm.assert_series_equal(result, expected)
  234. def test_series_agg_multi_pure_python():
  235. data = DataFrame(
  236. {
  237. "A": [
  238. "foo",
  239. "foo",
  240. "foo",
  241. "foo",
  242. "bar",
  243. "bar",
  244. "bar",
  245. "bar",
  246. "foo",
  247. "foo",
  248. "foo",
  249. ],
  250. "B": [
  251. "one",
  252. "one",
  253. "one",
  254. "two",
  255. "one",
  256. "one",
  257. "one",
  258. "two",
  259. "two",
  260. "two",
  261. "one",
  262. ],
  263. "C": [
  264. "dull",
  265. "dull",
  266. "shiny",
  267. "dull",
  268. "dull",
  269. "shiny",
  270. "shiny",
  271. "dull",
  272. "shiny",
  273. "shiny",
  274. "shiny",
  275. ],
  276. "D": np.random.randn(11),
  277. "E": np.random.randn(11),
  278. "F": np.random.randn(11),
  279. }
  280. )
  281. def bad(x):
  282. assert len(x.values.base) > 0
  283. return "foo"
  284. result = data.groupby(["A", "B"]).agg(bad)
  285. expected = data.groupby(["A", "B"]).agg(lambda x: "foo")
  286. tm.assert_frame_equal(result, expected)
  287. def test_agg_consistency():
  288. # agg with ([]) and () not consistent
  289. # GH 6715
  290. def P1(a):
  291. return np.percentile(a.dropna(), q=1)
  292. df = DataFrame(
  293. {
  294. "col1": [1, 2, 3, 4],
  295. "col2": [10, 25, 26, 31],
  296. "date": [
  297. dt.date(2013, 2, 10),
  298. dt.date(2013, 2, 10),
  299. dt.date(2013, 2, 11),
  300. dt.date(2013, 2, 11),
  301. ],
  302. }
  303. )
  304. g = df.groupby("date")
  305. expected = g.agg([P1])
  306. expected.columns = expected.columns.levels[0]
  307. result = g.agg(P1)
  308. tm.assert_frame_equal(result, expected)
  309. def test_agg_callables():
  310. # GH 7929
  311. df = DataFrame({"foo": [1, 2], "bar": [3, 4]}).astype(np.int64)
  312. class fn_class:
  313. def __call__(self, x):
  314. return sum(x)
  315. equiv_callables = [
  316. sum,
  317. np.sum,
  318. lambda x: sum(x),
  319. lambda x: x.sum(),
  320. partial(sum),
  321. fn_class(),
  322. ]
  323. expected = df.groupby("foo").agg(sum)
  324. for ecall in equiv_callables:
  325. result = df.groupby("foo").agg(ecall)
  326. tm.assert_frame_equal(result, expected)
  327. def test_agg_over_numpy_arrays():
  328. # GH 3788
  329. df = DataFrame(
  330. [
  331. [1, np.array([10, 20, 30])],
  332. [1, np.array([40, 50, 60])],
  333. [2, np.array([20, 30, 40])],
  334. ],
  335. columns=["category", "arraydata"],
  336. )
  337. gb = df.groupby("category")
  338. expected_data = [[np.array([50, 70, 90])], [np.array([20, 30, 40])]]
  339. expected_index = Index([1, 2], name="category")
  340. expected_column = ["arraydata"]
  341. expected = DataFrame(expected_data, index=expected_index, columns=expected_column)
  342. alt = gb.sum(numeric_only=False)
  343. tm.assert_frame_equal(alt, expected)
  344. result = gb.agg("sum", numeric_only=False)
  345. tm.assert_frame_equal(result, expected)
  346. # FIXME: the original version of this test called `gb.agg(sum)`
  347. # and that raises TypeError if `numeric_only=False` is passed
  348. @pytest.mark.parametrize("as_period", [True, False])
  349. def test_agg_tzaware_non_datetime_result(as_period):
  350. # discussed in GH#29589, fixed in GH#29641, operating on tzaware values
  351. # with function that is not dtype-preserving
  352. dti = date_range("2012-01-01", periods=4, tz="UTC")
  353. if as_period:
  354. dti = dti.tz_localize(None).to_period("D")
  355. df = DataFrame({"a": [0, 0, 1, 1], "b": dti})
  356. gb = df.groupby("a")
  357. # Case that _does_ preserve the dtype
  358. result = gb["b"].agg(lambda x: x.iloc[0])
  359. expected = Series(dti[::2], name="b")
  360. expected.index.name = "a"
  361. tm.assert_series_equal(result, expected)
  362. # Cases that do _not_ preserve the dtype
  363. result = gb["b"].agg(lambda x: x.iloc[0].year)
  364. expected = Series([2012, 2012], name="b")
  365. expected.index.name = "a"
  366. tm.assert_series_equal(result, expected)
  367. result = gb["b"].agg(lambda x: x.iloc[-1] - x.iloc[0])
  368. expected = Series([pd.Timedelta(days=1), pd.Timedelta(days=1)], name="b")
  369. expected.index.name = "a"
  370. if as_period:
  371. expected = Series([pd.offsets.Day(1), pd.offsets.Day(1)], name="b")
  372. expected.index.name = "a"
  373. tm.assert_series_equal(result, expected)
  374. def test_agg_timezone_round_trip():
  375. # GH 15426
  376. ts = pd.Timestamp("2016-01-01 12:00:00", tz="US/Pacific")
  377. df = DataFrame({"a": 1, "b": [ts + dt.timedelta(minutes=nn) for nn in range(10)]})
  378. result1 = df.groupby("a")["b"].agg(np.min).iloc[0]
  379. result2 = df.groupby("a")["b"].agg(lambda x: np.min(x)).iloc[0]
  380. result3 = df.groupby("a")["b"].min().iloc[0]
  381. assert result1 == ts
  382. assert result2 == ts
  383. assert result3 == ts
  384. dates = [
  385. pd.Timestamp(f"2016-01-0{i:d} 12:00:00", tz="US/Pacific") for i in range(1, 5)
  386. ]
  387. df = DataFrame({"A": ["a", "b"] * 2, "B": dates})
  388. grouped = df.groupby("A")
  389. ts = df["B"].iloc[0]
  390. assert ts == grouped.nth(0)["B"].iloc[0]
  391. assert ts == grouped.head(1)["B"].iloc[0]
  392. assert ts == grouped.first()["B"].iloc[0]
  393. # GH#27110 applying iloc should return a DataFrame
  394. assert ts == grouped.apply(lambda x: x.iloc[0]).iloc[0, 1]
  395. ts = df["B"].iloc[2]
  396. assert ts == grouped.last()["B"].iloc[0]
  397. # GH#27110 applying iloc should return a DataFrame
  398. assert ts == grouped.apply(lambda x: x.iloc[-1]).iloc[0, 1]
  399. def test_sum_uint64_overflow():
  400. # see gh-14758
  401. # Convert to uint64 and don't overflow
  402. df = DataFrame([[1, 2], [3, 4], [5, 6]], dtype=object)
  403. df = df + 9223372036854775807
  404. index = Index(
  405. [9223372036854775808, 9223372036854775810, 9223372036854775812], dtype=np.uint64
  406. )
  407. expected = DataFrame(
  408. {1: [9223372036854775809, 9223372036854775811, 9223372036854775813]},
  409. index=index,
  410. dtype=object,
  411. )
  412. expected.index.name = 0
  413. result = df.groupby(0).sum(numeric_only=False)
  414. tm.assert_frame_equal(result, expected)
  415. # out column is non-numeric, so with numeric_only=True it is dropped
  416. result2 = df.groupby(0).sum(numeric_only=True)
  417. expected2 = expected[[]]
  418. tm.assert_frame_equal(result2, expected2)
  419. @pytest.mark.parametrize(
  420. "structure, expected",
  421. [
  422. (tuple, DataFrame({"C": {(1, 1): (1, 1, 1), (3, 4): (3, 4, 4)}})),
  423. (list, DataFrame({"C": {(1, 1): [1, 1, 1], (3, 4): [3, 4, 4]}})),
  424. (
  425. lambda x: tuple(x),
  426. DataFrame({"C": {(1, 1): (1, 1, 1), (3, 4): (3, 4, 4)}}),
  427. ),
  428. (
  429. lambda x: list(x),
  430. DataFrame({"C": {(1, 1): [1, 1, 1], (3, 4): [3, 4, 4]}}),
  431. ),
  432. ],
  433. )
  434. def test_agg_structs_dataframe(structure, expected):
  435. df = DataFrame(
  436. {"A": [1, 1, 1, 3, 3, 3], "B": [1, 1, 1, 4, 4, 4], "C": [1, 1, 1, 3, 4, 4]}
  437. )
  438. result = df.groupby(["A", "B"]).aggregate(structure)
  439. expected.index.names = ["A", "B"]
  440. tm.assert_frame_equal(result, expected)
  441. @pytest.mark.parametrize(
  442. "structure, expected",
  443. [
  444. (tuple, Series([(1, 1, 1), (3, 4, 4)], index=[1, 3], name="C")),
  445. (list, Series([[1, 1, 1], [3, 4, 4]], index=[1, 3], name="C")),
  446. (lambda x: tuple(x), Series([(1, 1, 1), (3, 4, 4)], index=[1, 3], name="C")),
  447. (lambda x: list(x), Series([[1, 1, 1], [3, 4, 4]], index=[1, 3], name="C")),
  448. ],
  449. )
  450. def test_agg_structs_series(structure, expected):
  451. # Issue #18079
  452. df = DataFrame(
  453. {"A": [1, 1, 1, 3, 3, 3], "B": [1, 1, 1, 4, 4, 4], "C": [1, 1, 1, 3, 4, 4]}
  454. )
  455. result = df.groupby("A")["C"].aggregate(structure)
  456. expected.index.name = "A"
  457. tm.assert_series_equal(result, expected)
  458. def test_agg_category_nansum(observed):
  459. categories = ["a", "b", "c"]
  460. df = DataFrame(
  461. {"A": pd.Categorical(["a", "a", "b"], categories=categories), "B": [1, 2, 3]}
  462. )
  463. result = df.groupby("A", observed=observed).B.agg(np.nansum)
  464. expected = Series(
  465. [3, 3, 0],
  466. index=pd.CategoricalIndex(["a", "b", "c"], categories=categories, name="A"),
  467. name="B",
  468. )
  469. if observed:
  470. expected = expected[expected != 0]
  471. tm.assert_series_equal(result, expected)
  472. def test_agg_list_like_func():
  473. # GH 18473
  474. df = DataFrame({"A": [str(x) for x in range(3)], "B": [str(x) for x in range(3)]})
  475. grouped = df.groupby("A", as_index=False, sort=False)
  476. result = grouped.agg({"B": lambda x: list(x)})
  477. expected = DataFrame(
  478. {"A": [str(x) for x in range(3)], "B": [[str(x)] for x in range(3)]}
  479. )
  480. tm.assert_frame_equal(result, expected)
  481. def test_agg_lambda_with_timezone():
  482. # GH 23683
  483. df = DataFrame(
  484. {
  485. "tag": [1, 1],
  486. "date": [
  487. pd.Timestamp("2018-01-01", tz="UTC"),
  488. pd.Timestamp("2018-01-02", tz="UTC"),
  489. ],
  490. }
  491. )
  492. result = df.groupby("tag").agg({"date": lambda e: e.head(1)})
  493. expected = DataFrame(
  494. [pd.Timestamp("2018-01-01", tz="UTC")],
  495. index=Index([1], name="tag"),
  496. columns=["date"],
  497. )
  498. tm.assert_frame_equal(result, expected)
  499. @pytest.mark.parametrize(
  500. "err_cls",
  501. [
  502. NotImplementedError,
  503. RuntimeError,
  504. KeyError,
  505. IndexError,
  506. OSError,
  507. ValueError,
  508. ArithmeticError,
  509. AttributeError,
  510. ],
  511. )
  512. def test_groupby_agg_err_catching(err_cls):
  513. # make sure we suppress anything other than TypeError or AssertionError
  514. # in _python_agg_general
  515. # Use a non-standard EA to make sure we don't go down ndarray paths
  516. from pandas.tests.extension.decimal.array import (
  517. DecimalArray,
  518. make_data,
  519. to_decimal,
  520. )
  521. data = make_data()[:5]
  522. df = DataFrame(
  523. {"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)}
  524. )
  525. expected = Series(to_decimal([data[0], data[3]]))
  526. def weird_func(x):
  527. # weird function that raise something other than TypeError or IndexError
  528. # in _python_agg_general
  529. if len(x) == 0:
  530. raise err_cls
  531. return x.iloc[0]
  532. result = df["decimals"].groupby(df["id1"]).agg(weird_func)
  533. tm.assert_series_equal(result, expected, check_names=False)