matexpr.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885
  1. from __future__ import annotations
  2. from functools import wraps
  3. from sympy.core import S, Integer, Basic, Mul, Add
  4. from sympy.core.assumptions import check_assumptions
  5. from sympy.core.decorators import call_highest_priority
  6. from sympy.core.expr import Expr, ExprBuilder
  7. from sympy.core.logic import FuzzyBool
  8. from sympy.core.symbol import Str, Dummy, symbols, Symbol
  9. from sympy.core.sympify import SympifyError, _sympify
  10. from sympy.external.gmpy import SYMPY_INTS
  11. from sympy.functions import conjugate, adjoint
  12. from sympy.functions.special.tensor_functions import KroneckerDelta
  13. from sympy.matrices.common import NonSquareMatrixError
  14. from sympy.matrices.matrices import MatrixKind, MatrixBase
  15. from sympy.multipledispatch import dispatch
  16. from sympy.utilities.misc import filldedent
  17. def _sympifyit(arg, retval=None):
  18. # This version of _sympifyit sympifies MutableMatrix objects
  19. def deco(func):
  20. @wraps(func)
  21. def __sympifyit_wrapper(a, b):
  22. try:
  23. b = _sympify(b)
  24. return func(a, b)
  25. except SympifyError:
  26. return retval
  27. return __sympifyit_wrapper
  28. return deco
  29. class MatrixExpr(Expr):
  30. """Superclass for Matrix Expressions
  31. MatrixExprs represent abstract matrices, linear transformations represented
  32. within a particular basis.
  33. Examples
  34. ========
  35. >>> from sympy import MatrixSymbol
  36. >>> A = MatrixSymbol('A', 3, 3)
  37. >>> y = MatrixSymbol('y', 3, 1)
  38. >>> x = (A.T*A).I * A * y
  39. See Also
  40. ========
  41. MatrixSymbol, MatAdd, MatMul, Transpose, Inverse
  42. """
  43. __slots__: tuple[str, ...] = ()
  44. # Should not be considered iterable by the
  45. # sympy.utilities.iterables.iterable function. Subclass that actually are
  46. # iterable (i.e., explicit matrices) should set this to True.
  47. _iterable = False
  48. _op_priority = 11.0
  49. is_Matrix: bool = True
  50. is_MatrixExpr: bool = True
  51. is_Identity: FuzzyBool = None
  52. is_Inverse = False
  53. is_Transpose = False
  54. is_ZeroMatrix = False
  55. is_MatAdd = False
  56. is_MatMul = False
  57. is_commutative = False
  58. is_number = False
  59. is_symbol = False
  60. is_scalar = False
  61. kind: MatrixKind = MatrixKind()
  62. def __new__(cls, *args, **kwargs):
  63. args = map(_sympify, args)
  64. return Basic.__new__(cls, *args, **kwargs)
  65. # The following is adapted from the core Expr object
  66. @property
  67. def shape(self) -> tuple[Expr, Expr]:
  68. raise NotImplementedError
  69. @property
  70. def _add_handler(self):
  71. return MatAdd
  72. @property
  73. def _mul_handler(self):
  74. return MatMul
  75. def __neg__(self):
  76. return MatMul(S.NegativeOne, self).doit()
  77. def __abs__(self):
  78. raise NotImplementedError
  79. @_sympifyit('other', NotImplemented)
  80. @call_highest_priority('__radd__')
  81. def __add__(self, other):
  82. return MatAdd(self, other).doit()
  83. @_sympifyit('other', NotImplemented)
  84. @call_highest_priority('__add__')
  85. def __radd__(self, other):
  86. return MatAdd(other, self).doit()
  87. @_sympifyit('other', NotImplemented)
  88. @call_highest_priority('__rsub__')
  89. def __sub__(self, other):
  90. return MatAdd(self, -other).doit()
  91. @_sympifyit('other', NotImplemented)
  92. @call_highest_priority('__sub__')
  93. def __rsub__(self, other):
  94. return MatAdd(other, -self).doit()
  95. @_sympifyit('other', NotImplemented)
  96. @call_highest_priority('__rmul__')
  97. def __mul__(self, other):
  98. return MatMul(self, other).doit()
  99. @_sympifyit('other', NotImplemented)
  100. @call_highest_priority('__rmul__')
  101. def __matmul__(self, other):
  102. return MatMul(self, other).doit()
  103. @_sympifyit('other', NotImplemented)
  104. @call_highest_priority('__mul__')
  105. def __rmul__(self, other):
  106. return MatMul(other, self).doit()
  107. @_sympifyit('other', NotImplemented)
  108. @call_highest_priority('__mul__')
  109. def __rmatmul__(self, other):
  110. return MatMul(other, self).doit()
  111. @_sympifyit('other', NotImplemented)
  112. @call_highest_priority('__rpow__')
  113. def __pow__(self, other):
  114. return MatPow(self, other).doit()
  115. @_sympifyit('other', NotImplemented)
  116. @call_highest_priority('__pow__')
  117. def __rpow__(self, other):
  118. raise NotImplementedError("Matrix Power not defined")
  119. @_sympifyit('other', NotImplemented)
  120. @call_highest_priority('__rtruediv__')
  121. def __truediv__(self, other):
  122. return self * other**S.NegativeOne
  123. @_sympifyit('other', NotImplemented)
  124. @call_highest_priority('__truediv__')
  125. def __rtruediv__(self, other):
  126. raise NotImplementedError()
  127. #return MatMul(other, Pow(self, S.NegativeOne))
  128. @property
  129. def rows(self):
  130. return self.shape[0]
  131. @property
  132. def cols(self):
  133. return self.shape[1]
  134. @property
  135. def is_square(self) -> bool | None:
  136. rows, cols = self.shape
  137. if isinstance(rows, Integer) and isinstance(cols, Integer):
  138. return rows == cols
  139. if rows == cols:
  140. return True
  141. return None
  142. def _eval_conjugate(self):
  143. from sympy.matrices.expressions.adjoint import Adjoint
  144. return Adjoint(Transpose(self))
  145. def as_real_imag(self, deep=True, **hints):
  146. return self._eval_as_real_imag()
  147. def _eval_as_real_imag(self):
  148. real = S.Half * (self + self._eval_conjugate())
  149. im = (self - self._eval_conjugate())/(2*S.ImaginaryUnit)
  150. return (real, im)
  151. def _eval_inverse(self):
  152. return Inverse(self)
  153. def _eval_determinant(self):
  154. return Determinant(self)
  155. def _eval_transpose(self):
  156. return Transpose(self)
  157. def _eval_power(self, exp):
  158. """
  159. Override this in sub-classes to implement simplification of powers. The cases where the exponent
  160. is -1, 0, 1 are already covered in MatPow.doit(), so implementations can exclude these cases.
  161. """
  162. return MatPow(self, exp)
  163. def _eval_simplify(self, **kwargs):
  164. if self.is_Atom:
  165. return self
  166. else:
  167. from sympy.simplify import simplify
  168. return self.func(*[simplify(x, **kwargs) for x in self.args])
  169. def _eval_adjoint(self):
  170. from sympy.matrices.expressions.adjoint import Adjoint
  171. return Adjoint(self)
  172. def _eval_derivative_n_times(self, x, n):
  173. return Basic._eval_derivative_n_times(self, x, n)
  174. def _eval_derivative(self, x):
  175. # `x` is a scalar:
  176. if self.has(x):
  177. # See if there are other methods using it:
  178. return super()._eval_derivative(x)
  179. else:
  180. return ZeroMatrix(*self.shape)
  181. @classmethod
  182. def _check_dim(cls, dim):
  183. """Helper function to check invalid matrix dimensions"""
  184. ok = check_assumptions(dim, integer=True, nonnegative=True)
  185. if ok is False:
  186. raise ValueError(
  187. "The dimension specification {} should be "
  188. "a nonnegative integer.".format(dim))
  189. def _entry(self, i, j, **kwargs):
  190. raise NotImplementedError(
  191. "Indexing not implemented for %s" % self.__class__.__name__)
  192. def adjoint(self):
  193. return adjoint(self)
  194. def as_coeff_Mul(self, rational=False):
  195. """Efficiently extract the coefficient of a product."""
  196. return S.One, self
  197. def conjugate(self):
  198. return conjugate(self)
  199. def transpose(self):
  200. from sympy.matrices.expressions.transpose import transpose
  201. return transpose(self)
  202. @property
  203. def T(self):
  204. '''Matrix transposition'''
  205. return self.transpose()
  206. def inverse(self):
  207. if self.is_square is False:
  208. raise NonSquareMatrixError('Inverse of non-square matrix')
  209. return self._eval_inverse()
  210. def inv(self):
  211. return self.inverse()
  212. def det(self):
  213. from sympy.matrices.expressions.determinant import det
  214. return det(self)
  215. @property
  216. def I(self):
  217. return self.inverse()
  218. def valid_index(self, i, j):
  219. def is_valid(idx):
  220. return isinstance(idx, (int, Integer, Symbol, Expr))
  221. return (is_valid(i) and is_valid(j) and
  222. (self.rows is None or
  223. (i >= -self.rows) != False and (i < self.rows) != False) and
  224. (j >= -self.cols) != False and (j < self.cols) != False)
  225. def __getitem__(self, key):
  226. if not isinstance(key, tuple) and isinstance(key, slice):
  227. from sympy.matrices.expressions.slice import MatrixSlice
  228. return MatrixSlice(self, key, (0, None, 1))
  229. if isinstance(key, tuple) and len(key) == 2:
  230. i, j = key
  231. if isinstance(i, slice) or isinstance(j, slice):
  232. from sympy.matrices.expressions.slice import MatrixSlice
  233. return MatrixSlice(self, i, j)
  234. i, j = _sympify(i), _sympify(j)
  235. if self.valid_index(i, j) != False:
  236. return self._entry(i, j)
  237. else:
  238. raise IndexError("Invalid indices (%s, %s)" % (i, j))
  239. elif isinstance(key, (SYMPY_INTS, Integer)):
  240. # row-wise decomposition of matrix
  241. rows, cols = self.shape
  242. # allow single indexing if number of columns is known
  243. if not isinstance(cols, Integer):
  244. raise IndexError(filldedent('''
  245. Single indexing is only supported when the number
  246. of columns is known.'''))
  247. key = _sympify(key)
  248. i = key // cols
  249. j = key % cols
  250. if self.valid_index(i, j) != False:
  251. return self._entry(i, j)
  252. else:
  253. raise IndexError("Invalid index %s" % key)
  254. elif isinstance(key, (Symbol, Expr)):
  255. raise IndexError(filldedent('''
  256. Only integers may be used when addressing the matrix
  257. with a single index.'''))
  258. raise IndexError("Invalid index, wanted %s[i,j]" % self)
  259. def _is_shape_symbolic(self) -> bool:
  260. return (not isinstance(self.rows, (SYMPY_INTS, Integer))
  261. or not isinstance(self.cols, (SYMPY_INTS, Integer)))
  262. def as_explicit(self):
  263. """
  264. Returns a dense Matrix with elements represented explicitly
  265. Returns an object of type ImmutableDenseMatrix.
  266. Examples
  267. ========
  268. >>> from sympy import Identity
  269. >>> I = Identity(3)
  270. >>> I
  271. I
  272. >>> I.as_explicit()
  273. Matrix([
  274. [1, 0, 0],
  275. [0, 1, 0],
  276. [0, 0, 1]])
  277. See Also
  278. ========
  279. as_mutable: returns mutable Matrix type
  280. """
  281. if self._is_shape_symbolic():
  282. raise ValueError(
  283. 'Matrix with symbolic shape '
  284. 'cannot be represented explicitly.')
  285. from sympy.matrices.immutable import ImmutableDenseMatrix
  286. return ImmutableDenseMatrix([[self[i, j]
  287. for j in range(self.cols)]
  288. for i in range(self.rows)])
  289. def as_mutable(self):
  290. """
  291. Returns a dense, mutable matrix with elements represented explicitly
  292. Examples
  293. ========
  294. >>> from sympy import Identity
  295. >>> I = Identity(3)
  296. >>> I
  297. I
  298. >>> I.shape
  299. (3, 3)
  300. >>> I.as_mutable()
  301. Matrix([
  302. [1, 0, 0],
  303. [0, 1, 0],
  304. [0, 0, 1]])
  305. See Also
  306. ========
  307. as_explicit: returns ImmutableDenseMatrix
  308. """
  309. return self.as_explicit().as_mutable()
  310. def __array__(self):
  311. from numpy import empty
  312. a = empty(self.shape, dtype=object)
  313. for i in range(self.rows):
  314. for j in range(self.cols):
  315. a[i, j] = self[i, j]
  316. return a
  317. def equals(self, other):
  318. """
  319. Test elementwise equality between matrices, potentially of different
  320. types
  321. >>> from sympy import Identity, eye
  322. >>> Identity(3).equals(eye(3))
  323. True
  324. """
  325. return self.as_explicit().equals(other)
  326. def canonicalize(self):
  327. return self
  328. def as_coeff_mmul(self):
  329. return S.One, MatMul(self)
  330. @staticmethod
  331. def from_index_summation(expr, first_index=None, last_index=None, dimensions=None):
  332. r"""
  333. Parse expression of matrices with explicitly summed indices into a
  334. matrix expression without indices, if possible.
  335. This transformation expressed in mathematical notation:
  336. `\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}`
  337. Optional parameter ``first_index``: specify which free index to use as
  338. the index starting the expression.
  339. Examples
  340. ========
  341. >>> from sympy import MatrixSymbol, MatrixExpr, Sum
  342. >>> from sympy.abc import i, j, k, l, N
  343. >>> A = MatrixSymbol("A", N, N)
  344. >>> B = MatrixSymbol("B", N, N)
  345. >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1))
  346. >>> MatrixExpr.from_index_summation(expr)
  347. A*B
  348. Transposition is detected:
  349. >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1))
  350. >>> MatrixExpr.from_index_summation(expr)
  351. A.T*B
  352. Detect the trace:
  353. >>> expr = Sum(A[i, i], (i, 0, N-1))
  354. >>> MatrixExpr.from_index_summation(expr)
  355. Trace(A)
  356. More complicated expressions:
  357. >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1))
  358. >>> MatrixExpr.from_index_summation(expr)
  359. A*B.T*A.T
  360. """
  361. from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array
  362. from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix
  363. first_indices = []
  364. if first_index is not None:
  365. first_indices.append(first_index)
  366. if last_index is not None:
  367. first_indices.append(last_index)
  368. arr = convert_indexed_to_array(expr, first_indices=first_indices)
  369. return convert_array_to_matrix(arr)
  370. def applyfunc(self, func):
  371. from .applyfunc import ElementwiseApplyFunction
  372. return ElementwiseApplyFunction(func, self)
  373. @dispatch(MatrixExpr, Expr)
  374. def _eval_is_eq(lhs, rhs): # noqa:F811
  375. return False
  376. @dispatch(MatrixExpr, MatrixExpr) # type: ignore
  377. def _eval_is_eq(lhs, rhs): # noqa:F811
  378. if lhs.shape != rhs.shape:
  379. return False
  380. if (lhs - rhs).is_ZeroMatrix:
  381. return True
  382. def get_postprocessor(cls):
  383. def _postprocessor(expr):
  384. # To avoid circular imports, we can't have MatMul/MatAdd on the top level
  385. mat_class = {Mul: MatMul, Add: MatAdd}[cls]
  386. nonmatrices = []
  387. matrices = []
  388. for term in expr.args:
  389. if isinstance(term, MatrixExpr):
  390. matrices.append(term)
  391. else:
  392. nonmatrices.append(term)
  393. if not matrices:
  394. return cls._from_args(nonmatrices)
  395. if nonmatrices:
  396. if cls == Mul:
  397. for i in range(len(matrices)):
  398. if not matrices[i].is_MatrixExpr:
  399. # If one of the matrices explicit, absorb the scalar into it
  400. # (doit will combine all explicit matrices into one, so it
  401. # doesn't matter which)
  402. matrices[i] = matrices[i].__mul__(cls._from_args(nonmatrices))
  403. nonmatrices = []
  404. break
  405. else:
  406. # Maintain the ability to create Add(scalar, matrix) without
  407. # raising an exception. That way different algorithms can
  408. # replace matrix expressions with non-commutative symbols to
  409. # manipulate them like non-commutative scalars.
  410. return cls._from_args(nonmatrices + [mat_class(*matrices).doit(deep=False)])
  411. if mat_class == MatAdd:
  412. return mat_class(*matrices).doit(deep=False)
  413. return mat_class(cls._from_args(nonmatrices), *matrices).doit(deep=False)
  414. return _postprocessor
  415. Basic._constructor_postprocessor_mapping[MatrixExpr] = {
  416. "Mul": [get_postprocessor(Mul)],
  417. "Add": [get_postprocessor(Add)],
  418. }
  419. def _matrix_derivative(expr, x, old_algorithm=False):
  420. if isinstance(expr, MatrixBase) or isinstance(x, MatrixBase):
  421. # Do not use array expressions for explicit matrices:
  422. old_algorithm = True
  423. if old_algorithm:
  424. return _matrix_derivative_old_algorithm(expr, x)
  425. from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array
  426. from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive
  427. from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix
  428. array_expr = convert_matrix_to_array(expr)
  429. diff_array_expr = array_derive(array_expr, x)
  430. diff_matrix_expr = convert_array_to_matrix(diff_array_expr)
  431. return diff_matrix_expr
  432. def _matrix_derivative_old_algorithm(expr, x):
  433. from sympy.tensor.array.array_derivatives import ArrayDerivative
  434. lines = expr._eval_derivative_matrix_lines(x)
  435. parts = [i.build() for i in lines]
  436. from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix
  437. parts = [[convert_array_to_matrix(j) for j in i] for i in parts]
  438. def _get_shape(elem):
  439. if isinstance(elem, MatrixExpr):
  440. return elem.shape
  441. return 1, 1
  442. def get_rank(parts):
  443. return sum([j not in (1, None) for i in parts for j in _get_shape(i)])
  444. ranks = [get_rank(i) for i in parts]
  445. rank = ranks[0]
  446. def contract_one_dims(parts):
  447. if len(parts) == 1:
  448. return parts[0]
  449. else:
  450. p1, p2 = parts[:2]
  451. if p2.is_Matrix:
  452. p2 = p2.T
  453. if p1 == Identity(1):
  454. pbase = p2
  455. elif p2 == Identity(1):
  456. pbase = p1
  457. else:
  458. pbase = p1*p2
  459. if len(parts) == 2:
  460. return pbase
  461. else: # len(parts) > 2
  462. if pbase.is_Matrix:
  463. raise ValueError("")
  464. return pbase*Mul.fromiter(parts[2:])
  465. if rank <= 2:
  466. return Add.fromiter([contract_one_dims(i) for i in parts])
  467. return ArrayDerivative(expr, x)
  468. class MatrixElement(Expr):
  469. parent = property(lambda self: self.args[0])
  470. i = property(lambda self: self.args[1])
  471. j = property(lambda self: self.args[2])
  472. _diff_wrt = True
  473. is_symbol = True
  474. is_commutative = True
  475. def __new__(cls, name, n, m):
  476. n, m = map(_sympify, (n, m))
  477. from sympy.matrices.matrices import MatrixBase
  478. if isinstance(name, str):
  479. name = Symbol(name)
  480. else:
  481. if isinstance(name, MatrixBase):
  482. if n.is_Integer and m.is_Integer:
  483. return name[n, m]
  484. name = _sympify(name) # change mutable into immutable
  485. else:
  486. name = _sympify(name)
  487. if not isinstance(name.kind, MatrixKind):
  488. raise TypeError("First argument of MatrixElement should be a matrix")
  489. if not getattr(name, 'valid_index', lambda n, m: True)(n, m):
  490. raise IndexError('indices out of range')
  491. obj = Expr.__new__(cls, name, n, m)
  492. return obj
  493. @property
  494. def symbol(self):
  495. return self.args[0]
  496. def doit(self, **hints):
  497. deep = hints.get('deep', True)
  498. if deep:
  499. args = [arg.doit(**hints) for arg in self.args]
  500. else:
  501. args = self.args
  502. return args[0][args[1], args[2]]
  503. @property
  504. def indices(self):
  505. return self.args[1:]
  506. def _eval_derivative(self, v):
  507. if not isinstance(v, MatrixElement):
  508. from sympy.matrices.matrices import MatrixBase
  509. if isinstance(self.parent, MatrixBase):
  510. return self.parent.diff(v)[self.i, self.j]
  511. return S.Zero
  512. M = self.args[0]
  513. m, n = self.parent.shape
  514. if M == v.args[0]:
  515. return KroneckerDelta(self.args[1], v.args[1], (0, m-1)) * \
  516. KroneckerDelta(self.args[2], v.args[2], (0, n-1))
  517. if isinstance(M, Inverse):
  518. from sympy.concrete.summations import Sum
  519. i, j = self.args[1:]
  520. i1, i2 = symbols("z1, z2", cls=Dummy)
  521. Y = M.args[0]
  522. r1, r2 = Y.shape
  523. return -Sum(M[i, i1]*Y[i1, i2].diff(v)*M[i2, j], (i1, 0, r1-1), (i2, 0, r2-1))
  524. if self.has(v.args[0]):
  525. return None
  526. return S.Zero
  527. class MatrixSymbol(MatrixExpr):
  528. """Symbolic representation of a Matrix object
  529. Creates a SymPy Symbol to represent a Matrix. This matrix has a shape and
  530. can be included in Matrix Expressions
  531. Examples
  532. ========
  533. >>> from sympy import MatrixSymbol, Identity
  534. >>> A = MatrixSymbol('A', 3, 4) # A 3 by 4 Matrix
  535. >>> B = MatrixSymbol('B', 4, 3) # A 4 by 3 Matrix
  536. >>> A.shape
  537. (3, 4)
  538. >>> 2*A*B + Identity(3)
  539. I + 2*A*B
  540. """
  541. is_commutative = False
  542. is_symbol = True
  543. _diff_wrt = True
  544. def __new__(cls, name, n, m):
  545. n, m = _sympify(n), _sympify(m)
  546. cls._check_dim(m)
  547. cls._check_dim(n)
  548. if isinstance(name, str):
  549. name = Str(name)
  550. obj = Basic.__new__(cls, name, n, m)
  551. return obj
  552. @property
  553. def shape(self):
  554. return self.args[1], self.args[2]
  555. @property
  556. def name(self):
  557. return self.args[0].name
  558. def _entry(self, i, j, **kwargs):
  559. return MatrixElement(self, i, j)
  560. @property
  561. def free_symbols(self):
  562. return {self}
  563. def _eval_simplify(self, **kwargs):
  564. return self
  565. def _eval_derivative(self, x):
  566. # x is a scalar:
  567. return ZeroMatrix(self.shape[0], self.shape[1])
  568. def _eval_derivative_matrix_lines(self, x):
  569. if self != x:
  570. first = ZeroMatrix(x.shape[0], self.shape[0]) if self.shape[0] != 1 else S.Zero
  571. second = ZeroMatrix(x.shape[1], self.shape[1]) if self.shape[1] != 1 else S.Zero
  572. return [_LeftRightArgs(
  573. [first, second],
  574. )]
  575. else:
  576. first = Identity(self.shape[0]) if self.shape[0] != 1 else S.One
  577. second = Identity(self.shape[1]) if self.shape[1] != 1 else S.One
  578. return [_LeftRightArgs(
  579. [first, second],
  580. )]
  581. def matrix_symbols(expr):
  582. return [sym for sym in expr.free_symbols if sym.is_Matrix]
  583. class _LeftRightArgs:
  584. r"""
  585. Helper class to compute matrix derivatives.
  586. The logic: when an expression is derived by a matrix `X_{mn}`, two lines of
  587. matrix multiplications are created: the one contracted to `m` (first line),
  588. and the one contracted to `n` (second line).
  589. Transposition flips the side by which new matrices are connected to the
  590. lines.
  591. The trace connects the end of the two lines.
  592. """
  593. def __init__(self, lines, higher=S.One):
  594. self._lines = list(lines)
  595. self._first_pointer_parent = self._lines
  596. self._first_pointer_index = 0
  597. self._first_line_index = 0
  598. self._second_pointer_parent = self._lines
  599. self._second_pointer_index = 1
  600. self._second_line_index = 1
  601. self.higher = higher
  602. @property
  603. def first_pointer(self):
  604. return self._first_pointer_parent[self._first_pointer_index]
  605. @first_pointer.setter
  606. def first_pointer(self, value):
  607. self._first_pointer_parent[self._first_pointer_index] = value
  608. @property
  609. def second_pointer(self):
  610. return self._second_pointer_parent[self._second_pointer_index]
  611. @second_pointer.setter
  612. def second_pointer(self, value):
  613. self._second_pointer_parent[self._second_pointer_index] = value
  614. def __repr__(self):
  615. built = [self._build(i) for i in self._lines]
  616. return "_LeftRightArgs(lines=%s, higher=%s)" % (
  617. built,
  618. self.higher,
  619. )
  620. def transpose(self):
  621. self._first_pointer_parent, self._second_pointer_parent = self._second_pointer_parent, self._first_pointer_parent
  622. self._first_pointer_index, self._second_pointer_index = self._second_pointer_index, self._first_pointer_index
  623. self._first_line_index, self._second_line_index = self._second_line_index, self._first_line_index
  624. return self
  625. @staticmethod
  626. def _build(expr):
  627. if isinstance(expr, ExprBuilder):
  628. return expr.build()
  629. if isinstance(expr, list):
  630. if len(expr) == 1:
  631. return expr[0]
  632. else:
  633. return expr[0](*[_LeftRightArgs._build(i) for i in expr[1]])
  634. else:
  635. return expr
  636. def build(self):
  637. data = [self._build(i) for i in self._lines]
  638. if self.higher != 1:
  639. data += [self._build(self.higher)]
  640. data = list(data)
  641. return data
  642. def matrix_form(self):
  643. if self.first != 1 and self.higher != 1:
  644. raise ValueError("higher dimensional array cannot be represented")
  645. def _get_shape(elem):
  646. if isinstance(elem, MatrixExpr):
  647. return elem.shape
  648. return (None, None)
  649. if _get_shape(self.first)[1] != _get_shape(self.second)[1]:
  650. # Remove one-dimensional identity matrices:
  651. # (this is needed by `a.diff(a)` where `a` is a vector)
  652. if _get_shape(self.second) == (1, 1):
  653. return self.first*self.second[0, 0]
  654. if _get_shape(self.first) == (1, 1):
  655. return self.first[1, 1]*self.second.T
  656. raise ValueError("incompatible shapes")
  657. if self.first != 1:
  658. return self.first*self.second.T
  659. else:
  660. return self.higher
  661. def rank(self):
  662. """
  663. Number of dimensions different from trivial (warning: not related to
  664. matrix rank).
  665. """
  666. rank = 0
  667. if self.first != 1:
  668. rank += sum([i != 1 for i in self.first.shape])
  669. if self.second != 1:
  670. rank += sum([i != 1 for i in self.second.shape])
  671. if self.higher != 1:
  672. rank += 2
  673. return rank
  674. def _multiply_pointer(self, pointer, other):
  675. from ...tensor.array.expressions.array_expressions import ArrayTensorProduct
  676. from ...tensor.array.expressions.array_expressions import ArrayContraction
  677. subexpr = ExprBuilder(
  678. ArrayContraction,
  679. [
  680. ExprBuilder(
  681. ArrayTensorProduct,
  682. [
  683. pointer,
  684. other
  685. ]
  686. ),
  687. (1, 2)
  688. ],
  689. validator=ArrayContraction._validate
  690. )
  691. return subexpr
  692. def append_first(self, other):
  693. self.first_pointer *= other
  694. def append_second(self, other):
  695. self.second_pointer *= other
  696. def _make_matrix(x):
  697. from sympy.matrices.immutable import ImmutableDenseMatrix
  698. if isinstance(x, MatrixExpr):
  699. return x
  700. return ImmutableDenseMatrix([[x]])
  701. from .matmul import MatMul
  702. from .matadd import MatAdd
  703. from .matpow import MatPow
  704. from .transpose import Transpose
  705. from .inverse import Inverse
  706. from .special import ZeroMatrix, Identity
  707. from .determinant import Determinant