test_disjoint_set.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import pytest
  2. from pytest import raises as assert_raises
  3. import numpy as np
  4. from scipy.cluster.hierarchy import DisjointSet
  5. import string
  6. def generate_random_token():
  7. k = len(string.ascii_letters)
  8. tokens = list(np.arange(k, dtype=int))
  9. tokens += list(np.arange(k, dtype=float))
  10. tokens += list(string.ascii_letters)
  11. tokens += [None for i in range(k)]
  12. tokens = np.array(tokens, dtype=object)
  13. rng = np.random.RandomState(seed=0)
  14. while 1:
  15. size = rng.randint(1, 3)
  16. element = rng.choice(tokens, size)
  17. if size == 1:
  18. yield element[0]
  19. else:
  20. yield tuple(element)
  21. def get_elements(n):
  22. # dict is deterministic without difficulty of comparing numpy ints
  23. elements = {}
  24. for element in generate_random_token():
  25. if element not in elements:
  26. elements[element] = len(elements)
  27. if len(elements) >= n:
  28. break
  29. return list(elements.keys())
  30. def test_init():
  31. n = 10
  32. elements = get_elements(n)
  33. dis = DisjointSet(elements)
  34. assert dis.n_subsets == n
  35. assert list(dis) == elements
  36. def test_len():
  37. n = 10
  38. elements = get_elements(n)
  39. dis = DisjointSet(elements)
  40. assert len(dis) == n
  41. dis.add("dummy")
  42. assert len(dis) == n + 1
  43. @pytest.mark.parametrize("n", [10, 100])
  44. def test_contains(n):
  45. elements = get_elements(n)
  46. dis = DisjointSet(elements)
  47. for x in elements:
  48. assert x in dis
  49. assert "dummy" not in dis
  50. @pytest.mark.parametrize("n", [10, 100])
  51. def test_add(n):
  52. elements = get_elements(n)
  53. dis1 = DisjointSet(elements)
  54. dis2 = DisjointSet()
  55. for i, x in enumerate(elements):
  56. dis2.add(x)
  57. assert len(dis2) == i + 1
  58. # test idempotency by adding element again
  59. dis2.add(x)
  60. assert len(dis2) == i + 1
  61. assert list(dis1) == list(dis2)
  62. def test_element_not_present():
  63. elements = get_elements(n=10)
  64. dis = DisjointSet(elements)
  65. with assert_raises(KeyError):
  66. dis["dummy"]
  67. with assert_raises(KeyError):
  68. dis.merge(elements[0], "dummy")
  69. with assert_raises(KeyError):
  70. dis.connected(elements[0], "dummy")
  71. @pytest.mark.parametrize("direction", ["forwards", "backwards"])
  72. @pytest.mark.parametrize("n", [10, 100])
  73. def test_linear_union_sequence(n, direction):
  74. elements = get_elements(n)
  75. dis = DisjointSet(elements)
  76. assert elements == list(dis)
  77. indices = list(range(n - 1))
  78. if direction == "backwards":
  79. indices = indices[::-1]
  80. for it, i in enumerate(indices):
  81. assert not dis.connected(elements[i], elements[i + 1])
  82. assert dis.merge(elements[i], elements[i + 1])
  83. assert dis.connected(elements[i], elements[i + 1])
  84. assert dis.n_subsets == n - 1 - it
  85. roots = [dis[i] for i in elements]
  86. if direction == "forwards":
  87. assert all(elements[0] == r for r in roots)
  88. else:
  89. assert all(elements[-2] == r for r in roots)
  90. assert not dis.merge(elements[0], elements[-1])
  91. @pytest.mark.parametrize("n", [10, 100])
  92. def test_self_unions(n):
  93. elements = get_elements(n)
  94. dis = DisjointSet(elements)
  95. for x in elements:
  96. assert dis.connected(x, x)
  97. assert not dis.merge(x, x)
  98. assert dis.connected(x, x)
  99. assert dis.n_subsets == len(elements)
  100. assert elements == list(dis)
  101. roots = [dis[x] for x in elements]
  102. assert elements == roots
  103. @pytest.mark.parametrize("order", ["ab", "ba"])
  104. @pytest.mark.parametrize("n", [10, 100])
  105. def test_equal_size_ordering(n, order):
  106. elements = get_elements(n)
  107. dis = DisjointSet(elements)
  108. rng = np.random.RandomState(seed=0)
  109. indices = np.arange(n)
  110. rng.shuffle(indices)
  111. for i in range(0, len(indices), 2):
  112. a, b = elements[indices[i]], elements[indices[i + 1]]
  113. if order == "ab":
  114. assert dis.merge(a, b)
  115. else:
  116. assert dis.merge(b, a)
  117. expected = elements[min(indices[i], indices[i + 1])]
  118. assert dis[a] == expected
  119. assert dis[b] == expected
  120. @pytest.mark.parametrize("kmax", [5, 10])
  121. def test_binary_tree(kmax):
  122. n = 2**kmax
  123. elements = get_elements(n)
  124. dis = DisjointSet(elements)
  125. rng = np.random.RandomState(seed=0)
  126. for k in 2**np.arange(kmax):
  127. for i in range(0, n, 2 * k):
  128. r1, r2 = rng.randint(0, k, size=2)
  129. a, b = elements[i + r1], elements[i + k + r2]
  130. assert not dis.connected(a, b)
  131. assert dis.merge(a, b)
  132. assert dis.connected(a, b)
  133. assert elements == list(dis)
  134. roots = [dis[i] for i in elements]
  135. expected_indices = np.arange(n) - np.arange(n) % (2 * k)
  136. expected = [elements[i] for i in expected_indices]
  137. assert roots == expected
  138. @pytest.mark.parametrize("n", [10, 100])
  139. def test_subsets(n):
  140. elements = get_elements(n)
  141. dis = DisjointSet(elements)
  142. rng = np.random.RandomState(seed=0)
  143. for i, j in rng.randint(0, n, (n, 2)):
  144. x = elements[i]
  145. y = elements[j]
  146. expected = {element for element in dis if {dis[element]} == {dis[x]}}
  147. assert expected == dis.subset(x)
  148. expected = {dis[element]: set() for element in dis}
  149. for element in dis:
  150. expected[dis[element]].add(element)
  151. expected = list(expected.values())
  152. assert expected == dis.subsets()
  153. dis.merge(x, y)
  154. assert dis.subset(x) == dis.subset(y)