test_sequences.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. from sympy.core.containers import Tuple
  2. from sympy.core.function import Function
  3. from sympy.core.numbers import oo, Rational
  4. from sympy.core.singleton import S
  5. from sympy.core.symbol import symbols, Symbol
  6. from sympy.functions.combinatorial.numbers import tribonacci, fibonacci
  7. from sympy.functions.elementary.exponential import exp
  8. from sympy.functions.elementary.miscellaneous import sqrt
  9. from sympy.functions.elementary.trigonometric import cos, sin
  10. from sympy.series import EmptySequence
  11. from sympy.series.sequences import (SeqMul, SeqAdd, SeqPer, SeqFormula,
  12. sequence)
  13. from sympy.sets.sets import Interval
  14. from sympy.tensor.indexed import Indexed, Idx
  15. from sympy.series.sequences import SeqExpr, SeqExprOp, RecursiveSeq
  16. from sympy.testing.pytest import raises, slow
  17. x, y, z = symbols('x y z')
  18. n, m = symbols('n m')
  19. def test_EmptySequence():
  20. assert S.EmptySequence is EmptySequence
  21. assert S.EmptySequence.interval is S.EmptySet
  22. assert S.EmptySequence.length is S.Zero
  23. assert list(S.EmptySequence) == []
  24. def test_SeqExpr():
  25. #SeqExpr is a baseclass and does not take care of
  26. #ensuring all arguments are Basics hence the use of
  27. #Tuple(...) here.
  28. s = SeqExpr(Tuple(1, n, y), Tuple(x, 0, 10))
  29. assert isinstance(s, SeqExpr)
  30. assert s.gen == (1, n, y)
  31. assert s.interval == Interval(0, 10)
  32. assert s.start == 0
  33. assert s.stop == 10
  34. assert s.length == 11
  35. assert s.variables == (x,)
  36. assert SeqExpr(Tuple(1, 2, 3), Tuple(x, 0, oo)).length is oo
  37. def test_SeqPer():
  38. s = SeqPer((1, n, 3), (x, 0, 5))
  39. assert isinstance(s, SeqPer)
  40. assert s.periodical == Tuple(1, n, 3)
  41. assert s.period == 3
  42. assert s.coeff(3) == 1
  43. assert s.free_symbols == {n}
  44. assert list(s) == [1, n, 3, 1, n, 3]
  45. assert s[:] == [1, n, 3, 1, n, 3]
  46. assert SeqPer((1, n, 3), (x, -oo, 0))[0:6] == [1, n, 3, 1, n, 3]
  47. raises(ValueError, lambda: SeqPer((1, 2, 3), (0, 1, 2)))
  48. raises(ValueError, lambda: SeqPer((1, 2, 3), (x, -oo, oo)))
  49. raises(ValueError, lambda: SeqPer(n**2, (0, oo)))
  50. assert SeqPer((n, n**2, n**3), (m, 0, oo))[:6] == \
  51. [n, n**2, n**3, n, n**2, n**3]
  52. assert SeqPer((n, n**2, n**3), (n, 0, oo))[:6] == [0, 1, 8, 3, 16, 125]
  53. assert SeqPer((n, m), (n, 0, oo))[:6] == [0, m, 2, m, 4, m]
  54. def test_SeqFormula():
  55. s = SeqFormula(n**2, (n, 0, 5))
  56. assert isinstance(s, SeqFormula)
  57. assert s.formula == n**2
  58. assert s.coeff(3) == 9
  59. assert list(s) == [i**2 for i in range(6)]
  60. assert s[:] == [i**2 for i in range(6)]
  61. assert SeqFormula(n**2, (n, -oo, 0))[0:6] == [i**2 for i in range(6)]
  62. assert SeqFormula(n**2, (0, oo)) == SeqFormula(n**2, (n, 0, oo))
  63. assert SeqFormula(n**2, (0, m)).subs(m, x) == SeqFormula(n**2, (0, x))
  64. assert SeqFormula(m*n**2, (n, 0, oo)).subs(m, x) == \
  65. SeqFormula(x*n**2, (n, 0, oo))
  66. raises(ValueError, lambda: SeqFormula(n**2, (0, 1, 2)))
  67. raises(ValueError, lambda: SeqFormula(n**2, (n, -oo, oo)))
  68. raises(ValueError, lambda: SeqFormula(m*n**2, (0, oo)))
  69. seq = SeqFormula(x*(y**2 + z), (z, 1, 100))
  70. assert seq.expand() == SeqFormula(x*y**2 + x*z, (z, 1, 100))
  71. seq = SeqFormula(sin(x*(y**2 + z)),(z, 1, 100))
  72. assert seq.expand(trig=True) == SeqFormula(sin(x*y**2)*cos(x*z) + sin(x*z)*cos(x*y**2), (z, 1, 100))
  73. assert seq.expand() == SeqFormula(sin(x*y**2 + x*z), (z, 1, 100))
  74. assert seq.expand(trig=False) == SeqFormula(sin(x*y**2 + x*z), (z, 1, 100))
  75. seq = SeqFormula(exp(x*(y**2 + z)), (z, 1, 100))
  76. assert seq.expand() == SeqFormula(exp(x*y**2)*exp(x*z), (z, 1, 100))
  77. assert seq.expand(power_exp=False) == SeqFormula(exp(x*y**2 + x*z), (z, 1, 100))
  78. assert seq.expand(mul=False, power_exp=False) == SeqFormula(exp(x*(y**2 + z)), (z, 1, 100))
  79. def test_sequence():
  80. form = SeqFormula(n**2, (n, 0, 5))
  81. per = SeqPer((1, 2, 3), (n, 0, 5))
  82. inter = SeqFormula(n**2)
  83. assert sequence(n**2, (n, 0, 5)) == form
  84. assert sequence((1, 2, 3), (n, 0, 5)) == per
  85. assert sequence(n**2) == inter
  86. def test_SeqExprOp():
  87. form = SeqFormula(n**2, (n, 0, 10))
  88. per = SeqPer((1, 2, 3), (m, 5, 10))
  89. s = SeqExprOp(form, per)
  90. assert s.gen == (n**2, (1, 2, 3))
  91. assert s.interval == Interval(5, 10)
  92. assert s.start == 5
  93. assert s.stop == 10
  94. assert s.length == 6
  95. assert s.variables == (n, m)
  96. def test_SeqAdd():
  97. per = SeqPer((1, 2, 3), (n, 0, oo))
  98. form = SeqFormula(n**2)
  99. per_bou = SeqPer((1, 2), (n, 1, 5))
  100. form_bou = SeqFormula(n**2, (6, 10))
  101. form_bou2 = SeqFormula(n**2, (1, 5))
  102. assert SeqAdd() == S.EmptySequence
  103. assert SeqAdd(S.EmptySequence) == S.EmptySequence
  104. assert SeqAdd(per) == per
  105. assert SeqAdd(per, S.EmptySequence) == per
  106. assert SeqAdd(per_bou, form_bou) == S.EmptySequence
  107. s = SeqAdd(per_bou, form_bou2, evaluate=False)
  108. assert s.args == (form_bou2, per_bou)
  109. assert s[:] == [2, 6, 10, 18, 26]
  110. assert list(s) == [2, 6, 10, 18, 26]
  111. assert isinstance(SeqAdd(per, per_bou, evaluate=False), SeqAdd)
  112. s1 = SeqAdd(per, per_bou)
  113. assert isinstance(s1, SeqPer)
  114. assert s1 == SeqPer((2, 4, 4, 3, 3, 5), (n, 1, 5))
  115. s2 = SeqAdd(form, form_bou)
  116. assert isinstance(s2, SeqFormula)
  117. assert s2 == SeqFormula(2*n**2, (6, 10))
  118. assert SeqAdd(form, form_bou, per) == \
  119. SeqAdd(per, SeqFormula(2*n**2, (6, 10)))
  120. assert SeqAdd(form, SeqAdd(form_bou, per)) == \
  121. SeqAdd(per, SeqFormula(2*n**2, (6, 10)))
  122. assert SeqAdd(per, SeqAdd(form, form_bou), evaluate=False) == \
  123. SeqAdd(per, SeqFormula(2*n**2, (6, 10)))
  124. assert SeqAdd(SeqPer((1, 2), (n, 0, oo)), SeqPer((1, 2), (m, 0, oo))) == \
  125. SeqPer((2, 4), (n, 0, oo))
  126. def test_SeqMul():
  127. per = SeqPer((1, 2, 3), (n, 0, oo))
  128. form = SeqFormula(n**2)
  129. per_bou = SeqPer((1, 2), (n, 1, 5))
  130. form_bou = SeqFormula(n**2, (n, 6, 10))
  131. form_bou2 = SeqFormula(n**2, (1, 5))
  132. assert SeqMul() == S.EmptySequence
  133. assert SeqMul(S.EmptySequence) == S.EmptySequence
  134. assert SeqMul(per) == per
  135. assert SeqMul(per, S.EmptySequence) == S.EmptySequence
  136. assert SeqMul(per_bou, form_bou) == S.EmptySequence
  137. s = SeqMul(per_bou, form_bou2, evaluate=False)
  138. assert s.args == (form_bou2, per_bou)
  139. assert s[:] == [1, 8, 9, 32, 25]
  140. assert list(s) == [1, 8, 9, 32, 25]
  141. assert isinstance(SeqMul(per, per_bou, evaluate=False), SeqMul)
  142. s1 = SeqMul(per, per_bou)
  143. assert isinstance(s1, SeqPer)
  144. assert s1 == SeqPer((1, 4, 3, 2, 2, 6), (n, 1, 5))
  145. s2 = SeqMul(form, form_bou)
  146. assert isinstance(s2, SeqFormula)
  147. assert s2 == SeqFormula(n**4, (6, 10))
  148. assert SeqMul(form, form_bou, per) == \
  149. SeqMul(per, SeqFormula(n**4, (6, 10)))
  150. assert SeqMul(form, SeqMul(form_bou, per)) == \
  151. SeqMul(per, SeqFormula(n**4, (6, 10)))
  152. assert SeqMul(per, SeqMul(form, form_bou2,
  153. evaluate=False), evaluate=False) == \
  154. SeqMul(form, per, form_bou2, evaluate=False)
  155. assert SeqMul(SeqPer((1, 2), (n, 0, oo)), SeqPer((1, 2), (n, 0, oo))) == \
  156. SeqPer((1, 4), (n, 0, oo))
  157. def test_add():
  158. per = SeqPer((1, 2), (n, 0, oo))
  159. form = SeqFormula(n**2)
  160. assert per + (SeqPer((2, 3))) == SeqPer((3, 5), (n, 0, oo))
  161. assert form + SeqFormula(n**3) == SeqFormula(n**2 + n**3)
  162. assert per + form == SeqAdd(per, form)
  163. raises(TypeError, lambda: per + n)
  164. raises(TypeError, lambda: n + per)
  165. def test_sub():
  166. per = SeqPer((1, 2), (n, 0, oo))
  167. form = SeqFormula(n**2)
  168. assert per - (SeqPer((2, 3))) == SeqPer((-1, -1), (n, 0, oo))
  169. assert form - (SeqFormula(n**3)) == SeqFormula(n**2 - n**3)
  170. assert per - form == SeqAdd(per, -form)
  171. raises(TypeError, lambda: per - n)
  172. raises(TypeError, lambda: n - per)
  173. def test_mul__coeff_mul():
  174. assert SeqPer((1, 2), (n, 0, oo)).coeff_mul(2) == SeqPer((2, 4), (n, 0, oo))
  175. assert SeqFormula(n**2).coeff_mul(2) == SeqFormula(2*n**2)
  176. assert S.EmptySequence.coeff_mul(100) == S.EmptySequence
  177. assert SeqPer((1, 2), (n, 0, oo)) * (SeqPer((2, 3))) == \
  178. SeqPer((2, 6), (n, 0, oo))
  179. assert SeqFormula(n**2) * SeqFormula(n**3) == SeqFormula(n**5)
  180. assert S.EmptySequence * SeqFormula(n**2) == S.EmptySequence
  181. assert SeqFormula(n**2) * S.EmptySequence == S.EmptySequence
  182. raises(TypeError, lambda: sequence(n**2) * n)
  183. raises(TypeError, lambda: n * sequence(n**2))
  184. def test_neg():
  185. assert -SeqPer((1, -2), (n, 0, oo)) == SeqPer((-1, 2), (n, 0, oo))
  186. assert -SeqFormula(n**2) == SeqFormula(-n**2)
  187. def test_operations():
  188. per = SeqPer((1, 2), (n, 0, oo))
  189. per2 = SeqPer((2, 4), (n, 0, oo))
  190. form = SeqFormula(n**2)
  191. form2 = SeqFormula(n**3)
  192. assert per + form + form2 == SeqAdd(per, form, form2)
  193. assert per + form - form2 == SeqAdd(per, form, -form2)
  194. assert per + form - S.EmptySequence == SeqAdd(per, form)
  195. assert per + per2 + form == SeqAdd(SeqPer((3, 6), (n, 0, oo)), form)
  196. assert S.EmptySequence - per == -per
  197. assert form + form == SeqFormula(2*n**2)
  198. assert per * form * form2 == SeqMul(per, form, form2)
  199. assert form * form == SeqFormula(n**4)
  200. assert form * -form == SeqFormula(-n**4)
  201. assert form * (per + form2) == SeqMul(form, SeqAdd(per, form2))
  202. assert form * (per + per) == SeqMul(form, per2)
  203. assert form.coeff_mul(m) == SeqFormula(m*n**2, (n, 0, oo))
  204. assert per.coeff_mul(m) == SeqPer((m, 2*m), (n, 0, oo))
  205. def test_Idx_limits():
  206. i = symbols('i', cls=Idx)
  207. r = Indexed('r', i)
  208. assert SeqFormula(r, (i, 0, 5))[:] == [r.subs(i, j) for j in range(6)]
  209. assert SeqPer((1, 2), (i, 0, 5))[:] == [1, 2, 1, 2, 1, 2]
  210. @slow
  211. def test_find_linear_recurrence():
  212. assert sequence((0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55), \
  213. (n, 0, 10)).find_linear_recurrence(11) == [1, 1]
  214. assert sequence((1, 2, 4, 7, 28, 128, 582, 2745, 13021, 61699, 292521, \
  215. 1387138), (n, 0, 11)).find_linear_recurrence(12) == [5, -2, 6, -11]
  216. assert sequence(x*n**3+y*n, (n, 0, oo)).find_linear_recurrence(10) \
  217. == [4, -6, 4, -1]
  218. assert sequence(x**n, (n,0,20)).find_linear_recurrence(21) == [x]
  219. assert sequence((1,2,3)).find_linear_recurrence(10, 5) == [0, 0, 1]
  220. assert sequence(((1 + sqrt(5))/2)**n + \
  221. (-(1 + sqrt(5))/2)**(-n)).find_linear_recurrence(10) == [1, 1]
  222. assert sequence(x*((1 + sqrt(5))/2)**n + y*(-(1 + sqrt(5))/2)**(-n), \
  223. (n,0,oo)).find_linear_recurrence(10) == [1, 1]
  224. assert sequence((1,2,3,4,6),(n, 0, 4)).find_linear_recurrence(5) == []
  225. assert sequence((2,3,4,5,6,79),(n, 0, 5)).find_linear_recurrence(6,gfvar=x) \
  226. == ([], None)
  227. assert sequence((2,3,4,5,8,30),(n, 0, 5)).find_linear_recurrence(6,gfvar=x) \
  228. == ([Rational(19, 2), -20, Rational(27, 2)], (-31*x**2 + 32*x - 4)/(27*x**3 - 40*x**2 + 19*x -2))
  229. assert sequence(fibonacci(n)).find_linear_recurrence(30,gfvar=x) \
  230. == ([1, 1], -x/(x**2 + x - 1))
  231. assert sequence(tribonacci(n)).find_linear_recurrence(30,gfvar=x) \
  232. == ([1, 1, 1], -x/(x**3 + x**2 + x - 1))
  233. def test_RecursiveSeq():
  234. y = Function('y')
  235. n = Symbol('n')
  236. fib = RecursiveSeq(y(n - 1) + y(n - 2), y(n), n, [0, 1])
  237. assert fib.coeff(3) == 2