blockmatrix.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979
  1. from sympy.assumptions.ask import (Q, ask)
  2. from sympy.core import Basic, Add, Mul, S
  3. from sympy.core.sympify import _sympify
  4. from sympy.functions import adjoint
  5. from sympy.functions.elementary.complexes import re, im
  6. from sympy.strategies import typed, exhaust, condition, do_one, unpack
  7. from sympy.strategies.traverse import bottom_up
  8. from sympy.utilities.iterables import is_sequence, sift
  9. from sympy.utilities.misc import filldedent
  10. from sympy.matrices import Matrix, ShapeError
  11. from sympy.matrices.common import NonInvertibleMatrixError
  12. from sympy.matrices.expressions.determinant import det, Determinant
  13. from sympy.matrices.expressions.inverse import Inverse
  14. from sympy.matrices.expressions.matadd import MatAdd
  15. from sympy.matrices.expressions.matexpr import MatrixExpr, MatrixElement
  16. from sympy.matrices.expressions.matmul import MatMul
  17. from sympy.matrices.expressions.matpow import MatPow
  18. from sympy.matrices.expressions.slice import MatrixSlice
  19. from sympy.matrices.expressions.special import ZeroMatrix, Identity
  20. from sympy.matrices.expressions.trace import trace
  21. from sympy.matrices.expressions.transpose import Transpose, transpose
  22. class BlockMatrix(MatrixExpr):
  23. """A BlockMatrix is a Matrix comprised of other matrices.
  24. The submatrices are stored in a SymPy Matrix object but accessed as part of
  25. a Matrix Expression
  26. >>> from sympy import (MatrixSymbol, BlockMatrix, symbols,
  27. ... Identity, ZeroMatrix, block_collapse)
  28. >>> n,m,l = symbols('n m l')
  29. >>> X = MatrixSymbol('X', n, n)
  30. >>> Y = MatrixSymbol('Y', m, m)
  31. >>> Z = MatrixSymbol('Z', n, m)
  32. >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m,n), Y]])
  33. >>> print(B)
  34. Matrix([
  35. [X, Z],
  36. [0, Y]])
  37. >>> C = BlockMatrix([[Identity(n), Z]])
  38. >>> print(C)
  39. Matrix([[I, Z]])
  40. >>> print(block_collapse(C*B))
  41. Matrix([[X, Z + Z*Y]])
  42. Some matrices might be comprised of rows of blocks with
  43. the matrices in each row having the same height and the
  44. rows all having the same total number of columns but
  45. not having the same number of columns for each matrix
  46. in each row. In this case, the matrix is not a block
  47. matrix and should be instantiated by Matrix.
  48. >>> from sympy import ones, Matrix
  49. >>> dat = [
  50. ... [ones(3,2), ones(3,3)*2],
  51. ... [ones(2,3)*3, ones(2,2)*4]]
  52. ...
  53. >>> BlockMatrix(dat)
  54. Traceback (most recent call last):
  55. ...
  56. ValueError:
  57. Although this matrix is comprised of blocks, the blocks do not fill
  58. the matrix in a size-symmetric fashion. To create a full matrix from
  59. these arguments, pass them directly to Matrix.
  60. >>> Matrix(dat)
  61. Matrix([
  62. [1, 1, 2, 2, 2],
  63. [1, 1, 2, 2, 2],
  64. [1, 1, 2, 2, 2],
  65. [3, 3, 3, 4, 4],
  66. [3, 3, 3, 4, 4]])
  67. See Also
  68. ========
  69. sympy.matrices.matrices.MatrixBase.irregular
  70. """
  71. def __new__(cls, *args, **kwargs):
  72. from sympy.matrices.immutable import ImmutableDenseMatrix
  73. isMat = lambda i: getattr(i, 'is_Matrix', False)
  74. if len(args) != 1 or \
  75. not is_sequence(args[0]) or \
  76. len({isMat(r) for r in args[0]}) != 1:
  77. raise ValueError(filldedent('''
  78. expecting a sequence of 1 or more rows
  79. containing Matrices.'''))
  80. rows = args[0] if args else []
  81. if not isMat(rows):
  82. if rows and isMat(rows[0]):
  83. rows = [rows] # rows is not list of lists or []
  84. # regularity check
  85. # same number of matrices in each row
  86. blocky = ok = len({len(r) for r in rows}) == 1
  87. if ok:
  88. # same number of rows for each matrix in a row
  89. for r in rows:
  90. ok = len({i.rows for i in r}) == 1
  91. if not ok:
  92. break
  93. blocky = ok
  94. if ok:
  95. # same number of cols for each matrix in each col
  96. for c in range(len(rows[0])):
  97. ok = len({rows[i][c].cols
  98. for i in range(len(rows))}) == 1
  99. if not ok:
  100. break
  101. if not ok:
  102. # same total cols in each row
  103. ok = len({
  104. sum([i.cols for i in r]) for r in rows}) == 1
  105. if blocky and ok:
  106. raise ValueError(filldedent('''
  107. Although this matrix is comprised of blocks,
  108. the blocks do not fill the matrix in a
  109. size-symmetric fashion. To create a full matrix
  110. from these arguments, pass them directly to
  111. Matrix.'''))
  112. raise ValueError(filldedent('''
  113. When there are not the same number of rows in each
  114. row's matrices or there are not the same number of
  115. total columns in each row, the matrix is not a
  116. block matrix. If this matrix is known to consist of
  117. blocks fully filling a 2-D space then see
  118. Matrix.irregular.'''))
  119. mat = ImmutableDenseMatrix(rows, evaluate=False)
  120. obj = Basic.__new__(cls, mat)
  121. return obj
  122. @property
  123. def shape(self):
  124. numrows = numcols = 0
  125. M = self.blocks
  126. for i in range(M.shape[0]):
  127. numrows += M[i, 0].shape[0]
  128. for i in range(M.shape[1]):
  129. numcols += M[0, i].shape[1]
  130. return (numrows, numcols)
  131. @property
  132. def blockshape(self):
  133. return self.blocks.shape
  134. @property
  135. def blocks(self):
  136. return self.args[0]
  137. @property
  138. def rowblocksizes(self):
  139. return [self.blocks[i, 0].rows for i in range(self.blockshape[0])]
  140. @property
  141. def colblocksizes(self):
  142. return [self.blocks[0, i].cols for i in range(self.blockshape[1])]
  143. def structurally_equal(self, other):
  144. return (isinstance(other, BlockMatrix)
  145. and self.shape == other.shape
  146. and self.blockshape == other.blockshape
  147. and self.rowblocksizes == other.rowblocksizes
  148. and self.colblocksizes == other.colblocksizes)
  149. def _blockmul(self, other):
  150. if (isinstance(other, BlockMatrix) and
  151. self.colblocksizes == other.rowblocksizes):
  152. return BlockMatrix(self.blocks*other.blocks)
  153. return self * other
  154. def _blockadd(self, other):
  155. if (isinstance(other, BlockMatrix)
  156. and self.structurally_equal(other)):
  157. return BlockMatrix(self.blocks + other.blocks)
  158. return self + other
  159. def _eval_transpose(self):
  160. # Flip all the individual matrices
  161. matrices = [transpose(matrix) for matrix in self.blocks]
  162. # Make a copy
  163. M = Matrix(self.blockshape[0], self.blockshape[1], matrices)
  164. # Transpose the block structure
  165. M = M.transpose()
  166. return BlockMatrix(M)
  167. def _eval_adjoint(self):
  168. # Adjoint all the individual matrices
  169. matrices = [adjoint(matrix) for matrix in self.blocks]
  170. # Make a copy
  171. M = Matrix(self.blockshape[0], self.blockshape[1], matrices)
  172. # Transpose the block structure
  173. M = M.transpose()
  174. return BlockMatrix(M)
  175. def _eval_trace(self):
  176. if self.rowblocksizes == self.colblocksizes:
  177. return Add(*[trace(self.blocks[i, i])
  178. for i in range(self.blockshape[0])])
  179. raise NotImplementedError(
  180. "Can't perform trace of irregular blockshape")
  181. def _eval_determinant(self):
  182. if self.blockshape == (1, 1):
  183. return det(self.blocks[0, 0])
  184. if self.blockshape == (2, 2):
  185. [[A, B],
  186. [C, D]] = self.blocks.tolist()
  187. if ask(Q.invertible(A)):
  188. return det(A)*det(D - C*A.I*B)
  189. elif ask(Q.invertible(D)):
  190. return det(D)*det(A - B*D.I*C)
  191. return Determinant(self)
  192. def _eval_as_real_imag(self):
  193. real_matrices = [re(matrix) for matrix in self.blocks]
  194. real_matrices = Matrix(self.blockshape[0], self.blockshape[1], real_matrices)
  195. im_matrices = [im(matrix) for matrix in self.blocks]
  196. im_matrices = Matrix(self.blockshape[0], self.blockshape[1], im_matrices)
  197. return (BlockMatrix(real_matrices), BlockMatrix(im_matrices))
  198. def transpose(self):
  199. """Return transpose of matrix.
  200. Examples
  201. ========
  202. >>> from sympy import MatrixSymbol, BlockMatrix, ZeroMatrix
  203. >>> from sympy.abc import m, n
  204. >>> X = MatrixSymbol('X', n, n)
  205. >>> Y = MatrixSymbol('Y', m, m)
  206. >>> Z = MatrixSymbol('Z', n, m)
  207. >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m,n), Y]])
  208. >>> B.transpose()
  209. Matrix([
  210. [X.T, 0],
  211. [Z.T, Y.T]])
  212. >>> _.transpose()
  213. Matrix([
  214. [X, Z],
  215. [0, Y]])
  216. """
  217. return self._eval_transpose()
  218. def schur(self, mat = 'A', generalized = False):
  219. """Return the Schur Complement of the 2x2 BlockMatrix
  220. Parameters
  221. ==========
  222. mat : String, optional
  223. The matrix with respect to which the
  224. Schur Complement is calculated. 'A' is
  225. used by default
  226. generalized : bool, optional
  227. If True, returns the generalized Schur
  228. Component which uses Moore-Penrose Inverse
  229. Examples
  230. ========
  231. >>> from sympy import symbols, MatrixSymbol, BlockMatrix
  232. >>> m, n = symbols('m n')
  233. >>> A = MatrixSymbol('A', n, n)
  234. >>> B = MatrixSymbol('B', n, m)
  235. >>> C = MatrixSymbol('C', m, n)
  236. >>> D = MatrixSymbol('D', m, m)
  237. >>> X = BlockMatrix([[A, B], [C, D]])
  238. The default Schur Complement is evaluated with "A"
  239. >>> X.schur()
  240. -C*A**(-1)*B + D
  241. >>> X.schur('D')
  242. A - B*D**(-1)*C
  243. Schur complement with non-invertible matrices is not
  244. defined. Instead, the generalized Schur complement can
  245. be calculated which uses the Moore-Penrose Inverse. To
  246. achieve this, `generalized` must be set to `True`
  247. >>> X.schur('B', generalized=True)
  248. C - D*(B.T*B)**(-1)*B.T*A
  249. >>> X.schur('C', generalized=True)
  250. -A*(C.T*C)**(-1)*C.T*D + B
  251. Returns
  252. =======
  253. M : Matrix
  254. The Schur Complement Matrix
  255. Raises
  256. ======
  257. ShapeError
  258. If the block matrix is not a 2x2 matrix
  259. NonInvertibleMatrixError
  260. If given matrix is non-invertible
  261. References
  262. ==========
  263. .. [1] Wikipedia Article on Schur Component : https://en.wikipedia.org/wiki/Schur_complement
  264. See Also
  265. ========
  266. sympy.matrices.matrices.MatrixBase.pinv
  267. """
  268. if self.blockshape == (2, 2):
  269. [[A, B],
  270. [C, D]] = self.blocks.tolist()
  271. d={'A' : A, 'B' : B, 'C' : C, 'D' : D}
  272. try:
  273. inv = (d[mat].T*d[mat]).inv()*d[mat].T if generalized else d[mat].inv()
  274. if mat == 'A':
  275. return D - C * inv * B
  276. elif mat == 'B':
  277. return C - D * inv * A
  278. elif mat == 'C':
  279. return B - A * inv * D
  280. elif mat == 'D':
  281. return A - B * inv * C
  282. #For matrices where no sub-matrix is square
  283. return self
  284. except NonInvertibleMatrixError:
  285. raise NonInvertibleMatrixError('The given matrix is not invertible. Please set generalized=True \
  286. to compute the generalized Schur Complement which uses Moore-Penrose Inverse')
  287. else:
  288. raise ShapeError('Schur Complement can only be calculated for 2x2 block matrices')
  289. def LDUdecomposition(self):
  290. """Returns the Block LDU decomposition of
  291. a 2x2 Block Matrix
  292. Returns
  293. =======
  294. (L, D, U) : Matrices
  295. L : Lower Diagonal Matrix
  296. D : Diagonal Matrix
  297. U : Upper Diagonal Matrix
  298. Examples
  299. ========
  300. >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse
  301. >>> m, n = symbols('m n')
  302. >>> A = MatrixSymbol('A', n, n)
  303. >>> B = MatrixSymbol('B', n, m)
  304. >>> C = MatrixSymbol('C', m, n)
  305. >>> D = MatrixSymbol('D', m, m)
  306. >>> X = BlockMatrix([[A, B], [C, D]])
  307. >>> L, D, U = X.LDUdecomposition()
  308. >>> block_collapse(L*D*U)
  309. Matrix([
  310. [A, B],
  311. [C, D]])
  312. Raises
  313. ======
  314. ShapeError
  315. If the block matrix is not a 2x2 matrix
  316. NonInvertibleMatrixError
  317. If the matrix "A" is non-invertible
  318. See Also
  319. ========
  320. sympy.matrices.expressions.blockmatrix.BlockMatrix.UDLdecomposition
  321. sympy.matrices.expressions.blockmatrix.BlockMatrix.LUdecomposition
  322. """
  323. if self.blockshape == (2,2):
  324. [[A, B],
  325. [C, D]] = self.blocks.tolist()
  326. try:
  327. AI = A.I
  328. except NonInvertibleMatrixError:
  329. raise NonInvertibleMatrixError('Block LDU decomposition cannot be calculated when\
  330. "A" is singular')
  331. Ip = Identity(B.shape[0])
  332. Iq = Identity(B.shape[1])
  333. Z = ZeroMatrix(*B.shape)
  334. L = BlockMatrix([[Ip, Z], [C*AI, Iq]])
  335. D = BlockDiagMatrix(A, self.schur())
  336. U = BlockMatrix([[Ip, AI*B],[Z.T, Iq]])
  337. return L, D, U
  338. else:
  339. raise ShapeError("Block LDU decomposition is supported only for 2x2 block matrices")
  340. def UDLdecomposition(self):
  341. """Returns the Block UDL decomposition of
  342. a 2x2 Block Matrix
  343. Returns
  344. =======
  345. (U, D, L) : Matrices
  346. U : Upper Diagonal Matrix
  347. D : Diagonal Matrix
  348. L : Lower Diagonal Matrix
  349. Examples
  350. ========
  351. >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse
  352. >>> m, n = symbols('m n')
  353. >>> A = MatrixSymbol('A', n, n)
  354. >>> B = MatrixSymbol('B', n, m)
  355. >>> C = MatrixSymbol('C', m, n)
  356. >>> D = MatrixSymbol('D', m, m)
  357. >>> X = BlockMatrix([[A, B], [C, D]])
  358. >>> U, D, L = X.UDLdecomposition()
  359. >>> block_collapse(U*D*L)
  360. Matrix([
  361. [A, B],
  362. [C, D]])
  363. Raises
  364. ======
  365. ShapeError
  366. If the block matrix is not a 2x2 matrix
  367. NonInvertibleMatrixError
  368. If the matrix "D" is non-invertible
  369. See Also
  370. ========
  371. sympy.matrices.expressions.blockmatrix.BlockMatrix.LDUdecomposition
  372. sympy.matrices.expressions.blockmatrix.BlockMatrix.LUdecomposition
  373. """
  374. if self.blockshape == (2,2):
  375. [[A, B],
  376. [C, D]] = self.blocks.tolist()
  377. try:
  378. DI = D.I
  379. except NonInvertibleMatrixError:
  380. raise NonInvertibleMatrixError('Block UDL decomposition cannot be calculated when\
  381. "D" is singular')
  382. Ip = Identity(A.shape[0])
  383. Iq = Identity(B.shape[1])
  384. Z = ZeroMatrix(*B.shape)
  385. U = BlockMatrix([[Ip, B*DI], [Z.T, Iq]])
  386. D = BlockDiagMatrix(self.schur('D'), D)
  387. L = BlockMatrix([[Ip, Z],[DI*C, Iq]])
  388. return U, D, L
  389. else:
  390. raise ShapeError("Block UDL decomposition is supported only for 2x2 block matrices")
  391. def LUdecomposition(self):
  392. """Returns the Block LU decomposition of
  393. a 2x2 Block Matrix
  394. Returns
  395. =======
  396. (L, U) : Matrices
  397. L : Lower Diagonal Matrix
  398. U : Upper Diagonal Matrix
  399. Examples
  400. ========
  401. >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse
  402. >>> m, n = symbols('m n')
  403. >>> A = MatrixSymbol('A', n, n)
  404. >>> B = MatrixSymbol('B', n, m)
  405. >>> C = MatrixSymbol('C', m, n)
  406. >>> D = MatrixSymbol('D', m, m)
  407. >>> X = BlockMatrix([[A, B], [C, D]])
  408. >>> L, U = X.LUdecomposition()
  409. >>> block_collapse(L*U)
  410. Matrix([
  411. [A, B],
  412. [C, D]])
  413. Raises
  414. ======
  415. ShapeError
  416. If the block matrix is not a 2x2 matrix
  417. NonInvertibleMatrixError
  418. If the matrix "A" is non-invertible
  419. See Also
  420. ========
  421. sympy.matrices.expressions.blockmatrix.BlockMatrix.UDLdecomposition
  422. sympy.matrices.expressions.blockmatrix.BlockMatrix.LDUdecomposition
  423. """
  424. if self.blockshape == (2,2):
  425. [[A, B],
  426. [C, D]] = self.blocks.tolist()
  427. try:
  428. A = A**0.5
  429. AI = A.I
  430. except NonInvertibleMatrixError:
  431. raise NonInvertibleMatrixError('Block LU decomposition cannot be calculated when\
  432. "A" is singular')
  433. Z = ZeroMatrix(*B.shape)
  434. Q = self.schur()**0.5
  435. L = BlockMatrix([[A, Z], [C*AI, Q]])
  436. U = BlockMatrix([[A, AI*B],[Z.T, Q]])
  437. return L, U
  438. else:
  439. raise ShapeError("Block LU decomposition is supported only for 2x2 block matrices")
  440. def _entry(self, i, j, **kwargs):
  441. # Find row entry
  442. orig_i, orig_j = i, j
  443. for row_block, numrows in enumerate(self.rowblocksizes):
  444. cmp = i < numrows
  445. if cmp == True:
  446. break
  447. elif cmp == False:
  448. i -= numrows
  449. elif row_block < self.blockshape[0] - 1:
  450. # Can't tell which block and it's not the last one, return unevaluated
  451. return MatrixElement(self, orig_i, orig_j)
  452. for col_block, numcols in enumerate(self.colblocksizes):
  453. cmp = j < numcols
  454. if cmp == True:
  455. break
  456. elif cmp == False:
  457. j -= numcols
  458. elif col_block < self.blockshape[1] - 1:
  459. return MatrixElement(self, orig_i, orig_j)
  460. return self.blocks[row_block, col_block][i, j]
  461. @property
  462. def is_Identity(self):
  463. if self.blockshape[0] != self.blockshape[1]:
  464. return False
  465. for i in range(self.blockshape[0]):
  466. for j in range(self.blockshape[1]):
  467. if i==j and not self.blocks[i, j].is_Identity:
  468. return False
  469. if i!=j and not self.blocks[i, j].is_ZeroMatrix:
  470. return False
  471. return True
  472. @property
  473. def is_structurally_symmetric(self):
  474. return self.rowblocksizes == self.colblocksizes
  475. def equals(self, other):
  476. if self == other:
  477. return True
  478. if (isinstance(other, BlockMatrix) and self.blocks == other.blocks):
  479. return True
  480. return super().equals(other)
  481. class BlockDiagMatrix(BlockMatrix):
  482. """A sparse matrix with block matrices along its diagonals
  483. Examples
  484. ========
  485. >>> from sympy import MatrixSymbol, BlockDiagMatrix, symbols
  486. >>> n, m, l = symbols('n m l')
  487. >>> X = MatrixSymbol('X', n, n)
  488. >>> Y = MatrixSymbol('Y', m, m)
  489. >>> BlockDiagMatrix(X, Y)
  490. Matrix([
  491. [X, 0],
  492. [0, Y]])
  493. Notes
  494. =====
  495. If you want to get the individual diagonal blocks, use
  496. :meth:`get_diag_blocks`.
  497. See Also
  498. ========
  499. sympy.matrices.dense.diag
  500. """
  501. def __new__(cls, *mats):
  502. return Basic.__new__(BlockDiagMatrix, *[_sympify(m) for m in mats])
  503. @property
  504. def diag(self):
  505. return self.args
  506. @property
  507. def blocks(self):
  508. from sympy.matrices.immutable import ImmutableDenseMatrix
  509. mats = self.args
  510. data = [[mats[i] if i == j else ZeroMatrix(mats[i].rows, mats[j].cols)
  511. for j in range(len(mats))]
  512. for i in range(len(mats))]
  513. return ImmutableDenseMatrix(data, evaluate=False)
  514. @property
  515. def shape(self):
  516. return (sum(block.rows for block in self.args),
  517. sum(block.cols for block in self.args))
  518. @property
  519. def blockshape(self):
  520. n = len(self.args)
  521. return (n, n)
  522. @property
  523. def rowblocksizes(self):
  524. return [block.rows for block in self.args]
  525. @property
  526. def colblocksizes(self):
  527. return [block.cols for block in self.args]
  528. def _all_square_blocks(self):
  529. """Returns true if all blocks are square"""
  530. return all(mat.is_square for mat in self.args)
  531. def _eval_determinant(self):
  532. if self._all_square_blocks():
  533. return Mul(*[det(mat) for mat in self.args])
  534. # At least one block is non-square. Since the entire matrix must be square we know there must
  535. # be at least two blocks in this matrix, in which case the entire matrix is necessarily rank-deficient
  536. return S.Zero
  537. def _eval_inverse(self, expand='ignored'):
  538. if self._all_square_blocks():
  539. return BlockDiagMatrix(*[mat.inverse() for mat in self.args])
  540. # See comment in _eval_determinant()
  541. raise NonInvertibleMatrixError('Matrix det == 0; not invertible.')
  542. def _eval_transpose(self):
  543. return BlockDiagMatrix(*[mat.transpose() for mat in self.args])
  544. def _blockmul(self, other):
  545. if (isinstance(other, BlockDiagMatrix) and
  546. self.colblocksizes == other.rowblocksizes):
  547. return BlockDiagMatrix(*[a*b for a, b in zip(self.args, other.args)])
  548. else:
  549. return BlockMatrix._blockmul(self, other)
  550. def _blockadd(self, other):
  551. if (isinstance(other, BlockDiagMatrix) and
  552. self.blockshape == other.blockshape and
  553. self.rowblocksizes == other.rowblocksizes and
  554. self.colblocksizes == other.colblocksizes):
  555. return BlockDiagMatrix(*[a + b for a, b in zip(self.args, other.args)])
  556. else:
  557. return BlockMatrix._blockadd(self, other)
  558. def get_diag_blocks(self):
  559. """Return the list of diagonal blocks of the matrix.
  560. Examples
  561. ========
  562. >>> from sympy import BlockDiagMatrix, Matrix
  563. >>> A = Matrix([[1, 2], [3, 4]])
  564. >>> B = Matrix([[5, 6], [7, 8]])
  565. >>> M = BlockDiagMatrix(A, B)
  566. How to get diagonal blocks from the block diagonal matrix:
  567. >>> diag_blocks = M.get_diag_blocks()
  568. >>> diag_blocks[0]
  569. Matrix([
  570. [1, 2],
  571. [3, 4]])
  572. >>> diag_blocks[1]
  573. Matrix([
  574. [5, 6],
  575. [7, 8]])
  576. """
  577. return self.args
  578. def block_collapse(expr):
  579. """Evaluates a block matrix expression
  580. >>> from sympy import MatrixSymbol, BlockMatrix, symbols, Identity, ZeroMatrix, block_collapse
  581. >>> n,m,l = symbols('n m l')
  582. >>> X = MatrixSymbol('X', n, n)
  583. >>> Y = MatrixSymbol('Y', m, m)
  584. >>> Z = MatrixSymbol('Z', n, m)
  585. >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m, n), Y]])
  586. >>> print(B)
  587. Matrix([
  588. [X, Z],
  589. [0, Y]])
  590. >>> C = BlockMatrix([[Identity(n), Z]])
  591. >>> print(C)
  592. Matrix([[I, Z]])
  593. >>> print(block_collapse(C*B))
  594. Matrix([[X, Z + Z*Y]])
  595. """
  596. from sympy.strategies.util import expr_fns
  597. hasbm = lambda expr: isinstance(expr, MatrixExpr) and expr.has(BlockMatrix)
  598. conditioned_rl = condition(
  599. hasbm,
  600. typed(
  601. {MatAdd: do_one(bc_matadd, bc_block_plus_ident),
  602. MatMul: do_one(bc_matmul, bc_dist),
  603. MatPow: bc_matmul,
  604. Transpose: bc_transpose,
  605. Inverse: bc_inverse,
  606. BlockMatrix: do_one(bc_unpack, deblock)}
  607. )
  608. )
  609. rule = exhaust(
  610. bottom_up(
  611. exhaust(conditioned_rl),
  612. fns=expr_fns
  613. )
  614. )
  615. result = rule(expr)
  616. doit = getattr(result, 'doit', None)
  617. if doit is not None:
  618. return doit()
  619. else:
  620. return result
  621. def bc_unpack(expr):
  622. if expr.blockshape == (1, 1):
  623. return expr.blocks[0, 0]
  624. return expr
  625. def bc_matadd(expr):
  626. args = sift(expr.args, lambda M: isinstance(M, BlockMatrix))
  627. blocks = args[True]
  628. if not blocks:
  629. return expr
  630. nonblocks = args[False]
  631. block = blocks[0]
  632. for b in blocks[1:]:
  633. block = block._blockadd(b)
  634. if nonblocks:
  635. return MatAdd(*nonblocks) + block
  636. else:
  637. return block
  638. def bc_block_plus_ident(expr):
  639. idents = [arg for arg in expr.args if arg.is_Identity]
  640. if not idents:
  641. return expr
  642. blocks = [arg for arg in expr.args if isinstance(arg, BlockMatrix)]
  643. if (blocks and all(b.structurally_equal(blocks[0]) for b in blocks)
  644. and blocks[0].is_structurally_symmetric):
  645. block_id = BlockDiagMatrix(*[Identity(k)
  646. for k in blocks[0].rowblocksizes])
  647. rest = [arg for arg in expr.args if not arg.is_Identity and not isinstance(arg, BlockMatrix)]
  648. return MatAdd(block_id * len(idents), *blocks, *rest).doit()
  649. return expr
  650. def bc_dist(expr):
  651. """ Turn a*[X, Y] into [a*X, a*Y] """
  652. factor, mat = expr.as_coeff_mmul()
  653. if factor == 1:
  654. return expr
  655. unpacked = unpack(mat)
  656. if isinstance(unpacked, BlockDiagMatrix):
  657. B = unpacked.diag
  658. new_B = [factor * mat for mat in B]
  659. return BlockDiagMatrix(*new_B)
  660. elif isinstance(unpacked, BlockMatrix):
  661. B = unpacked.blocks
  662. new_B = [
  663. [factor * B[i, j] for j in range(B.cols)] for i in range(B.rows)]
  664. return BlockMatrix(new_B)
  665. return expr
  666. def bc_matmul(expr):
  667. if isinstance(expr, MatPow):
  668. if expr.args[1].is_Integer:
  669. factor, matrices = (1, [expr.args[0]]*expr.args[1])
  670. else:
  671. return expr
  672. else:
  673. factor, matrices = expr.as_coeff_matrices()
  674. i = 0
  675. while (i+1 < len(matrices)):
  676. A, B = matrices[i:i+2]
  677. if isinstance(A, BlockMatrix) and isinstance(B, BlockMatrix):
  678. matrices[i] = A._blockmul(B)
  679. matrices.pop(i+1)
  680. elif isinstance(A, BlockMatrix):
  681. matrices[i] = A._blockmul(BlockMatrix([[B]]))
  682. matrices.pop(i+1)
  683. elif isinstance(B, BlockMatrix):
  684. matrices[i] = BlockMatrix([[A]])._blockmul(B)
  685. matrices.pop(i+1)
  686. else:
  687. i+=1
  688. return MatMul(factor, *matrices).doit()
  689. def bc_transpose(expr):
  690. collapse = block_collapse(expr.arg)
  691. return collapse._eval_transpose()
  692. def bc_inverse(expr):
  693. if isinstance(expr.arg, BlockDiagMatrix):
  694. return expr.inverse()
  695. expr2 = blockinverse_1x1(expr)
  696. if expr != expr2:
  697. return expr2
  698. return blockinverse_2x2(Inverse(reblock_2x2(expr.arg)))
  699. def blockinverse_1x1(expr):
  700. if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (1, 1):
  701. mat = Matrix([[expr.arg.blocks[0].inverse()]])
  702. return BlockMatrix(mat)
  703. return expr
  704. def blockinverse_2x2(expr):
  705. if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (2, 2):
  706. # See: Inverses of 2x2 Block Matrices, Tzon-Tzer Lu and Sheng-Hua Shiou
  707. [[A, B],
  708. [C, D]] = expr.arg.blocks.tolist()
  709. formula = _choose_2x2_inversion_formula(A, B, C, D)
  710. if formula != None:
  711. MI = expr.arg.schur(formula).I
  712. if formula == 'A':
  713. AI = A.I
  714. return BlockMatrix([[AI + AI * B * MI * C * AI, -AI * B * MI], [-MI * C * AI, MI]])
  715. if formula == 'B':
  716. BI = B.I
  717. return BlockMatrix([[-MI * D * BI, MI], [BI + BI * A * MI * D * BI, -BI * A * MI]])
  718. if formula == 'C':
  719. CI = C.I
  720. return BlockMatrix([[-CI * D * MI, CI + CI * D * MI * A * CI], [MI, -MI * A * CI]])
  721. if formula == 'D':
  722. DI = D.I
  723. return BlockMatrix([[MI, -MI * B * DI], [-DI * C * MI, DI + DI * C * MI * B * DI]])
  724. return expr
  725. def _choose_2x2_inversion_formula(A, B, C, D):
  726. """
  727. Assuming [[A, B], [C, D]] would form a valid square block matrix, find
  728. which of the classical 2x2 block matrix inversion formulas would be
  729. best suited.
  730. Returns 'A', 'B', 'C', 'D' to represent the algorithm involving inversion
  731. of the given argument or None if the matrix cannot be inverted using
  732. any of those formulas.
  733. """
  734. # Try to find a known invertible matrix. Note that the Schur complement
  735. # is currently not being considered for this
  736. A_inv = ask(Q.invertible(A))
  737. if A_inv == True:
  738. return 'A'
  739. B_inv = ask(Q.invertible(B))
  740. if B_inv == True:
  741. return 'B'
  742. C_inv = ask(Q.invertible(C))
  743. if C_inv == True:
  744. return 'C'
  745. D_inv = ask(Q.invertible(D))
  746. if D_inv == True:
  747. return 'D'
  748. # Otherwise try to find a matrix that isn't known to be non-invertible
  749. if A_inv != False:
  750. return 'A'
  751. if B_inv != False:
  752. return 'B'
  753. if C_inv != False:
  754. return 'C'
  755. if D_inv != False:
  756. return 'D'
  757. return None
  758. def deblock(B):
  759. """ Flatten a BlockMatrix of BlockMatrices """
  760. if not isinstance(B, BlockMatrix) or not B.blocks.has(BlockMatrix):
  761. return B
  762. wrap = lambda x: x if isinstance(x, BlockMatrix) else BlockMatrix([[x]])
  763. bb = B.blocks.applyfunc(wrap) # everything is a block
  764. try:
  765. MM = Matrix(0, sum(bb[0, i].blocks.shape[1] for i in range(bb.shape[1])), [])
  766. for row in range(0, bb.shape[0]):
  767. M = Matrix(bb[row, 0].blocks)
  768. for col in range(1, bb.shape[1]):
  769. M = M.row_join(bb[row, col].blocks)
  770. MM = MM.col_join(M)
  771. return BlockMatrix(MM)
  772. except ShapeError:
  773. return B
  774. def reblock_2x2(expr):
  775. """
  776. Reblock a BlockMatrix so that it has 2x2 blocks of block matrices. If
  777. possible in such a way that the matrix continues to be invertible using the
  778. classical 2x2 block inversion formulas.
  779. """
  780. if not isinstance(expr, BlockMatrix) or not all(d > 2 for d in expr.blockshape):
  781. return expr
  782. BM = BlockMatrix # for brevity's sake
  783. rowblocks, colblocks = expr.blockshape
  784. blocks = expr.blocks
  785. for i in range(1, rowblocks):
  786. for j in range(1, colblocks):
  787. # try to split rows at i and cols at j
  788. A = bc_unpack(BM(blocks[:i, :j]))
  789. B = bc_unpack(BM(blocks[:i, j:]))
  790. C = bc_unpack(BM(blocks[i:, :j]))
  791. D = bc_unpack(BM(blocks[i:, j:]))
  792. formula = _choose_2x2_inversion_formula(A, B, C, D)
  793. if formula is not None:
  794. return BlockMatrix([[A, B], [C, D]])
  795. # else: nothing worked, just split upper left corner
  796. return BM([[blocks[0, 0], BM(blocks[0, 1:])],
  797. [BM(blocks[1:, 0]), BM(blocks[1:, 1:])]])
  798. def bounds(sizes):
  799. """ Convert sequence of numbers into pairs of low-high pairs
  800. >>> from sympy.matrices.expressions.blockmatrix import bounds
  801. >>> bounds((1, 10, 50))
  802. [(0, 1), (1, 11), (11, 61)]
  803. """
  804. low = 0
  805. rv = []
  806. for size in sizes:
  807. rv.append((low, low + size))
  808. low += size
  809. return rv
  810. def blockcut(expr, rowsizes, colsizes):
  811. """ Cut a matrix expression into Blocks
  812. >>> from sympy import ImmutableMatrix, blockcut
  813. >>> M = ImmutableMatrix(4, 4, range(16))
  814. >>> B = blockcut(M, (1, 3), (1, 3))
  815. >>> type(B).__name__
  816. 'BlockMatrix'
  817. >>> ImmutableMatrix(B.blocks[0, 1])
  818. Matrix([[1, 2, 3]])
  819. """
  820. rowbounds = bounds(rowsizes)
  821. colbounds = bounds(colsizes)
  822. return BlockMatrix([[MatrixSlice(expr, rowbound, colbound)
  823. for colbound in colbounds]
  824. for rowbound in rowbounds])