test_indexing_slow.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from typing import (
  2. Any,
  3. List,
  4. )
  5. import warnings
  6. import numpy as np
  7. import pytest
  8. import pandas as pd
  9. from pandas import (
  10. DataFrame,
  11. Series,
  12. )
  13. import pandas._testing as tm
  14. m = 50
  15. n = 1000
  16. cols = ["jim", "joe", "jolie", "joline", "jolia"]
  17. vals: List[Any] = [
  18. np.random.randint(0, 10, n),
  19. np.random.choice(list("abcdefghij"), n),
  20. np.random.choice(pd.date_range("20141009", periods=10).tolist(), n),
  21. np.random.choice(list("ZYXWVUTSRQ"), n),
  22. np.random.randn(n),
  23. ]
  24. vals = list(map(tuple, zip(*vals)))
  25. # bunch of keys for testing
  26. keys: List[Any] = [
  27. np.random.randint(0, 11, m),
  28. np.random.choice(list("abcdefghijk"), m),
  29. np.random.choice(pd.date_range("20141009", periods=11).tolist(), m),
  30. np.random.choice(list("ZYXWVUTSRQP"), m),
  31. ]
  32. keys = list(map(tuple, zip(*keys)))
  33. keys += list(map(lambda t: t[:-1], vals[:: n // m]))
  34. # covers both unique index and non-unique index
  35. df = DataFrame(vals, columns=cols)
  36. a = pd.concat([df, df])
  37. b = df.drop_duplicates(subset=cols[:-1])
  38. def validate(mi, df, key):
  39. # check indexing into a multi-index before & past the lexsort depth
  40. mask = np.ones(len(df)).astype("bool")
  41. # test for all partials of this key
  42. for i, k in enumerate(key):
  43. mask &= df.iloc[:, i] == k
  44. if not mask.any():
  45. assert key[: i + 1] not in mi.index
  46. continue
  47. assert key[: i + 1] in mi.index
  48. right = df[mask].copy()
  49. if i + 1 != len(key): # partial key
  50. return_value = right.drop(cols[: i + 1], axis=1, inplace=True)
  51. assert return_value is None
  52. return_value = right.set_index(cols[i + 1 : -1], inplace=True)
  53. assert return_value is None
  54. tm.assert_frame_equal(mi.loc[key[: i + 1]], right)
  55. else: # full key
  56. return_value = right.set_index(cols[:-1], inplace=True)
  57. assert return_value is None
  58. if len(right) == 1: # single hit
  59. right = Series(
  60. right["jolia"].values, name=right.index[0], index=["jolia"]
  61. )
  62. tm.assert_series_equal(mi.loc[key[: i + 1]], right)
  63. else: # multi hit
  64. tm.assert_frame_equal(mi.loc[key[: i + 1]], right)
  65. @pytest.mark.filterwarnings("ignore::pandas.errors.PerformanceWarning")
  66. @pytest.mark.parametrize("lexsort_depth", list(range(5)))
  67. @pytest.mark.parametrize("key", keys)
  68. @pytest.mark.parametrize("frame", [a, b])
  69. def test_multiindex_get_loc(lexsort_depth, key, frame):
  70. # GH7724, GH2646
  71. with warnings.catch_warnings(record=True):
  72. if lexsort_depth == 0:
  73. df = frame.copy()
  74. else:
  75. df = frame.sort_values(by=cols[:lexsort_depth])
  76. mi = df.set_index(cols[:-1])
  77. assert not mi.index._lexsort_depth < lexsort_depth
  78. validate(mi, df, key)