test_algorithms.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import tempfile
  2. import sympy as sp
  3. from sympy.codegen.ast import Assignment
  4. from sympy.codegen.algorithms import newtons_method, newtons_method_function
  5. from sympy.codegen.fnodes import bind_C
  6. from sympy.codegen.futils import render_as_module as f_module
  7. from sympy.codegen.pyutils import render_as_module as py_module
  8. from sympy.external import import_module
  9. from sympy.printing.codeprinter import ccode
  10. from sympy.utilities._compilation import compile_link_import_strings, has_c, has_fortran
  11. from sympy.utilities._compilation.util import may_xfail
  12. from sympy.testing.pytest import skip, raises
  13. cython = import_module('cython')
  14. wurlitzer = import_module('wurlitzer')
  15. def test_newtons_method():
  16. x, dx, atol = sp.symbols('x dx atol')
  17. expr = sp.cos(x) - x**3
  18. algo = newtons_method(expr, x, atol, dx)
  19. assert algo.has(Assignment(dx, -expr/expr.diff(x)))
  20. @may_xfail
  21. def test_newtons_method_function__ccode():
  22. x = sp.Symbol('x', real=True)
  23. expr = sp.cos(x) - x**3
  24. func = newtons_method_function(expr, x)
  25. if not cython:
  26. skip("cython not installed.")
  27. if not has_c():
  28. skip("No C compiler found.")
  29. compile_kw = {"std": 'c99'}
  30. with tempfile.TemporaryDirectory() as folder:
  31. mod, info = compile_link_import_strings([
  32. ('newton.c', ('#include <math.h>\n'
  33. '#include <stdio.h>\n') + ccode(func)),
  34. ('_newton.pyx', ("#cython: language_level={}\n".format("3") +
  35. "cdef extern double newton(double)\n"
  36. "def py_newton(x):\n"
  37. " return newton(x)\n"))
  38. ], build_dir=folder, compile_kwargs=compile_kw)
  39. assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
  40. @may_xfail
  41. def test_newtons_method_function__fcode():
  42. x = sp.Symbol('x', real=True)
  43. expr = sp.cos(x) - x**3
  44. func = newtons_method_function(expr, x, attrs=[bind_C(name='newton')])
  45. if not cython:
  46. skip("cython not installed.")
  47. if not has_fortran():
  48. skip("No Fortran compiler found.")
  49. f_mod = f_module([func], 'mod_newton')
  50. with tempfile.TemporaryDirectory() as folder:
  51. mod, info = compile_link_import_strings([
  52. ('newton.f90', f_mod),
  53. ('_newton.pyx', ("#cython: language_level={}\n".format("3") +
  54. "cdef extern double newton(double*)\n"
  55. "def py_newton(double x):\n"
  56. " return newton(&x)\n"))
  57. ], build_dir=folder)
  58. assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
  59. def test_newtons_method_function__pycode():
  60. x = sp.Symbol('x', real=True)
  61. expr = sp.cos(x) - x**3
  62. func = newtons_method_function(expr, x)
  63. py_mod = py_module(func)
  64. namespace = {}
  65. exec(py_mod, namespace, namespace)
  66. res = eval('newton(0.5)', namespace)
  67. assert abs(res - 0.865474033102) < 1e-12
  68. @may_xfail
  69. def test_newtons_method_function__ccode_parameters():
  70. args = x, A, k, p = sp.symbols('x A k p')
  71. expr = A*sp.cos(k*x) - p*x**3
  72. raises(ValueError, lambda: newtons_method_function(expr, x))
  73. use_wurlitzer = wurlitzer
  74. func = newtons_method_function(expr, x, args, debug=use_wurlitzer)
  75. if not has_c():
  76. skip("No C compiler found.")
  77. if not cython:
  78. skip("cython not installed.")
  79. compile_kw = {"std": 'c99'}
  80. with tempfile.TemporaryDirectory() as folder:
  81. mod, info = compile_link_import_strings([
  82. ('newton_par.c', ('#include <math.h>\n'
  83. '#include <stdio.h>\n') + ccode(func)),
  84. ('_newton_par.pyx', ("#cython: language_level={}\n".format("3") +
  85. "cdef extern double newton(double, double, double, double)\n"
  86. "def py_newton(x, A=1, k=1, p=1):\n"
  87. " return newton(x, A, k, p)\n"))
  88. ], compile_kwargs=compile_kw, build_dir=folder)
  89. if use_wurlitzer:
  90. with wurlitzer.pipes() as (out, err):
  91. result = mod.py_newton(0.5)
  92. else:
  93. result = mod.py_newton(0.5)
  94. assert abs(result - 0.865474033102) < 1e-12
  95. if not use_wurlitzer:
  96. skip("C-level output only tested when package 'wurlitzer' is available.")
  97. out, err = out.read(), err.read()
  98. assert err == ''
  99. assert out == """\
  100. x= 0.5 d_x= 0.61214
  101. x= 1.1121 d_x= -0.20247
  102. x= 0.90967 d_x= -0.042409
  103. x= 0.86726 d_x= -0.0017867
  104. x= 0.86548 d_x= -3.1022e-06
  105. x= 0.86547 d_x= -9.3421e-12
  106. x= 0.86547 d_x= 3.6902e-17
  107. """ # try to run tests with LC_ALL=C if this assertion fails