test_c.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875
  1. from sympy.core import (
  2. S, pi, oo, Symbol, symbols, Rational, Integer, Float, Function, Mod, GoldenRatio, EulerGamma, Catalan,
  3. Lambda, Dummy, nan, Mul, Pow, UnevaluatedExpr
  4. )
  5. from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne)
  6. from sympy.functions import (
  7. Abs, acos, acosh, asin, asinh, atan, atanh, atan2, ceiling, cos, cosh, erf,
  8. erfc, exp, floor, gamma, log, loggamma, Max, Min, Piecewise, sign, sin, sinh,
  9. sqrt, tan, tanh, fibonacci, lucas
  10. )
  11. from sympy.sets import Range
  12. from sympy.logic import ITE, Implies, Equivalent
  13. from sympy.codegen import For, aug_assign, Assignment
  14. from sympy.testing.pytest import raises, XFAIL
  15. from sympy.printing.c import C89CodePrinter, C99CodePrinter, get_math_macros
  16. from sympy.codegen.ast import (
  17. AddAugmentedAssignment, Element, Type, FloatType, Declaration, Pointer, Variable, value_const, pointer_const,
  18. While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall, Return,
  19. real, float32, float64, float80, float128, intc, Comment, CodeBlock
  20. )
  21. from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, fma, log10, Cbrt, hypot, Sqrt
  22. from sympy.codegen.cnodes import restrict
  23. from sympy.utilities.lambdify import implemented_function
  24. from sympy.tensor import IndexedBase, Idx
  25. from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix
  26. from sympy.printing.codeprinter import ccode
  27. x, y, z = symbols('x,y,z')
  28. def test_printmethod():
  29. class fabs(Abs):
  30. def _ccode(self, printer):
  31. return "fabs(%s)" % printer._print(self.args[0])
  32. assert ccode(fabs(x)) == "fabs(x)"
  33. def test_ccode_sqrt():
  34. assert ccode(sqrt(x)) == "sqrt(x)"
  35. assert ccode(x**0.5) == "sqrt(x)"
  36. assert ccode(sqrt(x)) == "sqrt(x)"
  37. def test_ccode_Pow():
  38. assert ccode(x**3) == "pow(x, 3)"
  39. assert ccode(x**(y**3)) == "pow(x, pow(y, 3))"
  40. g = implemented_function('g', Lambda(x, 2*x))
  41. assert ccode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \
  42. "pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2) + y)"
  43. assert ccode(x**-1.0) == '1.0/x'
  44. assert ccode(x**Rational(2, 3)) == 'pow(x, 2.0/3.0)'
  45. assert ccode(x**Rational(2, 3), type_aliases={real: float80}) == 'powl(x, 2.0L/3.0L)'
  46. _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"),
  47. (lambda base, exp: not exp.is_integer, "pow")]
  48. assert ccode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)'
  49. assert ccode(x**0.5, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 0.5)'
  50. assert ccode(x**Rational(16, 5), user_functions={'Pow': _cond_cfunc}) == 'pow(x, 16.0/5.0)'
  51. _cond_cfunc2 = [(lambda base, exp: base == 2, lambda base, exp: 'exp2(%s)' % exp),
  52. (lambda base, exp: base != 2, 'pow')]
  53. # Related to gh-11353
  54. assert ccode(2**x, user_functions={'Pow': _cond_cfunc2}) == 'exp2(x)'
  55. assert ccode(x**2, user_functions={'Pow': _cond_cfunc2}) == 'pow(x, 2)'
  56. # For issue 14160
  57. assert ccode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False),
  58. evaluate=False)) == '-2*x/(y*y)'
  59. def test_ccode_Max():
  60. # Test for gh-11926
  61. assert ccode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))'
  62. def test_ccode_Min_performance():
  63. #Shouldn't take more than a few seconds
  64. big_min = Min(*symbols('a[0:50]'))
  65. for curr_standard in ('c89', 'c99', 'c11'):
  66. output = ccode(big_min, standard=curr_standard)
  67. assert output.count('(') == output.count(')')
  68. def test_ccode_constants_mathh():
  69. assert ccode(exp(1)) == "M_E"
  70. assert ccode(pi) == "M_PI"
  71. assert ccode(oo, standard='c89') == "HUGE_VAL"
  72. assert ccode(-oo, standard='c89') == "-HUGE_VAL"
  73. assert ccode(oo) == "INFINITY"
  74. assert ccode(-oo, standard='c99') == "-INFINITY"
  75. assert ccode(pi, type_aliases={real: float80}) == "M_PIl"
  76. def test_ccode_constants_other():
  77. assert ccode(2*GoldenRatio) == "const double GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17)
  78. assert ccode(
  79. 2*Catalan) == "const double Catalan = %s;\n2*Catalan" % Catalan.evalf(17)
  80. assert ccode(2*EulerGamma) == "const double EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17)
  81. def test_ccode_Rational():
  82. assert ccode(Rational(3, 7)) == "3.0/7.0"
  83. assert ccode(Rational(3, 7), type_aliases={real: float80}) == "3.0L/7.0L"
  84. assert ccode(Rational(18, 9)) == "2"
  85. assert ccode(Rational(3, -7)) == "-3.0/7.0"
  86. assert ccode(Rational(3, -7), type_aliases={real: float80}) == "-3.0L/7.0L"
  87. assert ccode(Rational(-3, -7)) == "3.0/7.0"
  88. assert ccode(Rational(-3, -7), type_aliases={real: float80}) == "3.0L/7.0L"
  89. assert ccode(x + Rational(3, 7)) == "x + 3.0/7.0"
  90. assert ccode(x + Rational(3, 7), type_aliases={real: float80}) == "x + 3.0L/7.0L"
  91. assert ccode(Rational(3, 7)*x) == "(3.0/7.0)*x"
  92. assert ccode(Rational(3, 7)*x, type_aliases={real: float80}) == "(3.0L/7.0L)*x"
  93. def test_ccode_Integer():
  94. assert ccode(Integer(67)) == "67"
  95. assert ccode(Integer(-1)) == "-1"
  96. def test_ccode_functions():
  97. assert ccode(sin(x) ** cos(x)) == "pow(sin(x), cos(x))"
  98. def test_ccode_inline_function():
  99. x = symbols('x')
  100. g = implemented_function('g', Lambda(x, 2*x))
  101. assert ccode(g(x)) == "2*x"
  102. g = implemented_function('g', Lambda(x, 2*x/Catalan))
  103. assert ccode(
  104. g(x)) == "const double Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17)
  105. A = IndexedBase('A')
  106. i = Idx('i', symbols('n', integer=True))
  107. g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x)))
  108. assert ccode(g(A[i]), assign_to=A[i]) == (
  109. "for (int i=0; i<n; i++){\n"
  110. " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n"
  111. "}"
  112. )
  113. def test_ccode_exceptions():
  114. assert ccode(gamma(x), standard='C99') == "tgamma(x)"
  115. gamma_c89 = ccode(gamma(x), standard='C89')
  116. assert 'not supported in c' in gamma_c89.lower()
  117. gamma_c89 = ccode(gamma(x), standard='C89', allow_unknown_functions=False)
  118. assert 'not supported in c' in gamma_c89.lower()
  119. gamma_c89 = ccode(gamma(x), standard='C89', allow_unknown_functions=True)
  120. assert 'not supported in c' not in gamma_c89.lower()
  121. def test_ccode_functions2():
  122. assert ccode(ceiling(x)) == "ceil(x)"
  123. assert ccode(Abs(x)) == "fabs(x)"
  124. assert ccode(gamma(x)) == "tgamma(x)"
  125. r, s = symbols('r,s', real=True)
  126. assert ccode(Mod(ceiling(r), ceiling(s))) == '((ceil(r) % ceil(s)) + '\
  127. 'ceil(s)) % ceil(s)'
  128. assert ccode(Mod(r, s)) == "fmod(r, s)"
  129. p1, p2 = symbols('p1 p2', integer=True, positive=True)
  130. assert ccode(Mod(p1, p2)) == 'p1 % p2'
  131. assert ccode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)'
  132. assert ccode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)'
  133. assert ccode(-Mod(3, 7, evaluate=False)) == '-(3 % 7)'
  134. assert ccode(r*Mod(p1, p2)) == 'r*(p1 % p2)'
  135. assert ccode(Mod(p1, p2)**s) == 'pow(p1 % p2, s)'
  136. n = symbols('n', integer=True, negative=True)
  137. assert ccode(Mod(-n, p2)) == '(-n) % p2'
  138. assert ccode(fibonacci(n)) == '(1.0/5.0)*pow(2, -n)*sqrt(5)*(-pow(1 - sqrt(5), n) + pow(1 + sqrt(5), n))'
  139. assert ccode(lucas(n)) == 'pow(2, -n)*(pow(1 - sqrt(5), n) + pow(1 + sqrt(5), n))'
  140. def test_ccode_user_functions():
  141. x = symbols('x', integer=False)
  142. n = symbols('n', integer=True)
  143. custom_functions = {
  144. "ceiling": "ceil",
  145. "Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")],
  146. }
  147. assert ccode(ceiling(x), user_functions=custom_functions) == "ceil(x)"
  148. assert ccode(Abs(x), user_functions=custom_functions) == "fabs(x)"
  149. assert ccode(Abs(n), user_functions=custom_functions) == "abs(n)"
  150. expr = Symbol('a')
  151. muladd = Function('muladd')
  152. for i in range(0, 100):
  153. # the large number of terms acts as a regression test for gh-23839
  154. expr = muladd(Rational(1, 2), Symbol(f'a{i}'), expr)
  155. out = ccode(expr, user_functions={'muladd':'muladd'})
  156. assert 'a99' in out
  157. assert out.count('muladd') == 100
  158. def test_ccode_boolean():
  159. assert ccode(True) == "true"
  160. assert ccode(S.true) == "true"
  161. assert ccode(False) == "false"
  162. assert ccode(S.false) == "false"
  163. assert ccode(x & y) == "x && y"
  164. assert ccode(x | y) == "x || y"
  165. assert ccode(~x) == "!x"
  166. assert ccode(x & y & z) == "x && y && z"
  167. assert ccode(x | y | z) == "x || y || z"
  168. assert ccode((x & y) | z) == "z || x && y"
  169. assert ccode((x | y) & z) == "z && (x || y)"
  170. # Automatic rewrites
  171. assert ccode(x ^ y) == '(x || y) && (!x || !y)'
  172. assert ccode((x ^ y) ^ z) == '(x || y || z) && (x || !y || !z) && (y || !x || !z) && (z || !x || !y)'
  173. assert ccode(Implies(x, y)) == 'y || !x'
  174. assert ccode(Equivalent(x, z ^ y, Implies(z, x))) == '(x || (y || !z) && (z || !y)) && (z && !x || (y || z) && (!y || !z))'
  175. def test_ccode_Relational():
  176. assert ccode(Eq(x, y)) == "x == y"
  177. assert ccode(Ne(x, y)) == "x != y"
  178. assert ccode(Le(x, y)) == "x <= y"
  179. assert ccode(Lt(x, y)) == "x < y"
  180. assert ccode(Gt(x, y)) == "x > y"
  181. assert ccode(Ge(x, y)) == "x >= y"
  182. def test_ccode_Piecewise():
  183. expr = Piecewise((x, x < 1), (x**2, True))
  184. assert ccode(expr) == (
  185. "((x < 1) ? (\n"
  186. " x\n"
  187. ")\n"
  188. ": (\n"
  189. " pow(x, 2)\n"
  190. "))")
  191. assert ccode(expr, assign_to="c") == (
  192. "if (x < 1) {\n"
  193. " c = x;\n"
  194. "}\n"
  195. "else {\n"
  196. " c = pow(x, 2);\n"
  197. "}")
  198. expr = Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True))
  199. assert ccode(expr) == (
  200. "((x < 1) ? (\n"
  201. " x\n"
  202. ")\n"
  203. ": ((x < 2) ? (\n"
  204. " x + 1\n"
  205. ")\n"
  206. ": (\n"
  207. " pow(x, 2)\n"
  208. ")))")
  209. assert ccode(expr, assign_to='c') == (
  210. "if (x < 1) {\n"
  211. " c = x;\n"
  212. "}\n"
  213. "else if (x < 2) {\n"
  214. " c = x + 1;\n"
  215. "}\n"
  216. "else {\n"
  217. " c = pow(x, 2);\n"
  218. "}")
  219. # Check that Piecewise without a True (default) condition error
  220. expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0))
  221. raises(ValueError, lambda: ccode(expr))
  222. def test_ccode_sinc():
  223. from sympy.functions.elementary.trigonometric import sinc
  224. expr = sinc(x)
  225. assert ccode(expr) == (
  226. "((x != 0) ? (\n"
  227. " sin(x)/x\n"
  228. ")\n"
  229. ": (\n"
  230. " 1\n"
  231. "))")
  232. def test_ccode_Piecewise_deep():
  233. p = ccode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)))
  234. assert p == (
  235. "2*((x < 1) ? (\n"
  236. " x\n"
  237. ")\n"
  238. ": ((x < 2) ? (\n"
  239. " x + 1\n"
  240. ")\n"
  241. ": (\n"
  242. " pow(x, 2)\n"
  243. ")))")
  244. expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1
  245. assert ccode(expr) == (
  246. "pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n"
  247. " 0\n"
  248. ")\n"
  249. ": (\n"
  250. " 1\n"
  251. ")) + cos(z) - 1")
  252. assert ccode(expr, assign_to='c') == (
  253. "c = pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n"
  254. " 0\n"
  255. ")\n"
  256. ": (\n"
  257. " 1\n"
  258. ")) + cos(z) - 1;")
  259. def test_ccode_ITE():
  260. expr = ITE(x < 1, y, z)
  261. assert ccode(expr) == (
  262. "((x < 1) ? (\n"
  263. " y\n"
  264. ")\n"
  265. ": (\n"
  266. " z\n"
  267. "))")
  268. def test_ccode_settings():
  269. raises(TypeError, lambda: ccode(sin(x), method="garbage"))
  270. def test_ccode_Indexed():
  271. s, n, m, o = symbols('s n m o', integer=True)
  272. i, j, k = Idx('i', n), Idx('j', m), Idx('k', o)
  273. x = IndexedBase('x')[j]
  274. A = IndexedBase('A')[i, j]
  275. B = IndexedBase('B')[i, j, k]
  276. p = C99CodePrinter()
  277. assert p._print_Indexed(x) == 'x[j]'
  278. assert p._print_Indexed(A) == 'A[%s]' % (m*i+j)
  279. assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k)
  280. A = IndexedBase('A', shape=(5,3))[i, j]
  281. assert p._print_Indexed(A) == 'A[%s]' % (3*i + j)
  282. A = IndexedBase('A', shape=(5,3), strides='F')[i, j]
  283. assert ccode(A) == 'A[%s]' % (i + 5*j)
  284. A = IndexedBase('A', shape=(29,29), strides=(1, s), offset=o)[i, j]
  285. assert ccode(A) == 'A[o + s*j + i]'
  286. Abase = IndexedBase('A', strides=(s, m, n), offset=o)
  287. assert ccode(Abase[i, j, k]) == 'A[m*j + n*k + o + s*i]'
  288. assert ccode(Abase[2, 3, k]) == 'A[3*m + n*k + o + 2*s]'
  289. def test_Element():
  290. assert ccode(Element('x', 'ij')) == 'x[i][j]'
  291. assert ccode(Element('x', 'ij', strides='kl', offset='o')) == 'x[i*k + j*l + o]'
  292. assert ccode(Element('x', (3,))) == 'x[3]'
  293. assert ccode(Element('x', (3,4,5))) == 'x[3][4][5]'
  294. def test_ccode_Indexed_without_looking_for_contraction():
  295. len_y = 5
  296. y = IndexedBase('y', shape=(len_y,))
  297. x = IndexedBase('x', shape=(len_y,))
  298. Dy = IndexedBase('Dy', shape=(len_y-1,))
  299. i = Idx('i', len_y-1)
  300. e = Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i]))
  301. code0 = ccode(e.rhs, assign_to=e.lhs, contract=False)
  302. assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1)
  303. def test_ccode_loops_matrix_vector():
  304. n, m = symbols('n m', integer=True)
  305. A = IndexedBase('A')
  306. x = IndexedBase('x')
  307. y = IndexedBase('y')
  308. i = Idx('i', m)
  309. j = Idx('j', n)
  310. s = (
  311. 'for (int i=0; i<m; i++){\n'
  312. ' y[i] = 0;\n'
  313. '}\n'
  314. 'for (int i=0; i<m; i++){\n'
  315. ' for (int j=0; j<n; j++){\n'
  316. ' y[i] = A[%s]*x[j] + y[i];\n' % (i*n + j) +\
  317. ' }\n'
  318. '}'
  319. )
  320. assert ccode(A[i, j]*x[j], assign_to=y[i]) == s
  321. def test_dummy_loops():
  322. i, m = symbols('i m', integer=True, cls=Dummy)
  323. x = IndexedBase('x')
  324. y = IndexedBase('y')
  325. i = Idx(i, m)
  326. expected = (
  327. 'for (int i_%(icount)i=0; i_%(icount)i<m_%(mcount)i; i_%(icount)i++){\n'
  328. ' y[i_%(icount)i] = x[i_%(icount)i];\n'
  329. '}'
  330. ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
  331. assert ccode(x[i], assign_to=y[i]) == expected
  332. def test_ccode_loops_add():
  333. n, m = symbols('n m', integer=True)
  334. A = IndexedBase('A')
  335. x = IndexedBase('x')
  336. y = IndexedBase('y')
  337. z = IndexedBase('z')
  338. i = Idx('i', m)
  339. j = Idx('j', n)
  340. s = (
  341. 'for (int i=0; i<m; i++){\n'
  342. ' y[i] = x[i] + z[i];\n'
  343. '}\n'
  344. 'for (int i=0; i<m; i++){\n'
  345. ' for (int j=0; j<n; j++){\n'
  346. ' y[i] = A[%s]*x[j] + y[i];\n' % (i*n + j) +\
  347. ' }\n'
  348. '}'
  349. )
  350. assert ccode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == s
  351. def test_ccode_loops_multiple_contractions():
  352. n, m, o, p = symbols('n m o p', integer=True)
  353. a = IndexedBase('a')
  354. b = IndexedBase('b')
  355. y = IndexedBase('y')
  356. i = Idx('i', m)
  357. j = Idx('j', n)
  358. k = Idx('k', o)
  359. l = Idx('l', p)
  360. s = (
  361. 'for (int i=0; i<m; i++){\n'
  362. ' y[i] = 0;\n'
  363. '}\n'
  364. 'for (int i=0; i<m; i++){\n'
  365. ' for (int j=0; j<n; j++){\n'
  366. ' for (int k=0; k<o; k++){\n'
  367. ' for (int l=0; l<p; l++){\n'
  368. ' y[i] = a[%s]*b[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
  369. ' }\n'
  370. ' }\n'
  371. ' }\n'
  372. '}'
  373. )
  374. assert ccode(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == s
  375. def test_ccode_loops_addfactor():
  376. n, m, o, p = symbols('n m o p', integer=True)
  377. a = IndexedBase('a')
  378. b = IndexedBase('b')
  379. c = IndexedBase('c')
  380. y = IndexedBase('y')
  381. i = Idx('i', m)
  382. j = Idx('j', n)
  383. k = Idx('k', o)
  384. l = Idx('l', p)
  385. s = (
  386. 'for (int i=0; i<m; i++){\n'
  387. ' y[i] = 0;\n'
  388. '}\n'
  389. 'for (int i=0; i<m; i++){\n'
  390. ' for (int j=0; j<n; j++){\n'
  391. ' for (int k=0; k<o; k++){\n'
  392. ' for (int l=0; l<p; l++){\n'
  393. ' y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n' % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\
  394. ' }\n'
  395. ' }\n'
  396. ' }\n'
  397. '}'
  398. )
  399. assert ccode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) == s
  400. def test_ccode_loops_multiple_terms():
  401. n, m, o, p = symbols('n m o p', integer=True)
  402. a = IndexedBase('a')
  403. b = IndexedBase('b')
  404. c = IndexedBase('c')
  405. y = IndexedBase('y')
  406. i = Idx('i', m)
  407. j = Idx('j', n)
  408. k = Idx('k', o)
  409. s0 = (
  410. 'for (int i=0; i<m; i++){\n'
  411. ' y[i] = 0;\n'
  412. '}\n'
  413. )
  414. s1 = (
  415. 'for (int i=0; i<m; i++){\n'
  416. ' for (int j=0; j<n; j++){\n'
  417. ' for (int k=0; k<o; k++){\n'
  418. ' y[i] = b[j]*b[k]*c[%s] + y[i];\n' % (i*n*o + j*o + k) +\
  419. ' }\n'
  420. ' }\n'
  421. '}\n'
  422. )
  423. s2 = (
  424. 'for (int i=0; i<m; i++){\n'
  425. ' for (int k=0; k<o; k++){\n'
  426. ' y[i] = a[%s]*b[k] + y[i];\n' % (i*o + k) +\
  427. ' }\n'
  428. '}\n'
  429. )
  430. s3 = (
  431. 'for (int i=0; i<m; i++){\n'
  432. ' for (int j=0; j<n; j++){\n'
  433. ' y[i] = a[%s]*b[j] + y[i];\n' % (i*n + j) +\
  434. ' }\n'
  435. '}\n'
  436. )
  437. c = ccode(b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i])
  438. assert (c == s0 + s1 + s2 + s3[:-1] or
  439. c == s0 + s1 + s3 + s2[:-1] or
  440. c == s0 + s2 + s1 + s3[:-1] or
  441. c == s0 + s2 + s3 + s1[:-1] or
  442. c == s0 + s3 + s1 + s2[:-1] or
  443. c == s0 + s3 + s2 + s1[:-1])
  444. def test_dereference_printing():
  445. expr = x + y + sin(z) + z
  446. assert ccode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))"
  447. def test_Matrix_printing():
  448. # Test returning a Matrix
  449. mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)])
  450. A = MatrixSymbol('A', 3, 1)
  451. assert ccode(mat, A) == (
  452. "A[0] = x*y;\n"
  453. "if (y > 0) {\n"
  454. " A[1] = x + 2;\n"
  455. "}\n"
  456. "else {\n"
  457. " A[1] = y;\n"
  458. "}\n"
  459. "A[2] = sin(z);")
  460. # Test using MatrixElements in expressions
  461. expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0]
  462. assert ccode(expr) == (
  463. "((x > 0) ? (\n"
  464. " 2*A[2]\n"
  465. ")\n"
  466. ": (\n"
  467. " A[2]\n"
  468. ")) + sin(A[1]) + A[0]")
  469. # Test using MatrixElements in a Matrix
  470. q = MatrixSymbol('q', 5, 1)
  471. M = MatrixSymbol('M', 3, 3)
  472. m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])],
  473. [q[1,0] + q[2,0], q[3, 0], 5],
  474. [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]])
  475. assert ccode(m, M) == (
  476. "M[0] = sin(q[1]);\n"
  477. "M[1] = 0;\n"
  478. "M[2] = cos(q[2]);\n"
  479. "M[3] = q[1] + q[2];\n"
  480. "M[4] = q[3];\n"
  481. "M[5] = 5;\n"
  482. "M[6] = 2*q[4]/q[1];\n"
  483. "M[7] = sqrt(q[0]) + 4;\n"
  484. "M[8] = 0;")
  485. def test_sparse_matrix():
  486. # gh-15791
  487. assert 'Not supported in C' in ccode(SparseMatrix([[1, 2, 3]]))
  488. def test_ccode_reserved_words():
  489. x, y = symbols('x, if')
  490. with raises(ValueError):
  491. ccode(y**2, error_on_reserved=True, standard='C99')
  492. assert ccode(y**2) == 'pow(if_, 2)'
  493. assert ccode(x * y**2, dereference=[y]) == 'pow((*if_), 2)*x'
  494. assert ccode(y**2, reserved_word_suffix='_unreserved') == 'pow(if_unreserved, 2)'
  495. def test_ccode_sign():
  496. expr1, ref1 = sign(x) * y, 'y*(((x) > 0) - ((x) < 0))'
  497. expr2, ref2 = sign(cos(x)), '(((cos(x)) > 0) - ((cos(x)) < 0))'
  498. expr3, ref3 = sign(2 * x + x**2) * x + x**2, 'pow(x, 2) + x*(((pow(x, 2) + 2*x) > 0) - ((pow(x, 2) + 2*x) < 0))'
  499. assert ccode(expr1) == ref1
  500. assert ccode(expr1, 'z') == 'z = %s;' % ref1
  501. assert ccode(expr2) == ref2
  502. assert ccode(expr3) == ref3
  503. def test_ccode_Assignment():
  504. assert ccode(Assignment(x, y + z)) == 'x = y + z;'
  505. assert ccode(aug_assign(x, '+', y + z)) == 'x += y + z;'
  506. def test_ccode_For():
  507. f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)])
  508. assert ccode(f) == ("for (x = 0; x < 10; x += 2) {\n"
  509. " y *= x;\n"
  510. "}")
  511. def test_ccode_Max_Min():
  512. assert ccode(Max(x, 0), standard='C89') == '((0 > x) ? 0 : x)'
  513. assert ccode(Max(x, 0), standard='C99') == 'fmax(0, x)'
  514. assert ccode(Min(x, 0, sqrt(x)), standard='c89') == (
  515. '((0 < ((x < sqrt(x)) ? x : sqrt(x))) ? 0 : ((x < sqrt(x)) ? x : sqrt(x)))'
  516. )
  517. def test_ccode_standard():
  518. assert ccode(expm1(x), standard='c99') == 'expm1(x)'
  519. assert ccode(nan, standard='c99') == 'NAN'
  520. assert ccode(float('nan'), standard='c99') == 'NAN'
  521. def test_C89CodePrinter():
  522. c89printer = C89CodePrinter()
  523. assert c89printer.language == 'C'
  524. assert c89printer.standard == 'C89'
  525. assert 'void' in c89printer.reserved_words
  526. assert 'template' not in c89printer.reserved_words
  527. def test_C99CodePrinter():
  528. assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)'
  529. assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)'
  530. assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)'
  531. assert C99CodePrinter().doprint(log2(x)) == 'log2(x)'
  532. assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)'
  533. assert C99CodePrinter().doprint(log10(x)) == 'log10(x)'
  534. assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)' # note Cbrt due to cbrt already taken.
  535. assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)'
  536. assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)'
  537. assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))'
  538. assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)'
  539. c99printer = C99CodePrinter()
  540. assert c99printer.language == 'C'
  541. assert c99printer.standard == 'C99'
  542. assert 'restrict' in c99printer.reserved_words
  543. assert 'using' not in c99printer.reserved_words
  544. @XFAIL
  545. def test_C99CodePrinter__precision_f80():
  546. f80_printer = C99CodePrinter({"type_aliases": {real: float80}})
  547. assert f80_printer.doprint(sin(x+Float('2.1'))) == 'sinl(x + 2.1L)'
  548. def test_C99CodePrinter__precision():
  549. n = symbols('n', integer=True)
  550. p = symbols('p', integer=True, positive=True)
  551. f32_printer = C99CodePrinter({"type_aliases": {real: float32}})
  552. f64_printer = C99CodePrinter({"type_aliases": {real: float64}})
  553. f80_printer = C99CodePrinter({"type_aliases": {real: float80}})
  554. assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)'
  555. assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)'
  556. assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)'
  557. for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']):
  558. def check(expr, ref):
  559. assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper())
  560. check(Abs(n), 'abs(n)')
  561. check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})')
  562. check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))')
  563. check(exp(x*8.0), 'exp{s}(8.0{S}*x)')
  564. check(exp2(x), 'exp2{s}(x)')
  565. check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)')
  566. check(Mod(p, 2), 'p % 2')
  567. check(Mod(2*p + 3, 3*p + 5, evaluate=False), '(2*p + 3) % (3*p + 5)')
  568. check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})')
  569. check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})')
  570. check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)')
  571. check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)')
  572. check(log2(x*8.0), 'log2{s}(8.0{S}*x)')
  573. check(log1p(x), 'log1p{s}(x)')
  574. check(2**x, 'pow{s}(2, x)')
  575. check(2.0**x, 'pow{s}(2.0{S}, x)')
  576. check(x**3, 'pow{s}(x, 3)')
  577. check(x**4.0, 'pow{s}(x, 4.0{S})')
  578. check(sqrt(3+x), 'sqrt{s}(x + 3)')
  579. check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})')
  580. check(hypot(x, y), 'hypot{s}(x, y)')
  581. check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})')
  582. check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})')
  583. check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})')
  584. check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})')
  585. check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})')
  586. check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})')
  587. check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)')
  588. check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})')
  589. check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})')
  590. check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})')
  591. check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})')
  592. check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})')
  593. check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})')
  594. check(erf(42.*x), 'erf{s}(42.0{S}*x)')
  595. check(erfc(42.*x), 'erfc{s}(42.0{S}*x)')
  596. check(gamma(x), 'tgamma{s}(x)')
  597. check(loggamma(x), 'lgamma{s}(x)')
  598. check(ceiling(x + 2.), "ceil{s}(x + 2.0{S})")
  599. check(floor(x + 2.), "floor{s}(x + 2.0{S})")
  600. check(fma(x, y, -z), 'fma{s}(x, y, -z)')
  601. check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))')
  602. check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)')
  603. def test_get_math_macros():
  604. macros = get_math_macros()
  605. assert macros[exp(1)] == 'M_E'
  606. assert macros[1/Sqrt(2)] == 'M_SQRT1_2'
  607. def test_ccode_Declaration():
  608. i = symbols('i', integer=True)
  609. var1 = Variable(i, type=Type.from_expr(i))
  610. dcl1 = Declaration(var1)
  611. assert ccode(dcl1) == 'int i'
  612. var2 = Variable(x, type=float32, attrs={value_const})
  613. dcl2a = Declaration(var2)
  614. assert ccode(dcl2a) == 'const float x'
  615. dcl2b = var2.as_Declaration(value=pi)
  616. assert ccode(dcl2b) == 'const float x = M_PI'
  617. var3 = Variable(y, type=Type('bool'))
  618. dcl3 = Declaration(var3)
  619. printer = C89CodePrinter()
  620. assert 'stdbool.h' not in printer.headers
  621. assert printer.doprint(dcl3) == 'bool y'
  622. assert 'stdbool.h' in printer.headers
  623. u = symbols('u', real=True)
  624. ptr4 = Pointer.deduced(u, attrs={pointer_const, restrict})
  625. dcl4 = Declaration(ptr4)
  626. assert ccode(dcl4) == 'double * const restrict u'
  627. var5 = Variable(x, Type('__float128'), attrs={value_const})
  628. dcl5a = Declaration(var5)
  629. assert ccode(dcl5a) == 'const __float128 x'
  630. var5b = Variable(var5.symbol, var5.type, pi, attrs=var5.attrs)
  631. dcl5b = Declaration(var5b)
  632. assert ccode(dcl5b) == 'const __float128 x = M_PI'
  633. def test_C99CodePrinter_custom_type():
  634. # We will look at __float128 (new in glibc 2.26)
  635. f128 = FloatType('_Float128', float128.nbits, float128.nmant, float128.nexp)
  636. p128 = C99CodePrinter({
  637. "type_aliases": {real: f128},
  638. "type_literal_suffixes": {f128: 'Q'},
  639. "type_func_suffixes": {f128: 'f128'},
  640. "type_math_macro_suffixes": {
  641. real: 'f128',
  642. f128: 'f128'
  643. },
  644. "type_macros": {
  645. f128: ('__STDC_WANT_IEC_60559_TYPES_EXT__',)
  646. }
  647. })
  648. assert p128.doprint(x) == 'x'
  649. assert not p128.headers
  650. assert not p128.libraries
  651. assert not p128.macros
  652. assert p128.doprint(2.0) == '2.0Q'
  653. assert not p128.headers
  654. assert not p128.libraries
  655. assert p128.macros == {'__STDC_WANT_IEC_60559_TYPES_EXT__'}
  656. assert p128.doprint(Rational(1, 2)) == '1.0Q/2.0Q'
  657. assert p128.doprint(sin(x)) == 'sinf128(x)'
  658. assert p128.doprint(cos(2., evaluate=False)) == 'cosf128(2.0Q)'
  659. assert p128.doprint(x**-1.0) == '1.0Q/x'
  660. var5 = Variable(x, f128, attrs={value_const})
  661. dcl5a = Declaration(var5)
  662. assert ccode(dcl5a) == 'const _Float128 x'
  663. var5b = Variable(x, f128, pi, attrs={value_const})
  664. dcl5b = Declaration(var5b)
  665. assert p128.doprint(dcl5b) == 'const _Float128 x = M_PIf128'
  666. var5b = Variable(x, f128, value=Catalan.evalf(38), attrs={value_const})
  667. dcl5c = Declaration(var5b)
  668. assert p128.doprint(dcl5c) == 'const _Float128 x = %sQ' % Catalan.evalf(f128.decimal_dig)
  669. def test_MatrixElement_printing():
  670. # test cases for issue #11821
  671. A = MatrixSymbol("A", 1, 3)
  672. B = MatrixSymbol("B", 1, 3)
  673. C = MatrixSymbol("C", 1, 3)
  674. assert(ccode(A[0, 0]) == "A[0]")
  675. assert(ccode(3 * A[0, 0]) == "3*A[0]")
  676. F = C[0, 0].subs(C, A - B)
  677. assert(ccode(F) == "(A - B)[0]")
  678. def test_ccode_math_macros():
  679. assert ccode(z + exp(1)) == 'z + M_E'
  680. assert ccode(z + log2(exp(1))) == 'z + M_LOG2E'
  681. assert ccode(z + 1/log(2)) == 'z + M_LOG2E'
  682. assert ccode(z + log(2)) == 'z + M_LN2'
  683. assert ccode(z + log(10)) == 'z + M_LN10'
  684. assert ccode(z + pi) == 'z + M_PI'
  685. assert ccode(z + pi/2) == 'z + M_PI_2'
  686. assert ccode(z + pi/4) == 'z + M_PI_4'
  687. assert ccode(z + 1/pi) == 'z + M_1_PI'
  688. assert ccode(z + 2/pi) == 'z + M_2_PI'
  689. assert ccode(z + 2/sqrt(pi)) == 'z + M_2_SQRTPI'
  690. assert ccode(z + 2/Sqrt(pi)) == 'z + M_2_SQRTPI'
  691. assert ccode(z + sqrt(2)) == 'z + M_SQRT2'
  692. assert ccode(z + Sqrt(2)) == 'z + M_SQRT2'
  693. assert ccode(z + 1/sqrt(2)) == 'z + M_SQRT1_2'
  694. assert ccode(z + 1/Sqrt(2)) == 'z + M_SQRT1_2'
  695. def test_ccode_Type():
  696. assert ccode(Type('float')) == 'float'
  697. assert ccode(intc) == 'int'
  698. def test_ccode_codegen_ast():
  699. # Note that C only allows comments of the form /* ... */, double forward
  700. # slash is not standard C, and some C compilers will grind to a halt upon
  701. # encountering them.
  702. assert ccode(Comment("this is a comment")) == "/* this is a comment */" # not //
  703. assert ccode(While(abs(x) > 1, [aug_assign(x, '-', 1)])) == (
  704. 'while (fabs(x) > 1) {\n'
  705. ' x -= 1;\n'
  706. '}'
  707. )
  708. assert ccode(Scope([AddAugmentedAssignment(x, 1)])) == (
  709. '{\n'
  710. ' x += 1;\n'
  711. '}'
  712. )
  713. inp_x = Declaration(Variable(x, type=real))
  714. assert ccode(FunctionPrototype(real, 'pwer', [inp_x])) == 'double pwer(double x)'
  715. assert ccode(FunctionDefinition(real, 'pwer', [inp_x], [Assignment(x, x**2)])) == (
  716. 'double pwer(double x){\n'
  717. ' x = pow(x, 2);\n'
  718. '}'
  719. )
  720. # Elements of CodeBlock are formatted as statements:
  721. block = CodeBlock(
  722. x,
  723. Print([x, y], "%d %d"),
  724. FunctionCall('pwer', [x]),
  725. Return(x),
  726. )
  727. assert ccode(block) == '\n'.join([
  728. 'x;',
  729. 'printf("%d %d", x, y);',
  730. 'pwer(x);',
  731. 'return x;',
  732. ])
  733. def test_ccode_UnevaluatedExpr():
  734. assert ccode(UnevaluatedExpr(y * x) + z) == "z + x*y"
  735. assert ccode(UnevaluatedExpr(y + x) + z) == "z + (x + y)" # gh-21955
  736. w = symbols('w')
  737. assert ccode(UnevaluatedExpr(y + x) + UnevaluatedExpr(z + w)) == "(w + z) + (x + y)"
  738. p, q, r = symbols("p q r", real=True)
  739. q_r = UnevaluatedExpr(q + r)
  740. expr = abs(exp(p+q_r))
  741. assert ccode(expr) == "exp(p + (q + r))"
  742. def test_ccode_array_like_containers():
  743. assert ccode([2,3,4]) == "{2, 3, 4}"
  744. assert ccode((2,3,4)) == "{2, 3, 4}"