matmul.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. from sympy.assumptions.ask import ask, Q
  2. from sympy.assumptions.refine import handlers_dict
  3. from sympy.core import Basic, sympify, S
  4. from sympy.core.mul import mul, Mul
  5. from sympy.core.numbers import Number, Integer
  6. from sympy.core.symbol import Dummy
  7. from sympy.functions import adjoint
  8. from sympy.strategies import (rm_id, unpack, typed, flatten, exhaust,
  9. do_one, new)
  10. from sympy.matrices.common import NonInvertibleMatrixError
  11. from sympy.matrices.matrices import MatrixBase
  12. from sympy.utilities.exceptions import sympy_deprecation_warning
  13. from sympy.matrices.expressions._shape import validate_matmul_integer as validate
  14. from .inverse import Inverse
  15. from .matexpr import MatrixExpr
  16. from .matpow import MatPow
  17. from .transpose import transpose
  18. from .permutation import PermutationMatrix
  19. from .special import ZeroMatrix, Identity, GenericIdentity, OneMatrix
  20. # XXX: MatMul should perhaps not subclass directly from Mul
  21. class MatMul(MatrixExpr, Mul):
  22. """
  23. A product of matrix expressions
  24. Examples
  25. ========
  26. >>> from sympy import MatMul, MatrixSymbol
  27. >>> A = MatrixSymbol('A', 5, 4)
  28. >>> B = MatrixSymbol('B', 4, 3)
  29. >>> C = MatrixSymbol('C', 3, 6)
  30. >>> MatMul(A, B, C)
  31. A*B*C
  32. """
  33. is_MatMul = True
  34. identity = GenericIdentity()
  35. def __new__(cls, *args, evaluate=False, check=None, _sympify=True):
  36. if not args:
  37. return cls.identity
  38. # This must be removed aggressively in the constructor to avoid
  39. # TypeErrors from GenericIdentity().shape
  40. args = list(filter(lambda i: cls.identity != i, args))
  41. if _sympify:
  42. args = list(map(sympify, args))
  43. obj = Basic.__new__(cls, *args)
  44. factor, matrices = obj.as_coeff_matrices()
  45. if check is not None:
  46. sympy_deprecation_warning(
  47. "Passing check to MatMul is deprecated and the check argument will be removed in a future version.",
  48. deprecated_since_version="1.11",
  49. active_deprecations_target='remove-check-argument-from-matrix-operations')
  50. if check is not False:
  51. validate(*matrices)
  52. if not matrices:
  53. # Should it be
  54. #
  55. # return Basic.__neq__(cls, factor, GenericIdentity()) ?
  56. return factor
  57. if evaluate:
  58. return cls._evaluate(obj)
  59. return obj
  60. @classmethod
  61. def _evaluate(cls, expr):
  62. return canonicalize(expr)
  63. @property
  64. def shape(self):
  65. matrices = [arg for arg in self.args if arg.is_Matrix]
  66. return (matrices[0].rows, matrices[-1].cols)
  67. def _entry(self, i, j, expand=True, **kwargs):
  68. # Avoid cyclic imports
  69. from sympy.concrete.summations import Sum
  70. from sympy.matrices.immutable import ImmutableMatrix
  71. coeff, matrices = self.as_coeff_matrices()
  72. if len(matrices) == 1: # situation like 2*X, matmul is just X
  73. return coeff * matrices[0][i, j]
  74. indices = [None]*(len(matrices) + 1)
  75. ind_ranges = [None]*(len(matrices) - 1)
  76. indices[0] = i
  77. indices[-1] = j
  78. def f():
  79. counter = 1
  80. while True:
  81. yield Dummy("i_%i" % counter)
  82. counter += 1
  83. dummy_generator = kwargs.get("dummy_generator", f())
  84. for i in range(1, len(matrices)):
  85. indices[i] = next(dummy_generator)
  86. for i, arg in enumerate(matrices[:-1]):
  87. ind_ranges[i] = arg.shape[1] - 1
  88. matrices = [arg._entry(indices[i], indices[i+1], dummy_generator=dummy_generator) for i, arg in enumerate(matrices)]
  89. expr_in_sum = Mul.fromiter(matrices)
  90. if any(v.has(ImmutableMatrix) for v in matrices):
  91. expand = True
  92. result = coeff*Sum(
  93. expr_in_sum,
  94. *zip(indices[1:-1], [0]*len(ind_ranges), ind_ranges)
  95. )
  96. # Don't waste time in result.doit() if the sum bounds are symbolic
  97. if not any(isinstance(v, (Integer, int)) for v in ind_ranges):
  98. expand = False
  99. return result.doit() if expand else result
  100. def as_coeff_matrices(self):
  101. scalars = [x for x in self.args if not x.is_Matrix]
  102. matrices = [x for x in self.args if x.is_Matrix]
  103. coeff = Mul(*scalars)
  104. if coeff.is_commutative is False:
  105. raise NotImplementedError("noncommutative scalars in MatMul are not supported.")
  106. return coeff, matrices
  107. def as_coeff_mmul(self):
  108. coeff, matrices = self.as_coeff_matrices()
  109. return coeff, MatMul(*matrices)
  110. def expand(self, **kwargs):
  111. expanded = super(MatMul, self).expand(**kwargs)
  112. return self._evaluate(expanded)
  113. def _eval_transpose(self):
  114. """Transposition of matrix multiplication.
  115. Notes
  116. =====
  117. The following rules are applied.
  118. Transposition for matrix multiplied with another matrix:
  119. `\\left(A B\\right)^{T} = B^{T} A^{T}`
  120. Transposition for matrix multiplied with scalar:
  121. `\\left(c A\\right)^{T} = c A^{T}`
  122. References
  123. ==========
  124. .. [1] https://en.wikipedia.org/wiki/Transpose
  125. """
  126. coeff, matrices = self.as_coeff_matrices()
  127. return MatMul(
  128. coeff, *[transpose(arg) for arg in matrices[::-1]]).doit()
  129. def _eval_adjoint(self):
  130. return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit()
  131. def _eval_trace(self):
  132. factor, mmul = self.as_coeff_mmul()
  133. if factor != 1:
  134. from .trace import trace
  135. return factor * trace(mmul.doit())
  136. else:
  137. raise NotImplementedError("Can't simplify any further")
  138. def _eval_determinant(self):
  139. from sympy.matrices.expressions.determinant import Determinant
  140. factor, matrices = self.as_coeff_matrices()
  141. square_matrices = only_squares(*matrices)
  142. return factor**self.rows * Mul(*list(map(Determinant, square_matrices)))
  143. def _eval_inverse(self):
  144. if all(arg.is_square for arg in self.args if isinstance(arg, MatrixExpr)):
  145. return MatMul(*(
  146. arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1
  147. for arg in self.args[::-1]
  148. )
  149. ).doit()
  150. return Inverse(self)
  151. def doit(self, **hints):
  152. deep = hints.get('deep', True)
  153. if deep:
  154. args = tuple(arg.doit(**hints) for arg in self.args)
  155. else:
  156. args = self.args
  157. # treat scalar*MatrixSymbol or scalar*MatPow separately
  158. expr = canonicalize(MatMul(*args))
  159. return expr
  160. # Needed for partial compatibility with Mul
  161. def args_cnc(self, cset=False, warn=True, **kwargs):
  162. coeff_c = [x for x in self.args if x.is_commutative]
  163. coeff_nc = [x for x in self.args if not x.is_commutative]
  164. if cset:
  165. clen = len(coeff_c)
  166. coeff_c = set(coeff_c)
  167. if clen and warn and len(coeff_c) != clen:
  168. raise ValueError('repeated commutative arguments: %s' %
  169. [ci for ci in coeff_c if list(self.args).count(ci) > 1])
  170. return [coeff_c, coeff_nc]
  171. def _eval_derivative_matrix_lines(self, x):
  172. from .transpose import Transpose
  173. with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)]
  174. lines = []
  175. for ind in with_x_ind:
  176. left_args = self.args[:ind]
  177. right_args = self.args[ind+1:]
  178. if right_args:
  179. right_mat = MatMul.fromiter(right_args)
  180. else:
  181. right_mat = Identity(self.shape[1])
  182. if left_args:
  183. left_rev = MatMul.fromiter([Transpose(i).doit() if i.is_Matrix else i for i in reversed(left_args)])
  184. else:
  185. left_rev = Identity(self.shape[0])
  186. d = self.args[ind]._eval_derivative_matrix_lines(x)
  187. for i in d:
  188. i.append_first(left_rev)
  189. i.append_second(right_mat)
  190. lines.append(i)
  191. return lines
  192. mul.register_handlerclass((Mul, MatMul), MatMul)
  193. # Rules
  194. def newmul(*args):
  195. if args[0] == 1:
  196. args = args[1:]
  197. return new(MatMul, *args)
  198. def any_zeros(mul):
  199. if any(arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix)
  200. for arg in mul.args):
  201. matrices = [arg for arg in mul.args if arg.is_Matrix]
  202. return ZeroMatrix(matrices[0].rows, matrices[-1].cols)
  203. return mul
  204. def merge_explicit(matmul):
  205. """ Merge explicit MatrixBase arguments
  206. >>> from sympy import MatrixSymbol, Matrix, MatMul, pprint
  207. >>> from sympy.matrices.expressions.matmul import merge_explicit
  208. >>> A = MatrixSymbol('A', 2, 2)
  209. >>> B = Matrix([[1, 1], [1, 1]])
  210. >>> C = Matrix([[1, 2], [3, 4]])
  211. >>> X = MatMul(A, B, C)
  212. >>> pprint(X)
  213. [1 1] [1 2]
  214. A*[ ]*[ ]
  215. [1 1] [3 4]
  216. >>> pprint(merge_explicit(X))
  217. [4 6]
  218. A*[ ]
  219. [4 6]
  220. >>> X = MatMul(B, A, C)
  221. >>> pprint(X)
  222. [1 1] [1 2]
  223. [ ]*A*[ ]
  224. [1 1] [3 4]
  225. >>> pprint(merge_explicit(X))
  226. [1 1] [1 2]
  227. [ ]*A*[ ]
  228. [1 1] [3 4]
  229. """
  230. if not any(isinstance(arg, MatrixBase) for arg in matmul.args):
  231. return matmul
  232. newargs = []
  233. last = matmul.args[0]
  234. for arg in matmul.args[1:]:
  235. if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)):
  236. last = last * arg
  237. else:
  238. newargs.append(last)
  239. last = arg
  240. newargs.append(last)
  241. return MatMul(*newargs)
  242. def remove_ids(mul):
  243. """ Remove Identities from a MatMul
  244. This is a modified version of sympy.strategies.rm_id.
  245. This is necesssary because MatMul may contain both MatrixExprs and Exprs
  246. as args.
  247. See Also
  248. ========
  249. sympy.strategies.rm_id
  250. """
  251. # Separate Exprs from MatrixExprs in args
  252. factor, mmul = mul.as_coeff_mmul()
  253. # Apply standard rm_id for MatMuls
  254. result = rm_id(lambda x: x.is_Identity is True)(mmul)
  255. if result != mmul:
  256. return newmul(factor, *result.args) # Recombine and return
  257. else:
  258. return mul
  259. def factor_in_front(mul):
  260. factor, matrices = mul.as_coeff_matrices()
  261. if factor != 1:
  262. return newmul(factor, *matrices)
  263. return mul
  264. def combine_powers(mul):
  265. r"""Combine consecutive powers with the same base into one, e.g.
  266. $$A \times A^2 \Rightarrow A^3$$
  267. This also cancels out the possible matrix inverses using the
  268. knowledgebase of :class:`~.Inverse`, e.g.,
  269. $$ Y \times X \times X^{-1} \Rightarrow Y $$
  270. """
  271. factor, args = mul.as_coeff_matrices()
  272. new_args = [args[0]]
  273. for i in range(1, len(args)):
  274. A = new_args[-1]
  275. B = args[i]
  276. if isinstance(B, Inverse) and isinstance(B.arg, MatMul):
  277. Bargs = B.arg.args
  278. l = len(Bargs)
  279. if list(Bargs) == new_args[-l:]:
  280. new_args = new_args[:-l] + [Identity(B.shape[0])]
  281. continue
  282. if isinstance(A, Inverse) and isinstance(A.arg, MatMul):
  283. Aargs = A.arg.args
  284. l = len(Aargs)
  285. if list(Aargs) == args[i:i+l]:
  286. identity = Identity(A.shape[0])
  287. new_args[-1] = identity
  288. for j in range(i, i+l):
  289. args[j] = identity
  290. continue
  291. if A.is_square == False or B.is_square == False:
  292. new_args.append(B)
  293. continue
  294. if isinstance(A, MatPow):
  295. A_base, A_exp = A.args
  296. else:
  297. A_base, A_exp = A, S.One
  298. if isinstance(B, MatPow):
  299. B_base, B_exp = B.args
  300. else:
  301. B_base, B_exp = B, S.One
  302. if A_base == B_base:
  303. new_exp = A_exp + B_exp
  304. new_args[-1] = MatPow(A_base, new_exp).doit(deep=False)
  305. continue
  306. elif not isinstance(B_base, MatrixBase):
  307. try:
  308. B_base_inv = B_base.inverse()
  309. except NonInvertibleMatrixError:
  310. B_base_inv = None
  311. if B_base_inv is not None and A_base == B_base_inv:
  312. new_exp = A_exp - B_exp
  313. new_args[-1] = MatPow(A_base, new_exp).doit(deep=False)
  314. continue
  315. new_args.append(B)
  316. return newmul(factor, *new_args)
  317. def combine_permutations(mul):
  318. """Refine products of permutation matrices as the products of cycles.
  319. """
  320. args = mul.args
  321. l = len(args)
  322. if l < 2:
  323. return mul
  324. result = [args[0]]
  325. for i in range(1, l):
  326. A = result[-1]
  327. B = args[i]
  328. if isinstance(A, PermutationMatrix) and \
  329. isinstance(B, PermutationMatrix):
  330. cycle_1 = A.args[0]
  331. cycle_2 = B.args[0]
  332. result[-1] = PermutationMatrix(cycle_1 * cycle_2)
  333. else:
  334. result.append(B)
  335. return MatMul(*result)
  336. def combine_one_matrices(mul):
  337. """
  338. Combine products of OneMatrix
  339. e.g. OneMatrix(2, 3) * OneMatrix(3, 4) -> 3 * OneMatrix(2, 4)
  340. """
  341. factor, args = mul.as_coeff_matrices()
  342. new_args = [args[0]]
  343. for B in args[1:]:
  344. A = new_args[-1]
  345. if not isinstance(A, OneMatrix) or not isinstance(B, OneMatrix):
  346. new_args.append(B)
  347. continue
  348. new_args.pop()
  349. new_args.append(OneMatrix(A.shape[0], B.shape[1]))
  350. factor *= A.shape[1]
  351. return newmul(factor, *new_args)
  352. def distribute_monom(mul):
  353. """
  354. Simplify MatMul expressions but distributing
  355. rational term to MatMul.
  356. e.g. 2*(A+B) -> 2*A + 2*B
  357. """
  358. args = mul.args
  359. if len(args) == 2:
  360. from .matadd import MatAdd
  361. if args[0].is_MatAdd and args[1].is_Rational:
  362. return MatAdd(*[MatMul(mat, args[1]).doit() for mat in args[0].args])
  363. if args[1].is_MatAdd and args[0].is_Rational:
  364. return MatAdd(*[MatMul(args[0], mat).doit() for mat in args[1].args])
  365. return mul
  366. rules = (
  367. distribute_monom, any_zeros, remove_ids, combine_one_matrices, combine_powers, unpack, rm_id(lambda x: x == 1),
  368. merge_explicit, factor_in_front, flatten, combine_permutations)
  369. canonicalize = exhaust(typed({MatMul: do_one(*rules)}))
  370. def only_squares(*matrices):
  371. """factor matrices only if they are square"""
  372. if matrices[0].rows != matrices[-1].cols:
  373. raise RuntimeError("Invalid matrices being multiplied")
  374. out = []
  375. start = 0
  376. for i, M in enumerate(matrices):
  377. if M.cols == matrices[start].rows:
  378. out.append(MatMul(*matrices[start:i+1]).doit())
  379. start = i+1
  380. return out
  381. def refine_MatMul(expr, assumptions):
  382. """
  383. >>> from sympy import MatrixSymbol, Q, assuming, refine
  384. >>> X = MatrixSymbol('X', 2, 2)
  385. >>> expr = X * X.T
  386. >>> print(expr)
  387. X*X.T
  388. >>> with assuming(Q.orthogonal(X)):
  389. ... print(refine(expr))
  390. I
  391. """
  392. newargs = []
  393. exprargs = []
  394. for args in expr.args:
  395. if args.is_Matrix:
  396. exprargs.append(args)
  397. else:
  398. newargs.append(args)
  399. last = exprargs[0]
  400. for arg in exprargs[1:]:
  401. if arg == last.T and ask(Q.orthogonal(arg), assumptions):
  402. last = Identity(arg.shape[0])
  403. elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions):
  404. last = Identity(arg.shape[0])
  405. else:
  406. newargs.append(last)
  407. last = arg
  408. newargs.append(last)
  409. return MatMul(*newargs)
  410. handlers_dict['MatMul'] = refine_MatMul