test_fourier.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from sympy.assumptions.ask import (Q, ask)
  2. from sympy.core.numbers import (I, Rational)
  3. from sympy.core.singleton import S
  4. from sympy.functions.elementary.complexes import Abs
  5. from sympy.functions.elementary.exponential import exp
  6. from sympy.functions.elementary.miscellaneous import sqrt
  7. from sympy.simplify.simplify import simplify
  8. from sympy.core.symbol import symbols
  9. from sympy.matrices.expressions.fourier import DFT, IDFT
  10. from sympy.matrices import det, Matrix, Identity
  11. from sympy.testing.pytest import raises
  12. def test_dft_creation():
  13. assert DFT(2)
  14. assert DFT(0)
  15. raises(ValueError, lambda: DFT(-1))
  16. raises(ValueError, lambda: DFT(2.0))
  17. raises(ValueError, lambda: DFT(2 + 1j))
  18. n = symbols('n')
  19. assert DFT(n)
  20. n = symbols('n', integer=False)
  21. raises(ValueError, lambda: DFT(n))
  22. n = symbols('n', negative=True)
  23. raises(ValueError, lambda: DFT(n))
  24. def test_dft():
  25. n, i, j = symbols('n i j')
  26. assert DFT(4).shape == (4, 4)
  27. assert ask(Q.unitary(DFT(4)))
  28. assert Abs(simplify(det(Matrix(DFT(4))))) == 1
  29. assert DFT(n)*IDFT(n) == Identity(n)
  30. assert DFT(n)[i, j] == exp(-2*S.Pi*I/n)**(i*j) / sqrt(n)
  31. def test_dft2():
  32. assert DFT(1).as_explicit() == Matrix([[1]])
  33. assert DFT(2).as_explicit() == 1/sqrt(2)*Matrix([[1,1],[1,-1]])
  34. assert DFT(4).as_explicit() == Matrix([[S.Half, S.Half, S.Half, S.Half],
  35. [S.Half, -I/2, Rational(-1,2), I/2],
  36. [S.Half, Rational(-1,2), S.Half, Rational(-1,2)],
  37. [S.Half, I/2, Rational(-1,2), -I/2]])