join.pyx 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897
  1. cimport cython
  2. from cython cimport Py_ssize_t
  3. import numpy as np
  4. cimport numpy as cnp
  5. from numpy cimport (
  6. int64_t,
  7. intp_t,
  8. ndarray,
  9. uint64_t,
  10. )
  11. cnp.import_array()
  12. from pandas._libs.algos import groupsort_indexer
  13. from pandas._libs.dtypes cimport (
  14. numeric_object_t,
  15. numeric_t,
  16. )
  17. @cython.wraparound(False)
  18. @cython.boundscheck(False)
  19. def inner_join(const intp_t[:] left, const intp_t[:] right,
  20. Py_ssize_t max_groups):
  21. cdef:
  22. Py_ssize_t i, j, k, count = 0
  23. intp_t[::1] left_sorter, right_sorter
  24. intp_t[::1] left_count, right_count
  25. intp_t[::1] left_indexer, right_indexer
  26. intp_t lc, rc
  27. Py_ssize_t left_pos = 0, right_pos = 0, position = 0
  28. Py_ssize_t offset
  29. left_sorter, left_count = groupsort_indexer(left, max_groups)
  30. right_sorter, right_count = groupsort_indexer(right, max_groups)
  31. with nogil:
  32. # First pass, determine size of result set, do not use the NA group
  33. for i in range(1, max_groups + 1):
  34. lc = left_count[i]
  35. rc = right_count[i]
  36. if rc > 0 and lc > 0:
  37. count += lc * rc
  38. left_indexer = np.empty(count, dtype=np.intp)
  39. right_indexer = np.empty(count, dtype=np.intp)
  40. with nogil:
  41. # exclude the NA group
  42. left_pos = left_count[0]
  43. right_pos = right_count[0]
  44. for i in range(1, max_groups + 1):
  45. lc = left_count[i]
  46. rc = right_count[i]
  47. if rc > 0 and lc > 0:
  48. for j in range(lc):
  49. offset = position + j * rc
  50. for k in range(rc):
  51. left_indexer[offset + k] = left_pos + j
  52. right_indexer[offset + k] = right_pos + k
  53. position += lc * rc
  54. left_pos += lc
  55. right_pos += rc
  56. # Will overwrite left/right indexer with the result
  57. _get_result_indexer(left_sorter, left_indexer)
  58. _get_result_indexer(right_sorter, right_indexer)
  59. return np.asarray(left_indexer), np.asarray(right_indexer)
  60. @cython.wraparound(False)
  61. @cython.boundscheck(False)
  62. def left_outer_join(const intp_t[:] left, const intp_t[:] right,
  63. Py_ssize_t max_groups, bint sort=True):
  64. cdef:
  65. Py_ssize_t i, j, k, count = 0
  66. ndarray[intp_t] rev
  67. intp_t[::1] left_count, right_count
  68. intp_t[::1] left_sorter, right_sorter
  69. intp_t[::1] left_indexer, right_indexer
  70. intp_t lc, rc
  71. Py_ssize_t left_pos = 0, right_pos = 0, position = 0
  72. Py_ssize_t offset
  73. left_sorter, left_count = groupsort_indexer(left, max_groups)
  74. right_sorter, right_count = groupsort_indexer(right, max_groups)
  75. with nogil:
  76. # First pass, determine size of result set, do not use the NA group
  77. for i in range(1, max_groups + 1):
  78. lc = left_count[i]
  79. rc = right_count[i]
  80. if rc > 0:
  81. count += lc * rc
  82. else:
  83. count += lc
  84. left_indexer = np.empty(count, dtype=np.intp)
  85. right_indexer = np.empty(count, dtype=np.intp)
  86. with nogil:
  87. # exclude the NA group
  88. left_pos = left_count[0]
  89. right_pos = right_count[0]
  90. for i in range(1, max_groups + 1):
  91. lc = left_count[i]
  92. rc = right_count[i]
  93. if rc == 0:
  94. for j in range(lc):
  95. left_indexer[position + j] = left_pos + j
  96. right_indexer[position + j] = -1
  97. position += lc
  98. else:
  99. for j in range(lc):
  100. offset = position + j * rc
  101. for k in range(rc):
  102. left_indexer[offset + k] = left_pos + j
  103. right_indexer[offset + k] = right_pos + k
  104. position += lc * rc
  105. left_pos += lc
  106. right_pos += rc
  107. # Will overwrite left/right indexer with the result
  108. _get_result_indexer(left_sorter, left_indexer)
  109. _get_result_indexer(right_sorter, right_indexer)
  110. if not sort: # if not asked to sort, revert to original order
  111. if len(left) == len(left_indexer):
  112. # no multiple matches for any row on the left
  113. # this is a short-cut to avoid groupsort_indexer
  114. # otherwise, the `else` path also works in this case
  115. rev = np.empty(len(left), dtype=np.intp)
  116. rev.put(np.asarray(left_sorter), np.arange(len(left)))
  117. else:
  118. rev, _ = groupsort_indexer(left_indexer, len(left))
  119. return np.asarray(left_indexer).take(rev), np.asarray(right_indexer).take(rev)
  120. else:
  121. return np.asarray(left_indexer), np.asarray(right_indexer)
  122. @cython.wraparound(False)
  123. @cython.boundscheck(False)
  124. def full_outer_join(const intp_t[:] left, const intp_t[:] right,
  125. Py_ssize_t max_groups):
  126. cdef:
  127. Py_ssize_t i, j, k, count = 0
  128. intp_t[::1] left_sorter, right_sorter
  129. intp_t[::1] left_count, right_count
  130. intp_t[::1] left_indexer, right_indexer
  131. intp_t lc, rc
  132. intp_t left_pos = 0, right_pos = 0
  133. Py_ssize_t offset, position = 0
  134. left_sorter, left_count = groupsort_indexer(left, max_groups)
  135. right_sorter, right_count = groupsort_indexer(right, max_groups)
  136. with nogil:
  137. # First pass, determine size of result set, do not use the NA group
  138. for i in range(1, max_groups + 1):
  139. lc = left_count[i]
  140. rc = right_count[i]
  141. if rc > 0 and lc > 0:
  142. count += lc * rc
  143. else:
  144. count += lc + rc
  145. left_indexer = np.empty(count, dtype=np.intp)
  146. right_indexer = np.empty(count, dtype=np.intp)
  147. with nogil:
  148. # exclude the NA group
  149. left_pos = left_count[0]
  150. right_pos = right_count[0]
  151. for i in range(1, max_groups + 1):
  152. lc = left_count[i]
  153. rc = right_count[i]
  154. if rc == 0:
  155. for j in range(lc):
  156. left_indexer[position + j] = left_pos + j
  157. right_indexer[position + j] = -1
  158. position += lc
  159. elif lc == 0:
  160. for j in range(rc):
  161. left_indexer[position + j] = -1
  162. right_indexer[position + j] = right_pos + j
  163. position += rc
  164. else:
  165. for j in range(lc):
  166. offset = position + j * rc
  167. for k in range(rc):
  168. left_indexer[offset + k] = left_pos + j
  169. right_indexer[offset + k] = right_pos + k
  170. position += lc * rc
  171. left_pos += lc
  172. right_pos += rc
  173. # Will overwrite left/right indexer with the result
  174. _get_result_indexer(left_sorter, left_indexer)
  175. _get_result_indexer(right_sorter, right_indexer)
  176. return np.asarray(left_indexer), np.asarray(right_indexer)
  177. @cython.wraparound(False)
  178. @cython.boundscheck(False)
  179. cdef void _get_result_indexer(intp_t[::1] sorter, intp_t[::1] indexer) nogil:
  180. """NOTE: overwrites indexer with the result to avoid allocating another array"""
  181. cdef:
  182. Py_ssize_t i, n, idx
  183. if len(sorter) > 0:
  184. # cython-only equivalent to
  185. # `res = algos.take_nd(sorter, indexer, fill_value=-1)`
  186. n = indexer.shape[0]
  187. for i in range(n):
  188. idx = indexer[i]
  189. if idx == -1:
  190. indexer[i] = -1
  191. else:
  192. indexer[i] = sorter[idx]
  193. else:
  194. # length-0 case
  195. indexer[:] = -1
  196. @cython.wraparound(False)
  197. @cython.boundscheck(False)
  198. def ffill_indexer(const intp_t[:] indexer) -> np.ndarray:
  199. cdef:
  200. Py_ssize_t i, n = len(indexer)
  201. ndarray[intp_t] result
  202. intp_t val, last_obs
  203. result = np.empty(n, dtype=np.intp)
  204. last_obs = -1
  205. for i in range(n):
  206. val = indexer[i]
  207. if val == -1:
  208. result[i] = last_obs
  209. else:
  210. result[i] = val
  211. last_obs = val
  212. return result
  213. # ----------------------------------------------------------------------
  214. # left_join_indexer, inner_join_indexer, outer_join_indexer
  215. # ----------------------------------------------------------------------
  216. # Joins on ordered, unique indices
  217. # right might contain non-unique values
  218. @cython.wraparound(False)
  219. @cython.boundscheck(False)
  220. def left_join_indexer_unique(
  221. ndarray[numeric_object_t] left,
  222. ndarray[numeric_object_t] right
  223. ):
  224. """
  225. Both left and right are strictly monotonic increasing.
  226. """
  227. cdef:
  228. Py_ssize_t i, j, nleft, nright
  229. ndarray[intp_t] indexer
  230. numeric_object_t rval
  231. i = 0
  232. j = 0
  233. nleft = len(left)
  234. nright = len(right)
  235. indexer = np.empty(nleft, dtype=np.intp)
  236. while True:
  237. if i == nleft:
  238. break
  239. if j == nright:
  240. indexer[i] = -1
  241. i += 1
  242. continue
  243. rval = right[j]
  244. while i < nleft - 1 and left[i] == rval:
  245. indexer[i] = j
  246. i += 1
  247. if left[i] == rval:
  248. indexer[i] = j
  249. i += 1
  250. while i < nleft - 1 and left[i] == rval:
  251. indexer[i] = j
  252. i += 1
  253. j += 1
  254. elif left[i] > rval:
  255. indexer[i] = -1
  256. j += 1
  257. else:
  258. indexer[i] = -1
  259. i += 1
  260. return indexer
  261. @cython.wraparound(False)
  262. @cython.boundscheck(False)
  263. def left_join_indexer(ndarray[numeric_object_t] left, ndarray[numeric_object_t] right):
  264. """
  265. Two-pass algorithm for monotonic indexes. Handles many-to-one merges.
  266. Both left and right are monotonic increasing, but at least one of them
  267. is non-unique (if both were unique we'd use left_join_indexer_unique).
  268. """
  269. cdef:
  270. Py_ssize_t i, j, nright, nleft, count
  271. numeric_object_t lval, rval
  272. ndarray[intp_t] lindexer, rindexer
  273. ndarray[numeric_object_t] result
  274. nleft = len(left)
  275. nright = len(right)
  276. # First pass is to find the size 'count' of our output indexers.
  277. i = 0
  278. j = 0
  279. count = 0
  280. if nleft > 0:
  281. while i < nleft:
  282. if j == nright:
  283. count += nleft - i
  284. break
  285. lval = left[i]
  286. rval = right[j]
  287. if lval == rval:
  288. # This block is identical across
  289. # left_join_indexer, inner_join_indexer, outer_join_indexer
  290. count += 1
  291. if i < nleft - 1:
  292. if j < nright - 1 and right[j + 1] == rval:
  293. j += 1
  294. else:
  295. i += 1
  296. if left[i] != rval:
  297. j += 1
  298. elif j < nright - 1:
  299. j += 1
  300. if lval != right[j]:
  301. i += 1
  302. else:
  303. # end of the road
  304. break
  305. elif lval < rval:
  306. count += 1
  307. i += 1
  308. else:
  309. j += 1
  310. # do it again now that result size is known
  311. lindexer = np.empty(count, dtype=np.intp)
  312. rindexer = np.empty(count, dtype=np.intp)
  313. result = np.empty(count, dtype=left.dtype)
  314. i = 0
  315. j = 0
  316. count = 0
  317. if nleft > 0:
  318. while i < nleft:
  319. if j == nright:
  320. while i < nleft:
  321. lindexer[count] = i
  322. rindexer[count] = -1
  323. result[count] = left[i]
  324. i += 1
  325. count += 1
  326. break
  327. lval = left[i]
  328. rval = right[j]
  329. if lval == rval:
  330. lindexer[count] = i
  331. rindexer[count] = j
  332. result[count] = lval
  333. count += 1
  334. if i < nleft - 1:
  335. if j < nright - 1 and right[j + 1] == rval:
  336. j += 1
  337. else:
  338. i += 1
  339. if left[i] != rval:
  340. j += 1
  341. elif j < nright - 1:
  342. j += 1
  343. if lval != right[j]:
  344. i += 1
  345. else:
  346. # end of the road
  347. break
  348. elif lval < rval:
  349. # i.e. lval not in right; we keep for left_join_indexer
  350. lindexer[count] = i
  351. rindexer[count] = -1
  352. result[count] = lval
  353. count += 1
  354. i += 1
  355. else:
  356. # i.e. rval not in left; we discard for left_join_indexer
  357. j += 1
  358. return result, lindexer, rindexer
  359. @cython.wraparound(False)
  360. @cython.boundscheck(False)
  361. def inner_join_indexer(ndarray[numeric_object_t] left, ndarray[numeric_object_t] right):
  362. """
  363. Two-pass algorithm for monotonic indexes. Handles many-to-one merges.
  364. Both left and right are monotonic increasing but not necessarily unique.
  365. """
  366. cdef:
  367. Py_ssize_t i, j, nright, nleft, count
  368. numeric_object_t lval, rval
  369. ndarray[intp_t] lindexer, rindexer
  370. ndarray[numeric_object_t] result
  371. nleft = len(left)
  372. nright = len(right)
  373. # First pass is to find the size 'count' of our output indexers.
  374. i = 0
  375. j = 0
  376. count = 0
  377. if nleft > 0 and nright > 0:
  378. while True:
  379. if i == nleft:
  380. break
  381. if j == nright:
  382. break
  383. lval = left[i]
  384. rval = right[j]
  385. if lval == rval:
  386. count += 1
  387. if i < nleft - 1:
  388. if j < nright - 1 and right[j + 1] == rval:
  389. j += 1
  390. else:
  391. i += 1
  392. if left[i] != rval:
  393. j += 1
  394. elif j < nright - 1:
  395. j += 1
  396. if lval != right[j]:
  397. i += 1
  398. else:
  399. # end of the road
  400. break
  401. elif lval < rval:
  402. # i.e. lval not in right; we discard for inner_indexer
  403. i += 1
  404. else:
  405. # i.e. rval not in left; we discard for inner_indexer
  406. j += 1
  407. # do it again now that result size is known
  408. lindexer = np.empty(count, dtype=np.intp)
  409. rindexer = np.empty(count, dtype=np.intp)
  410. result = np.empty(count, dtype=left.dtype)
  411. i = 0
  412. j = 0
  413. count = 0
  414. if nleft > 0 and nright > 0:
  415. while True:
  416. if i == nleft:
  417. break
  418. if j == nright:
  419. break
  420. lval = left[i]
  421. rval = right[j]
  422. if lval == rval:
  423. lindexer[count] = i
  424. rindexer[count] = j
  425. result[count] = lval
  426. count += 1
  427. if i < nleft - 1:
  428. if j < nright - 1 and right[j + 1] == rval:
  429. j += 1
  430. else:
  431. i += 1
  432. if left[i] != rval:
  433. j += 1
  434. elif j < nright - 1:
  435. j += 1
  436. if lval != right[j]:
  437. i += 1
  438. else:
  439. # end of the road
  440. break
  441. elif lval < rval:
  442. # i.e. lval not in right; we discard for inner_indexer
  443. i += 1
  444. else:
  445. # i.e. rval not in left; we discard for inner_indexer
  446. j += 1
  447. return result, lindexer, rindexer
  448. @cython.wraparound(False)
  449. @cython.boundscheck(False)
  450. def outer_join_indexer(ndarray[numeric_object_t] left, ndarray[numeric_object_t] right):
  451. """
  452. Both left and right are monotonic increasing but not necessarily unique.
  453. """
  454. cdef:
  455. Py_ssize_t i, j, nright, nleft, count
  456. numeric_object_t lval, rval
  457. ndarray[intp_t] lindexer, rindexer
  458. ndarray[numeric_object_t] result
  459. nleft = len(left)
  460. nright = len(right)
  461. # First pass is to find the size 'count' of our output indexers.
  462. # count will be length of left plus the number of elements of right not in
  463. # left (counting duplicates)
  464. i = 0
  465. j = 0
  466. count = 0
  467. if nleft == 0:
  468. count = nright
  469. elif nright == 0:
  470. count = nleft
  471. else:
  472. while True:
  473. if i == nleft:
  474. count += nright - j
  475. break
  476. if j == nright:
  477. count += nleft - i
  478. break
  479. lval = left[i]
  480. rval = right[j]
  481. if lval == rval:
  482. count += 1
  483. if i < nleft - 1:
  484. if j < nright - 1 and right[j + 1] == rval:
  485. j += 1
  486. else:
  487. i += 1
  488. if left[i] != rval:
  489. j += 1
  490. elif j < nright - 1:
  491. j += 1
  492. if lval != right[j]:
  493. i += 1
  494. else:
  495. # end of the road
  496. break
  497. elif lval < rval:
  498. count += 1
  499. i += 1
  500. else:
  501. count += 1
  502. j += 1
  503. lindexer = np.empty(count, dtype=np.intp)
  504. rindexer = np.empty(count, dtype=np.intp)
  505. result = np.empty(count, dtype=left.dtype)
  506. # do it again, but populate the indexers / result
  507. i = 0
  508. j = 0
  509. count = 0
  510. if nleft == 0:
  511. for j in range(nright):
  512. lindexer[j] = -1
  513. rindexer[j] = j
  514. result[j] = right[j]
  515. elif nright == 0:
  516. for i in range(nleft):
  517. lindexer[i] = i
  518. rindexer[i] = -1
  519. result[i] = left[i]
  520. else:
  521. while True:
  522. if i == nleft:
  523. while j < nright:
  524. lindexer[count] = -1
  525. rindexer[count] = j
  526. result[count] = right[j]
  527. count += 1
  528. j += 1
  529. break
  530. if j == nright:
  531. while i < nleft:
  532. lindexer[count] = i
  533. rindexer[count] = -1
  534. result[count] = left[i]
  535. count += 1
  536. i += 1
  537. break
  538. lval = left[i]
  539. rval = right[j]
  540. if lval == rval:
  541. lindexer[count] = i
  542. rindexer[count] = j
  543. result[count] = lval
  544. count += 1
  545. if i < nleft - 1:
  546. if j < nright - 1 and right[j + 1] == rval:
  547. j += 1
  548. else:
  549. i += 1
  550. if left[i] != rval:
  551. j += 1
  552. elif j < nright - 1:
  553. j += 1
  554. if lval != right[j]:
  555. i += 1
  556. else:
  557. # end of the road
  558. break
  559. elif lval < rval:
  560. # i.e. lval not in right; we keep for outer_join_indexer
  561. lindexer[count] = i
  562. rindexer[count] = -1
  563. result[count] = lval
  564. count += 1
  565. i += 1
  566. else:
  567. # i.e. rval not in left; we keep for outer_join_indexer
  568. lindexer[count] = -1
  569. rindexer[count] = j
  570. result[count] = rval
  571. count += 1
  572. j += 1
  573. return result, lindexer, rindexer
  574. # ----------------------------------------------------------------------
  575. # asof_join_by
  576. # ----------------------------------------------------------------------
  577. from pandas._libs.hashtable cimport (
  578. HashTable,
  579. Int64HashTable,
  580. PyObjectHashTable,
  581. UInt64HashTable,
  582. )
  583. ctypedef fused by_t:
  584. object
  585. int64_t
  586. uint64_t
  587. def asof_join_backward_on_X_by_Y(numeric_t[:] left_values,
  588. numeric_t[:] right_values,
  589. by_t[:] left_by_values,
  590. by_t[:] right_by_values,
  591. bint allow_exact_matches=True,
  592. tolerance=None,
  593. bint use_hashtable=True):
  594. cdef:
  595. Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos
  596. ndarray[intp_t] left_indexer, right_indexer
  597. bint has_tolerance = False
  598. numeric_t tolerance_ = 0
  599. numeric_t diff = 0
  600. HashTable hash_table
  601. by_t by_value
  602. # if we are using tolerance, set our objects
  603. if tolerance is not None:
  604. has_tolerance = True
  605. tolerance_ = tolerance
  606. left_size = len(left_values)
  607. right_size = len(right_values)
  608. left_indexer = np.empty(left_size, dtype=np.intp)
  609. right_indexer = np.empty(left_size, dtype=np.intp)
  610. if use_hashtable:
  611. if by_t is object:
  612. hash_table = PyObjectHashTable(right_size)
  613. elif by_t is int64_t:
  614. hash_table = Int64HashTable(right_size)
  615. elif by_t is uint64_t:
  616. hash_table = UInt64HashTable(right_size)
  617. right_pos = 0
  618. for left_pos in range(left_size):
  619. # restart right_pos if it went negative in a previous iteration
  620. if right_pos < 0:
  621. right_pos = 0
  622. # find last position in right whose value is less than left's
  623. if allow_exact_matches:
  624. while (right_pos < right_size and
  625. right_values[right_pos] <= left_values[left_pos]):
  626. if use_hashtable:
  627. hash_table.set_item(right_by_values[right_pos], right_pos)
  628. right_pos += 1
  629. else:
  630. while (right_pos < right_size and
  631. right_values[right_pos] < left_values[left_pos]):
  632. if use_hashtable:
  633. hash_table.set_item(right_by_values[right_pos], right_pos)
  634. right_pos += 1
  635. right_pos -= 1
  636. # save positions as the desired index
  637. if use_hashtable:
  638. by_value = left_by_values[left_pos]
  639. found_right_pos = (hash_table.get_item(by_value)
  640. if by_value in hash_table else -1)
  641. else:
  642. found_right_pos = right_pos
  643. left_indexer[left_pos] = left_pos
  644. right_indexer[left_pos] = found_right_pos
  645. # if needed, verify that tolerance is met
  646. if has_tolerance and found_right_pos != -1:
  647. diff = left_values[left_pos] - right_values[found_right_pos]
  648. if diff > tolerance_:
  649. right_indexer[left_pos] = -1
  650. return left_indexer, right_indexer
  651. def asof_join_forward_on_X_by_Y(numeric_t[:] left_values,
  652. numeric_t[:] right_values,
  653. by_t[:] left_by_values,
  654. by_t[:] right_by_values,
  655. bint allow_exact_matches=1,
  656. tolerance=None,
  657. bint use_hashtable=True):
  658. cdef:
  659. Py_ssize_t left_pos, right_pos, left_size, right_size, found_right_pos
  660. ndarray[intp_t] left_indexer, right_indexer
  661. bint has_tolerance = False
  662. numeric_t tolerance_ = 0
  663. numeric_t diff = 0
  664. HashTable hash_table
  665. by_t by_value
  666. # if we are using tolerance, set our objects
  667. if tolerance is not None:
  668. has_tolerance = True
  669. tolerance_ = tolerance
  670. left_size = len(left_values)
  671. right_size = len(right_values)
  672. left_indexer = np.empty(left_size, dtype=np.intp)
  673. right_indexer = np.empty(left_size, dtype=np.intp)
  674. if use_hashtable:
  675. if by_t is object:
  676. hash_table = PyObjectHashTable(right_size)
  677. elif by_t is int64_t:
  678. hash_table = Int64HashTable(right_size)
  679. elif by_t is uint64_t:
  680. hash_table = UInt64HashTable(right_size)
  681. right_pos = right_size - 1
  682. for left_pos in range(left_size - 1, -1, -1):
  683. # restart right_pos if it went over in a previous iteration
  684. if right_pos == right_size:
  685. right_pos = right_size - 1
  686. # find first position in right whose value is greater than left's
  687. if allow_exact_matches:
  688. while (right_pos >= 0 and
  689. right_values[right_pos] >= left_values[left_pos]):
  690. if use_hashtable:
  691. hash_table.set_item(right_by_values[right_pos], right_pos)
  692. right_pos -= 1
  693. else:
  694. while (right_pos >= 0 and
  695. right_values[right_pos] > left_values[left_pos]):
  696. if use_hashtable:
  697. hash_table.set_item(right_by_values[right_pos], right_pos)
  698. right_pos -= 1
  699. right_pos += 1
  700. # save positions as the desired index
  701. if use_hashtable:
  702. by_value = left_by_values[left_pos]
  703. found_right_pos = (hash_table.get_item(by_value)
  704. if by_value in hash_table else -1)
  705. else:
  706. found_right_pos = (right_pos
  707. if right_pos != right_size else -1)
  708. left_indexer[left_pos] = left_pos
  709. right_indexer[left_pos] = found_right_pos
  710. # if needed, verify that tolerance is met
  711. if has_tolerance and found_right_pos != -1:
  712. diff = right_values[found_right_pos] - left_values[left_pos]
  713. if diff > tolerance_:
  714. right_indexer[left_pos] = -1
  715. return left_indexer, right_indexer
  716. def asof_join_nearest_on_X_by_Y(ndarray[numeric_t] left_values,
  717. ndarray[numeric_t] right_values,
  718. ndarray[by_t] left_by_values,
  719. ndarray[by_t] right_by_values,
  720. bint allow_exact_matches=True,
  721. tolerance=None,
  722. bint use_hashtable=True):
  723. cdef:
  724. ndarray[intp_t] bli, bri, fli, fri
  725. ndarray[intp_t] left_indexer, right_indexer
  726. Py_ssize_t left_size, i
  727. numeric_t bdiff, fdiff
  728. # search both forward and backward
  729. # TODO(cython3):
  730. # Bug in beta1 preventing Cython from choosing
  731. # right specialization when one fused memview is None
  732. # Doesn't matter what type we choose
  733. # (nothing happens anyways since it is None)
  734. # GH 51640
  735. if left_by_values is not None and left_by_values.dtype != object:
  736. by_dtype = f"{left_by_values.dtype}_t"
  737. else:
  738. by_dtype = object
  739. bli, bri = asof_join_backward_on_X_by_Y[f"{left_values.dtype}_t", by_dtype](
  740. left_values,
  741. right_values,
  742. left_by_values,
  743. right_by_values,
  744. allow_exact_matches,
  745. tolerance,
  746. use_hashtable
  747. )
  748. fli, fri = asof_join_forward_on_X_by_Y[f"{left_values.dtype}_t", by_dtype](
  749. left_values,
  750. right_values,
  751. left_by_values,
  752. right_by_values,
  753. allow_exact_matches,
  754. tolerance,
  755. use_hashtable
  756. )
  757. # choose the smaller timestamp
  758. left_size = len(left_values)
  759. left_indexer = np.empty(left_size, dtype=np.intp)
  760. right_indexer = np.empty(left_size, dtype=np.intp)
  761. for i in range(len(bri)):
  762. # choose timestamp from right with smaller difference
  763. if bri[i] != -1 and fri[i] != -1:
  764. bdiff = left_values[bli[i]] - right_values[bri[i]]
  765. fdiff = right_values[fri[i]] - left_values[fli[i]]
  766. right_indexer[i] = bri[i] if bdiff <= fdiff else fri[i]
  767. else:
  768. right_indexer[i] = bri[i] if bri[i] != -1 else fri[i]
  769. left_indexer[i] = bli[i]
  770. return left_indexer, right_indexer