test_fourier.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from sympy.core.add import Add
  2. from sympy.core.numbers import (Rational, oo, pi)
  3. from sympy.core.singleton import S
  4. from sympy.core.symbol import symbols
  5. from sympy.functions.elementary.exponential import (exp, log)
  6. from sympy.functions.elementary.piecewise import Piecewise
  7. from sympy.functions.elementary.trigonometric import (cos, sin, sinc, tan)
  8. from sympy.series.fourier import fourier_series
  9. from sympy.series.fourier import FourierSeries
  10. from sympy.testing.pytest import raises
  11. from functools import lru_cache
  12. x, y, z = symbols('x y z')
  13. # Don't declare these during import because they are slow
  14. @lru_cache()
  15. def _get_examples():
  16. fo = fourier_series(x, (x, -pi, pi))
  17. fe = fourier_series(x**2, (-pi, pi))
  18. fp = fourier_series(Piecewise((0, x < 0), (pi, True)), (x, -pi, pi))
  19. return fo, fe, fp
  20. def test_FourierSeries():
  21. fo, fe, fp = _get_examples()
  22. assert fourier_series(1, (-pi, pi)) == 1
  23. assert (Piecewise((0, x < 0), (pi, True)).
  24. fourier_series((x, -pi, pi)).truncate()) == fp.truncate()
  25. assert isinstance(fo, FourierSeries)
  26. assert fo.function == x
  27. assert fo.x == x
  28. assert fo.period == (-pi, pi)
  29. assert fo.term(3) == 2*sin(3*x) / 3
  30. assert fe.term(3) == -4*cos(3*x) / 9
  31. assert fp.term(3) == 2*sin(3*x) / 3
  32. assert fo.as_leading_term(x) == 2*sin(x)
  33. assert fe.as_leading_term(x) == pi**2 / 3
  34. assert fp.as_leading_term(x) == pi / 2
  35. assert fo.truncate() == 2*sin(x) - sin(2*x) + (2*sin(3*x) / 3)
  36. assert fe.truncate() == -4*cos(x) + cos(2*x) + pi**2 / 3
  37. assert fp.truncate() == 2*sin(x) + (2*sin(3*x) / 3) + pi / 2
  38. fot = fo.truncate(n=None)
  39. s = [0, 2*sin(x), -sin(2*x)]
  40. for i, t in enumerate(fot):
  41. if i == 3:
  42. break
  43. assert s[i] == t
  44. def _check_iter(f, i):
  45. for ind, t in enumerate(f):
  46. assert t == f[ind]
  47. if ind == i:
  48. break
  49. _check_iter(fo, 3)
  50. _check_iter(fe, 3)
  51. _check_iter(fp, 3)
  52. assert fo.subs(x, x**2) == fo
  53. raises(ValueError, lambda: fourier_series(x, (0, 1, 2)))
  54. raises(ValueError, lambda: fourier_series(x, (x, 0, oo)))
  55. raises(ValueError, lambda: fourier_series(x*y, (0, oo)))
  56. def test_FourierSeries_2():
  57. p = Piecewise((0, x < 0), (x, True))
  58. f = fourier_series(p, (x, -2, 2))
  59. assert f.term(3) == (2*sin(3*pi*x / 2) / (3*pi) -
  60. 4*cos(3*pi*x / 2) / (9*pi**2))
  61. assert f.truncate() == (2*sin(pi*x / 2) / pi - sin(pi*x) / pi -
  62. 4*cos(pi*x / 2) / pi**2 + S.Half)
  63. def test_square_wave():
  64. """Test if fourier_series approximates discontinuous function correctly."""
  65. square_wave = Piecewise((1, x < pi), (-1, True))
  66. s = fourier_series(square_wave, (x, 0, 2*pi))
  67. assert s.truncate(3) == 4 / pi * sin(x) + 4 / (3 * pi) * sin(3 * x) + \
  68. 4 / (5 * pi) * sin(5 * x)
  69. assert s.sigma_approximation(4) == 4 / pi * sin(x) * sinc(pi / 4) + \
  70. 4 / (3 * pi) * sin(3 * x) * sinc(3 * pi / 4)
  71. def test_sawtooth_wave():
  72. s = fourier_series(x, (x, 0, pi))
  73. assert s.truncate(4) == \
  74. pi/2 - sin(2*x) - sin(4*x)/2 - sin(6*x)/3
  75. s = fourier_series(x, (x, 0, 1))
  76. assert s.truncate(4) == \
  77. S.Half - sin(2*pi*x)/pi - sin(4*pi*x)/(2*pi) - sin(6*pi*x)/(3*pi)
  78. def test_FourierSeries__operations():
  79. fo, fe, fp = _get_examples()
  80. fes = fe.scale(-1).shift(pi**2)
  81. assert fes.truncate() == 4*cos(x) - cos(2*x) + 2*pi**2 / 3
  82. assert fp.shift(-pi/2).truncate() == (2*sin(x) + (2*sin(3*x) / 3) +
  83. (2*sin(5*x) / 5))
  84. fos = fo.scale(3)
  85. assert fos.truncate() == 6*sin(x) - 3*sin(2*x) + 2*sin(3*x)
  86. fx = fe.scalex(2).shiftx(1)
  87. assert fx.truncate() == -4*cos(2*x + 2) + cos(4*x + 4) + pi**2 / 3
  88. fl = fe.scalex(3).shift(-pi).scalex(2).shiftx(1).scale(4)
  89. assert fl.truncate() == (-16*cos(6*x + 6) + 4*cos(12*x + 12) -
  90. 4*pi + 4*pi**2 / 3)
  91. raises(ValueError, lambda: fo.shift(x))
  92. raises(ValueError, lambda: fo.shiftx(sin(x)))
  93. raises(ValueError, lambda: fo.scale(x*y))
  94. raises(ValueError, lambda: fo.scalex(x**2))
  95. def test_FourierSeries__neg():
  96. fo, fe, fp = _get_examples()
  97. assert (-fo).truncate() == -2*sin(x) + sin(2*x) - (2*sin(3*x) / 3)
  98. assert (-fe).truncate() == +4*cos(x) - cos(2*x) - pi**2 / 3
  99. def test_FourierSeries__add__sub():
  100. fo, fe, fp = _get_examples()
  101. assert fo + fo == fo.scale(2)
  102. assert fo - fo == 0
  103. assert -fe - fe == fe.scale(-2)
  104. assert (fo + fe).truncate() == 2*sin(x) - sin(2*x) - 4*cos(x) + cos(2*x) \
  105. + pi**2 / 3
  106. assert (fo - fe).truncate() == 2*sin(x) - sin(2*x) + 4*cos(x) - cos(2*x) \
  107. - pi**2 / 3
  108. assert isinstance(fo + 1, Add)
  109. raises(ValueError, lambda: fo + fourier_series(x, (x, 0, 2)))
  110. def test_FourierSeries_finite():
  111. assert fourier_series(sin(x)).truncate(1) == sin(x)
  112. # assert type(fourier_series(sin(x)*log(x))).truncate() == FourierSeries
  113. # assert type(fourier_series(sin(x**2+6))).truncate() == FourierSeries
  114. assert fourier_series(sin(x)*log(y)*exp(z),(x,pi,-pi)).truncate() == sin(x)*log(y)*exp(z)
  115. assert fourier_series(sin(x)**6).truncate(oo) == -15*cos(2*x)/32 + 3*cos(4*x)/16 - cos(6*x)/32 \
  116. + Rational(5, 16)
  117. assert fourier_series(sin(x) ** 6).truncate() == -15 * cos(2 * x) / 32 + 3 * cos(4 * x) / 16 \
  118. + Rational(5, 16)
  119. assert fourier_series(sin(4*x+3) + cos(3*x+4)).truncate(oo) == -sin(4)*sin(3*x) + sin(4*x)*cos(3) \
  120. + cos(4)*cos(3*x) + sin(3)*cos(4*x)
  121. assert fourier_series(sin(x)+cos(x)*tan(x)).truncate(oo) == 2*sin(x)
  122. assert fourier_series(cos(pi*x), (x, -1, 1)).truncate(oo) == cos(pi*x)
  123. assert fourier_series(cos(3*pi*x + 4) - sin(4*pi*x)*log(pi*y), (x, -1, 1)).truncate(oo) == -log(pi*y)*sin(4*pi*x)\
  124. - sin(4)*sin(3*pi*x) + cos(4)*cos(3*pi*x)