test_fnodes.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import os
  2. import tempfile
  3. from sympy.core.symbol import (Symbol, symbols)
  4. from sympy.codegen.ast import (
  5. Assignment, Print, Declaration, FunctionDefinition, Return, real,
  6. FunctionCall, Variable, Element, integer
  7. )
  8. from sympy.codegen.fnodes import (
  9. allocatable, ArrayConstructor, isign, dsign, cmplx, kind, literal_dp,
  10. Program, Module, use, Subroutine, dimension, assumed_extent, ImpliedDoLoop,
  11. intent_out, size, Do, SubroutineCall, sum_, array, bind_C
  12. )
  13. from sympy.codegen.futils import render_as_module
  14. from sympy.core.expr import unchanged
  15. from sympy.external import import_module
  16. from sympy.printing.codeprinter import fcode
  17. from sympy.utilities._compilation import has_fortran, compile_run_strings, compile_link_import_strings
  18. from sympy.utilities._compilation.util import may_xfail
  19. from sympy.testing.pytest import skip, XFAIL
  20. cython = import_module('cython')
  21. np = import_module('numpy')
  22. def test_size():
  23. x = Symbol('x', real=True)
  24. sx = size(x)
  25. assert fcode(sx, source_format='free') == 'size(x)'
  26. @may_xfail
  27. def test_size_assumed_shape():
  28. if not has_fortran():
  29. skip("No fortran compiler found.")
  30. a = Symbol('a', real=True)
  31. body = [Return((sum_(a**2)/size(a))**.5)]
  32. arr = array(a, dim=[':'], intent='in')
  33. fd = FunctionDefinition(real, 'rms', [arr], body)
  34. render_as_module([fd], 'mod_rms')
  35. (stdout, stderr), info = compile_run_strings([
  36. ('rms.f90', render_as_module([fd], 'mod_rms')),
  37. ('main.f90', (
  38. 'program myprog\n'
  39. 'use mod_rms, only: rms\n'
  40. 'real*8, dimension(4), parameter :: x = [4, 2, 2, 2]\n'
  41. 'print *, dsqrt(7d0) - rms(x)\n'
  42. 'end program\n'
  43. ))
  44. ], clean=True)
  45. assert '0.00000' in stdout
  46. assert stderr == ''
  47. assert info['exit_status'] == os.EX_OK
  48. @XFAIL # https://github.com/sympy/sympy/issues/20265
  49. @may_xfail
  50. def test_ImpliedDoLoop():
  51. if not has_fortran():
  52. skip("No fortran compiler found.")
  53. a, i = symbols('a i', integer=True)
  54. idl = ImpliedDoLoop(i**3, i, -3, 3, 2)
  55. ac = ArrayConstructor([-28, idl, 28])
  56. a = array(a, dim=[':'], attrs=[allocatable])
  57. prog = Program('idlprog', [
  58. a.as_Declaration(),
  59. Assignment(a, ac),
  60. Print([a])
  61. ])
  62. fsrc = fcode(prog, standard=2003, source_format='free')
  63. (stdout, stderr), info = compile_run_strings([('main.f90', fsrc)], clean=True)
  64. for numstr in '-28 -27 -1 1 27 28'.split():
  65. assert numstr in stdout
  66. assert stderr == ''
  67. assert info['exit_status'] == os.EX_OK
  68. @may_xfail
  69. def test_Program():
  70. x = Symbol('x', real=True)
  71. vx = Variable.deduced(x, 42)
  72. decl = Declaration(vx)
  73. prnt = Print([x, x+1])
  74. prog = Program('foo', [decl, prnt])
  75. if not has_fortran():
  76. skip("No fortran compiler found.")
  77. (stdout, stderr), info = compile_run_strings([('main.f90', fcode(prog, standard=90))], clean=True)
  78. assert '42' in stdout
  79. assert '43' in stdout
  80. assert stderr == ''
  81. assert info['exit_status'] == os.EX_OK
  82. @may_xfail
  83. def test_Module():
  84. x = Symbol('x', real=True)
  85. v_x = Variable.deduced(x)
  86. sq = FunctionDefinition(real, 'sqr', [v_x], [Return(x**2)])
  87. mod_sq = Module('mod_sq', [], [sq])
  88. sq_call = FunctionCall('sqr', [42.])
  89. prg_sq = Program('foobar', [
  90. use('mod_sq', only=['sqr']),
  91. Print(['"Square of 42 = "', sq_call])
  92. ])
  93. if not has_fortran():
  94. skip("No fortran compiler found.")
  95. (stdout, stderr), info = compile_run_strings([
  96. ('mod_sq.f90', fcode(mod_sq, standard=90)),
  97. ('main.f90', fcode(prg_sq, standard=90))
  98. ], clean=True)
  99. assert '42' in stdout
  100. assert str(42**2) in stdout
  101. assert stderr == ''
  102. @XFAIL # https://github.com/sympy/sympy/issues/20265
  103. @may_xfail
  104. def test_Subroutine():
  105. # Code to generate the subroutine in the example from
  106. # http://www.fortran90.org/src/best-practices.html#arrays
  107. r = Symbol('r', real=True)
  108. i = Symbol('i', integer=True)
  109. v_r = Variable.deduced(r, attrs=(dimension(assumed_extent), intent_out))
  110. v_i = Variable.deduced(i)
  111. v_n = Variable('n', integer)
  112. do_loop = Do([
  113. Assignment(Element(r, [i]), literal_dp(1)/i**2)
  114. ], i, 1, v_n)
  115. sub = Subroutine("f", [v_r], [
  116. Declaration(v_n),
  117. Declaration(v_i),
  118. Assignment(v_n, size(r)),
  119. do_loop
  120. ])
  121. x = Symbol('x', real=True)
  122. v_x3 = Variable.deduced(x, attrs=[dimension(3)])
  123. mod = Module('mymod', definitions=[sub])
  124. prog = Program('foo', [
  125. use(mod, only=[sub]),
  126. Declaration(v_x3),
  127. SubroutineCall(sub, [v_x3]),
  128. Print([sum_(v_x3), v_x3])
  129. ])
  130. if not has_fortran():
  131. skip("No fortran compiler found.")
  132. (stdout, stderr), info = compile_run_strings([
  133. ('a.f90', fcode(mod, standard=90)),
  134. ('b.f90', fcode(prog, standard=90))
  135. ], clean=True)
  136. ref = [1.0/i**2 for i in range(1, 4)]
  137. assert str(sum(ref))[:-3] in stdout
  138. for _ in ref:
  139. assert str(_)[:-3] in stdout
  140. assert stderr == ''
  141. def test_isign():
  142. x = Symbol('x', integer=True)
  143. assert unchanged(isign, 1, x)
  144. assert fcode(isign(1, x), standard=95, source_format='free') == 'isign(1, x)'
  145. def test_dsign():
  146. x = Symbol('x')
  147. assert unchanged(dsign, 1, x)
  148. assert fcode(dsign(literal_dp(1), x), standard=95, source_format='free') == 'dsign(1d0, x)'
  149. def test_cmplx():
  150. x = Symbol('x')
  151. assert unchanged(cmplx, 1, x)
  152. def test_kind():
  153. x = Symbol('x')
  154. assert unchanged(kind, x)
  155. def test_literal_dp():
  156. assert fcode(literal_dp(0), source_format='free') == '0d0'
  157. @may_xfail
  158. def test_bind_C():
  159. if not has_fortran():
  160. skip("No fortran compiler found.")
  161. if not cython:
  162. skip("Cython not found.")
  163. if not np:
  164. skip("NumPy not found.")
  165. a = Symbol('a', real=True)
  166. s = Symbol('s', integer=True)
  167. body = [Return((sum_(a**2)/s)**.5)]
  168. arr = array(a, dim=[s], intent='in')
  169. fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')])
  170. f_mod = render_as_module([fd], 'mod_rms')
  171. with tempfile.TemporaryDirectory() as folder:
  172. mod, info = compile_link_import_strings([
  173. ('rms.f90', f_mod),
  174. ('_rms.pyx', (
  175. "#cython: language_level={}\n".format("3") +
  176. "cdef extern double rms(double*, int*)\n"
  177. "def py_rms(double[::1] x):\n"
  178. " cdef int s = x.size\n"
  179. " return rms(&x[0], &s)\n"))
  180. ], build_dir=folder)
  181. assert abs(mod.py_rms(np.array([2., 4., 2., 2.])) - 7**0.5) < 1e-14