test__root.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. """
  2. Unit tests for optimization routines from _root.py.
  3. """
  4. from numpy.testing import assert_
  5. from pytest import raises as assert_raises
  6. import numpy as np
  7. from scipy.optimize import root
  8. class TestRoot:
  9. def test_tol_parameter(self):
  10. # Check that the minimize() tol= argument does something
  11. def func(z):
  12. x, y = z
  13. return np.array([x**3 - 1, y**3 - 1])
  14. def dfunc(z):
  15. x, y = z
  16. return np.array([[3*x**2, 0], [0, 3*y**2]])
  17. for method in ['hybr', 'lm', 'broyden1', 'broyden2', 'anderson',
  18. 'diagbroyden', 'krylov']:
  19. if method in ('linearmixing', 'excitingmixing'):
  20. # doesn't converge
  21. continue
  22. if method in ('hybr', 'lm'):
  23. jac = dfunc
  24. else:
  25. jac = None
  26. sol1 = root(func, [1.1,1.1], jac=jac, tol=1e-4, method=method)
  27. sol2 = root(func, [1.1,1.1], jac=jac, tol=0.5, method=method)
  28. msg = "%s: %s vs. %s" % (method, func(sol1.x), func(sol2.x))
  29. assert_(sol1.success, msg)
  30. assert_(sol2.success, msg)
  31. assert_(abs(func(sol1.x)).max() < abs(func(sol2.x)).max(),
  32. msg)
  33. def test_tol_norm(self):
  34. def norm(x):
  35. return abs(x[0])
  36. for method in ['excitingmixing',
  37. 'diagbroyden',
  38. 'linearmixing',
  39. 'anderson',
  40. 'broyden1',
  41. 'broyden2',
  42. 'krylov']:
  43. root(np.zeros_like, np.zeros(2), method=method,
  44. options={"tol_norm": norm})
  45. def test_minimize_scalar_coerce_args_param(self):
  46. # github issue #3503
  47. def func(z, f=1):
  48. x, y = z
  49. return np.array([x**3 - 1, y**3 - f])
  50. root(func, [1.1, 1.1], args=1.5)
  51. def test_f_size(self):
  52. # gh8320
  53. # check that decreasing the size of the returned array raises an error
  54. # and doesn't segfault
  55. class fun:
  56. def __init__(self):
  57. self.count = 0
  58. def __call__(self, x):
  59. self.count += 1
  60. if not (self.count % 5):
  61. ret = x[0] + 0.5 * (x[0] - x[1]) ** 3 - 1.0
  62. else:
  63. ret = ([x[0] + 0.5 * (x[0] - x[1]) ** 3 - 1.0,
  64. 0.5 * (x[1] - x[0]) ** 3 + x[1]])
  65. return ret
  66. F = fun()
  67. with assert_raises(ValueError):
  68. root(F, [0.1, 0.0], method='lm')