kronecker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. """Implementation of the Kronecker product"""
  2. from functools import reduce
  3. from math import prod
  4. from sympy.core import Mul, sympify
  5. from sympy.functions import adjoint
  6. from sympy.matrices.common import ShapeError
  7. from sympy.matrices.expressions.matexpr import MatrixExpr
  8. from sympy.matrices.expressions.transpose import transpose
  9. from sympy.matrices.expressions.special import Identity
  10. from sympy.matrices.matrices import MatrixBase
  11. from sympy.strategies import (
  12. canon, condition, distribute, do_one, exhaust, flatten, typed, unpack)
  13. from sympy.strategies.traverse import bottom_up
  14. from sympy.utilities import sift
  15. from .matadd import MatAdd
  16. from .matmul import MatMul
  17. from .matpow import MatPow
  18. def kronecker_product(*matrices):
  19. """
  20. The Kronecker product of two or more arguments.
  21. This computes the explicit Kronecker product for subclasses of
  22. ``MatrixBase`` i.e. explicit matrices. Otherwise, a symbolic
  23. ``KroneckerProduct`` object is returned.
  24. Examples
  25. ========
  26. For ``MatrixSymbol`` arguments a ``KroneckerProduct`` object is returned.
  27. Elements of this matrix can be obtained by indexing, or for MatrixSymbols
  28. with known dimension the explicit matrix can be obtained with
  29. ``.as_explicit()``
  30. >>> from sympy import kronecker_product, MatrixSymbol
  31. >>> A = MatrixSymbol('A', 2, 2)
  32. >>> B = MatrixSymbol('B', 2, 2)
  33. >>> kronecker_product(A)
  34. A
  35. >>> kronecker_product(A, B)
  36. KroneckerProduct(A, B)
  37. >>> kronecker_product(A, B)[0, 1]
  38. A[0, 0]*B[0, 1]
  39. >>> kronecker_product(A, B).as_explicit()
  40. Matrix([
  41. [A[0, 0]*B[0, 0], A[0, 0]*B[0, 1], A[0, 1]*B[0, 0], A[0, 1]*B[0, 1]],
  42. [A[0, 0]*B[1, 0], A[0, 0]*B[1, 1], A[0, 1]*B[1, 0], A[0, 1]*B[1, 1]],
  43. [A[1, 0]*B[0, 0], A[1, 0]*B[0, 1], A[1, 1]*B[0, 0], A[1, 1]*B[0, 1]],
  44. [A[1, 0]*B[1, 0], A[1, 0]*B[1, 1], A[1, 1]*B[1, 0], A[1, 1]*B[1, 1]]])
  45. For explicit matrices the Kronecker product is returned as a Matrix
  46. >>> from sympy import Matrix, kronecker_product
  47. >>> sigma_x = Matrix([
  48. ... [0, 1],
  49. ... [1, 0]])
  50. ...
  51. >>> Isigma_y = Matrix([
  52. ... [0, 1],
  53. ... [-1, 0]])
  54. ...
  55. >>> kronecker_product(sigma_x, Isigma_y)
  56. Matrix([
  57. [ 0, 0, 0, 1],
  58. [ 0, 0, -1, 0],
  59. [ 0, 1, 0, 0],
  60. [-1, 0, 0, 0]])
  61. See Also
  62. ========
  63. KroneckerProduct
  64. """
  65. if not matrices:
  66. raise TypeError("Empty Kronecker product is undefined")
  67. if len(matrices) == 1:
  68. return matrices[0]
  69. else:
  70. return KroneckerProduct(*matrices).doit()
  71. class KroneckerProduct(MatrixExpr):
  72. """
  73. The Kronecker product of two or more arguments.
  74. The Kronecker product is a non-commutative product of matrices.
  75. Given two matrices of dimension (m, n) and (s, t) it produces a matrix
  76. of dimension (m s, n t).
  77. This is a symbolic object that simply stores its argument without
  78. evaluating it. To actually compute the product, use the function
  79. ``kronecker_product()`` or call the ``.doit()`` or ``.as_explicit()``
  80. methods.
  81. >>> from sympy import KroneckerProduct, MatrixSymbol
  82. >>> A = MatrixSymbol('A', 5, 5)
  83. >>> B = MatrixSymbol('B', 5, 5)
  84. >>> isinstance(KroneckerProduct(A, B), KroneckerProduct)
  85. True
  86. """
  87. is_KroneckerProduct = True
  88. def __new__(cls, *args, check=True):
  89. args = list(map(sympify, args))
  90. if all(a.is_Identity for a in args):
  91. ret = Identity(prod(a.rows for a in args))
  92. if all(isinstance(a, MatrixBase) for a in args):
  93. return ret.as_explicit()
  94. else:
  95. return ret
  96. if check:
  97. validate(*args)
  98. return super().__new__(cls, *args)
  99. @property
  100. def shape(self):
  101. rows, cols = self.args[0].shape
  102. for mat in self.args[1:]:
  103. rows *= mat.rows
  104. cols *= mat.cols
  105. return (rows, cols)
  106. def _entry(self, i, j, **kwargs):
  107. result = 1
  108. for mat in reversed(self.args):
  109. i, m = divmod(i, mat.rows)
  110. j, n = divmod(j, mat.cols)
  111. result *= mat[m, n]
  112. return result
  113. def _eval_adjoint(self):
  114. return KroneckerProduct(*list(map(adjoint, self.args))).doit()
  115. def _eval_conjugate(self):
  116. return KroneckerProduct(*[a.conjugate() for a in self.args]).doit()
  117. def _eval_transpose(self):
  118. return KroneckerProduct(*list(map(transpose, self.args))).doit()
  119. def _eval_trace(self):
  120. from .trace import trace
  121. return Mul(*[trace(a) for a in self.args])
  122. def _eval_determinant(self):
  123. from .determinant import det, Determinant
  124. if not all(a.is_square for a in self.args):
  125. return Determinant(self)
  126. m = self.rows
  127. return Mul(*[det(a)**(m/a.rows) for a in self.args])
  128. def _eval_inverse(self):
  129. try:
  130. return KroneckerProduct(*[a.inverse() for a in self.args])
  131. except ShapeError:
  132. from sympy.matrices.expressions.inverse import Inverse
  133. return Inverse(self)
  134. def structurally_equal(self, other):
  135. '''Determine whether two matrices have the same Kronecker product structure
  136. Examples
  137. ========
  138. >>> from sympy import KroneckerProduct, MatrixSymbol, symbols
  139. >>> m, n = symbols(r'm, n', integer=True)
  140. >>> A = MatrixSymbol('A', m, m)
  141. >>> B = MatrixSymbol('B', n, n)
  142. >>> C = MatrixSymbol('C', m, m)
  143. >>> D = MatrixSymbol('D', n, n)
  144. >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(C, D))
  145. True
  146. >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(D, C))
  147. False
  148. >>> KroneckerProduct(A, B).structurally_equal(C)
  149. False
  150. '''
  151. # Inspired by BlockMatrix
  152. return (isinstance(other, KroneckerProduct)
  153. and self.shape == other.shape
  154. and len(self.args) == len(other.args)
  155. and all(a.shape == b.shape for (a, b) in zip(self.args, other.args)))
  156. def has_matching_shape(self, other):
  157. '''Determine whether two matrices have the appropriate structure to bring matrix
  158. multiplication inside the KroneckerProdut
  159. Examples
  160. ========
  161. >>> from sympy import KroneckerProduct, MatrixSymbol, symbols
  162. >>> m, n = symbols(r'm, n', integer=True)
  163. >>> A = MatrixSymbol('A', m, n)
  164. >>> B = MatrixSymbol('B', n, m)
  165. >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(B, A))
  166. True
  167. >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(A, B))
  168. False
  169. >>> KroneckerProduct(A, B).has_matching_shape(A)
  170. False
  171. '''
  172. return (isinstance(other, KroneckerProduct)
  173. and self.cols == other.rows
  174. and len(self.args) == len(other.args)
  175. and all(a.cols == b.rows for (a, b) in zip(self.args, other.args)))
  176. def _eval_expand_kroneckerproduct(self, **hints):
  177. return flatten(canon(typed({KroneckerProduct: distribute(KroneckerProduct, MatAdd)}))(self))
  178. def _kronecker_add(self, other):
  179. if self.structurally_equal(other):
  180. return self.__class__(*[a + b for (a, b) in zip(self.args, other.args)])
  181. else:
  182. return self + other
  183. def _kronecker_mul(self, other):
  184. if self.has_matching_shape(other):
  185. return self.__class__(*[a*b for (a, b) in zip(self.args, other.args)])
  186. else:
  187. return self * other
  188. def doit(self, **hints):
  189. deep = hints.get('deep', True)
  190. if deep:
  191. args = [arg.doit(**hints) for arg in self.args]
  192. else:
  193. args = self.args
  194. return canonicalize(KroneckerProduct(*args))
  195. def validate(*args):
  196. if not all(arg.is_Matrix for arg in args):
  197. raise TypeError("Mix of Matrix and Scalar symbols")
  198. # rules
  199. def extract_commutative(kron):
  200. c_part = []
  201. nc_part = []
  202. for arg in kron.args:
  203. c, nc = arg.args_cnc()
  204. c_part.extend(c)
  205. nc_part.append(Mul._from_args(nc))
  206. c_part = Mul(*c_part)
  207. if c_part != 1:
  208. return c_part*KroneckerProduct(*nc_part)
  209. return kron
  210. def matrix_kronecker_product(*matrices):
  211. """Compute the Kronecker product of a sequence of SymPy Matrices.
  212. This is the standard Kronecker product of matrices [1].
  213. Parameters
  214. ==========
  215. matrices : tuple of MatrixBase instances
  216. The matrices to take the Kronecker product of.
  217. Returns
  218. =======
  219. matrix : MatrixBase
  220. The Kronecker product matrix.
  221. Examples
  222. ========
  223. >>> from sympy import Matrix
  224. >>> from sympy.matrices.expressions.kronecker import (
  225. ... matrix_kronecker_product)
  226. >>> m1 = Matrix([[1,2],[3,4]])
  227. >>> m2 = Matrix([[1,0],[0,1]])
  228. >>> matrix_kronecker_product(m1, m2)
  229. Matrix([
  230. [1, 0, 2, 0],
  231. [0, 1, 0, 2],
  232. [3, 0, 4, 0],
  233. [0, 3, 0, 4]])
  234. >>> matrix_kronecker_product(m2, m1)
  235. Matrix([
  236. [1, 2, 0, 0],
  237. [3, 4, 0, 0],
  238. [0, 0, 1, 2],
  239. [0, 0, 3, 4]])
  240. References
  241. ==========
  242. .. [1] https://en.wikipedia.org/wiki/Kronecker_product
  243. """
  244. # Make sure we have a sequence of Matrices
  245. if not all(isinstance(m, MatrixBase) for m in matrices):
  246. raise TypeError(
  247. 'Sequence of Matrices expected, got: %s' % repr(matrices)
  248. )
  249. # Pull out the first element in the product.
  250. matrix_expansion = matrices[-1]
  251. # Do the kronecker product working from right to left.
  252. for mat in reversed(matrices[:-1]):
  253. rows = mat.rows
  254. cols = mat.cols
  255. # Go through each row appending kronecker product to.
  256. # running matrix_expansion.
  257. for i in range(rows):
  258. start = matrix_expansion*mat[i*cols]
  259. # Go through each column joining each item
  260. for j in range(cols - 1):
  261. start = start.row_join(
  262. matrix_expansion*mat[i*cols + j + 1]
  263. )
  264. # If this is the first element, make it the start of the
  265. # new row.
  266. if i == 0:
  267. next = start
  268. else:
  269. next = next.col_join(start)
  270. matrix_expansion = next
  271. MatrixClass = max(matrices, key=lambda M: M._class_priority).__class__
  272. if isinstance(matrix_expansion, MatrixClass):
  273. return matrix_expansion
  274. else:
  275. return MatrixClass(matrix_expansion)
  276. def explicit_kronecker_product(kron):
  277. # Make sure we have a sequence of Matrices
  278. if not all(isinstance(m, MatrixBase) for m in kron.args):
  279. return kron
  280. return matrix_kronecker_product(*kron.args)
  281. rules = (unpack,
  282. explicit_kronecker_product,
  283. flatten,
  284. extract_commutative)
  285. canonicalize = exhaust(condition(lambda x: isinstance(x, KroneckerProduct),
  286. do_one(*rules)))
  287. def _kronecker_dims_key(expr):
  288. if isinstance(expr, KroneckerProduct):
  289. return tuple(a.shape for a in expr.args)
  290. else:
  291. return (0,)
  292. def kronecker_mat_add(expr):
  293. args = sift(expr.args, _kronecker_dims_key)
  294. nonkrons = args.pop((0,), None)
  295. if not args:
  296. return expr
  297. krons = [reduce(lambda x, y: x._kronecker_add(y), group)
  298. for group in args.values()]
  299. if not nonkrons:
  300. return MatAdd(*krons)
  301. else:
  302. return MatAdd(*krons) + nonkrons
  303. def kronecker_mat_mul(expr):
  304. # modified from block matrix code
  305. factor, matrices = expr.as_coeff_matrices()
  306. i = 0
  307. while i < len(matrices) - 1:
  308. A, B = matrices[i:i+2]
  309. if isinstance(A, KroneckerProduct) and isinstance(B, KroneckerProduct):
  310. matrices[i] = A._kronecker_mul(B)
  311. matrices.pop(i+1)
  312. else:
  313. i += 1
  314. return factor*MatMul(*matrices)
  315. def kronecker_mat_pow(expr):
  316. if isinstance(expr.base, KroneckerProduct) and all(a.is_square for a in expr.base.args):
  317. return KroneckerProduct(*[MatPow(a, expr.exp) for a in expr.base.args])
  318. else:
  319. return expr
  320. def combine_kronecker(expr):
  321. """Combine KronekeckerProduct with expression.
  322. If possible write operations on KroneckerProducts of compatible shapes
  323. as a single KroneckerProduct.
  324. Examples
  325. ========
  326. >>> from sympy.matrices.expressions import combine_kronecker
  327. >>> from sympy import MatrixSymbol, KroneckerProduct, symbols
  328. >>> m, n = symbols(r'm, n', integer=True)
  329. >>> A = MatrixSymbol('A', m, n)
  330. >>> B = MatrixSymbol('B', n, m)
  331. >>> combine_kronecker(KroneckerProduct(A, B)*KroneckerProduct(B, A))
  332. KroneckerProduct(A*B, B*A)
  333. >>> combine_kronecker(KroneckerProduct(A, B)+KroneckerProduct(B.T, A.T))
  334. KroneckerProduct(A + B.T, B + A.T)
  335. >>> C = MatrixSymbol('C', n, n)
  336. >>> D = MatrixSymbol('D', m, m)
  337. >>> combine_kronecker(KroneckerProduct(C, D)**m)
  338. KroneckerProduct(C**m, D**m)
  339. """
  340. def haskron(expr):
  341. return isinstance(expr, MatrixExpr) and expr.has(KroneckerProduct)
  342. rule = exhaust(
  343. bottom_up(exhaust(condition(haskron, typed(
  344. {MatAdd: kronecker_mat_add,
  345. MatMul: kronecker_mat_mul,
  346. MatPow: kronecker_mat_pow})))))
  347. result = rule(expr)
  348. doit = getattr(result, 'doit', None)
  349. if doit is not None:
  350. return doit()
  351. else:
  352. return result