hashtable_func_helper.pxi.in 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. """
  2. Template for each `dtype` helper function for hashtable
  3. WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
  4. """
  5. {{py:
  6. # name, dtype, ttype, c_type, to_c_type
  7. dtypes = [('Complex128', 'complex128', 'complex128',
  8. 'khcomplex128_t', 'to_khcomplex128_t'),
  9. ('Complex64', 'complex64', 'complex64',
  10. 'khcomplex64_t', 'to_khcomplex64_t'),
  11. ('Float64', 'float64', 'float64', 'float64_t', ''),
  12. ('Float32', 'float32', 'float32', 'float32_t', ''),
  13. ('UInt64', 'uint64', 'uint64', 'uint64_t', ''),
  14. ('UInt32', 'uint32', 'uint32', 'uint32_t', ''),
  15. ('UInt16', 'uint16', 'uint16', 'uint16_t', ''),
  16. ('UInt8', 'uint8', 'uint8', 'uint8_t', ''),
  17. ('Object', 'object', 'pymap', 'object', '<PyObject*>'),
  18. ('Int64', 'int64', 'int64', 'int64_t', ''),
  19. ('Int32', 'int32', 'int32', 'int32_t', ''),
  20. ('Int16', 'int16', 'int16', 'int16_t', ''),
  21. ('Int8', 'int8', 'int8', 'int8_t', '')]
  22. }}
  23. {{for name, dtype, ttype, c_type, to_c_type in dtypes}}
  24. @cython.wraparound(False)
  25. @cython.boundscheck(False)
  26. {{if dtype == 'object'}}
  27. cdef value_count_{{dtype}}(ndarray[{{dtype}}] values, bint dropna, const uint8_t[:] mask=None):
  28. {{else}}
  29. cdef value_count_{{dtype}}(const {{dtype}}_t[:] values, bint dropna, const uint8_t[:] mask=None):
  30. {{endif}}
  31. cdef:
  32. Py_ssize_t i = 0
  33. Py_ssize_t n = len(values)
  34. kh_{{ttype}}_t *table
  35. # Don't use Py_ssize_t, since table.n_buckets is unsigned
  36. khiter_t k
  37. {{c_type}} val
  38. int ret = 0
  39. bint uses_mask = mask is not None
  40. bint isna_entry = False
  41. if uses_mask and not dropna:
  42. raise NotImplementedError("uses_mask not implemented with dropna=False")
  43. # we track the order in which keys are first seen (GH39009),
  44. # khash-map isn't insertion-ordered, thus:
  45. # table maps keys to counts
  46. # result_keys remembers the original order of keys
  47. result_keys = {{name}}Vector()
  48. table = kh_init_{{ttype}}()
  49. {{if dtype == 'object'}}
  50. if uses_mask:
  51. raise NotImplementedError("uses_mask not implemented with object dtype")
  52. kh_resize_{{ttype}}(table, n // 10)
  53. for i in range(n):
  54. val = values[i]
  55. if not dropna or not checknull(val):
  56. k = kh_get_{{ttype}}(table, {{to_c_type}}val)
  57. if k != table.n_buckets:
  58. table.vals[k] += 1
  59. else:
  60. k = kh_put_{{ttype}}(table, {{to_c_type}}val, &ret)
  61. table.vals[k] = 1
  62. result_keys.append(val)
  63. {{else}}
  64. kh_resize_{{ttype}}(table, n)
  65. for i in range(n):
  66. val = {{to_c_type}}(values[i])
  67. if dropna:
  68. if uses_mask:
  69. isna_entry = mask[i]
  70. else:
  71. isna_entry = is_nan_{{c_type}}(val)
  72. if not dropna or not isna_entry:
  73. k = kh_get_{{ttype}}(table, val)
  74. if k != table.n_buckets:
  75. table.vals[k] += 1
  76. else:
  77. k = kh_put_{{ttype}}(table, val, &ret)
  78. table.vals[k] = 1
  79. result_keys.append(val)
  80. {{endif}}
  81. # collect counts in the order corresponding to result_keys:
  82. cdef:
  83. int64_t[::1] result_counts = np.empty(table.size, dtype=np.int64)
  84. for i in range(table.size):
  85. {{if dtype == 'object'}}
  86. k = kh_get_{{ttype}}(table, result_keys.data[i])
  87. {{else}}
  88. k = kh_get_{{ttype}}(table, result_keys.data.data[i])
  89. {{endif}}
  90. result_counts[i] = table.vals[k]
  91. kh_destroy_{{ttype}}(table)
  92. return result_keys.to_array(), result_counts.base
  93. @cython.wraparound(False)
  94. @cython.boundscheck(False)
  95. {{if dtype == 'object'}}
  96. cdef duplicated_{{dtype}}(ndarray[{{dtype}}] values, object keep='first', const uint8_t[:] mask=None):
  97. {{else}}
  98. cdef duplicated_{{dtype}}(const {{dtype}}_t[:] values, object keep='first', const uint8_t[:] mask=None):
  99. {{endif}}
  100. cdef:
  101. int ret = 0
  102. {{if dtype != 'object'}}
  103. {{c_type}} value
  104. {{else}}
  105. PyObject* value
  106. {{endif}}
  107. Py_ssize_t i, n = len(values), first_na = -1
  108. khiter_t k
  109. kh_{{ttype}}_t *table = kh_init_{{ttype}}()
  110. ndarray[uint8_t, ndim=1, cast=True] out = np.empty(n, dtype='bool')
  111. bint seen_na = False, uses_mask = mask is not None
  112. bint seen_multiple_na = False
  113. kh_resize_{{ttype}}(table, min(kh_needed_n_buckets(n), SIZE_HINT_LIMIT))
  114. if keep not in ('last', 'first', False):
  115. raise ValueError('keep must be either "first", "last" or False')
  116. {{for cond, keep in [('if', '"last"'), ('elif', '"first"')]}}
  117. {{cond}} keep == {{keep}}:
  118. {{if dtype == 'object'}}
  119. if True:
  120. {{else}}
  121. with nogil:
  122. {{endif}}
  123. {{if keep == '"last"'}}
  124. for i in range(n - 1, -1, -1):
  125. {{else}}
  126. for i in range(n):
  127. {{endif}}
  128. if uses_mask and mask[i]:
  129. if seen_na:
  130. out[i] = True
  131. else:
  132. out[i] = False
  133. seen_na = True
  134. else:
  135. value = {{to_c_type}}(values[i])
  136. kh_put_{{ttype}}(table, value, &ret)
  137. out[i] = ret == 0
  138. {{endfor}}
  139. else:
  140. {{if dtype == 'object'}}
  141. if True:
  142. {{else}}
  143. with nogil:
  144. {{endif}}
  145. for i in range(n):
  146. if uses_mask and mask[i]:
  147. if not seen_na:
  148. first_na = i
  149. seen_na = True
  150. out[i] = 0
  151. elif not seen_multiple_na:
  152. out[i] = 1
  153. out[first_na] = 1
  154. seen_multiple_na = True
  155. else:
  156. out[i] = 1
  157. else:
  158. value = {{to_c_type}}(values[i])
  159. k = kh_get_{{ttype}}(table, value)
  160. if k != table.n_buckets:
  161. out[table.vals[k]] = 1
  162. out[i] = 1
  163. else:
  164. k = kh_put_{{ttype}}(table, value, &ret)
  165. table.vals[k] = i
  166. out[i] = 0
  167. kh_destroy_{{ttype}}(table)
  168. return out
  169. # ----------------------------------------------------------------------
  170. # Membership
  171. # ----------------------------------------------------------------------
  172. @cython.wraparound(False)
  173. @cython.boundscheck(False)
  174. {{if dtype == 'object'}}
  175. cdef ismember_{{dtype}}(ndarray[{{c_type}}] arr, ndarray[{{c_type}}] values):
  176. {{else}}
  177. cdef ismember_{{dtype}}(const {{dtype}}_t[:] arr, const {{dtype}}_t[:] values):
  178. {{endif}}
  179. """
  180. Return boolean of values in arr on an
  181. element by-element basis
  182. Parameters
  183. ----------
  184. arr : {{dtype}} ndarray
  185. values : {{dtype}} ndarray
  186. Returns
  187. -------
  188. boolean ndarray len of (arr)
  189. """
  190. cdef:
  191. Py_ssize_t i, n
  192. khiter_t k
  193. int ret = 0
  194. ndarray[uint8_t] result
  195. {{if dtype == "object"}}
  196. PyObject* val
  197. {{else}}
  198. {{c_type}} val
  199. {{endif}}
  200. kh_{{ttype}}_t *table = kh_init_{{ttype}}()
  201. # construct the table
  202. n = len(values)
  203. kh_resize_{{ttype}}(table, n)
  204. {{if dtype == 'object'}}
  205. if True:
  206. {{else}}
  207. with nogil:
  208. {{endif}}
  209. for i in range(n):
  210. val = {{to_c_type}}(values[i])
  211. kh_put_{{ttype}}(table, val, &ret)
  212. # test membership
  213. n = len(arr)
  214. result = np.empty(n, dtype=np.uint8)
  215. {{if dtype == 'object'}}
  216. if True:
  217. {{else}}
  218. with nogil:
  219. {{endif}}
  220. for i in range(n):
  221. val = {{to_c_type}}(arr[i])
  222. k = kh_get_{{ttype}}(table, val)
  223. result[i] = (k != table.n_buckets)
  224. kh_destroy_{{ttype}}(table)
  225. return result.view(np.bool_)
  226. # ----------------------------------------------------------------------
  227. # Mode Computations
  228. # ----------------------------------------------------------------------
  229. {{endfor}}
  230. ctypedef fused htfunc_t:
  231. numeric_object_t
  232. complex128_t
  233. complex64_t
  234. cpdef value_count(ndarray[htfunc_t] values, bint dropna, const uint8_t[:] mask=None):
  235. if htfunc_t is object:
  236. return value_count_object(values, dropna, mask=mask)
  237. elif htfunc_t is int8_t:
  238. return value_count_int8(values, dropna, mask=mask)
  239. elif htfunc_t is int16_t:
  240. return value_count_int16(values, dropna, mask=mask)
  241. elif htfunc_t is int32_t:
  242. return value_count_int32(values, dropna, mask=mask)
  243. elif htfunc_t is int64_t:
  244. return value_count_int64(values, dropna, mask=mask)
  245. elif htfunc_t is uint8_t:
  246. return value_count_uint8(values, dropna, mask=mask)
  247. elif htfunc_t is uint16_t:
  248. return value_count_uint16(values, dropna, mask=mask)
  249. elif htfunc_t is uint32_t:
  250. return value_count_uint32(values, dropna, mask=mask)
  251. elif htfunc_t is uint64_t:
  252. return value_count_uint64(values, dropna, mask=mask)
  253. elif htfunc_t is float64_t:
  254. return value_count_float64(values, dropna, mask=mask)
  255. elif htfunc_t is float32_t:
  256. return value_count_float32(values, dropna, mask=mask)
  257. elif htfunc_t is complex128_t:
  258. return value_count_complex128(values, dropna, mask=mask)
  259. elif htfunc_t is complex64_t:
  260. return value_count_complex64(values, dropna, mask=mask)
  261. else:
  262. raise TypeError(values.dtype)
  263. cpdef duplicated(ndarray[htfunc_t] values, object keep="first", const uint8_t[:] mask=None):
  264. if htfunc_t is object:
  265. return duplicated_object(values, keep, mask=mask)
  266. elif htfunc_t is int8_t:
  267. return duplicated_int8(values, keep, mask=mask)
  268. elif htfunc_t is int16_t:
  269. return duplicated_int16(values, keep, mask=mask)
  270. elif htfunc_t is int32_t:
  271. return duplicated_int32(values, keep, mask=mask)
  272. elif htfunc_t is int64_t:
  273. return duplicated_int64(values, keep, mask=mask)
  274. elif htfunc_t is uint8_t:
  275. return duplicated_uint8(values, keep, mask=mask)
  276. elif htfunc_t is uint16_t:
  277. return duplicated_uint16(values, keep, mask=mask)
  278. elif htfunc_t is uint32_t:
  279. return duplicated_uint32(values, keep, mask=mask)
  280. elif htfunc_t is uint64_t:
  281. return duplicated_uint64(values, keep, mask=mask)
  282. elif htfunc_t is float64_t:
  283. return duplicated_float64(values, keep, mask=mask)
  284. elif htfunc_t is float32_t:
  285. return duplicated_float32(values, keep, mask=mask)
  286. elif htfunc_t is complex128_t:
  287. return duplicated_complex128(values, keep, mask=mask)
  288. elif htfunc_t is complex64_t:
  289. return duplicated_complex64(values, keep, mask=mask)
  290. else:
  291. raise TypeError(values.dtype)
  292. cpdef ismember(ndarray[htfunc_t] arr, ndarray[htfunc_t] values):
  293. if htfunc_t is object:
  294. return ismember_object(arr, values)
  295. elif htfunc_t is int8_t:
  296. return ismember_int8(arr, values)
  297. elif htfunc_t is int16_t:
  298. return ismember_int16(arr, values)
  299. elif htfunc_t is int32_t:
  300. return ismember_int32(arr, values)
  301. elif htfunc_t is int64_t:
  302. return ismember_int64(arr, values)
  303. elif htfunc_t is uint8_t:
  304. return ismember_uint8(arr, values)
  305. elif htfunc_t is uint16_t:
  306. return ismember_uint16(arr, values)
  307. elif htfunc_t is uint32_t:
  308. return ismember_uint32(arr, values)
  309. elif htfunc_t is uint64_t:
  310. return ismember_uint64(arr, values)
  311. elif htfunc_t is float64_t:
  312. return ismember_float64(arr, values)
  313. elif htfunc_t is float32_t:
  314. return ismember_float32(arr, values)
  315. elif htfunc_t is complex128_t:
  316. return ismember_complex128(arr, values)
  317. elif htfunc_t is complex64_t:
  318. return ismember_complex64(arr, values)
  319. else:
  320. raise TypeError(values.dtype)
  321. @cython.wraparound(False)
  322. @cython.boundscheck(False)
  323. def mode(ndarray[htfunc_t] values, bint dropna, const uint8_t[:] mask=None):
  324. # TODO(cython3): use const htfunct_t[:]
  325. cdef:
  326. ndarray[htfunc_t] keys
  327. ndarray[htfunc_t] modes
  328. int64_t[::1] counts
  329. int64_t count, max_count = -1
  330. Py_ssize_t nkeys, k, j = 0
  331. keys, counts = value_count(values, dropna, mask=mask)
  332. nkeys = len(keys)
  333. modes = np.empty(nkeys, dtype=values.dtype)
  334. if htfunc_t is not object:
  335. with nogil:
  336. for k in range(nkeys):
  337. count = counts[k]
  338. if count == max_count:
  339. j += 1
  340. elif count > max_count:
  341. max_count = count
  342. j = 0
  343. else:
  344. continue
  345. modes[j] = keys[k]
  346. else:
  347. for k in range(nkeys):
  348. count = counts[k]
  349. if count == max_count:
  350. j += 1
  351. elif count > max_count:
  352. max_count = count
  353. j = 0
  354. else:
  355. continue
  356. modes[j] = keys[k]
  357. return modes[:j + 1]
  358. {{py:
  359. # name, dtype, ttype, c_type
  360. dtypes = [('Int64', 'int64', 'int64', 'int64_t'),
  361. ('Int32', 'int32', 'int32', 'int32_t'), ]
  362. }}
  363. {{for name, dtype, ttype, c_type in dtypes}}
  364. @cython.wraparound(False)
  365. @cython.boundscheck(False)
  366. def _unique_label_indices_{{dtype}}(const {{c_type}}[:] labels) -> ndarray:
  367. """
  368. Indices of the first occurrences of the unique labels
  369. *excluding* -1. equivalent to:
  370. np.unique(labels, return_index=True)[1]
  371. """
  372. cdef:
  373. int ret = 0
  374. Py_ssize_t i, n = len(labels)
  375. kh_{{ttype}}_t *table = kh_init_{{ttype}}()
  376. {{name}}Vector idx = {{name}}Vector()
  377. ndarray[{{c_type}}, ndim=1] arr
  378. {{name}}VectorData *ud = idx.data
  379. kh_resize_{{ttype}}(table, min(kh_needed_n_buckets(n), SIZE_HINT_LIMIT))
  380. with nogil:
  381. for i in range(n):
  382. kh_put_{{ttype}}(table, labels[i], &ret)
  383. if ret != 0:
  384. if needs_resize(ud):
  385. with gil:
  386. idx.resize()
  387. append_data_{{ttype}}(ud, i)
  388. kh_destroy_{{ttype}}(table)
  389. arr = idx.to_array()
  390. arr = arr[np.asarray(labels)[arr].argsort()]
  391. return arr[1:] if arr.size != 0 and labels[arr[0]] == -1 else arr
  392. {{endfor}}