test_solvers.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. from sympy.core.function import expand_mul
  2. from sympy.core.numbers import (I, Rational)
  3. from sympy.core.singleton import S
  4. from sympy.core.symbol import (Symbol, symbols)
  5. from sympy.core.sympify import sympify
  6. from sympy.simplify.simplify import simplify
  7. from sympy.matrices.matrices import (ShapeError, NonSquareMatrixError)
  8. from sympy.matrices import (
  9. ImmutableMatrix, Matrix, eye, ones, ImmutableDenseMatrix, dotprodsimp)
  10. from sympy.testing.pytest import raises
  11. from sympy.matrices.common import NonInvertibleMatrixError
  12. from sympy.solvers.solveset import linsolve
  13. from sympy.abc import x, y
  14. def test_issue_17247_expression_blowup_29():
  15. M = Matrix(S('''[
  16. [ -3/4, 45/32 - 37*I/16, 0, 0],
  17. [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128],
  18. [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0],
  19. [ 0, 0, 0, -177/128 - 1369*I/128]]'''))
  20. with dotprodsimp(True):
  21. assert M.gauss_jordan_solve(ones(4, 1)) == (Matrix(S('''[
  22. [ -32549314808672/3306971225785 - 17397006745216*I/3306971225785],
  23. [ 67439348256/3306971225785 - 9167503335872*I/3306971225785],
  24. [-15091965363354518272/21217636514687010905 + 16890163109293858304*I/21217636514687010905],
  25. [ -11328/952745 + 87616*I/952745]]''')), Matrix(0, 1, []))
  26. def test_issue_17247_expression_blowup_30():
  27. M = Matrix(S('''[
  28. [ -3/4, 45/32 - 37*I/16, 0, 0],
  29. [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128],
  30. [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0],
  31. [ 0, 0, 0, -177/128 - 1369*I/128]]'''))
  32. with dotprodsimp(True):
  33. assert M.cholesky_solve(ones(4, 1)) == Matrix(S('''[
  34. [ -32549314808672/3306971225785 - 17397006745216*I/3306971225785],
  35. [ 67439348256/3306971225785 - 9167503335872*I/3306971225785],
  36. [-15091965363354518272/21217636514687010905 + 16890163109293858304*I/21217636514687010905],
  37. [ -11328/952745 + 87616*I/952745]]'''))
  38. # @XFAIL # This calculation hangs with dotprodsimp.
  39. # def test_issue_17247_expression_blowup_31():
  40. # M = Matrix([
  41. # [x + 1, 1 - x, 0, 0],
  42. # [1 - x, x + 1, 0, x + 1],
  43. # [ 0, 1 - x, x + 1, 0],
  44. # [ 0, 0, 0, x + 1]])
  45. # with dotprodsimp(True):
  46. # assert M.LDLsolve(ones(4, 1)) == Matrix([
  47. # [(x + 1)/(4*x)],
  48. # [(x - 1)/(4*x)],
  49. # [(x + 1)/(4*x)],
  50. # [ 1/(x + 1)]])
  51. def test_issue_17247_expression_blowup_32():
  52. M = Matrix([
  53. [x + 1, 1 - x, 0, 0],
  54. [1 - x, x + 1, 0, x + 1],
  55. [ 0, 1 - x, x + 1, 0],
  56. [ 0, 0, 0, x + 1]])
  57. with dotprodsimp(True):
  58. assert M.LUsolve(ones(4, 1)) == Matrix([
  59. [(x + 1)/(4*x)],
  60. [(x - 1)/(4*x)],
  61. [(x + 1)/(4*x)],
  62. [ 1/(x + 1)]])
  63. def test_LUsolve():
  64. A = Matrix([[2, 3, 5],
  65. [3, 6, 2],
  66. [8, 3, 6]])
  67. x = Matrix(3, 1, [3, 7, 5])
  68. b = A*x
  69. soln = A.LUsolve(b)
  70. assert soln == x
  71. A = Matrix([[0, -1, 2],
  72. [5, 10, 7],
  73. [8, 3, 4]])
  74. x = Matrix(3, 1, [-1, 2, 5])
  75. b = A*x
  76. soln = A.LUsolve(b)
  77. assert soln == x
  78. A = Matrix([[2, 1], [1, 0], [1, 0]]) # issue 14548
  79. b = Matrix([3, 1, 1])
  80. assert A.LUsolve(b) == Matrix([1, 1])
  81. b = Matrix([3, 1, 2]) # inconsistent
  82. raises(ValueError, lambda: A.LUsolve(b))
  83. A = Matrix([[0, -1, 2],
  84. [5, 10, 7],
  85. [8, 3, 4],
  86. [2, 3, 5],
  87. [3, 6, 2],
  88. [8, 3, 6]])
  89. x = Matrix([2, 1, -4])
  90. b = A*x
  91. soln = A.LUsolve(b)
  92. assert soln == x
  93. A = Matrix([[0, -1, 2], [5, 10, 7]]) # underdetermined
  94. x = Matrix([-1, 2, 0])
  95. b = A*x
  96. raises(NotImplementedError, lambda: A.LUsolve(b))
  97. A = Matrix(4, 4, lambda i, j: 1/(i+j+1) if i != 3 else 0)
  98. b = Matrix.zeros(4, 1)
  99. raises(NonInvertibleMatrixError, lambda: A.LUsolve(b))
  100. def test_QRsolve():
  101. A = Matrix([[2, 3, 5],
  102. [3, 6, 2],
  103. [8, 3, 6]])
  104. x = Matrix(3, 1, [3, 7, 5])
  105. b = A*x
  106. soln = A.QRsolve(b)
  107. assert soln == x
  108. x = Matrix([[1, 2], [3, 4], [5, 6]])
  109. b = A*x
  110. soln = A.QRsolve(b)
  111. assert soln == x
  112. A = Matrix([[0, -1, 2],
  113. [5, 10, 7],
  114. [8, 3, 4]])
  115. x = Matrix(3, 1, [-1, 2, 5])
  116. b = A*x
  117. soln = A.QRsolve(b)
  118. assert soln == x
  119. x = Matrix([[7, 8], [9, 10], [11, 12]])
  120. b = A*x
  121. soln = A.QRsolve(b)
  122. assert soln == x
  123. def test_errors():
  124. raises(ShapeError, lambda: Matrix([1]).LUsolve(Matrix([[1, 2], [3, 4]])))
  125. def test_cholesky_solve():
  126. A = Matrix([[2, 3, 5],
  127. [3, 6, 2],
  128. [8, 3, 6]])
  129. x = Matrix(3, 1, [3, 7, 5])
  130. b = A*x
  131. soln = A.cholesky_solve(b)
  132. assert soln == x
  133. A = Matrix([[0, -1, 2],
  134. [5, 10, 7],
  135. [8, 3, 4]])
  136. x = Matrix(3, 1, [-1, 2, 5])
  137. b = A*x
  138. soln = A.cholesky_solve(b)
  139. assert soln == x
  140. A = Matrix(((1, 5), (5, 1)))
  141. x = Matrix((4, -3))
  142. b = A*x
  143. soln = A.cholesky_solve(b)
  144. assert soln == x
  145. A = Matrix(((9, 3*I), (-3*I, 5)))
  146. x = Matrix((-2, 1))
  147. b = A*x
  148. soln = A.cholesky_solve(b)
  149. assert expand_mul(soln) == x
  150. A = Matrix(((9*I, 3), (-3 + I, 5)))
  151. x = Matrix((2 + 3*I, -1))
  152. b = A*x
  153. soln = A.cholesky_solve(b)
  154. assert expand_mul(soln) == x
  155. a00, a01, a11, b0, b1 = symbols('a00, a01, a11, b0, b1')
  156. A = Matrix(((a00, a01), (a01, a11)))
  157. b = Matrix((b0, b1))
  158. x = A.cholesky_solve(b)
  159. assert simplify(A*x) == b
  160. def test_LDLsolve():
  161. A = Matrix([[2, 3, 5],
  162. [3, 6, 2],
  163. [8, 3, 6]])
  164. x = Matrix(3, 1, [3, 7, 5])
  165. b = A*x
  166. soln = A.LDLsolve(b)
  167. assert soln == x
  168. A = Matrix([[0, -1, 2],
  169. [5, 10, 7],
  170. [8, 3, 4]])
  171. x = Matrix(3, 1, [-1, 2, 5])
  172. b = A*x
  173. soln = A.LDLsolve(b)
  174. assert soln == x
  175. A = Matrix(((9, 3*I), (-3*I, 5)))
  176. x = Matrix((-2, 1))
  177. b = A*x
  178. soln = A.LDLsolve(b)
  179. assert expand_mul(soln) == x
  180. A = Matrix(((9*I, 3), (-3 + I, 5)))
  181. x = Matrix((2 + 3*I, -1))
  182. b = A*x
  183. soln = A.LDLsolve(b)
  184. assert expand_mul(soln) == x
  185. A = Matrix(((9, 3), (3, 9)))
  186. x = Matrix((1, 1))
  187. b = A * x
  188. soln = A.LDLsolve(b)
  189. assert expand_mul(soln) == x
  190. A = Matrix([[-5, -3, -4], [-3, -7, 7]])
  191. x = Matrix([[8], [7], [-2]])
  192. b = A * x
  193. raises(NotImplementedError, lambda: A.LDLsolve(b))
  194. def test_lower_triangular_solve():
  195. raises(NonSquareMatrixError,
  196. lambda: Matrix([1, 0]).lower_triangular_solve(Matrix([0, 1])))
  197. raises(ShapeError,
  198. lambda: Matrix([[1, 0], [0, 1]]).lower_triangular_solve(Matrix([1])))
  199. raises(ValueError,
  200. lambda: Matrix([[2, 1], [1, 2]]).lower_triangular_solve(
  201. Matrix([[1, 0], [0, 1]])))
  202. A = Matrix([[1, 0], [0, 1]])
  203. B = Matrix([[x, y], [y, x]])
  204. C = Matrix([[4, 8], [2, 9]])
  205. assert A.lower_triangular_solve(B) == B
  206. assert A.lower_triangular_solve(C) == C
  207. def test_upper_triangular_solve():
  208. raises(NonSquareMatrixError,
  209. lambda: Matrix([1, 0]).upper_triangular_solve(Matrix([0, 1])))
  210. raises(ShapeError,
  211. lambda: Matrix([[1, 0], [0, 1]]).upper_triangular_solve(Matrix([1])))
  212. raises(TypeError,
  213. lambda: Matrix([[2, 1], [1, 2]]).upper_triangular_solve(
  214. Matrix([[1, 0], [0, 1]])))
  215. A = Matrix([[1, 0], [0, 1]])
  216. B = Matrix([[x, y], [y, x]])
  217. C = Matrix([[2, 4], [3, 8]])
  218. assert A.upper_triangular_solve(B) == B
  219. assert A.upper_triangular_solve(C) == C
  220. def test_diagonal_solve():
  221. raises(TypeError, lambda: Matrix([1, 1]).diagonal_solve(Matrix([1])))
  222. A = Matrix([[1, 0], [0, 1]])*2
  223. B = Matrix([[x, y], [y, x]])
  224. assert A.diagonal_solve(B) == B/2
  225. A = Matrix([[1, 0], [1, 2]])
  226. raises(TypeError, lambda: A.diagonal_solve(B))
  227. def test_pinv_solve():
  228. # Fully determined system (unique result, identical to other solvers).
  229. A = Matrix([[1, 5], [7, 9]])
  230. B = Matrix([12, 13])
  231. assert A.pinv_solve(B) == A.cholesky_solve(B)
  232. assert A.pinv_solve(B) == A.LDLsolve(B)
  233. assert A.pinv_solve(B) == Matrix([sympify('-43/26'), sympify('71/26')])
  234. assert A * A.pinv() * B == B
  235. # Fully determined, with two-dimensional B matrix.
  236. B = Matrix([[12, 13, 14], [15, 16, 17]])
  237. assert A.pinv_solve(B) == A.cholesky_solve(B)
  238. assert A.pinv_solve(B) == A.LDLsolve(B)
  239. assert A.pinv_solve(B) == Matrix([[-33, -37, -41], [69, 75, 81]]) / 26
  240. assert A * A.pinv() * B == B
  241. # Underdetermined system (infinite results).
  242. A = Matrix([[1, 0, 1], [0, 1, 1]])
  243. B = Matrix([5, 7])
  244. solution = A.pinv_solve(B)
  245. w = {}
  246. for s in solution.atoms(Symbol):
  247. # Extract dummy symbols used in the solution.
  248. w[s.name] = s
  249. assert solution == Matrix([[w['w0_0']/3 + w['w1_0']/3 - w['w2_0']/3 + 1],
  250. [w['w0_0']/3 + w['w1_0']/3 - w['w2_0']/3 + 3],
  251. [-w['w0_0']/3 - w['w1_0']/3 + w['w2_0']/3 + 4]])
  252. assert A * A.pinv() * B == B
  253. # Overdetermined system (least squares results).
  254. A = Matrix([[1, 0], [0, 0], [0, 1]])
  255. B = Matrix([3, 2, 1])
  256. assert A.pinv_solve(B) == Matrix([3, 1])
  257. # Proof the solution is not exact.
  258. assert A * A.pinv() * B != B
  259. def test_pinv_rank_deficient():
  260. # Test the four properties of the pseudoinverse for various matrices.
  261. As = [Matrix([[1, 1, 1], [2, 2, 2]]),
  262. Matrix([[1, 0], [0, 0]]),
  263. Matrix([[1, 2], [2, 4], [3, 6]])]
  264. for A in As:
  265. A_pinv = A.pinv(method="RD")
  266. AAp = A * A_pinv
  267. ApA = A_pinv * A
  268. assert simplify(AAp * A) == A
  269. assert simplify(ApA * A_pinv) == A_pinv
  270. assert AAp.H == AAp
  271. assert ApA.H == ApA
  272. for A in As:
  273. A_pinv = A.pinv(method="ED")
  274. AAp = A * A_pinv
  275. ApA = A_pinv * A
  276. assert simplify(AAp * A) == A
  277. assert simplify(ApA * A_pinv) == A_pinv
  278. assert AAp.H == AAp
  279. assert ApA.H == ApA
  280. # Test solving with rank-deficient matrices.
  281. A = Matrix([[1, 0], [0, 0]])
  282. # Exact, non-unique solution.
  283. B = Matrix([3, 0])
  284. solution = A.pinv_solve(B)
  285. w1 = solution.atoms(Symbol).pop()
  286. assert w1.name == 'w1_0'
  287. assert solution == Matrix([3, w1])
  288. assert A * A.pinv() * B == B
  289. # Least squares, non-unique solution.
  290. B = Matrix([3, 1])
  291. solution = A.pinv_solve(B)
  292. w1 = solution.atoms(Symbol).pop()
  293. assert w1.name == 'w1_0'
  294. assert solution == Matrix([3, w1])
  295. assert A * A.pinv() * B != B
  296. def test_gauss_jordan_solve():
  297. # Square, full rank, unique solution
  298. A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 10]])
  299. b = Matrix([3, 6, 9])
  300. sol, params = A.gauss_jordan_solve(b)
  301. assert sol == Matrix([[-1], [2], [0]])
  302. assert params == Matrix(0, 1, [])
  303. # Square, full rank, unique solution, B has more columns than rows
  304. A = eye(3)
  305. B = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
  306. sol, params = A.gauss_jordan_solve(B)
  307. assert sol == B
  308. assert params == Matrix(0, 4, [])
  309. # Square, reduced rank, parametrized solution
  310. A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  311. b = Matrix([3, 6, 9])
  312. sol, params, freevar = A.gauss_jordan_solve(b, freevar=True)
  313. w = {}
  314. for s in sol.atoms(Symbol):
  315. # Extract dummy symbols used in the solution.
  316. w[s.name] = s
  317. assert sol == Matrix([[w['tau0'] - 1], [-2*w['tau0'] + 2], [w['tau0']]])
  318. assert params == Matrix([[w['tau0']]])
  319. assert freevar == [2]
  320. # Square, reduced rank, parametrized solution, B has two columns
  321. A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  322. B = Matrix([[3, 4], [6, 8], [9, 12]])
  323. sol, params, freevar = A.gauss_jordan_solve(B, freevar=True)
  324. w = {}
  325. for s in sol.atoms(Symbol):
  326. # Extract dummy symbols used in the solution.
  327. w[s.name] = s
  328. assert sol == Matrix([[w['tau0'] - 1, w['tau1'] - Rational(4, 3)],
  329. [-2*w['tau0'] + 2, -2*w['tau1'] + Rational(8, 3)],
  330. [w['tau0'], w['tau1']],])
  331. assert params == Matrix([[w['tau0'], w['tau1']]])
  332. assert freevar == [2]
  333. # Square, reduced rank, parametrized solution
  334. A = Matrix([[1, 2, 3], [2, 4, 6], [3, 6, 9]])
  335. b = Matrix([0, 0, 0])
  336. sol, params = A.gauss_jordan_solve(b)
  337. w = {}
  338. for s in sol.atoms(Symbol):
  339. w[s.name] = s
  340. assert sol == Matrix([[-2*w['tau0'] - 3*w['tau1']],
  341. [w['tau0']], [w['tau1']]])
  342. assert params == Matrix([[w['tau0']], [w['tau1']]])
  343. # Square, reduced rank, parametrized solution
  344. A = Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
  345. b = Matrix([0, 0, 0])
  346. sol, params = A.gauss_jordan_solve(b)
  347. w = {}
  348. for s in sol.atoms(Symbol):
  349. w[s.name] = s
  350. assert sol == Matrix([[w['tau0']], [w['tau1']], [w['tau2']]])
  351. assert params == Matrix([[w['tau0']], [w['tau1']], [w['tau2']]])
  352. # Square, reduced rank, no solution
  353. A = Matrix([[1, 2, 3], [2, 4, 6], [3, 6, 9]])
  354. b = Matrix([0, 0, 1])
  355. raises(ValueError, lambda: A.gauss_jordan_solve(b))
  356. # Rectangular, tall, full rank, unique solution
  357. A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]])
  358. b = Matrix([0, 0, 1, 0])
  359. sol, params = A.gauss_jordan_solve(b)
  360. assert sol == Matrix([[Rational(-1, 2)], [0], [Rational(1, 6)]])
  361. assert params == Matrix(0, 1, [])
  362. # Rectangular, tall, full rank, unique solution, B has less columns than rows
  363. A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]])
  364. B = Matrix([[0,0], [0, 0], [1, 2], [0, 0]])
  365. sol, params = A.gauss_jordan_solve(B)
  366. assert sol == Matrix([[Rational(-1, 2), Rational(-2, 2)], [0, 0], [Rational(1, 6), Rational(2, 6)]])
  367. assert params == Matrix(0, 2, [])
  368. # Rectangular, tall, full rank, no solution
  369. A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]])
  370. b = Matrix([0, 0, 0, 1])
  371. raises(ValueError, lambda: A.gauss_jordan_solve(b))
  372. # Rectangular, tall, full rank, no solution, B has two columns (2nd has no solution)
  373. A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]])
  374. B = Matrix([[0,0], [0, 0], [1, 0], [0, 1]])
  375. raises(ValueError, lambda: A.gauss_jordan_solve(B))
  376. # Rectangular, tall, full rank, no solution, B has two columns (1st has no solution)
  377. A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]])
  378. B = Matrix([[0,0], [0, 0], [0, 1], [1, 0]])
  379. raises(ValueError, lambda: A.gauss_jordan_solve(B))
  380. # Rectangular, tall, reduced rank, parametrized solution
  381. A = Matrix([[1, 5, 3], [2, 10, 6], [3, 15, 9], [1, 4, 3]])
  382. b = Matrix([0, 0, 0, 1])
  383. sol, params = A.gauss_jordan_solve(b)
  384. w = {}
  385. for s in sol.atoms(Symbol):
  386. w[s.name] = s
  387. assert sol == Matrix([[-3*w['tau0'] + 5], [-1], [w['tau0']]])
  388. assert params == Matrix([[w['tau0']]])
  389. # Rectangular, tall, reduced rank, no solution
  390. A = Matrix([[1, 5, 3], [2, 10, 6], [3, 15, 9], [1, 4, 3]])
  391. b = Matrix([0, 0, 1, 1])
  392. raises(ValueError, lambda: A.gauss_jordan_solve(b))
  393. # Rectangular, wide, full rank, parametrized solution
  394. A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 1, 12]])
  395. b = Matrix([1, 1, 1])
  396. sol, params = A.gauss_jordan_solve(b)
  397. w = {}
  398. for s in sol.atoms(Symbol):
  399. w[s.name] = s
  400. assert sol == Matrix([[2*w['tau0'] - 1], [-3*w['tau0'] + 1], [0],
  401. [w['tau0']]])
  402. assert params == Matrix([[w['tau0']]])
  403. # Rectangular, wide, reduced rank, parametrized solution
  404. A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [2, 4, 6, 8]])
  405. b = Matrix([0, 1, 0])
  406. sol, params = A.gauss_jordan_solve(b)
  407. w = {}
  408. for s in sol.atoms(Symbol):
  409. w[s.name] = s
  410. assert sol == Matrix([[w['tau0'] + 2*w['tau1'] + S.Half],
  411. [-2*w['tau0'] - 3*w['tau1'] - Rational(1, 4)],
  412. [w['tau0']], [w['tau1']]])
  413. assert params == Matrix([[w['tau0']], [w['tau1']]])
  414. # watch out for clashing symbols
  415. x0, x1, x2, _x0 = symbols('_tau0 _tau1 _tau2 tau1')
  416. M = Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]])
  417. A = M[:, :-1]
  418. b = M[:, -1:]
  419. sol, params = A.gauss_jordan_solve(b)
  420. assert params == Matrix(3, 1, [x0, x1, x2])
  421. assert sol == Matrix(5, 1, [x0, 0, x1, _x0, x2])
  422. # Rectangular, wide, reduced rank, no solution
  423. A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [2, 4, 6, 8]])
  424. b = Matrix([1, 1, 1])
  425. raises(ValueError, lambda: A.gauss_jordan_solve(b))
  426. # Test for immutable matrix
  427. A = ImmutableMatrix([[1, 0], [0, 1]])
  428. B = ImmutableMatrix([1, 2])
  429. sol, params = A.gauss_jordan_solve(B)
  430. assert sol == ImmutableMatrix([1, 2])
  431. assert params == ImmutableMatrix(0, 1, [])
  432. assert sol.__class__ == ImmutableDenseMatrix
  433. assert params.__class__ == ImmutableDenseMatrix
  434. # Test placement of free variables
  435. A = Matrix([[1, 0, 0, 0], [0, 0, 0, 1]])
  436. b = Matrix([1, 1])
  437. sol, params = A.gauss_jordan_solve(b)
  438. w = {}
  439. for s in sol.atoms(Symbol):
  440. w[s.name] = s
  441. assert sol == Matrix([[1], [w['tau0']], [w['tau1']], [1]])
  442. assert params == Matrix([[w['tau0']], [w['tau1']]])
  443. def test_linsolve_underdetermined_AND_gauss_jordan_solve():
  444. #Test placement of free variables as per issue 19815
  445. A = Matrix([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  446. [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
  447. [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
  448. [0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
  449. [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0],
  450. [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0],
  451. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
  452. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]])
  453. B = Matrix([1, 2, 1, 1, 1, 1, 1, 2])
  454. sol, params = A.gauss_jordan_solve(B)
  455. w = {}
  456. for s in sol.atoms(Symbol):
  457. w[s.name] = s
  458. assert params == Matrix([[w['tau0']], [w['tau1']], [w['tau2']],
  459. [w['tau3']], [w['tau4']], [w['tau5']]])
  460. assert sol == Matrix([[1 - 1*w['tau2']],
  461. [w['tau2']],
  462. [1 - 1*w['tau0'] + w['tau1']],
  463. [w['tau0']],
  464. [w['tau3'] + w['tau4']],
  465. [-1*w['tau3'] - 1*w['tau4'] - 1*w['tau1']],
  466. [1 - 1*w['tau2']],
  467. [w['tau1']],
  468. [w['tau2']],
  469. [w['tau3']],
  470. [w['tau4']],
  471. [1 - 1*w['tau5']],
  472. [w['tau5']],
  473. [1]])
  474. from sympy.abc import j,f
  475. # https://github.com/sympy/sympy/issues/20046
  476. A = Matrix([
  477. [1, 1, 1, 1, 1, 1, 1, 1, 1],
  478. [0, -1, 0, -1, 0, -1, 0, -1, -j],
  479. [0, 0, 0, 0, 1, 1, 1, 1, f]
  480. ])
  481. sol_1=Matrix(list(linsolve(A))[0])
  482. tau0, tau1, tau2, tau3, tau4 = symbols('tau:5')
  483. assert sol_1 == Matrix([[-f - j - tau0 + tau2 + tau4 + 1],
  484. [j - tau1 - tau2 - tau4],
  485. [tau0],
  486. [tau1],
  487. [f - tau2 - tau3 - tau4],
  488. [tau2],
  489. [tau3],
  490. [tau4]])
  491. # https://github.com/sympy/sympy/issues/19815
  492. sol_2 = A[:, : -1 ] * sol_1 - A[:, -1 ]
  493. assert sol_2 == Matrix([[0], [0], [0]])
  494. def test_solve():
  495. A = Matrix([[1,2], [2,4]])
  496. b = Matrix([[3], [4]])
  497. raises(ValueError, lambda: A.solve(b)) #no solution
  498. b = Matrix([[ 4], [8]])
  499. raises(ValueError, lambda: A.solve(b)) #infinite solution