test_solvers.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. import os
  2. import numpy as np
  3. from numpy.testing import assert_array_almost_equal
  4. import pytest
  5. from pytest import raises as assert_raises
  6. from scipy.linalg import solve_sylvester
  7. from scipy.linalg import solve_continuous_lyapunov, solve_discrete_lyapunov
  8. from scipy.linalg import solve_continuous_are, solve_discrete_are
  9. from scipy.linalg import block_diag, solve, LinAlgError
  10. from scipy.sparse._sputils import matrix
  11. def _load_data(name):
  12. """
  13. Load npz data file under data/
  14. Returns a copy of the data, rather than keeping the npz file open.
  15. """
  16. filename = os.path.join(os.path.abspath(os.path.dirname(__file__)),
  17. 'data', name)
  18. with np.load(filename) as f:
  19. return dict(f.items())
  20. class TestSolveLyapunov:
  21. cases = [
  22. (np.array([[1, 2], [3, 4]]),
  23. np.array([[9, 10], [11, 12]])),
  24. # a, q all complex.
  25. (np.array([[1.0+1j, 2.0], [3.0-4.0j, 5.0]]),
  26. np.array([[2.0-2j, 2.0+2j], [-1.0-1j, 2.0]])),
  27. # a real; q complex.
  28. (np.array([[1.0, 2.0], [3.0, 5.0]]),
  29. np.array([[2.0-2j, 2.0+2j], [-1.0-1j, 2.0]])),
  30. # a complex; q real.
  31. (np.array([[1.0+1j, 2.0], [3.0-4.0j, 5.0]]),
  32. np.array([[2.0, 2.0], [-1.0, 2.0]])),
  33. # An example from Kitagawa, 1977
  34. (np.array([[3, 9, 5, 1, 4], [1, 2, 3, 8, 4], [4, 6, 6, 6, 3],
  35. [1, 5, 2, 0, 7], [5, 3, 3, 1, 5]]),
  36. np.array([[2, 4, 1, 0, 1], [4, 1, 0, 2, 0], [1, 0, 3, 0, 3],
  37. [0, 2, 0, 1, 0], [1, 0, 3, 0, 4]])),
  38. # Companion matrix example. a complex; q real; a.shape[0] = 11
  39. (np.array([[0.100+0.j, 0.091+0.j, 0.082+0.j, 0.073+0.j, 0.064+0.j,
  40. 0.055+0.j, 0.046+0.j, 0.037+0.j, 0.028+0.j, 0.019+0.j,
  41. 0.010+0.j],
  42. [1.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  43. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  44. 0.000+0.j],
  45. [0.000+0.j, 1.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  46. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  47. 0.000+0.j],
  48. [0.000+0.j, 0.000+0.j, 1.000+0.j, 0.000+0.j, 0.000+0.j,
  49. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  50. 0.000+0.j],
  51. [0.000+0.j, 0.000+0.j, 0.000+0.j, 1.000+0.j, 0.000+0.j,
  52. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  53. 0.000+0.j],
  54. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 1.000+0.j,
  55. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  56. 0.000+0.j],
  57. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  58. 1.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  59. 0.000+0.j],
  60. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  61. 0.000+0.j, 1.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  62. 0.000+0.j],
  63. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  64. 0.000+0.j, 0.000+0.j, 1.000+0.j, 0.000+0.j, 0.000+0.j,
  65. 0.000+0.j],
  66. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  67. 0.000+0.j, 0.000+0.j, 0.000+0.j, 1.000+0.j, 0.000+0.j,
  68. 0.000+0.j],
  69. [0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j,
  70. 0.000+0.j, 0.000+0.j, 0.000+0.j, 0.000+0.j, 1.000+0.j,
  71. 0.000+0.j]]),
  72. np.eye(11)),
  73. # https://github.com/scipy/scipy/issues/4176
  74. (matrix([[0, 1], [-1/2, -1]]),
  75. (matrix([0, 3]).T @ matrix([0, 3]).T.T)),
  76. # https://github.com/scipy/scipy/issues/4176
  77. (matrix([[0, 1], [-1/2, -1]]),
  78. (np.array(matrix([0, 3]).T @ matrix([0, 3]).T.T))),
  79. ]
  80. def test_continuous_squareness_and_shape(self):
  81. nsq = np.ones((3, 2))
  82. sq = np.eye(3)
  83. assert_raises(ValueError, solve_continuous_lyapunov, nsq, sq)
  84. assert_raises(ValueError, solve_continuous_lyapunov, sq, nsq)
  85. assert_raises(ValueError, solve_continuous_lyapunov, sq, np.eye(2))
  86. def check_continuous_case(self, a, q):
  87. x = solve_continuous_lyapunov(a, q)
  88. assert_array_almost_equal(
  89. np.dot(a, x) + np.dot(x, a.conj().transpose()), q)
  90. def check_discrete_case(self, a, q, method=None):
  91. x = solve_discrete_lyapunov(a, q, method=method)
  92. assert_array_almost_equal(
  93. np.dot(np.dot(a, x), a.conj().transpose()) - x, -1.0*q)
  94. def test_cases(self):
  95. for case in self.cases:
  96. self.check_continuous_case(case[0], case[1])
  97. self.check_discrete_case(case[0], case[1])
  98. self.check_discrete_case(case[0], case[1], method='direct')
  99. self.check_discrete_case(case[0], case[1], method='bilinear')
  100. def test_solve_continuous_are():
  101. mat6 = _load_data('carex_6_data.npz')
  102. mat15 = _load_data('carex_15_data.npz')
  103. mat18 = _load_data('carex_18_data.npz')
  104. mat19 = _load_data('carex_19_data.npz')
  105. mat20 = _load_data('carex_20_data.npz')
  106. cases = [
  107. # Carex examples taken from (with default parameters):
  108. # [1] P.BENNER, A.J. LAUB, V. MEHRMANN: 'A Collection of Benchmark
  109. # Examples for the Numerical Solution of Algebraic Riccati
  110. # Equations II: Continuous-Time Case', Tech. Report SPC 95_23,
  111. # Fak. f. Mathematik, TU Chemnitz-Zwickau (Germany), 1995.
  112. #
  113. # The format of the data is (a, b, q, r, knownfailure), where
  114. # knownfailure is None if the test passes or a string
  115. # indicating the reason for failure.
  116. #
  117. # Test Case 0: carex #1
  118. (np.diag([1.], 1),
  119. np.array([[0], [1]]),
  120. block_diag(1., 2.),
  121. 1,
  122. None),
  123. # Test Case 1: carex #2
  124. (np.array([[4, 3], [-4.5, -3.5]]),
  125. np.array([[1], [-1]]),
  126. np.array([[9, 6], [6, 4.]]),
  127. 1,
  128. None),
  129. # Test Case 2: carex #3
  130. (np.array([[0, 1, 0, 0],
  131. [0, -1.89, 0.39, -5.53],
  132. [0, -0.034, -2.98, 2.43],
  133. [0.034, -0.0011, -0.99, -0.21]]),
  134. np.array([[0, 0], [0.36, -1.6], [-0.95, -0.032], [0.03, 0]]),
  135. np.array([[2.313, 2.727, 0.688, 0.023],
  136. [2.727, 4.271, 1.148, 0.323],
  137. [0.688, 1.148, 0.313, 0.102],
  138. [0.023, 0.323, 0.102, 0.083]]),
  139. np.eye(2),
  140. None),
  141. # Test Case 3: carex #4
  142. (np.array([[-0.991, 0.529, 0, 0, 0, 0, 0, 0],
  143. [0.522, -1.051, 0.596, 0, 0, 0, 0, 0],
  144. [0, 0.522, -1.118, 0.596, 0, 0, 0, 0],
  145. [0, 0, 0.522, -1.548, 0.718, 0, 0, 0],
  146. [0, 0, 0, 0.922, -1.64, 0.799, 0, 0],
  147. [0, 0, 0, 0, 0.922, -1.721, 0.901, 0],
  148. [0, 0, 0, 0, 0, 0.922, -1.823, 1.021],
  149. [0, 0, 0, 0, 0, 0, 0.922, -1.943]]),
  150. np.array([[3.84, 4.00, 37.60, 3.08, 2.36, 2.88, 3.08, 3.00],
  151. [-2.88, -3.04, -2.80, -2.32, -3.32, -3.82, -4.12, -3.96]]
  152. ).T * 0.001,
  153. np.array([[1.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.0, 0.1],
  154. [0.0, 1.0, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0],
  155. [0.0, 0.0, 1.0, 0.0, 0.0, 0.5, 0.0, 0.0],
  156. [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0],
  157. [0.5, 0.1, 0.0, 0.0, 0.1, 0.0, 0.0, 0.0],
  158. [0.0, 0.0, 0.5, 0.0, 0.0, 0.1, 0.0, 0.0],
  159. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, 0.0],
  160. [0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1]]),
  161. np.eye(2),
  162. None),
  163. # Test Case 4: carex #5
  164. (np.array(
  165. [[-4.019, 5.120, 0., 0., -2.082, 0., 0., 0., 0.870],
  166. [-0.346, 0.986, 0., 0., -2.340, 0., 0., 0., 0.970],
  167. [-7.909, 15.407, -4.069, 0., -6.450, 0., 0., 0., 2.680],
  168. [-21.816, 35.606, -0.339, -3.870, -17.800, 0., 0., 0., 7.390],
  169. [-60.196, 98.188, -7.907, 0.340, -53.008, 0., 0., 0., 20.400],
  170. [0, 0, 0, 0, 94.000, -147.200, 0., 53.200, 0.],
  171. [0, 0, 0, 0, 0, 94.000, -147.200, 0, 0],
  172. [0, 0, 0, 0, 0, 12.800, 0.000, -31.600, 0],
  173. [0, 0, 0, 0, 12.800, 0.000, 0.000, 18.800, -31.600]]),
  174. np.array([[0.010, -0.011, -0.151],
  175. [0.003, -0.021, 0.000],
  176. [0.009, -0.059, 0.000],
  177. [0.024, -0.162, 0.000],
  178. [0.068, -0.445, 0.000],
  179. [0.000, 0.000, 0.000],
  180. [0.000, 0.000, 0.000],
  181. [0.000, 0.000, 0.000],
  182. [0.000, 0.000, 0.000]]),
  183. np.eye(9),
  184. np.eye(3),
  185. None),
  186. # Test Case 5: carex #6
  187. (mat6['A'], mat6['B'], mat6['Q'], mat6['R'], None),
  188. # Test Case 6: carex #7
  189. (np.array([[1, 0], [0, -2.]]),
  190. np.array([[1e-6], [0]]),
  191. np.ones((2, 2)),
  192. 1.,
  193. 'Bad residual accuracy'),
  194. # Test Case 7: carex #8
  195. (block_diag(-0.1, -0.02),
  196. np.array([[0.100, 0.000], [0.001, 0.010]]),
  197. np.array([[100, 1000], [1000, 10000]]),
  198. np.ones((2, 2)) + block_diag(1e-6, 0),
  199. None),
  200. # Test Case 8: carex #9
  201. (np.array([[0, 1e6], [0, 0]]),
  202. np.array([[0], [1.]]),
  203. np.eye(2),
  204. 1.,
  205. None),
  206. # Test Case 9: carex #10
  207. (np.array([[1.0000001, 1], [1., 1.0000001]]),
  208. np.eye(2),
  209. np.eye(2),
  210. np.eye(2),
  211. None),
  212. # Test Case 10: carex #11
  213. (np.array([[3, 1.], [4, 2]]),
  214. np.array([[1], [1]]),
  215. np.array([[-11, -5], [-5, -2.]]),
  216. 1.,
  217. None),
  218. # Test Case 11: carex #12
  219. (np.array([[7000000., 2000000., -0.],
  220. [2000000., 6000000., -2000000.],
  221. [0., -2000000., 5000000.]]) / 3,
  222. np.eye(3),
  223. np.array([[1., -2., -2.], [-2., 1., -2.], [-2., -2., 1.]]).dot(
  224. np.diag([1e-6, 1, 1e6])).dot(
  225. np.array([[1., -2., -2.], [-2., 1., -2.], [-2., -2., 1.]])) / 9,
  226. np.eye(3) * 1e6,
  227. 'Bad Residual Accuracy'),
  228. # Test Case 12: carex #13
  229. (np.array([[0, 0.4, 0, 0],
  230. [0, 0, 0.345, 0],
  231. [0, -0.524e6, -0.465e6, 0.262e6],
  232. [0, 0, 0, -1e6]]),
  233. np.array([[0, 0, 0, 1e6]]).T,
  234. np.diag([1, 0, 1, 0]),
  235. 1.,
  236. None),
  237. # Test Case 13: carex #14
  238. (np.array([[-1e-6, 1, 0, 0],
  239. [-1, -1e-6, 0, 0],
  240. [0, 0, 1e-6, 1],
  241. [0, 0, -1, 1e-6]]),
  242. np.ones((4, 1)),
  243. np.ones((4, 4)),
  244. 1.,
  245. None),
  246. # Test Case 14: carex #15
  247. (mat15['A'], mat15['B'], mat15['Q'], mat15['R'], None),
  248. # Test Case 15: carex #16
  249. (np.eye(64, 64, k=-1) + np.eye(64, 64)*(-2.) + np.rot90(
  250. block_diag(1, np.zeros((62, 62)), 1)) + np.eye(64, 64, k=1),
  251. np.eye(64),
  252. np.eye(64),
  253. np.eye(64),
  254. None),
  255. # Test Case 16: carex #17
  256. (np.diag(np.ones((20, )), 1),
  257. np.flipud(np.eye(21, 1)),
  258. np.eye(21, 1) * np.eye(21, 1).T,
  259. 1,
  260. 'Bad Residual Accuracy'),
  261. # Test Case 17: carex #18
  262. (mat18['A'], mat18['B'], mat18['Q'], mat18['R'], None),
  263. # Test Case 18: carex #19
  264. (mat19['A'], mat19['B'], mat19['Q'], mat19['R'],
  265. 'Bad Residual Accuracy'),
  266. # Test Case 19: carex #20
  267. (mat20['A'], mat20['B'], mat20['Q'], mat20['R'],
  268. 'Bad Residual Accuracy')
  269. ]
  270. # Makes the minimum precision requirements customized to the test.
  271. # Here numbers represent the number of decimals that agrees with zero
  272. # matrix when the solution x is plugged in to the equation.
  273. #
  274. # res = array([[8e-3,1e-16],[1e-16,1e-20]]) --> min_decimal[k] = 2
  275. #
  276. # If the test is failing use "None" for that entry.
  277. #
  278. min_decimal = (14, 12, 13, 14, 11, 6, None, 5, 7, 14, 14,
  279. None, 9, 14, 13, 14, None, 12, None, None)
  280. def _test_factory(case, dec):
  281. """Checks if 0 = XA + A'X - XB(R)^{-1} B'X + Q is true"""
  282. a, b, q, r, knownfailure = case
  283. if knownfailure:
  284. pytest.xfail(reason=knownfailure)
  285. x = solve_continuous_are(a, b, q, r)
  286. res = x.dot(a) + a.conj().T.dot(x) + q
  287. out_fact = x.dot(b)
  288. res -= out_fact.dot(solve(np.atleast_2d(r), out_fact.conj().T))
  289. assert_array_almost_equal(res, np.zeros_like(res), decimal=dec)
  290. for ind, case in enumerate(cases):
  291. _test_factory(case, min_decimal[ind])
  292. def test_solve_discrete_are():
  293. cases = [
  294. # Darex examples taken from (with default parameters):
  295. # [1] P.BENNER, A.J. LAUB, V. MEHRMANN: 'A Collection of Benchmark
  296. # Examples for the Numerical Solution of Algebraic Riccati
  297. # Equations II: Discrete-Time Case', Tech. Report SPC 95_23,
  298. # Fak. f. Mathematik, TU Chemnitz-Zwickau (Germany), 1995.
  299. # [2] T. GUDMUNDSSON, C. KENNEY, A.J. LAUB: 'Scaling of the
  300. # Discrete-Time Algebraic Riccati Equation to Enhance Stability
  301. # of the Schur Solution Method', IEEE Trans.Aut.Cont., vol.37(4)
  302. #
  303. # The format of the data is (a, b, q, r, knownfailure), where
  304. # knownfailure is None if the test passes or a string
  305. # indicating the reason for failure.
  306. #
  307. # TEST CASE 0 : Complex a; real b, q, r
  308. (np.array([[2, 1-2j], [0, -3j]]),
  309. np.array([[0], [1]]),
  310. np.array([[1, 0], [0, 2]]),
  311. np.array([[1]]),
  312. None),
  313. # TEST CASE 1 :Real a, q, r; complex b
  314. (np.array([[2, 1], [0, -1]]),
  315. np.array([[-2j], [1j]]),
  316. np.array([[1, 0], [0, 2]]),
  317. np.array([[1]]),
  318. None),
  319. # TEST CASE 2 : Real a, b; complex q, r
  320. (np.array([[3, 1], [0, -1]]),
  321. np.array([[1, 2], [1, 3]]),
  322. np.array([[1, 1+1j], [1-1j, 2]]),
  323. np.array([[2, -2j], [2j, 3]]),
  324. None),
  325. # TEST CASE 3 : User-reported gh-2251 (Trac #1732)
  326. (np.array([[0.63399379, 0.54906824, 0.76253406],
  327. [0.5404729, 0.53745766, 0.08731853],
  328. [0.27524045, 0.84922129, 0.4681622]]),
  329. np.array([[0.96861695], [0.05532739], [0.78934047]]),
  330. np.eye(3),
  331. np.eye(1),
  332. None),
  333. # TEST CASE 4 : darex #1
  334. (np.array([[4, 3], [-4.5, -3.5]]),
  335. np.array([[1], [-1]]),
  336. np.array([[9, 6], [6, 4]]),
  337. np.array([[1]]),
  338. None),
  339. # TEST CASE 5 : darex #2
  340. (np.array([[0.9512, 0], [0, 0.9048]]),
  341. np.array([[4.877, 4.877], [-1.1895, 3.569]]),
  342. np.array([[0.005, 0], [0, 0.02]]),
  343. np.array([[1/3, 0], [0, 3]]),
  344. None),
  345. # TEST CASE 6 : darex #3
  346. (np.array([[2, -1], [1, 0]]),
  347. np.array([[1], [0]]),
  348. np.array([[0, 0], [0, 1]]),
  349. np.array([[0]]),
  350. None),
  351. # TEST CASE 7 : darex #4 (skipped the gen. Ric. term S)
  352. (np.array([[0, 1], [0, -1]]),
  353. np.array([[1, 0], [2, 1]]),
  354. np.array([[-4, -4], [-4, 7]]) * (1/11),
  355. np.array([[9, 3], [3, 1]]),
  356. None),
  357. # TEST CASE 8 : darex #5
  358. (np.array([[0, 1], [0, 0]]),
  359. np.array([[0], [1]]),
  360. np.array([[1, 2], [2, 4]]),
  361. np.array([[1]]),
  362. None),
  363. # TEST CASE 9 : darex #6
  364. (np.array([[0.998, 0.067, 0, 0],
  365. [-.067, 0.998, 0, 0],
  366. [0, 0, 0.998, 0.153],
  367. [0, 0, -.153, 0.998]]),
  368. np.array([[0.0033, 0.0200],
  369. [0.1000, -.0007],
  370. [0.0400, 0.0073],
  371. [-.0028, 0.1000]]),
  372. np.array([[1.87, 0, 0, -0.244],
  373. [0, 0.744, 0.205, 0],
  374. [0, 0.205, 0.589, 0],
  375. [-0.244, 0, 0, 1.048]]),
  376. np.eye(2),
  377. None),
  378. # TEST CASE 10 : darex #7
  379. (np.array([[0.984750, -.079903, 0.0009054, -.0010765],
  380. [0.041588, 0.998990, -.0358550, 0.0126840],
  381. [-.546620, 0.044916, -.3299100, 0.1931800],
  382. [2.662400, -.100450, -.9245500, -.2632500]]),
  383. np.array([[0.0037112, 0.0007361],
  384. [-.0870510, 9.3411e-6],
  385. [-1.198440, -4.1378e-4],
  386. [-3.192700, 9.2535e-4]]),
  387. np.eye(4)*1e-2,
  388. np.eye(2),
  389. None),
  390. # TEST CASE 11 : darex #8
  391. (np.array([[-0.6000000, -2.2000000, -3.6000000, -5.4000180],
  392. [1.0000000, 0.6000000, 0.8000000, 3.3999820],
  393. [0.0000000, 1.0000000, 1.8000000, 3.7999820],
  394. [0.0000000, 0.0000000, 0.0000000, -0.9999820]]),
  395. np.array([[1.0, -1.0, -1.0, -1.0],
  396. [0.0, 1.0, -1.0, -1.0],
  397. [0.0, 0.0, 1.0, -1.0],
  398. [0.0, 0.0, 0.0, 1.0]]),
  399. np.array([[2, 1, 3, 6],
  400. [1, 2, 2, 5],
  401. [3, 2, 6, 11],
  402. [6, 5, 11, 22]]),
  403. np.eye(4),
  404. None),
  405. # TEST CASE 12 : darex #9
  406. (np.array([[95.4070, 1.9643, 0.3597, 0.0673, 0.0190],
  407. [40.8490, 41.3170, 16.0840, 4.4679, 1.1971],
  408. [12.2170, 26.3260, 36.1490, 15.9300, 12.3830],
  409. [4.1118, 12.8580, 27.2090, 21.4420, 40.9760],
  410. [0.1305, 0.5808, 1.8750, 3.6162, 94.2800]]) * 0.01,
  411. np.array([[0.0434, -0.0122],
  412. [2.6606, -1.0453],
  413. [3.7530, -5.5100],
  414. [3.6076, -6.6000],
  415. [0.4617, -0.9148]]) * 0.01,
  416. np.eye(5),
  417. np.eye(2),
  418. None),
  419. # TEST CASE 13 : darex #10
  420. (np.kron(np.eye(2), np.diag([1, 1], k=1)),
  421. np.kron(np.eye(2), np.array([[0], [0], [1]])),
  422. np.array([[1, 1, 0, 0, 0, 0],
  423. [1, 1, 0, 0, 0, 0],
  424. [0, 0, 0, 0, 0, 0],
  425. [0, 0, 0, 1, -1, 0],
  426. [0, 0, 0, -1, 1, 0],
  427. [0, 0, 0, 0, 0, 0]]),
  428. np.array([[3, 0], [0, 1]]),
  429. None),
  430. # TEST CASE 14 : darex #11
  431. (0.001 * np.array(
  432. [[870.1, 135.0, 11.59, .5014, -37.22, .3484, 0, 4.242, 7.249],
  433. [76.55, 897.4, 12.72, 0.5504, -40.16, .3743, 0, 4.53, 7.499],
  434. [-127.2, 357.5, 817, 1.455, -102.8, .987, 0, 11.85, 18.72],
  435. [-363.5, 633.9, 74.91, 796.6, -273.5, 2.653, 0, 31.72, 48.82],
  436. [-960, 1645.9, -128.9, -5.597, 71.42, 7.108, 0, 84.52, 125.9],
  437. [-664.4, 112.96, -88.89, -3.854, 84.47, 13.6, 0, 144.3, 101.6],
  438. [-410.2, 693, -54.71, -2.371, 66.49, 12.49, .1063, 99.97, 69.67],
  439. [-179.9, 301.7, -23.93, -1.035, 60.59, 22.16, 0, 213.9, 35.54],
  440. [-345.1, 580.4, -45.96, -1.989, 105.6, 19.86, 0, 219.1, 215.2]]),
  441. np.array([[4.7600, -0.5701, -83.6800],
  442. [0.8790, -4.7730, -2.7300],
  443. [1.4820, -13.1200, 8.8760],
  444. [3.8920, -35.1300, 24.8000],
  445. [10.3400, -92.7500, 66.8000],
  446. [7.2030, -61.5900, 38.3400],
  447. [4.4540, -36.8300, 20.2900],
  448. [1.9710, -15.5400, 6.9370],
  449. [3.7730, -30.2800, 14.6900]]) * 0.001,
  450. np.diag([50, 0, 0, 0, 50, 0, 0, 0, 0]),
  451. np.eye(3),
  452. None),
  453. # TEST CASE 15 : darex #12 - numerically least accurate example
  454. (np.array([[0, 1e6], [0, 0]]),
  455. np.array([[0], [1]]),
  456. np.eye(2),
  457. np.array([[1]]),
  458. "Presumed issue with OpenBLAS, see gh-16926"),
  459. # TEST CASE 16 : darex #13
  460. (np.array([[16, 10, -2],
  461. [10, 13, -8],
  462. [-2, -8, 7]]) * (1/9),
  463. np.eye(3),
  464. 1e6 * np.eye(3),
  465. 1e6 * np.eye(3),
  466. "Issue with OpenBLAS, see gh-16926"),
  467. # TEST CASE 17 : darex #14
  468. (np.array([[1 - 1/1e8, 0, 0, 0],
  469. [1, 0, 0, 0],
  470. [0, 1, 0, 0],
  471. [0, 0, 1, 0]]),
  472. np.array([[1e-08], [0], [0], [0]]),
  473. np.diag([0, 0, 0, 1]),
  474. np.array([[0.25]]),
  475. None),
  476. # TEST CASE 18 : darex #15
  477. (np.eye(100, k=1),
  478. np.flipud(np.eye(100, 1)),
  479. np.eye(100),
  480. np.array([[1]]),
  481. None)
  482. ]
  483. # Makes the minimum precision requirements customized to the test.
  484. # Here numbers represent the number of decimals that agrees with zero
  485. # matrix when the solution x is plugged in to the equation.
  486. #
  487. # res = array([[8e-3,1e-16],[1e-16,1e-20]]) --> min_decimal[k] = 2
  488. #
  489. # If the test is failing use "None" for that entry.
  490. #
  491. min_decimal = (12, 14, 13, 14, 13, 16, 18, 14, 14, 13,
  492. 14, 13, 13, 14, 12, 2, 5, 6, 10)
  493. def _test_factory(case, dec):
  494. """Checks if X = A'XA-(A'XB)(R+B'XB)^-1(B'XA)+Q) is true"""
  495. a, b, q, r, knownfailure = case
  496. if knownfailure:
  497. pytest.xfail(reason=knownfailure)
  498. x = solve_discrete_are(a, b, q, r)
  499. res = a.conj().T.dot(x.dot(a)) - x + q
  500. res -= a.conj().T.dot(x.dot(b)).dot(
  501. solve(r+b.conj().T.dot(x.dot(b)), b.conj().T).dot(x.dot(a))
  502. )
  503. assert_array_almost_equal(res, np.zeros_like(res), decimal=dec)
  504. for ind, case in enumerate(cases):
  505. _test_factory(case, min_decimal[ind])
  506. # An infeasible example taken from https://arxiv.org/abs/1505.04861v1
  507. A = np.triu(np.ones((3, 3)))
  508. A[0, 1] = -1
  509. B = np.array([[1, 1, 0], [0, 0, 1]]).T
  510. Q = np.full_like(A, -2) + np.diag([8, -1, -1.9])
  511. R = np.diag([-10, 0.1])
  512. assert_raises(LinAlgError, solve_continuous_are, A, B, Q, R)
  513. def test_solve_generalized_continuous_are():
  514. cases = [
  515. # Two random examples differ by s term
  516. # in the absence of any literature for demanding examples.
  517. (np.array([[2.769230e-01, 8.234578e-01, 9.502220e-01],
  518. [4.617139e-02, 6.948286e-01, 3.444608e-02],
  519. [9.713178e-02, 3.170995e-01, 4.387444e-01]]),
  520. np.array([[3.815585e-01, 1.868726e-01],
  521. [7.655168e-01, 4.897644e-01],
  522. [7.951999e-01, 4.455862e-01]]),
  523. np.eye(3),
  524. np.eye(2),
  525. np.array([[6.463130e-01, 2.760251e-01, 1.626117e-01],
  526. [7.093648e-01, 6.797027e-01, 1.189977e-01],
  527. [7.546867e-01, 6.550980e-01, 4.983641e-01]]),
  528. np.zeros((3, 2)),
  529. None),
  530. (np.array([[2.769230e-01, 8.234578e-01, 9.502220e-01],
  531. [4.617139e-02, 6.948286e-01, 3.444608e-02],
  532. [9.713178e-02, 3.170995e-01, 4.387444e-01]]),
  533. np.array([[3.815585e-01, 1.868726e-01],
  534. [7.655168e-01, 4.897644e-01],
  535. [7.951999e-01, 4.455862e-01]]),
  536. np.eye(3),
  537. np.eye(2),
  538. np.array([[6.463130e-01, 2.760251e-01, 1.626117e-01],
  539. [7.093648e-01, 6.797027e-01, 1.189977e-01],
  540. [7.546867e-01, 6.550980e-01, 4.983641e-01]]),
  541. np.ones((3, 2)),
  542. None)
  543. ]
  544. min_decimal = (10, 10)
  545. def _test_factory(case, dec):
  546. """Checks if X = A'XA-(A'XB)(R+B'XB)^-1(B'XA)+Q) is true"""
  547. a, b, q, r, e, s, knownfailure = case
  548. if knownfailure:
  549. pytest.xfail(reason=knownfailure)
  550. x = solve_continuous_are(a, b, q, r, e, s)
  551. res = a.conj().T.dot(x.dot(e)) + e.conj().T.dot(x.dot(a)) + q
  552. out_fact = e.conj().T.dot(x).dot(b) + s
  553. res -= out_fact.dot(solve(np.atleast_2d(r), out_fact.conj().T))
  554. assert_array_almost_equal(res, np.zeros_like(res), decimal=dec)
  555. for ind, case in enumerate(cases):
  556. _test_factory(case, min_decimal[ind])
  557. def test_solve_generalized_discrete_are():
  558. mat20170120 = _load_data('gendare_20170120_data.npz')
  559. cases = [
  560. # Two random examples differ by s term
  561. # in the absence of any literature for demanding examples.
  562. (np.array([[2.769230e-01, 8.234578e-01, 9.502220e-01],
  563. [4.617139e-02, 6.948286e-01, 3.444608e-02],
  564. [9.713178e-02, 3.170995e-01, 4.387444e-01]]),
  565. np.array([[3.815585e-01, 1.868726e-01],
  566. [7.655168e-01, 4.897644e-01],
  567. [7.951999e-01, 4.455862e-01]]),
  568. np.eye(3),
  569. np.eye(2),
  570. np.array([[6.463130e-01, 2.760251e-01, 1.626117e-01],
  571. [7.093648e-01, 6.797027e-01, 1.189977e-01],
  572. [7.546867e-01, 6.550980e-01, 4.983641e-01]]),
  573. np.zeros((3, 2)),
  574. None),
  575. (np.array([[2.769230e-01, 8.234578e-01, 9.502220e-01],
  576. [4.617139e-02, 6.948286e-01, 3.444608e-02],
  577. [9.713178e-02, 3.170995e-01, 4.387444e-01]]),
  578. np.array([[3.815585e-01, 1.868726e-01],
  579. [7.655168e-01, 4.897644e-01],
  580. [7.951999e-01, 4.455862e-01]]),
  581. np.eye(3),
  582. np.eye(2),
  583. np.array([[6.463130e-01, 2.760251e-01, 1.626117e-01],
  584. [7.093648e-01, 6.797027e-01, 1.189977e-01],
  585. [7.546867e-01, 6.550980e-01, 4.983641e-01]]),
  586. np.ones((3, 2)),
  587. None),
  588. # user-reported (under PR-6616) 20-Jan-2017
  589. # tests against the case where E is None but S is provided
  590. (mat20170120['A'],
  591. mat20170120['B'],
  592. mat20170120['Q'],
  593. mat20170120['R'],
  594. None,
  595. mat20170120['S'],
  596. None),
  597. ]
  598. min_decimal = (11, 11, 16)
  599. def _test_factory(case, dec):
  600. """Checks if X = A'XA-(A'XB)(R+B'XB)^-1(B'XA)+Q) is true"""
  601. a, b, q, r, e, s, knownfailure = case
  602. if knownfailure:
  603. pytest.xfail(reason=knownfailure)
  604. x = solve_discrete_are(a, b, q, r, e, s)
  605. if e is None:
  606. e = np.eye(a.shape[0])
  607. if s is None:
  608. s = np.zeros_like(b)
  609. res = a.conj().T.dot(x.dot(a)) - e.conj().T.dot(x.dot(e)) + q
  610. res -= (a.conj().T.dot(x.dot(b)) + s).dot(
  611. solve(r+b.conj().T.dot(x.dot(b)),
  612. (b.conj().T.dot(x.dot(a)) + s.conj().T)
  613. )
  614. )
  615. assert_array_almost_equal(res, np.zeros_like(res), decimal=dec)
  616. for ind, case in enumerate(cases):
  617. _test_factory(case, min_decimal[ind])
  618. def test_are_validate_args():
  619. def test_square_shape():
  620. nsq = np.ones((3, 2))
  621. sq = np.eye(3)
  622. for x in (solve_continuous_are, solve_discrete_are):
  623. assert_raises(ValueError, x, nsq, 1, 1, 1)
  624. assert_raises(ValueError, x, sq, sq, nsq, 1)
  625. assert_raises(ValueError, x, sq, sq, sq, nsq)
  626. assert_raises(ValueError, x, sq, sq, sq, sq, nsq)
  627. def test_compatible_sizes():
  628. nsq = np.ones((3, 2))
  629. sq = np.eye(4)
  630. for x in (solve_continuous_are, solve_discrete_are):
  631. assert_raises(ValueError, x, sq, nsq, 1, 1)
  632. assert_raises(ValueError, x, sq, sq, sq, sq, sq, nsq)
  633. assert_raises(ValueError, x, sq, sq, np.eye(3), sq)
  634. assert_raises(ValueError, x, sq, sq, sq, np.eye(3))
  635. assert_raises(ValueError, x, sq, sq, sq, sq, np.eye(3))
  636. def test_symmetry():
  637. nsym = np.arange(9).reshape(3, 3)
  638. sym = np.eye(3)
  639. for x in (solve_continuous_are, solve_discrete_are):
  640. assert_raises(ValueError, x, sym, sym, nsym, sym)
  641. assert_raises(ValueError, x, sym, sym, sym, nsym)
  642. def test_singularity():
  643. sing = np.full((3, 3), 1e12)
  644. sing[2, 2] -= 1
  645. sq = np.eye(3)
  646. for x in (solve_continuous_are, solve_discrete_are):
  647. assert_raises(ValueError, x, sq, sq, sq, sq, sing)
  648. assert_raises(ValueError, solve_continuous_are, sq, sq, sq, sing)
  649. def test_finiteness():
  650. nm = np.full((2, 2), np.nan)
  651. sq = np.eye(2)
  652. for x in (solve_continuous_are, solve_discrete_are):
  653. assert_raises(ValueError, x, nm, sq, sq, sq)
  654. assert_raises(ValueError, x, sq, nm, sq, sq)
  655. assert_raises(ValueError, x, sq, sq, nm, sq)
  656. assert_raises(ValueError, x, sq, sq, sq, nm)
  657. assert_raises(ValueError, x, sq, sq, sq, sq, nm)
  658. assert_raises(ValueError, x, sq, sq, sq, sq, sq, nm)
  659. class TestSolveSylvester:
  660. cases = [
  661. # a, b, c all real.
  662. (np.array([[1, 2], [0, 4]]),
  663. np.array([[5, 6], [0, 8]]),
  664. np.array([[9, 10], [11, 12]])),
  665. # a, b, c all real, 4x4. a and b have non-trival 2x2 blocks in their
  666. # quasi-triangular form.
  667. (np.array([[1.0, 0, 0, 0],
  668. [0, 1.0, 2.0, 0.0],
  669. [0, 0, 3.0, -4],
  670. [0, 0, 2, 5]]),
  671. np.array([[2.0, 0, 0, 1.0],
  672. [0, 1.0, 0.0, 0.0],
  673. [0, 0, 1.0, -1],
  674. [0, 0, 1, 1]]),
  675. np.array([[1.0, 0, 0, 0],
  676. [0, 1.0, 0, 0],
  677. [0, 0, 1.0, 0],
  678. [0, 0, 0, 1.0]])),
  679. # a, b, c all complex.
  680. (np.array([[1.0+1j, 2.0], [3.0-4.0j, 5.0]]),
  681. np.array([[-1.0, 2j], [3.0, 4.0]]),
  682. np.array([[2.0-2j, 2.0+2j], [-1.0-1j, 2.0]])),
  683. # a and b real; c complex.
  684. (np.array([[1.0, 2.0], [3.0, 5.0]]),
  685. np.array([[-1.0, 0], [3.0, 4.0]]),
  686. np.array([[2.0-2j, 2.0+2j], [-1.0-1j, 2.0]])),
  687. # a and c complex; b real.
  688. (np.array([[1.0+1j, 2.0], [3.0-4.0j, 5.0]]),
  689. np.array([[-1.0, 0], [3.0, 4.0]]),
  690. np.array([[2.0-2j, 2.0+2j], [-1.0-1j, 2.0]])),
  691. # a complex; b and c real.
  692. (np.array([[1.0+1j, 2.0], [3.0-4.0j, 5.0]]),
  693. np.array([[-1.0, 0], [3.0, 4.0]]),
  694. np.array([[2.0, 2.0], [-1.0, 2.0]])),
  695. # not square matrices, real
  696. (np.array([[8, 1, 6], [3, 5, 7], [4, 9, 2]]),
  697. np.array([[2, 3], [4, 5]]),
  698. np.array([[1, 2], [3, 4], [5, 6]])),
  699. # not square matrices, complex
  700. (np.array([[8, 1j, 6+2j], [3, 5, 7], [4, 9, 2]]),
  701. np.array([[2, 3], [4, 5-1j]]),
  702. np.array([[1, 2j], [3, 4j], [5j, 6+7j]])),
  703. ]
  704. def check_case(self, a, b, c):
  705. x = solve_sylvester(a, b, c)
  706. assert_array_almost_equal(np.dot(a, x) + np.dot(x, b), c)
  707. def test_cases(self):
  708. for case in self.cases:
  709. self.check_case(case[0], case[1], case[2])
  710. def test_trivial(self):
  711. a = np.array([[1.0, 0.0], [0.0, 1.0]])
  712. b = np.array([[1.0]])
  713. c = np.array([2.0, 2.0]).reshape(-1, 1)
  714. x = solve_sylvester(a, b, c)
  715. assert_array_almost_equal(x, np.array([1.0, 1.0]).reshape(-1, 1))