test_hashtable.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734
  1. from contextlib import contextmanager
  2. import re
  3. import struct
  4. import tracemalloc
  5. from typing import Generator
  6. import numpy as np
  7. import pytest
  8. from pandas._libs import hashtable as ht
  9. import pandas as pd
  10. import pandas._testing as tm
  11. from pandas.core.algorithms import isin
  12. @contextmanager
  13. def activated_tracemalloc() -> Generator[None, None, None]:
  14. tracemalloc.start()
  15. try:
  16. yield
  17. finally:
  18. tracemalloc.stop()
  19. def get_allocated_khash_memory():
  20. snapshot = tracemalloc.take_snapshot()
  21. snapshot = snapshot.filter_traces(
  22. (tracemalloc.DomainFilter(True, ht.get_hashtable_trace_domain()),)
  23. )
  24. return sum(map(lambda x: x.size, snapshot.traces))
  25. @pytest.mark.parametrize(
  26. "table_type, dtype",
  27. [
  28. (ht.PyObjectHashTable, np.object_),
  29. (ht.Complex128HashTable, np.complex128),
  30. (ht.Int64HashTable, np.int64),
  31. (ht.UInt64HashTable, np.uint64),
  32. (ht.Float64HashTable, np.float64),
  33. (ht.Complex64HashTable, np.complex64),
  34. (ht.Int32HashTable, np.int32),
  35. (ht.UInt32HashTable, np.uint32),
  36. (ht.Float32HashTable, np.float32),
  37. (ht.Int16HashTable, np.int16),
  38. (ht.UInt16HashTable, np.uint16),
  39. (ht.Int8HashTable, np.int8),
  40. (ht.UInt8HashTable, np.uint8),
  41. (ht.IntpHashTable, np.intp),
  42. ],
  43. )
  44. class TestHashTable:
  45. def test_get_set_contains_len(self, table_type, dtype):
  46. index = 5
  47. table = table_type(55)
  48. assert len(table) == 0
  49. assert index not in table
  50. table.set_item(index, 42)
  51. assert len(table) == 1
  52. assert index in table
  53. assert table.get_item(index) == 42
  54. table.set_item(index + 1, 41)
  55. assert index in table
  56. assert index + 1 in table
  57. assert len(table) == 2
  58. assert table.get_item(index) == 42
  59. assert table.get_item(index + 1) == 41
  60. table.set_item(index, 21)
  61. assert index in table
  62. assert index + 1 in table
  63. assert len(table) == 2
  64. assert table.get_item(index) == 21
  65. assert table.get_item(index + 1) == 41
  66. assert index + 2 not in table
  67. table.set_item(index + 1, 21)
  68. assert index in table
  69. assert index + 1 in table
  70. assert len(table) == 2
  71. assert table.get_item(index) == 21
  72. assert table.get_item(index + 1) == 21
  73. with pytest.raises(KeyError, match=str(index + 2)):
  74. table.get_item(index + 2)
  75. def test_get_set_contains_len_mask(self, table_type, dtype):
  76. if table_type == ht.PyObjectHashTable:
  77. pytest.skip("Mask not supported for object")
  78. index = 5
  79. table = table_type(55, uses_mask=True)
  80. assert len(table) == 0
  81. assert index not in table
  82. table.set_item(index, 42)
  83. assert len(table) == 1
  84. assert index in table
  85. assert table.get_item(index) == 42
  86. with pytest.raises(KeyError, match="NA"):
  87. table.get_na()
  88. table.set_item(index + 1, 41)
  89. table.set_na(41)
  90. assert pd.NA in table
  91. assert index in table
  92. assert index + 1 in table
  93. assert len(table) == 3
  94. assert table.get_item(index) == 42
  95. assert table.get_item(index + 1) == 41
  96. assert table.get_na() == 41
  97. table.set_na(21)
  98. assert index in table
  99. assert index + 1 in table
  100. assert len(table) == 3
  101. assert table.get_item(index + 1) == 41
  102. assert table.get_na() == 21
  103. assert index + 2 not in table
  104. with pytest.raises(KeyError, match=str(index + 2)):
  105. table.get_item(index + 2)
  106. def test_map_keys_to_values(self, table_type, dtype, writable):
  107. # only Int64HashTable has this method
  108. if table_type == ht.Int64HashTable:
  109. N = 77
  110. table = table_type()
  111. keys = np.arange(N).astype(dtype)
  112. vals = np.arange(N).astype(np.int64) + N
  113. keys.flags.writeable = writable
  114. vals.flags.writeable = writable
  115. table.map_keys_to_values(keys, vals)
  116. for i in range(N):
  117. assert table.get_item(keys[i]) == i + N
  118. def test_map_locations(self, table_type, dtype, writable):
  119. N = 8
  120. table = table_type()
  121. keys = (np.arange(N) + N).astype(dtype)
  122. keys.flags.writeable = writable
  123. table.map_locations(keys)
  124. for i in range(N):
  125. assert table.get_item(keys[i]) == i
  126. def test_map_locations_mask(self, table_type, dtype, writable):
  127. if table_type == ht.PyObjectHashTable:
  128. pytest.skip("Mask not supported for object")
  129. N = 3
  130. table = table_type(uses_mask=True)
  131. keys = (np.arange(N) + N).astype(dtype)
  132. keys.flags.writeable = writable
  133. table.map_locations(keys, np.array([False, False, True]))
  134. for i in range(N - 1):
  135. assert table.get_item(keys[i]) == i
  136. with pytest.raises(KeyError, match=re.escape(str(keys[N - 1]))):
  137. table.get_item(keys[N - 1])
  138. assert table.get_na() == 2
  139. def test_lookup(self, table_type, dtype, writable):
  140. N = 3
  141. table = table_type()
  142. keys = (np.arange(N) + N).astype(dtype)
  143. keys.flags.writeable = writable
  144. table.map_locations(keys)
  145. result = table.lookup(keys)
  146. expected = np.arange(N)
  147. tm.assert_numpy_array_equal(result.astype(np.int64), expected.astype(np.int64))
  148. def test_lookup_wrong(self, table_type, dtype):
  149. if dtype in (np.int8, np.uint8):
  150. N = 100
  151. else:
  152. N = 512
  153. table = table_type()
  154. keys = (np.arange(N) + N).astype(dtype)
  155. table.map_locations(keys)
  156. wrong_keys = np.arange(N).astype(dtype)
  157. result = table.lookup(wrong_keys)
  158. assert np.all(result == -1)
  159. def test_lookup_mask(self, table_type, dtype, writable):
  160. if table_type == ht.PyObjectHashTable:
  161. pytest.skip("Mask not supported for object")
  162. N = 3
  163. table = table_type(uses_mask=True)
  164. keys = (np.arange(N) + N).astype(dtype)
  165. mask = np.array([False, True, False])
  166. keys.flags.writeable = writable
  167. table.map_locations(keys, mask)
  168. result = table.lookup(keys, mask)
  169. expected = np.arange(N)
  170. tm.assert_numpy_array_equal(result.astype(np.int64), expected.astype(np.int64))
  171. result = table.lookup(np.array([1 + N]).astype(dtype), np.array([False]))
  172. tm.assert_numpy_array_equal(
  173. result.astype(np.int64), np.array([-1], dtype=np.int64)
  174. )
  175. def test_unique(self, table_type, dtype, writable):
  176. if dtype in (np.int8, np.uint8):
  177. N = 88
  178. else:
  179. N = 1000
  180. table = table_type()
  181. expected = (np.arange(N) + N).astype(dtype)
  182. keys = np.repeat(expected, 5)
  183. keys.flags.writeable = writable
  184. unique = table.unique(keys)
  185. tm.assert_numpy_array_equal(unique, expected)
  186. def test_tracemalloc_works(self, table_type, dtype):
  187. if dtype in (np.int8, np.uint8):
  188. N = 256
  189. else:
  190. N = 30000
  191. keys = np.arange(N).astype(dtype)
  192. with activated_tracemalloc():
  193. table = table_type()
  194. table.map_locations(keys)
  195. used = get_allocated_khash_memory()
  196. my_size = table.sizeof()
  197. assert used == my_size
  198. del table
  199. assert get_allocated_khash_memory() == 0
  200. def test_tracemalloc_for_empty(self, table_type, dtype):
  201. with activated_tracemalloc():
  202. table = table_type()
  203. used = get_allocated_khash_memory()
  204. my_size = table.sizeof()
  205. assert used == my_size
  206. del table
  207. assert get_allocated_khash_memory() == 0
  208. def test_get_state(self, table_type, dtype):
  209. table = table_type(1000)
  210. state = table.get_state()
  211. assert state["size"] == 0
  212. assert state["n_occupied"] == 0
  213. assert "n_buckets" in state
  214. assert "upper_bound" in state
  215. @pytest.mark.parametrize("N", range(1, 110))
  216. def test_no_reallocation(self, table_type, dtype, N):
  217. keys = np.arange(N).astype(dtype)
  218. preallocated_table = table_type(N)
  219. n_buckets_start = preallocated_table.get_state()["n_buckets"]
  220. preallocated_table.map_locations(keys)
  221. n_buckets_end = preallocated_table.get_state()["n_buckets"]
  222. # original number of buckets was enough:
  223. assert n_buckets_start == n_buckets_end
  224. # check with clean table (not too much preallocated)
  225. clean_table = table_type()
  226. clean_table.map_locations(keys)
  227. assert n_buckets_start == clean_table.get_state()["n_buckets"]
  228. class TestHashTableUnsorted:
  229. # TODO: moved from test_algos; may be redundancies with other tests
  230. def test_string_hashtable_set_item_signature(self):
  231. # GH#30419 fix typing in StringHashTable.set_item to prevent segfault
  232. tbl = ht.StringHashTable()
  233. tbl.set_item("key", 1)
  234. assert tbl.get_item("key") == 1
  235. with pytest.raises(TypeError, match="'key' has incorrect type"):
  236. # key arg typed as string, not object
  237. tbl.set_item(4, 6)
  238. with pytest.raises(TypeError, match="'val' has incorrect type"):
  239. tbl.get_item(4)
  240. def test_lookup_nan(self, writable):
  241. # GH#21688 ensure we can deal with readonly memory views
  242. xs = np.array([2.718, 3.14, np.nan, -7, 5, 2, 3])
  243. xs.setflags(write=writable)
  244. m = ht.Float64HashTable()
  245. m.map_locations(xs)
  246. tm.assert_numpy_array_equal(m.lookup(xs), np.arange(len(xs), dtype=np.intp))
  247. def test_add_signed_zeros(self):
  248. # GH#21866 inconsistent hash-function for float64
  249. # default hash-function would lead to different hash-buckets
  250. # for 0.0 and -0.0 if there are more than 2^30 hash-buckets
  251. # but this would mean 16GB
  252. N = 4 # 12 * 10**8 would trigger the error, if you have enough memory
  253. m = ht.Float64HashTable(N)
  254. m.set_item(0.0, 0)
  255. m.set_item(-0.0, 0)
  256. assert len(m) == 1 # 0.0 and -0.0 are equivalent
  257. def test_add_different_nans(self):
  258. # GH#21866 inconsistent hash-function for float64
  259. # create different nans from bit-patterns:
  260. NAN1 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000000))[0]
  261. NAN2 = struct.unpack("d", struct.pack("=Q", 0x7FF8000000000001))[0]
  262. assert NAN1 != NAN1
  263. assert NAN2 != NAN2
  264. # default hash function would lead to different hash-buckets
  265. # for NAN1 and NAN2 even if there are only 4 buckets:
  266. m = ht.Float64HashTable()
  267. m.set_item(NAN1, 0)
  268. m.set_item(NAN2, 0)
  269. assert len(m) == 1 # NAN1 and NAN2 are equivalent
  270. def test_lookup_overflow(self, writable):
  271. xs = np.array([1, 2, 2**63], dtype=np.uint64)
  272. # GH 21688 ensure we can deal with readonly memory views
  273. xs.setflags(write=writable)
  274. m = ht.UInt64HashTable()
  275. m.map_locations(xs)
  276. tm.assert_numpy_array_equal(m.lookup(xs), np.arange(len(xs), dtype=np.intp))
  277. @pytest.mark.parametrize("nvals", [0, 10]) # resizing to 0 is special case
  278. @pytest.mark.parametrize(
  279. "htable, uniques, dtype, safely_resizes",
  280. [
  281. (ht.PyObjectHashTable, ht.ObjectVector, "object", False),
  282. (ht.StringHashTable, ht.ObjectVector, "object", True),
  283. (ht.Float64HashTable, ht.Float64Vector, "float64", False),
  284. (ht.Int64HashTable, ht.Int64Vector, "int64", False),
  285. (ht.Int32HashTable, ht.Int32Vector, "int32", False),
  286. (ht.UInt64HashTable, ht.UInt64Vector, "uint64", False),
  287. ],
  288. )
  289. def test_vector_resize(
  290. self, writable, htable, uniques, dtype, safely_resizes, nvals
  291. ):
  292. # Test for memory errors after internal vector
  293. # reallocations (GH 7157)
  294. # Changed from using np.random.rand to range
  295. # which could cause flaky CI failures when safely_resizes=False
  296. vals = np.array(range(1000), dtype=dtype)
  297. # GH 21688 ensures we can deal with read-only memory views
  298. vals.setflags(write=writable)
  299. # initialise instances; cannot initialise in parametrization,
  300. # as otherwise external views would be held on the array (which is
  301. # one of the things this test is checking)
  302. htable = htable()
  303. uniques = uniques()
  304. # get_labels may append to uniques
  305. htable.get_labels(vals[:nvals], uniques, 0, -1)
  306. # to_array() sets an external_view_exists flag on uniques.
  307. tmp = uniques.to_array()
  308. oldshape = tmp.shape
  309. # subsequent get_labels() calls can no longer append to it
  310. # (except for StringHashTables + ObjectVector)
  311. if safely_resizes:
  312. htable.get_labels(vals, uniques, 0, -1)
  313. else:
  314. with pytest.raises(ValueError, match="external reference.*"):
  315. htable.get_labels(vals, uniques, 0, -1)
  316. uniques.to_array() # should not raise here
  317. assert tmp.shape == oldshape
  318. @pytest.mark.parametrize(
  319. "hashtable",
  320. [
  321. ht.PyObjectHashTable,
  322. ht.StringHashTable,
  323. ht.Float64HashTable,
  324. ht.Int64HashTable,
  325. ht.Int32HashTable,
  326. ht.UInt64HashTable,
  327. ],
  328. )
  329. def test_hashtable_large_sizehint(self, hashtable):
  330. # GH#22729 smoketest for not raising when passing a large size_hint
  331. size_hint = np.iinfo(np.uint32).max + 1
  332. hashtable(size_hint=size_hint)
  333. class TestPyObjectHashTableWithNans:
  334. def test_nan_float(self):
  335. nan1 = float("nan")
  336. nan2 = float("nan")
  337. assert nan1 is not nan2
  338. table = ht.PyObjectHashTable()
  339. table.set_item(nan1, 42)
  340. assert table.get_item(nan2) == 42
  341. def test_nan_complex_both(self):
  342. nan1 = complex(float("nan"), float("nan"))
  343. nan2 = complex(float("nan"), float("nan"))
  344. assert nan1 is not nan2
  345. table = ht.PyObjectHashTable()
  346. table.set_item(nan1, 42)
  347. assert table.get_item(nan2) == 42
  348. def test_nan_complex_real(self):
  349. nan1 = complex(float("nan"), 1)
  350. nan2 = complex(float("nan"), 1)
  351. other = complex(float("nan"), 2)
  352. assert nan1 is not nan2
  353. table = ht.PyObjectHashTable()
  354. table.set_item(nan1, 42)
  355. assert table.get_item(nan2) == 42
  356. with pytest.raises(KeyError, match=None) as error:
  357. table.get_item(other)
  358. assert str(error.value) == str(other)
  359. def test_nan_complex_imag(self):
  360. nan1 = complex(1, float("nan"))
  361. nan2 = complex(1, float("nan"))
  362. other = complex(2, float("nan"))
  363. assert nan1 is not nan2
  364. table = ht.PyObjectHashTable()
  365. table.set_item(nan1, 42)
  366. assert table.get_item(nan2) == 42
  367. with pytest.raises(KeyError, match=None) as error:
  368. table.get_item(other)
  369. assert str(error.value) == str(other)
  370. def test_nan_in_tuple(self):
  371. nan1 = (float("nan"),)
  372. nan2 = (float("nan"),)
  373. assert nan1[0] is not nan2[0]
  374. table = ht.PyObjectHashTable()
  375. table.set_item(nan1, 42)
  376. assert table.get_item(nan2) == 42
  377. def test_nan_in_nested_tuple(self):
  378. nan1 = (1, (2, (float("nan"),)))
  379. nan2 = (1, (2, (float("nan"),)))
  380. other = (1, 2)
  381. table = ht.PyObjectHashTable()
  382. table.set_item(nan1, 42)
  383. assert table.get_item(nan2) == 42
  384. with pytest.raises(KeyError, match=None) as error:
  385. table.get_item(other)
  386. assert str(error.value) == str(other)
  387. def test_hash_equal_tuple_with_nans():
  388. a = (float("nan"), (float("nan"), float("nan")))
  389. b = (float("nan"), (float("nan"), float("nan")))
  390. assert ht.object_hash(a) == ht.object_hash(b)
  391. assert ht.objects_are_equal(a, b)
  392. def test_get_labels_groupby_for_Int64(writable):
  393. table = ht.Int64HashTable()
  394. vals = np.array([1, 2, -1, 2, 1, -1], dtype=np.int64)
  395. vals.flags.writeable = writable
  396. arr, unique = table.get_labels_groupby(vals)
  397. expected_arr = np.array([0, 1, -1, 1, 0, -1], dtype=np.intp)
  398. expected_unique = np.array([1, 2], dtype=np.int64)
  399. tm.assert_numpy_array_equal(arr, expected_arr)
  400. tm.assert_numpy_array_equal(unique, expected_unique)
  401. def test_tracemalloc_works_for_StringHashTable():
  402. N = 1000
  403. keys = np.arange(N).astype(np.str_).astype(np.object_)
  404. with activated_tracemalloc():
  405. table = ht.StringHashTable()
  406. table.map_locations(keys)
  407. used = get_allocated_khash_memory()
  408. my_size = table.sizeof()
  409. assert used == my_size
  410. del table
  411. assert get_allocated_khash_memory() == 0
  412. def test_tracemalloc_for_empty_StringHashTable():
  413. with activated_tracemalloc():
  414. table = ht.StringHashTable()
  415. used = get_allocated_khash_memory()
  416. my_size = table.sizeof()
  417. assert used == my_size
  418. del table
  419. assert get_allocated_khash_memory() == 0
  420. @pytest.mark.parametrize("N", range(1, 110))
  421. def test_no_reallocation_StringHashTable(N):
  422. keys = np.arange(N).astype(np.str_).astype(np.object_)
  423. preallocated_table = ht.StringHashTable(N)
  424. n_buckets_start = preallocated_table.get_state()["n_buckets"]
  425. preallocated_table.map_locations(keys)
  426. n_buckets_end = preallocated_table.get_state()["n_buckets"]
  427. # original number of buckets was enough:
  428. assert n_buckets_start == n_buckets_end
  429. # check with clean table (not too much preallocated)
  430. clean_table = ht.StringHashTable()
  431. clean_table.map_locations(keys)
  432. assert n_buckets_start == clean_table.get_state()["n_buckets"]
  433. @pytest.mark.parametrize(
  434. "table_type, dtype",
  435. [
  436. (ht.Float64HashTable, np.float64),
  437. (ht.Float32HashTable, np.float32),
  438. (ht.Complex128HashTable, np.complex128),
  439. (ht.Complex64HashTable, np.complex64),
  440. ],
  441. )
  442. class TestHashTableWithNans:
  443. def test_get_set_contains_len(self, table_type, dtype):
  444. index = float("nan")
  445. table = table_type()
  446. assert index not in table
  447. table.set_item(index, 42)
  448. assert len(table) == 1
  449. assert index in table
  450. assert table.get_item(index) == 42
  451. table.set_item(index, 41)
  452. assert len(table) == 1
  453. assert index in table
  454. assert table.get_item(index) == 41
  455. def test_map_locations(self, table_type, dtype):
  456. N = 10
  457. table = table_type()
  458. keys = np.full(N, np.nan, dtype=dtype)
  459. table.map_locations(keys)
  460. assert len(table) == 1
  461. assert table.get_item(np.nan) == N - 1
  462. def test_unique(self, table_type, dtype):
  463. N = 1020
  464. table = table_type()
  465. keys = np.full(N, np.nan, dtype=dtype)
  466. unique = table.unique(keys)
  467. assert np.all(np.isnan(unique)) and len(unique) == 1
  468. def test_unique_for_nan_objects_floats():
  469. table = ht.PyObjectHashTable()
  470. keys = np.array([float("nan") for i in range(50)], dtype=np.object_)
  471. unique = table.unique(keys)
  472. assert len(unique) == 1
  473. def test_unique_for_nan_objects_complex():
  474. table = ht.PyObjectHashTable()
  475. keys = np.array([complex(float("nan"), 1.0) for i in range(50)], dtype=np.object_)
  476. unique = table.unique(keys)
  477. assert len(unique) == 1
  478. def test_unique_for_nan_objects_tuple():
  479. table = ht.PyObjectHashTable()
  480. keys = np.array(
  481. [1] + [(1.0, (float("nan"), 1.0)) for i in range(50)], dtype=np.object_
  482. )
  483. unique = table.unique(keys)
  484. assert len(unique) == 2
  485. @pytest.mark.parametrize(
  486. "dtype",
  487. [
  488. np.object_,
  489. np.complex128,
  490. np.int64,
  491. np.uint64,
  492. np.float64,
  493. np.complex64,
  494. np.int32,
  495. np.uint32,
  496. np.float32,
  497. np.int16,
  498. np.uint16,
  499. np.int8,
  500. np.uint8,
  501. np.intp,
  502. ],
  503. )
  504. class TestHelpFunctions:
  505. def test_value_count(self, dtype, writable):
  506. N = 43
  507. expected = (np.arange(N) + N).astype(dtype)
  508. values = np.repeat(expected, 5)
  509. values.flags.writeable = writable
  510. keys, counts = ht.value_count(values, False)
  511. tm.assert_numpy_array_equal(np.sort(keys), expected)
  512. assert np.all(counts == 5)
  513. def test_value_count_stable(self, dtype, writable):
  514. # GH12679
  515. values = np.array([2, 1, 5, 22, 3, -1, 8]).astype(dtype)
  516. values.flags.writeable = writable
  517. keys, counts = ht.value_count(values, False)
  518. tm.assert_numpy_array_equal(keys, values)
  519. assert np.all(counts == 1)
  520. def test_duplicated_first(self, dtype, writable):
  521. N = 100
  522. values = np.repeat(np.arange(N).astype(dtype), 5)
  523. values.flags.writeable = writable
  524. result = ht.duplicated(values)
  525. expected = np.ones_like(values, dtype=np.bool_)
  526. expected[::5] = False
  527. tm.assert_numpy_array_equal(result, expected)
  528. def test_ismember_yes(self, dtype, writable):
  529. N = 127
  530. arr = np.arange(N).astype(dtype)
  531. values = np.arange(N).astype(dtype)
  532. arr.flags.writeable = writable
  533. values.flags.writeable = writable
  534. result = ht.ismember(arr, values)
  535. expected = np.ones_like(values, dtype=np.bool_)
  536. tm.assert_numpy_array_equal(result, expected)
  537. def test_ismember_no(self, dtype):
  538. N = 17
  539. arr = np.arange(N).astype(dtype)
  540. values = (np.arange(N) + N).astype(dtype)
  541. result = ht.ismember(arr, values)
  542. expected = np.zeros_like(values, dtype=np.bool_)
  543. tm.assert_numpy_array_equal(result, expected)
  544. def test_mode(self, dtype, writable):
  545. if dtype in (np.int8, np.uint8):
  546. N = 53
  547. else:
  548. N = 11111
  549. values = np.repeat(np.arange(N).astype(dtype), 5)
  550. values[0] = 42
  551. values.flags.writeable = writable
  552. result = ht.mode(values, False)
  553. assert result == 42
  554. def test_mode_stable(self, dtype, writable):
  555. values = np.array([2, 1, 5, 22, 3, -1, 8]).astype(dtype)
  556. values.flags.writeable = writable
  557. keys = ht.mode(values, False)
  558. tm.assert_numpy_array_equal(keys, values)
  559. def test_modes_with_nans():
  560. # GH42688, nans aren't mangled
  561. nulls = [pd.NA, np.nan, pd.NaT, None]
  562. values = np.array([True] + nulls * 2, dtype=np.object_)
  563. modes = ht.mode(values, False)
  564. assert modes.size == len(nulls)
  565. def test_unique_label_indices_intp(writable):
  566. keys = np.array([1, 2, 2, 2, 1, 3], dtype=np.intp)
  567. keys.flags.writeable = writable
  568. result = ht.unique_label_indices(keys)
  569. expected = np.array([0, 1, 5], dtype=np.intp)
  570. tm.assert_numpy_array_equal(result, expected)
  571. def test_unique_label_indices():
  572. a = np.random.randint(1, 1 << 10, 1 << 15).astype(np.intp)
  573. left = ht.unique_label_indices(a)
  574. right = np.unique(a, return_index=True)[1]
  575. tm.assert_numpy_array_equal(left, right, check_dtype=False)
  576. a[np.random.choice(len(a), 10)] = -1
  577. left = ht.unique_label_indices(a)
  578. right = np.unique(a, return_index=True)[1][1:]
  579. tm.assert_numpy_array_equal(left, right, check_dtype=False)
  580. @pytest.mark.parametrize(
  581. "dtype",
  582. [
  583. np.float64,
  584. np.float32,
  585. np.complex128,
  586. np.complex64,
  587. ],
  588. )
  589. class TestHelpFunctionsWithNans:
  590. def test_value_count(self, dtype):
  591. values = np.array([np.nan, np.nan, np.nan], dtype=dtype)
  592. keys, counts = ht.value_count(values, True)
  593. assert len(keys) == 0
  594. keys, counts = ht.value_count(values, False)
  595. assert len(keys) == 1 and np.all(np.isnan(keys))
  596. assert counts[0] == 3
  597. def test_duplicated_first(self, dtype):
  598. values = np.array([np.nan, np.nan, np.nan], dtype=dtype)
  599. result = ht.duplicated(values)
  600. expected = np.array([False, True, True])
  601. tm.assert_numpy_array_equal(result, expected)
  602. def test_ismember_yes(self, dtype):
  603. arr = np.array([np.nan, np.nan, np.nan], dtype=dtype)
  604. values = np.array([np.nan, np.nan], dtype=dtype)
  605. result = ht.ismember(arr, values)
  606. expected = np.array([True, True, True], dtype=np.bool_)
  607. tm.assert_numpy_array_equal(result, expected)
  608. def test_ismember_no(self, dtype):
  609. arr = np.array([np.nan, np.nan, np.nan], dtype=dtype)
  610. values = np.array([1], dtype=dtype)
  611. result = ht.ismember(arr, values)
  612. expected = np.array([False, False, False], dtype=np.bool_)
  613. tm.assert_numpy_array_equal(result, expected)
  614. def test_mode(self, dtype):
  615. values = np.array([42, np.nan, np.nan, np.nan], dtype=dtype)
  616. assert ht.mode(values, True) == 42
  617. assert np.isnan(ht.mode(values, False))
  618. def test_ismember_tuple_with_nans():
  619. # GH-41836
  620. values = [("a", float("nan")), ("b", 1)]
  621. comps = [("a", float("nan"))]
  622. result = isin(values, comps)
  623. expected = np.array([True, False], dtype=np.bool_)
  624. tm.assert_numpy_array_equal(result, expected)
  625. def test_float_complex_int_are_equal_as_objects():
  626. values = ["a", 5, 5.0, 5.0 + 0j]
  627. comps = list(range(129))
  628. result = isin(values, comps)
  629. expected = np.array([False, True, True, True], dtype=np.bool_)
  630. tm.assert_numpy_array_equal(result, expected)