123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476 |
- from sympy.core import (S, pi, oo, Symbol, symbols, Rational, Integer,
- GoldenRatio, EulerGamma, Catalan, Lambda, Dummy)
- from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
- gamma, sign, Max, Min, factorial, beta)
- from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
- from sympy.sets import Range
- from sympy.logic import ITE
- from sympy.codegen import For, aug_assign, Assignment
- from sympy.testing.pytest import raises
- from sympy.printing.rcode import RCodePrinter
- from sympy.utilities.lambdify import implemented_function
- from sympy.tensor import IndexedBase, Idx
- from sympy.matrices import Matrix, MatrixSymbol
- from sympy.printing.rcode import rcode
- x, y, z = symbols('x,y,z')
- def test_printmethod():
- class fabs(Abs):
- def _rcode(self, printer):
- return "abs(%s)" % printer._print(self.args[0])
- assert rcode(fabs(x)) == "abs(x)"
- def test_rcode_sqrt():
- assert rcode(sqrt(x)) == "sqrt(x)"
- assert rcode(x**0.5) == "sqrt(x)"
- assert rcode(sqrt(x)) == "sqrt(x)"
- def test_rcode_Pow():
- assert rcode(x**3) == "x^3"
- assert rcode(x**(y**3)) == "x^(y^3)"
- g = implemented_function('g', Lambda(x, 2*x))
- assert rcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
- "(3.5*2*x)^(-x + y^x)/(x^2 + y)"
- assert rcode(x**-1.0) == '1.0/x'
- assert rcode(x**Rational(2, 3)) == 'x^(2.0/3.0)'
- _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"),
- (lambda base, exp: not exp.is_integer, "pow")]
- assert rcode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)'
- assert rcode(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 3.2)'
- def test_rcode_Max():
- # Test for gh-11926
- assert rcode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))'
- def test_rcode_constants_mathh():
- assert rcode(exp(1)) == "exp(1)"
- assert rcode(pi) == "pi"
- assert rcode(oo) == "Inf"
- assert rcode(-oo) == "-Inf"
- def test_rcode_constants_other():
- assert rcode(2*GoldenRatio) == "GoldenRatio = 1.61803398874989;\n2*GoldenRatio"
- assert rcode(
- 2*Catalan) == "Catalan = 0.915965594177219;\n2*Catalan"
- assert rcode(2*EulerGamma) == "EulerGamma = 0.577215664901533;\n2*EulerGamma"
- def test_rcode_Rational():
- assert rcode(Rational(3, 7)) == "3.0/7.0"
- assert rcode(Rational(18, 9)) == "2"
- assert rcode(Rational(3, -7)) == "-3.0/7.0"
- assert rcode(Rational(-3, -7)) == "3.0/7.0"
- assert rcode(x + Rational(3, 7)) == "x + 3.0/7.0"
- assert rcode(Rational(3, 7)*x) == "(3.0/7.0)*x"
- def test_rcode_Integer():
- assert rcode(Integer(67)) == "67"
- assert rcode(Integer(-1)) == "-1"
- def test_rcode_functions():
- assert rcode(sin(x) ** cos(x)) == "sin(x)^cos(x)"
- assert rcode(factorial(x) + gamma(y)) == "factorial(x) + gamma(y)"
- assert rcode(beta(Min(x, y), Max(x, y))) == "beta(min(x, y), max(x, y))"
- def test_rcode_inline_function():
- x = symbols('x')
- g = implemented_function('g', Lambda(x, 2*x))
- assert rcode(g(x)) == "2*x"
- g = implemented_function('g', Lambda(x, 2*x/Catalan))
- assert rcode(
- g(x)) == "Catalan = %s;\n2*x/Catalan" % Catalan.n()
- A = IndexedBase('A')
- i = Idx('i', symbols('n', integer=True))
- g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
- res=rcode(g(A[i]), assign_to=A[i])
- ref=(
- "for (i in 1:n){\n"
- " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
- "}"
- )
- assert res == ref
- def test_rcode_exceptions():
- assert rcode(ceiling(x)) == "ceiling(x)"
- assert rcode(Abs(x)) == "abs(x)"
- assert rcode(gamma(x)) == "gamma(x)"
- def test_rcode_user_functions():
- x = symbols('x', integer=False)
- n = symbols('n', integer=True)
- custom_functions = {
- "ceiling": "myceil",
- "Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")],
- }
- assert rcode(ceiling(x), user_functions=custom_functions) == "myceil(x)"
- assert rcode(Abs(x), user_functions=custom_functions) == "fabs(x)"
- assert rcode(Abs(n), user_functions=custom_functions) == "abs(n)"
- def test_rcode_boolean():
- assert rcode(True) == "True"
- assert rcode(S.true) == "True"
- assert rcode(False) == "False"
- assert rcode(S.false) == "False"
- assert rcode(x & y) == "x & y"
- assert rcode(x | y) == "x | y"
- assert rcode(~x) == "!x"
- assert rcode(x & y & z) == "x & y & z"
- assert rcode(x | y | z) == "x | y | z"
- assert rcode((x & y) | z) == "z | x & y"
- assert rcode((x | y) & z) == "z & (x | y)"
- def test_rcode_Relational():
- assert rcode(Eq(x, y)) == "x == y"
- assert rcode(Ne(x, y)) == "x != y"
- assert rcode(Le(x, y)) == "x <= y"
- assert rcode(Lt(x, y)) == "x < y"
- assert rcode(Gt(x, y)) == "x > y"
- assert rcode(Ge(x, y)) == "x >= y"
- def test_rcode_Piecewise():
- expr = Piecewise((x, x < 1), (x**2, True))
- res=rcode(expr)
- ref="ifelse(x < 1,x,x^2)"
- assert res == ref
- tau=Symbol("tau")
- res=rcode(expr,tau)
- ref="tau = ifelse(x < 1,x,x^2);"
- assert res == ref
- expr = 2*Piecewise((x, x < 1), (x**2, x<2), (x**3,True))
- assert rcode(expr) == "2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3))"
- res = rcode(expr, assign_to='c')
- assert res == "c = 2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3));"
- # Check that Piecewise without a True (default) condition error
- #expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
- #raises(ValueError, lambda: rcode(expr))
- expr = 2*Piecewise((x, x < 1), (x**2, x<2))
- assert(rcode(expr))== "2*ifelse(x < 1,x,ifelse(x < 2,x^2,NA))"
- def test_rcode_sinc():
- from sympy.functions.elementary.trigonometric import sinc
- expr = sinc(x)
- res = rcode(expr)
- ref = "ifelse(x != 0,sin(x)/x,1)"
- assert res == ref
- def test_rcode_Piecewise_deep():
- p = rcode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)))
- assert p == "2*ifelse(x < 1,x,ifelse(x < 2,x + 1,x^2))"
- expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1
- p = rcode(expr)
- ref="x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1"
- assert p == ref
- ref="c = x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1;"
- p = rcode(expr, assign_to='c')
- assert p == ref
- def test_rcode_ITE():
- expr = ITE(x < 1, y, z)
- p = rcode(expr)
- ref="ifelse(x < 1,y,z)"
- assert p == ref
- def test_rcode_settings():
- raises(TypeError, lambda: rcode(sin(x), method="garbage"))
- def test_rcode_Indexed():
- n, m, o = symbols('n m o', integer=True)
- i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
- p = RCodePrinter()
- p._not_r = set()
- x = IndexedBase('x')[j]
- assert p._print_Indexed(x) == 'x[j]'
- A = IndexedBase('A')[i, j]
- assert p._print_Indexed(A) == 'A[i, j]'
- B = IndexedBase('B')[i, j, k]
- assert p._print_Indexed(B) == 'B[i, j, k]'
- assert p._not_r == set()
- def test_rcode_Indexed_without_looking_for_contraction():
- len_y = 5
- y = IndexedBase('y', shape=(len_y,))
- x = IndexedBase('x', shape=(len_y,))
- Dy = IndexedBase('Dy', shape=(len_y-1,))
- i = Idx('i', len_y-1)
- e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
- code0 = rcode(e.rhs, assign_to=e.lhs, contract=False)
- assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1)
- def test_rcode_loops_matrix_vector():
- n, m = symbols('n m', integer=True)
- A = IndexedBase('A')
- x = IndexedBase('x')
- y = IndexedBase('y')
- i = Idx('i', m)
- j = Idx('j', n)
- s = (
- 'for (i in 1:m){\n'
- ' y[i] = 0;\n'
- '}\n'
- 'for (i in 1:m){\n'
- ' for (j in 1:n){\n'
- ' y[i] = A[i, j]*x[j] + y[i];\n'
- ' }\n'
- '}'
- )
- c = rcode(A[i, j]*x[j], assign_to=y[i])
- assert c == s
- def test_dummy_loops():
- # the following line could also be
- # [Dummy(s, integer=True) for s in 'im']
- # or [Dummy(integer=True) for s in 'im']
- i, m = symbols('i m', integer=True, cls=Dummy)
- x = IndexedBase('x')
- y = IndexedBase('y')
- i = Idx(i, m)
- expected = (
- 'for (i_%(icount)i in 1:m_%(mcount)i){\n'
- ' y[i_%(icount)i] = x[i_%(icount)i];\n'
- '}'
- ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
- code = rcode(x[i], assign_to=y[i])
- assert code == expected
- def test_rcode_loops_add():
- n, m = symbols('n m', integer=True)
- A = IndexedBase('A')
- x = IndexedBase('x')
- y = IndexedBase('y')
- z = IndexedBase('z')
- i = Idx('i', m)
- j = Idx('j', n)
- s = (
- 'for (i in 1:m){\n'
- ' y[i] = x[i] + z[i];\n'
- '}\n'
- 'for (i in 1:m){\n'
- ' for (j in 1:n){\n'
- ' y[i] = A[i, j]*x[j] + y[i];\n'
- ' }\n'
- '}'
- )
- c = rcode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
- assert c == s
- def test_rcode_loops_multiple_contractions():
- n, m, o, p = symbols('n m o p', integer=True)
- a = IndexedBase('a')
- b = IndexedBase('b')
- y = IndexedBase('y')
- i = Idx('i', m)
- j = Idx('j', n)
- k = Idx('k', o)
- l = Idx('l', p)
- s = (
- 'for (i in 1:m){\n'
- ' y[i] = 0;\n'
- '}\n'
- 'for (i in 1:m){\n'
- ' for (j in 1:n){\n'
- ' for (k in 1:o){\n'
- ' for (l in 1:p){\n'
- ' y[i] = a[i, j, k, l]*b[j, k, l] + y[i];\n'
- ' }\n'
- ' }\n'
- ' }\n'
- '}'
- )
- c = rcode(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
- assert c == s
- def test_rcode_loops_addfactor():
- n, m, o, p = symbols('n m o p', integer=True)
- a = IndexedBase('a')
- b = IndexedBase('b')
- c = IndexedBase('c')
- y = IndexedBase('y')
- i = Idx('i', m)
- j = Idx('j', n)
- k = Idx('k', o)
- l = Idx('l', p)
- s = (
- 'for (i in 1:m){\n'
- ' y[i] = 0;\n'
- '}\n'
- 'for (i in 1:m){\n'
- ' for (j in 1:n){\n'
- ' for (k in 1:o){\n'
- ' for (l in 1:p){\n'
- ' y[i] = (a[i, j, k, l] + b[i, j, k, l])*c[j, k, l] + y[i];\n'
- ' }\n'
- ' }\n'
- ' }\n'
- '}'
- )
- c = rcode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
- assert c == s
- def test_rcode_loops_multiple_terms():
- n, m, o, p = symbols('n m o p', integer=True)
- a = IndexedBase('a')
- b = IndexedBase('b')
- c = IndexedBase('c')
- y = IndexedBase('y')
- i = Idx('i', m)
- j = Idx('j', n)
- k = Idx('k', o)
- s0 = (
- 'for (i in 1:m){\n'
- ' y[i] = 0;\n'
- '}\n'
- )
- s1 = (
- 'for (i in 1:m){\n'
- ' for (j in 1:n){\n'
- ' for (k in 1:o){\n'
- ' y[i] = b[j]*b[k]*c[i, j, k] + y[i];\n'
- ' }\n'
- ' }\n'
- '}\n'
- )
- s2 = (
- 'for (i in 1:m){\n'
- ' for (k in 1:o){\n'
- ' y[i] = a[i, k]*b[k] + y[i];\n'
- ' }\n'
- '}\n'
- )
- s3 = (
- 'for (i in 1:m){\n'
- ' for (j in 1:n){\n'
- ' y[i] = a[i, j]*b[j] + y[i];\n'
- ' }\n'
- '}\n'
- )
- c = rcode(
- b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
- ref={}
- ref[0] = s0 + s1 + s2 + s3[:-1]
- ref[1] = s0 + s1 + s3 + s2[:-1]
- ref[2] = s0 + s2 + s1 + s3[:-1]
- ref[3] = s0 + s2 + s3 + s1[:-1]
- ref[4] = s0 + s3 + s1 + s2[:-1]
- ref[5] = s0 + s3 + s2 + s1[:-1]
- assert (c == ref[0] or
- c == ref[1] or
- c == ref[2] or
- c == ref[3] or
- c == ref[4] or
- c == ref[5])
- def test_dereference_printing():
- expr = x + y + sin(z) + z
- assert rcode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))"
- def test_Matrix_printing():
- # Test returning a Matrix
- mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
- A = MatrixSymbol('A', 3, 1)
- p = rcode(mat, A)
- assert p == (
- "A[0] = x*y;\n"
- "A[1] = ifelse(y > 0,x + 2,y);\n"
- "A[2] = sin(z);")
- # Test using MatrixElements in expressions
- expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
- p = rcode(expr)
- assert p == ("ifelse(x > 0,2*A[2],A[2]) + sin(A[1]) + A[0]")
- # Test using MatrixElements in a Matrix
- q = MatrixSymbol('q', 5, 1)
- M = MatrixSymbol('M', 3, 3)
- m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
- [q[1,0] + q[2,0], q[3, 0], 5],
- [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
- assert rcode(m, M) == (
- "M[0] = sin(q[1]);\n"
- "M[1] = 0;\n"
- "M[2] = cos(q[2]);\n"
- "M[3] = q[1] + q[2];\n"
- "M[4] = q[3];\n"
- "M[5] = 5;\n"
- "M[6] = 2*q[4]/q[1];\n"
- "M[7] = sqrt(q[0]) + 4;\n"
- "M[8] = 0;")
- def test_rcode_sgn():
- expr = sign(x) * y
- assert rcode(expr) == 'y*sign(x)'
- p = rcode(expr, 'z')
- assert p == 'z = y*sign(x);'
- p = rcode(sign(2 * x + x**2) * x + x**2)
- assert p == "x^2 + x*sign(x^2 + 2*x)"
- expr = sign(cos(x))
- p = rcode(expr)
- assert p == 'sign(cos(x))'
- def test_rcode_Assignment():
- assert rcode(Assignment(x, y + z)) == 'x = y + z;'
- assert rcode(aug_assign(x, '+', y + z)) == 'x += y + z;'
- def test_rcode_For():
- f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)])
- sol = rcode(f)
- assert sol == ("for(x in seq(from=0, to=9, by=2){\n"
- " y *= x;\n"
- "}")
- def test_MatrixElement_printing():
- # test cases for issue #11821
- A = MatrixSymbol("A", 1, 3)
- B = MatrixSymbol("B", 1, 3)
- C = MatrixSymbol("C", 1, 3)
- assert(rcode(A[0, 0]) == "A[0]")
- assert(rcode(3 * A[0, 0]) == "3*A[0]")
- F = C[0, 0].subs(C, A - B)
- assert(rcode(F) == "(A - B)[0]")
|