matadd.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from functools import reduce
  2. import operator
  3. from sympy.core import Basic, sympify
  4. from sympy.core.add import add, Add, _could_extract_minus_sign
  5. from sympy.core.sorting import default_sort_key
  6. from sympy.functions import adjoint
  7. from sympy.matrices.matrices import MatrixBase
  8. from sympy.matrices.expressions.transpose import transpose
  9. from sympy.strategies import (rm_id, unpack, flatten, sort, condition,
  10. exhaust, do_one, glom)
  11. from sympy.matrices.expressions.matexpr import MatrixExpr
  12. from sympy.matrices.expressions.special import ZeroMatrix, GenericZeroMatrix
  13. from sympy.matrices.expressions._shape import validate_matadd_integer as validate
  14. from sympy.utilities.iterables import sift
  15. from sympy.utilities.exceptions import sympy_deprecation_warning
  16. # XXX: MatAdd should perhaps not subclass directly from Add
  17. class MatAdd(MatrixExpr, Add):
  18. """A Sum of Matrix Expressions
  19. MatAdd inherits from and operates like SymPy Add
  20. Examples
  21. ========
  22. >>> from sympy import MatAdd, MatrixSymbol
  23. >>> A = MatrixSymbol('A', 5, 5)
  24. >>> B = MatrixSymbol('B', 5, 5)
  25. >>> C = MatrixSymbol('C', 5, 5)
  26. >>> MatAdd(A, B, C)
  27. A + B + C
  28. """
  29. is_MatAdd = True
  30. identity = GenericZeroMatrix()
  31. def __new__(cls, *args, evaluate=False, check=None, _sympify=True):
  32. if not args:
  33. return cls.identity
  34. # This must be removed aggressively in the constructor to avoid
  35. # TypeErrors from GenericZeroMatrix().shape
  36. args = list(filter(lambda i: cls.identity != i, args))
  37. if _sympify:
  38. args = list(map(sympify, args))
  39. if not all(isinstance(arg, MatrixExpr) for arg in args):
  40. raise TypeError("Mix of Matrix and Scalar symbols")
  41. obj = Basic.__new__(cls, *args)
  42. if check is not None:
  43. sympy_deprecation_warning(
  44. "Passing check to MatAdd is deprecated and the check argument will be removed in a future version.",
  45. deprecated_since_version="1.11",
  46. active_deprecations_target='remove-check-argument-from-matrix-operations')
  47. if check is not False:
  48. validate(*args)
  49. if evaluate:
  50. obj = cls._evaluate(obj)
  51. return obj
  52. @classmethod
  53. def _evaluate(cls, expr):
  54. return canonicalize(expr)
  55. @property
  56. def shape(self):
  57. return self.args[0].shape
  58. def could_extract_minus_sign(self):
  59. return _could_extract_minus_sign(self)
  60. def expand(self, **kwargs):
  61. expanded = super(MatAdd, self).expand(**kwargs)
  62. return self._evaluate(expanded)
  63. def _entry(self, i, j, **kwargs):
  64. return Add(*[arg._entry(i, j, **kwargs) for arg in self.args])
  65. def _eval_transpose(self):
  66. return MatAdd(*[transpose(arg) for arg in self.args]).doit()
  67. def _eval_adjoint(self):
  68. return MatAdd(*[adjoint(arg) for arg in self.args]).doit()
  69. def _eval_trace(self):
  70. from .trace import trace
  71. return Add(*[trace(arg) for arg in self.args]).doit()
  72. def doit(self, **hints):
  73. deep = hints.get('deep', True)
  74. if deep:
  75. args = [arg.doit(**hints) for arg in self.args]
  76. else:
  77. args = self.args
  78. return canonicalize(MatAdd(*args))
  79. def _eval_derivative_matrix_lines(self, x):
  80. add_lines = [arg._eval_derivative_matrix_lines(x) for arg in self.args]
  81. return [j for i in add_lines for j in i]
  82. add.register_handlerclass((Add, MatAdd), MatAdd)
  83. factor_of = lambda arg: arg.as_coeff_mmul()[0]
  84. matrix_of = lambda arg: unpack(arg.as_coeff_mmul()[1])
  85. def combine(cnt, mat):
  86. if cnt == 1:
  87. return mat
  88. else:
  89. return cnt * mat
  90. def merge_explicit(matadd):
  91. """ Merge explicit MatrixBase arguments
  92. Examples
  93. ========
  94. >>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint
  95. >>> from sympy.matrices.expressions.matadd import merge_explicit
  96. >>> A = MatrixSymbol('A', 2, 2)
  97. >>> B = eye(2)
  98. >>> C = Matrix([[1, 2], [3, 4]])
  99. >>> X = MatAdd(A, B, C)
  100. >>> pprint(X)
  101. [1 0] [1 2]
  102. A + [ ] + [ ]
  103. [0 1] [3 4]
  104. >>> pprint(merge_explicit(X))
  105. [2 2]
  106. A + [ ]
  107. [3 5]
  108. """
  109. groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase))
  110. if len(groups[True]) > 1:
  111. return MatAdd(*(groups[False] + [reduce(operator.add, groups[True])]))
  112. else:
  113. return matadd
  114. rules = (rm_id(lambda x: x == 0 or isinstance(x, ZeroMatrix)),
  115. unpack,
  116. flatten,
  117. glom(matrix_of, factor_of, combine),
  118. merge_explicit,
  119. sort(default_sort_key))
  120. canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd),
  121. do_one(*rules)))