test_convolutions.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. from sympy.core.numbers import (E, Rational, pi)
  2. from sympy.functions.elementary.exponential import exp
  3. from sympy.functions.elementary.miscellaneous import sqrt
  4. from sympy.core import S, symbols, I
  5. from sympy.discrete.convolutions import (
  6. convolution, convolution_fft, convolution_ntt, convolution_fwht,
  7. convolution_subset, covering_product, intersecting_product)
  8. from sympy.testing.pytest import raises
  9. from sympy.abc import x, y
  10. def test_convolution():
  11. # fft
  12. a = [1, Rational(5, 3), sqrt(3), Rational(7, 5)]
  13. b = [9, 5, 5, 4, 3, 2]
  14. c = [3, 5, 3, 7, 8]
  15. d = [1422, 6572, 3213, 5552]
  16. assert convolution(a, b) == convolution_fft(a, b)
  17. assert convolution(a, b, dps=9) == convolution_fft(a, b, dps=9)
  18. assert convolution(a, d, dps=7) == convolution_fft(d, a, dps=7)
  19. assert convolution(a, d[1:], dps=3) == convolution_fft(d[1:], a, dps=3)
  20. # prime moduli of the form (m*2**k + 1), sequence length
  21. # should be a divisor of 2**k
  22. p = 7*17*2**23 + 1
  23. q = 19*2**10 + 1
  24. # ntt
  25. assert convolution(d, b, prime=q) == convolution_ntt(b, d, prime=q)
  26. assert convolution(c, b, prime=p) == convolution_ntt(b, c, prime=p)
  27. assert convolution(d, c, prime=p) == convolution_ntt(c, d, prime=p)
  28. raises(TypeError, lambda: convolution(b, d, dps=5, prime=q))
  29. raises(TypeError, lambda: convolution(b, d, dps=6, prime=q))
  30. # fwht
  31. assert convolution(a, b, dyadic=True) == convolution_fwht(a, b)
  32. assert convolution(a, b, dyadic=False) == convolution(a, b)
  33. raises(TypeError, lambda: convolution(b, d, dps=2, dyadic=True))
  34. raises(TypeError, lambda: convolution(b, d, prime=p, dyadic=True))
  35. raises(TypeError, lambda: convolution(a, b, dps=2, dyadic=True))
  36. raises(TypeError, lambda: convolution(b, c, prime=p, dyadic=True))
  37. # subset
  38. assert convolution(a, b, subset=True) == convolution_subset(a, b) == \
  39. convolution(a, b, subset=True, dyadic=False) == \
  40. convolution(a, b, subset=True)
  41. assert convolution(a, b, subset=False) == convolution(a, b)
  42. raises(TypeError, lambda: convolution(a, b, subset=True, dyadic=True))
  43. raises(TypeError, lambda: convolution(c, d, subset=True, dps=6))
  44. raises(TypeError, lambda: convolution(a, c, subset=True, prime=q))
  45. def test_cyclic_convolution():
  46. # fft
  47. a = [1, Rational(5, 3), sqrt(3), Rational(7, 5)]
  48. b = [9, 5, 5, 4, 3, 2]
  49. assert convolution([1, 2, 3], [4, 5, 6], cycle=0) == \
  50. convolution([1, 2, 3], [4, 5, 6], cycle=5) == \
  51. convolution([1, 2, 3], [4, 5, 6])
  52. assert convolution([1, 2, 3], [4, 5, 6], cycle=3) == [31, 31, 28]
  53. a = [Rational(1, 3), Rational(7, 3), Rational(5, 9), Rational(2, 7), Rational(5, 8)]
  54. b = [Rational(3, 5), Rational(4, 7), Rational(7, 8), Rational(8, 9)]
  55. assert convolution(a, b, cycle=0) == \
  56. convolution(a, b, cycle=len(a) + len(b) - 1)
  57. assert convolution(a, b, cycle=4) == [Rational(87277, 26460), Rational(30521, 11340),
  58. Rational(11125, 4032), Rational(3653, 1080)]
  59. assert convolution(a, b, cycle=6) == [Rational(20177, 20160), Rational(676, 315), Rational(47, 24),
  60. Rational(3053, 1080), Rational(16397, 5292), Rational(2497, 2268)]
  61. assert convolution(a, b, cycle=9) == \
  62. convolution(a, b, cycle=0) + [S.Zero]
  63. # ntt
  64. a = [2313, 5323532, S(3232), 42142, 42242421]
  65. b = [S(33456), 56757, 45754, 432423]
  66. assert convolution(a, b, prime=19*2**10 + 1, cycle=0) == \
  67. convolution(a, b, prime=19*2**10 + 1, cycle=8) == \
  68. convolution(a, b, prime=19*2**10 + 1)
  69. assert convolution(a, b, prime=19*2**10 + 1, cycle=5) == [96, 17146, 2664,
  70. 15534, 3517]
  71. assert convolution(a, b, prime=19*2**10 + 1, cycle=7) == [4643, 3458, 1260,
  72. 15534, 3517, 16314, 13688]
  73. assert convolution(a, b, prime=19*2**10 + 1, cycle=9) == \
  74. convolution(a, b, prime=19*2**10 + 1) + [0]
  75. # fwht
  76. u, v, w, x, y = symbols('u v w x y')
  77. p, q, r, s, t = symbols('p q r s t')
  78. c = [u, v, w, x, y]
  79. d = [p, q, r, s, t]
  80. assert convolution(a, b, dyadic=True, cycle=3) == \
  81. [2499522285783, 19861417974796, 4702176579021]
  82. assert convolution(a, b, dyadic=True, cycle=5) == [2718149225143,
  83. 2114320852171, 20571217906407, 246166418903, 1413262436976]
  84. assert convolution(c, d, dyadic=True, cycle=4) == \
  85. [p*u + p*y + q*v + r*w + s*x + t*u + t*y,
  86. p*v + q*u + q*y + r*x + s*w + t*v,
  87. p*w + q*x + r*u + r*y + s*v + t*w,
  88. p*x + q*w + r*v + s*u + s*y + t*x]
  89. assert convolution(c, d, dyadic=True, cycle=6) == \
  90. [p*u + q*v + r*w + r*y + s*x + t*w + t*y,
  91. p*v + q*u + r*x + s*w + s*y + t*x,
  92. p*w + q*x + r*u + s*v,
  93. p*x + q*w + r*v + s*u,
  94. p*y + t*u,
  95. q*y + t*v]
  96. # subset
  97. assert convolution(a, b, subset=True, cycle=7) == [18266671799811,
  98. 178235365533, 213958794, 246166418903, 1413262436976,
  99. 2397553088697, 1932759730434]
  100. assert convolution(a[1:], b, subset=True, cycle=4) == \
  101. [178104086592, 302255835516, 244982785880, 3717819845434]
  102. assert convolution(a, b[:-1], subset=True, cycle=6) == [1932837114162,
  103. 178235365533, 213958794, 245166224504, 1413262436976, 2397553088697]
  104. assert convolution(c, d, subset=True, cycle=3) == \
  105. [p*u + p*x + q*w + r*v + r*y + s*u + t*w,
  106. p*v + p*y + q*u + s*y + t*u + t*x,
  107. p*w + q*y + r*u + t*v]
  108. assert convolution(c, d, subset=True, cycle=5) == \
  109. [p*u + q*y + t*v,
  110. p*v + q*u + r*y + t*w,
  111. p*w + r*u + s*y + t*x,
  112. p*x + q*w + r*v + s*u,
  113. p*y + t*u]
  114. raises(ValueError, lambda: convolution([1, 2, 3], [4, 5, 6], cycle=-1))
  115. def test_convolution_fft():
  116. assert all(convolution_fft([], x, dps=y) == [] for x in ([], [1]) for y in (None, 3))
  117. assert convolution_fft([1, 2, 3], [4, 5, 6]) == [4, 13, 28, 27, 18]
  118. assert convolution_fft([1], [5, 6, 7]) == [5, 6, 7]
  119. assert convolution_fft([1, 3], [5, 6, 7]) == [5, 21, 25, 21]
  120. assert convolution_fft([1 + 2*I], [2 + 3*I]) == [-4 + 7*I]
  121. assert convolution_fft([1 + 2*I, 3 + 4*I, 5 + 3*I/5], [Rational(2, 5) + 4*I/7]) == \
  122. [Rational(-26, 35) + I*48/35, Rational(-38, 35) + I*116/35, Rational(58, 35) + I*542/175]
  123. assert convolution_fft([Rational(3, 4), Rational(5, 6)], [Rational(7, 8), Rational(1, 3), Rational(2, 5)]) == \
  124. [Rational(21, 32), Rational(47, 48), Rational(26, 45), Rational(1, 3)]
  125. assert convolution_fft([Rational(1, 9), Rational(2, 3), Rational(3, 5)], [Rational(2, 5), Rational(3, 7), Rational(4, 9)]) == \
  126. [Rational(2, 45), Rational(11, 35), Rational(8152, 14175), Rational(523, 945), Rational(4, 15)]
  127. assert convolution_fft([pi, E, sqrt(2)], [sqrt(3), 1/pi, 1/E]) == \
  128. [sqrt(3)*pi, 1 + sqrt(3)*E, E/pi + pi*exp(-1) + sqrt(6),
  129. sqrt(2)/pi + 1, sqrt(2)*exp(-1)]
  130. assert convolution_fft([2321, 33123], [5321, 6321, 71323]) == \
  131. [12350041, 190918524, 374911166, 2362431729]
  132. assert convolution_fft([312313, 31278232], [32139631, 319631]) == \
  133. [10037624576503, 1005370659728895, 9997492572392]
  134. raises(TypeError, lambda: convolution_fft(x, y))
  135. raises(ValueError, lambda: convolution_fft([x, y], [y, x]))
  136. def test_convolution_ntt():
  137. # prime moduli of the form (m*2**k + 1), sequence length
  138. # should be a divisor of 2**k
  139. p = 7*17*2**23 + 1
  140. q = 19*2**10 + 1
  141. r = 2*500000003 + 1 # only for sequences of length 1 or 2
  142. # s = 2*3*5*7 # composite modulus
  143. assert all(convolution_ntt([], x, prime=y) == [] for x in ([], [1]) for y in (p, q, r))
  144. assert convolution_ntt([2], [3], r) == [6]
  145. assert convolution_ntt([2, 3], [4], r) == [8, 12]
  146. assert convolution_ntt([32121, 42144, 4214, 4241], [32132, 3232, 87242], p) == [33867619,
  147. 459741727, 79180879, 831885249, 381344700, 369993322]
  148. assert convolution_ntt([121913, 3171831, 31888131, 12], [17882, 21292, 29921, 312], q) == \
  149. [8158, 3065, 3682, 7090, 1239, 2232, 3744]
  150. assert convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], p) == \
  151. convolution_ntt([12, 19, 21, 98, 67], [2, 6, 7, 8, 9], q)
  152. assert convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], p) == \
  153. convolution_ntt([12, 19, 21, 98, 67], [21, 76, 17, 78, 69], q)
  154. raises(ValueError, lambda: convolution_ntt([2, 3], [4, 5], r))
  155. raises(ValueError, lambda: convolution_ntt([x, y], [y, x], q))
  156. raises(TypeError, lambda: convolution_ntt(x, y, p))
  157. def test_convolution_fwht():
  158. assert convolution_fwht([], []) == []
  159. assert convolution_fwht([], [1]) == []
  160. assert convolution_fwht([1, 2, 3], [4, 5, 6]) == [32, 13, 18, 27]
  161. assert convolution_fwht([Rational(5, 7), Rational(6, 8), Rational(7, 3)], [2, 4, Rational(6, 7)]) == \
  162. [Rational(45, 7), Rational(61, 14), Rational(776, 147), Rational(419, 42)]
  163. a = [1, Rational(5, 3), sqrt(3), Rational(7, 5), 4 + 5*I]
  164. b = [94, 51, 53, 45, 31, 27, 13]
  165. c = [3 + 4*I, 5 + 7*I, 3, Rational(7, 6), 8]
  166. assert convolution_fwht(a, b) == [53*sqrt(3) + 366 + 155*I,
  167. 45*sqrt(3) + Rational(5848, 15) + 135*I,
  168. 94*sqrt(3) + Rational(1257, 5) + 65*I,
  169. 51*sqrt(3) + Rational(3974, 15),
  170. 13*sqrt(3) + 452 + 470*I,
  171. Rational(4513, 15) + 255*I,
  172. 31*sqrt(3) + Rational(1314, 5) + 265*I,
  173. 27*sqrt(3) + Rational(3676, 15) + 225*I]
  174. assert convolution_fwht(b, c) == [Rational(1993, 2) + 733*I, Rational(6215, 6) + 862*I,
  175. Rational(1659, 2) + 527*I, Rational(1988, 3) + 551*I, 1019 + 313*I, Rational(3955, 6) + 325*I,
  176. Rational(1175, 2) + 52*I, Rational(3253, 6) + 91*I]
  177. assert convolution_fwht(a[3:], c) == [Rational(-54, 5) + I*293/5, -1 + I*204/5,
  178. Rational(133, 15) + I*35/6, Rational(409, 30) + 15*I, Rational(56, 5), 32 + 40*I, 0, 0]
  179. u, v, w, x, y, z = symbols('u v w x y z')
  180. assert convolution_fwht([u, v], [x, y]) == [u*x + v*y, u*y + v*x]
  181. assert convolution_fwht([u, v, w], [x, y]) == \
  182. [u*x + v*y, u*y + v*x, w*x, w*y]
  183. assert convolution_fwht([u, v, w], [x, y, z]) == \
  184. [u*x + v*y + w*z, u*y + v*x, u*z + w*x, v*z + w*y]
  185. raises(TypeError, lambda: convolution_fwht(x, y))
  186. raises(TypeError, lambda: convolution_fwht(x*y, u + v))
  187. def test_convolution_subset():
  188. assert convolution_subset([], []) == []
  189. assert convolution_subset([], [Rational(1, 3)]) == []
  190. assert convolution_subset([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7]
  191. a = [1, Rational(5, 3), sqrt(3), 4 + 5*I]
  192. b = [64, 71, 55, 47, 33, 29, 15]
  193. c = [3 + I*2/3, 5 + 7*I, 7, Rational(7, 5), 9]
  194. assert convolution_subset(a, b) == [64, Rational(533, 3), 55 + 64*sqrt(3),
  195. 71*sqrt(3) + Rational(1184, 3) + 320*I, 33, 84,
  196. 15 + 33*sqrt(3), 29*sqrt(3) + 157 + 165*I]
  197. assert convolution_subset(b, c) == [192 + I*128/3, 533 + I*1486/3,
  198. 613 + I*110/3, Rational(5013, 5) + I*1249/3,
  199. 675 + 22*I, 891 + I*751/3,
  200. 771 + 10*I, Rational(3736, 5) + 105*I]
  201. assert convolution_subset(a, c) == convolution_subset(c, a)
  202. assert convolution_subset(a[:2], b) == \
  203. [64, Rational(533, 3), 55, Rational(416, 3), 33, 84, 15, 25]
  204. assert convolution_subset(a[:2], c) == \
  205. [3 + I*2/3, 10 + I*73/9, 7, Rational(196, 15), 9, 15, 0, 0]
  206. u, v, w, x, y, z = symbols('u v w x y z')
  207. assert convolution_subset([u, v, w], [x, y]) == [u*x, u*y + v*x, w*x, w*y]
  208. assert convolution_subset([u, v, w, x], [y, z]) == \
  209. [u*y, u*z + v*y, w*y, w*z + x*y]
  210. assert convolution_subset([u, v], [x, y, z]) == \
  211. convolution_subset([x, y, z], [u, v])
  212. raises(TypeError, lambda: convolution_subset(x, z))
  213. raises(TypeError, lambda: convolution_subset(Rational(7, 3), u))
  214. def test_covering_product():
  215. assert covering_product([], []) == []
  216. assert covering_product([], [Rational(1, 3)]) == []
  217. assert covering_product([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7]
  218. a = [1, Rational(5, 8), sqrt(7), 4 + 9*I]
  219. b = [66, 81, 95, 49, 37, 89, 17]
  220. c = [3 + I*2/3, 51 + 72*I, 7, Rational(7, 15), 91]
  221. assert covering_product(a, b) == [66, Rational(1383, 8), 95 + 161*sqrt(7),
  222. 130*sqrt(7) + 1303 + 2619*I, 37,
  223. Rational(671, 4), 17 + 54*sqrt(7),
  224. 89*sqrt(7) + Rational(4661, 8) + 1287*I]
  225. assert covering_product(b, c) == [198 + 44*I, 7740 + 10638*I,
  226. 1412 + I*190/3, Rational(42684, 5) + I*31202/3,
  227. 9484 + I*74/3, 22163 + I*27394/3,
  228. 10621 + I*34/3, Rational(90236, 15) + 1224*I]
  229. assert covering_product(a, c) == covering_product(c, a)
  230. assert covering_product(b, c[:-1]) == [198 + 44*I, 7740 + 10638*I,
  231. 1412 + I*190/3, Rational(42684, 5) + I*31202/3,
  232. 111 + I*74/3, 6693 + I*27394/3,
  233. 429 + I*34/3, Rational(23351, 15) + 1224*I]
  234. assert covering_product(a, c[:-1]) == [3 + I*2/3,
  235. Rational(339, 4) + I*1409/12, 7 + 10*sqrt(7) + 2*sqrt(7)*I/3,
  236. -403 + 772*sqrt(7)/15 + 72*sqrt(7)*I + I*12658/15]
  237. u, v, w, x, y, z = symbols('u v w x y z')
  238. assert covering_product([u, v, w], [x, y]) == \
  239. [u*x, u*y + v*x + v*y, w*x, w*y]
  240. assert covering_product([u, v, w, x], [y, z]) == \
  241. [u*y, u*z + v*y + v*z, w*y, w*z + x*y + x*z]
  242. assert covering_product([u, v], [x, y, z]) == \
  243. covering_product([x, y, z], [u, v])
  244. raises(TypeError, lambda: covering_product(x, z))
  245. raises(TypeError, lambda: covering_product(Rational(7, 3), u))
  246. def test_intersecting_product():
  247. assert intersecting_product([], []) == []
  248. assert intersecting_product([], [Rational(1, 3)]) == []
  249. assert intersecting_product([6 + I*3/7], [Rational(2, 3)]) == [4 + I*2/7]
  250. a = [1, sqrt(5), Rational(3, 8) + 5*I, 4 + 7*I]
  251. b = [67, 51, 65, 48, 36, 79, 27]
  252. c = [3 + I*2/5, 5 + 9*I, 7, Rational(7, 19), 13]
  253. assert intersecting_product(a, b) == [195*sqrt(5) + Rational(6979, 8) + 1886*I,
  254. 178*sqrt(5) + 520 + 910*I, Rational(841, 2) + 1344*I,
  255. 192 + 336*I, 0, 0, 0, 0]
  256. assert intersecting_product(b, c) == [Rational(128553, 19) + I*9521/5,
  257. Rational(17820, 19) + 1602*I, Rational(19264, 19), Rational(336, 19), 1846, 0, 0, 0]
  258. assert intersecting_product(a, c) == intersecting_product(c, a)
  259. assert intersecting_product(b[1:], c[:-1]) == [Rational(64788, 19) + I*8622/5,
  260. Rational(12804, 19) + 1152*I, Rational(11508, 19), Rational(252, 19), 0, 0, 0, 0]
  261. assert intersecting_product(a, c[:-2]) == \
  262. [Rational(-99, 5) + 10*sqrt(5) + 2*sqrt(5)*I/5 + I*3021/40,
  263. -43 + 5*sqrt(5) + 9*sqrt(5)*I + 71*I, Rational(245, 8) + 84*I, 0]
  264. u, v, w, x, y, z = symbols('u v w x y z')
  265. assert intersecting_product([u, v, w], [x, y]) == \
  266. [u*x + u*y + v*x + w*x + w*y, v*y, 0, 0]
  267. assert intersecting_product([u, v, w, x], [y, z]) == \
  268. [u*y + u*z + v*y + w*y + w*z + x*y, v*z + x*z, 0, 0]
  269. assert intersecting_product([u, v], [x, y, z]) == \
  270. intersecting_product([x, y, z], [u, v])
  271. raises(TypeError, lambda: intersecting_product(x, z))
  272. raises(TypeError, lambda: intersecting_product(u, Rational(8, 3)))