test_iterables.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936
  1. from textwrap import dedent
  2. from itertools import islice, product
  3. from sympy.core.basic import Basic
  4. from sympy.core.numbers import Integer
  5. from sympy.core.sorting import ordered
  6. from sympy.core.symbol import (Dummy, symbols)
  7. from sympy.functions.combinatorial.factorials import factorial
  8. from sympy.matrices.dense import Matrix
  9. from sympy.combinatorics import RGS_enum, RGS_unrank, Permutation
  10. from sympy.utilities.iterables import (
  11. _partition, _set_partitions, binary_partitions, bracelets, capture,
  12. cartes, common_prefix, common_suffix, connected_components, dict_merge,
  13. filter_symbols, flatten, generate_bell, generate_derangements,
  14. generate_involutions, generate_oriented_forest, group, has_dups, ibin,
  15. iproduct, kbins, minlex, multiset, multiset_combinations,
  16. multiset_partitions, multiset_permutations, necklaces, numbered_symbols,
  17. partitions, permutations, postfixes,
  18. prefixes, reshape, rotate_left, rotate_right, runs, sift,
  19. strongly_connected_components, subsets, take, topological_sort, unflatten,
  20. uniq, variations, ordered_partitions, rotations, is_palindromic, iterable,
  21. NotIterable, multiset_derangements,
  22. sequence_partitions, sequence_partitions_empty)
  23. from sympy.utilities.enumerative import (
  24. factoring_visitor, multiset_partitions_taocp )
  25. from sympy.core.singleton import S
  26. from sympy.testing.pytest import raises, warns_deprecated_sympy
  27. w, x, y, z = symbols('w,x,y,z')
  28. def test_deprecated_iterables():
  29. from sympy.utilities.iterables import default_sort_key, ordered
  30. with warns_deprecated_sympy():
  31. assert list(ordered([y, x])) == [x, y]
  32. with warns_deprecated_sympy():
  33. assert sorted([y, x], key=default_sort_key) == [x, y]
  34. def test_is_palindromic():
  35. assert is_palindromic('')
  36. assert is_palindromic('x')
  37. assert is_palindromic('xx')
  38. assert is_palindromic('xyx')
  39. assert not is_palindromic('xy')
  40. assert not is_palindromic('xyzx')
  41. assert is_palindromic('xxyzzyx', 1)
  42. assert not is_palindromic('xxyzzyx', 2)
  43. assert is_palindromic('xxyzzyx', 2, -1)
  44. assert is_palindromic('xxyzzyx', 2, 6)
  45. assert is_palindromic('xxyzyx', 1)
  46. assert not is_palindromic('xxyzyx', 2)
  47. assert is_palindromic('xxyzyx', 2, 2 + 3)
  48. def test_flatten():
  49. assert flatten((1, (1,))) == [1, 1]
  50. assert flatten((x, (x,))) == [x, x]
  51. ls = [[(-2, -1), (1, 2)], [(0, 0)]]
  52. assert flatten(ls, levels=0) == ls
  53. assert flatten(ls, levels=1) == [(-2, -1), (1, 2), (0, 0)]
  54. assert flatten(ls, levels=2) == [-2, -1, 1, 2, 0, 0]
  55. assert flatten(ls, levels=3) == [-2, -1, 1, 2, 0, 0]
  56. raises(ValueError, lambda: flatten(ls, levels=-1))
  57. class MyOp(Basic):
  58. pass
  59. assert flatten([MyOp(x, y), z]) == [MyOp(x, y), z]
  60. assert flatten([MyOp(x, y), z], cls=MyOp) == [x, y, z]
  61. assert flatten({1, 11, 2}) == list({1, 11, 2})
  62. def test_iproduct():
  63. assert list(iproduct()) == [()]
  64. assert list(iproduct([])) == []
  65. assert list(iproduct([1,2,3])) == [(1,),(2,),(3,)]
  66. assert sorted(iproduct([1, 2], [3, 4, 5])) == [
  67. (1,3),(1,4),(1,5),(2,3),(2,4),(2,5)]
  68. assert sorted(iproduct([0,1],[0,1],[0,1])) == [
  69. (0,0,0),(0,0,1),(0,1,0),(0,1,1),(1,0,0),(1,0,1),(1,1,0),(1,1,1)]
  70. assert iterable(iproduct(S.Integers)) is True
  71. assert iterable(iproduct(S.Integers, S.Integers)) is True
  72. assert (3,) in iproduct(S.Integers)
  73. assert (4, 5) in iproduct(S.Integers, S.Integers)
  74. assert (1, 2, 3) in iproduct(S.Integers, S.Integers, S.Integers)
  75. triples = set(islice(iproduct(S.Integers, S.Integers, S.Integers), 1000))
  76. for n1, n2, n3 in triples:
  77. assert isinstance(n1, Integer)
  78. assert isinstance(n2, Integer)
  79. assert isinstance(n3, Integer)
  80. for t in set(product(*([range(-2, 3)]*3))):
  81. assert t in iproduct(S.Integers, S.Integers, S.Integers)
  82. def test_group():
  83. assert group([]) == []
  84. assert group([], multiple=False) == []
  85. assert group([1]) == [[1]]
  86. assert group([1], multiple=False) == [(1, 1)]
  87. assert group([1, 1]) == [[1, 1]]
  88. assert group([1, 1], multiple=False) == [(1, 2)]
  89. assert group([1, 1, 1]) == [[1, 1, 1]]
  90. assert group([1, 1, 1], multiple=False) == [(1, 3)]
  91. assert group([1, 2, 1]) == [[1], [2], [1]]
  92. assert group([1, 2, 1], multiple=False) == [(1, 1), (2, 1), (1, 1)]
  93. assert group([1, 1, 2, 2, 2, 1, 3, 3]) == [[1, 1], [2, 2, 2], [1], [3, 3]]
  94. assert group([1, 1, 2, 2, 2, 1, 3, 3], multiple=False) == [(1, 2),
  95. (2, 3), (1, 1), (3, 2)]
  96. def test_subsets():
  97. # combinations
  98. assert list(subsets([1, 2, 3], 0)) == [()]
  99. assert list(subsets([1, 2, 3], 1)) == [(1,), (2,), (3,)]
  100. assert list(subsets([1, 2, 3], 2)) == [(1, 2), (1, 3), (2, 3)]
  101. assert list(subsets([1, 2, 3], 3)) == [(1, 2, 3)]
  102. l = list(range(4))
  103. assert list(subsets(l, 0, repetition=True)) == [()]
  104. assert list(subsets(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)]
  105. assert list(subsets(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2),
  106. (0, 3), (1, 1), (1, 2),
  107. (1, 3), (2, 2), (2, 3),
  108. (3, 3)]
  109. assert list(subsets(l, 3, repetition=True)) == [(0, 0, 0), (0, 0, 1),
  110. (0, 0, 2), (0, 0, 3),
  111. (0, 1, 1), (0, 1, 2),
  112. (0, 1, 3), (0, 2, 2),
  113. (0, 2, 3), (0, 3, 3),
  114. (1, 1, 1), (1, 1, 2),
  115. (1, 1, 3), (1, 2, 2),
  116. (1, 2, 3), (1, 3, 3),
  117. (2, 2, 2), (2, 2, 3),
  118. (2, 3, 3), (3, 3, 3)]
  119. assert len(list(subsets(l, 4, repetition=True))) == 35
  120. assert list(subsets(l[:2], 3, repetition=False)) == []
  121. assert list(subsets(l[:2], 3, repetition=True)) == [(0, 0, 0),
  122. (0, 0, 1),
  123. (0, 1, 1),
  124. (1, 1, 1)]
  125. assert list(subsets([1, 2], repetition=True)) == \
  126. [(), (1,), (2,), (1, 1), (1, 2), (2, 2)]
  127. assert list(subsets([1, 2], repetition=False)) == \
  128. [(), (1,), (2,), (1, 2)]
  129. assert list(subsets([1, 2, 3], 2)) == \
  130. [(1, 2), (1, 3), (2, 3)]
  131. assert list(subsets([1, 2, 3], 2, repetition=True)) == \
  132. [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
  133. def test_variations():
  134. # permutations
  135. l = list(range(4))
  136. assert list(variations(l, 0, repetition=False)) == [()]
  137. assert list(variations(l, 1, repetition=False)) == [(0,), (1,), (2,), (3,)]
  138. assert list(variations(l, 2, repetition=False)) == [(0, 1), (0, 2), (0, 3), (1, 0), (1, 2), (1, 3), (2, 0), (2, 1), (2, 3), (3, 0), (3, 1), (3, 2)]
  139. assert list(variations(l, 3, repetition=False)) == [(0, 1, 2), (0, 1, 3), (0, 2, 1), (0, 2, 3), (0, 3, 1), (0, 3, 2), (1, 0, 2), (1, 0, 3), (1, 2, 0), (1, 2, 3), (1, 3, 0), (1, 3, 2), (2, 0, 1), (2, 0, 3), (2, 1, 0), (2, 1, 3), (2, 3, 0), (2, 3, 1), (3, 0, 1), (3, 0, 2), (3, 1, 0), (3, 1, 2), (3, 2, 0), (3, 2, 1)]
  140. assert list(variations(l, 0, repetition=True)) == [()]
  141. assert list(variations(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)]
  142. assert list(variations(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2),
  143. (0, 3), (1, 0), (1, 1),
  144. (1, 2), (1, 3), (2, 0),
  145. (2, 1), (2, 2), (2, 3),
  146. (3, 0), (3, 1), (3, 2),
  147. (3, 3)]
  148. assert len(list(variations(l, 3, repetition=True))) == 64
  149. assert len(list(variations(l, 4, repetition=True))) == 256
  150. assert list(variations(l[:2], 3, repetition=False)) == []
  151. assert list(variations(l[:2], 3, repetition=True)) == [
  152. (0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1),
  153. (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)
  154. ]
  155. def test_cartes():
  156. assert list(cartes([1, 2], [3, 4, 5])) == \
  157. [(1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5)]
  158. assert list(cartes()) == [()]
  159. assert list(cartes('a')) == [('a',)]
  160. assert list(cartes('a', repeat=2)) == [('a', 'a')]
  161. assert list(cartes(list(range(2)))) == [(0,), (1,)]
  162. def test_filter_symbols():
  163. s = numbered_symbols()
  164. filtered = filter_symbols(s, symbols("x0 x2 x3"))
  165. assert take(filtered, 3) == list(symbols("x1 x4 x5"))
  166. def test_numbered_symbols():
  167. s = numbered_symbols(cls=Dummy)
  168. assert isinstance(next(s), Dummy)
  169. assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \
  170. symbols('C2')
  171. def test_sift():
  172. assert sift(list(range(5)), lambda _: _ % 2) == {1: [1, 3], 0: [0, 2, 4]}
  173. assert sift([x, y], lambda _: _.has(x)) == {False: [y], True: [x]}
  174. assert sift([S.One], lambda _: _.has(x)) == {False: [1]}
  175. assert sift([0, 1, 2, 3], lambda x: x % 2, binary=True) == (
  176. [1, 3], [0, 2])
  177. assert sift([0, 1, 2, 3], lambda x: x % 3 == 1, binary=True) == (
  178. [1], [0, 2, 3])
  179. raises(ValueError, lambda:
  180. sift([0, 1, 2, 3], lambda x: x % 3, binary=True))
  181. def test_take():
  182. X = numbered_symbols()
  183. assert take(X, 5) == list(symbols('x0:5'))
  184. assert take(X, 5) == list(symbols('x5:10'))
  185. assert take([1, 2, 3, 4, 5], 5) == [1, 2, 3, 4, 5]
  186. def test_dict_merge():
  187. assert dict_merge({}, {1: x, y: z}) == {1: x, y: z}
  188. assert dict_merge({1: x, y: z}, {}) == {1: x, y: z}
  189. assert dict_merge({2: z}, {1: x, y: z}) == {1: x, 2: z, y: z}
  190. assert dict_merge({1: x, y: z}, {2: z}) == {1: x, 2: z, y: z}
  191. assert dict_merge({1: y, 2: z}, {1: x, y: z}) == {1: x, 2: z, y: z}
  192. assert dict_merge({1: x, y: z}, {1: y, 2: z}) == {1: y, 2: z, y: z}
  193. def test_prefixes():
  194. assert list(prefixes([])) == []
  195. assert list(prefixes([1])) == [[1]]
  196. assert list(prefixes([1, 2])) == [[1], [1, 2]]
  197. assert list(prefixes([1, 2, 3, 4, 5])) == \
  198. [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]
  199. def test_postfixes():
  200. assert list(postfixes([])) == []
  201. assert list(postfixes([1])) == [[1]]
  202. assert list(postfixes([1, 2])) == [[2], [1, 2]]
  203. assert list(postfixes([1, 2, 3, 4, 5])) == \
  204. [[5], [4, 5], [3, 4, 5], [2, 3, 4, 5], [1, 2, 3, 4, 5]]
  205. def test_topological_sort():
  206. V = [2, 3, 5, 7, 8, 9, 10, 11]
  207. E = [(7, 11), (7, 8), (5, 11),
  208. (3, 8), (3, 10), (11, 2),
  209. (11, 9), (11, 10), (8, 9)]
  210. assert topological_sort((V, E)) == [3, 5, 7, 8, 11, 2, 9, 10]
  211. assert topological_sort((V, E), key=lambda v: -v) == \
  212. [7, 5, 11, 3, 10, 8, 9, 2]
  213. raises(ValueError, lambda: topological_sort((V, E + [(10, 7)])))
  214. def test_strongly_connected_components():
  215. assert strongly_connected_components(([], [])) == []
  216. assert strongly_connected_components(([1, 2, 3], [])) == [[1], [2], [3]]
  217. V = [1, 2, 3]
  218. E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)]
  219. assert strongly_connected_components((V, E)) == [[1, 2, 3]]
  220. V = [1, 2, 3, 4]
  221. E = [(1, 2), (2, 3), (3, 2), (3, 4)]
  222. assert strongly_connected_components((V, E)) == [[4], [2, 3], [1]]
  223. V = [1, 2, 3, 4]
  224. E = [(1, 2), (2, 1), (3, 4), (4, 3)]
  225. assert strongly_connected_components((V, E)) == [[1, 2], [3, 4]]
  226. def test_connected_components():
  227. assert connected_components(([], [])) == []
  228. assert connected_components(([1, 2, 3], [])) == [[1], [2], [3]]
  229. V = [1, 2, 3]
  230. E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)]
  231. assert connected_components((V, E)) == [[1, 2, 3]]
  232. V = [1, 2, 3, 4]
  233. E = [(1, 2), (2, 3), (3, 2), (3, 4)]
  234. assert connected_components((V, E)) == [[1, 2, 3, 4]]
  235. V = [1, 2, 3, 4]
  236. E = [(1, 2), (3, 4)]
  237. assert connected_components((V, E)) == [[1, 2], [3, 4]]
  238. def test_rotate():
  239. A = [0, 1, 2, 3, 4]
  240. assert rotate_left(A, 2) == [2, 3, 4, 0, 1]
  241. assert rotate_right(A, 1) == [4, 0, 1, 2, 3]
  242. A = []
  243. B = rotate_right(A, 1)
  244. assert B == []
  245. B.append(1)
  246. assert A == []
  247. B = rotate_left(A, 1)
  248. assert B == []
  249. B.append(1)
  250. assert A == []
  251. def test_multiset_partitions():
  252. A = [0, 1, 2, 3, 4]
  253. assert list(multiset_partitions(A, 5)) == [[[0], [1], [2], [3], [4]]]
  254. assert len(list(multiset_partitions(A, 4))) == 10
  255. assert len(list(multiset_partitions(A, 3))) == 25
  256. assert list(multiset_partitions([1, 1, 1, 2, 2], 2)) == [
  257. [[1, 1, 1, 2], [2]], [[1, 1, 1], [2, 2]], [[1, 1, 2, 2], [1]],
  258. [[1, 1, 2], [1, 2]], [[1, 1], [1, 2, 2]]]
  259. assert list(multiset_partitions([1, 1, 2, 2], 2)) == [
  260. [[1, 1, 2], [2]], [[1, 1], [2, 2]], [[1, 2, 2], [1]],
  261. [[1, 2], [1, 2]]]
  262. assert list(multiset_partitions([1, 2, 3, 4], 2)) == [
  263. [[1, 2, 3], [4]], [[1, 2, 4], [3]], [[1, 2], [3, 4]],
  264. [[1, 3, 4], [2]], [[1, 3], [2, 4]], [[1, 4], [2, 3]],
  265. [[1], [2, 3, 4]]]
  266. assert list(multiset_partitions([1, 2, 2], 2)) == [
  267. [[1, 2], [2]], [[1], [2, 2]]]
  268. assert list(multiset_partitions(3)) == [
  269. [[0, 1, 2]], [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]],
  270. [[0], [1], [2]]]
  271. assert list(multiset_partitions(3, 2)) == [
  272. [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]]]
  273. assert list(multiset_partitions([1] * 3, 2)) == [[[1], [1, 1]]]
  274. assert list(multiset_partitions([1] * 3)) == [
  275. [[1, 1, 1]], [[1], [1, 1]], [[1], [1], [1]]]
  276. a = [3, 2, 1]
  277. assert list(multiset_partitions(a)) == \
  278. list(multiset_partitions(sorted(a)))
  279. assert list(multiset_partitions(a, 5)) == []
  280. assert list(multiset_partitions(a, 1)) == [[[1, 2, 3]]]
  281. assert list(multiset_partitions(a + [4], 5)) == []
  282. assert list(multiset_partitions(a + [4], 1)) == [[[1, 2, 3, 4]]]
  283. assert list(multiset_partitions(2, 5)) == []
  284. assert list(multiset_partitions(2, 1)) == [[[0, 1]]]
  285. assert list(multiset_partitions('a')) == [[['a']]]
  286. assert list(multiset_partitions('a', 2)) == []
  287. assert list(multiset_partitions('ab')) == [[['a', 'b']], [['a'], ['b']]]
  288. assert list(multiset_partitions('ab', 1)) == [[['a', 'b']]]
  289. assert list(multiset_partitions('aaa', 1)) == [['aaa']]
  290. assert list(multiset_partitions([1, 1], 1)) == [[[1, 1]]]
  291. ans = [('mpsyy',), ('mpsy', 'y'), ('mps', 'yy'), ('mps', 'y', 'y'),
  292. ('mpyy', 's'), ('mpy', 'sy'), ('mpy', 's', 'y'), ('mp', 'syy'),
  293. ('mp', 'sy', 'y'), ('mp', 's', 'yy'), ('mp', 's', 'y', 'y'),
  294. ('msyy', 'p'), ('msy', 'py'), ('msy', 'p', 'y'), ('ms', 'pyy'),
  295. ('ms', 'py', 'y'), ('ms', 'p', 'yy'), ('ms', 'p', 'y', 'y'),
  296. ('myy', 'ps'), ('myy', 'p', 's'), ('my', 'psy'), ('my', 'ps', 'y'),
  297. ('my', 'py', 's'), ('my', 'p', 'sy'), ('my', 'p', 's', 'y'),
  298. ('m', 'psyy'), ('m', 'psy', 'y'), ('m', 'ps', 'yy'),
  299. ('m', 'ps', 'y', 'y'), ('m', 'pyy', 's'), ('m', 'py', 'sy'),
  300. ('m', 'py', 's', 'y'), ('m', 'p', 'syy'),
  301. ('m', 'p', 'sy', 'y'), ('m', 'p', 's', 'yy'),
  302. ('m', 'p', 's', 'y', 'y')]
  303. assert [tuple("".join(part) for part in p)
  304. for p in multiset_partitions('sympy')] == ans
  305. factorings = [[24], [8, 3], [12, 2], [4, 6], [4, 2, 3],
  306. [6, 2, 2], [2, 2, 2, 3]]
  307. assert [factoring_visitor(p, [2,3]) for
  308. p in multiset_partitions_taocp([3, 1])] == factorings
  309. def test_multiset_combinations():
  310. ans = ['iii', 'iim', 'iip', 'iis', 'imp', 'ims', 'ipp', 'ips',
  311. 'iss', 'mpp', 'mps', 'mss', 'pps', 'pss', 'sss']
  312. assert [''.join(i) for i in
  313. list(multiset_combinations('mississippi', 3))] == ans
  314. M = multiset('mississippi')
  315. assert [''.join(i) for i in
  316. list(multiset_combinations(M, 3))] == ans
  317. assert [''.join(i) for i in multiset_combinations(M, 30)] == []
  318. assert list(multiset_combinations([[1], [2, 3]], 2)) == [[[1], [2, 3]]]
  319. assert len(list(multiset_combinations('a', 3))) == 0
  320. assert len(list(multiset_combinations('a', 0))) == 1
  321. assert list(multiset_combinations('abc', 1)) == [['a'], ['b'], ['c']]
  322. raises(ValueError, lambda: list(multiset_combinations({0: 3, 1: -1}, 2)))
  323. def test_multiset_permutations():
  324. ans = ['abby', 'abyb', 'aybb', 'baby', 'bayb', 'bbay', 'bbya', 'byab',
  325. 'byba', 'yabb', 'ybab', 'ybba']
  326. assert [''.join(i) for i in multiset_permutations('baby')] == ans
  327. assert [''.join(i) for i in multiset_permutations(multiset('baby'))] == ans
  328. assert list(multiset_permutations([0, 0, 0], 2)) == [[0, 0]]
  329. assert list(multiset_permutations([0, 2, 1], 2)) == [
  330. [0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]]
  331. assert len(list(multiset_permutations('a', 0))) == 1
  332. assert len(list(multiset_permutations('a', 3))) == 0
  333. for nul in ([], {}, ''):
  334. assert list(multiset_permutations(nul)) == [[]]
  335. assert list(multiset_permutations(nul, 0)) == [[]]
  336. # impossible requests give no result
  337. assert list(multiset_permutations(nul, 1)) == []
  338. assert list(multiset_permutations(nul, -1)) == []
  339. def test():
  340. for i in range(1, 7):
  341. print(i)
  342. for p in multiset_permutations([0, 0, 1, 0, 1], i):
  343. print(p)
  344. assert capture(lambda: test()) == dedent('''\
  345. 1
  346. [0]
  347. [1]
  348. 2
  349. [0, 0]
  350. [0, 1]
  351. [1, 0]
  352. [1, 1]
  353. 3
  354. [0, 0, 0]
  355. [0, 0, 1]
  356. [0, 1, 0]
  357. [0, 1, 1]
  358. [1, 0, 0]
  359. [1, 0, 1]
  360. [1, 1, 0]
  361. 4
  362. [0, 0, 0, 1]
  363. [0, 0, 1, 0]
  364. [0, 0, 1, 1]
  365. [0, 1, 0, 0]
  366. [0, 1, 0, 1]
  367. [0, 1, 1, 0]
  368. [1, 0, 0, 0]
  369. [1, 0, 0, 1]
  370. [1, 0, 1, 0]
  371. [1, 1, 0, 0]
  372. 5
  373. [0, 0, 0, 1, 1]
  374. [0, 0, 1, 0, 1]
  375. [0, 0, 1, 1, 0]
  376. [0, 1, 0, 0, 1]
  377. [0, 1, 0, 1, 0]
  378. [0, 1, 1, 0, 0]
  379. [1, 0, 0, 0, 1]
  380. [1, 0, 0, 1, 0]
  381. [1, 0, 1, 0, 0]
  382. [1, 1, 0, 0, 0]
  383. 6\n''')
  384. raises(ValueError, lambda: list(multiset_permutations({0: 3, 1: -1})))
  385. def test_partitions():
  386. ans = [[{}], [(0, {})]]
  387. for i in range(2):
  388. assert list(partitions(0, size=i)) == ans[i]
  389. assert list(partitions(1, 0, size=i)) == ans[i]
  390. assert list(partitions(6, 2, 2, size=i)) == ans[i]
  391. assert list(partitions(6, 2, None, size=i)) != ans[i]
  392. assert list(partitions(6, None, 2, size=i)) != ans[i]
  393. assert list(partitions(6, 2, 0, size=i)) == ans[i]
  394. assert list(partitions(6, k=2)) == [
  395. {2: 3}, {1: 2, 2: 2}, {1: 4, 2: 1}, {1: 6}]
  396. assert list(partitions(6, k=3)) == [
  397. {3: 2}, {1: 1, 2: 1, 3: 1}, {1: 3, 3: 1}, {2: 3}, {1: 2, 2: 2},
  398. {1: 4, 2: 1}, {1: 6}]
  399. assert list(partitions(8, k=4, m=3)) == [
  400. {4: 2}, {1: 1, 3: 1, 4: 1}, {2: 2, 4: 1}, {2: 1, 3: 2}] == [
  401. i for i in partitions(8, k=4, m=3) if all(k <= 4 for k in i)
  402. and sum(i.values()) <=3]
  403. assert list(partitions(S(3), m=2)) == [
  404. {3: 1}, {1: 1, 2: 1}]
  405. assert list(partitions(4, k=3)) == [
  406. {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}] == [
  407. i for i in partitions(4) if all(k <= 3 for k in i)]
  408. # Consistency check on output of _partitions and RGS_unrank.
  409. # This provides a sanity test on both routines. Also verifies that
  410. # the total number of partitions is the same in each case.
  411. # (from pkrathmann2)
  412. for n in range(2, 6):
  413. i = 0
  414. for m, q in _set_partitions(n):
  415. assert q == RGS_unrank(i, n)
  416. i += 1
  417. assert i == RGS_enum(n)
  418. def test_binary_partitions():
  419. assert [i[:] for i in binary_partitions(10)] == [[8, 2], [8, 1, 1],
  420. [4, 4, 2], [4, 4, 1, 1], [4, 2, 2, 2], [4, 2, 2, 1, 1],
  421. [4, 2, 1, 1, 1, 1], [4, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2],
  422. [2, 2, 2, 2, 1, 1], [2, 2, 2, 1, 1, 1, 1], [2, 2, 1, 1, 1, 1, 1, 1],
  423. [2, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
  424. assert len([j[:] for j in binary_partitions(16)]) == 36
  425. def test_bell_perm():
  426. assert [len(set(generate_bell(i))) for i in range(1, 7)] == [
  427. factorial(i) for i in range(1, 7)]
  428. assert list(generate_bell(3)) == [
  429. (0, 1, 2), (0, 2, 1), (2, 0, 1), (2, 1, 0), (1, 2, 0), (1, 0, 2)]
  430. # generate_bell and trotterjohnson are advertised to return the same
  431. # permutations; this is not technically necessary so this test could
  432. # be removed
  433. for n in range(1, 5):
  434. p = Permutation(range(n))
  435. b = generate_bell(n)
  436. for bi in b:
  437. assert bi == tuple(p.array_form)
  438. p = p.next_trotterjohnson()
  439. raises(ValueError, lambda: list(generate_bell(0))) # XXX is this consistent with other permutation algorithms?
  440. def test_involutions():
  441. lengths = [1, 2, 4, 10, 26, 76]
  442. for n, N in enumerate(lengths):
  443. i = list(generate_involutions(n + 1))
  444. assert len(i) == N
  445. assert len({Permutation(j)**2 for j in i}) == 1
  446. def test_derangements():
  447. assert len(list(generate_derangements(list(range(6))))) == 265
  448. assert ''.join(''.join(i) for i in generate_derangements('abcde')) == (
  449. 'badecbaecdbcaedbcdeabceadbdaecbdeacbdecabeacdbedacbedcacabedcadebcaebd'
  450. 'cdaebcdbeacdeabcdebaceabdcebadcedabcedbadabecdaebcdaecbdcaebdcbeadceab'
  451. 'dcebadeabcdeacbdebacdebcaeabcdeadbceadcbecabdecbadecdabecdbaedabcedacb'
  452. 'edbacedbca')
  453. assert list(generate_derangements([0, 1, 2, 3])) == [
  454. [1, 0, 3, 2], [1, 2, 3, 0], [1, 3, 0, 2], [2, 0, 3, 1],
  455. [2, 3, 0, 1], [2, 3, 1, 0], [3, 0, 1, 2], [3, 2, 0, 1], [3, 2, 1, 0]]
  456. assert list(generate_derangements([0, 1, 2, 2])) == [
  457. [2, 2, 0, 1], [2, 2, 1, 0]]
  458. assert list(generate_derangements('ba')) == [list('ab')]
  459. # multiset_derangements
  460. D = multiset_derangements
  461. assert list(D('abb')) == []
  462. assert [''.join(i) for i in D('ab')] == ['ba']
  463. assert [''.join(i) for i in D('abc')] == ['bca', 'cab']
  464. assert [''.join(i) for i in D('aabb')] == ['bbaa']
  465. assert [''.join(i) for i in D('aabbcccc')] == [
  466. 'ccccaabb', 'ccccabab', 'ccccabba', 'ccccbaab', 'ccccbaba',
  467. 'ccccbbaa']
  468. assert [''.join(i) for i in D('aabbccc')] == [
  469. 'cccabba', 'cccabab', 'cccaabb', 'ccacbba', 'ccacbab',
  470. 'ccacabb', 'cbccbaa', 'cbccaba', 'cbccaab', 'bcccbaa',
  471. 'bcccaba', 'bcccaab']
  472. assert [''.join(i) for i in D('books')] == ['kbsoo', 'ksboo',
  473. 'sbkoo', 'skboo', 'oksbo', 'oskbo', 'okbso', 'obkso', 'oskob',
  474. 'oksob', 'osbok', 'obsok']
  475. assert list(generate_derangements([[3], [2], [2], [1]])) == [
  476. [[2], [1], [3], [2]], [[2], [3], [1], [2]]]
  477. def test_necklaces():
  478. def count(n, k, f):
  479. return len(list(necklaces(n, k, f)))
  480. m = []
  481. for i in range(1, 8):
  482. m.append((
  483. i, count(i, 2, 0), count(i, 2, 1), count(i, 3, 1)))
  484. assert Matrix(m) == Matrix([
  485. [1, 2, 2, 3],
  486. [2, 3, 3, 6],
  487. [3, 4, 4, 10],
  488. [4, 6, 6, 21],
  489. [5, 8, 8, 39],
  490. [6, 14, 13, 92],
  491. [7, 20, 18, 198]])
  492. def test_bracelets():
  493. bc = list(bracelets(2, 4))
  494. assert Matrix(bc) == Matrix([
  495. [0, 0],
  496. [0, 1],
  497. [0, 2],
  498. [0, 3],
  499. [1, 1],
  500. [1, 2],
  501. [1, 3],
  502. [2, 2],
  503. [2, 3],
  504. [3, 3]
  505. ])
  506. bc = list(bracelets(4, 2))
  507. assert Matrix(bc) == Matrix([
  508. [0, 0, 0, 0],
  509. [0, 0, 0, 1],
  510. [0, 0, 1, 1],
  511. [0, 1, 0, 1],
  512. [0, 1, 1, 1],
  513. [1, 1, 1, 1]
  514. ])
  515. def test_generate_oriented_forest():
  516. assert list(generate_oriented_forest(5)) == [[0, 1, 2, 3, 4],
  517. [0, 1, 2, 3, 3], [0, 1, 2, 3, 2], [0, 1, 2, 3, 1], [0, 1, 2, 3, 0],
  518. [0, 1, 2, 2, 2], [0, 1, 2, 2, 1], [0, 1, 2, 2, 0], [0, 1, 2, 1, 2],
  519. [0, 1, 2, 1, 1], [0, 1, 2, 1, 0], [0, 1, 2, 0, 1], [0, 1, 2, 0, 0],
  520. [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 1], [0, 1, 1, 0, 0],
  521. [0, 1, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0]]
  522. assert len(list(generate_oriented_forest(10))) == 1842
  523. def test_unflatten():
  524. r = list(range(10))
  525. assert unflatten(r) == list(zip(r[::2], r[1::2]))
  526. assert unflatten(r, 5) == [tuple(r[:5]), tuple(r[5:])]
  527. raises(ValueError, lambda: unflatten(list(range(10)), 3))
  528. raises(ValueError, lambda: unflatten(list(range(10)), -2))
  529. def test_common_prefix_suffix():
  530. assert common_prefix([], [1]) == []
  531. assert common_prefix(list(range(3))) == [0, 1, 2]
  532. assert common_prefix(list(range(3)), list(range(4))) == [0, 1, 2]
  533. assert common_prefix([1, 2, 3], [1, 2, 5]) == [1, 2]
  534. assert common_prefix([1, 2, 3], [1, 3, 5]) == [1]
  535. assert common_suffix([], [1]) == []
  536. assert common_suffix(list(range(3))) == [0, 1, 2]
  537. assert common_suffix(list(range(3)), list(range(3))) == [0, 1, 2]
  538. assert common_suffix(list(range(3)), list(range(4))) == []
  539. assert common_suffix([1, 2, 3], [9, 2, 3]) == [2, 3]
  540. assert common_suffix([1, 2, 3], [9, 7, 3]) == [3]
  541. def test_minlex():
  542. assert minlex([1, 2, 0]) == (0, 1, 2)
  543. assert minlex((1, 2, 0)) == (0, 1, 2)
  544. assert minlex((1, 0, 2)) == (0, 2, 1)
  545. assert minlex((1, 0, 2), directed=False) == (0, 1, 2)
  546. assert minlex('aba') == 'aab'
  547. assert minlex(('bb', 'aaa', 'c', 'a'), key=len) == ('c', 'a', 'bb', 'aaa')
  548. def test_ordered():
  549. assert list(ordered((x, y), hash, default=False)) in [[x, y], [y, x]]
  550. assert list(ordered((x, y), hash, default=False)) == \
  551. list(ordered((y, x), hash, default=False))
  552. assert list(ordered((x, y))) == [x, y]
  553. seq, keys = [[[1, 2, 1], [0, 3, 1], [1, 1, 3], [2], [1]],
  554. (lambda x: len(x), lambda x: sum(x))]
  555. assert list(ordered(seq, keys, default=False, warn=False)) == \
  556. [[1], [2], [1, 2, 1], [0, 3, 1], [1, 1, 3]]
  557. raises(ValueError, lambda:
  558. list(ordered(seq, keys, default=False, warn=True)))
  559. def test_runs():
  560. assert runs([]) == []
  561. assert runs([1]) == [[1]]
  562. assert runs([1, 1]) == [[1], [1]]
  563. assert runs([1, 1, 2]) == [[1], [1, 2]]
  564. assert runs([1, 2, 1]) == [[1, 2], [1]]
  565. assert runs([2, 1, 1]) == [[2], [1], [1]]
  566. from operator import lt
  567. assert runs([2, 1, 1], lt) == [[2, 1], [1]]
  568. def test_reshape():
  569. seq = list(range(1, 9))
  570. assert reshape(seq, [4]) == \
  571. [[1, 2, 3, 4], [5, 6, 7, 8]]
  572. assert reshape(seq, (4,)) == \
  573. [(1, 2, 3, 4), (5, 6, 7, 8)]
  574. assert reshape(seq, (2, 2)) == \
  575. [(1, 2, 3, 4), (5, 6, 7, 8)]
  576. assert reshape(seq, (2, [2])) == \
  577. [(1, 2, [3, 4]), (5, 6, [7, 8])]
  578. assert reshape(seq, ((2,), [2])) == \
  579. [((1, 2), [3, 4]), ((5, 6), [7, 8])]
  580. assert reshape(seq, (1, [2], 1)) == \
  581. [(1, [2, 3], 4), (5, [6, 7], 8)]
  582. assert reshape(tuple(seq), ([[1], 1, (2,)],)) == \
  583. (([[1], 2, (3, 4)],), ([[5], 6, (7, 8)],))
  584. assert reshape(tuple(seq), ([1], 1, (2,))) == \
  585. (([1], 2, (3, 4)), ([5], 6, (7, 8)))
  586. assert reshape(list(range(12)), [2, [3], {2}, (1, (3,), 1)]) == \
  587. [[0, 1, [2, 3, 4], {5, 6}, (7, (8, 9, 10), 11)]]
  588. raises(ValueError, lambda: reshape([0, 1], [-1]))
  589. raises(ValueError, lambda: reshape([0, 1], [3]))
  590. def test_uniq():
  591. assert list(uniq(p for p in partitions(4))) == \
  592. [{4: 1}, {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}]
  593. assert list(uniq(x % 2 for x in range(5))) == [0, 1]
  594. assert list(uniq('a')) == ['a']
  595. assert list(uniq('ababc')) == list('abc')
  596. assert list(uniq([[1], [2, 1], [1]])) == [[1], [2, 1]]
  597. assert list(uniq(permutations(i for i in [[1], 2, 2]))) == \
  598. [([1], 2, 2), (2, [1], 2), (2, 2, [1])]
  599. assert list(uniq([2, 3, 2, 4, [2], [1], [2], [3], [1]])) == \
  600. [2, 3, 4, [2], [1], [3]]
  601. f = [1]
  602. raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)])
  603. f = [[1]]
  604. raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)])
  605. def test_kbins():
  606. assert len(list(kbins('1123', 2, ordered=1))) == 24
  607. assert len(list(kbins('1123', 2, ordered=11))) == 36
  608. assert len(list(kbins('1123', 2, ordered=10))) == 10
  609. assert len(list(kbins('1123', 2, ordered=0))) == 5
  610. assert len(list(kbins('1123', 2, ordered=None))) == 3
  611. def test1():
  612. for orderedval in [None, 0, 1, 10, 11]:
  613. print('ordered =', orderedval)
  614. for p in kbins([0, 0, 1], 2, ordered=orderedval):
  615. print(' ', p)
  616. assert capture(lambda : test1()) == dedent('''\
  617. ordered = None
  618. [[0], [0, 1]]
  619. [[0, 0], [1]]
  620. ordered = 0
  621. [[0, 0], [1]]
  622. [[0, 1], [0]]
  623. ordered = 1
  624. [[0], [0, 1]]
  625. [[0], [1, 0]]
  626. [[1], [0, 0]]
  627. ordered = 10
  628. [[0, 0], [1]]
  629. [[1], [0, 0]]
  630. [[0, 1], [0]]
  631. [[0], [0, 1]]
  632. ordered = 11
  633. [[0], [0, 1]]
  634. [[0, 0], [1]]
  635. [[0], [1, 0]]
  636. [[0, 1], [0]]
  637. [[1], [0, 0]]
  638. [[1, 0], [0]]\n''')
  639. def test2():
  640. for orderedval in [None, 0, 1, 10, 11]:
  641. print('ordered =', orderedval)
  642. for p in kbins(list(range(3)), 2, ordered=orderedval):
  643. print(' ', p)
  644. assert capture(lambda : test2()) == dedent('''\
  645. ordered = None
  646. [[0], [1, 2]]
  647. [[0, 1], [2]]
  648. ordered = 0
  649. [[0, 1], [2]]
  650. [[0, 2], [1]]
  651. [[0], [1, 2]]
  652. ordered = 1
  653. [[0], [1, 2]]
  654. [[0], [2, 1]]
  655. [[1], [0, 2]]
  656. [[1], [2, 0]]
  657. [[2], [0, 1]]
  658. [[2], [1, 0]]
  659. ordered = 10
  660. [[0, 1], [2]]
  661. [[2], [0, 1]]
  662. [[0, 2], [1]]
  663. [[1], [0, 2]]
  664. [[0], [1, 2]]
  665. [[1, 2], [0]]
  666. ordered = 11
  667. [[0], [1, 2]]
  668. [[0, 1], [2]]
  669. [[0], [2, 1]]
  670. [[0, 2], [1]]
  671. [[1], [0, 2]]
  672. [[1, 0], [2]]
  673. [[1], [2, 0]]
  674. [[1, 2], [0]]
  675. [[2], [0, 1]]
  676. [[2, 0], [1]]
  677. [[2], [1, 0]]
  678. [[2, 1], [0]]\n''')
  679. def test_has_dups():
  680. assert has_dups(set()) is False
  681. assert has_dups(list(range(3))) is False
  682. assert has_dups([1, 2, 1]) is True
  683. assert has_dups([[1], [1]]) is True
  684. assert has_dups([[1], [2]]) is False
  685. def test__partition():
  686. assert _partition('abcde', [1, 0, 1, 2, 0]) == [
  687. ['b', 'e'], ['a', 'c'], ['d']]
  688. assert _partition('abcde', [1, 0, 1, 2, 0], 3) == [
  689. ['b', 'e'], ['a', 'c'], ['d']]
  690. output = (3, [1, 0, 1, 2, 0])
  691. assert _partition('abcde', *output) == [['b', 'e'], ['a', 'c'], ['d']]
  692. def test_ordered_partitions():
  693. from sympy.functions.combinatorial.numbers import nT
  694. f = ordered_partitions
  695. assert list(f(0, 1)) == [[]]
  696. assert list(f(1, 0)) == [[]]
  697. for i in range(1, 7):
  698. for j in [None] + list(range(1, i)):
  699. assert (
  700. sum(1 for p in f(i, j, 1)) ==
  701. sum(1 for p in f(i, j, 0)) ==
  702. nT(i, j))
  703. def test_rotations():
  704. assert list(rotations('ab')) == [['a', 'b'], ['b', 'a']]
  705. assert list(rotations(range(3))) == [[0, 1, 2], [1, 2, 0], [2, 0, 1]]
  706. assert list(rotations(range(3), dir=-1)) == [[0, 1, 2], [2, 0, 1], [1, 2, 0]]
  707. def test_ibin():
  708. assert ibin(3) == [1, 1]
  709. assert ibin(3, 3) == [0, 1, 1]
  710. assert ibin(3, str=True) == '11'
  711. assert ibin(3, 3, str=True) == '011'
  712. assert list(ibin(2, 'all')) == [(0, 0), (0, 1), (1, 0), (1, 1)]
  713. assert list(ibin(2, '', str=True)) == ['00', '01', '10', '11']
  714. raises(ValueError, lambda: ibin(-.5))
  715. raises(ValueError, lambda: ibin(2, 1))
  716. def test_iterable():
  717. assert iterable(0) is False
  718. assert iterable(1) is False
  719. assert iterable(None) is False
  720. class Test1(NotIterable):
  721. pass
  722. assert iterable(Test1()) is False
  723. class Test2(NotIterable):
  724. _iterable = True
  725. assert iterable(Test2()) is True
  726. class Test3:
  727. pass
  728. assert iterable(Test3()) is False
  729. class Test4:
  730. _iterable = True
  731. assert iterable(Test4()) is True
  732. class Test5:
  733. def __iter__(self):
  734. yield 1
  735. assert iterable(Test5()) is True
  736. class Test6(Test5):
  737. _iterable = False
  738. assert iterable(Test6()) is False
  739. def test_sequence_partitions():
  740. assert list(sequence_partitions([1], 1)) == [[[1]]]
  741. assert list(sequence_partitions([1, 2], 1)) == [[[1, 2]]]
  742. assert list(sequence_partitions([1, 2], 2)) == [[[1], [2]]]
  743. assert list(sequence_partitions([1, 2, 3], 1)) == [[[1, 2, 3]]]
  744. assert list(sequence_partitions([1, 2, 3], 2)) == \
  745. [[[1], [2, 3]], [[1, 2], [3]]]
  746. assert list(sequence_partitions([1, 2, 3], 3)) == [[[1], [2], [3]]]
  747. # Exceptional cases
  748. assert list(sequence_partitions([], 0)) == []
  749. assert list(sequence_partitions([], 1)) == []
  750. assert list(sequence_partitions([1, 2], 0)) == []
  751. assert list(sequence_partitions([1, 2], 3)) == []
  752. def test_sequence_partitions_empty():
  753. assert list(sequence_partitions_empty([], 1)) == [[[]]]
  754. assert list(sequence_partitions_empty([], 2)) == [[[], []]]
  755. assert list(sequence_partitions_empty([], 3)) == [[[], [], []]]
  756. assert list(sequence_partitions_empty([1], 1)) == [[[1]]]
  757. assert list(sequence_partitions_empty([1], 2)) == [[[], [1]], [[1], []]]
  758. assert list(sequence_partitions_empty([1], 3)) == \
  759. [[[], [], [1]], [[], [1], []], [[1], [], []]]
  760. assert list(sequence_partitions_empty([1, 2], 1)) == [[[1, 2]]]
  761. assert list(sequence_partitions_empty([1, 2], 2)) == \
  762. [[[], [1, 2]], [[1], [2]], [[1, 2], []]]
  763. assert list(sequence_partitions_empty([1, 2], 3)) == [
  764. [[], [], [1, 2]], [[], [1], [2]], [[], [1, 2], []],
  765. [[1], [], [2]], [[1], [2], []], [[1, 2], [], []]
  766. ]
  767. assert list(sequence_partitions_empty([1, 2, 3], 1)) == [[[1, 2, 3]]]
  768. assert list(sequence_partitions_empty([1, 2, 3], 2)) == \
  769. [[[], [1, 2, 3]], [[1], [2, 3]], [[1, 2], [3]], [[1, 2, 3], []]]
  770. assert list(sequence_partitions_empty([1, 2, 3], 3)) == [
  771. [[], [], [1, 2, 3]], [[], [1], [2, 3]],
  772. [[], [1, 2], [3]], [[], [1, 2, 3], []],
  773. [[1], [], [2, 3]], [[1], [2], [3]],
  774. [[1], [2, 3], []], [[1, 2], [], [3]],
  775. [[1, 2], [3], []], [[1, 2, 3], [], []]
  776. ]
  777. # Exceptional cases
  778. assert list(sequence_partitions([], 0)) == []
  779. assert list(sequence_partitions([1], 0)) == []
  780. assert list(sequence_partitions([1, 2], 0)) == []