test_jscode.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. from sympy.core import (pi, oo, symbols, Rational, Integer, GoldenRatio,
  2. EulerGamma, Catalan, Lambda, Dummy, S, Eq, Ne, Le,
  3. Lt, Gt, Ge, Mod)
  4. from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt,
  5. sinh, cosh, tanh, asin, acos, acosh, Max, Min)
  6. from sympy.testing.pytest import raises
  7. from sympy.printing.jscode import JavascriptCodePrinter
  8. from sympy.utilities.lambdify import implemented_function
  9. from sympy.tensor import IndexedBase, Idx
  10. from sympy.matrices import Matrix, MatrixSymbol
  11. from sympy.printing.jscode import jscode
  12. x, y, z = symbols('x,y,z')
  13. def test_printmethod():
  14. assert jscode(Abs(x)) == "Math.abs(x)"
  15. def test_jscode_sqrt():
  16. assert jscode(sqrt(x)) == "Math.sqrt(x)"
  17. assert jscode(x**0.5) == "Math.sqrt(x)"
  18. assert jscode(x**(S.One/3)) == "Math.cbrt(x)"
  19. def test_jscode_Pow():
  20. g = implemented_function('g', Lambda(x, 2*x))
  21. assert jscode(x**3) == "Math.pow(x, 3)"
  22. assert jscode(x**(y**3)) == "Math.pow(x, Math.pow(y, 3))"
  23. assert jscode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
  24. "Math.pow(3.5*2*x, -x + Math.pow(y, x))/(Math.pow(x, 2) + y)"
  25. assert jscode(x**-1.0) == '1/x'
  26. def test_jscode_constants_mathh():
  27. assert jscode(exp(1)) == "Math.E"
  28. assert jscode(pi) == "Math.PI"
  29. assert jscode(oo) == "Number.POSITIVE_INFINITY"
  30. assert jscode(-oo) == "Number.NEGATIVE_INFINITY"
  31. def test_jscode_constants_other():
  32. assert jscode(
  33. 2*GoldenRatio) == "var GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
  34. assert jscode(2*Catalan) == "var Catalan = %s;\n2*Catalan" % Catalan.evalf(17)
  35. assert jscode(
  36. 2*EulerGamma) == "var EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
  37. def test_jscode_Rational():
  38. assert jscode(Rational(3, 7)) == "3/7"
  39. assert jscode(Rational(18, 9)) == "2"
  40. assert jscode(Rational(3, -7)) == "-3/7"
  41. assert jscode(Rational(-3, -7)) == "3/7"
  42. def test_Relational():
  43. assert jscode(Eq(x, y)) == "x == y"
  44. assert jscode(Ne(x, y)) == "x != y"
  45. assert jscode(Le(x, y)) == "x <= y"
  46. assert jscode(Lt(x, y)) == "x < y"
  47. assert jscode(Gt(x, y)) == "x > y"
  48. assert jscode(Ge(x, y)) == "x >= y"
  49. def test_Mod():
  50. assert jscode(Mod(x, y)) == '((x % y) + y) % y'
  51. assert jscode(Mod(x, x + y)) == '((x % (x + y)) + (x + y)) % (x + y)'
  52. p1, p2 = symbols('p1 p2', positive=True)
  53. assert jscode(Mod(p1, p2)) == 'p1 % p2'
  54. assert jscode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)'
  55. assert jscode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)'
  56. assert jscode(-Mod(p1, p2)) == '-(p1 % p2)'
  57. assert jscode(x*Mod(p1, p2)) == 'x*(p1 % p2)'
  58. def test_jscode_Integer():
  59. assert jscode(Integer(67)) == "67"
  60. assert jscode(Integer(-1)) == "-1"
  61. def test_jscode_functions():
  62. assert jscode(sin(x) ** cos(x)) == "Math.pow(Math.sin(x), Math.cos(x))"
  63. assert jscode(sinh(x) * cosh(x)) == "Math.sinh(x)*Math.cosh(x)"
  64. assert jscode(Max(x, y) + Min(x, y)) == "Math.max(x, y) + Math.min(x, y)"
  65. assert jscode(tanh(x)*acosh(y)) == "Math.tanh(x)*Math.acosh(y)"
  66. assert jscode(asin(x)-acos(y)) == "-Math.acos(y) + Math.asin(x)"
  67. def test_jscode_inline_function():
  68. x = symbols('x')
  69. g = implemented_function('g', Lambda(x, 2*x))
  70. assert jscode(g(x)) == "2*x"
  71. g = implemented_function('g', Lambda(x, 2*x/Catalan))
  72. assert jscode(g(x)) == "var Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17)
  73. A = IndexedBase('A')
  74. i = Idx('i', symbols('n', integer=True))
  75. g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
  76. assert jscode(g(A[i]), assign_to=A[i]) == (
  77. "for (var i=0; i<n; i++){\n"
  78. " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
  79. "}"
  80. )
  81. def test_jscode_exceptions():
  82. assert jscode(ceiling(x)) == "Math.ceil(x)"
  83. assert jscode(Abs(x)) == "Math.abs(x)"
  84. def test_jscode_boolean():
  85. assert jscode(x & y) == "x && y"
  86. assert jscode(x | y) == "x || y"
  87. assert jscode(~x) == "!x"
  88. assert jscode(x & y & z) == "x && y && z"
  89. assert jscode(x | y | z) == "x || y || z"
  90. assert jscode((x & y) | z) == "z || x && y"
  91. assert jscode((x | y) & z) == "z && (x || y)"
  92. def test_jscode_Piecewise():
  93. expr = Piecewise((x, x < 1), (x**2, True))
  94. p = jscode(expr)
  95. s = \
  96. """\
  97. ((x < 1) ? (
  98. x
  99. )
  100. : (
  101. Math.pow(x, 2)
  102. ))\
  103. """
  104. assert p == s
  105. assert jscode(expr, assign_to="c") == (
  106. "if (x < 1) {\n"
  107. " c = x;\n"
  108. "}\n"
  109. "else {\n"
  110. " c = Math.pow(x, 2);\n"
  111. "}")
  112. # Check that Piecewise without a True (default) condition error
  113. expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
  114. raises(ValueError, lambda: jscode(expr))
  115. def test_jscode_Piecewise_deep():
  116. p = jscode(2*Piecewise((x, x < 1), (x**2, True)))
  117. s = \
  118. """\
  119. 2*((x < 1) ? (
  120. x
  121. )
  122. : (
  123. Math.pow(x, 2)
  124. ))\
  125. """
  126. assert p == s
  127. def test_jscode_settings():
  128. raises(TypeError, lambda: jscode(sin(x), method="garbage"))
  129. def test_jscode_Indexed():
  130. n, m, o = symbols('n m o', integer=True)
  131. i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
  132. p = JavascriptCodePrinter()
  133. p._not_c = set()
  134. x = IndexedBase('x')[j]
  135. assert p._print_Indexed(x) == 'x[j]'
  136. A = IndexedBase('A')[i, j]
  137. assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
  138. B = IndexedBase('B')[i, j, k]
  139. assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
  140. assert p._not_c == set()
  141. def test_jscode_loops_matrix_vector():
  142. n, m = symbols('n m', integer=True)
  143. A = IndexedBase('A')
  144. x = IndexedBase('x')
  145. y = IndexedBase('y')
  146. i = Idx('i', m)
  147. j = Idx('j', n)
  148. s = (
  149. 'for (var i=0; i<m; i++){\n'
  150. ' y[i] = 0;\n'
  151. '}\n'
  152. 'for (var i=0; i<m; i++){\n'
  153. ' for (var j=0; j<n; j++){\n'
  154. ' y[i] = A[n*i + j]*x[j] + y[i];\n'
  155. ' }\n'
  156. '}'
  157. )
  158. c = jscode(A[i, j]*x[j], assign_to=y[i])
  159. assert c == s
  160. def test_dummy_loops():
  161. i, m = symbols('i m', integer=True, cls=Dummy)
  162. x = IndexedBase('x')
  163. y = IndexedBase('y')
  164. i = Idx(i, m)
  165. expected = (
  166. 'for (var i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
  167. ' y[i_%(icount)i] = x[i_%(icount)i];\n'
  168. '}'
  169. ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
  170. code = jscode(x[i], assign_to=y[i])
  171. assert code == expected
  172. def test_jscode_loops_add():
  173. n, m = symbols('n m', integer=True)
  174. A = IndexedBase('A')
  175. x = IndexedBase('x')
  176. y = IndexedBase('y')
  177. z = IndexedBase('z')
  178. i = Idx('i', m)
  179. j = Idx('j', n)
  180. s = (
  181. 'for (var i=0; i<m; i++){\n'
  182. ' y[i] = x[i] + z[i];\n'
  183. '}\n'
  184. 'for (var i=0; i<m; i++){\n'
  185. ' for (var j=0; j<n; j++){\n'
  186. ' y[i] = A[n*i + j]*x[j] + y[i];\n'
  187. ' }\n'
  188. '}'
  189. )
  190. c = jscode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i])
  191. assert c == s
  192. def test_jscode_loops_multiple_contractions():
  193. n, m, o, p = symbols('n m o p', integer=True)
  194. a = IndexedBase('a')
  195. b = IndexedBase('b')
  196. y = IndexedBase('y')
  197. i = Idx('i', m)
  198. j = Idx('j', n)
  199. k = Idx('k', o)
  200. l = Idx('l', p)
  201. s = (
  202. 'for (var i=0; i<m; i++){\n'
  203. ' y[i] = 0;\n'
  204. '}\n'
  205. 'for (var i=0; i<m; i++){\n'
  206. ' for (var j=0; j<n; j++){\n'
  207. ' for (var k=0; k<o; k++){\n'
  208. ' for (var l=0; l<p; l++){\n'
  209. ' y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
  210. ' }\n'
  211. ' }\n'
  212. ' }\n'
  213. '}'
  214. )
  215. c = jscode(b[j, k, l]*a[i, j, k, l], assign_to=y[i])
  216. assert c == s
  217. def test_jscode_loops_addfactor():
  218. n, m, o, p = symbols('n m o p', integer=True)
  219. a = IndexedBase('a')
  220. b = IndexedBase('b')
  221. c = IndexedBase('c')
  222. y = IndexedBase('y')
  223. i = Idx('i', m)
  224. j = Idx('j', n)
  225. k = Idx('k', o)
  226. l = Idx('l', p)
  227. s = (
  228. 'for (var i=0; i<m; i++){\n'
  229. ' y[i] = 0;\n'
  230. '}\n'
  231. 'for (var i=0; i<m; i++){\n'
  232. ' for (var j=0; j<n; j++){\n'
  233. ' for (var k=0; k<o; k++){\n'
  234. ' for (var l=0; l<p; l++){\n'
  235. ' y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
  236. ' }\n'
  237. ' }\n'
  238. ' }\n'
  239. '}'
  240. )
  241. c = jscode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i])
  242. assert c == s
  243. def test_jscode_loops_multiple_terms():
  244. n, m, o, p = symbols('n m o p', integer=True)
  245. a = IndexedBase('a')
  246. b = IndexedBase('b')
  247. c = IndexedBase('c')
  248. y = IndexedBase('y')
  249. i = Idx('i', m)
  250. j = Idx('j', n)
  251. k = Idx('k', o)
  252. s0 = (
  253. 'for (var i=0; i<m; i++){\n'
  254. ' y[i] = 0;\n'
  255. '}\n'
  256. )
  257. s1 = (
  258. 'for (var i=0; i<m; i++){\n'
  259. ' for (var j=0; j<n; j++){\n'
  260. ' for (var k=0; k<o; k++){\n'
  261. ' y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
  262. ' }\n'
  263. ' }\n'
  264. '}\n'
  265. )
  266. s2 = (
  267. 'for (var i=0; i<m; i++){\n'
  268. ' for (var k=0; k<o; k++){\n'
  269. ' y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
  270. ' }\n'
  271. '}\n'
  272. )
  273. s3 = (
  274. 'for (var i=0; i<m; i++){\n'
  275. ' for (var j=0; j<n; j++){\n'
  276. ' y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
  277. ' }\n'
  278. '}\n'
  279. )
  280. c = jscode(
  281. b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
  282. assert (c == s0 + s1 + s2 + s3[:-1] or
  283. c == s0 + s1 + s3 + s2[:-1] or
  284. c == s0 + s2 + s1 + s3[:-1] or
  285. c == s0 + s2 + s3 + s1[:-1] or
  286. c == s0 + s3 + s1 + s2[:-1] or
  287. c == s0 + s3 + s2 + s1[:-1])
  288. def test_Matrix_printing():
  289. # Test returning a Matrix
  290. mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
  291. A = MatrixSymbol('A', 3, 1)
  292. assert jscode(mat, A) == (
  293. "A[0] = x*y;\n"
  294. "if (y > 0) {\n"
  295. " A[1] = x + 2;\n"
  296. "}\n"
  297. "else {\n"
  298. " A[1] = y;\n"
  299. "}\n"
  300. "A[2] = Math.sin(z);")
  301. # Test using MatrixElements in expressions
  302. expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
  303. assert jscode(expr) == (
  304. "((x > 0) ? (\n"
  305. " 2*A[2]\n"
  306. ")\n"
  307. ": (\n"
  308. " A[2]\n"
  309. ")) + Math.sin(A[1]) + A[0]")
  310. # Test using MatrixElements in a Matrix
  311. q = MatrixSymbol('q', 5, 1)
  312. M = MatrixSymbol('M', 3, 3)
  313. m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
  314. [q[1,0] + q[2,0], q[3, 0], 5],
  315. [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
  316. assert jscode(m, M) == (
  317. "M[0] = Math.sin(q[1]);\n"
  318. "M[1] = 0;\n"
  319. "M[2] = Math.cos(q[2]);\n"
  320. "M[3] = q[1] + q[2];\n"
  321. "M[4] = q[3];\n"
  322. "M[5] = 5;\n"
  323. "M[6] = 2*q[4]/q[1];\n"
  324. "M[7] = Math.sqrt(q[0]) + 4;\n"
  325. "M[8] = 0;")
  326. def test_MatrixElement_printing():
  327. # test cases for issue #11821
  328. A = MatrixSymbol("A", 1, 3)
  329. B = MatrixSymbol("B", 1, 3)
  330. C = MatrixSymbol("C", 1, 3)
  331. assert(jscode(A[0, 0]) == "A[0]")
  332. assert(jscode(3 * A[0, 0]) == "3*A[0]")
  333. F = C[0, 0].subs(C, A - B)
  334. assert(jscode(F) == "(A - B)[0]")