test_rcode.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. from sympy.core import (S, pi, oo, Symbol, symbols, Rational, Integer,
  2. GoldenRatio, EulerGamma, Catalan, Lambda, Dummy)
  3. from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
  4. gamma, sign, Max, Min, factorial, beta)
  5. from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
  6. from sympy.sets import Range
  7. from sympy.logic import ITE
  8. from sympy.codegen import For, aug_assign, Assignment
  9. from sympy.testing.pytest import raises
  10. from sympy.printing.rcode import RCodePrinter
  11. from sympy.utilities.lambdify import implemented_function
  12. from sympy.tensor import IndexedBase, Idx
  13. from sympy.matrices import Matrix, MatrixSymbol
  14. from sympy.printing.rcode import rcode
  15. x, y, z = symbols('x,y,z')
  16. def test_printmethod():
  17. class fabs(Abs):
  18. def _rcode(self, printer):
  19. return "abs(%s)" % printer._print(self.args[0])
  20. assert rcode(fabs(x)) == "abs(x)"
  21. def test_rcode_sqrt():
  22. assert rcode(sqrt(x)) == "sqrt(x)"
  23. assert rcode(x**0.5) == "sqrt(x)"
  24. assert rcode(sqrt(x)) == "sqrt(x)"
  25. def test_rcode_Pow():
  26. assert rcode(x**3) == "x^3"
  27. assert rcode(x**(y**3)) == "x^(y^3)"
  28. g = implemented_function('g', Lambda(x, 2*x))
  29. assert rcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
  30. "(3.5*2*x)^(-x + y^x)/(x^2 + y)"
  31. assert rcode(x**-1.0) == '1.0/x'
  32. assert rcode(x**Rational(2, 3)) == 'x^(2.0/3.0)'
  33. _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"),
  34. (lambda base, exp: not exp.is_integer, "pow")]
  35. assert rcode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)'
  36. assert rcode(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 3.2)'
  37. def test_rcode_Max():
  38. # Test for gh-11926
  39. assert rcode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))'
  40. def test_rcode_constants_mathh():
  41. assert rcode(exp(1)) == "exp(1)"
  42. assert rcode(pi) == "pi"
  43. assert rcode(oo) == "Inf"
  44. assert rcode(-oo) == "-Inf"
  45. def test_rcode_constants_other():
  46. assert rcode(2*GoldenRatio) == "GoldenRatio = 1.61803398874989;\n2*GoldenRatio"
  47. assert rcode(
  48. 2*Catalan) == "Catalan = 0.915965594177219;\n2*Catalan"
  49. assert rcode(2*EulerGamma) == "EulerGamma = 0.577215664901533;\n2*EulerGamma"
  50. def test_rcode_Rational():
  51. assert rcode(Rational(3, 7)) == "3.0/7.0"
  52. assert rcode(Rational(18, 9)) == "2"
  53. assert rcode(Rational(3, -7)) == "-3.0/7.0"
  54. assert rcode(Rational(-3, -7)) == "3.0/7.0"
  55. assert rcode(x + Rational(3, 7)) == "x + 3.0/7.0"
  56. assert rcode(Rational(3, 7)*x) == "(3.0/7.0)*x"
  57. def test_rcode_Integer():
  58. assert rcode(Integer(67)) == "67"
  59. assert rcode(Integer(-1)) == "-1"
  60. def test_rcode_functions():
  61. assert rcode(sin(x) ** cos(x)) == "sin(x)^cos(x)"
  62. assert rcode(factorial(x) + gamma(y)) == "factorial(x) + gamma(y)"
  63. assert rcode(beta(Min(x, y), Max(x, y))) == "beta(min(x, y), max(x, y))"
  64. def test_rcode_inline_function():
  65. x = symbols('x')
  66. g = implemented_function('g', Lambda(x, 2*x))
  67. assert rcode(g(x)) == "2*x"
  68. g = implemented_function('g', Lambda(x, 2*x/Catalan))
  69. assert rcode(
  70. g(x)) == "Catalan = %s;\n2*x/Catalan" % Catalan.n()
  71. A = IndexedBase('A')
  72. i = Idx('i', symbols('n', integer=True))
  73. g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
  74. res=rcode(g(A[i]), assign_to=A[i])
  75. ref=(
  76. "for (i in 1:n){\n"
  77. " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
  78. "}"
  79. )
  80. assert res == ref
  81. def test_rcode_exceptions():
  82. assert rcode(ceiling(x)) == "ceiling(x)"
  83. assert rcode(Abs(x)) == "abs(x)"
  84. assert rcode(gamma(x)) == "gamma(x)"
  85. def test_rcode_user_functions():
  86. x = symbols('x', integer=False)
  87. n = symbols('n', integer=True)
  88. custom_functions = {
  89. "ceiling": "myceil",
  90. "Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")],
  91. }
  92. assert rcode(ceiling(x), user_functions=custom_functions) == "myceil(x)"
  93. assert rcode(Abs(x), user_functions=custom_functions) == "fabs(x)"
  94. assert rcode(Abs(n), user_functions=custom_functions) == "abs(n)"
  95. def test_rcode_boolean():
  96. assert rcode(True) == "True"
  97. assert rcode(S.true) == "True"
  98. assert rcode(False) == "False"
  99. assert rcode(S.false) == "False"
  100. assert rcode(x & y) == "x & y"
  101. assert rcode(x | y) == "x | y"
  102. assert rcode(~x) == "!x"
  103. assert rcode(x & y & z) == "x & y & z"
  104. assert rcode(x | y | z) == "x | y | z"
  105. assert rcode((x & y) | z) == "z | x & y"
  106. assert rcode((x | y) & z) == "z & (x | y)"
  107. def test_rcode_Relational():
  108. assert rcode(Eq(x, y)) == "x == y"
  109. assert rcode(Ne(x, y)) == "x != y"
  110. assert rcode(Le(x, y)) == "x <= y"
  111. assert rcode(Lt(x, y)) == "x < y"
  112. assert rcode(Gt(x, y)) == "x > y"
  113. assert rcode(Ge(x, y)) == "x >= y"
  114. def test_rcode_Piecewise():
  115. expr = Piecewise((x, x < 1), (x**2, True))
  116. res=rcode(expr)
  117. ref="ifelse(x < 1,x,x^2)"
  118. assert res == ref
  119. tau=Symbol("tau")
  120. res=rcode(expr,tau)
  121. ref="tau = ifelse(x < 1,x,x^2);"
  122. assert res == ref
  123. expr = 2*Piecewise((x, x < 1), (x**2, x<2), (x**3,True))
  124. assert rcode(expr) == "2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3))"
  125. res = rcode(expr, assign_to='c')
  126. assert res == "c = 2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3));"
  127. # Check that Piecewise without a True (default) condition error
  128. #expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
  129. #raises(ValueError, lambda: rcode(expr))
  130. expr = 2*Piecewise((x, x < 1), (x**2, x<2))
  131. assert(rcode(expr))== "2*ifelse(x < 1,x,ifelse(x < 2,x^2,NA))"
  132. def test_rcode_sinc():
  133. from sympy.functions.elementary.trigonometric import sinc
  134. expr = sinc(x)
  135. res = rcode(expr)
  136. ref = "ifelse(x != 0,sin(x)/x,1)"
  137. assert res == ref
  138. def test_rcode_Piecewise_deep():
  139. p = rcode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)))
  140. assert p == "2*ifelse(x < 1,x,ifelse(x < 2,x + 1,x^2))"
  141. expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1
  142. p = rcode(expr)
  143. ref="x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1"
  144. assert p == ref
  145. ref="c = x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1;"
  146. p = rcode(expr, assign_to='c')
  147. assert p == ref
  148. def test_rcode_ITE():
  149. expr = ITE(x < 1, y, z)
  150. p = rcode(expr)
  151. ref="ifelse(x < 1,y,z)"
  152. assert p == ref
  153. def test_rcode_settings():
  154. raises(TypeError, lambda: rcode(sin(x), method="garbage"))
  155. def test_rcode_Indexed():
  156. n, m, o = symbols('n m o', integer=True)
  157. i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
  158. p = RCodePrinter()
  159. p._not_r = set()
  160. x = IndexedBase('x')[j]
  161. assert p._print_Indexed(x) == 'x[j]'
  162. A = IndexedBase('A')[i, j]
  163. assert p._print_Indexed(A) == 'A[i, j]'
  164. B = IndexedBase('B')[i, j, k]
  165. assert p._print_Indexed(B) == 'B[i, j, k]'
  166. assert p._not_r == set()
  167. def test_rcode_Indexed_without_looking_for_contraction():
  168. len_y = 5
  169. y = IndexedBase('y', shape=(len_y,))
  170. x = IndexedBase('x', shape=(len_y,))
  171. Dy = IndexedBase('Dy', shape=(len_y-1,))
  172. i = Idx('i', len_y-1)
  173. e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
  174. code0 = rcode(e.rhs, assign_to=e.lhs, contract=False)
  175. assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1)
  176. def test_rcode_loops_matrix_vector():
  177. n, m = symbols('n m', integer=True)
  178. A = IndexedBase('A')
  179. x = IndexedBase('x')
  180. y = IndexedBase('y')
  181. i = Idx('i', m)
  182. j = Idx('j', n)
  183. s = (
  184. 'for (i in 1:m){\n'
  185. ' y[i] = 0;\n'
  186. '}\n'
  187. 'for (i in 1:m){\n'
  188. ' for (j in 1:n){\n'
  189. ' y[i] = A[i, j]*x[j] + y[i];\n'
  190. ' }\n'
  191. '}'
  192. )
  193. c = rcode(A[i, j]*x[j], assign_to=y[i])
  194. assert c == s
  195. def test_dummy_loops():
  196. # the following line could also be
  197. # [Dummy(s, integer=True) for s in 'im']
  198. # or [Dummy(integer=True) for s in 'im']
  199. i, m = symbols('i m', integer=True, cls=Dummy)
  200. x = IndexedBase('x')
  201. y = IndexedBase('y')
  202. i = Idx(i, m)
  203. expected = (
  204. 'for (i_%(icount)i in 1:m_%(mcount)i){\n'
  205. ' y[i_%(icount)i] = x[i_%(icount)i];\n'
  206. '}'
  207. ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
  208. code = rcode(x[i], assign_to=y[i])
  209. assert code == expected
  210. def test_rcode_loops_add():
  211. n, m = symbols('n m', integer=True)
  212. A = IndexedBase('A')
  213. x = IndexedBase('x')
  214. y = IndexedBase('y')
  215. z = IndexedBase('z')
  216. i = Idx('i', m)
  217. j = Idx('j', n)
  218. s = (
  219. 'for (i in 1:m){\n'
  220. ' y[i] = x[i] + z[i];\n'
  221. '}\n'
  222. 'for (i in 1:m){\n'
  223. ' for (j in 1:n){\n'
  224. ' y[i] = A[i, j]*x[j] + y[i];\n'
  225. ' }\n'
  226. '}'
  227. )
  228. c = rcode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
  229. assert c == s
  230. def test_rcode_loops_multiple_contractions():
  231. n, m, o, p = symbols('n m o p', integer=True)
  232. a = IndexedBase('a')
  233. b = IndexedBase('b')
  234. y = IndexedBase('y')
  235. i = Idx('i', m)
  236. j = Idx('j', n)
  237. k = Idx('k', o)
  238. l = Idx('l', p)
  239. s = (
  240. 'for (i in 1:m){\n'
  241. ' y[i] = 0;\n'
  242. '}\n'
  243. 'for (i in 1:m){\n'
  244. ' for (j in 1:n){\n'
  245. ' for (k in 1:o){\n'
  246. ' for (l in 1:p){\n'
  247. ' y[i] = a[i, j, k, l]*b[j, k, l] + y[i];\n'
  248. ' }\n'
  249. ' }\n'
  250. ' }\n'
  251. '}'
  252. )
  253. c = rcode(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
  254. assert c == s
  255. def test_rcode_loops_addfactor():
  256. n, m, o, p = symbols('n m o p', integer=True)
  257. a = IndexedBase('a')
  258. b = IndexedBase('b')
  259. c = IndexedBase('c')
  260. y = IndexedBase('y')
  261. i = Idx('i', m)
  262. j = Idx('j', n)
  263. k = Idx('k', o)
  264. l = Idx('l', p)
  265. s = (
  266. 'for (i in 1:m){\n'
  267. ' y[i] = 0;\n'
  268. '}\n'
  269. 'for (i in 1:m){\n'
  270. ' for (j in 1:n){\n'
  271. ' for (k in 1:o){\n'
  272. ' for (l in 1:p){\n'
  273. ' y[i] = (a[i, j, k, l] + b[i, j, k, l])*c[j, k, l] + y[i];\n'
  274. ' }\n'
  275. ' }\n'
  276. ' }\n'
  277. '}'
  278. )
  279. c = rcode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
  280. assert c == s
  281. def test_rcode_loops_multiple_terms():
  282. n, m, o, p = symbols('n m o p', integer=True)
  283. a = IndexedBase('a')
  284. b = IndexedBase('b')
  285. c = IndexedBase('c')
  286. y = IndexedBase('y')
  287. i = Idx('i', m)
  288. j = Idx('j', n)
  289. k = Idx('k', o)
  290. s0 = (
  291. 'for (i in 1:m){\n'
  292. ' y[i] = 0;\n'
  293. '}\n'
  294. )
  295. s1 = (
  296. 'for (i in 1:m){\n'
  297. ' for (j in 1:n){\n'
  298. ' for (k in 1:o){\n'
  299. ' y[i] = b[j]*b[k]*c[i, j, k] + y[i];\n'
  300. ' }\n'
  301. ' }\n'
  302. '}\n'
  303. )
  304. s2 = (
  305. 'for (i in 1:m){\n'
  306. ' for (k in 1:o){\n'
  307. ' y[i] = a[i, k]*b[k] + y[i];\n'
  308. ' }\n'
  309. '}\n'
  310. )
  311. s3 = (
  312. 'for (i in 1:m){\n'
  313. ' for (j in 1:n){\n'
  314. ' y[i] = a[i, j]*b[j] + y[i];\n'
  315. ' }\n'
  316. '}\n'
  317. )
  318. c = rcode(
  319. b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
  320. ref={}
  321. ref[0] = s0 + s1 + s2 + s3[:-1]
  322. ref[1] = s0 + s1 + s3 + s2[:-1]
  323. ref[2] = s0 + s2 + s1 + s3[:-1]
  324. ref[3] = s0 + s2 + s3 + s1[:-1]
  325. ref[4] = s0 + s3 + s1 + s2[:-1]
  326. ref[5] = s0 + s3 + s2 + s1[:-1]
  327. assert (c == ref[0] or
  328. c == ref[1] or
  329. c == ref[2] or
  330. c == ref[3] or
  331. c == ref[4] or
  332. c == ref[5])
  333. def test_dereference_printing():
  334. expr = x + y + sin(z) + z
  335. assert rcode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))"
  336. def test_Matrix_printing():
  337. # Test returning a Matrix
  338. mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
  339. A = MatrixSymbol('A', 3, 1)
  340. p = rcode(mat, A)
  341. assert p == (
  342. "A[0] = x*y;\n"
  343. "A[1] = ifelse(y > 0,x + 2,y);\n"
  344. "A[2] = sin(z);")
  345. # Test using MatrixElements in expressions
  346. expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
  347. p = rcode(expr)
  348. assert p == ("ifelse(x > 0,2*A[2],A[2]) + sin(A[1]) + A[0]")
  349. # Test using MatrixElements in a Matrix
  350. q = MatrixSymbol('q', 5, 1)
  351. M = MatrixSymbol('M', 3, 3)
  352. m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
  353. [q[1,0] + q[2,0], q[3, 0], 5],
  354. [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
  355. assert rcode(m, M) == (
  356. "M[0] = sin(q[1]);\n"
  357. "M[1] = 0;\n"
  358. "M[2] = cos(q[2]);\n"
  359. "M[3] = q[1] + q[2];\n"
  360. "M[4] = q[3];\n"
  361. "M[5] = 5;\n"
  362. "M[6] = 2*q[4]/q[1];\n"
  363. "M[7] = sqrt(q[0]) + 4;\n"
  364. "M[8] = 0;")
  365. def test_rcode_sgn():
  366. expr = sign(x) * y
  367. assert rcode(expr) == 'y*sign(x)'
  368. p = rcode(expr, 'z')
  369. assert p == 'z = y*sign(x);'
  370. p = rcode(sign(2 * x + x**2) * x + x**2)
  371. assert p == "x^2 + x*sign(x^2 + 2*x)"
  372. expr = sign(cos(x))
  373. p = rcode(expr)
  374. assert p == 'sign(cos(x))'
  375. def test_rcode_Assignment():
  376. assert rcode(Assignment(x, y + z)) == 'x = y + z;'
  377. assert rcode(aug_assign(x, '+', y + z)) == 'x += y + z;'
  378. def test_rcode_For():
  379. f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)])
  380. sol = rcode(f)
  381. assert sol == ("for(x in seq(from=0, to=9, by=2){\n"
  382. " y *= x;\n"
  383. "}")
  384. def test_MatrixElement_printing():
  385. # test cases for issue #11821
  386. A = MatrixSymbol("A", 1, 3)
  387. B = MatrixSymbol("B", 1, 3)
  388. C = MatrixSymbol("C", 1, 3)
  389. assert(rcode(A[0, 0]) == "A[0]")
  390. assert(rcode(3 * A[0, 0]) == "3*A[0]")
  391. F = C[0, 0].subs(C, A - B)
  392. assert(rcode(F) == "(A - B)[0]")