test_rust.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. from sympy.core import (S, pi, oo, symbols, Rational, Integer,
  2. GoldenRatio, EulerGamma, Catalan, Lambda, Dummy,
  3. Eq, Ne, Le, Lt, Gt, Ge, Mod)
  4. from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
  5. sign, floor)
  6. from sympy.logic import ITE
  7. from sympy.testing.pytest import raises
  8. from sympy.utilities.lambdify import implemented_function
  9. from sympy.tensor import IndexedBase, Idx
  10. from sympy.matrices import MatrixSymbol, SparseMatrix, Matrix
  11. from sympy.printing.rust import rust_code
  12. x, y, z = symbols('x,y,z')
  13. def test_Integer():
  14. assert rust_code(Integer(42)) == "42"
  15. assert rust_code(Integer(-56)) == "-56"
  16. def test_Relational():
  17. assert rust_code(Eq(x, y)) == "x == y"
  18. assert rust_code(Ne(x, y)) == "x != y"
  19. assert rust_code(Le(x, y)) == "x <= y"
  20. assert rust_code(Lt(x, y)) == "x < y"
  21. assert rust_code(Gt(x, y)) == "x > y"
  22. assert rust_code(Ge(x, y)) == "x >= y"
  23. def test_Rational():
  24. assert rust_code(Rational(3, 7)) == "3_f64/7.0"
  25. assert rust_code(Rational(18, 9)) == "2"
  26. assert rust_code(Rational(3, -7)) == "-3_f64/7.0"
  27. assert rust_code(Rational(-3, -7)) == "3_f64/7.0"
  28. assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0"
  29. assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x"
  30. def test_basic_ops():
  31. assert rust_code(x + y) == "x + y"
  32. assert rust_code(x - y) == "x - y"
  33. assert rust_code(x * y) == "x*y"
  34. assert rust_code(x / y) == "x/y"
  35. assert rust_code(-x) == "-x"
  36. def test_printmethod():
  37. class fabs(Abs):
  38. def _rust_code(self, printer):
  39. return "%s.fabs()" % printer._print(self.args[0])
  40. assert rust_code(fabs(x)) == "x.fabs()"
  41. a = MatrixSymbol("a", 1, 3)
  42. assert rust_code(a[0,0]) == 'a[0]'
  43. def test_Functions():
  44. assert rust_code(sin(x) ** cos(x)) == "x.sin().powf(x.cos())"
  45. assert rust_code(abs(x)) == "x.abs()"
  46. assert rust_code(ceiling(x)) == "x.ceil()"
  47. assert rust_code(floor(x)) == "x.floor()"
  48. # Automatic rewrite
  49. assert rust_code(Mod(x, 3)) == 'x - 3*((1_f64/3.0)*x).floor()'
  50. def test_Pow():
  51. assert rust_code(1/x) == "x.recip()"
  52. assert rust_code(x**-1) == rust_code(x**-1.0) == "x.recip()"
  53. assert rust_code(sqrt(x)) == "x.sqrt()"
  54. assert rust_code(x**S.Half) == rust_code(x**0.5) == "x.sqrt()"
  55. assert rust_code(1/sqrt(x)) == "x.sqrt().recip()"
  56. assert rust_code(x**-S.Half) == rust_code(x**-0.5) == "x.sqrt().recip()"
  57. assert rust_code(1/pi) == "PI.recip()"
  58. assert rust_code(pi**-1) == rust_code(pi**-1.0) == "PI.recip()"
  59. assert rust_code(pi**-0.5) == "PI.sqrt().recip()"
  60. assert rust_code(x**Rational(1, 3)) == "x.cbrt()"
  61. assert rust_code(2**x) == "x.exp2()"
  62. assert rust_code(exp(x)) == "x.exp()"
  63. assert rust_code(x**3) == "x.powi(3)"
  64. assert rust_code(x**(y**3)) == "x.powf(y.powi(3))"
  65. assert rust_code(x**Rational(2, 3)) == "x.powf(2_f64/3.0)"
  66. g = implemented_function('g', Lambda(x, 2*x))
  67. assert rust_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
  68. "(3.5*2*x).powf(-x + y.powf(x))/(x.powi(2) + y)"
  69. _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi", 1),
  70. (lambda base, exp: not exp.is_integer, "pow", 1)]
  71. assert rust_code(x**3, user_functions={'Pow': _cond_cfunc}) == 'x.dpowi(3)'
  72. assert rust_code(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'x.pow(3.2)'
  73. def test_constants():
  74. assert rust_code(pi) == "PI"
  75. assert rust_code(oo) == "INFINITY"
  76. assert rust_code(S.Infinity) == "INFINITY"
  77. assert rust_code(-oo) == "NEG_INFINITY"
  78. assert rust_code(S.NegativeInfinity) == "NEG_INFINITY"
  79. assert rust_code(S.NaN) == "NAN"
  80. assert rust_code(exp(1)) == "E"
  81. assert rust_code(S.Exp1) == "E"
  82. def test_constants_other():
  83. assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
  84. assert rust_code(
  85. 2*Catalan) == "const Catalan: f64 = %s;\n2*Catalan" % Catalan.evalf(17)
  86. assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
  87. def test_boolean():
  88. assert rust_code(True) == "true"
  89. assert rust_code(S.true) == "true"
  90. assert rust_code(False) == "false"
  91. assert rust_code(S.false) == "false"
  92. assert rust_code(x & y) == "x && y"
  93. assert rust_code(x | y) == "x || y"
  94. assert rust_code(~x) == "!x"
  95. assert rust_code(x & y & z) == "x && y && z"
  96. assert rust_code(x | y | z) == "x || y || z"
  97. assert rust_code((x & y) | z) == "z || x && y"
  98. assert rust_code((x | y) & z) == "z && (x || y)"
  99. def test_Piecewise():
  100. expr = Piecewise((x, x < 1), (x + 2, True))
  101. assert rust_code(expr) == (
  102. "if (x < 1) {\n"
  103. " x\n"
  104. "} else {\n"
  105. " x + 2\n"
  106. "}")
  107. assert rust_code(expr, assign_to="r") == (
  108. "r = if (x < 1) {\n"
  109. " x\n"
  110. "} else {\n"
  111. " x + 2\n"
  112. "};")
  113. assert rust_code(expr, assign_to="r", inline=True) == (
  114. "r = if (x < 1) { x } else { x + 2 };")
  115. expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
  116. assert rust_code(expr, inline=True) == (
  117. "if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }")
  118. assert rust_code(expr, assign_to="r", inline=True) == (
  119. "r = if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 };")
  120. assert rust_code(expr, assign_to="r") == (
  121. "r = if (x < 1) {\n"
  122. " x\n"
  123. "} else if (x < 5) {\n"
  124. " x + 1\n"
  125. "} else {\n"
  126. " x + 2\n"
  127. "};")
  128. expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True))
  129. assert rust_code(expr, inline=True) == (
  130. "2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }")
  131. expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42
  132. assert rust_code(expr, inline=True) == (
  133. "2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 } - 42")
  134. # Check that Piecewise without a True (default) condition error
  135. expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
  136. raises(ValueError, lambda: rust_code(expr))
  137. def test_dereference_printing():
  138. expr = x + y + sin(z) + z
  139. assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()"
  140. def test_sign():
  141. expr = sign(x) * y
  142. assert rust_code(expr) == "y*x.signum()"
  143. assert rust_code(expr, assign_to='r') == "r = y*x.signum();"
  144. expr = sign(x + y) + 42
  145. assert rust_code(expr) == "(x + y).signum() + 42"
  146. assert rust_code(expr, assign_to='r') == "r = (x + y).signum() + 42;"
  147. expr = sign(cos(x))
  148. assert rust_code(expr) == "x.cos().signum()"
  149. def test_reserved_words():
  150. x, y = symbols("x if")
  151. expr = sin(y)
  152. assert rust_code(expr) == "if_.sin()"
  153. assert rust_code(expr, dereference=[y]) == "(*if_).sin()"
  154. assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()"
  155. with raises(ValueError):
  156. rust_code(expr, error_on_reserved=True)
  157. def test_ITE():
  158. expr = ITE(x < 1, y, z)
  159. assert rust_code(expr) == (
  160. "if (x < 1) {\n"
  161. " y\n"
  162. "} else {\n"
  163. " z\n"
  164. "}")
  165. def test_Indexed():
  166. n, m, o = symbols('n m o', integer=True)
  167. i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
  168. x = IndexedBase('x')[j]
  169. assert rust_code(x) == "x[j]"
  170. A = IndexedBase('A')[i, j]
  171. assert rust_code(A) == "A[m*i + j]"
  172. B = IndexedBase('B')[i, j, k]
  173. assert rust_code(B) == "B[m*o*i + o*j + k]"
  174. def test_dummy_loops():
  175. i, m = symbols('i m', integer=True, cls=Dummy)
  176. x = IndexedBase('x')
  177. y = IndexedBase('y')
  178. i = Idx(i, m)
  179. assert rust_code(x[i], assign_to=y[i]) == (
  180. "for i in 0..m {\n"
  181. " y[i] = x[i];\n"
  182. "}")
  183. def test_loops():
  184. m, n = symbols('m n', integer=True)
  185. A = IndexedBase('A')
  186. x = IndexedBase('x')
  187. y = IndexedBase('y')
  188. z = IndexedBase('z')
  189. i = Idx('i', m)
  190. j = Idx('j', n)
  191. assert rust_code(A[i, j]*x[j], assign_to=y[i]) == (
  192. "for i in 0..m {\n"
  193. " y[i] = 0;\n"
  194. "}\n"
  195. "for i in 0..m {\n"
  196. " for j in 0..n {\n"
  197. " y[i] = A[n*i + j]*x[j] + y[i];\n"
  198. " }\n"
  199. "}")
  200. assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == (
  201. "for i in 0..m {\n"
  202. " y[i] = x[i] + z[i];\n"
  203. "}\n"
  204. "for i in 0..m {\n"
  205. " for j in 0..n {\n"
  206. " y[i] = A[n*i + j]*x[j] + y[i];\n"
  207. " }\n"
  208. "}")
  209. def test_loops_multiple_contractions():
  210. n, m, o, p = symbols('n m o p', integer=True)
  211. a = IndexedBase('a')
  212. b = IndexedBase('b')
  213. y = IndexedBase('y')
  214. i = Idx('i', m)
  215. j = Idx('j', n)
  216. k = Idx('k', o)
  217. l = Idx('l', p)
  218. assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == (
  219. "for i in 0..m {\n"
  220. " y[i] = 0;\n"
  221. "}\n"
  222. "for i in 0..m {\n"
  223. " for j in 0..n {\n"
  224. " for k in 0..o {\n"
  225. " for l in 0..p {\n"
  226. " y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
  227. " }\n"
  228. " }\n"
  229. " }\n"
  230. "}")
  231. def test_loops_addfactor():
  232. m, n, o, p = symbols('m n o p', integer=True)
  233. a = IndexedBase('a')
  234. b = IndexedBase('b')
  235. c = IndexedBase('c')
  236. y = IndexedBase('y')
  237. i = Idx('i', m)
  238. j = Idx('j', n)
  239. k = Idx('k', o)
  240. l = Idx('l', p)
  241. code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
  242. assert code == (
  243. "for i in 0..m {\n"
  244. " y[i] = 0;\n"
  245. "}\n"
  246. "for i in 0..m {\n"
  247. " for j in 0..n {\n"
  248. " for k in 0..o {\n"
  249. " for l in 0..p {\n"
  250. " y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
  251. " }\n"
  252. " }\n"
  253. " }\n"
  254. "}")
  255. def test_settings():
  256. raises(TypeError, lambda: rust_code(sin(x), method="garbage"))
  257. def test_inline_function():
  258. x = symbols('x')
  259. g = implemented_function('g', Lambda(x, 2*x))
  260. assert rust_code(g(x)) == "2*x"
  261. g = implemented_function('g', Lambda(x, 2*x/Catalan))
  262. assert rust_code(g(x)) == (
  263. "const Catalan: f64 = %s;\n2*x/Catalan" % Catalan.evalf(17))
  264. A = IndexedBase('A')
  265. i = Idx('i', symbols('n', integer=True))
  266. g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
  267. assert rust_code(g(A[i]), assign_to=A[i]) == (
  268. "for i in 0..n {\n"
  269. " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
  270. "}")
  271. def test_user_functions():
  272. x = symbols('x', integer=False)
  273. n = symbols('n', integer=True)
  274. custom_functions = {
  275. "ceiling": "ceil",
  276. "Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)],
  277. }
  278. assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()"
  279. assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)"
  280. assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)"
  281. def test_matrix():
  282. assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]'
  283. with raises(ValueError):
  284. rust_code(Matrix([[1, 2, 3]]))
  285. def test_sparse_matrix():
  286. # gh-15791
  287. assert 'Not supported in Rust' in rust_code(SparseMatrix([[1, 2, 3]]))