test_transforms.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from sympy.functions.elementary.miscellaneous import sqrt
  2. from sympy.core import S, Symbol, symbols, I, Rational
  3. from sympy.discrete import (fft, ifft, ntt, intt, fwht, ifwht,
  4. mobius_transform, inverse_mobius_transform)
  5. from sympy.testing.pytest import raises
  6. def test_fft_ifft():
  7. assert all(tf(ls) == ls for tf in (fft, ifft)
  8. for ls in ([], [Rational(5, 3)]))
  9. ls = list(range(6))
  10. fls = [15, -7*sqrt(2)/2 - 4 - sqrt(2)*I/2 + 2*I, 2 + 3*I,
  11. -4 + 7*sqrt(2)/2 - 2*I - sqrt(2)*I/2, -3,
  12. -4 + 7*sqrt(2)/2 + sqrt(2)*I/2 + 2*I,
  13. 2 - 3*I, -7*sqrt(2)/2 - 4 - 2*I + sqrt(2)*I/2]
  14. assert fft(ls) == fls
  15. assert ifft(fls) == ls + [S.Zero]*2
  16. ls = [1 + 2*I, 3 + 4*I, 5 + 6*I]
  17. ifls = [Rational(9, 4) + 3*I, I*Rational(-7, 4), Rational(3, 4) + I, -2 - I/4]
  18. assert ifft(ls) == ifls
  19. assert fft(ifls) == ls + [S.Zero]
  20. x = Symbol('x', real=True)
  21. raises(TypeError, lambda: fft(x))
  22. raises(ValueError, lambda: ifft([x, 2*x, 3*x**2, 4*x**3]))
  23. def test_ntt_intt():
  24. # prime moduli of the form (m*2**k + 1), sequence length
  25. # should be a divisor of 2**k
  26. p = 7*17*2**23 + 1
  27. q = 2*500000003 + 1 # only for sequences of length 1 or 2
  28. r = 2*3*5*7 # composite modulus
  29. assert all(tf(ls, p) == ls for tf in (ntt, intt)
  30. for ls in ([], [5]))
  31. ls = list(range(6))
  32. nls = [15, 801133602, 738493201, 334102277, 998244350, 849020224,
  33. 259751156, 12232587]
  34. assert ntt(ls, p) == nls
  35. assert intt(nls, p) == ls + [0]*2
  36. ls = [1 + 2*I, 3 + 4*I, 5 + 6*I]
  37. x = Symbol('x', integer=True)
  38. raises(TypeError, lambda: ntt(x, p))
  39. raises(ValueError, lambda: intt([x, 2*x, 3*x**2, 4*x**3], p))
  40. raises(ValueError, lambda: intt(ls, p))
  41. raises(ValueError, lambda: ntt([1.2, 2.1, 3.5], p))
  42. raises(ValueError, lambda: ntt([3, 5, 6], q))
  43. raises(ValueError, lambda: ntt([4, 5, 7], r))
  44. raises(ValueError, lambda: ntt([1.0, 2.0, 3.0], p))
  45. def test_fwht_ifwht():
  46. assert all(tf(ls) == ls for tf in (fwht, ifwht) \
  47. for ls in ([], [Rational(7, 4)]))
  48. ls = [213, 321, 43235, 5325, 312, 53]
  49. fls = [49459, 38061, -47661, -37759, 48729, 37543, -48391, -38277]
  50. assert fwht(ls) == fls
  51. assert ifwht(fls) == ls + [S.Zero]*2
  52. ls = [S.Half + 2*I, Rational(3, 7) + 4*I, Rational(5, 6) + 6*I, Rational(7, 3), Rational(9, 4)]
  53. ifls = [Rational(533, 672) + I*3/2, Rational(23, 224) + I/2, Rational(1, 672), Rational(107, 224) - I,
  54. Rational(155, 672) + I*3/2, Rational(-103, 224) + I/2, Rational(-377, 672), Rational(-19, 224) - I]
  55. assert ifwht(ls) == ifls
  56. assert fwht(ifls) == ls + [S.Zero]*3
  57. x, y = symbols('x y')
  58. raises(TypeError, lambda: fwht(x))
  59. ls = [x, 2*x, 3*x**2, 4*x**3]
  60. ifls = [x**3 + 3*x**2/4 + x*Rational(3, 4),
  61. -x**3 + 3*x**2/4 - x/4,
  62. -x**3 - 3*x**2/4 + x*Rational(3, 4),
  63. x**3 - 3*x**2/4 - x/4]
  64. assert ifwht(ls) == ifls
  65. assert fwht(ifls) == ls
  66. ls = [x, y, x**2, y**2, x*y]
  67. fls = [x**2 + x*y + x + y**2 + y,
  68. x**2 + x*y + x - y**2 - y,
  69. -x**2 + x*y + x - y**2 + y,
  70. -x**2 + x*y + x + y**2 - y,
  71. x**2 - x*y + x + y**2 + y,
  72. x**2 - x*y + x - y**2 - y,
  73. -x**2 - x*y + x - y**2 + y,
  74. -x**2 - x*y + x + y**2 - y]
  75. assert fwht(ls) == fls
  76. assert ifwht(fls) == ls + [S.Zero]*3
  77. ls = list(range(6))
  78. assert fwht(ls) == [x*8 for x in ifwht(ls)]
  79. def test_mobius_transform():
  80. assert all(tf(ls, subset=subset) == ls
  81. for ls in ([], [Rational(7, 4)]) for subset in (True, False)
  82. for tf in (mobius_transform, inverse_mobius_transform))
  83. w, x, y, z = symbols('w x y z')
  84. assert mobius_transform([x, y]) == [x, x + y]
  85. assert inverse_mobius_transform([x, x + y]) == [x, y]
  86. assert mobius_transform([x, y], subset=False) == [x + y, y]
  87. assert inverse_mobius_transform([x + y, y], subset=False) == [x, y]
  88. assert mobius_transform([w, x, y, z]) == [w, w + x, w + y, w + x + y + z]
  89. assert inverse_mobius_transform([w, w + x, w + y, w + x + y + z]) == \
  90. [w, x, y, z]
  91. assert mobius_transform([w, x, y, z], subset=False) == \
  92. [w + x + y + z, x + z, y + z, z]
  93. assert inverse_mobius_transform([w + x + y + z, x + z, y + z, z], subset=False) == \
  94. [w, x, y, z]
  95. ls = [Rational(2, 3), Rational(6, 7), Rational(5, 8), 9, Rational(5, 3) + 7*I]
  96. mls = [Rational(2, 3), Rational(32, 21), Rational(31, 24), Rational(1873, 168),
  97. Rational(7, 3) + 7*I, Rational(67, 21) + 7*I, Rational(71, 24) + 7*I,
  98. Rational(2153, 168) + 7*I]
  99. assert mobius_transform(ls) == mls
  100. assert inverse_mobius_transform(mls) == ls + [S.Zero]*3
  101. mls = [Rational(2153, 168) + 7*I, Rational(69, 7), Rational(77, 8), 9, Rational(5, 3) + 7*I, 0, 0, 0]
  102. assert mobius_transform(ls, subset=False) == mls
  103. assert inverse_mobius_transform(mls, subset=False) == ls + [S.Zero]*3
  104. ls = ls[:-1]
  105. mls = [Rational(2, 3), Rational(32, 21), Rational(31, 24), Rational(1873, 168)]
  106. assert mobius_transform(ls) == mls
  107. assert inverse_mobius_transform(mls) == ls
  108. mls = [Rational(1873, 168), Rational(69, 7), Rational(77, 8), 9]
  109. assert mobius_transform(ls, subset=False) == mls
  110. assert inverse_mobius_transform(mls, subset=False) == ls
  111. raises(TypeError, lambda: mobius_transform(x, subset=True))
  112. raises(TypeError, lambda: inverse_mobius_transform(y, subset=False))