test_compilation.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import shutil
  2. from sympy.external import import_module
  3. from sympy.testing.pytest import skip
  4. from sympy.utilities._compilation.compilation import compile_link_import_strings
  5. numpy = import_module('numpy')
  6. cython = import_module('cython')
  7. _sources1 = [
  8. ('sigmoid.c', r"""
  9. #include <math.h>
  10. void sigmoid(int n, const double * const restrict in,
  11. double * const restrict out, double lim){
  12. for (int i=0; i<n; ++i){
  13. const double x = in[i];
  14. out[i] = x*pow(pow(x/lim, 8)+1, -1./8.);
  15. }
  16. }
  17. """),
  18. ('_sigmoid.pyx', r"""
  19. import numpy as np
  20. cimport numpy as cnp
  21. cdef extern void c_sigmoid "sigmoid" (int, const double * const,
  22. double * const, double)
  23. def sigmoid(double [:] inp, double lim=350.0):
  24. cdef cnp.ndarray[cnp.float64_t, ndim=1] out = np.empty(
  25. inp.size, dtype=np.float64)
  26. c_sigmoid(inp.size, &inp[0], &out[0], lim)
  27. return out
  28. """)
  29. ]
  30. def npy(data, lim=350.0):
  31. return data/((data/lim)**8+1)**(1/8.)
  32. def test_compile_link_import_strings():
  33. if not numpy:
  34. skip("numpy not installed.")
  35. if not cython:
  36. skip("cython not installed.")
  37. from sympy.utilities._compilation import has_c
  38. if not has_c():
  39. skip("No C compiler found.")
  40. compile_kw = {"std": 'c99', "include_dirs": [numpy.get_include()]}
  41. info = None
  42. try:
  43. mod, info = compile_link_import_strings(_sources1, compile_kwargs=compile_kw)
  44. data = numpy.random.random(1024*1024*8) # 64 MB of RAM needed..
  45. res_mod = mod.sigmoid(data)
  46. res_npy = npy(data)
  47. assert numpy.allclose(res_mod, res_npy)
  48. finally:
  49. if info and info['build_dir']:
  50. shutil.rmtree(info['build_dir'])