test_numeric.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. from sympy.core.function import nfloat
  2. from sympy.core.numbers import (Float, I, Rational, pi)
  3. from sympy.core.relational import Eq
  4. from sympy.core.symbol import (Symbol, symbols)
  5. from sympy.functions.elementary.miscellaneous import sqrt
  6. from sympy.functions.elementary.piecewise import Piecewise
  7. from sympy.functions.elementary.trigonometric import sin
  8. from sympy.integrals.integrals import Integral
  9. from sympy.matrices.dense import Matrix
  10. from mpmath import mnorm, mpf
  11. from sympy.solvers import nsolve
  12. from sympy.utilities.lambdify import lambdify
  13. from sympy.testing.pytest import raises, XFAIL
  14. from sympy.utilities.decorator import conserve_mpmath_dps
  15. @XFAIL
  16. def test_nsolve_fail():
  17. x = symbols('x')
  18. # Sometimes it is better to use the numerator (issue 4829)
  19. # but sometimes it is not (issue 11768) so leave this to
  20. # the discretion of the user
  21. ans = nsolve(x**2/(1 - x)/(1 - 2*x)**2 - 100, x, 0)
  22. assert ans > 0.46 and ans < 0.47
  23. def test_nsolve_denominator():
  24. x = symbols('x')
  25. # Test that nsolve uses the full expression (numerator and denominator).
  26. ans = nsolve((x**2 + 3*x + 2)/(x + 2), -2.1)
  27. # The root -2 was divided out, so make sure we don't find it.
  28. assert ans == -1.0
  29. def test_nsolve():
  30. # onedimensional
  31. x = Symbol('x')
  32. assert nsolve(sin(x), 2) - pi.evalf() < 1e-15
  33. assert nsolve(Eq(2*x, 2), x, -10) == nsolve(2*x - 2, -10)
  34. # Testing checks on number of inputs
  35. raises(TypeError, lambda: nsolve(Eq(2*x, 2)))
  36. raises(TypeError, lambda: nsolve(Eq(2*x, 2), x, 1, 2))
  37. # multidimensional
  38. x1 = Symbol('x1')
  39. x2 = Symbol('x2')
  40. f1 = 3 * x1**2 - 2 * x2**2 - 1
  41. f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8
  42. f = Matrix((f1, f2)).T
  43. F = lambdify((x1, x2), f.T, modules='mpmath')
  44. for x0 in [(-1, 1), (1, -2), (4, 4), (-4, -4)]:
  45. x = nsolve(f, (x1, x2), x0, tol=1.e-8)
  46. assert mnorm(F(*x), 1) <= 1.e-10
  47. # The Chinese mathematician Zhu Shijie was the very first to solve this
  48. # nonlinear system 700 years ago (z was added to make it 3-dimensional)
  49. x = Symbol('x')
  50. y = Symbol('y')
  51. z = Symbol('z')
  52. f1 = -x + 2*y
  53. f2 = (x**2 + x*(y**2 - 2) - 4*y) / (x + 4)
  54. f3 = sqrt(x**2 + y**2)*z
  55. f = Matrix((f1, f2, f3)).T
  56. F = lambdify((x, y, z), f.T, modules='mpmath')
  57. def getroot(x0):
  58. root = nsolve(f, (x, y, z), x0)
  59. assert mnorm(F(*root), 1) <= 1.e-8
  60. return root
  61. assert list(map(round, getroot((1, 1, 1)))) == [2, 1, 0]
  62. assert nsolve([Eq(
  63. f1, 0), Eq(f2, 0), Eq(f3, 0)], [x, y, z], (1, 1, 1)) # just see that it works
  64. a = Symbol('a')
  65. assert abs(nsolve(1/(0.001 + a)**3 - 6/(0.9 - a)**3, a, 0.3) -
  66. mpf('0.31883011387318591')) < 1e-15
  67. def test_issue_6408():
  68. x = Symbol('x')
  69. assert nsolve(Piecewise((x, x < 1), (x**2, True)), x, 2) == 0.0
  70. def test_issue_6408_integral():
  71. x, y = symbols('x y')
  72. assert nsolve(Integral(x*y, (x, 0, 5)), y, 2) == 0.0
  73. @conserve_mpmath_dps
  74. def test_increased_dps():
  75. # Issue 8564
  76. import mpmath
  77. mpmath.mp.dps = 128
  78. x = Symbol('x')
  79. e1 = x**2 - pi
  80. q = nsolve(e1, x, 3.0)
  81. assert abs(sqrt(pi).evalf(128) - q) < 1e-128
  82. def test_nsolve_precision():
  83. x, y = symbols('x y')
  84. sol = nsolve(x**2 - pi, x, 3, prec=128)
  85. assert abs(sqrt(pi).evalf(128) - sol) < 1e-128
  86. assert isinstance(sol, Float)
  87. sols = nsolve((y**2 - x, x**2 - pi), (x, y), (3, 3), prec=128)
  88. assert isinstance(sols, Matrix)
  89. assert sols.shape == (2, 1)
  90. assert abs(sqrt(pi).evalf(128) - sols[0]) < 1e-128
  91. assert abs(sqrt(sqrt(pi)).evalf(128) - sols[1]) < 1e-128
  92. assert all(isinstance(i, Float) for i in sols)
  93. def test_nsolve_complex():
  94. x, y = symbols('x y')
  95. assert nsolve(x**2 + 2, 1j) == sqrt(2.)*I
  96. assert nsolve(x**2 + 2, I) == sqrt(2.)*I
  97. assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I])
  98. assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I])
  99. def test_nsolve_dict_kwarg():
  100. x, y = symbols('x y')
  101. # one variable
  102. assert nsolve(x**2 - 2, 1, dict = True) == \
  103. [{x: sqrt(2.)}]
  104. # one variable with complex solution
  105. assert nsolve(x**2 + 2, I, dict = True) == \
  106. [{x: sqrt(2.)*I}]
  107. # two variables
  108. assert nsolve([x**2 + y**2 - 5, x**2 - y**2 + 1], [x, y], [1, 1], dict = True) == \
  109. [{x: sqrt(2.), y: sqrt(3.)}]
  110. def test_nsolve_rational():
  111. x = symbols('x')
  112. assert nsolve(x - Rational(1, 3), 0, prec=100) == Rational(1, 3).evalf(100)
  113. def test_issue_14950():
  114. x = Matrix(symbols('t s'))
  115. x0 = Matrix([17, 23])
  116. eqn = x + x0
  117. assert nsolve(eqn, x, x0) == nfloat(-x0)
  118. assert nsolve(eqn.T, x.T, x0.T) == nfloat(-x0)