test_autowrap.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import sympy
  2. import tempfile
  3. import os
  4. from sympy.core.mod import Mod
  5. from sympy.core.relational import Eq
  6. from sympy.core.symbol import symbols
  7. from sympy.external import import_module
  8. from sympy.tensor import IndexedBase, Idx
  9. from sympy.utilities.autowrap import autowrap, ufuncify, CodeWrapError
  10. from sympy.testing.pytest import skip
  11. numpy = import_module('numpy', min_module_version='1.6.1')
  12. Cython = import_module('Cython', min_module_version='0.15.1')
  13. f2py = import_module('numpy.f2py', import_kwargs={'fromlist': ['f2py']})
  14. f2pyworks = False
  15. if f2py:
  16. try:
  17. autowrap(symbols('x'), 'f95', 'f2py')
  18. except (CodeWrapError, ImportError, OSError):
  19. f2pyworks = False
  20. else:
  21. f2pyworks = True
  22. a, b, c = symbols('a b c')
  23. n, m, d = symbols('n m d', integer=True)
  24. A, B, C = symbols('A B C', cls=IndexedBase)
  25. i = Idx('i', m)
  26. j = Idx('j', n)
  27. k = Idx('k', d)
  28. def has_module(module):
  29. """
  30. Return True if module exists, otherwise run skip().
  31. module should be a string.
  32. """
  33. # To give a string of the module name to skip(), this function takes a
  34. # string. So we don't waste time running import_module() more than once,
  35. # just map the three modules tested here in this dict.
  36. modnames = {'numpy': numpy, 'Cython': Cython, 'f2py': f2py}
  37. if modnames[module]:
  38. if module == 'f2py' and not f2pyworks:
  39. skip("Couldn't run f2py.")
  40. return True
  41. skip("Couldn't import %s." % module)
  42. #
  43. # test runners used by several language-backend combinations
  44. #
  45. def runtest_autowrap_twice(language, backend):
  46. f = autowrap((((a + b)/c)**5).expand(), language, backend)
  47. g = autowrap((((a + b)/c)**4).expand(), language, backend)
  48. # check that autowrap updates the module name. Else, g gives the same as f
  49. assert f(1, -2, 1) == -1.0
  50. assert g(1, -2, 1) == 1.0
  51. def runtest_autowrap_trace(language, backend):
  52. has_module('numpy')
  53. trace = autowrap(A[i, i], language, backend)
  54. assert trace(numpy.eye(100)) == 100
  55. def runtest_autowrap_matrix_vector(language, backend):
  56. has_module('numpy')
  57. x, y = symbols('x y', cls=IndexedBase)
  58. expr = Eq(y[i], A[i, j]*x[j])
  59. mv = autowrap(expr, language, backend)
  60. # compare with numpy's dot product
  61. M = numpy.random.rand(10, 20)
  62. x = numpy.random.rand(20)
  63. y = numpy.dot(M, x)
  64. assert numpy.sum(numpy.abs(y - mv(M, x))) < 1e-13
  65. def runtest_autowrap_matrix_matrix(language, backend):
  66. has_module('numpy')
  67. expr = Eq(C[i, j], A[i, k]*B[k, j])
  68. matmat = autowrap(expr, language, backend)
  69. # compare with numpy's dot product
  70. M1 = numpy.random.rand(10, 20)
  71. M2 = numpy.random.rand(20, 15)
  72. M3 = numpy.dot(M1, M2)
  73. assert numpy.sum(numpy.abs(M3 - matmat(M1, M2))) < 1e-13
  74. def runtest_ufuncify(language, backend):
  75. has_module('numpy')
  76. a, b, c = symbols('a b c')
  77. fabc = ufuncify([a, b, c], a*b + c, backend=backend)
  78. facb = ufuncify([a, c, b], a*b + c, backend=backend)
  79. grid = numpy.linspace(-2, 2, 50)
  80. b = numpy.linspace(-5, 4, 50)
  81. c = numpy.linspace(-1, 1, 50)
  82. expected = grid*b + c
  83. numpy.testing.assert_allclose(fabc(grid, b, c), expected)
  84. numpy.testing.assert_allclose(facb(grid, c, b), expected)
  85. def runtest_issue_10274(language, backend):
  86. expr = (a - b + c)**(13)
  87. tmp = tempfile.mkdtemp()
  88. f = autowrap(expr, language, backend, tempdir=tmp,
  89. helpers=('helper', a - b + c, (a, b, c)))
  90. assert f(1, 1, 1) == 1
  91. for file in os.listdir(tmp):
  92. if not (file.startswith("wrapped_code_") and file.endswith(".c")):
  93. continue
  94. with open(tmp + '/' + file) as fil:
  95. lines = fil.readlines()
  96. assert lines[0] == "/******************************************************************************\n"
  97. assert "Code generated with SymPy " + sympy.__version__ in lines[1]
  98. assert lines[2:] == [
  99. " * *\n",
  100. " * See http://www.sympy.org/ for more information. *\n",
  101. " * *\n",
  102. " * This file is part of 'autowrap' *\n",
  103. " ******************************************************************************/\n",
  104. "#include " + '"' + file[:-1]+ 'h"' + "\n",
  105. "#include <math.h>\n",
  106. "\n",
  107. "double helper(double a, double b, double c) {\n",
  108. "\n",
  109. " double helper_result;\n",
  110. " helper_result = a - b + c;\n",
  111. " return helper_result;\n",
  112. "\n",
  113. "}\n",
  114. "\n",
  115. "double autofunc(double a, double b, double c) {\n",
  116. "\n",
  117. " double autofunc_result;\n",
  118. " autofunc_result = pow(helper(a, b, c), 13);\n",
  119. " return autofunc_result;\n",
  120. "\n",
  121. "}\n",
  122. ]
  123. def runtest_issue_15337(language, backend):
  124. has_module('numpy')
  125. # NOTE : autowrap was originally designed to only accept an iterable for
  126. # the kwarg "helpers", but in issue 10274 the user mistakenly thought that
  127. # if there was only a single helper it did not need to be passed via an
  128. # iterable that wrapped the helper tuple. There were no tests for this
  129. # behavior so when the code was changed to accept a single tuple it broke
  130. # the original behavior. These tests below ensure that both now work.
  131. a, b, c, d, e = symbols('a, b, c, d, e')
  132. expr = (a - b + c - d + e)**13
  133. exp_res = (1. - 2. + 3. - 4. + 5.)**13
  134. f = autowrap(expr, language, backend, args=(a, b, c, d, e),
  135. helpers=('f1', a - b + c, (a, b, c)))
  136. numpy.testing.assert_allclose(f(1, 2, 3, 4, 5), exp_res)
  137. f = autowrap(expr, language, backend, args=(a, b, c, d, e),
  138. helpers=(('f1', a - b, (a, b)), ('f2', c - d, (c, d))))
  139. numpy.testing.assert_allclose(f(1, 2, 3, 4, 5), exp_res)
  140. def test_issue_15230():
  141. has_module('f2py')
  142. x, y = symbols('x, y')
  143. expr = Mod(x, 3.0) - Mod(y, -2.0)
  144. f = autowrap(expr, args=[x, y], language='F95')
  145. exp_res = float(expr.xreplace({x: 3.5, y: 2.7}).evalf())
  146. assert abs(f(3.5, 2.7) - exp_res) < 1e-14
  147. x, y = symbols('x, y', integer=True)
  148. expr = Mod(x, 3) - Mod(y, -2)
  149. f = autowrap(expr, args=[x, y], language='F95')
  150. assert f(3, 2) == expr.xreplace({x: 3, y: 2})
  151. #
  152. # tests of language-backend combinations
  153. #
  154. # f2py
  155. def test_wrap_twice_f95_f2py():
  156. has_module('f2py')
  157. runtest_autowrap_twice('f95', 'f2py')
  158. def test_autowrap_trace_f95_f2py():
  159. has_module('f2py')
  160. runtest_autowrap_trace('f95', 'f2py')
  161. def test_autowrap_matrix_vector_f95_f2py():
  162. has_module('f2py')
  163. runtest_autowrap_matrix_vector('f95', 'f2py')
  164. def test_autowrap_matrix_matrix_f95_f2py():
  165. has_module('f2py')
  166. runtest_autowrap_matrix_matrix('f95', 'f2py')
  167. def test_ufuncify_f95_f2py():
  168. has_module('f2py')
  169. runtest_ufuncify('f95', 'f2py')
  170. def test_issue_15337_f95_f2py():
  171. has_module('f2py')
  172. runtest_issue_15337('f95', 'f2py')
  173. # Cython
  174. def test_wrap_twice_c_cython():
  175. has_module('Cython')
  176. runtest_autowrap_twice('C', 'cython')
  177. def test_autowrap_trace_C_Cython():
  178. has_module('Cython')
  179. runtest_autowrap_trace('C99', 'cython')
  180. def test_autowrap_matrix_vector_C_cython():
  181. has_module('Cython')
  182. runtest_autowrap_matrix_vector('C99', 'cython')
  183. def test_autowrap_matrix_matrix_C_cython():
  184. has_module('Cython')
  185. runtest_autowrap_matrix_matrix('C99', 'cython')
  186. def test_ufuncify_C_Cython():
  187. has_module('Cython')
  188. runtest_ufuncify('C99', 'cython')
  189. def test_issue_10274_C_cython():
  190. has_module('Cython')
  191. runtest_issue_10274('C89', 'cython')
  192. def test_issue_15337_C_cython():
  193. has_module('Cython')
  194. runtest_issue_15337('C89', 'cython')
  195. def test_autowrap_custom_printer():
  196. has_module('Cython')
  197. from sympy.core.numbers import pi
  198. from sympy.utilities.codegen import C99CodeGen
  199. from sympy.printing.c import C99CodePrinter
  200. class PiPrinter(C99CodePrinter):
  201. def _print_Pi(self, expr):
  202. return "S_PI"
  203. printer = PiPrinter()
  204. gen = C99CodeGen(printer=printer)
  205. gen.preprocessor_statements.append('#include "shortpi.h"')
  206. expr = pi * a
  207. expected = (
  208. '#include "%s"\n'
  209. '#include <math.h>\n'
  210. '#include "shortpi.h"\n'
  211. '\n'
  212. 'double autofunc(double a) {\n'
  213. '\n'
  214. ' double autofunc_result;\n'
  215. ' autofunc_result = S_PI*a;\n'
  216. ' return autofunc_result;\n'
  217. '\n'
  218. '}\n'
  219. )
  220. tmpdir = tempfile.mkdtemp()
  221. # write a trivial header file to use in the generated code
  222. with open(os.path.join(tmpdir, 'shortpi.h'), 'w') as f:
  223. f.write('#define S_PI 3.14')
  224. func = autowrap(expr, backend='cython', tempdir=tmpdir, code_gen=gen)
  225. assert func(4.2) == 3.14 * 4.2
  226. # check that the generated code is correct
  227. for filename in os.listdir(tmpdir):
  228. if filename.startswith('wrapped_code') and filename.endswith('.c'):
  229. with open(os.path.join(tmpdir, filename)) as f:
  230. lines = f.readlines()
  231. expected = expected % filename.replace('.c', '.h')
  232. assert ''.join(lines[7:]) == expected
  233. # Numpy
  234. def test_ufuncify_numpy():
  235. # This test doesn't use Cython, but if Cython works, then there is a valid
  236. # C compiler, which is needed.
  237. has_module('Cython')
  238. runtest_ufuncify('C99', 'numpy')