test_join.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. from datetime import datetime
  2. import numpy as np
  3. import pytest
  4. from pandas.errors import MergeError
  5. import pandas as pd
  6. from pandas import (
  7. DataFrame,
  8. Index,
  9. MultiIndex,
  10. date_range,
  11. period_range,
  12. )
  13. import pandas._testing as tm
  14. from pandas.core.reshape.concat import concat
  15. @pytest.fixture
  16. def frame_with_period_index():
  17. return DataFrame(
  18. data=np.arange(20).reshape(4, 5),
  19. columns=list("abcde"),
  20. index=period_range(start="2000", freq="A", periods=4),
  21. )
  22. @pytest.fixture
  23. def left():
  24. return DataFrame({"a": [20, 10, 0]}, index=[2, 1, 0])
  25. @pytest.fixture
  26. def right():
  27. return DataFrame({"b": [300, 100, 200]}, index=[3, 1, 2])
  28. @pytest.fixture
  29. def left_no_dup():
  30. return DataFrame(
  31. {"a": ["a", "b", "c", "d"], "b": ["cat", "dog", "weasel", "horse"]},
  32. index=range(4),
  33. )
  34. @pytest.fixture
  35. def right_no_dup():
  36. return DataFrame(
  37. {
  38. "a": ["a", "b", "c", "d", "e"],
  39. "c": ["meow", "bark", "um... weasel noise?", "nay", "chirp"],
  40. },
  41. index=range(5),
  42. ).set_index("a")
  43. @pytest.fixture
  44. def left_w_dups(left_no_dup):
  45. return concat(
  46. [left_no_dup, DataFrame({"a": ["a"], "b": ["cow"]}, index=[3])], sort=True
  47. )
  48. @pytest.fixture
  49. def right_w_dups(right_no_dup):
  50. return concat(
  51. [right_no_dup, DataFrame({"a": ["e"], "c": ["moo"]}, index=[3])]
  52. ).set_index("a")
  53. @pytest.mark.parametrize(
  54. "how, sort, expected",
  55. [
  56. ("inner", False, DataFrame({"a": [20, 10], "b": [200, 100]}, index=[2, 1])),
  57. ("inner", True, DataFrame({"a": [10, 20], "b": [100, 200]}, index=[1, 2])),
  58. (
  59. "left",
  60. False,
  61. DataFrame({"a": [20, 10, 0], "b": [200, 100, np.nan]}, index=[2, 1, 0]),
  62. ),
  63. (
  64. "left",
  65. True,
  66. DataFrame({"a": [0, 10, 20], "b": [np.nan, 100, 200]}, index=[0, 1, 2]),
  67. ),
  68. (
  69. "right",
  70. False,
  71. DataFrame({"a": [np.nan, 10, 20], "b": [300, 100, 200]}, index=[3, 1, 2]),
  72. ),
  73. (
  74. "right",
  75. True,
  76. DataFrame({"a": [10, 20, np.nan], "b": [100, 200, 300]}, index=[1, 2, 3]),
  77. ),
  78. (
  79. "outer",
  80. False,
  81. DataFrame(
  82. {"a": [0, 10, 20, np.nan], "b": [np.nan, 100, 200, 300]},
  83. index=[0, 1, 2, 3],
  84. ),
  85. ),
  86. (
  87. "outer",
  88. True,
  89. DataFrame(
  90. {"a": [0, 10, 20, np.nan], "b": [np.nan, 100, 200, 300]},
  91. index=[0, 1, 2, 3],
  92. ),
  93. ),
  94. ],
  95. )
  96. def test_join(left, right, how, sort, expected):
  97. result = left.join(right, how=how, sort=sort, validate="1:1")
  98. tm.assert_frame_equal(result, expected)
  99. def test_suffix_on_list_join():
  100. first = DataFrame({"key": [1, 2, 3, 4, 5]})
  101. second = DataFrame({"key": [1, 8, 3, 2, 5], "v1": [1, 2, 3, 4, 5]})
  102. third = DataFrame({"keys": [5, 2, 3, 4, 1], "v2": [1, 2, 3, 4, 5]})
  103. # check proper errors are raised
  104. msg = "Suffixes not supported when joining multiple DataFrames"
  105. with pytest.raises(ValueError, match=msg):
  106. first.join([second], lsuffix="y")
  107. with pytest.raises(ValueError, match=msg):
  108. first.join([second, third], rsuffix="x")
  109. with pytest.raises(ValueError, match=msg):
  110. first.join([second, third], lsuffix="y", rsuffix="x")
  111. with pytest.raises(ValueError, match="Indexes have overlapping values"):
  112. first.join([second, third])
  113. # no errors should be raised
  114. arr_joined = first.join([third])
  115. norm_joined = first.join(third)
  116. tm.assert_frame_equal(arr_joined, norm_joined)
  117. def test_join_invalid_validate(left_no_dup, right_no_dup):
  118. # GH 46622
  119. # Check invalid arguments
  120. msg = (
  121. '"invalid" is not a valid argument. '
  122. "Valid arguments are:\n"
  123. '- "1:1"\n'
  124. '- "1:m"\n'
  125. '- "m:1"\n'
  126. '- "m:m"\n'
  127. '- "one_to_one"\n'
  128. '- "one_to_many"\n'
  129. '- "many_to_one"\n'
  130. '- "many_to_many"'
  131. )
  132. with pytest.raises(ValueError, match=msg):
  133. left_no_dup.merge(right_no_dup, on="a", validate="invalid")
  134. def test_join_on_single_col_dup_on_right(left_no_dup, right_w_dups):
  135. # GH 46622
  136. # Dups on right allowed by one_to_many constraint
  137. left_no_dup.join(
  138. right_w_dups,
  139. on="a",
  140. validate="one_to_many",
  141. )
  142. # Dups on right not allowed by one_to_one constraint
  143. msg = "Merge keys are not unique in right dataset; not a one-to-one merge"
  144. with pytest.raises(MergeError, match=msg):
  145. left_no_dup.join(
  146. right_w_dups,
  147. on="a",
  148. validate="one_to_one",
  149. )
  150. def test_join_on_single_col_dup_on_left(left_w_dups, right_no_dup):
  151. # GH 46622
  152. # Dups on left allowed by many_to_one constraint
  153. left_w_dups.join(
  154. right_no_dup,
  155. on="a",
  156. validate="many_to_one",
  157. )
  158. # Dups on left not allowed by one_to_one constraint
  159. msg = "Merge keys are not unique in left dataset; not a one-to-one merge"
  160. with pytest.raises(MergeError, match=msg):
  161. left_w_dups.join(
  162. right_no_dup,
  163. on="a",
  164. validate="one_to_one",
  165. )
  166. def test_join_on_single_col_dup_on_both(left_w_dups, right_w_dups):
  167. # GH 46622
  168. # Dups on both allowed by many_to_many constraint
  169. left_w_dups.join(right_w_dups, on="a", validate="many_to_many")
  170. # Dups on both not allowed by many_to_one constraint
  171. msg = "Merge keys are not unique in right dataset; not a many-to-one merge"
  172. with pytest.raises(MergeError, match=msg):
  173. left_w_dups.join(
  174. right_w_dups,
  175. on="a",
  176. validate="many_to_one",
  177. )
  178. # Dups on both not allowed by one_to_many constraint
  179. msg = "Merge keys are not unique in left dataset; not a one-to-many merge"
  180. with pytest.raises(MergeError, match=msg):
  181. left_w_dups.join(
  182. right_w_dups,
  183. on="a",
  184. validate="one_to_many",
  185. )
  186. def test_join_on_multi_col_check_dup():
  187. # GH 46622
  188. # Two column join, dups in both, but jointly no dups
  189. left = DataFrame(
  190. {
  191. "a": ["a", "a", "b", "b"],
  192. "b": [0, 1, 0, 1],
  193. "c": ["cat", "dog", "weasel", "horse"],
  194. },
  195. index=range(4),
  196. ).set_index(["a", "b"])
  197. right = DataFrame(
  198. {
  199. "a": ["a", "a", "b"],
  200. "b": [0, 1, 0],
  201. "d": ["meow", "bark", "um... weasel noise?"],
  202. },
  203. index=range(3),
  204. ).set_index(["a", "b"])
  205. expected_multi = DataFrame(
  206. {
  207. "a": ["a", "a", "b"],
  208. "b": [0, 1, 0],
  209. "c": ["cat", "dog", "weasel"],
  210. "d": ["meow", "bark", "um... weasel noise?"],
  211. },
  212. index=range(3),
  213. ).set_index(["a", "b"])
  214. # Jointly no dups allowed by one_to_one constraint
  215. result = left.join(right, how="inner", validate="1:1")
  216. tm.assert_frame_equal(result, expected_multi)
  217. def test_join_index(float_frame):
  218. # left / right
  219. f = float_frame.loc[float_frame.index[:10], ["A", "B"]]
  220. f2 = float_frame.loc[float_frame.index[5:], ["C", "D"]].iloc[::-1]
  221. joined = f.join(f2)
  222. tm.assert_index_equal(f.index, joined.index)
  223. expected_columns = Index(["A", "B", "C", "D"])
  224. tm.assert_index_equal(joined.columns, expected_columns)
  225. joined = f.join(f2, how="left")
  226. tm.assert_index_equal(joined.index, f.index)
  227. tm.assert_index_equal(joined.columns, expected_columns)
  228. joined = f.join(f2, how="right")
  229. tm.assert_index_equal(joined.index, f2.index)
  230. tm.assert_index_equal(joined.columns, expected_columns)
  231. # inner
  232. joined = f.join(f2, how="inner")
  233. tm.assert_index_equal(joined.index, f.index[5:10])
  234. tm.assert_index_equal(joined.columns, expected_columns)
  235. # outer
  236. joined = f.join(f2, how="outer")
  237. tm.assert_index_equal(joined.index, float_frame.index.sort_values())
  238. tm.assert_index_equal(joined.columns, expected_columns)
  239. with pytest.raises(ValueError, match="join method"):
  240. f.join(f2, how="foo")
  241. # corner case - overlapping columns
  242. msg = "columns overlap but no suffix"
  243. for how in ("outer", "left", "inner"):
  244. with pytest.raises(ValueError, match=msg):
  245. float_frame.join(float_frame, how=how)
  246. def test_join_index_more(float_frame):
  247. af = float_frame.loc[:, ["A", "B"]]
  248. bf = float_frame.loc[::2, ["C", "D"]]
  249. expected = af.copy()
  250. expected["C"] = float_frame["C"][::2]
  251. expected["D"] = float_frame["D"][::2]
  252. result = af.join(bf)
  253. tm.assert_frame_equal(result, expected)
  254. result = af.join(bf, how="right")
  255. tm.assert_frame_equal(result, expected[::2])
  256. result = bf.join(af, how="right")
  257. tm.assert_frame_equal(result, expected.loc[:, result.columns])
  258. def test_join_index_series(float_frame):
  259. df = float_frame.copy()
  260. ser = df.pop(float_frame.columns[-1])
  261. joined = df.join(ser)
  262. tm.assert_frame_equal(joined, float_frame)
  263. ser.name = None
  264. with pytest.raises(ValueError, match="must have a name"):
  265. df.join(ser)
  266. def test_join_overlap(float_frame):
  267. df1 = float_frame.loc[:, ["A", "B", "C"]]
  268. df2 = float_frame.loc[:, ["B", "C", "D"]]
  269. joined = df1.join(df2, lsuffix="_df1", rsuffix="_df2")
  270. df1_suf = df1.loc[:, ["B", "C"]].add_suffix("_df1")
  271. df2_suf = df2.loc[:, ["B", "C"]].add_suffix("_df2")
  272. no_overlap = float_frame.loc[:, ["A", "D"]]
  273. expected = df1_suf.join(df2_suf).join(no_overlap)
  274. # column order not necessarily sorted
  275. tm.assert_frame_equal(joined, expected.loc[:, joined.columns])
  276. def test_join_period_index(frame_with_period_index):
  277. other = frame_with_period_index.rename(columns=lambda key: f"{key}{key}")
  278. joined_values = np.concatenate([frame_with_period_index.values] * 2, axis=1)
  279. joined_cols = frame_with_period_index.columns.append(other.columns)
  280. joined = frame_with_period_index.join(other)
  281. expected = DataFrame(
  282. data=joined_values, columns=joined_cols, index=frame_with_period_index.index
  283. )
  284. tm.assert_frame_equal(joined, expected)
  285. def test_join_left_sequence_non_unique_index():
  286. # https://github.com/pandas-dev/pandas/issues/19607
  287. df1 = DataFrame({"a": [0, 10, 20]}, index=[1, 2, 3])
  288. df2 = DataFrame({"b": [100, 200, 300]}, index=[4, 3, 2])
  289. df3 = DataFrame({"c": [400, 500, 600]}, index=[2, 2, 4])
  290. joined = df1.join([df2, df3], how="left")
  291. expected = DataFrame(
  292. {
  293. "a": [0, 10, 10, 20],
  294. "b": [np.nan, 300, 300, 200],
  295. "c": [np.nan, 400, 500, np.nan],
  296. },
  297. index=[1, 2, 2, 3],
  298. )
  299. tm.assert_frame_equal(joined, expected)
  300. def test_join_list_series(float_frame):
  301. # GH#46850
  302. # Join a DataFrame with a list containing both a Series and a DataFrame
  303. left = float_frame.A.to_frame()
  304. right = [float_frame.B, float_frame[["C", "D"]]]
  305. result = left.join(right)
  306. tm.assert_frame_equal(result, float_frame)
  307. @pytest.mark.parametrize("sort_kw", [True, False])
  308. def test_suppress_future_warning_with_sort_kw(sort_kw):
  309. a = DataFrame({"col1": [1, 2]}, index=["c", "a"])
  310. b = DataFrame({"col2": [4, 5]}, index=["b", "a"])
  311. c = DataFrame({"col3": [7, 8]}, index=["a", "b"])
  312. expected = DataFrame(
  313. {
  314. "col1": {"a": 2.0, "b": float("nan"), "c": 1.0},
  315. "col2": {"a": 5.0, "b": 4.0, "c": float("nan")},
  316. "col3": {"a": 7.0, "b": 8.0, "c": float("nan")},
  317. }
  318. )
  319. if sort_kw is False:
  320. expected = expected.reindex(index=["c", "a", "b"])
  321. with tm.assert_produces_warning(None):
  322. result = a.join([b, c], how="outer", sort=sort_kw)
  323. tm.assert_frame_equal(result, expected)
  324. class TestDataFrameJoin:
  325. def test_join(self, multiindex_dataframe_random_data):
  326. frame = multiindex_dataframe_random_data
  327. a = frame.loc[frame.index[:5], ["A"]]
  328. b = frame.loc[frame.index[2:], ["B", "C"]]
  329. joined = a.join(b, how="outer").reindex(frame.index)
  330. expected = frame.copy().values.copy()
  331. expected[np.isnan(joined.values)] = np.nan
  332. expected = DataFrame(expected, index=frame.index, columns=frame.columns)
  333. assert not np.isnan(joined.values).all()
  334. tm.assert_frame_equal(joined, expected)
  335. def test_join_segfault(self):
  336. # GH#1532
  337. df1 = DataFrame({"a": [1, 1], "b": [1, 2], "x": [1, 2]})
  338. df2 = DataFrame({"a": [2, 2], "b": [1, 2], "y": [1, 2]})
  339. df1 = df1.set_index(["a", "b"])
  340. df2 = df2.set_index(["a", "b"])
  341. # it works!
  342. for how in ["left", "right", "outer"]:
  343. df1.join(df2, how=how)
  344. def test_join_str_datetime(self):
  345. str_dates = ["20120209", "20120222"]
  346. dt_dates = [datetime(2012, 2, 9), datetime(2012, 2, 22)]
  347. A = DataFrame(str_dates, index=range(2), columns=["aa"])
  348. C = DataFrame([[1, 2], [3, 4]], index=str_dates, columns=dt_dates)
  349. tst = A.join(C, on="aa")
  350. assert len(tst.columns) == 3
  351. def test_join_multiindex_leftright(self):
  352. # GH 10741
  353. df1 = DataFrame(
  354. [
  355. ["a", "x", 0.471780],
  356. ["a", "y", 0.774908],
  357. ["a", "z", 0.563634],
  358. ["b", "x", -0.353756],
  359. ["b", "y", 0.368062],
  360. ["b", "z", -1.721840],
  361. ["c", "x", 1],
  362. ["c", "y", 2],
  363. ["c", "z", 3],
  364. ],
  365. columns=["first", "second", "value1"],
  366. ).set_index(["first", "second"])
  367. df2 = DataFrame([["a", 10], ["b", 20]], columns=["first", "value2"]).set_index(
  368. ["first"]
  369. )
  370. exp = DataFrame(
  371. [
  372. [0.471780, 10],
  373. [0.774908, 10],
  374. [0.563634, 10],
  375. [-0.353756, 20],
  376. [0.368062, 20],
  377. [-1.721840, 20],
  378. [1.000000, np.nan],
  379. [2.000000, np.nan],
  380. [3.000000, np.nan],
  381. ],
  382. index=df1.index,
  383. columns=["value1", "value2"],
  384. )
  385. # these must be the same results (but columns are flipped)
  386. tm.assert_frame_equal(df1.join(df2, how="left"), exp)
  387. tm.assert_frame_equal(df2.join(df1, how="right"), exp[["value2", "value1"]])
  388. exp_idx = MultiIndex.from_product(
  389. [["a", "b"], ["x", "y", "z"]], names=["first", "second"]
  390. )
  391. exp = DataFrame(
  392. [
  393. [0.471780, 10],
  394. [0.774908, 10],
  395. [0.563634, 10],
  396. [-0.353756, 20],
  397. [0.368062, 20],
  398. [-1.721840, 20],
  399. ],
  400. index=exp_idx,
  401. columns=["value1", "value2"],
  402. )
  403. tm.assert_frame_equal(df1.join(df2, how="right"), exp)
  404. tm.assert_frame_equal(df2.join(df1, how="left"), exp[["value2", "value1"]])
  405. def test_join_multiindex_dates(self):
  406. # GH 33692
  407. date = pd.Timestamp(2000, 1, 1).date()
  408. df1_index = MultiIndex.from_tuples([(0, date)], names=["index_0", "date"])
  409. df1 = DataFrame({"col1": [0]}, index=df1_index)
  410. df2_index = MultiIndex.from_tuples([(0, date)], names=["index_0", "date"])
  411. df2 = DataFrame({"col2": [0]}, index=df2_index)
  412. df3_index = MultiIndex.from_tuples([(0, date)], names=["index_0", "date"])
  413. df3 = DataFrame({"col3": [0]}, index=df3_index)
  414. result = df1.join([df2, df3])
  415. expected_index = MultiIndex.from_tuples([(0, date)], names=["index_0", "date"])
  416. expected = DataFrame(
  417. {"col1": [0], "col2": [0], "col3": [0]}, index=expected_index
  418. )
  419. tm.assert_equal(result, expected)
  420. def test_merge_join_different_levels_raises(self):
  421. # GH#9455
  422. # GH 40993: For raising, enforced in 2.0
  423. # first dataframe
  424. df1 = DataFrame(columns=["a", "b"], data=[[1, 11], [0, 22]])
  425. # second dataframe
  426. columns = MultiIndex.from_tuples([("a", ""), ("c", "c1")])
  427. df2 = DataFrame(columns=columns, data=[[1, 33], [0, 44]])
  428. # merge
  429. with pytest.raises(
  430. MergeError, match="Not allowed to merge between different levels"
  431. ):
  432. pd.merge(df1, df2, on="a")
  433. # join, see discussion in GH#12219
  434. with pytest.raises(
  435. MergeError, match="Not allowed to merge between different levels"
  436. ):
  437. df1.join(df2, on="a")
  438. def test_frame_join_tzaware(self):
  439. test1 = DataFrame(
  440. np.zeros((6, 3)),
  441. index=date_range(
  442. "2012-11-15 00:00:00", periods=6, freq="100L", tz="US/Central"
  443. ),
  444. )
  445. test2 = DataFrame(
  446. np.zeros((3, 3)),
  447. index=date_range(
  448. "2012-11-15 00:00:00", periods=3, freq="250L", tz="US/Central"
  449. ),
  450. columns=range(3, 6),
  451. )
  452. result = test1.join(test2, how="outer")
  453. expected = test1.index.union(test2.index)
  454. tm.assert_index_equal(result.index, expected)
  455. assert result.index.tz.zone == "US/Central"