test_rank.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. from datetime import datetime
  2. import numpy as np
  3. import pytest
  4. import pandas as pd
  5. from pandas import (
  6. DataFrame,
  7. NaT,
  8. Series,
  9. concat,
  10. )
  11. import pandas._testing as tm
  12. def test_rank_unordered_categorical_typeerror():
  13. # GH#51034 should be TypeError, not NotImplementedError
  14. cat = pd.Categorical([], ordered=False)
  15. ser = Series(cat)
  16. df = ser.to_frame()
  17. msg = "Cannot perform rank with non-ordered Categorical"
  18. gb = ser.groupby(cat)
  19. with pytest.raises(TypeError, match=msg):
  20. gb.rank()
  21. gb2 = df.groupby(cat)
  22. with pytest.raises(TypeError, match=msg):
  23. gb2.rank()
  24. def test_rank_apply():
  25. lev1 = tm.rands_array(10, 100)
  26. lev2 = tm.rands_array(10, 130)
  27. lab1 = np.random.randint(0, 100, size=500)
  28. lab2 = np.random.randint(0, 130, size=500)
  29. df = DataFrame(
  30. {
  31. "value": np.random.randn(500),
  32. "key1": lev1.take(lab1),
  33. "key2": lev2.take(lab2),
  34. }
  35. )
  36. result = df.groupby(["key1", "key2"]).value.rank()
  37. expected = [piece.value.rank() for key, piece in df.groupby(["key1", "key2"])]
  38. expected = concat(expected, axis=0)
  39. expected = expected.reindex(result.index)
  40. tm.assert_series_equal(result, expected)
  41. result = df.groupby(["key1", "key2"]).value.rank(pct=True)
  42. expected = [
  43. piece.value.rank(pct=True) for key, piece in df.groupby(["key1", "key2"])
  44. ]
  45. expected = concat(expected, axis=0)
  46. expected = expected.reindex(result.index)
  47. tm.assert_series_equal(result, expected)
  48. @pytest.mark.parametrize("grps", [["qux"], ["qux", "quux"]])
  49. @pytest.mark.parametrize(
  50. "vals",
  51. [
  52. np.array([2, 2, 8, 2, 6], dtype=dtype)
  53. for dtype in ["i8", "i4", "i2", "i1", "u8", "u4", "u2", "u1", "f8", "f4", "f2"]
  54. ]
  55. + [
  56. [
  57. pd.Timestamp("2018-01-02"),
  58. pd.Timestamp("2018-01-02"),
  59. pd.Timestamp("2018-01-08"),
  60. pd.Timestamp("2018-01-02"),
  61. pd.Timestamp("2018-01-06"),
  62. ],
  63. [
  64. pd.Timestamp("2018-01-02", tz="US/Pacific"),
  65. pd.Timestamp("2018-01-02", tz="US/Pacific"),
  66. pd.Timestamp("2018-01-08", tz="US/Pacific"),
  67. pd.Timestamp("2018-01-02", tz="US/Pacific"),
  68. pd.Timestamp("2018-01-06", tz="US/Pacific"),
  69. ],
  70. [
  71. pd.Timestamp("2018-01-02") - pd.Timestamp(0),
  72. pd.Timestamp("2018-01-02") - pd.Timestamp(0),
  73. pd.Timestamp("2018-01-08") - pd.Timestamp(0),
  74. pd.Timestamp("2018-01-02") - pd.Timestamp(0),
  75. pd.Timestamp("2018-01-06") - pd.Timestamp(0),
  76. ],
  77. [
  78. pd.Timestamp("2018-01-02").to_period("D"),
  79. pd.Timestamp("2018-01-02").to_period("D"),
  80. pd.Timestamp("2018-01-08").to_period("D"),
  81. pd.Timestamp("2018-01-02").to_period("D"),
  82. pd.Timestamp("2018-01-06").to_period("D"),
  83. ],
  84. ],
  85. ids=lambda x: type(x[0]),
  86. )
  87. @pytest.mark.parametrize(
  88. "ties_method,ascending,pct,exp",
  89. [
  90. ("average", True, False, [2.0, 2.0, 5.0, 2.0, 4.0]),
  91. ("average", True, True, [0.4, 0.4, 1.0, 0.4, 0.8]),
  92. ("average", False, False, [4.0, 4.0, 1.0, 4.0, 2.0]),
  93. ("average", False, True, [0.8, 0.8, 0.2, 0.8, 0.4]),
  94. ("min", True, False, [1.0, 1.0, 5.0, 1.0, 4.0]),
  95. ("min", True, True, [0.2, 0.2, 1.0, 0.2, 0.8]),
  96. ("min", False, False, [3.0, 3.0, 1.0, 3.0, 2.0]),
  97. ("min", False, True, [0.6, 0.6, 0.2, 0.6, 0.4]),
  98. ("max", True, False, [3.0, 3.0, 5.0, 3.0, 4.0]),
  99. ("max", True, True, [0.6, 0.6, 1.0, 0.6, 0.8]),
  100. ("max", False, False, [5.0, 5.0, 1.0, 5.0, 2.0]),
  101. ("max", False, True, [1.0, 1.0, 0.2, 1.0, 0.4]),
  102. ("first", True, False, [1.0, 2.0, 5.0, 3.0, 4.0]),
  103. ("first", True, True, [0.2, 0.4, 1.0, 0.6, 0.8]),
  104. ("first", False, False, [3.0, 4.0, 1.0, 5.0, 2.0]),
  105. ("first", False, True, [0.6, 0.8, 0.2, 1.0, 0.4]),
  106. ("dense", True, False, [1.0, 1.0, 3.0, 1.0, 2.0]),
  107. ("dense", True, True, [1.0 / 3.0, 1.0 / 3.0, 3.0 / 3.0, 1.0 / 3.0, 2.0 / 3.0]),
  108. ("dense", False, False, [3.0, 3.0, 1.0, 3.0, 2.0]),
  109. ("dense", False, True, [3.0 / 3.0, 3.0 / 3.0, 1.0 / 3.0, 3.0 / 3.0, 2.0 / 3.0]),
  110. ],
  111. )
  112. def test_rank_args(grps, vals, ties_method, ascending, pct, exp):
  113. key = np.repeat(grps, len(vals))
  114. orig_vals = vals
  115. vals = list(vals) * len(grps)
  116. if isinstance(orig_vals, np.ndarray):
  117. vals = np.array(vals, dtype=orig_vals.dtype)
  118. df = DataFrame({"key": key, "val": vals})
  119. result = df.groupby("key").rank(method=ties_method, ascending=ascending, pct=pct)
  120. exp_df = DataFrame(exp * len(grps), columns=["val"])
  121. tm.assert_frame_equal(result, exp_df)
  122. @pytest.mark.parametrize("grps", [["qux"], ["qux", "quux"]])
  123. @pytest.mark.parametrize(
  124. "vals", [[-np.inf, -np.inf, np.nan, 1.0, np.nan, np.inf, np.inf]]
  125. )
  126. @pytest.mark.parametrize(
  127. "ties_method,ascending,na_option,exp",
  128. [
  129. ("average", True, "keep", [1.5, 1.5, np.nan, 3, np.nan, 4.5, 4.5]),
  130. ("average", True, "top", [3.5, 3.5, 1.5, 5.0, 1.5, 6.5, 6.5]),
  131. ("average", True, "bottom", [1.5, 1.5, 6.5, 3.0, 6.5, 4.5, 4.5]),
  132. ("average", False, "keep", [4.5, 4.5, np.nan, 3, np.nan, 1.5, 1.5]),
  133. ("average", False, "top", [6.5, 6.5, 1.5, 5.0, 1.5, 3.5, 3.5]),
  134. ("average", False, "bottom", [4.5, 4.5, 6.5, 3.0, 6.5, 1.5, 1.5]),
  135. ("min", True, "keep", [1.0, 1.0, np.nan, 3.0, np.nan, 4.0, 4.0]),
  136. ("min", True, "top", [3.0, 3.0, 1.0, 5.0, 1.0, 6.0, 6.0]),
  137. ("min", True, "bottom", [1.0, 1.0, 6.0, 3.0, 6.0, 4.0, 4.0]),
  138. ("min", False, "keep", [4.0, 4.0, np.nan, 3.0, np.nan, 1.0, 1.0]),
  139. ("min", False, "top", [6.0, 6.0, 1.0, 5.0, 1.0, 3.0, 3.0]),
  140. ("min", False, "bottom", [4.0, 4.0, 6.0, 3.0, 6.0, 1.0, 1.0]),
  141. ("max", True, "keep", [2.0, 2.0, np.nan, 3.0, np.nan, 5.0, 5.0]),
  142. ("max", True, "top", [4.0, 4.0, 2.0, 5.0, 2.0, 7.0, 7.0]),
  143. ("max", True, "bottom", [2.0, 2.0, 7.0, 3.0, 7.0, 5.0, 5.0]),
  144. ("max", False, "keep", [5.0, 5.0, np.nan, 3.0, np.nan, 2.0, 2.0]),
  145. ("max", False, "top", [7.0, 7.0, 2.0, 5.0, 2.0, 4.0, 4.0]),
  146. ("max", False, "bottom", [5.0, 5.0, 7.0, 3.0, 7.0, 2.0, 2.0]),
  147. ("first", True, "keep", [1.0, 2.0, np.nan, 3.0, np.nan, 4.0, 5.0]),
  148. ("first", True, "top", [3.0, 4.0, 1.0, 5.0, 2.0, 6.0, 7.0]),
  149. ("first", True, "bottom", [1.0, 2.0, 6.0, 3.0, 7.0, 4.0, 5.0]),
  150. ("first", False, "keep", [4.0, 5.0, np.nan, 3.0, np.nan, 1.0, 2.0]),
  151. ("first", False, "top", [6.0, 7.0, 1.0, 5.0, 2.0, 3.0, 4.0]),
  152. ("first", False, "bottom", [4.0, 5.0, 6.0, 3.0, 7.0, 1.0, 2.0]),
  153. ("dense", True, "keep", [1.0, 1.0, np.nan, 2.0, np.nan, 3.0, 3.0]),
  154. ("dense", True, "top", [2.0, 2.0, 1.0, 3.0, 1.0, 4.0, 4.0]),
  155. ("dense", True, "bottom", [1.0, 1.0, 4.0, 2.0, 4.0, 3.0, 3.0]),
  156. ("dense", False, "keep", [3.0, 3.0, np.nan, 2.0, np.nan, 1.0, 1.0]),
  157. ("dense", False, "top", [4.0, 4.0, 1.0, 3.0, 1.0, 2.0, 2.0]),
  158. ("dense", False, "bottom", [3.0, 3.0, 4.0, 2.0, 4.0, 1.0, 1.0]),
  159. ],
  160. )
  161. def test_infs_n_nans(grps, vals, ties_method, ascending, na_option, exp):
  162. # GH 20561
  163. key = np.repeat(grps, len(vals))
  164. vals = vals * len(grps)
  165. df = DataFrame({"key": key, "val": vals})
  166. result = df.groupby("key").rank(
  167. method=ties_method, ascending=ascending, na_option=na_option
  168. )
  169. exp_df = DataFrame(exp * len(grps), columns=["val"])
  170. tm.assert_frame_equal(result, exp_df)
  171. @pytest.mark.parametrize("grps", [["qux"], ["qux", "quux"]])
  172. @pytest.mark.parametrize(
  173. "vals",
  174. [
  175. np.array([2, 2, np.nan, 8, 2, 6, np.nan, np.nan], dtype=dtype)
  176. for dtype in ["f8", "f4", "f2"]
  177. ]
  178. + [
  179. [
  180. pd.Timestamp("2018-01-02"),
  181. pd.Timestamp("2018-01-02"),
  182. np.nan,
  183. pd.Timestamp("2018-01-08"),
  184. pd.Timestamp("2018-01-02"),
  185. pd.Timestamp("2018-01-06"),
  186. np.nan,
  187. np.nan,
  188. ],
  189. [
  190. pd.Timestamp("2018-01-02", tz="US/Pacific"),
  191. pd.Timestamp("2018-01-02", tz="US/Pacific"),
  192. np.nan,
  193. pd.Timestamp("2018-01-08", tz="US/Pacific"),
  194. pd.Timestamp("2018-01-02", tz="US/Pacific"),
  195. pd.Timestamp("2018-01-06", tz="US/Pacific"),
  196. np.nan,
  197. np.nan,
  198. ],
  199. [
  200. pd.Timestamp("2018-01-02") - pd.Timestamp(0),
  201. pd.Timestamp("2018-01-02") - pd.Timestamp(0),
  202. np.nan,
  203. pd.Timestamp("2018-01-08") - pd.Timestamp(0),
  204. pd.Timestamp("2018-01-02") - pd.Timestamp(0),
  205. pd.Timestamp("2018-01-06") - pd.Timestamp(0),
  206. np.nan,
  207. np.nan,
  208. ],
  209. [
  210. pd.Timestamp("2018-01-02").to_period("D"),
  211. pd.Timestamp("2018-01-02").to_period("D"),
  212. np.nan,
  213. pd.Timestamp("2018-01-08").to_period("D"),
  214. pd.Timestamp("2018-01-02").to_period("D"),
  215. pd.Timestamp("2018-01-06").to_period("D"),
  216. np.nan,
  217. np.nan,
  218. ],
  219. ],
  220. ids=lambda x: type(x[0]),
  221. )
  222. @pytest.mark.parametrize(
  223. "ties_method,ascending,na_option,pct,exp",
  224. [
  225. (
  226. "average",
  227. True,
  228. "keep",
  229. False,
  230. [2.0, 2.0, np.nan, 5.0, 2.0, 4.0, np.nan, np.nan],
  231. ),
  232. (
  233. "average",
  234. True,
  235. "keep",
  236. True,
  237. [0.4, 0.4, np.nan, 1.0, 0.4, 0.8, np.nan, np.nan],
  238. ),
  239. (
  240. "average",
  241. False,
  242. "keep",
  243. False,
  244. [4.0, 4.0, np.nan, 1.0, 4.0, 2.0, np.nan, np.nan],
  245. ),
  246. (
  247. "average",
  248. False,
  249. "keep",
  250. True,
  251. [0.8, 0.8, np.nan, 0.2, 0.8, 0.4, np.nan, np.nan],
  252. ),
  253. ("min", True, "keep", False, [1.0, 1.0, np.nan, 5.0, 1.0, 4.0, np.nan, np.nan]),
  254. ("min", True, "keep", True, [0.2, 0.2, np.nan, 1.0, 0.2, 0.8, np.nan, np.nan]),
  255. (
  256. "min",
  257. False,
  258. "keep",
  259. False,
  260. [3.0, 3.0, np.nan, 1.0, 3.0, 2.0, np.nan, np.nan],
  261. ),
  262. ("min", False, "keep", True, [0.6, 0.6, np.nan, 0.2, 0.6, 0.4, np.nan, np.nan]),
  263. ("max", True, "keep", False, [3.0, 3.0, np.nan, 5.0, 3.0, 4.0, np.nan, np.nan]),
  264. ("max", True, "keep", True, [0.6, 0.6, np.nan, 1.0, 0.6, 0.8, np.nan, np.nan]),
  265. (
  266. "max",
  267. False,
  268. "keep",
  269. False,
  270. [5.0, 5.0, np.nan, 1.0, 5.0, 2.0, np.nan, np.nan],
  271. ),
  272. ("max", False, "keep", True, [1.0, 1.0, np.nan, 0.2, 1.0, 0.4, np.nan, np.nan]),
  273. (
  274. "first",
  275. True,
  276. "keep",
  277. False,
  278. [1.0, 2.0, np.nan, 5.0, 3.0, 4.0, np.nan, np.nan],
  279. ),
  280. (
  281. "first",
  282. True,
  283. "keep",
  284. True,
  285. [0.2, 0.4, np.nan, 1.0, 0.6, 0.8, np.nan, np.nan],
  286. ),
  287. (
  288. "first",
  289. False,
  290. "keep",
  291. False,
  292. [3.0, 4.0, np.nan, 1.0, 5.0, 2.0, np.nan, np.nan],
  293. ),
  294. (
  295. "first",
  296. False,
  297. "keep",
  298. True,
  299. [0.6, 0.8, np.nan, 0.2, 1.0, 0.4, np.nan, np.nan],
  300. ),
  301. (
  302. "dense",
  303. True,
  304. "keep",
  305. False,
  306. [1.0, 1.0, np.nan, 3.0, 1.0, 2.0, np.nan, np.nan],
  307. ),
  308. (
  309. "dense",
  310. True,
  311. "keep",
  312. True,
  313. [
  314. 1.0 / 3.0,
  315. 1.0 / 3.0,
  316. np.nan,
  317. 3.0 / 3.0,
  318. 1.0 / 3.0,
  319. 2.0 / 3.0,
  320. np.nan,
  321. np.nan,
  322. ],
  323. ),
  324. (
  325. "dense",
  326. False,
  327. "keep",
  328. False,
  329. [3.0, 3.0, np.nan, 1.0, 3.0, 2.0, np.nan, np.nan],
  330. ),
  331. (
  332. "dense",
  333. False,
  334. "keep",
  335. True,
  336. [
  337. 3.0 / 3.0,
  338. 3.0 / 3.0,
  339. np.nan,
  340. 1.0 / 3.0,
  341. 3.0 / 3.0,
  342. 2.0 / 3.0,
  343. np.nan,
  344. np.nan,
  345. ],
  346. ),
  347. ("average", True, "bottom", False, [2.0, 2.0, 7.0, 5.0, 2.0, 4.0, 7.0, 7.0]),
  348. (
  349. "average",
  350. True,
  351. "bottom",
  352. True,
  353. [0.25, 0.25, 0.875, 0.625, 0.25, 0.5, 0.875, 0.875],
  354. ),
  355. ("average", False, "bottom", False, [4.0, 4.0, 7.0, 1.0, 4.0, 2.0, 7.0, 7.0]),
  356. (
  357. "average",
  358. False,
  359. "bottom",
  360. True,
  361. [0.5, 0.5, 0.875, 0.125, 0.5, 0.25, 0.875, 0.875],
  362. ),
  363. ("min", True, "bottom", False, [1.0, 1.0, 6.0, 5.0, 1.0, 4.0, 6.0, 6.0]),
  364. (
  365. "min",
  366. True,
  367. "bottom",
  368. True,
  369. [0.125, 0.125, 0.75, 0.625, 0.125, 0.5, 0.75, 0.75],
  370. ),
  371. ("min", False, "bottom", False, [3.0, 3.0, 6.0, 1.0, 3.0, 2.0, 6.0, 6.0]),
  372. (
  373. "min",
  374. False,
  375. "bottom",
  376. True,
  377. [0.375, 0.375, 0.75, 0.125, 0.375, 0.25, 0.75, 0.75],
  378. ),
  379. ("max", True, "bottom", False, [3.0, 3.0, 8.0, 5.0, 3.0, 4.0, 8.0, 8.0]),
  380. ("max", True, "bottom", True, [0.375, 0.375, 1.0, 0.625, 0.375, 0.5, 1.0, 1.0]),
  381. ("max", False, "bottom", False, [5.0, 5.0, 8.0, 1.0, 5.0, 2.0, 8.0, 8.0]),
  382. (
  383. "max",
  384. False,
  385. "bottom",
  386. True,
  387. [0.625, 0.625, 1.0, 0.125, 0.625, 0.25, 1.0, 1.0],
  388. ),
  389. ("first", True, "bottom", False, [1.0, 2.0, 6.0, 5.0, 3.0, 4.0, 7.0, 8.0]),
  390. (
  391. "first",
  392. True,
  393. "bottom",
  394. True,
  395. [0.125, 0.25, 0.75, 0.625, 0.375, 0.5, 0.875, 1.0],
  396. ),
  397. ("first", False, "bottom", False, [3.0, 4.0, 6.0, 1.0, 5.0, 2.0, 7.0, 8.0]),
  398. (
  399. "first",
  400. False,
  401. "bottom",
  402. True,
  403. [0.375, 0.5, 0.75, 0.125, 0.625, 0.25, 0.875, 1.0],
  404. ),
  405. ("dense", True, "bottom", False, [1.0, 1.0, 4.0, 3.0, 1.0, 2.0, 4.0, 4.0]),
  406. ("dense", True, "bottom", True, [0.25, 0.25, 1.0, 0.75, 0.25, 0.5, 1.0, 1.0]),
  407. ("dense", False, "bottom", False, [3.0, 3.0, 4.0, 1.0, 3.0, 2.0, 4.0, 4.0]),
  408. ("dense", False, "bottom", True, [0.75, 0.75, 1.0, 0.25, 0.75, 0.5, 1.0, 1.0]),
  409. ],
  410. )
  411. def test_rank_args_missing(grps, vals, ties_method, ascending, na_option, pct, exp):
  412. key = np.repeat(grps, len(vals))
  413. orig_vals = vals
  414. vals = list(vals) * len(grps)
  415. if isinstance(orig_vals, np.ndarray):
  416. vals = np.array(vals, dtype=orig_vals.dtype)
  417. df = DataFrame({"key": key, "val": vals})
  418. result = df.groupby("key").rank(
  419. method=ties_method, ascending=ascending, na_option=na_option, pct=pct
  420. )
  421. exp_df = DataFrame(exp * len(grps), columns=["val"])
  422. tm.assert_frame_equal(result, exp_df)
  423. @pytest.mark.parametrize(
  424. "pct,exp", [(False, [3.0, 3.0, 3.0, 3.0, 3.0]), (True, [0.6, 0.6, 0.6, 0.6, 0.6])]
  425. )
  426. def test_rank_resets_each_group(pct, exp):
  427. df = DataFrame(
  428. {"key": ["a", "a", "a", "a", "a", "b", "b", "b", "b", "b"], "val": [1] * 10}
  429. )
  430. result = df.groupby("key").rank(pct=pct)
  431. exp_df = DataFrame(exp * 2, columns=["val"])
  432. tm.assert_frame_equal(result, exp_df)
  433. @pytest.mark.parametrize(
  434. "dtype", ["int64", "int32", "uint64", "uint32", "float64", "float32"]
  435. )
  436. @pytest.mark.parametrize("upper", [True, False])
  437. def test_rank_avg_even_vals(dtype, upper):
  438. if upper:
  439. # use IntegerDtype/FloatingDtype
  440. dtype = dtype[0].upper() + dtype[1:]
  441. dtype = dtype.replace("Ui", "UI")
  442. df = DataFrame({"key": ["a"] * 4, "val": [1] * 4})
  443. df["val"] = df["val"].astype(dtype)
  444. assert df["val"].dtype == dtype
  445. result = df.groupby("key").rank()
  446. exp_df = DataFrame([2.5, 2.5, 2.5, 2.5], columns=["val"])
  447. if upper:
  448. exp_df = exp_df.astype("Float64")
  449. tm.assert_frame_equal(result, exp_df)
  450. @pytest.mark.parametrize("ties_method", ["average", "min", "max", "first", "dense"])
  451. @pytest.mark.parametrize("ascending", [True, False])
  452. @pytest.mark.parametrize("na_option", ["keep", "top", "bottom"])
  453. @pytest.mark.parametrize("pct", [True, False])
  454. @pytest.mark.parametrize(
  455. "vals", [["bar", "bar", "foo", "bar", "baz"], ["bar", np.nan, "foo", np.nan, "baz"]]
  456. )
  457. def test_rank_object_dtype(ties_method, ascending, na_option, pct, vals):
  458. df = DataFrame({"key": ["foo"] * 5, "val": vals})
  459. mask = df["val"].isna()
  460. gb = df.groupby("key")
  461. res = gb.rank(method=ties_method, ascending=ascending, na_option=na_option, pct=pct)
  462. # construct our expected by using numeric values with the same ordering
  463. if mask.any():
  464. df2 = DataFrame({"key": ["foo"] * 5, "val": [0, np.nan, 2, np.nan, 1]})
  465. else:
  466. df2 = DataFrame({"key": ["foo"] * 5, "val": [0, 0, 2, 0, 1]})
  467. gb2 = df2.groupby("key")
  468. alt = gb2.rank(
  469. method=ties_method, ascending=ascending, na_option=na_option, pct=pct
  470. )
  471. tm.assert_frame_equal(res, alt)
  472. @pytest.mark.parametrize("na_option", [True, "bad", 1])
  473. @pytest.mark.parametrize("ties_method", ["average", "min", "max", "first", "dense"])
  474. @pytest.mark.parametrize("ascending", [True, False])
  475. @pytest.mark.parametrize("pct", [True, False])
  476. @pytest.mark.parametrize(
  477. "vals",
  478. [
  479. ["bar", "bar", "foo", "bar", "baz"],
  480. ["bar", np.nan, "foo", np.nan, "baz"],
  481. [1, np.nan, 2, np.nan, 3],
  482. ],
  483. )
  484. def test_rank_naoption_raises(ties_method, ascending, na_option, pct, vals):
  485. df = DataFrame({"key": ["foo"] * 5, "val": vals})
  486. msg = "na_option must be one of 'keep', 'top', or 'bottom'"
  487. with pytest.raises(ValueError, match=msg):
  488. df.groupby("key").rank(
  489. method=ties_method, ascending=ascending, na_option=na_option, pct=pct
  490. )
  491. def test_rank_empty_group():
  492. # see gh-22519
  493. column = "A"
  494. df = DataFrame({"A": [0, 1, 0], "B": [1.0, np.nan, 2.0]})
  495. result = df.groupby(column).B.rank(pct=True)
  496. expected = Series([0.5, np.nan, 1.0], name="B")
  497. tm.assert_series_equal(result, expected)
  498. result = df.groupby(column).rank(pct=True)
  499. expected = DataFrame({"B": [0.5, np.nan, 1.0]})
  500. tm.assert_frame_equal(result, expected)
  501. @pytest.mark.parametrize(
  502. "input_key,input_value,output_value",
  503. [
  504. ([1, 2], [1, 1], [1.0, 1.0]),
  505. ([1, 1, 2, 2], [1, 2, 1, 2], [0.5, 1.0, 0.5, 1.0]),
  506. ([1, 1, 2, 2], [1, 2, 1, np.nan], [0.5, 1.0, 1.0, np.nan]),
  507. ([1, 1, 2], [1, 2, np.nan], [0.5, 1.0, np.nan]),
  508. ],
  509. )
  510. def test_rank_zero_div(input_key, input_value, output_value):
  511. # GH 23666
  512. df = DataFrame({"A": input_key, "B": input_value})
  513. result = df.groupby("A").rank(method="dense", pct=True)
  514. expected = DataFrame({"B": output_value})
  515. tm.assert_frame_equal(result, expected)
  516. def test_rank_min_int():
  517. # GH-32859
  518. df = DataFrame(
  519. {
  520. "grp": [1, 1, 2],
  521. "int_col": [
  522. np.iinfo(np.int64).min,
  523. np.iinfo(np.int64).max,
  524. np.iinfo(np.int64).min,
  525. ],
  526. "datetimelike": [NaT, datetime(2001, 1, 1), NaT],
  527. }
  528. )
  529. result = df.groupby("grp").rank()
  530. expected = DataFrame(
  531. {"int_col": [1.0, 2.0, 1.0], "datetimelike": [np.NaN, 1.0, np.NaN]}
  532. )
  533. tm.assert_frame_equal(result, expected)
  534. @pytest.mark.parametrize("use_nan", [True, False])
  535. def test_rank_pct_equal_values_on_group_transition(use_nan):
  536. # GH#40518
  537. fill_value = np.nan if use_nan else 3
  538. df = DataFrame(
  539. [
  540. [-1, 1],
  541. [-1, 2],
  542. [1, fill_value],
  543. [-1, fill_value],
  544. ],
  545. columns=["group", "val"],
  546. )
  547. result = df.groupby(["group"])["val"].rank(
  548. method="dense",
  549. pct=True,
  550. )
  551. if use_nan:
  552. expected = Series([0.5, 1, np.nan, np.nan], name="val")
  553. else:
  554. expected = Series([1 / 3, 2 / 3, 1, 1], name="val")
  555. tm.assert_series_equal(result, expected)
  556. def test_rank_multiindex():
  557. # GH27721
  558. df = concat(
  559. {
  560. "a": DataFrame({"col1": [3, 4], "col2": [1, 2]}),
  561. "b": DataFrame({"col3": [5, 6], "col4": [7, 8]}),
  562. },
  563. axis=1,
  564. )
  565. gb = df.groupby(level=0, axis=1)
  566. result = gb.rank(axis=1)
  567. expected = concat(
  568. [
  569. df["a"].rank(axis=1),
  570. df["b"].rank(axis=1),
  571. ],
  572. axis=1,
  573. keys=["a", "b"],
  574. )
  575. tm.assert_frame_equal(result, expected)
  576. def test_groupby_axis0_rank_axis1():
  577. # GH#41320
  578. df = DataFrame(
  579. {0: [1, 3, 5, 7], 1: [2, 4, 6, 8], 2: [1.5, 3.5, 5.5, 7.5]},
  580. index=["a", "a", "b", "b"],
  581. )
  582. gb = df.groupby(level=0, axis=0)
  583. res = gb.rank(axis=1)
  584. # This should match what we get when "manually" operating group-by-group
  585. expected = concat([df.loc["a"].rank(axis=1), df.loc["b"].rank(axis=1)], axis=0)
  586. tm.assert_frame_equal(res, expected)
  587. # check that we haven't accidentally written a case that coincidentally
  588. # matches rank(axis=0)
  589. alt = gb.rank(axis=0)
  590. assert not alt.equals(expected)
  591. def test_groupby_axis0_cummax_axis1():
  592. # case where groupby axis is 0 and axis keyword in transform is 1
  593. # df has mixed dtype -> multiple blocks
  594. df = DataFrame(
  595. {0: [1, 3, 5, 7], 1: [2, 4, 6, 8], 2: [1.5, 3.5, 5.5, 7.5]},
  596. index=["a", "a", "b", "b"],
  597. )
  598. gb = df.groupby(level=0, axis=0)
  599. cmax = gb.cummax(axis=1)
  600. expected = df[[0, 1]].astype(np.float64)
  601. expected[2] = expected[1]
  602. tm.assert_frame_equal(cmax, expected)
  603. def test_non_unique_index():
  604. # GH 16577
  605. df = DataFrame(
  606. {"A": [1.0, 2.0, 3.0, np.nan], "value": 1.0},
  607. index=[pd.Timestamp("20170101", tz="US/Eastern")] * 4,
  608. )
  609. result = df.groupby([df.index, "A"]).value.rank(ascending=True, pct=True)
  610. expected = Series(
  611. [1.0, 1.0, 1.0, np.nan],
  612. index=[pd.Timestamp("20170101", tz="US/Eastern")] * 4,
  613. name="value",
  614. )
  615. tm.assert_series_equal(result, expected)
  616. def test_rank_categorical():
  617. cat = pd.Categorical(["a", "a", "b", np.nan, "c", "b"], ordered=True)
  618. cat2 = pd.Categorical([1, 2, 3, np.nan, 4, 5], ordered=True)
  619. df = DataFrame({"col1": [0, 1, 0, 1, 0, 1], "col2": cat, "col3": cat2})
  620. gb = df.groupby("col1")
  621. res = gb.rank()
  622. expected = df.astype(object).groupby("col1").rank()
  623. tm.assert_frame_equal(res, expected)