test_autowrap.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. # Tests that require installed backends go into
  2. # sympy/test_external/test_autowrap
  3. import os
  4. import tempfile
  5. import shutil
  6. from io import StringIO
  7. from sympy.core import symbols, Eq
  8. from sympy.utilities.autowrap import (autowrap, binary_function,
  9. CythonCodeWrapper, UfuncifyCodeWrapper, CodeWrapper)
  10. from sympy.utilities.codegen import (
  11. CCodeGen, C99CodeGen, CodeGenArgumentListError, make_routine
  12. )
  13. from sympy.testing.pytest import raises
  14. from sympy.testing.tmpfiles import TmpFileManager
  15. def get_string(dump_fn, routines, prefix="file", **kwargs):
  16. """Wrapper for dump_fn. dump_fn writes its results to a stream object and
  17. this wrapper returns the contents of that stream as a string. This
  18. auxiliary function is used by many tests below.
  19. The header and the empty lines are not generator to facilitate the
  20. testing of the output.
  21. """
  22. output = StringIO()
  23. dump_fn(routines, output, prefix, **kwargs)
  24. source = output.getvalue()
  25. output.close()
  26. return source
  27. def test_cython_wrapper_scalar_function():
  28. x, y, z = symbols('x,y,z')
  29. expr = (x + y)*z
  30. routine = make_routine("test", expr)
  31. code_gen = CythonCodeWrapper(CCodeGen())
  32. source = get_string(code_gen.dump_pyx, [routine])
  33. expected = (
  34. "cdef extern from 'file.h':\n"
  35. " double test(double x, double y, double z)\n"
  36. "\n"
  37. "def test_c(double x, double y, double z):\n"
  38. "\n"
  39. " return test(x, y, z)")
  40. assert source == expected
  41. def test_cython_wrapper_outarg():
  42. from sympy.core.relational import Equality
  43. x, y, z = symbols('x,y,z')
  44. code_gen = CythonCodeWrapper(C99CodeGen())
  45. routine = make_routine("test", Equality(z, x + y))
  46. source = get_string(code_gen.dump_pyx, [routine])
  47. expected = (
  48. "cdef extern from 'file.h':\n"
  49. " void test(double x, double y, double *z)\n"
  50. "\n"
  51. "def test_c(double x, double y):\n"
  52. "\n"
  53. " cdef double z = 0\n"
  54. " test(x, y, &z)\n"
  55. " return z")
  56. assert source == expected
  57. def test_cython_wrapper_inoutarg():
  58. from sympy.core.relational import Equality
  59. x, y, z = symbols('x,y,z')
  60. code_gen = CythonCodeWrapper(C99CodeGen())
  61. routine = make_routine("test", Equality(z, x + y + z))
  62. source = get_string(code_gen.dump_pyx, [routine])
  63. expected = (
  64. "cdef extern from 'file.h':\n"
  65. " void test(double x, double y, double *z)\n"
  66. "\n"
  67. "def test_c(double x, double y, double z):\n"
  68. "\n"
  69. " test(x, y, &z)\n"
  70. " return z")
  71. assert source == expected
  72. def test_cython_wrapper_compile_flags():
  73. from sympy.core.relational import Equality
  74. x, y, z = symbols('x,y,z')
  75. routine = make_routine("test", Equality(z, x + y))
  76. code_gen = CythonCodeWrapper(CCodeGen())
  77. expected = """\
  78. from setuptools import setup
  79. from setuptools import Extension
  80. from Cython.Build import cythonize
  81. cy_opts = {'compiler_directives': {'language_level': '3'}}
  82. ext_mods = [Extension(
  83. 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
  84. include_dirs=[],
  85. library_dirs=[],
  86. libraries=[],
  87. extra_compile_args=['-std=c99'],
  88. extra_link_args=[]
  89. )]
  90. setup(ext_modules=cythonize(ext_mods, **cy_opts))
  91. """ % {'num': CodeWrapper._module_counter}
  92. temp_dir = tempfile.mkdtemp()
  93. TmpFileManager.tmp_folder(temp_dir)
  94. setup_file_path = os.path.join(temp_dir, 'setup.py')
  95. code_gen._prepare_files(routine, build_dir=temp_dir)
  96. with open(setup_file_path) as f:
  97. setup_text = f.read()
  98. assert setup_text == expected
  99. code_gen = CythonCodeWrapper(CCodeGen(),
  100. include_dirs=['/usr/local/include', '/opt/booger/include'],
  101. library_dirs=['/user/local/lib'],
  102. libraries=['thelib', 'nilib'],
  103. extra_compile_args=['-slow-math'],
  104. extra_link_args=['-lswamp', '-ltrident'],
  105. cythonize_options={'compiler_directives': {'boundscheck': False}}
  106. )
  107. expected = """\
  108. from setuptools import setup
  109. from setuptools import Extension
  110. from Cython.Build import cythonize
  111. cy_opts = {'compiler_directives': {'boundscheck': False}}
  112. ext_mods = [Extension(
  113. 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
  114. include_dirs=['/usr/local/include', '/opt/booger/include'],
  115. library_dirs=['/user/local/lib'],
  116. libraries=['thelib', 'nilib'],
  117. extra_compile_args=['-slow-math', '-std=c99'],
  118. extra_link_args=['-lswamp', '-ltrident']
  119. )]
  120. setup(ext_modules=cythonize(ext_mods, **cy_opts))
  121. """ % {'num': CodeWrapper._module_counter}
  122. code_gen._prepare_files(routine, build_dir=temp_dir)
  123. with open(setup_file_path) as f:
  124. setup_text = f.read()
  125. assert setup_text == expected
  126. expected = """\
  127. from setuptools import setup
  128. from setuptools import Extension
  129. from Cython.Build import cythonize
  130. cy_opts = {'compiler_directives': {'boundscheck': False}}
  131. import numpy as np
  132. ext_mods = [Extension(
  133. 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
  134. include_dirs=['/usr/local/include', '/opt/booger/include', np.get_include()],
  135. library_dirs=['/user/local/lib'],
  136. libraries=['thelib', 'nilib'],
  137. extra_compile_args=['-slow-math', '-std=c99'],
  138. extra_link_args=['-lswamp', '-ltrident']
  139. )]
  140. setup(ext_modules=cythonize(ext_mods, **cy_opts))
  141. """ % {'num': CodeWrapper._module_counter}
  142. code_gen._need_numpy = True
  143. code_gen._prepare_files(routine, build_dir=temp_dir)
  144. with open(setup_file_path) as f:
  145. setup_text = f.read()
  146. assert setup_text == expected
  147. TmpFileManager.cleanup()
  148. def test_cython_wrapper_unique_dummyvars():
  149. from sympy.core.relational import Equality
  150. from sympy.core.symbol import Dummy
  151. x, y, z = Dummy('x'), Dummy('y'), Dummy('z')
  152. x_id, y_id, z_id = [str(d.dummy_index) for d in [x, y, z]]
  153. expr = Equality(z, x + y)
  154. routine = make_routine("test", expr)
  155. code_gen = CythonCodeWrapper(CCodeGen())
  156. source = get_string(code_gen.dump_pyx, [routine])
  157. expected_template = (
  158. "cdef extern from 'file.h':\n"
  159. " void test(double x_{x_id}, double y_{y_id}, double *z_{z_id})\n"
  160. "\n"
  161. "def test_c(double x_{x_id}, double y_{y_id}):\n"
  162. "\n"
  163. " cdef double z_{z_id} = 0\n"
  164. " test(x_{x_id}, y_{y_id}, &z_{z_id})\n"
  165. " return z_{z_id}")
  166. expected = expected_template.format(x_id=x_id, y_id=y_id, z_id=z_id)
  167. assert source == expected
  168. def test_autowrap_dummy():
  169. x, y, z = symbols('x y z')
  170. # Uses DummyWrapper to test that codegen works as expected
  171. f = autowrap(x + y, backend='dummy')
  172. assert f() == str(x + y)
  173. assert f.args == "x, y"
  174. assert f.returns == "nameless"
  175. f = autowrap(Eq(z, x + y), backend='dummy')
  176. assert f() == str(x + y)
  177. assert f.args == "x, y"
  178. assert f.returns == "z"
  179. f = autowrap(Eq(z, x + y + z), backend='dummy')
  180. assert f() == str(x + y + z)
  181. assert f.args == "x, y, z"
  182. assert f.returns == "z"
  183. def test_autowrap_args():
  184. x, y, z = symbols('x y z')
  185. raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y),
  186. backend='dummy', args=[x]))
  187. f = autowrap(Eq(z, x + y), backend='dummy', args=[y, x])
  188. assert f() == str(x + y)
  189. assert f.args == "y, x"
  190. assert f.returns == "z"
  191. raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y + z),
  192. backend='dummy', args=[x, y]))
  193. f = autowrap(Eq(z, x + y + z), backend='dummy', args=[y, x, z])
  194. assert f() == str(x + y + z)
  195. assert f.args == "y, x, z"
  196. assert f.returns == "z"
  197. f = autowrap(Eq(z, x + y + z), backend='dummy', args=(y, x, z))
  198. assert f() == str(x + y + z)
  199. assert f.args == "y, x, z"
  200. assert f.returns == "z"
  201. def test_autowrap_store_files():
  202. x, y = symbols('x y')
  203. tmp = tempfile.mkdtemp()
  204. TmpFileManager.tmp_folder(tmp)
  205. f = autowrap(x + y, backend='dummy', tempdir=tmp)
  206. assert f() == str(x + y)
  207. assert os.access(tmp, os.F_OK)
  208. TmpFileManager.cleanup()
  209. def test_autowrap_store_files_issue_gh12939():
  210. x, y = symbols('x y')
  211. tmp = './tmp'
  212. saved_cwd = os.getcwd()
  213. temp_cwd = tempfile.mkdtemp()
  214. try:
  215. os.chdir(temp_cwd)
  216. f = autowrap(x + y, backend='dummy', tempdir=tmp)
  217. assert f() == str(x + y)
  218. assert os.access(tmp, os.F_OK)
  219. finally:
  220. os.chdir(saved_cwd)
  221. shutil.rmtree(temp_cwd)
  222. def test_binary_function():
  223. x, y = symbols('x y')
  224. f = binary_function('f', x + y, backend='dummy')
  225. assert f._imp_() == str(x + y)
  226. def test_ufuncify_source():
  227. x, y, z = symbols('x,y,z')
  228. code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
  229. routine = make_routine("test", x + y + z)
  230. source = get_string(code_wrapper.dump_c, [routine])
  231. expected = """\
  232. #include "Python.h"
  233. #include "math.h"
  234. #include "numpy/ndarraytypes.h"
  235. #include "numpy/ufuncobject.h"
  236. #include "numpy/halffloat.h"
  237. #include "file.h"
  238. static PyMethodDef wrapper_module_%(num)sMethods[] = {
  239. {NULL, NULL, 0, NULL}
  240. };
  241. static void test_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
  242. {
  243. npy_intp i;
  244. npy_intp n = dimensions[0];
  245. char *in0 = args[0];
  246. char *in1 = args[1];
  247. char *in2 = args[2];
  248. char *out0 = args[3];
  249. npy_intp in0_step = steps[0];
  250. npy_intp in1_step = steps[1];
  251. npy_intp in2_step = steps[2];
  252. npy_intp out0_step = steps[3];
  253. for (i = 0; i < n; i++) {
  254. *((double *)out0) = test(*(double *)in0, *(double *)in1, *(double *)in2);
  255. in0 += in0_step;
  256. in1 += in1_step;
  257. in2 += in2_step;
  258. out0 += out0_step;
  259. }
  260. }
  261. PyUFuncGenericFunction test_funcs[1] = {&test_ufunc};
  262. static char test_types[4] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
  263. static void *test_data[1] = {NULL};
  264. #if PY_VERSION_HEX >= 0x03000000
  265. static struct PyModuleDef moduledef = {
  266. PyModuleDef_HEAD_INIT,
  267. "wrapper_module_%(num)s",
  268. NULL,
  269. -1,
  270. wrapper_module_%(num)sMethods,
  271. NULL,
  272. NULL,
  273. NULL,
  274. NULL
  275. };
  276. PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
  277. {
  278. PyObject *m, *d;
  279. PyObject *ufunc0;
  280. m = PyModule_Create(&moduledef);
  281. if (!m) {
  282. return NULL;
  283. }
  284. import_array();
  285. import_umath();
  286. d = PyModule_GetDict(m);
  287. ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
  288. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  289. PyDict_SetItemString(d, "test", ufunc0);
  290. Py_DECREF(ufunc0);
  291. return m;
  292. }
  293. #else
  294. PyMODINIT_FUNC initwrapper_module_%(num)s(void)
  295. {
  296. PyObject *m, *d;
  297. PyObject *ufunc0;
  298. m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
  299. if (m == NULL) {
  300. return;
  301. }
  302. import_array();
  303. import_umath();
  304. d = PyModule_GetDict(m);
  305. ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
  306. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  307. PyDict_SetItemString(d, "test", ufunc0);
  308. Py_DECREF(ufunc0);
  309. }
  310. #endif""" % {'num': CodeWrapper._module_counter}
  311. assert source == expected
  312. def test_ufuncify_source_multioutput():
  313. x, y, z = symbols('x,y,z')
  314. var_symbols = (x, y, z)
  315. expr = x + y**3 + 10*z**2
  316. code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
  317. routines = [make_routine("func{}".format(i), expr.diff(var_symbols[i]), var_symbols) for i in range(len(var_symbols))]
  318. source = get_string(code_wrapper.dump_c, routines, funcname='multitest')
  319. expected = """\
  320. #include "Python.h"
  321. #include "math.h"
  322. #include "numpy/ndarraytypes.h"
  323. #include "numpy/ufuncobject.h"
  324. #include "numpy/halffloat.h"
  325. #include "file.h"
  326. static PyMethodDef wrapper_module_%(num)sMethods[] = {
  327. {NULL, NULL, 0, NULL}
  328. };
  329. static void multitest_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
  330. {
  331. npy_intp i;
  332. npy_intp n = dimensions[0];
  333. char *in0 = args[0];
  334. char *in1 = args[1];
  335. char *in2 = args[2];
  336. char *out0 = args[3];
  337. char *out1 = args[4];
  338. char *out2 = args[5];
  339. npy_intp in0_step = steps[0];
  340. npy_intp in1_step = steps[1];
  341. npy_intp in2_step = steps[2];
  342. npy_intp out0_step = steps[3];
  343. npy_intp out1_step = steps[4];
  344. npy_intp out2_step = steps[5];
  345. for (i = 0; i < n; i++) {
  346. *((double *)out0) = func0(*(double *)in0, *(double *)in1, *(double *)in2);
  347. *((double *)out1) = func1(*(double *)in0, *(double *)in1, *(double *)in2);
  348. *((double *)out2) = func2(*(double *)in0, *(double *)in1, *(double *)in2);
  349. in0 += in0_step;
  350. in1 += in1_step;
  351. in2 += in2_step;
  352. out0 += out0_step;
  353. out1 += out1_step;
  354. out2 += out2_step;
  355. }
  356. }
  357. PyUFuncGenericFunction multitest_funcs[1] = {&multitest_ufunc};
  358. static char multitest_types[6] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
  359. static void *multitest_data[1] = {NULL};
  360. #if PY_VERSION_HEX >= 0x03000000
  361. static struct PyModuleDef moduledef = {
  362. PyModuleDef_HEAD_INIT,
  363. "wrapper_module_%(num)s",
  364. NULL,
  365. -1,
  366. wrapper_module_%(num)sMethods,
  367. NULL,
  368. NULL,
  369. NULL,
  370. NULL
  371. };
  372. PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
  373. {
  374. PyObject *m, *d;
  375. PyObject *ufunc0;
  376. m = PyModule_Create(&moduledef);
  377. if (!m) {
  378. return NULL;
  379. }
  380. import_array();
  381. import_umath();
  382. d = PyModule_GetDict(m);
  383. ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
  384. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  385. PyDict_SetItemString(d, "multitest", ufunc0);
  386. Py_DECREF(ufunc0);
  387. return m;
  388. }
  389. #else
  390. PyMODINIT_FUNC initwrapper_module_%(num)s(void)
  391. {
  392. PyObject *m, *d;
  393. PyObject *ufunc0;
  394. m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
  395. if (m == NULL) {
  396. return;
  397. }
  398. import_array();
  399. import_umath();
  400. d = PyModule_GetDict(m);
  401. ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
  402. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  403. PyDict_SetItemString(d, "multitest", ufunc0);
  404. Py_DECREF(ufunc0);
  405. }
  406. #endif""" % {'num': CodeWrapper._module_counter}
  407. assert source == expected