test_codegen_octave.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. from io import StringIO
  2. from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function
  3. from sympy.core.relational import Equality
  4. from sympy.functions.elementary.piecewise import Piecewise
  5. from sympy.matrices import Matrix, MatrixSymbol
  6. from sympy.utilities.codegen import OctaveCodeGen, codegen, make_routine
  7. from sympy.testing.pytest import raises
  8. from sympy.testing.pytest import XFAIL
  9. import sympy
  10. x, y, z = symbols('x,y,z')
  11. def test_empty_m_code():
  12. code_gen = OctaveCodeGen()
  13. output = StringIO()
  14. code_gen.dump_m([], output, "file", header=False, empty=False)
  15. source = output.getvalue()
  16. assert source == ""
  17. def test_m_simple_code():
  18. name_expr = ("test", (x + y)*z)
  19. result, = codegen(name_expr, "Octave", header=False, empty=False)
  20. assert result[0] == "test.m"
  21. source = result[1]
  22. expected = (
  23. "function out1 = test(x, y, z)\n"
  24. " out1 = z.*(x + y);\n"
  25. "end\n"
  26. )
  27. assert source == expected
  28. def test_m_simple_code_with_header():
  29. name_expr = ("test", (x + y)*z)
  30. result, = codegen(name_expr, "Octave", header=True, empty=False)
  31. assert result[0] == "test.m"
  32. source = result[1]
  33. expected = (
  34. "function out1 = test(x, y, z)\n"
  35. " %TEST Autogenerated by SymPy\n"
  36. " % Code generated with SymPy " + sympy.__version__ + "\n"
  37. " %\n"
  38. " % See http://www.sympy.org/ for more information.\n"
  39. " %\n"
  40. " % This file is part of 'project'\n"
  41. " out1 = z.*(x + y);\n"
  42. "end\n"
  43. )
  44. assert source == expected
  45. def test_m_simple_code_nameout():
  46. expr = Equality(z, (x + y))
  47. name_expr = ("test", expr)
  48. result, = codegen(name_expr, "Octave", header=False, empty=False)
  49. source = result[1]
  50. expected = (
  51. "function z = test(x, y)\n"
  52. " z = x + y;\n"
  53. "end\n"
  54. )
  55. assert source == expected
  56. def test_m_numbersymbol():
  57. name_expr = ("test", pi**Catalan)
  58. result, = codegen(name_expr, "Octave", header=False, empty=False)
  59. source = result[1]
  60. expected = (
  61. "function out1 = test()\n"
  62. " out1 = pi^%s;\n"
  63. "end\n"
  64. ) % Catalan.evalf(17)
  65. assert source == expected
  66. @XFAIL
  67. def test_m_numbersymbol_no_inline():
  68. # FIXME: how to pass inline=False to the OctaveCodePrinter?
  69. name_expr = ("test", [pi**Catalan, EulerGamma])
  70. result, = codegen(name_expr, "Octave", header=False,
  71. empty=False, inline=False)
  72. source = result[1]
  73. expected = (
  74. "function [out1, out2] = test()\n"
  75. " Catalan = 0.915965594177219; % constant\n"
  76. " EulerGamma = 0.5772156649015329; % constant\n"
  77. " out1 = pi^Catalan;\n"
  78. " out2 = EulerGamma;\n"
  79. "end\n"
  80. )
  81. assert source == expected
  82. def test_m_code_argument_order():
  83. expr = x + y
  84. routine = make_routine("test", expr, argument_sequence=[z, x, y], language="octave")
  85. code_gen = OctaveCodeGen()
  86. output = StringIO()
  87. code_gen.dump_m([routine], output, "test", header=False, empty=False)
  88. source = output.getvalue()
  89. expected = (
  90. "function out1 = test(z, x, y)\n"
  91. " out1 = x + y;\n"
  92. "end\n"
  93. )
  94. assert source == expected
  95. def test_multiple_results_m():
  96. # Here the output order is the input order
  97. expr1 = (x + y)*z
  98. expr2 = (x - y)*z
  99. name_expr = ("test", [expr1, expr2])
  100. result, = codegen(name_expr, "Octave", header=False, empty=False)
  101. source = result[1]
  102. expected = (
  103. "function [out1, out2] = test(x, y, z)\n"
  104. " out1 = z.*(x + y);\n"
  105. " out2 = z.*(x - y);\n"
  106. "end\n"
  107. )
  108. assert source == expected
  109. def test_results_named_unordered():
  110. # Here output order is based on name_expr
  111. A, B, C = symbols('A,B,C')
  112. expr1 = Equality(C, (x + y)*z)
  113. expr2 = Equality(A, (x - y)*z)
  114. expr3 = Equality(B, 2*x)
  115. name_expr = ("test", [expr1, expr2, expr3])
  116. result, = codegen(name_expr, "Octave", header=False, empty=False)
  117. source = result[1]
  118. expected = (
  119. "function [C, A, B] = test(x, y, z)\n"
  120. " C = z.*(x + y);\n"
  121. " A = z.*(x - y);\n"
  122. " B = 2*x;\n"
  123. "end\n"
  124. )
  125. assert source == expected
  126. def test_results_named_ordered():
  127. A, B, C = symbols('A,B,C')
  128. expr1 = Equality(C, (x + y)*z)
  129. expr2 = Equality(A, (x - y)*z)
  130. expr3 = Equality(B, 2*x)
  131. name_expr = ("test", [expr1, expr2, expr3])
  132. result = codegen(name_expr, "Octave", header=False, empty=False,
  133. argument_sequence=(x, z, y))
  134. assert result[0][0] == "test.m"
  135. source = result[0][1]
  136. expected = (
  137. "function [C, A, B] = test(x, z, y)\n"
  138. " C = z.*(x + y);\n"
  139. " A = z.*(x - y);\n"
  140. " B = 2*x;\n"
  141. "end\n"
  142. )
  143. assert source == expected
  144. def test_complicated_m_codegen():
  145. from sympy.functions.elementary.trigonometric import (cos, sin, tan)
  146. name_expr = ("testlong",
  147. [ ((sin(x) + cos(y) + tan(z))**3).expand(),
  148. cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))
  149. ])
  150. result = codegen(name_expr, "Octave", header=False, empty=False)
  151. assert result[0][0] == "testlong.m"
  152. source = result[0][1]
  153. expected = (
  154. "function [out1, out2] = testlong(x, y, z)\n"
  155. " out1 = sin(x).^3 + 3*sin(x).^2.*cos(y) + 3*sin(x).^2.*tan(z)"
  156. " + 3*sin(x).*cos(y).^2 + 6*sin(x).*cos(y).*tan(z) + 3*sin(x).*tan(z).^2"
  157. " + cos(y).^3 + 3*cos(y).^2.*tan(z) + 3*cos(y).*tan(z).^2 + tan(z).^3;\n"
  158. " out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n"
  159. "end\n"
  160. )
  161. assert source == expected
  162. def test_m_output_arg_mixed_unordered():
  163. # named outputs are alphabetical, unnamed output appear in the given order
  164. from sympy.functions.elementary.trigonometric import (cos, sin)
  165. a = symbols("a")
  166. name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))])
  167. result, = codegen(name_expr, "Octave", header=False, empty=False)
  168. assert result[0] == "foo.m"
  169. source = result[1];
  170. expected = (
  171. 'function [out1, y, out3, a] = foo(x)\n'
  172. ' out1 = cos(2*x);\n'
  173. ' y = sin(x);\n'
  174. ' out3 = cos(x);\n'
  175. ' a = sin(2*x);\n'
  176. 'end\n'
  177. )
  178. assert source == expected
  179. def test_m_piecewise_():
  180. pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False)
  181. name_expr = ("pwtest", pw)
  182. result, = codegen(name_expr, "Octave", header=False, empty=False)
  183. source = result[1]
  184. expected = (
  185. "function out1 = pwtest(x)\n"
  186. " out1 = ((x < -1).*(0) + (~(x < -1)).*( ...\n"
  187. " (x <= 1).*(x.^2) + (~(x <= 1)).*( ...\n"
  188. " (x > 1).*(2 - x) + (~(x > 1)).*(1))));\n"
  189. "end\n"
  190. )
  191. assert source == expected
  192. @XFAIL
  193. def test_m_piecewise_no_inline():
  194. # FIXME: how to pass inline=False to the OctaveCodePrinter?
  195. pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True))
  196. name_expr = ("pwtest", pw)
  197. result, = codegen(name_expr, "Octave", header=False, empty=False,
  198. inline=False)
  199. source = result[1]
  200. expected = (
  201. "function out1 = pwtest(x)\n"
  202. " if (x < -1)\n"
  203. " out1 = 0;\n"
  204. " elseif (x <= 1)\n"
  205. " out1 = x.^2;\n"
  206. " elseif (x > 1)\n"
  207. " out1 = -x + 2;\n"
  208. " else\n"
  209. " out1 = 1;\n"
  210. " end\n"
  211. "end\n"
  212. )
  213. assert source == expected
  214. def test_m_multifcns_per_file():
  215. name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
  216. result = codegen(name_expr, "Octave", header=False, empty=False)
  217. assert result[0][0] == "foo.m"
  218. source = result[0][1];
  219. expected = (
  220. "function [out1, out2] = foo(x, y)\n"
  221. " out1 = 2*x;\n"
  222. " out2 = 3*y;\n"
  223. "end\n"
  224. "function [out1, out2] = bar(y)\n"
  225. " out1 = y.^2;\n"
  226. " out2 = 4*y;\n"
  227. "end\n"
  228. )
  229. assert source == expected
  230. def test_m_multifcns_per_file_w_header():
  231. name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
  232. result = codegen(name_expr, "Octave", header=True, empty=False)
  233. assert result[0][0] == "foo.m"
  234. source = result[0][1];
  235. expected = (
  236. "function [out1, out2] = foo(x, y)\n"
  237. " %FOO Autogenerated by SymPy\n"
  238. " % Code generated with SymPy " + sympy.__version__ + "\n"
  239. " %\n"
  240. " % See http://www.sympy.org/ for more information.\n"
  241. " %\n"
  242. " % This file is part of 'project'\n"
  243. " out1 = 2*x;\n"
  244. " out2 = 3*y;\n"
  245. "end\n"
  246. "function [out1, out2] = bar(y)\n"
  247. " out1 = y.^2;\n"
  248. " out2 = 4*y;\n"
  249. "end\n"
  250. )
  251. assert source == expected
  252. def test_m_filename_match_first_fcn():
  253. name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ]
  254. raises(ValueError, lambda: codegen(name_expr,
  255. "Octave", prefix="bar", header=False, empty=False))
  256. def test_m_matrix_named():
  257. e2 = Matrix([[x, 2*y, pi*z]])
  258. name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2))
  259. result = codegen(name_expr, "Octave", header=False, empty=False)
  260. assert result[0][0] == "test.m"
  261. source = result[0][1]
  262. expected = (
  263. "function myout1 = test(x, y, z)\n"
  264. " myout1 = [x 2*y pi*z];\n"
  265. "end\n"
  266. )
  267. assert source == expected
  268. def test_m_matrix_named_matsym():
  269. myout1 = MatrixSymbol('myout1', 1, 3)
  270. e2 = Matrix([[x, 2*y, pi*z]])
  271. name_expr = ("test", Equality(myout1, e2, evaluate=False))
  272. result, = codegen(name_expr, "Octave", header=False, empty=False)
  273. source = result[1]
  274. expected = (
  275. "function myout1 = test(x, y, z)\n"
  276. " myout1 = [x 2*y pi*z];\n"
  277. "end\n"
  278. )
  279. assert source == expected
  280. def test_m_matrix_output_autoname():
  281. expr = Matrix([[x, x+y, 3]])
  282. name_expr = ("test", expr)
  283. result, = codegen(name_expr, "Octave", header=False, empty=False)
  284. source = result[1]
  285. expected = (
  286. "function out1 = test(x, y)\n"
  287. " out1 = [x x + y 3];\n"
  288. "end\n"
  289. )
  290. assert source == expected
  291. def test_m_matrix_output_autoname_2():
  292. e1 = (x + y)
  293. e2 = Matrix([[2*x, 2*y, 2*z]])
  294. e3 = Matrix([[x], [y], [z]])
  295. e4 = Matrix([[x, y], [z, 16]])
  296. name_expr = ("test", (e1, e2, e3, e4))
  297. result, = codegen(name_expr, "Octave", header=False, empty=False)
  298. source = result[1]
  299. expected = (
  300. "function [out1, out2, out3, out4] = test(x, y, z)\n"
  301. " out1 = x + y;\n"
  302. " out2 = [2*x 2*y 2*z];\n"
  303. " out3 = [x; y; z];\n"
  304. " out4 = [x y; z 16];\n"
  305. "end\n"
  306. )
  307. assert source == expected
  308. def test_m_results_matrix_named_ordered():
  309. B, C = symbols('B,C')
  310. A = MatrixSymbol('A', 1, 3)
  311. expr1 = Equality(C, (x + y)*z)
  312. expr2 = Equality(A, Matrix([[1, 2, x]]))
  313. expr3 = Equality(B, 2*x)
  314. name_expr = ("test", [expr1, expr2, expr3])
  315. result, = codegen(name_expr, "Octave", header=False, empty=False,
  316. argument_sequence=(x, z, y))
  317. source = result[1]
  318. expected = (
  319. "function [C, A, B] = test(x, z, y)\n"
  320. " C = z.*(x + y);\n"
  321. " A = [1 2 x];\n"
  322. " B = 2*x;\n"
  323. "end\n"
  324. )
  325. assert source == expected
  326. def test_m_matrixsymbol_slice():
  327. A = MatrixSymbol('A', 2, 3)
  328. B = MatrixSymbol('B', 1, 3)
  329. C = MatrixSymbol('C', 1, 3)
  330. D = MatrixSymbol('D', 2, 1)
  331. name_expr = ("test", [Equality(B, A[0, :]),
  332. Equality(C, A[1, :]),
  333. Equality(D, A[:, 2])])
  334. result, = codegen(name_expr, "Octave", header=False, empty=False)
  335. source = result[1]
  336. expected = (
  337. "function [B, C, D] = test(A)\n"
  338. " B = A(1, :);\n"
  339. " C = A(2, :);\n"
  340. " D = A(:, 3);\n"
  341. "end\n"
  342. )
  343. assert source == expected
  344. def test_m_matrixsymbol_slice2():
  345. A = MatrixSymbol('A', 3, 4)
  346. B = MatrixSymbol('B', 2, 2)
  347. C = MatrixSymbol('C', 2, 2)
  348. name_expr = ("test", [Equality(B, A[0:2, 0:2]),
  349. Equality(C, A[0:2, 1:3])])
  350. result, = codegen(name_expr, "Octave", header=False, empty=False)
  351. source = result[1]
  352. expected = (
  353. "function [B, C] = test(A)\n"
  354. " B = A(1:2, 1:2);\n"
  355. " C = A(1:2, 2:3);\n"
  356. "end\n"
  357. )
  358. assert source == expected
  359. def test_m_matrixsymbol_slice3():
  360. A = MatrixSymbol('A', 8, 7)
  361. B = MatrixSymbol('B', 2, 2)
  362. C = MatrixSymbol('C', 4, 2)
  363. name_expr = ("test", [Equality(B, A[6:, 1::3]),
  364. Equality(C, A[::2, ::3])])
  365. result, = codegen(name_expr, "Octave", header=False, empty=False)
  366. source = result[1]
  367. expected = (
  368. "function [B, C] = test(A)\n"
  369. " B = A(7:end, 2:3:end);\n"
  370. " C = A(1:2:end, 1:3:end);\n"
  371. "end\n"
  372. )
  373. assert source == expected
  374. def test_m_matrixsymbol_slice_autoname():
  375. A = MatrixSymbol('A', 2, 3)
  376. B = MatrixSymbol('B', 1, 3)
  377. name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]])
  378. result, = codegen(name_expr, "Octave", header=False, empty=False)
  379. source = result[1]
  380. expected = (
  381. "function [B, out2, out3, out4] = test(A)\n"
  382. " B = A(1, :);\n"
  383. " out2 = A(2, :);\n"
  384. " out3 = A(:, 1);\n"
  385. " out4 = A(:, 2);\n"
  386. "end\n"
  387. )
  388. assert source == expected
  389. def test_m_loops():
  390. # Note: an Octave programmer would probably vectorize this across one or
  391. # more dimensions. Also, size(A) would be used rather than passing in m
  392. # and n. Perhaps users would expect us to vectorize automatically here?
  393. # Or is it possible to represent such things using IndexedBase?
  394. from sympy.tensor import IndexedBase, Idx
  395. from sympy.core.symbol import symbols
  396. n, m = symbols('n m', integer=True)
  397. A = IndexedBase('A')
  398. x = IndexedBase('x')
  399. y = IndexedBase('y')
  400. i = Idx('i', m)
  401. j = Idx('j', n)
  402. result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Octave",
  403. header=False, empty=False)
  404. source = result[1]
  405. expected = (
  406. 'function y = mat_vec_mult(A, m, n, x)\n'
  407. ' for i = 1:m\n'
  408. ' y(i) = 0;\n'
  409. ' end\n'
  410. ' for i = 1:m\n'
  411. ' for j = 1:n\n'
  412. ' y(i) = %(rhs)s + y(i);\n'
  413. ' end\n'
  414. ' end\n'
  415. 'end\n'
  416. )
  417. assert (source == expected % {'rhs': 'A(%s, %s).*x(j)' % (i, j)} or
  418. source == expected % {'rhs': 'x(j).*A(%s, %s)' % (i, j)})
  419. def test_m_tensor_loops_multiple_contractions():
  420. # see comments in previous test about vectorizing
  421. from sympy.tensor import IndexedBase, Idx
  422. from sympy.core.symbol import symbols
  423. n, m, o, p = symbols('n m o p', integer=True)
  424. A = IndexedBase('A')
  425. B = IndexedBase('B')
  426. y = IndexedBase('y')
  427. i = Idx('i', m)
  428. j = Idx('j', n)
  429. k = Idx('k', o)
  430. l = Idx('l', p)
  431. result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])),
  432. "Octave", header=False, empty=False)
  433. source = result[1]
  434. expected = (
  435. 'function y = tensorthing(A, B, m, n, o, p)\n'
  436. ' for i = 1:m\n'
  437. ' y(i) = 0;\n'
  438. ' end\n'
  439. ' for i = 1:m\n'
  440. ' for j = 1:n\n'
  441. ' for k = 1:o\n'
  442. ' for l = 1:p\n'
  443. ' y(i) = A(i, j, k, l).*B(j, k, l) + y(i);\n'
  444. ' end\n'
  445. ' end\n'
  446. ' end\n'
  447. ' end\n'
  448. 'end\n'
  449. )
  450. assert source == expected
  451. def test_m_InOutArgument():
  452. expr = Equality(x, x**2)
  453. name_expr = ("mysqr", expr)
  454. result, = codegen(name_expr, "Octave", header=False, empty=False)
  455. source = result[1]
  456. expected = (
  457. "function x = mysqr(x)\n"
  458. " x = x.^2;\n"
  459. "end\n"
  460. )
  461. assert source == expected
  462. def test_m_InOutArgument_order():
  463. # can specify the order as (x, y)
  464. expr = Equality(x, x**2 + y)
  465. name_expr = ("test", expr)
  466. result, = codegen(name_expr, "Octave", header=False,
  467. empty=False, argument_sequence=(x,y))
  468. source = result[1]
  469. expected = (
  470. "function x = test(x, y)\n"
  471. " x = x.^2 + y;\n"
  472. "end\n"
  473. )
  474. assert source == expected
  475. # make sure it gives (x, y) not (y, x)
  476. expr = Equality(x, x**2 + y)
  477. name_expr = ("test", expr)
  478. result, = codegen(name_expr, "Octave", header=False, empty=False)
  479. source = result[1]
  480. expected = (
  481. "function x = test(x, y)\n"
  482. " x = x.^2 + y;\n"
  483. "end\n"
  484. )
  485. assert source == expected
  486. def test_m_not_supported():
  487. f = Function('f')
  488. name_expr = ("test", [f(x).diff(x), S.ComplexInfinity])
  489. result, = codegen(name_expr, "Octave", header=False, empty=False)
  490. source = result[1]
  491. expected = (
  492. "function [out1, out2] = test(x)\n"
  493. " % unsupported: Derivative(f(x), x)\n"
  494. " % unsupported: zoo\n"
  495. " out1 = Derivative(f(x), x);\n"
  496. " out2 = zoo;\n"
  497. "end\n"
  498. )
  499. assert source == expected
  500. def test_global_vars_octave():
  501. x, y, z, t = symbols("x y z t")
  502. result = codegen(('f', x*y), "Octave", header=False, empty=False,
  503. global_vars=(y,))
  504. source = result[0][1]
  505. expected = (
  506. "function out1 = f(x)\n"
  507. " global y\n"
  508. " out1 = x.*y;\n"
  509. "end\n"
  510. )
  511. assert source == expected
  512. result = codegen(('f', x*y+z), "Octave", header=False, empty=False,
  513. argument_sequence=(x, y), global_vars=(z, t))
  514. source = result[0][1]
  515. expected = (
  516. "function out1 = f(x, y)\n"
  517. " global t z\n"
  518. " out1 = x.*y + z;\n"
  519. "end\n"
  520. )
  521. assert source == expected