test_approximations.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import math
  2. from sympy.core.symbol import symbols
  3. from sympy.functions.elementary.exponential import exp
  4. from sympy.codegen.rewriting import optimize
  5. from sympy.codegen.approximations import SumApprox, SeriesApprox
  6. def test_SumApprox_trivial():
  7. x = symbols('x')
  8. expr1 = 1 + x
  9. sum_approx = SumApprox(bounds={x: (-1e-20, 1e-20)}, reltol=1e-16)
  10. apx1 = optimize(expr1, [sum_approx])
  11. assert apx1 - 1 == 0
  12. def test_SumApprox_monotone_terms():
  13. x, y, z = symbols('x y z')
  14. expr1 = exp(z)*(x**2 + y**2 + 1)
  15. bnds1 = {x: (0, 1e-3), y: (100, 1000)}
  16. sum_approx_m2 = SumApprox(bounds=bnds1, reltol=1e-2)
  17. sum_approx_m5 = SumApprox(bounds=bnds1, reltol=1e-5)
  18. sum_approx_m11 = SumApprox(bounds=bnds1, reltol=1e-11)
  19. assert (optimize(expr1, [sum_approx_m2])/exp(z) - (y**2)).simplify() == 0
  20. assert (optimize(expr1, [sum_approx_m5])/exp(z) - (y**2 + 1)).simplify() == 0
  21. assert (optimize(expr1, [sum_approx_m11])/exp(z) - (y**2 + 1 + x**2)).simplify() == 0
  22. def test_SeriesApprox_trivial():
  23. x, z = symbols('x z')
  24. for factor in [1, exp(z)]:
  25. x = symbols('x')
  26. expr1 = exp(x)*factor
  27. bnds1 = {x: (-1, 1)}
  28. series_approx_50 = SeriesApprox(bounds=bnds1, reltol=0.50)
  29. series_approx_10 = SeriesApprox(bounds=bnds1, reltol=0.10)
  30. series_approx_05 = SeriesApprox(bounds=bnds1, reltol=0.05)
  31. c = (bnds1[x][1] + bnds1[x][0])/2 # 0.0
  32. f0 = math.exp(c) # 1.0
  33. ref_50 = f0 + x + x**2/2
  34. ref_10 = f0 + x + x**2/2 + x**3/6
  35. ref_05 = f0 + x + x**2/2 + x**3/6 + x**4/24
  36. res_50 = optimize(expr1, [series_approx_50])
  37. res_10 = optimize(expr1, [series_approx_10])
  38. res_05 = optimize(expr1, [series_approx_05])
  39. assert (res_50/factor - ref_50).simplify() == 0
  40. assert (res_10/factor - ref_10).simplify() == 0
  41. assert (res_05/factor - ref_05).simplify() == 0
  42. max_ord3 = SeriesApprox(bounds=bnds1, reltol=0.05, max_order=3)
  43. assert optimize(expr1, [max_ord3]) == expr1