test_codegen.py 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613
  1. from io import StringIO
  2. from sympy.core import symbols, Eq, pi, Catalan, Lambda, Dummy
  3. from sympy.core.relational import Equality
  4. from sympy.core.symbol import Symbol
  5. from sympy.functions.special.error_functions import erf
  6. from sympy.integrals.integrals import Integral
  7. from sympy.matrices import Matrix, MatrixSymbol
  8. from sympy.utilities.codegen import (
  9. codegen, make_routine, CCodeGen, C89CodeGen, C99CodeGen, InputArgument,
  10. CodeGenError, FCodeGen, CodeGenArgumentListError, OutputArgument,
  11. InOutArgument)
  12. from sympy.testing.pytest import raises
  13. from sympy.utilities.lambdify import implemented_function
  14. #FIXME: Fails due to circular import in with core
  15. # from sympy import codegen
  16. def get_string(dump_fn, routines, prefix="file", header=False, empty=False):
  17. """Wrapper for dump_fn. dump_fn writes its results to a stream object and
  18. this wrapper returns the contents of that stream as a string. This
  19. auxiliary function is used by many tests below.
  20. The header and the empty lines are not generated to facilitate the
  21. testing of the output.
  22. """
  23. output = StringIO()
  24. dump_fn(routines, output, prefix, header, empty)
  25. source = output.getvalue()
  26. output.close()
  27. return source
  28. def test_Routine_argument_order():
  29. a, x, y, z = symbols('a x y z')
  30. expr = (x + y)*z
  31. raises(CodeGenArgumentListError, lambda: make_routine("test", expr,
  32. argument_sequence=[z, x]))
  33. raises(CodeGenArgumentListError, lambda: make_routine("test", Eq(a,
  34. expr), argument_sequence=[z, x, y]))
  35. r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y])
  36. assert [ arg.name for arg in r.arguments ] == [z, x, a, y]
  37. assert [ type(arg) for arg in r.arguments ] == [
  38. InputArgument, InputArgument, OutputArgument, InputArgument ]
  39. r = make_routine('test', Eq(z, expr), argument_sequence=[z, x, y])
  40. assert [ type(arg) for arg in r.arguments ] == [
  41. InOutArgument, InputArgument, InputArgument ]
  42. from sympy.tensor import IndexedBase, Idx
  43. A, B = map(IndexedBase, ['A', 'B'])
  44. m = symbols('m', integer=True)
  45. i = Idx('i', m)
  46. r = make_routine('test', Eq(A[i], B[i]), argument_sequence=[B, A, m])
  47. assert [ arg.name for arg in r.arguments ] == [B.label, A.label, m]
  48. expr = Integral(x*y*z, (x, 1, 2), (y, 1, 3))
  49. r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y])
  50. assert [ arg.name for arg in r.arguments ] == [z, x, a, y]
  51. def test_empty_c_code():
  52. code_gen = C89CodeGen()
  53. source = get_string(code_gen.dump_c, [])
  54. assert source == "#include \"file.h\"\n#include <math.h>\n"
  55. def test_empty_c_code_with_comment():
  56. code_gen = C89CodeGen()
  57. source = get_string(code_gen.dump_c, [], header=True)
  58. assert source[:82] == (
  59. "/******************************************************************************\n *"
  60. )
  61. # " Code generated with SymPy 0.7.2-git "
  62. assert source[158:] == ( "*\n"
  63. " * *\n"
  64. " * See http://www.sympy.org/ for more information. *\n"
  65. " * *\n"
  66. " * This file is part of 'project' *\n"
  67. " ******************************************************************************/\n"
  68. "#include \"file.h\"\n"
  69. "#include <math.h>\n"
  70. )
  71. def test_empty_c_header():
  72. code_gen = C99CodeGen()
  73. source = get_string(code_gen.dump_h, [])
  74. assert source == "#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n#endif\n"
  75. def test_simple_c_code():
  76. x, y, z = symbols('x,y,z')
  77. expr = (x + y)*z
  78. routine = make_routine("test", expr)
  79. code_gen = C89CodeGen()
  80. source = get_string(code_gen.dump_c, [routine])
  81. expected = (
  82. "#include \"file.h\"\n"
  83. "#include <math.h>\n"
  84. "double test(double x, double y, double z) {\n"
  85. " double test_result;\n"
  86. " test_result = z*(x + y);\n"
  87. " return test_result;\n"
  88. "}\n"
  89. )
  90. assert source == expected
  91. def test_c_code_reserved_words():
  92. x, y, z = symbols('if, typedef, while')
  93. expr = (x + y) * z
  94. routine = make_routine("test", expr)
  95. code_gen = C99CodeGen()
  96. source = get_string(code_gen.dump_c, [routine])
  97. expected = (
  98. "#include \"file.h\"\n"
  99. "#include <math.h>\n"
  100. "double test(double if_, double typedef_, double while_) {\n"
  101. " double test_result;\n"
  102. " test_result = while_*(if_ + typedef_);\n"
  103. " return test_result;\n"
  104. "}\n"
  105. )
  106. assert source == expected
  107. def test_numbersymbol_c_code():
  108. routine = make_routine("test", pi**Catalan)
  109. code_gen = C89CodeGen()
  110. source = get_string(code_gen.dump_c, [routine])
  111. expected = (
  112. "#include \"file.h\"\n"
  113. "#include <math.h>\n"
  114. "double test() {\n"
  115. " double test_result;\n"
  116. " double const Catalan = %s;\n"
  117. " test_result = pow(M_PI, Catalan);\n"
  118. " return test_result;\n"
  119. "}\n"
  120. ) % Catalan.evalf(17)
  121. assert source == expected
  122. def test_c_code_argument_order():
  123. x, y, z = symbols('x,y,z')
  124. expr = x + y
  125. routine = make_routine("test", expr, argument_sequence=[z, x, y])
  126. code_gen = C89CodeGen()
  127. source = get_string(code_gen.dump_c, [routine])
  128. expected = (
  129. "#include \"file.h\"\n"
  130. "#include <math.h>\n"
  131. "double test(double z, double x, double y) {\n"
  132. " double test_result;\n"
  133. " test_result = x + y;\n"
  134. " return test_result;\n"
  135. "}\n"
  136. )
  137. assert source == expected
  138. def test_simple_c_header():
  139. x, y, z = symbols('x,y,z')
  140. expr = (x + y)*z
  141. routine = make_routine("test", expr)
  142. code_gen = C89CodeGen()
  143. source = get_string(code_gen.dump_h, [routine])
  144. expected = (
  145. "#ifndef PROJECT__FILE__H\n"
  146. "#define PROJECT__FILE__H\n"
  147. "double test(double x, double y, double z);\n"
  148. "#endif\n"
  149. )
  150. assert source == expected
  151. def test_simple_c_codegen():
  152. x, y, z = symbols('x,y,z')
  153. expr = (x + y)*z
  154. expected = [
  155. ("file.c",
  156. "#include \"file.h\"\n"
  157. "#include <math.h>\n"
  158. "double test(double x, double y, double z) {\n"
  159. " double test_result;\n"
  160. " test_result = z*(x + y);\n"
  161. " return test_result;\n"
  162. "}\n"),
  163. ("file.h",
  164. "#ifndef PROJECT__FILE__H\n"
  165. "#define PROJECT__FILE__H\n"
  166. "double test(double x, double y, double z);\n"
  167. "#endif\n")
  168. ]
  169. result = codegen(("test", expr), "C", "file", header=False, empty=False)
  170. assert result == expected
  171. def test_multiple_results_c():
  172. x, y, z = symbols('x,y,z')
  173. expr1 = (x + y)*z
  174. expr2 = (x - y)*z
  175. routine = make_routine(
  176. "test",
  177. [expr1, expr2]
  178. )
  179. code_gen = C99CodeGen()
  180. raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine]))
  181. def test_no_results_c():
  182. raises(ValueError, lambda: make_routine("test", []))
  183. def test_ansi_math1_codegen():
  184. # not included: log10
  185. from sympy.functions.elementary.complexes import Abs
  186. from sympy.functions.elementary.exponential import log
  187. from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
  188. from sympy.functions.elementary.integers import (ceiling, floor)
  189. from sympy.functions.elementary.miscellaneous import sqrt
  190. from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
  191. x = symbols('x')
  192. name_expr = [
  193. ("test_fabs", Abs(x)),
  194. ("test_acos", acos(x)),
  195. ("test_asin", asin(x)),
  196. ("test_atan", atan(x)),
  197. ("test_ceil", ceiling(x)),
  198. ("test_cos", cos(x)),
  199. ("test_cosh", cosh(x)),
  200. ("test_floor", floor(x)),
  201. ("test_log", log(x)),
  202. ("test_ln", log(x)),
  203. ("test_sin", sin(x)),
  204. ("test_sinh", sinh(x)),
  205. ("test_sqrt", sqrt(x)),
  206. ("test_tan", tan(x)),
  207. ("test_tanh", tanh(x)),
  208. ]
  209. result = codegen(name_expr, "C89", "file", header=False, empty=False)
  210. assert result[0][0] == "file.c"
  211. assert result[0][1] == (
  212. '#include "file.h"\n#include <math.h>\n'
  213. 'double test_fabs(double x) {\n double test_fabs_result;\n test_fabs_result = fabs(x);\n return test_fabs_result;\n}\n'
  214. 'double test_acos(double x) {\n double test_acos_result;\n test_acos_result = acos(x);\n return test_acos_result;\n}\n'
  215. 'double test_asin(double x) {\n double test_asin_result;\n test_asin_result = asin(x);\n return test_asin_result;\n}\n'
  216. 'double test_atan(double x) {\n double test_atan_result;\n test_atan_result = atan(x);\n return test_atan_result;\n}\n'
  217. 'double test_ceil(double x) {\n double test_ceil_result;\n test_ceil_result = ceil(x);\n return test_ceil_result;\n}\n'
  218. 'double test_cos(double x) {\n double test_cos_result;\n test_cos_result = cos(x);\n return test_cos_result;\n}\n'
  219. 'double test_cosh(double x) {\n double test_cosh_result;\n test_cosh_result = cosh(x);\n return test_cosh_result;\n}\n'
  220. 'double test_floor(double x) {\n double test_floor_result;\n test_floor_result = floor(x);\n return test_floor_result;\n}\n'
  221. 'double test_log(double x) {\n double test_log_result;\n test_log_result = log(x);\n return test_log_result;\n}\n'
  222. 'double test_ln(double x) {\n double test_ln_result;\n test_ln_result = log(x);\n return test_ln_result;\n}\n'
  223. 'double test_sin(double x) {\n double test_sin_result;\n test_sin_result = sin(x);\n return test_sin_result;\n}\n'
  224. 'double test_sinh(double x) {\n double test_sinh_result;\n test_sinh_result = sinh(x);\n return test_sinh_result;\n}\n'
  225. 'double test_sqrt(double x) {\n double test_sqrt_result;\n test_sqrt_result = sqrt(x);\n return test_sqrt_result;\n}\n'
  226. 'double test_tan(double x) {\n double test_tan_result;\n test_tan_result = tan(x);\n return test_tan_result;\n}\n'
  227. 'double test_tanh(double x) {\n double test_tanh_result;\n test_tanh_result = tanh(x);\n return test_tanh_result;\n}\n'
  228. )
  229. assert result[1][0] == "file.h"
  230. assert result[1][1] == (
  231. '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n'
  232. 'double test_fabs(double x);\ndouble test_acos(double x);\n'
  233. 'double test_asin(double x);\ndouble test_atan(double x);\n'
  234. 'double test_ceil(double x);\ndouble test_cos(double x);\n'
  235. 'double test_cosh(double x);\ndouble test_floor(double x);\n'
  236. 'double test_log(double x);\ndouble test_ln(double x);\n'
  237. 'double test_sin(double x);\ndouble test_sinh(double x);\n'
  238. 'double test_sqrt(double x);\ndouble test_tan(double x);\n'
  239. 'double test_tanh(double x);\n#endif\n'
  240. )
  241. def test_ansi_math2_codegen():
  242. # not included: frexp, ldexp, modf, fmod
  243. from sympy.functions.elementary.trigonometric import atan2
  244. x, y = symbols('x,y')
  245. name_expr = [
  246. ("test_atan2", atan2(x, y)),
  247. ("test_pow", x**y),
  248. ]
  249. result = codegen(name_expr, "C89", "file", header=False, empty=False)
  250. assert result[0][0] == "file.c"
  251. assert result[0][1] == (
  252. '#include "file.h"\n#include <math.h>\n'
  253. 'double test_atan2(double x, double y) {\n double test_atan2_result;\n test_atan2_result = atan2(x, y);\n return test_atan2_result;\n}\n'
  254. 'double test_pow(double x, double y) {\n double test_pow_result;\n test_pow_result = pow(x, y);\n return test_pow_result;\n}\n'
  255. )
  256. assert result[1][0] == "file.h"
  257. assert result[1][1] == (
  258. '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n'
  259. 'double test_atan2(double x, double y);\n'
  260. 'double test_pow(double x, double y);\n'
  261. '#endif\n'
  262. )
  263. def test_complicated_codegen():
  264. from sympy.functions.elementary.trigonometric import (cos, sin, tan)
  265. x, y, z = symbols('x,y,z')
  266. name_expr = [
  267. ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
  268. ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
  269. ]
  270. result = codegen(name_expr, "C89", "file", header=False, empty=False)
  271. assert result[0][0] == "file.c"
  272. assert result[0][1] == (
  273. '#include "file.h"\n#include <math.h>\n'
  274. 'double test1(double x, double y, double z) {\n'
  275. ' double test1_result;\n'
  276. ' test1_result = '
  277. 'pow(sin(x), 7) + '
  278. '7*pow(sin(x), 6)*cos(y) + '
  279. '7*pow(sin(x), 6)*tan(z) + '
  280. '21*pow(sin(x), 5)*pow(cos(y), 2) + '
  281. '42*pow(sin(x), 5)*cos(y)*tan(z) + '
  282. '21*pow(sin(x), 5)*pow(tan(z), 2) + '
  283. '35*pow(sin(x), 4)*pow(cos(y), 3) + '
  284. '105*pow(sin(x), 4)*pow(cos(y), 2)*tan(z) + '
  285. '105*pow(sin(x), 4)*cos(y)*pow(tan(z), 2) + '
  286. '35*pow(sin(x), 4)*pow(tan(z), 3) + '
  287. '35*pow(sin(x), 3)*pow(cos(y), 4) + '
  288. '140*pow(sin(x), 3)*pow(cos(y), 3)*tan(z) + '
  289. '210*pow(sin(x), 3)*pow(cos(y), 2)*pow(tan(z), 2) + '
  290. '140*pow(sin(x), 3)*cos(y)*pow(tan(z), 3) + '
  291. '35*pow(sin(x), 3)*pow(tan(z), 4) + '
  292. '21*pow(sin(x), 2)*pow(cos(y), 5) + '
  293. '105*pow(sin(x), 2)*pow(cos(y), 4)*tan(z) + '
  294. '210*pow(sin(x), 2)*pow(cos(y), 3)*pow(tan(z), 2) + '
  295. '210*pow(sin(x), 2)*pow(cos(y), 2)*pow(tan(z), 3) + '
  296. '105*pow(sin(x), 2)*cos(y)*pow(tan(z), 4) + '
  297. '21*pow(sin(x), 2)*pow(tan(z), 5) + '
  298. '7*sin(x)*pow(cos(y), 6) + '
  299. '42*sin(x)*pow(cos(y), 5)*tan(z) + '
  300. '105*sin(x)*pow(cos(y), 4)*pow(tan(z), 2) + '
  301. '140*sin(x)*pow(cos(y), 3)*pow(tan(z), 3) + '
  302. '105*sin(x)*pow(cos(y), 2)*pow(tan(z), 4) + '
  303. '42*sin(x)*cos(y)*pow(tan(z), 5) + '
  304. '7*sin(x)*pow(tan(z), 6) + '
  305. 'pow(cos(y), 7) + '
  306. '7*pow(cos(y), 6)*tan(z) + '
  307. '21*pow(cos(y), 5)*pow(tan(z), 2) + '
  308. '35*pow(cos(y), 4)*pow(tan(z), 3) + '
  309. '35*pow(cos(y), 3)*pow(tan(z), 4) + '
  310. '21*pow(cos(y), 2)*pow(tan(z), 5) + '
  311. '7*cos(y)*pow(tan(z), 6) + '
  312. 'pow(tan(z), 7);\n'
  313. ' return test1_result;\n'
  314. '}\n'
  315. 'double test2(double x, double y, double z) {\n'
  316. ' double test2_result;\n'
  317. ' test2_result = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n'
  318. ' return test2_result;\n'
  319. '}\n'
  320. )
  321. assert result[1][0] == "file.h"
  322. assert result[1][1] == (
  323. '#ifndef PROJECT__FILE__H\n'
  324. '#define PROJECT__FILE__H\n'
  325. 'double test1(double x, double y, double z);\n'
  326. 'double test2(double x, double y, double z);\n'
  327. '#endif\n'
  328. )
  329. def test_loops_c():
  330. from sympy.tensor import IndexedBase, Idx
  331. from sympy.core.symbol import symbols
  332. n, m = symbols('n m', integer=True)
  333. A = IndexedBase('A')
  334. x = IndexedBase('x')
  335. y = IndexedBase('y')
  336. i = Idx('i', m)
  337. j = Idx('j', n)
  338. (f1, code), (f2, interface) = codegen(
  339. ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False)
  340. assert f1 == 'file.c'
  341. expected = (
  342. '#include "file.h"\n'
  343. '#include <math.h>\n'
  344. 'void matrix_vector(double *A, int m, int n, double *x, double *y) {\n'
  345. ' for (int i=0; i<m; i++){\n'
  346. ' y[i] = 0;\n'
  347. ' }\n'
  348. ' for (int i=0; i<m; i++){\n'
  349. ' for (int j=0; j<n; j++){\n'
  350. ' y[i] = %(rhs)s + y[i];\n'
  351. ' }\n'
  352. ' }\n'
  353. '}\n'
  354. )
  355. assert (code == expected % {'rhs': 'A[%s]*x[j]' % (i*n + j)} or
  356. code == expected % {'rhs': 'A[%s]*x[j]' % (j + i*n)} or
  357. code == expected % {'rhs': 'x[j]*A[%s]' % (i*n + j)} or
  358. code == expected % {'rhs': 'x[j]*A[%s]' % (j + i*n)})
  359. assert f2 == 'file.h'
  360. assert interface == (
  361. '#ifndef PROJECT__FILE__H\n'
  362. '#define PROJECT__FILE__H\n'
  363. 'void matrix_vector(double *A, int m, int n, double *x, double *y);\n'
  364. '#endif\n'
  365. )
  366. def test_dummy_loops_c():
  367. from sympy.tensor import IndexedBase, Idx
  368. i, m = symbols('i m', integer=True, cls=Dummy)
  369. x = IndexedBase('x')
  370. y = IndexedBase('y')
  371. i = Idx(i, m)
  372. expected = (
  373. '#include "file.h"\n'
  374. '#include <math.h>\n'
  375. 'void test_dummies(int m_%(mno)i, double *x, double *y) {\n'
  376. ' for (int i_%(ino)i=0; i_%(ino)i<m_%(mno)i; i_%(ino)i++){\n'
  377. ' y[i_%(ino)i] = x[i_%(ino)i];\n'
  378. ' }\n'
  379. '}\n'
  380. ) % {'ino': i.label.dummy_index, 'mno': m.dummy_index}
  381. r = make_routine('test_dummies', Eq(y[i], x[i]))
  382. c89 = C89CodeGen()
  383. c99 = C99CodeGen()
  384. code = get_string(c99.dump_c, [r])
  385. assert code == expected
  386. with raises(NotImplementedError):
  387. get_string(c89.dump_c, [r])
  388. def test_partial_loops_c():
  389. # check that loop boundaries are determined by Idx, and array strides
  390. # determined by shape of IndexedBase object.
  391. from sympy.tensor import IndexedBase, Idx
  392. from sympy.core.symbol import symbols
  393. n, m, o, p = symbols('n m o p', integer=True)
  394. A = IndexedBase('A', shape=(m, p))
  395. x = IndexedBase('x')
  396. y = IndexedBase('y')
  397. i = Idx('i', (o, m - 5)) # Note: bounds are inclusive
  398. j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1)
  399. (f1, code), (f2, interface) = codegen(
  400. ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False)
  401. assert f1 == 'file.c'
  402. expected = (
  403. '#include "file.h"\n'
  404. '#include <math.h>\n'
  405. 'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y) {\n'
  406. ' for (int i=o; i<%(upperi)s; i++){\n'
  407. ' y[i] = 0;\n'
  408. ' }\n'
  409. ' for (int i=o; i<%(upperi)s; i++){\n'
  410. ' for (int j=0; j<n; j++){\n'
  411. ' y[i] = %(rhs)s + y[i];\n'
  412. ' }\n'
  413. ' }\n'
  414. '}\n'
  415. ) % {'upperi': m - 4, 'rhs': '%(rhs)s'}
  416. assert (code == expected % {'rhs': 'A[%s]*x[j]' % (i*p + j)} or
  417. code == expected % {'rhs': 'A[%s]*x[j]' % (j + i*p)} or
  418. code == expected % {'rhs': 'x[j]*A[%s]' % (i*p + j)} or
  419. code == expected % {'rhs': 'x[j]*A[%s]' % (j + i*p)})
  420. assert f2 == 'file.h'
  421. assert interface == (
  422. '#ifndef PROJECT__FILE__H\n'
  423. '#define PROJECT__FILE__H\n'
  424. 'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y);\n'
  425. '#endif\n'
  426. )
  427. def test_output_arg_c():
  428. from sympy.core.relational import Equality
  429. from sympy.functions.elementary.trigonometric import (cos, sin)
  430. x, y, z = symbols("x,y,z")
  431. r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
  432. c = C89CodeGen()
  433. result = c.write([r], "test", header=False, empty=False)
  434. assert result[0][0] == "test.c"
  435. expected = (
  436. '#include "test.h"\n'
  437. '#include <math.h>\n'
  438. 'double foo(double x, double *y) {\n'
  439. ' (*y) = sin(x);\n'
  440. ' double foo_result;\n'
  441. ' foo_result = cos(x);\n'
  442. ' return foo_result;\n'
  443. '}\n'
  444. )
  445. assert result[0][1] == expected
  446. def test_output_arg_c_reserved_words():
  447. from sympy.core.relational import Equality
  448. from sympy.functions.elementary.trigonometric import (cos, sin)
  449. x, y, z = symbols("if, while, z")
  450. r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
  451. c = C89CodeGen()
  452. result = c.write([r], "test", header=False, empty=False)
  453. assert result[0][0] == "test.c"
  454. expected = (
  455. '#include "test.h"\n'
  456. '#include <math.h>\n'
  457. 'double foo(double if_, double *while_) {\n'
  458. ' (*while_) = sin(if_);\n'
  459. ' double foo_result;\n'
  460. ' foo_result = cos(if_);\n'
  461. ' return foo_result;\n'
  462. '}\n'
  463. )
  464. assert result[0][1] == expected
  465. def test_multidim_c_argument_cse():
  466. A_sym = MatrixSymbol('A', 3, 3)
  467. b_sym = MatrixSymbol('b', 3, 1)
  468. A = Matrix(A_sym)
  469. b = Matrix(b_sym)
  470. c = A*b
  471. cgen = CCodeGen(project="test", cse=True)
  472. r = cgen.routine("c", c)
  473. r.arguments[-1].result_var = "out"
  474. r.arguments[-1]._name = "out"
  475. code = get_string(cgen.dump_c, [r], prefix="test")
  476. expected = (
  477. '#include "test.h"\n'
  478. "#include <math.h>\n"
  479. "void c(double *A, double *b, double *out) {\n"
  480. " out[0] = A[0]*b[0] + A[1]*b[1] + A[2]*b[2];\n"
  481. " out[1] = A[3]*b[0] + A[4]*b[1] + A[5]*b[2];\n"
  482. " out[2] = A[6]*b[0] + A[7]*b[1] + A[8]*b[2];\n"
  483. "}\n"
  484. )
  485. assert code == expected
  486. def test_ccode_results_named_ordered():
  487. x, y, z = symbols('x,y,z')
  488. B, C = symbols('B,C')
  489. A = MatrixSymbol('A', 1, 3)
  490. expr1 = Equality(A, Matrix([[1, 2, x]]))
  491. expr2 = Equality(C, (x + y)*z)
  492. expr3 = Equality(B, 2*x)
  493. name_expr = ("test", [expr1, expr2, expr3])
  494. expected = (
  495. '#include "test.h"\n'
  496. '#include <math.h>\n'
  497. 'void test(double x, double *C, double z, double y, double *A, double *B) {\n'
  498. ' (*C) = z*(x + y);\n'
  499. ' A[0] = 1;\n'
  500. ' A[1] = 2;\n'
  501. ' A[2] = x;\n'
  502. ' (*B) = 2*x;\n'
  503. '}\n'
  504. )
  505. result = codegen(name_expr, "c", "test", header=False, empty=False,
  506. argument_sequence=(x, C, z, y, A, B))
  507. source = result[0][1]
  508. assert source == expected
  509. def test_ccode_matrixsymbol_slice():
  510. A = MatrixSymbol('A', 5, 3)
  511. B = MatrixSymbol('B', 1, 3)
  512. C = MatrixSymbol('C', 1, 3)
  513. D = MatrixSymbol('D', 5, 1)
  514. name_expr = ("test", [Equality(B, A[0, :]),
  515. Equality(C, A[1, :]),
  516. Equality(D, A[:, 2])])
  517. result = codegen(name_expr, "c99", "test", header=False, empty=False)
  518. source = result[0][1]
  519. expected = (
  520. '#include "test.h"\n'
  521. '#include <math.h>\n'
  522. 'void test(double *A, double *B, double *C, double *D) {\n'
  523. ' B[0] = A[0];\n'
  524. ' B[1] = A[1];\n'
  525. ' B[2] = A[2];\n'
  526. ' C[0] = A[3];\n'
  527. ' C[1] = A[4];\n'
  528. ' C[2] = A[5];\n'
  529. ' D[0] = A[2];\n'
  530. ' D[1] = A[5];\n'
  531. ' D[2] = A[8];\n'
  532. ' D[3] = A[11];\n'
  533. ' D[4] = A[14];\n'
  534. '}\n'
  535. )
  536. assert source == expected
  537. def test_ccode_cse():
  538. a, b, c, d = symbols('a b c d')
  539. e = MatrixSymbol('e', 3, 1)
  540. name_expr = ("test", [Equality(e, Matrix([[a*b], [a*b + c*d], [a*b*c*d]]))])
  541. generator = CCodeGen(cse=True)
  542. result = codegen(name_expr, code_gen=generator, header=False, empty=False)
  543. source = result[0][1]
  544. expected = (
  545. '#include "test.h"\n'
  546. '#include <math.h>\n'
  547. 'void test(double a, double b, double c, double d, double *e) {\n'
  548. ' const double x0 = a*b;\n'
  549. ' const double x1 = c*d;\n'
  550. ' e[0] = x0;\n'
  551. ' e[1] = x0 + x1;\n'
  552. ' e[2] = x0*x1;\n'
  553. '}\n'
  554. )
  555. assert source == expected
  556. def test_ccode_unused_array_arg():
  557. x = MatrixSymbol('x', 2, 1)
  558. # x does not appear in output
  559. name_expr = ("test", 1.0)
  560. generator = CCodeGen()
  561. result = codegen(name_expr, code_gen=generator, header=False, empty=False, argument_sequence=(x,))
  562. source = result[0][1]
  563. # note: x should appear as (double *)
  564. expected = (
  565. '#include "test.h"\n'
  566. '#include <math.h>\n'
  567. 'double test(double *x) {\n'
  568. ' double test_result;\n'
  569. ' test_result = 1.0;\n'
  570. ' return test_result;\n'
  571. '}\n'
  572. )
  573. assert source == expected
  574. def test_empty_f_code():
  575. code_gen = FCodeGen()
  576. source = get_string(code_gen.dump_f95, [])
  577. assert source == ""
  578. def test_empty_f_code_with_header():
  579. code_gen = FCodeGen()
  580. source = get_string(code_gen.dump_f95, [], header=True)
  581. assert source[:82] == (
  582. "!******************************************************************************\n!*"
  583. )
  584. # " Code generated with SymPy 0.7.2-git "
  585. assert source[158:] == ( "*\n"
  586. "!* *\n"
  587. "!* See http://www.sympy.org/ for more information. *\n"
  588. "!* *\n"
  589. "!* This file is part of 'project' *\n"
  590. "!******************************************************************************\n"
  591. )
  592. def test_empty_f_header():
  593. code_gen = FCodeGen()
  594. source = get_string(code_gen.dump_h, [])
  595. assert source == ""
  596. def test_simple_f_code():
  597. x, y, z = symbols('x,y,z')
  598. expr = (x + y)*z
  599. routine = make_routine("test", expr)
  600. code_gen = FCodeGen()
  601. source = get_string(code_gen.dump_f95, [routine])
  602. expected = (
  603. "REAL*8 function test(x, y, z)\n"
  604. "implicit none\n"
  605. "REAL*8, intent(in) :: x\n"
  606. "REAL*8, intent(in) :: y\n"
  607. "REAL*8, intent(in) :: z\n"
  608. "test = z*(x + y)\n"
  609. "end function\n"
  610. )
  611. assert source == expected
  612. def test_numbersymbol_f_code():
  613. routine = make_routine("test", pi**Catalan)
  614. code_gen = FCodeGen()
  615. source = get_string(code_gen.dump_f95, [routine])
  616. expected = (
  617. "REAL*8 function test()\n"
  618. "implicit none\n"
  619. "REAL*8, parameter :: Catalan = %sd0\n"
  620. "REAL*8, parameter :: pi = %sd0\n"
  621. "test = pi**Catalan\n"
  622. "end function\n"
  623. ) % (Catalan.evalf(17), pi.evalf(17))
  624. assert source == expected
  625. def test_erf_f_code():
  626. x = symbols('x')
  627. routine = make_routine("test", erf(x) - erf(-2 * x))
  628. code_gen = FCodeGen()
  629. source = get_string(code_gen.dump_f95, [routine])
  630. expected = (
  631. "REAL*8 function test(x)\n"
  632. "implicit none\n"
  633. "REAL*8, intent(in) :: x\n"
  634. "test = erf(x) + erf(2.0d0*x)\n"
  635. "end function\n"
  636. )
  637. assert source == expected, source
  638. def test_f_code_argument_order():
  639. x, y, z = symbols('x,y,z')
  640. expr = x + y
  641. routine = make_routine("test", expr, argument_sequence=[z, x, y])
  642. code_gen = FCodeGen()
  643. source = get_string(code_gen.dump_f95, [routine])
  644. expected = (
  645. "REAL*8 function test(z, x, y)\n"
  646. "implicit none\n"
  647. "REAL*8, intent(in) :: z\n"
  648. "REAL*8, intent(in) :: x\n"
  649. "REAL*8, intent(in) :: y\n"
  650. "test = x + y\n"
  651. "end function\n"
  652. )
  653. assert source == expected
  654. def test_simple_f_header():
  655. x, y, z = symbols('x,y,z')
  656. expr = (x + y)*z
  657. routine = make_routine("test", expr)
  658. code_gen = FCodeGen()
  659. source = get_string(code_gen.dump_h, [routine])
  660. expected = (
  661. "interface\n"
  662. "REAL*8 function test(x, y, z)\n"
  663. "implicit none\n"
  664. "REAL*8, intent(in) :: x\n"
  665. "REAL*8, intent(in) :: y\n"
  666. "REAL*8, intent(in) :: z\n"
  667. "end function\n"
  668. "end interface\n"
  669. )
  670. assert source == expected
  671. def test_simple_f_codegen():
  672. x, y, z = symbols('x,y,z')
  673. expr = (x + y)*z
  674. result = codegen(
  675. ("test", expr), "F95", "file", header=False, empty=False)
  676. expected = [
  677. ("file.f90",
  678. "REAL*8 function test(x, y, z)\n"
  679. "implicit none\n"
  680. "REAL*8, intent(in) :: x\n"
  681. "REAL*8, intent(in) :: y\n"
  682. "REAL*8, intent(in) :: z\n"
  683. "test = z*(x + y)\n"
  684. "end function\n"),
  685. ("file.h",
  686. "interface\n"
  687. "REAL*8 function test(x, y, z)\n"
  688. "implicit none\n"
  689. "REAL*8, intent(in) :: x\n"
  690. "REAL*8, intent(in) :: y\n"
  691. "REAL*8, intent(in) :: z\n"
  692. "end function\n"
  693. "end interface\n")
  694. ]
  695. assert result == expected
  696. def test_multiple_results_f():
  697. x, y, z = symbols('x,y,z')
  698. expr1 = (x + y)*z
  699. expr2 = (x - y)*z
  700. routine = make_routine(
  701. "test",
  702. [expr1, expr2]
  703. )
  704. code_gen = FCodeGen()
  705. raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine]))
  706. def test_no_results_f():
  707. raises(ValueError, lambda: make_routine("test", []))
  708. def test_intrinsic_math_codegen():
  709. # not included: log10
  710. from sympy.functions.elementary.complexes import Abs
  711. from sympy.functions.elementary.exponential import log
  712. from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
  713. from sympy.functions.elementary.miscellaneous import sqrt
  714. from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
  715. x = symbols('x')
  716. name_expr = [
  717. ("test_abs", Abs(x)),
  718. ("test_acos", acos(x)),
  719. ("test_asin", asin(x)),
  720. ("test_atan", atan(x)),
  721. ("test_cos", cos(x)),
  722. ("test_cosh", cosh(x)),
  723. ("test_log", log(x)),
  724. ("test_ln", log(x)),
  725. ("test_sin", sin(x)),
  726. ("test_sinh", sinh(x)),
  727. ("test_sqrt", sqrt(x)),
  728. ("test_tan", tan(x)),
  729. ("test_tanh", tanh(x)),
  730. ]
  731. result = codegen(name_expr, "F95", "file", header=False, empty=False)
  732. assert result[0][0] == "file.f90"
  733. expected = (
  734. 'REAL*8 function test_abs(x)\n'
  735. 'implicit none\n'
  736. 'REAL*8, intent(in) :: x\n'
  737. 'test_abs = abs(x)\n'
  738. 'end function\n'
  739. 'REAL*8 function test_acos(x)\n'
  740. 'implicit none\n'
  741. 'REAL*8, intent(in) :: x\n'
  742. 'test_acos = acos(x)\n'
  743. 'end function\n'
  744. 'REAL*8 function test_asin(x)\n'
  745. 'implicit none\n'
  746. 'REAL*8, intent(in) :: x\n'
  747. 'test_asin = asin(x)\n'
  748. 'end function\n'
  749. 'REAL*8 function test_atan(x)\n'
  750. 'implicit none\n'
  751. 'REAL*8, intent(in) :: x\n'
  752. 'test_atan = atan(x)\n'
  753. 'end function\n'
  754. 'REAL*8 function test_cos(x)\n'
  755. 'implicit none\n'
  756. 'REAL*8, intent(in) :: x\n'
  757. 'test_cos = cos(x)\n'
  758. 'end function\n'
  759. 'REAL*8 function test_cosh(x)\n'
  760. 'implicit none\n'
  761. 'REAL*8, intent(in) :: x\n'
  762. 'test_cosh = cosh(x)\n'
  763. 'end function\n'
  764. 'REAL*8 function test_log(x)\n'
  765. 'implicit none\n'
  766. 'REAL*8, intent(in) :: x\n'
  767. 'test_log = log(x)\n'
  768. 'end function\n'
  769. 'REAL*8 function test_ln(x)\n'
  770. 'implicit none\n'
  771. 'REAL*8, intent(in) :: x\n'
  772. 'test_ln = log(x)\n'
  773. 'end function\n'
  774. 'REAL*8 function test_sin(x)\n'
  775. 'implicit none\n'
  776. 'REAL*8, intent(in) :: x\n'
  777. 'test_sin = sin(x)\n'
  778. 'end function\n'
  779. 'REAL*8 function test_sinh(x)\n'
  780. 'implicit none\n'
  781. 'REAL*8, intent(in) :: x\n'
  782. 'test_sinh = sinh(x)\n'
  783. 'end function\n'
  784. 'REAL*8 function test_sqrt(x)\n'
  785. 'implicit none\n'
  786. 'REAL*8, intent(in) :: x\n'
  787. 'test_sqrt = sqrt(x)\n'
  788. 'end function\n'
  789. 'REAL*8 function test_tan(x)\n'
  790. 'implicit none\n'
  791. 'REAL*8, intent(in) :: x\n'
  792. 'test_tan = tan(x)\n'
  793. 'end function\n'
  794. 'REAL*8 function test_tanh(x)\n'
  795. 'implicit none\n'
  796. 'REAL*8, intent(in) :: x\n'
  797. 'test_tanh = tanh(x)\n'
  798. 'end function\n'
  799. )
  800. assert result[0][1] == expected
  801. assert result[1][0] == "file.h"
  802. expected = (
  803. 'interface\n'
  804. 'REAL*8 function test_abs(x)\n'
  805. 'implicit none\n'
  806. 'REAL*8, intent(in) :: x\n'
  807. 'end function\n'
  808. 'end interface\n'
  809. 'interface\n'
  810. 'REAL*8 function test_acos(x)\n'
  811. 'implicit none\n'
  812. 'REAL*8, intent(in) :: x\n'
  813. 'end function\n'
  814. 'end interface\n'
  815. 'interface\n'
  816. 'REAL*8 function test_asin(x)\n'
  817. 'implicit none\n'
  818. 'REAL*8, intent(in) :: x\n'
  819. 'end function\n'
  820. 'end interface\n'
  821. 'interface\n'
  822. 'REAL*8 function test_atan(x)\n'
  823. 'implicit none\n'
  824. 'REAL*8, intent(in) :: x\n'
  825. 'end function\n'
  826. 'end interface\n'
  827. 'interface\n'
  828. 'REAL*8 function test_cos(x)\n'
  829. 'implicit none\n'
  830. 'REAL*8, intent(in) :: x\n'
  831. 'end function\n'
  832. 'end interface\n'
  833. 'interface\n'
  834. 'REAL*8 function test_cosh(x)\n'
  835. 'implicit none\n'
  836. 'REAL*8, intent(in) :: x\n'
  837. 'end function\n'
  838. 'end interface\n'
  839. 'interface\n'
  840. 'REAL*8 function test_log(x)\n'
  841. 'implicit none\n'
  842. 'REAL*8, intent(in) :: x\n'
  843. 'end function\n'
  844. 'end interface\n'
  845. 'interface\n'
  846. 'REAL*8 function test_ln(x)\n'
  847. 'implicit none\n'
  848. 'REAL*8, intent(in) :: x\n'
  849. 'end function\n'
  850. 'end interface\n'
  851. 'interface\n'
  852. 'REAL*8 function test_sin(x)\n'
  853. 'implicit none\n'
  854. 'REAL*8, intent(in) :: x\n'
  855. 'end function\n'
  856. 'end interface\n'
  857. 'interface\n'
  858. 'REAL*8 function test_sinh(x)\n'
  859. 'implicit none\n'
  860. 'REAL*8, intent(in) :: x\n'
  861. 'end function\n'
  862. 'end interface\n'
  863. 'interface\n'
  864. 'REAL*8 function test_sqrt(x)\n'
  865. 'implicit none\n'
  866. 'REAL*8, intent(in) :: x\n'
  867. 'end function\n'
  868. 'end interface\n'
  869. 'interface\n'
  870. 'REAL*8 function test_tan(x)\n'
  871. 'implicit none\n'
  872. 'REAL*8, intent(in) :: x\n'
  873. 'end function\n'
  874. 'end interface\n'
  875. 'interface\n'
  876. 'REAL*8 function test_tanh(x)\n'
  877. 'implicit none\n'
  878. 'REAL*8, intent(in) :: x\n'
  879. 'end function\n'
  880. 'end interface\n'
  881. )
  882. assert result[1][1] == expected
  883. def test_intrinsic_math2_codegen():
  884. # not included: frexp, ldexp, modf, fmod
  885. from sympy.functions.elementary.trigonometric import atan2
  886. x, y = symbols('x,y')
  887. name_expr = [
  888. ("test_atan2", atan2(x, y)),
  889. ("test_pow", x**y),
  890. ]
  891. result = codegen(name_expr, "F95", "file", header=False, empty=False)
  892. assert result[0][0] == "file.f90"
  893. expected = (
  894. 'REAL*8 function test_atan2(x, y)\n'
  895. 'implicit none\n'
  896. 'REAL*8, intent(in) :: x\n'
  897. 'REAL*8, intent(in) :: y\n'
  898. 'test_atan2 = atan2(x, y)\n'
  899. 'end function\n'
  900. 'REAL*8 function test_pow(x, y)\n'
  901. 'implicit none\n'
  902. 'REAL*8, intent(in) :: x\n'
  903. 'REAL*8, intent(in) :: y\n'
  904. 'test_pow = x**y\n'
  905. 'end function\n'
  906. )
  907. assert result[0][1] == expected
  908. assert result[1][0] == "file.h"
  909. expected = (
  910. 'interface\n'
  911. 'REAL*8 function test_atan2(x, y)\n'
  912. 'implicit none\n'
  913. 'REAL*8, intent(in) :: x\n'
  914. 'REAL*8, intent(in) :: y\n'
  915. 'end function\n'
  916. 'end interface\n'
  917. 'interface\n'
  918. 'REAL*8 function test_pow(x, y)\n'
  919. 'implicit none\n'
  920. 'REAL*8, intent(in) :: x\n'
  921. 'REAL*8, intent(in) :: y\n'
  922. 'end function\n'
  923. 'end interface\n'
  924. )
  925. assert result[1][1] == expected
  926. def test_complicated_codegen_f95():
  927. from sympy.functions.elementary.trigonometric import (cos, sin, tan)
  928. x, y, z = symbols('x,y,z')
  929. name_expr = [
  930. ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
  931. ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
  932. ]
  933. result = codegen(name_expr, "F95", "file", header=False, empty=False)
  934. assert result[0][0] == "file.f90"
  935. expected = (
  936. 'REAL*8 function test1(x, y, z)\n'
  937. 'implicit none\n'
  938. 'REAL*8, intent(in) :: x\n'
  939. 'REAL*8, intent(in) :: y\n'
  940. 'REAL*8, intent(in) :: z\n'
  941. 'test1 = sin(x)**7 + 7*sin(x)**6*cos(y) + 7*sin(x)**6*tan(z) + 21*sin(x) &\n'
  942. ' **5*cos(y)**2 + 42*sin(x)**5*cos(y)*tan(z) + 21*sin(x)**5*tan(z) &\n'
  943. ' **2 + 35*sin(x)**4*cos(y)**3 + 105*sin(x)**4*cos(y)**2*tan(z) + &\n'
  944. ' 105*sin(x)**4*cos(y)*tan(z)**2 + 35*sin(x)**4*tan(z)**3 + 35*sin( &\n'
  945. ' x)**3*cos(y)**4 + 140*sin(x)**3*cos(y)**3*tan(z) + 210*sin(x)**3* &\n'
  946. ' cos(y)**2*tan(z)**2 + 140*sin(x)**3*cos(y)*tan(z)**3 + 35*sin(x) &\n'
  947. ' **3*tan(z)**4 + 21*sin(x)**2*cos(y)**5 + 105*sin(x)**2*cos(y)**4* &\n'
  948. ' tan(z) + 210*sin(x)**2*cos(y)**3*tan(z)**2 + 210*sin(x)**2*cos(y) &\n'
  949. ' **2*tan(z)**3 + 105*sin(x)**2*cos(y)*tan(z)**4 + 21*sin(x)**2*tan &\n'
  950. ' (z)**5 + 7*sin(x)*cos(y)**6 + 42*sin(x)*cos(y)**5*tan(z) + 105* &\n'
  951. ' sin(x)*cos(y)**4*tan(z)**2 + 140*sin(x)*cos(y)**3*tan(z)**3 + 105 &\n'
  952. ' *sin(x)*cos(y)**2*tan(z)**4 + 42*sin(x)*cos(y)*tan(z)**5 + 7*sin( &\n'
  953. ' x)*tan(z)**6 + cos(y)**7 + 7*cos(y)**6*tan(z) + 21*cos(y)**5*tan( &\n'
  954. ' z)**2 + 35*cos(y)**4*tan(z)**3 + 35*cos(y)**3*tan(z)**4 + 21*cos( &\n'
  955. ' y)**2*tan(z)**5 + 7*cos(y)*tan(z)**6 + tan(z)**7\n'
  956. 'end function\n'
  957. 'REAL*8 function test2(x, y, z)\n'
  958. 'implicit none\n'
  959. 'REAL*8, intent(in) :: x\n'
  960. 'REAL*8, intent(in) :: y\n'
  961. 'REAL*8, intent(in) :: z\n'
  962. 'test2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n'
  963. 'end function\n'
  964. )
  965. assert result[0][1] == expected
  966. assert result[1][0] == "file.h"
  967. expected = (
  968. 'interface\n'
  969. 'REAL*8 function test1(x, y, z)\n'
  970. 'implicit none\n'
  971. 'REAL*8, intent(in) :: x\n'
  972. 'REAL*8, intent(in) :: y\n'
  973. 'REAL*8, intent(in) :: z\n'
  974. 'end function\n'
  975. 'end interface\n'
  976. 'interface\n'
  977. 'REAL*8 function test2(x, y, z)\n'
  978. 'implicit none\n'
  979. 'REAL*8, intent(in) :: x\n'
  980. 'REAL*8, intent(in) :: y\n'
  981. 'REAL*8, intent(in) :: z\n'
  982. 'end function\n'
  983. 'end interface\n'
  984. )
  985. assert result[1][1] == expected
  986. def test_loops():
  987. from sympy.tensor import IndexedBase, Idx
  988. from sympy.core.symbol import symbols
  989. n, m = symbols('n,m', integer=True)
  990. A, x, y = map(IndexedBase, 'Axy')
  991. i = Idx('i', m)
  992. j = Idx('j', n)
  993. (f1, code), (f2, interface) = codegen(
  994. ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False)
  995. assert f1 == 'file.f90'
  996. expected = (
  997. 'subroutine matrix_vector(A, m, n, x, y)\n'
  998. 'implicit none\n'
  999. 'INTEGER*4, intent(in) :: m\n'
  1000. 'INTEGER*4, intent(in) :: n\n'
  1001. 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
  1002. 'REAL*8, intent(in), dimension(1:n) :: x\n'
  1003. 'REAL*8, intent(out), dimension(1:m) :: y\n'
  1004. 'INTEGER*4 :: i\n'
  1005. 'INTEGER*4 :: j\n'
  1006. 'do i = 1, m\n'
  1007. ' y(i) = 0\n'
  1008. 'end do\n'
  1009. 'do i = 1, m\n'
  1010. ' do j = 1, n\n'
  1011. ' y(i) = %(rhs)s + y(i)\n'
  1012. ' end do\n'
  1013. 'end do\n'
  1014. 'end subroutine\n'
  1015. )
  1016. assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\
  1017. code == expected % {'rhs': 'x(j)*A(i, j)'}
  1018. assert f2 == 'file.h'
  1019. assert interface == (
  1020. 'interface\n'
  1021. 'subroutine matrix_vector(A, m, n, x, y)\n'
  1022. 'implicit none\n'
  1023. 'INTEGER*4, intent(in) :: m\n'
  1024. 'INTEGER*4, intent(in) :: n\n'
  1025. 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
  1026. 'REAL*8, intent(in), dimension(1:n) :: x\n'
  1027. 'REAL*8, intent(out), dimension(1:m) :: y\n'
  1028. 'end subroutine\n'
  1029. 'end interface\n'
  1030. )
  1031. def test_dummy_loops_f95():
  1032. from sympy.tensor import IndexedBase, Idx
  1033. i, m = symbols('i m', integer=True, cls=Dummy)
  1034. x = IndexedBase('x')
  1035. y = IndexedBase('y')
  1036. i = Idx(i, m)
  1037. expected = (
  1038. 'subroutine test_dummies(m_%(mcount)i, x, y)\n'
  1039. 'implicit none\n'
  1040. 'INTEGER*4, intent(in) :: m_%(mcount)i\n'
  1041. 'REAL*8, intent(in), dimension(1:m_%(mcount)i) :: x\n'
  1042. 'REAL*8, intent(out), dimension(1:m_%(mcount)i) :: y\n'
  1043. 'INTEGER*4 :: i_%(icount)i\n'
  1044. 'do i_%(icount)i = 1, m_%(mcount)i\n'
  1045. ' y(i_%(icount)i) = x(i_%(icount)i)\n'
  1046. 'end do\n'
  1047. 'end subroutine\n'
  1048. ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
  1049. r = make_routine('test_dummies', Eq(y[i], x[i]))
  1050. c = FCodeGen()
  1051. code = get_string(c.dump_f95, [r])
  1052. assert code == expected
  1053. def test_loops_InOut():
  1054. from sympy.tensor import IndexedBase, Idx
  1055. from sympy.core.symbol import symbols
  1056. i, j, n, m = symbols('i,j,n,m', integer=True)
  1057. A, x, y = symbols('A,x,y')
  1058. A = IndexedBase(A)[Idx(i, m), Idx(j, n)]
  1059. x = IndexedBase(x)[Idx(j, n)]
  1060. y = IndexedBase(y)[Idx(i, m)]
  1061. (f1, code), (f2, interface) = codegen(
  1062. ('matrix_vector', Eq(y, y + A*x)), "F95", "file", header=False, empty=False)
  1063. assert f1 == 'file.f90'
  1064. expected = (
  1065. 'subroutine matrix_vector(A, m, n, x, y)\n'
  1066. 'implicit none\n'
  1067. 'INTEGER*4, intent(in) :: m\n'
  1068. 'INTEGER*4, intent(in) :: n\n'
  1069. 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
  1070. 'REAL*8, intent(in), dimension(1:n) :: x\n'
  1071. 'REAL*8, intent(inout), dimension(1:m) :: y\n'
  1072. 'INTEGER*4 :: i\n'
  1073. 'INTEGER*4 :: j\n'
  1074. 'do i = 1, m\n'
  1075. ' do j = 1, n\n'
  1076. ' y(i) = %(rhs)s + y(i)\n'
  1077. ' end do\n'
  1078. 'end do\n'
  1079. 'end subroutine\n'
  1080. )
  1081. assert (code == expected % {'rhs': 'A(i, j)*x(j)'} or
  1082. code == expected % {'rhs': 'x(j)*A(i, j)'})
  1083. assert f2 == 'file.h'
  1084. assert interface == (
  1085. 'interface\n'
  1086. 'subroutine matrix_vector(A, m, n, x, y)\n'
  1087. 'implicit none\n'
  1088. 'INTEGER*4, intent(in) :: m\n'
  1089. 'INTEGER*4, intent(in) :: n\n'
  1090. 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
  1091. 'REAL*8, intent(in), dimension(1:n) :: x\n'
  1092. 'REAL*8, intent(inout), dimension(1:m) :: y\n'
  1093. 'end subroutine\n'
  1094. 'end interface\n'
  1095. )
  1096. def test_partial_loops_f():
  1097. # check that loop boundaries are determined by Idx, and array strides
  1098. # determined by shape of IndexedBase object.
  1099. from sympy.tensor import IndexedBase, Idx
  1100. from sympy.core.symbol import symbols
  1101. n, m, o, p = symbols('n m o p', integer=True)
  1102. A = IndexedBase('A', shape=(m, p))
  1103. x = IndexedBase('x')
  1104. y = IndexedBase('y')
  1105. i = Idx('i', (o, m - 5)) # Note: bounds are inclusive
  1106. j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1)
  1107. (f1, code), (f2, interface) = codegen(
  1108. ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False)
  1109. expected = (
  1110. 'subroutine matrix_vector(A, m, n, o, p, x, y)\n'
  1111. 'implicit none\n'
  1112. 'INTEGER*4, intent(in) :: m\n'
  1113. 'INTEGER*4, intent(in) :: n\n'
  1114. 'INTEGER*4, intent(in) :: o\n'
  1115. 'INTEGER*4, intent(in) :: p\n'
  1116. 'REAL*8, intent(in), dimension(1:m, 1:p) :: A\n'
  1117. 'REAL*8, intent(in), dimension(1:n) :: x\n'
  1118. 'REAL*8, intent(out), dimension(1:%(iup-ilow)s) :: y\n'
  1119. 'INTEGER*4 :: i\n'
  1120. 'INTEGER*4 :: j\n'
  1121. 'do i = %(ilow)s, %(iup)s\n'
  1122. ' y(i) = 0\n'
  1123. 'end do\n'
  1124. 'do i = %(ilow)s, %(iup)s\n'
  1125. ' do j = 1, n\n'
  1126. ' y(i) = %(rhs)s + y(i)\n'
  1127. ' end do\n'
  1128. 'end do\n'
  1129. 'end subroutine\n'
  1130. ) % {
  1131. 'rhs': '%(rhs)s',
  1132. 'iup': str(m - 4),
  1133. 'ilow': str(1 + o),
  1134. 'iup-ilow': str(m - 4 - o)
  1135. }
  1136. assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\
  1137. code == expected % {'rhs': 'x(j)*A(i, j)'}
  1138. def test_output_arg_f():
  1139. from sympy.core.relational import Equality
  1140. from sympy.functions.elementary.trigonometric import (cos, sin)
  1141. x, y, z = symbols("x,y,z")
  1142. r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
  1143. c = FCodeGen()
  1144. result = c.write([r], "test", header=False, empty=False)
  1145. assert result[0][0] == "test.f90"
  1146. assert result[0][1] == (
  1147. 'REAL*8 function foo(x, y)\n'
  1148. 'implicit none\n'
  1149. 'REAL*8, intent(in) :: x\n'
  1150. 'REAL*8, intent(out) :: y\n'
  1151. 'y = sin(x)\n'
  1152. 'foo = cos(x)\n'
  1153. 'end function\n'
  1154. )
  1155. def test_inline_function():
  1156. from sympy.tensor import IndexedBase, Idx
  1157. from sympy.core.symbol import symbols
  1158. n, m = symbols('n m', integer=True)
  1159. A, x, y = map(IndexedBase, 'Axy')
  1160. i = Idx('i', m)
  1161. p = FCodeGen()
  1162. func = implemented_function('func', Lambda(n, n*(n + 1)))
  1163. routine = make_routine('test_inline', Eq(y[i], func(x[i])))
  1164. code = get_string(p.dump_f95, [routine])
  1165. expected = (
  1166. 'subroutine test_inline(m, x, y)\n'
  1167. 'implicit none\n'
  1168. 'INTEGER*4, intent(in) :: m\n'
  1169. 'REAL*8, intent(in), dimension(1:m) :: x\n'
  1170. 'REAL*8, intent(out), dimension(1:m) :: y\n'
  1171. 'INTEGER*4 :: i\n'
  1172. 'do i = 1, m\n'
  1173. ' y(i) = %s*%s\n'
  1174. 'end do\n'
  1175. 'end subroutine\n'
  1176. )
  1177. args = ('x(i)', '(x(i) + 1)')
  1178. assert code == expected % args or\
  1179. code == expected % args[::-1]
  1180. def test_f_code_call_signature_wrap():
  1181. # Issue #7934
  1182. x = symbols('x:20')
  1183. expr = 0
  1184. for sym in x:
  1185. expr += sym
  1186. routine = make_routine("test", expr)
  1187. code_gen = FCodeGen()
  1188. source = get_string(code_gen.dump_f95, [routine])
  1189. expected = """\
  1190. REAL*8 function test(x0, x1, x10, x11, x12, x13, x14, x15, x16, x17, x18, &
  1191. x19, x2, x3, x4, x5, x6, x7, x8, x9)
  1192. implicit none
  1193. REAL*8, intent(in) :: x0
  1194. REAL*8, intent(in) :: x1
  1195. REAL*8, intent(in) :: x10
  1196. REAL*8, intent(in) :: x11
  1197. REAL*8, intent(in) :: x12
  1198. REAL*8, intent(in) :: x13
  1199. REAL*8, intent(in) :: x14
  1200. REAL*8, intent(in) :: x15
  1201. REAL*8, intent(in) :: x16
  1202. REAL*8, intent(in) :: x17
  1203. REAL*8, intent(in) :: x18
  1204. REAL*8, intent(in) :: x19
  1205. REAL*8, intent(in) :: x2
  1206. REAL*8, intent(in) :: x3
  1207. REAL*8, intent(in) :: x4
  1208. REAL*8, intent(in) :: x5
  1209. REAL*8, intent(in) :: x6
  1210. REAL*8, intent(in) :: x7
  1211. REAL*8, intent(in) :: x8
  1212. REAL*8, intent(in) :: x9
  1213. test = x0 + x1 + x10 + x11 + x12 + x13 + x14 + x15 + x16 + x17 + x18 + &
  1214. x19 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
  1215. end function
  1216. """
  1217. assert source == expected
  1218. def test_check_case():
  1219. x, X = symbols('x,X')
  1220. raises(CodeGenError, lambda: codegen(('test', x*X), 'f95', 'prefix'))
  1221. def test_check_case_false_positive():
  1222. # The upper case/lower case exception should not be triggered by SymPy
  1223. # objects that differ only because of assumptions. (It may be useful to
  1224. # have a check for that as well, but here we only want to test against
  1225. # false positives with respect to case checking.)
  1226. x1 = symbols('x')
  1227. x2 = symbols('x', my_assumption=True)
  1228. try:
  1229. codegen(('test', x1*x2), 'f95', 'prefix')
  1230. except CodeGenError as e:
  1231. if e.args[0].startswith("Fortran ignores case."):
  1232. raise AssertionError("This exception should not be raised!")
  1233. def test_c_fortran_omit_routine_name():
  1234. x, y = symbols("x,y")
  1235. name_expr = [("foo", 2*x)]
  1236. result = codegen(name_expr, "F95", header=False, empty=False)
  1237. expresult = codegen(name_expr, "F95", "foo", header=False, empty=False)
  1238. assert result[0][1] == expresult[0][1]
  1239. name_expr = ("foo", x*y)
  1240. result = codegen(name_expr, "F95", header=False, empty=False)
  1241. expresult = codegen(name_expr, "F95", "foo", header=False, empty=False)
  1242. assert result[0][1] == expresult[0][1]
  1243. name_expr = ("foo", Matrix([[x, y], [x+y, x-y]]))
  1244. result = codegen(name_expr, "C89", header=False, empty=False)
  1245. expresult = codegen(name_expr, "C89", "foo", header=False, empty=False)
  1246. assert result[0][1] == expresult[0][1]
  1247. def test_fcode_matrix_output():
  1248. x, y, z = symbols('x,y,z')
  1249. e1 = x + y
  1250. e2 = Matrix([[x, y], [z, 16]])
  1251. name_expr = ("test", (e1, e2))
  1252. result = codegen(name_expr, "f95", "test", header=False, empty=False)
  1253. source = result[0][1]
  1254. expected = (
  1255. "REAL*8 function test(x, y, z, out_%(hash)s)\n"
  1256. "implicit none\n"
  1257. "REAL*8, intent(in) :: x\n"
  1258. "REAL*8, intent(in) :: y\n"
  1259. "REAL*8, intent(in) :: z\n"
  1260. "REAL*8, intent(out), dimension(1:2, 1:2) :: out_%(hash)s\n"
  1261. "out_%(hash)s(1, 1) = x\n"
  1262. "out_%(hash)s(2, 1) = z\n"
  1263. "out_%(hash)s(1, 2) = y\n"
  1264. "out_%(hash)s(2, 2) = 16\n"
  1265. "test = x + y\n"
  1266. "end function\n"
  1267. )
  1268. # look for the magic number
  1269. a = source.splitlines()[5]
  1270. b = a.split('_')
  1271. out = b[1]
  1272. expected = expected % {'hash': out}
  1273. assert source == expected
  1274. def test_fcode_results_named_ordered():
  1275. x, y, z = symbols('x,y,z')
  1276. B, C = symbols('B,C')
  1277. A = MatrixSymbol('A', 1, 3)
  1278. expr1 = Equality(A, Matrix([[1, 2, x]]))
  1279. expr2 = Equality(C, (x + y)*z)
  1280. expr3 = Equality(B, 2*x)
  1281. name_expr = ("test", [expr1, expr2, expr3])
  1282. result = codegen(name_expr, "f95", "test", header=False, empty=False,
  1283. argument_sequence=(x, z, y, C, A, B))
  1284. source = result[0][1]
  1285. expected = (
  1286. "subroutine test(x, z, y, C, A, B)\n"
  1287. "implicit none\n"
  1288. "REAL*8, intent(in) :: x\n"
  1289. "REAL*8, intent(in) :: z\n"
  1290. "REAL*8, intent(in) :: y\n"
  1291. "REAL*8, intent(out) :: C\n"
  1292. "REAL*8, intent(out) :: B\n"
  1293. "REAL*8, intent(out), dimension(1:1, 1:3) :: A\n"
  1294. "C = z*(x + y)\n"
  1295. "A(1, 1) = 1\n"
  1296. "A(1, 2) = 2\n"
  1297. "A(1, 3) = x\n"
  1298. "B = 2*x\n"
  1299. "end subroutine\n"
  1300. )
  1301. assert source == expected
  1302. def test_fcode_matrixsymbol_slice():
  1303. A = MatrixSymbol('A', 2, 3)
  1304. B = MatrixSymbol('B', 1, 3)
  1305. C = MatrixSymbol('C', 1, 3)
  1306. D = MatrixSymbol('D', 2, 1)
  1307. name_expr = ("test", [Equality(B, A[0, :]),
  1308. Equality(C, A[1, :]),
  1309. Equality(D, A[:, 2])])
  1310. result = codegen(name_expr, "f95", "test", header=False, empty=False)
  1311. source = result[0][1]
  1312. expected = (
  1313. "subroutine test(A, B, C, D)\n"
  1314. "implicit none\n"
  1315. "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n"
  1316. "REAL*8, intent(out), dimension(1:1, 1:3) :: B\n"
  1317. "REAL*8, intent(out), dimension(1:1, 1:3) :: C\n"
  1318. "REAL*8, intent(out), dimension(1:2, 1:1) :: D\n"
  1319. "B(1, 1) = A(1, 1)\n"
  1320. "B(1, 2) = A(1, 2)\n"
  1321. "B(1, 3) = A(1, 3)\n"
  1322. "C(1, 1) = A(2, 1)\n"
  1323. "C(1, 2) = A(2, 2)\n"
  1324. "C(1, 3) = A(2, 3)\n"
  1325. "D(1, 1) = A(1, 3)\n"
  1326. "D(2, 1) = A(2, 3)\n"
  1327. "end subroutine\n"
  1328. )
  1329. assert source == expected
  1330. def test_fcode_matrixsymbol_slice_autoname():
  1331. # see issue #8093
  1332. A = MatrixSymbol('A', 2, 3)
  1333. name_expr = ("test", A[:, 1])
  1334. result = codegen(name_expr, "f95", "test", header=False, empty=False)
  1335. source = result[0][1]
  1336. expected = (
  1337. "subroutine test(A, out_%(hash)s)\n"
  1338. "implicit none\n"
  1339. "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n"
  1340. "REAL*8, intent(out), dimension(1:2, 1:1) :: out_%(hash)s\n"
  1341. "out_%(hash)s(1, 1) = A(1, 2)\n"
  1342. "out_%(hash)s(2, 1) = A(2, 2)\n"
  1343. "end subroutine\n"
  1344. )
  1345. # look for the magic number
  1346. a = source.splitlines()[3]
  1347. b = a.split('_')
  1348. out = b[1]
  1349. expected = expected % {'hash': out}
  1350. assert source == expected
  1351. def test_global_vars():
  1352. x, y, z, t = symbols("x y z t")
  1353. result = codegen(('f', x*y), "F95", header=False, empty=False,
  1354. global_vars=(y,))
  1355. source = result[0][1]
  1356. expected = (
  1357. "REAL*8 function f(x)\n"
  1358. "implicit none\n"
  1359. "REAL*8, intent(in) :: x\n"
  1360. "f = x*y\n"
  1361. "end function\n"
  1362. )
  1363. assert source == expected
  1364. expected = (
  1365. '#include "f.h"\n'
  1366. '#include <math.h>\n'
  1367. 'double f(double x, double y) {\n'
  1368. ' double f_result;\n'
  1369. ' f_result = x*y + z;\n'
  1370. ' return f_result;\n'
  1371. '}\n'
  1372. )
  1373. result = codegen(('f', x*y+z), "C", header=False, empty=False,
  1374. global_vars=(z, t))
  1375. source = result[0][1]
  1376. assert source == expected
  1377. def test_custom_codegen():
  1378. from sympy.printing.c import C99CodePrinter
  1379. from sympy.functions.elementary.exponential import exp
  1380. printer = C99CodePrinter(settings={'user_functions': {'exp': 'fastexp'}})
  1381. x, y = symbols('x y')
  1382. expr = exp(x + y)
  1383. # replace math.h with a different header
  1384. gen = C99CodeGen(printer=printer,
  1385. preprocessor_statements=['#include "fastexp.h"'])
  1386. expected = (
  1387. '#include "expr.h"\n'
  1388. '#include "fastexp.h"\n'
  1389. 'double expr(double x, double y) {\n'
  1390. ' double expr_result;\n'
  1391. ' expr_result = fastexp(x + y);\n'
  1392. ' return expr_result;\n'
  1393. '}\n'
  1394. )
  1395. result = codegen(('expr', expr), header=False, empty=False, code_gen=gen)
  1396. source = result[0][1]
  1397. assert source == expected
  1398. # use both math.h and an external header
  1399. gen = C99CodeGen(printer=printer)
  1400. gen.preprocessor_statements.append('#include "fastexp.h"')
  1401. expected = (
  1402. '#include "expr.h"\n'
  1403. '#include <math.h>\n'
  1404. '#include "fastexp.h"\n'
  1405. 'double expr(double x, double y) {\n'
  1406. ' double expr_result;\n'
  1407. ' expr_result = fastexp(x + y);\n'
  1408. ' return expr_result;\n'
  1409. '}\n'
  1410. )
  1411. result = codegen(('expr', expr), header=False, empty=False, code_gen=gen)
  1412. source = result[0][1]
  1413. assert source == expected
  1414. def test_c_with_printer():
  1415. #issue 13586
  1416. from sympy.printing.c import C99CodePrinter
  1417. class CustomPrinter(C99CodePrinter):
  1418. def _print_Pow(self, expr):
  1419. return "fastpow({}, {})".format(self._print(expr.base),
  1420. self._print(expr.exp))
  1421. x = symbols('x')
  1422. expr = x**3
  1423. expected =[
  1424. ("file.c",
  1425. "#include \"file.h\"\n"
  1426. "#include <math.h>\n"
  1427. "double test(double x) {\n"
  1428. " double test_result;\n"
  1429. " test_result = fastpow(x, 3);\n"
  1430. " return test_result;\n"
  1431. "}\n"),
  1432. ("file.h",
  1433. "#ifndef PROJECT__FILE__H\n"
  1434. "#define PROJECT__FILE__H\n"
  1435. "double test(double x);\n"
  1436. "#endif\n")
  1437. ]
  1438. result = codegen(("test", expr), "C","file", header=False, empty=False, printer = CustomPrinter())
  1439. assert result == expected
  1440. def test_fcode_complex():
  1441. import sympy.utilities.codegen
  1442. sympy.utilities.codegen.COMPLEX_ALLOWED = True
  1443. x = Symbol('x', real=True)
  1444. y = Symbol('y',real=True)
  1445. result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False)
  1446. source = (result[0][1])
  1447. expected = (
  1448. "REAL*8 function test(x, y)\n"
  1449. "implicit none\n"
  1450. "REAL*8, intent(in) :: x\n"
  1451. "REAL*8, intent(in) :: y\n"
  1452. "test = x + y\n"
  1453. "end function\n")
  1454. assert source == expected
  1455. x = Symbol('x')
  1456. y = Symbol('y',real=True)
  1457. result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False)
  1458. source = (result[0][1])
  1459. expected = (
  1460. "COMPLEX*16 function test(x, y)\n"
  1461. "implicit none\n"
  1462. "COMPLEX*16, intent(in) :: x\n"
  1463. "REAL*8, intent(in) :: y\n"
  1464. "test = x + y\n"
  1465. "end function\n"
  1466. )
  1467. assert source==expected
  1468. sympy.utilities.codegen.COMPLEX_ALLOWED = False