test_codegen_rust.py 12 KB


  1. from io import StringIO
  2. from sympy.core import S, symbols, pi, Catalan, EulerGamma, Function
  3. from sympy.core.relational import Equality
  4. from sympy.functions.elementary.piecewise import Piecewise
  5. from sympy.utilities.codegen import RustCodeGen, codegen, make_routine
  6. from sympy.testing.pytest import XFAIL
  7. import sympy
  8. x, y, z = symbols('x,y,z')
  9. def test_empty_rust_code():
  10. code_gen = RustCodeGen()
  11. output = StringIO()
  12. code_gen.dump_rs([], output, "file", header=False, empty=False)
  13. source = output.getvalue()
  14. assert source == ""
  15. def test_simple_rust_code():
  16. name_expr = ("test", (x + y)*z)
  17. result, = codegen(name_expr, "Rust", header=False, empty=False)
  18. assert result[0] == "test.rs"
  19. source = result[1]
  20. expected = (
  21. "fn test(x: f64, y: f64, z: f64) -> f64 {\n"
  22. " let out1 = z*(x + y);\n"
  23. " out1\n"
  24. "}\n"
  25. )
  26. assert source == expected
  27. def test_simple_code_with_header():
  28. name_expr = ("test", (x + y)*z)
  29. result, = codegen(name_expr, "Rust", header=True, empty=False)
  30. assert result[0] == "test.rs"
  31. source = result[1]
  32. version_str = "Code generated with SymPy %s" % sympy.__version__
  33. version_line = version_str.center(76).rstrip()
  34. expected = (
  35. "/*\n"
  36. " *%(version_line)s\n"
  37. " *\n"
  38. " * See http://www.sympy.org/ for more information.\n"
  39. " *\n"
  40. " * This file is part of 'project'\n"
  41. " */\n"
  42. "fn test(x: f64, y: f64, z: f64) -> f64 {\n"
  43. " let out1 = z*(x + y);\n"
  44. " out1\n"
  45. "}\n"
  46. ) % {'version_line': version_line}
  47. assert source == expected
  48. def test_simple_code_nameout():
  49. expr = Equality(z, (x + y))
  50. name_expr = ("test", expr)
  51. result, = codegen(name_expr, "Rust", header=False, empty=False)
  52. source = result[1]
  53. expected = (
  54. "fn test(x: f64, y: f64) -> f64 {\n"
  55. " let z = x + y;\n"
  56. " z\n"
  57. "}\n"
  58. )
  59. assert source == expected
  60. def test_numbersymbol():
  61. name_expr = ("test", pi**Catalan)
  62. result, = codegen(name_expr, "Rust", header=False, empty=False)
  63. source = result[1]
  64. expected = (
  65. "fn test() -> f64 {\n"
  66. " const Catalan: f64 = %s;\n"
  67. " let out1 = PI.powf(Catalan);\n"
  68. " out1\n"
  69. "}\n"
  70. ) % Catalan.evalf(17)
  71. assert source == expected
  72. @XFAIL
  73. def test_numbersymbol_inline():
  74. # FIXME: how to pass inline to the RustCodePrinter?
  75. name_expr = ("test", [pi**Catalan, EulerGamma])
  76. result, = codegen(name_expr, "Rust", header=False,
  77. empty=False, inline=True)
  78. source = result[1]
  79. expected = (
  80. "fn test() -> (f64, f64) {\n"
  81. " const Catalan: f64 = %s;\n"
  82. " const EulerGamma: f64 = %s;\n"
  83. " let out1 = PI.powf(Catalan);\n"
  84. " let out2 = EulerGamma);\n"
  85. " (out1, out2)\n"
  86. "}\n"
  87. ) % (Catalan.evalf(17), EulerGamma.evalf(17))
  88. assert source == expected
  89. def test_argument_order():
  90. expr = x + y
  91. routine = make_routine("test", expr, argument_sequence=[z, x, y], language="rust")
  92. code_gen = RustCodeGen()
  93. output = StringIO()
  94. code_gen.dump_rs([routine], output, "test", header=False, empty=False)
  95. source = output.getvalue()
  96. expected = (
  97. "fn test(z: f64, x: f64, y: f64) -> f64 {\n"
  98. " let out1 = x + y;\n"
  99. " out1\n"
  100. "}\n"
  101. )
  102. assert source == expected
  103. def test_multiple_results_rust():
  104. # Here the output order is the input order
  105. expr1 = (x + y)*z
  106. expr2 = (x - y)*z
  107. name_expr = ("test", [expr1, expr2])
  108. result, = codegen(name_expr, "Rust", header=False, empty=False)
  109. source = result[1]
  110. expected = (
  111. "fn test(x: f64, y: f64, z: f64) -> (f64, f64) {\n"
  112. " let out1 = z*(x + y);\n"
  113. " let out2 = z*(x - y);\n"
  114. " (out1, out2)\n"
  115. "}\n"
  116. )
  117. assert source == expected
  118. def test_results_named_unordered():
  119. # Here output order is based on name_expr
  120. A, B, C = symbols('A,B,C')
  121. expr1 = Equality(C, (x + y)*z)
  122. expr2 = Equality(A, (x - y)*z)
  123. expr3 = Equality(B, 2*x)
  124. name_expr = ("test", [expr1, expr2, expr3])
  125. result, = codegen(name_expr, "Rust", header=False, empty=False)
  126. source = result[1]
  127. expected = (
  128. "fn test(x: f64, y: f64, z: f64) -> (f64, f64, f64) {\n"
  129. " let C = z*(x + y);\n"
  130. " let A = z*(x - y);\n"
  131. " let B = 2*x;\n"
  132. " (C, A, B)\n"
  133. "}\n"
  134. )
  135. assert source == expected
  136. def test_results_named_ordered():
  137. A, B, C = symbols('A,B,C')
  138. expr1 = Equality(C, (x + y)*z)
  139. expr2 = Equality(A, (x - y)*z)
  140. expr3 = Equality(B, 2*x)
  141. name_expr = ("test", [expr1, expr2, expr3])
  142. result = codegen(name_expr, "Rust", header=False, empty=False,
  143. argument_sequence=(x, z, y))
  144. assert result[0][0] == "test.rs"
  145. source = result[0][1]
  146. expected = (
  147. "fn test(x: f64, z: f64, y: f64) -> (f64, f64, f64) {\n"
  148. " let C = z*(x + y);\n"
  149. " let A = z*(x - y);\n"
  150. " let B = 2*x;\n"
  151. " (C, A, B)\n"
  152. "}\n"
  153. )
  154. assert source == expected
  155. def test_complicated_rs_codegen():
  156. from sympy.functions.elementary.trigonometric import (cos, sin, tan)
  157. name_expr = ("testlong",
  158. [ ((sin(x) + cos(y) + tan(z))**3).expand(),
  159. cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
  160. ])
  161. result = codegen(name_expr, "Rust", header=False, empty=False)
  162. assert result[0][0] == "testlong.rs"
  163. source = result[0][1]
  164. expected = (
  165. "fn testlong(x: f64, y: f64, z: f64) -> (f64, f64) {\n"
  166. " let out1 = x.sin().powi(3) + 3*x.sin().powi(2)*y.cos()"
  167. " + 3*x.sin().powi(2)*z.tan() + 3*x.sin()*y.cos().powi(2)"
  168. " + 6*x.sin()*y.cos()*z.tan() + 3*x.sin()*z.tan().powi(2)"
  169. " + y.cos().powi(3) + 3*y.cos().powi(2)*z.tan()"
  170. " + 3*y.cos()*z.tan().powi(2) + z.tan().powi(3);\n"
  171. " let out2 = (x + y + z).cos().cos().cos().cos()"
  172. ".cos().cos().cos().cos();\n"
  173. " (out1, out2)\n"
  174. "}\n"
  175. )
  176. assert source == expected
  177. def test_output_arg_mixed_unordered():
  178. # named outputs are alphabetical, unnamed output appear in the given order
  179. from sympy.functions.elementary.trigonometric import (cos, sin)
  180. a = symbols("a")
  181. name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
  182. result, = codegen(name_expr, "Rust", header=False, empty=False)
  183. assert result[0] == "foo.rs"
  184. source = result[1];
  185. expected = (
  186. "fn foo(x: f64) -> (f64, f64, f64, f64) {\n"
  187. " let out1 = (2*x).cos();\n"
  188. " let y = x.sin();\n"
  189. " let out3 = x.cos();\n"
  190. " let a = (2*x).sin();\n"
  191. " (out1, y, out3, a)\n"
  192. "}\n"
  193. )
  194. assert source == expected
  195. def test_piecewise_():
  196. pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
  197. name_expr = ("pwtest", pw)
  198. result, = codegen(name_expr, "Rust", header=False, empty=False)
  199. source = result[1]
  200. expected = (
  201. "fn pwtest(x: f64) -> f64 {\n"
  202. " let out1 = if (x < -1) {\n"
  203. " 0\n"
  204. " } else if (x <= 1) {\n"
  205. " x.powi(2)\n"
  206. " } else if (x > 1) {\n"
  207. " 2 - x\n"
  208. " } else {\n"
  209. " 1\n"
  210. " };\n"
  211. " out1\n"
  212. "}\n"
  213. )
  214. assert source == expected
  215. @XFAIL
  216. def test_piecewise_inline():
  217. # FIXME: how to pass inline to the RustCodePrinter?
  218. pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
  219. name_expr = ("pwtest", pw)
  220. result, = codegen(name_expr, "Rust", header=False, empty=False,
  221. inline=True)
  222. source = result[1]
  223. expected = (
  224. "fn pwtest(x: f64) -> f64 {\n"
  225. " let out1 = if (x < -1) { 0 } else if (x <= 1) { x.powi(2) }"
  226. " else if (x > 1) { -x + 2 } else { 1 };\n"
  227. " out1\n"
  228. "}\n"
  229. )
  230. assert source == expected
  231. def test_multifcns_per_file():
  232. name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
  233. result = codegen(name_expr, "Rust", header=False, empty=False)
  234. assert result[0][0] == "foo.rs"
  235. source = result[0][1];
  236. expected = (
  237. "fn foo(x: f64, y: f64) -> (f64, f64) {\n"
  238. " let out1 = 2*x;\n"
  239. " let out2 = 3*y;\n"
  240. " (out1, out2)\n"
  241. "}\n"
  242. "fn bar(y: f64) -> (f64, f64) {\n"
  243. " let out1 = y.powi(2);\n"
  244. " let out2 = 4*y;\n"
  245. " (out1, out2)\n"
  246. "}\n"
  247. )
  248. assert source == expected
  249. def test_multifcns_per_file_w_header():
  250. name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
  251. result = codegen(name_expr, "Rust", header=True, empty=False)
  252. assert result[0][0] == "foo.rs"
  253. source = result[0][1];
  254. version_str = "Code generated with SymPy %s" % sympy.__version__
  255. version_line = version_str.center(76).rstrip()
  256. expected = (
  257. "/*\n"
  258. " *%(version_line)s\n"
  259. " *\n"
  260. " * See http://www.sympy.org/ for more information.\n"
  261. " *\n"
  262. " * This file is part of 'project'\n"
  263. " */\n"
  264. "fn foo(x: f64, y: f64) -> (f64, f64) {\n"
  265. " let out1 = 2*x;\n"
  266. " let out2 = 3*y;\n"
  267. " (out1, out2)\n"
  268. "}\n"
  269. "fn bar(y: f64) -> (f64, f64) {\n"
  270. " let out1 = y.powi(2);\n"
  271. " let out2 = 4*y;\n"
  272. " (out1, out2)\n"
  273. "}\n"
  274. ) % {'version_line': version_line}
  275. assert source == expected
  276. def test_filename_match_prefix():
  277. name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
  278. result, = codegen(name_expr, "Rust", prefix="baz", header=False,
  279. empty=False)
  280. assert result[0] == "baz.rs"
  281. def test_InOutArgument():
  282. expr = Equality(x, x**2)
  283. name_expr = ("mysqr", expr)
  284. result, = codegen(name_expr, "Rust", header=False, empty=False)
  285. source = result[1]
  286. expected = (
  287. "fn mysqr(x: f64) -> f64 {\n"
  288. " let x = x.powi(2);\n"
  289. " x\n"
  290. "}\n"
  291. )
  292. assert source == expected
  293. def test_InOutArgument_order():
  294. # can specify the order as (x, y)
  295. expr = Equality(x, x**2 + y)
  296. name_expr = ("test", expr)
  297. result, = codegen(name_expr, "Rust", header=False,
  298. empty=False, argument_sequence=(x,y))
  299. source = result[1]
  300. expected = (
  301. "fn test(x: f64, y: f64) -> f64 {\n"
  302. " let x = x.powi(2) + y;\n"
  303. " x\n"
  304. "}\n"
  305. )
  306. assert source == expected
  307. # make sure it gives (x, y) not (y, x)
  308. expr = Equality(x, x**2 + y)
  309. name_expr = ("test", expr)
  310. result, = codegen(name_expr, "Rust", header=False, empty=False)
  311. source = result[1]
  312. expected = (
  313. "fn test(x: f64, y: f64) -> f64 {\n"
  314. " let x = x.powi(2) + y;\n"
  315. " x\n"
  316. "}\n"
  317. )
  318. assert source == expected
  319. def test_not_supported():
  320. f = Function('f')
  321. name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
  322. result, = codegen(name_expr, "Rust", header=False, empty=False)
  323. source = result[1]
  324. expected = (
  325. "fn test(x: f64) -> (f64, f64) {\n"
  326. " // unsupported: Derivative(f(x), x)\n"
  327. " // unsupported: zoo\n"
  328. " let out1 = Derivative(f(x), x);\n"
  329. " let out2 = zoo;\n"
  330. " (out1, out2)\n"
  331. "}\n"
  332. )
  333. assert source == expected
  334. def test_global_vars_rust():
  335. x, y, z, t = symbols("x y z t")
  336. result = codegen(('f', x*y), "Rust", header=False, empty=False,
  337. global_vars=(y,))
  338. source = result[0][1]
  339. expected = (
  340. "fn f(x: f64) -> f64 {\n"
  341. " let out1 = x*y;\n"
  342. " out1\n"
  343. "}\n"
  344. )
  345. assert source == expected
  346. result = codegen(('f', x*y+z), "Rust", header=False, empty=False,
  347. argument_sequence=(x, y), global_vars=(z, t))
  348. source = result[0][1]
  349. expected = (
  350. "fn f(x: f64, y: f64) -> f64 {\n"
  351. " let out1 = x*y + z;\n"
  352. " out1\n"
  353. "}\n"
  354. )
  355. assert source == expected