test_sorting.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. from collections import defaultdict
  2. from datetime import datetime
  3. from itertools import product
  4. import numpy as np
  5. import pytest
  6. from pandas.compat import (
  7. is_ci_environment,
  8. is_platform_windows,
  9. )
  10. from pandas import (
  11. NA,
  12. DataFrame,
  13. MultiIndex,
  14. Series,
  15. array,
  16. concat,
  17. merge,
  18. )
  19. import pandas._testing as tm
  20. from pandas.core.algorithms import safe_sort
  21. import pandas.core.common as com
  22. from pandas.core.sorting import (
  23. _decons_group_index,
  24. get_group_index,
  25. is_int64_overflow_possible,
  26. lexsort_indexer,
  27. nargsort,
  28. )
  29. @pytest.fixture
  30. def left_right():
  31. low, high, n = -1 << 10, 1 << 10, 1 << 20
  32. left = DataFrame(np.random.randint(low, high, (n, 7)), columns=list("ABCDEFG"))
  33. left["left"] = left.sum(axis=1)
  34. # one-2-one match
  35. i = np.random.permutation(len(left))
  36. right = left.iloc[i].copy()
  37. right.columns = right.columns[:-1].tolist() + ["right"]
  38. right.index = np.arange(len(right))
  39. right["right"] *= -1
  40. return left, right
  41. class TestSorting:
  42. @pytest.mark.slow
  43. def test_int64_overflow(self):
  44. B = np.concatenate((np.arange(1000), np.arange(1000), np.arange(500)))
  45. A = np.arange(2500)
  46. df = DataFrame(
  47. {
  48. "A": A,
  49. "B": B,
  50. "C": A,
  51. "D": B,
  52. "E": A,
  53. "F": B,
  54. "G": A,
  55. "H": B,
  56. "values": np.random.randn(2500),
  57. }
  58. )
  59. lg = df.groupby(["A", "B", "C", "D", "E", "F", "G", "H"])
  60. rg = df.groupby(["H", "G", "F", "E", "D", "C", "B", "A"])
  61. left = lg.sum()["values"]
  62. right = rg.sum()["values"]
  63. exp_index, _ = left.index.sortlevel()
  64. tm.assert_index_equal(left.index, exp_index)
  65. exp_index, _ = right.index.sortlevel(0)
  66. tm.assert_index_equal(right.index, exp_index)
  67. tups = list(map(tuple, df[["A", "B", "C", "D", "E", "F", "G", "H"]].values))
  68. tups = com.asarray_tuplesafe(tups)
  69. expected = df.groupby(tups).sum()["values"]
  70. for k, v in expected.items():
  71. assert left[k] == right[k[::-1]]
  72. assert left[k] == v
  73. assert len(left) == len(right)
  74. def test_int64_overflow_groupby_large_range(self):
  75. # GH9096
  76. values = range(55109)
  77. data = DataFrame.from_dict({"a": values, "b": values, "c": values, "d": values})
  78. grouped = data.groupby(["a", "b", "c", "d"])
  79. assert len(grouped) == len(values)
  80. @pytest.mark.parametrize("agg", ["mean", "median"])
  81. def test_int64_overflow_groupby_large_df_shuffled(self, agg):
  82. arr = np.random.randint(-1 << 12, 1 << 12, (1 << 15, 5))
  83. i = np.random.choice(len(arr), len(arr) * 4)
  84. arr = np.vstack((arr, arr[i])) # add some duplicate rows
  85. i = np.random.permutation(len(arr))
  86. arr = arr[i] # shuffle rows
  87. df = DataFrame(arr, columns=list("abcde"))
  88. df["jim"], df["joe"] = np.random.randn(2, len(df)) * 10
  89. gr = df.groupby(list("abcde"))
  90. # verify this is testing what it is supposed to test!
  91. assert is_int64_overflow_possible(gr.grouper.shape)
  92. # manually compute groupings
  93. jim, joe = defaultdict(list), defaultdict(list)
  94. for key, a, b in zip(map(tuple, arr), df["jim"], df["joe"]):
  95. jim[key].append(a)
  96. joe[key].append(b)
  97. assert len(gr) == len(jim)
  98. mi = MultiIndex.from_tuples(jim.keys(), names=list("abcde"))
  99. f = lambda a: np.fromiter(map(getattr(np, agg), a), dtype="f8")
  100. arr = np.vstack((f(jim.values()), f(joe.values()))).T
  101. res = DataFrame(arr, columns=["jim", "joe"], index=mi).sort_index()
  102. tm.assert_frame_equal(getattr(gr, agg)(), res)
  103. @pytest.mark.parametrize(
  104. "order, na_position, exp",
  105. [
  106. [
  107. True,
  108. "last",
  109. list(range(5, 105)) + list(range(5)) + list(range(105, 110)),
  110. ],
  111. [
  112. True,
  113. "first",
  114. list(range(5)) + list(range(105, 110)) + list(range(5, 105)),
  115. ],
  116. [
  117. False,
  118. "last",
  119. list(range(104, 4, -1)) + list(range(5)) + list(range(105, 110)),
  120. ],
  121. [
  122. False,
  123. "first",
  124. list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)),
  125. ],
  126. ],
  127. )
  128. def test_lexsort_indexer(self, order, na_position, exp):
  129. keys = [[np.nan] * 5 + list(range(100)) + [np.nan] * 5]
  130. result = lexsort_indexer(keys, orders=order, na_position=na_position)
  131. tm.assert_numpy_array_equal(result, np.array(exp, dtype=np.intp))
  132. @pytest.mark.parametrize(
  133. "ascending, na_position, exp, box",
  134. [
  135. [
  136. True,
  137. "last",
  138. list(range(5, 105)) + list(range(5)) + list(range(105, 110)),
  139. list,
  140. ],
  141. [
  142. True,
  143. "first",
  144. list(range(5)) + list(range(105, 110)) + list(range(5, 105)),
  145. list,
  146. ],
  147. [
  148. False,
  149. "last",
  150. list(range(104, 4, -1)) + list(range(5)) + list(range(105, 110)),
  151. list,
  152. ],
  153. [
  154. False,
  155. "first",
  156. list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)),
  157. list,
  158. ],
  159. [
  160. True,
  161. "last",
  162. list(range(5, 105)) + list(range(5)) + list(range(105, 110)),
  163. lambda x: np.array(x, dtype="O"),
  164. ],
  165. [
  166. True,
  167. "first",
  168. list(range(5)) + list(range(105, 110)) + list(range(5, 105)),
  169. lambda x: np.array(x, dtype="O"),
  170. ],
  171. [
  172. False,
  173. "last",
  174. list(range(104, 4, -1)) + list(range(5)) + list(range(105, 110)),
  175. lambda x: np.array(x, dtype="O"),
  176. ],
  177. [
  178. False,
  179. "first",
  180. list(range(5)) + list(range(105, 110)) + list(range(104, 4, -1)),
  181. lambda x: np.array(x, dtype="O"),
  182. ],
  183. ],
  184. )
  185. def test_nargsort(self, ascending, na_position, exp, box):
  186. # list places NaNs last, np.array(..., dtype="O") may not place NaNs first
  187. items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)
  188. # mergesort is the most difficult to get right because we want it to be
  189. # stable.
  190. # According to numpy/core/tests/test_multiarray, """The number of
  191. # sorted items must be greater than ~50 to check the actual algorithm
  192. # because quick and merge sort fall over to insertion sort for small
  193. # arrays."""
  194. result = nargsort(
  195. items, kind="mergesort", ascending=ascending, na_position=na_position
  196. )
  197. tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)
  198. class TestMerge:
  199. def test_int64_overflow_outer_merge(self):
  200. # #2690, combinatorial explosion
  201. df1 = DataFrame(np.random.randn(1000, 7), columns=list("ABCDEF") + ["G1"])
  202. df2 = DataFrame(np.random.randn(1000, 7), columns=list("ABCDEF") + ["G2"])
  203. result = merge(df1, df2, how="outer")
  204. assert len(result) == 2000
  205. @pytest.mark.slow
  206. def test_int64_overflow_check_sum_col(self, left_right):
  207. left, right = left_right
  208. out = merge(left, right, how="outer")
  209. assert len(out) == len(left)
  210. tm.assert_series_equal(out["left"], -out["right"], check_names=False)
  211. result = out.iloc[:, :-2].sum(axis=1)
  212. tm.assert_series_equal(out["left"], result, check_names=False)
  213. assert result.name is None
  214. @pytest.mark.slow
  215. @pytest.mark.parametrize("how", ["left", "right", "outer", "inner"])
  216. def test_int64_overflow_how_merge(self, left_right, how):
  217. left, right = left_right
  218. out = merge(left, right, how="outer")
  219. out.sort_values(out.columns.tolist(), inplace=True)
  220. out.index = np.arange(len(out))
  221. tm.assert_frame_equal(out, merge(left, right, how=how, sort=True))
  222. @pytest.mark.slow
  223. def test_int64_overflow_sort_false_order(self, left_right):
  224. left, right = left_right
  225. # check that left merge w/ sort=False maintains left frame order
  226. out = merge(left, right, how="left", sort=False)
  227. tm.assert_frame_equal(left, out[left.columns.tolist()])
  228. out = merge(right, left, how="left", sort=False)
  229. tm.assert_frame_equal(right, out[right.columns.tolist()])
  230. @pytest.mark.slow
  231. @pytest.mark.parametrize("how", ["left", "right", "outer", "inner"])
  232. @pytest.mark.parametrize("sort", [True, False])
  233. def test_int64_overflow_one_to_many_none_match(self, how, sort):
  234. # one-2-many/none match
  235. low, high, n = -1 << 10, 1 << 10, 1 << 11
  236. left = DataFrame(
  237. np.random.randint(low, high, (n, 7)).astype("int64"),
  238. columns=list("ABCDEFG"),
  239. )
  240. # confirm that this is checking what it is supposed to check
  241. shape = left.apply(Series.nunique).values
  242. assert is_int64_overflow_possible(shape)
  243. # add duplicates to left frame
  244. left = concat([left, left], ignore_index=True)
  245. right = DataFrame(
  246. np.random.randint(low, high, (n // 2, 7)).astype("int64"),
  247. columns=list("ABCDEFG"),
  248. )
  249. # add duplicates & overlap with left to the right frame
  250. i = np.random.choice(len(left), n)
  251. right = concat([right, right, left.iloc[i]], ignore_index=True)
  252. left["left"] = np.random.randn(len(left))
  253. right["right"] = np.random.randn(len(right))
  254. # shuffle left & right frames
  255. i = np.random.permutation(len(left))
  256. left = left.iloc[i].copy()
  257. left.index = np.arange(len(left))
  258. i = np.random.permutation(len(right))
  259. right = right.iloc[i].copy()
  260. right.index = np.arange(len(right))
  261. # manually compute outer merge
  262. ldict, rdict = defaultdict(list), defaultdict(list)
  263. for idx, row in left.set_index(list("ABCDEFG")).iterrows():
  264. ldict[idx].append(row["left"])
  265. for idx, row in right.set_index(list("ABCDEFG")).iterrows():
  266. rdict[idx].append(row["right"])
  267. vals = []
  268. for k, lval in ldict.items():
  269. rval = rdict.get(k, [np.nan])
  270. for lv, rv in product(lval, rval):
  271. vals.append(
  272. k
  273. + (
  274. lv,
  275. rv,
  276. )
  277. )
  278. for k, rval in rdict.items():
  279. if k not in ldict:
  280. for rv in rval:
  281. vals.append(
  282. k
  283. + (
  284. np.nan,
  285. rv,
  286. )
  287. )
  288. def align(df):
  289. df = df.sort_values(df.columns.tolist())
  290. df.index = np.arange(len(df))
  291. return df
  292. out = DataFrame(vals, columns=list("ABCDEFG") + ["left", "right"])
  293. out = align(out)
  294. jmask = {
  295. "left": out["left"].notna(),
  296. "right": out["right"].notna(),
  297. "inner": out["left"].notna() & out["right"].notna(),
  298. "outer": np.ones(len(out), dtype="bool"),
  299. }
  300. mask = jmask[how]
  301. frame = align(out[mask].copy())
  302. assert mask.all() ^ mask.any() or how == "outer"
  303. res = merge(left, right, how=how, sort=sort)
  304. if sort:
  305. kcols = list("ABCDEFG")
  306. tm.assert_frame_equal(
  307. res[kcols].copy(), res[kcols].sort_values(kcols, kind="mergesort")
  308. )
  309. # as in GH9092 dtypes break with outer/right join
  310. # 2021-12-18: dtype does not break anymore
  311. tm.assert_frame_equal(frame, align(res))
  312. @pytest.mark.parametrize(
  313. "codes_list, shape",
  314. [
  315. [
  316. [
  317. np.tile([0, 1, 2, 3, 0, 1, 2, 3], 100).astype(np.int64),
  318. np.tile([0, 2, 4, 3, 0, 1, 2, 3], 100).astype(np.int64),
  319. np.tile([5, 1, 0, 2, 3, 0, 5, 4], 100).astype(np.int64),
  320. ],
  321. (4, 5, 6),
  322. ],
  323. [
  324. [
  325. np.tile(np.arange(10000, dtype=np.int64), 5),
  326. np.tile(np.arange(10000, dtype=np.int64), 5),
  327. ],
  328. (10000, 10000),
  329. ],
  330. ],
  331. )
  332. def test_decons(codes_list, shape):
  333. group_index = get_group_index(codes_list, shape, sort=True, xnull=True)
  334. codes_list2 = _decons_group_index(group_index, shape)
  335. for a, b in zip(codes_list, codes_list2):
  336. tm.assert_numpy_array_equal(a, b)
  337. class TestSafeSort:
  338. @pytest.mark.parametrize(
  339. "arg, exp",
  340. [
  341. [[3, 1, 2, 0, 4], [0, 1, 2, 3, 4]],
  342. [list("baaacb"), np.array(list("aaabbc"), dtype=object)],
  343. [[], []],
  344. ],
  345. )
  346. def test_basic_sort(self, arg, exp):
  347. result = safe_sort(arg)
  348. expected = np.array(exp)
  349. tm.assert_numpy_array_equal(result, expected)
  350. @pytest.mark.parametrize("verify", [True, False])
  351. @pytest.mark.parametrize(
  352. "codes, exp_codes",
  353. [
  354. [[0, 1, 1, 2, 3, 0, -1, 4], [3, 1, 1, 2, 0, 3, -1, 4]],
  355. [[], []],
  356. ],
  357. )
  358. def test_codes(self, verify, codes, exp_codes):
  359. values = [3, 1, 2, 0, 4]
  360. expected = np.array([0, 1, 2, 3, 4])
  361. result, result_codes = safe_sort(
  362. values, codes, use_na_sentinel=True, verify=verify
  363. )
  364. expected_codes = np.array(exp_codes, dtype=np.intp)
  365. tm.assert_numpy_array_equal(result, expected)
  366. tm.assert_numpy_array_equal(result_codes, expected_codes)
  367. @pytest.mark.skipif(
  368. is_platform_windows() and is_ci_environment(),
  369. reason="In CI environment can crash thread with: "
  370. "Windows fatal exception: access violation",
  371. )
  372. def test_codes_out_of_bound(self):
  373. values = [3, 1, 2, 0, 4]
  374. expected = np.array([0, 1, 2, 3, 4])
  375. # out of bound indices
  376. codes = [0, 101, 102, 2, 3, 0, 99, 4]
  377. result, result_codes = safe_sort(values, codes, use_na_sentinel=True)
  378. expected_codes = np.array([3, -1, -1, 2, 0, 3, -1, 4], dtype=np.intp)
  379. tm.assert_numpy_array_equal(result, expected)
  380. tm.assert_numpy_array_equal(result_codes, expected_codes)
  381. @pytest.mark.parametrize("box", [lambda x: np.array(x, dtype=object), list])
  382. def test_mixed_integer(self, box):
  383. values = box(["b", 1, 0, "a", 0, "b"])
  384. result = safe_sort(values)
  385. expected = np.array([0, 0, 1, "a", "b", "b"], dtype=object)
  386. tm.assert_numpy_array_equal(result, expected)
  387. def test_mixed_integer_with_codes(self):
  388. values = np.array(["b", 1, 0, "a"], dtype=object)
  389. codes = [0, 1, 2, 3, 0, -1, 1]
  390. result, result_codes = safe_sort(values, codes)
  391. expected = np.array([0, 1, "a", "b"], dtype=object)
  392. expected_codes = np.array([3, 1, 0, 2, 3, -1, 1], dtype=np.intp)
  393. tm.assert_numpy_array_equal(result, expected)
  394. tm.assert_numpy_array_equal(result_codes, expected_codes)
  395. def test_unsortable(self):
  396. # GH 13714
  397. arr = np.array([1, 2, datetime.now(), 0, 3], dtype=object)
  398. msg = "'[<>]' not supported between instances of .*"
  399. with pytest.raises(TypeError, match=msg):
  400. safe_sort(arr)
  401. @pytest.mark.parametrize(
  402. "arg, codes, err, msg",
  403. [
  404. [1, None, TypeError, "Only list-like objects are allowed"],
  405. [[0, 1, 2], 1, TypeError, "Only list-like objects or None"],
  406. [[0, 1, 2, 1], [0, 1], ValueError, "values should be unique"],
  407. ],
  408. )
  409. def test_exceptions(self, arg, codes, err, msg):
  410. with pytest.raises(err, match=msg):
  411. safe_sort(values=arg, codes=codes)
  412. @pytest.mark.parametrize(
  413. "arg, exp", [[[1, 3, 2], [1, 2, 3]], [[1, 3, np.nan, 2], [1, 2, 3, np.nan]]]
  414. )
  415. def test_extension_array(self, arg, exp):
  416. a = array(arg, dtype="Int64")
  417. result = safe_sort(a)
  418. expected = array(exp, dtype="Int64")
  419. tm.assert_extension_array_equal(result, expected)
  420. @pytest.mark.parametrize("verify", [True, False])
  421. def test_extension_array_codes(self, verify):
  422. a = array([1, 3, 2], dtype="Int64")
  423. result, codes = safe_sort(a, [0, 1, -1, 2], use_na_sentinel=True, verify=verify)
  424. expected_values = array([1, 2, 3], dtype="Int64")
  425. expected_codes = np.array([0, 2, -1, 1], dtype=np.intp)
  426. tm.assert_extension_array_equal(result, expected_values)
  427. tm.assert_numpy_array_equal(codes, expected_codes)
  428. def test_mixed_str_null(nulls_fixture):
  429. values = np.array(["b", nulls_fixture, "a", "b"], dtype=object)
  430. result = safe_sort(values)
  431. expected = np.array(["a", "b", "b", nulls_fixture], dtype=object)
  432. tm.assert_numpy_array_equal(result, expected)
  433. def test_safe_sort_multiindex():
  434. # GH#48412
  435. arr1 = Series([2, 1, NA, NA], dtype="Int64")
  436. arr2 = [2, 1, 3, 3]
  437. midx = MultiIndex.from_arrays([arr1, arr2])
  438. result = safe_sort(midx)
  439. expected = MultiIndex.from_arrays(
  440. [Series([1, 2, NA, NA], dtype="Int64"), [1, 2, 3, 3]]
  441. )
  442. tm.assert_index_equal(result, expected)