test_codegen_julia.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. from io import StringIO
  2. from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function
  3. from sympy.core.relational import Equality
  4. from sympy.functions.elementary.piecewise import Piecewise
  5. from sympy.matrices import Matrix, MatrixSymbol
  6. from sympy.utilities.codegen import JuliaCodeGen, codegen, make_routine
  7. from sympy.testing.pytest import XFAIL
  8. import sympy
  9. x, y, z = symbols('x,y,z')
  10. def test_empty_jl_code():
  11. code_gen = JuliaCodeGen()
  12. output = StringIO()
  13. code_gen.dump_jl([], output, "file", header=False, empty=False)
  14. source = output.getvalue()
  15. assert source == ""
  16. def test_jl_simple_code():
  17. name_expr = ("test", (x + y)*z)
  18. result, = codegen(name_expr, "Julia", header=False, empty=False)
  19. assert result[0] == "test.jl"
  20. source = result[1]
  21. expected = (
  22. "function test(x, y, z)\n"
  23. " out1 = z .* (x + y)\n"
  24. " return out1\n"
  25. "end\n"
  26. )
  27. assert source == expected
  28. def test_jl_simple_code_with_header():
  29. name_expr = ("test", (x + y)*z)
  30. result, = codegen(name_expr, "Julia", header=True, empty=False)
  31. assert result[0] == "test.jl"
  32. source = result[1]
  33. expected = (
  34. "# Code generated with SymPy " + sympy.__version__ + "\n"
  35. "#\n"
  36. "# See http://www.sympy.org/ for more information.\n"
  37. "#\n"
  38. "# This file is part of 'project'\n"
  39. "function test(x, y, z)\n"
  40. " out1 = z .* (x + y)\n"
  41. " return out1\n"
  42. "end\n"
  43. )
  44. assert source == expected
  45. def test_jl_simple_code_nameout():
  46. expr = Equality(z, (x + y))
  47. name_expr = ("test", expr)
  48. result, = codegen(name_expr, "Julia", header=False, empty=False)
  49. source = result[1]
  50. expected = (
  51. "function test(x, y)\n"
  52. " z = x + y\n"
  53. " return z\n"
  54. "end\n"
  55. )
  56. assert source == expected
  57. def test_jl_numbersymbol():
  58. name_expr = ("test", pi**Catalan)
  59. result, = codegen(name_expr, "Julia", header=False, empty=False)
  60. source = result[1]
  61. expected = (
  62. "function test()\n"
  63. " out1 = pi ^ catalan\n"
  64. " return out1\n"
  65. "end\n"
  66. )
  67. assert source == expected
  68. @XFAIL
  69. def test_jl_numbersymbol_no_inline():
  70. # FIXME: how to pass inline=False to the JuliaCodePrinter?
  71. name_expr = ("test", [pi**Catalan, EulerGamma])
  72. result, = codegen(name_expr, "Julia", header=False,
  73. empty=False, inline=False)
  74. source = result[1]
  75. expected = (
  76. "function test()\n"
  77. " Catalan = 0.915965594177219\n"
  78. " EulerGamma = 0.5772156649015329\n"
  79. " out1 = pi ^ Catalan\n"
  80. " out2 = EulerGamma\n"
  81. " return out1, out2\n"
  82. "end\n"
  83. )
  84. assert source == expected
  85. def test_jl_code_argument_order():
  86. expr = x + y
  87. routine = make_routine("test", expr, argument_sequence=[z, x, y], language="julia")
  88. code_gen = JuliaCodeGen()
  89. output = StringIO()
  90. code_gen.dump_jl([routine], output, "test", header=False, empty=False)
  91. source = output.getvalue()
  92. expected = (
  93. "function test(z, x, y)\n"
  94. " out1 = x + y\n"
  95. " return out1\n"
  96. "end\n"
  97. )
  98. assert source == expected
  99. def test_multiple_results_m():
  100. # Here the output order is the input order
  101. expr1 = (x + y)*z
  102. expr2 = (x - y)*z
  103. name_expr = ("test", [expr1, expr2])
  104. result, = codegen(name_expr, "Julia", header=False, empty=False)
  105. source = result[1]
  106. expected = (
  107. "function test(x, y, z)\n"
  108. " out1 = z .* (x + y)\n"
  109. " out2 = z .* (x - y)\n"
  110. " return out1, out2\n"
  111. "end\n"
  112. )
  113. assert source == expected
  114. def test_results_named_unordered():
  115. # Here output order is based on name_expr
  116. A, B, C = symbols('A,B,C')
  117. expr1 = Equality(C, (x + y)*z)
  118. expr2 = Equality(A, (x - y)*z)
  119. expr3 = Equality(B, 2*x)
  120. name_expr = ("test", [expr1, expr2, expr3])
  121. result, = codegen(name_expr, "Julia", header=False, empty=False)
  122. source = result[1]
  123. expected = (
  124. "function test(x, y, z)\n"
  125. " C = z .* (x + y)\n"
  126. " A = z .* (x - y)\n"
  127. " B = 2 * x\n"
  128. " return C, A, B\n"
  129. "end\n"
  130. )
  131. assert source == expected
  132. def test_results_named_ordered():
  133. A, B, C = symbols('A,B,C')
  134. expr1 = Equality(C, (x + y)*z)
  135. expr2 = Equality(A, (x - y)*z)
  136. expr3 = Equality(B, 2*x)
  137. name_expr = ("test", [expr1, expr2, expr3])
  138. result = codegen(name_expr, "Julia", header=False, empty=False,
  139. argument_sequence=(x, z, y))
  140. assert result[0][0] == "test.jl"
  141. source = result[0][1]
  142. expected = (
  143. "function test(x, z, y)\n"
  144. " C = z .* (x + y)\n"
  145. " A = z .* (x - y)\n"
  146. " B = 2 * x\n"
  147. " return C, A, B\n"
  148. "end\n"
  149. )
  150. assert source == expected
  151. def test_complicated_jl_codegen():
  152. from sympy.functions.elementary.trigonometric import (cos, sin, tan)
  153. name_expr = ("testlong",
  154. [ ((sin(x) + cos(y) + tan(z))**3).expand(),
  155. cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
  156. ])
  157. result = codegen(name_expr, "Julia", header=False, empty=False)
  158. assert result[0][0] == "testlong.jl"
  159. source = result[0][1]
  160. expected = (
  161. "function testlong(x, y, z)\n"
  162. " out1 = sin(x) .^ 3 + 3 * sin(x) .^ 2 .* cos(y) + 3 * sin(x) .^ 2 .* tan(z)"
  163. " + 3 * sin(x) .* cos(y) .^ 2 + 6 * sin(x) .* cos(y) .* tan(z) + 3 * sin(x) .* tan(z) .^ 2"
  164. " + cos(y) .^ 3 + 3 * cos(y) .^ 2 .* tan(z) + 3 * cos(y) .* tan(z) .^ 2 + tan(z) .^ 3\n"
  165. " out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n"
  166. " return out1, out2\n"
  167. "end\n"
  168. )
  169. assert source == expected
  170. def test_jl_output_arg_mixed_unordered():
  171. # named outputs are alphabetical, unnamed output appear in the given order
  172. from sympy.functions.elementary.trigonometric import (cos, sin)
  173. a = symbols("a")
  174. name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
  175. result, = codegen(name_expr, "Julia", header=False, empty=False)
  176. assert result[0] == "foo.jl"
  177. source = result[1];
  178. expected = (
  179. 'function foo(x)\n'
  180. ' out1 = cos(2 * x)\n'
  181. ' y = sin(x)\n'
  182. ' out3 = cos(x)\n'
  183. ' a = sin(2 * x)\n'
  184. ' return out1, y, out3, a\n'
  185. 'end\n'
  186. )
  187. assert source == expected
  188. def test_jl_piecewise_():
  189. pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
  190. name_expr = ("pwtest", pw)
  191. result, = codegen(name_expr, "Julia", header=False, empty=False)
  192. source = result[1]
  193. expected = (
  194. "function pwtest(x)\n"
  195. " out1 = ((x < -1) ? (0) :\n"
  196. " (x <= 1) ? (x .^ 2) :\n"
  197. " (x > 1) ? (2 - x) : (1))\n"
  198. " return out1\n"
  199. "end\n"
  200. )
  201. assert source == expected
  202. @XFAIL
  203. def test_jl_piecewise_no_inline():
  204. # FIXME: how to pass inline=False to the JuliaCodePrinter?
  205. pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
  206. name_expr = ("pwtest", pw)
  207. result, = codegen(name_expr, "Julia", header=False, empty=False,
  208. inline=False)
  209. source = result[1]
  210. expected = (
  211. "function pwtest(x)\n"
  212. " if (x < -1)\n"
  213. " out1 = 0\n"
  214. " elseif (x <= 1)\n"
  215. " out1 = x .^ 2\n"
  216. " elseif (x > 1)\n"
  217. " out1 = -x + 2\n"
  218. " else\n"
  219. " out1 = 1\n"
  220. " end\n"
  221. " return out1\n"
  222. "end\n"
  223. )
  224. assert source == expected
  225. def test_jl_multifcns_per_file():
  226. name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
  227. result = codegen(name_expr, "Julia", header=False, empty=False)
  228. assert result[0][0] == "foo.jl"
  229. source = result[0][1];
  230. expected = (
  231. "function foo(x, y)\n"
  232. " out1 = 2 * x\n"
  233. " out2 = 3 * y\n"
  234. " return out1, out2\n"
  235. "end\n"
  236. "function bar(y)\n"
  237. " out1 = y .^ 2\n"
  238. " out2 = 4 * y\n"
  239. " return out1, out2\n"
  240. "end\n"
  241. )
  242. assert source == expected
  243. def test_jl_multifcns_per_file_w_header():
  244. name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
  245. result = codegen(name_expr, "Julia", header=True, empty=False)
  246. assert result[0][0] == "foo.jl"
  247. source = result[0][1];
  248. expected = (
  249. "# Code generated with SymPy " + sympy.__version__ + "\n"
  250. "#\n"
  251. "# See http://www.sympy.org/ for more information.\n"
  252. "#\n"
  253. "# This file is part of 'project'\n"
  254. "function foo(x, y)\n"
  255. " out1 = 2 * x\n"
  256. " out2 = 3 * y\n"
  257. " return out1, out2\n"
  258. "end\n"
  259. "function bar(y)\n"
  260. " out1 = y .^ 2\n"
  261. " out2 = 4 * y\n"
  262. " return out1, out2\n"
  263. "end\n"
  264. )
  265. assert source == expected
  266. def test_jl_filename_match_prefix():
  267. name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
  268. result, = codegen(name_expr, "Julia", prefix="baz", header=False,
  269. empty=False)
  270. assert result[0] == "baz.jl"
  271. def test_jl_matrix_named():
  272. e2 = Matrix([[x, 2*y, pi*z]])
  273. name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2))
  274. result = codegen(name_expr, "Julia", header=False, empty=False)
  275. assert result[0][0] == "test.jl"
  276. source = result[0][1]
  277. expected = (
  278. "function test(x, y, z)\n"
  279. " myout1 = [x 2 * y pi * z]\n"
  280. " return myout1\n"
  281. "end\n"
  282. )
  283. assert source == expected
  284. def test_jl_matrix_named_matsym():
  285. myout1 = MatrixSymbol('myout1', 1, 3)
  286. e2 = Matrix([[x, 2*y, pi*z]])
  287. name_expr = ("test", Equality(myout1, e2, evaluate=False))
  288. result, = codegen(name_expr, "Julia", header=False, empty=False)
  289. source = result[1]
  290. expected = (
  291. "function test(x, y, z)\n"
  292. " myout1 = [x 2 * y pi * z]\n"
  293. " return myout1\n"
  294. "end\n"
  295. )
  296. assert source == expected
  297. def test_jl_matrix_output_autoname():
  298. expr = Matrix([[x, x+y, 3]])
  299. name_expr = ("test", expr)
  300. result, = codegen(name_expr, "Julia", header=False, empty=False)
  301. source = result[1]
  302. expected = (
  303. "function test(x, y)\n"
  304. " out1 = [x x + y 3]\n"
  305. " return out1\n"
  306. "end\n"
  307. )
  308. assert source == expected
  309. def test_jl_matrix_output_autoname_2():
  310. e1 = (x + y)
  311. e2 = Matrix([[2*x, 2*y, 2*z]])
  312. e3 = Matrix([[x], [y], [z]])
  313. e4 = Matrix([[x, y], [z, 16]])
  314. name_expr = ("test", (e1, e2, e3, e4))
  315. result, = codegen(name_expr, "Julia", header=False, empty=False)
  316. source = result[1]
  317. expected = (
  318. "function test(x, y, z)\n"
  319. " out1 = x + y\n"
  320. " out2 = [2 * x 2 * y 2 * z]\n"
  321. " out3 = [x, y, z]\n"
  322. " out4 = [x y;\n"
  323. " z 16]\n"
  324. " return out1, out2, out3, out4\n"
  325. "end\n"
  326. )
  327. assert source == expected
  328. def test_jl_results_matrix_named_ordered():
  329. B, C = symbols('B,C')
  330. A = MatrixSymbol('A', 1, 3)
  331. expr1 = Equality(C, (x + y)*z)
  332. expr2 = Equality(A, Matrix([[1, 2, x]]))
  333. expr3 = Equality(B, 2*x)
  334. name_expr = ("test", [expr1, expr2, expr3])
  335. result, = codegen(name_expr, "Julia", header=False, empty=False,
  336. argument_sequence=(x, z, y))
  337. source = result[1]
  338. expected = (
  339. "function test(x, z, y)\n"
  340. " C = z .* (x + y)\n"
  341. " A = [1 2 x]\n"
  342. " B = 2 * x\n"
  343. " return C, A, B\n"
  344. "end\n"
  345. )
  346. assert source == expected
  347. def test_jl_matrixsymbol_slice():
  348. A = MatrixSymbol('A', 2, 3)
  349. B = MatrixSymbol('B', 1, 3)
  350. C = MatrixSymbol('C', 1, 3)
  351. D = MatrixSymbol('D', 2, 1)
  352. name_expr = ("test", [Equality(B, A[0, :]),
  353. Equality(C, A[1, :]),
  354. Equality(D, A[:, 2])])
  355. result, = codegen(name_expr, "Julia", header=False, empty=False)
  356. source = result[1]
  357. expected = (
  358. "function test(A)\n"
  359. " B = A[1,:]\n"
  360. " C = A[2,:]\n"
  361. " D = A[:,3]\n"
  362. " return B, C, D\n"
  363. "end\n"
  364. )
  365. assert source == expected
  366. def test_jl_matrixsymbol_slice2():
  367. A = MatrixSymbol('A', 3, 4)
  368. B = MatrixSymbol('B', 2, 2)
  369. C = MatrixSymbol('C', 2, 2)
  370. name_expr = ("test", [Equality(B, A[0:2, 0:2]),
  371. Equality(C, A[0:2, 1:3])])
  372. result, = codegen(name_expr, "Julia", header=False, empty=False)
  373. source = result[1]
  374. expected = (
  375. "function test(A)\n"
  376. " B = A[1:2,1:2]\n"
  377. " C = A[1:2,2:3]\n"
  378. " return B, C\n"
  379. "end\n"
  380. )
  381. assert source == expected
  382. def test_jl_matrixsymbol_slice3():
  383. A = MatrixSymbol('A', 8, 7)
  384. B = MatrixSymbol('B', 2, 2)
  385. C = MatrixSymbol('C', 4, 2)
  386. name_expr = ("test", [Equality(B, A[6:, 1::3]),
  387. Equality(C, A[::2, ::3])])
  388. result, = codegen(name_expr, "Julia", header=False, empty=False)
  389. source = result[1]
  390. expected = (
  391. "function test(A)\n"
  392. " B = A[7:end,2:3:end]\n"
  393. " C = A[1:2:end,1:3:end]\n"
  394. " return B, C\n"
  395. "end\n"
  396. )
  397. assert source == expected
  398. def test_jl_matrixsymbol_slice_autoname():
  399. A = MatrixSymbol('A', 2, 3)
  400. B = MatrixSymbol('B', 1, 3)
  401. name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]])
  402. result, = codegen(name_expr, "Julia", header=False, empty=False)
  403. source = result[1]
  404. expected = (
  405. "function test(A)\n"
  406. " B = A[1,:]\n"
  407. " out2 = A[2,:]\n"
  408. " out3 = A[:,1]\n"
  409. " out4 = A[:,2]\n"
  410. " return B, out2, out3, out4\n"
  411. "end\n"
  412. )
  413. assert source == expected
  414. def test_jl_loops():
  415. # Note: an Julia programmer would probably vectorize this across one or
  416. # more dimensions. Also, size(A) would be used rather than passing in m
  417. # and n. Perhaps users would expect us to vectorize automatically here?
  418. # Or is it possible to represent such things using IndexedBase?
  419. from sympy.tensor import IndexedBase, Idx
  420. from sympy.core.symbol import symbols
  421. n, m = symbols('n m', integer=True)
  422. A = IndexedBase('A')
  423. x = IndexedBase('x')
  424. y = IndexedBase('y')
  425. i = Idx('i', m)
  426. j = Idx('j', n)
  427. result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Julia",
  428. header=False, empty=False)
  429. source = result[1]
  430. expected = (
  431. 'function mat_vec_mult(y, A, m, n, x)\n'
  432. ' for i = 1:m\n'
  433. ' y[i] = 0\n'
  434. ' end\n'
  435. ' for i = 1:m\n'
  436. ' for j = 1:n\n'
  437. ' y[i] = %(rhs)s + y[i]\n'
  438. ' end\n'
  439. ' end\n'
  440. ' return y\n'
  441. 'end\n'
  442. )
  443. assert (source == expected % {'rhs': 'A[%s,%s] .* x[j]' % (i, j)} or
  444. source == expected % {'rhs': 'x[j] .* A[%s,%s]' % (i, j)})
  445. def test_jl_tensor_loops_multiple_contractions():
  446. # see comments in previous test about vectorizing
  447. from sympy.tensor import IndexedBase, Idx
  448. from sympy.core.symbol import symbols
  449. n, m, o, p = symbols('n m o p', integer=True)
  450. A = IndexedBase('A')
  451. B = IndexedBase('B')
  452. y = IndexedBase('y')
  453. i = Idx('i', m)
  454. j = Idx('j', n)
  455. k = Idx('k', o)
  456. l = Idx('l', p)
  457. result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])),
  458. "Julia", header=False, empty=False)
  459. source = result[1]
  460. expected = (
  461. 'function tensorthing(y, A, B, m, n, o, p)\n'
  462. ' for i = 1:m\n'
  463. ' y[i] = 0\n'
  464. ' end\n'
  465. ' for i = 1:m\n'
  466. ' for j = 1:n\n'
  467. ' for k = 1:o\n'
  468. ' for l = 1:p\n'
  469. ' y[i] = A[i,j,k,l] .* B[j,k,l] + y[i]\n'
  470. ' end\n'
  471. ' end\n'
  472. ' end\n'
  473. ' end\n'
  474. ' return y\n'
  475. 'end\n'
  476. )
  477. assert source == expected
  478. def test_jl_InOutArgument():
  479. expr = Equality(x, x**2)
  480. name_expr = ("mysqr", expr)
  481. result, = codegen(name_expr, "Julia", header=False, empty=False)
  482. source = result[1]
  483. expected = (
  484. "function mysqr(x)\n"
  485. " x = x .^ 2\n"
  486. " return x\n"
  487. "end\n"
  488. )
  489. assert source == expected
  490. def test_jl_InOutArgument_order():
  491. # can specify the order as (x, y)
  492. expr = Equality(x, x**2 + y)
  493. name_expr = ("test", expr)
  494. result, = codegen(name_expr, "Julia", header=False,
  495. empty=False, argument_sequence=(x,y))
  496. source = result[1]
  497. expected = (
  498. "function test(x, y)\n"
  499. " x = x .^ 2 + y\n"
  500. " return x\n"
  501. "end\n"
  502. )
  503. assert source == expected
  504. # make sure it gives (x, y) not (y, x)
  505. expr = Equality(x, x**2 + y)
  506. name_expr = ("test", expr)
  507. result, = codegen(name_expr, "Julia", header=False, empty=False)
  508. source = result[1]
  509. expected = (
  510. "function test(x, y)\n"
  511. " x = x .^ 2 + y\n"
  512. " return x\n"
  513. "end\n"
  514. )
  515. assert source == expected
  516. def test_jl_not_supported():
  517. f = Function('f')
  518. name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
  519. result, = codegen(name_expr, "Julia", header=False, empty=False)
  520. source = result[1]
  521. expected = (
  522. "function test(x)\n"
  523. " # unsupported: Derivative(f(x), x)\n"
  524. " # unsupported: zoo\n"
  525. " out1 = Derivative(f(x), x)\n"
  526. " out2 = zoo\n"
  527. " return out1, out2\n"
  528. "end\n"
  529. )
  530. assert source == expected
  531. def test_global_vars_octave():
  532. x, y, z, t = symbols("x y z t")
  533. result = codegen(('f', x*y), "Julia", header=False, empty=False,
  534. global_vars=(y,))
  535. source = result[0][1]
  536. expected = (
  537. "function f(x)\n"
  538. " out1 = x .* y\n"
  539. " return out1\n"
  540. "end\n"
  541. )
  542. assert source == expected
  543. result = codegen(('f', x*y+z), "Julia", header=False, empty=False,
  544. argument_sequence=(x, y), global_vars=(z, t))
  545. source = result[0][1]
  546. expected = (
  547. "function f(x, y)\n"
  548. " out1 = x .* y + z\n"
  549. " return out1\n"
  550. "end\n"
  551. )
  552. assert source == expected