special.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. from sympy.assumptions.ask import ask, Q
  2. from sympy.core.relational import Eq
  3. from sympy.core.singleton import S
  4. from sympy.core.sympify import _sympify
  5. from sympy.functions.special.tensor_functions import KroneckerDelta
  6. from sympy.matrices.common import NonInvertibleMatrixError
  7. from .matexpr import MatrixExpr
  8. class ZeroMatrix(MatrixExpr):
  9. """The Matrix Zero 0 - additive identity
  10. Examples
  11. ========
  12. >>> from sympy import MatrixSymbol, ZeroMatrix
  13. >>> A = MatrixSymbol('A', 3, 5)
  14. >>> Z = ZeroMatrix(3, 5)
  15. >>> A + Z
  16. A
  17. >>> Z*A.T
  18. 0
  19. """
  20. is_ZeroMatrix = True
  21. def __new__(cls, m, n):
  22. m, n = _sympify(m), _sympify(n)
  23. cls._check_dim(m)
  24. cls._check_dim(n)
  25. return super().__new__(cls, m, n)
  26. @property
  27. def shape(self):
  28. return (self.args[0], self.args[1])
  29. def _eval_power(self, exp):
  30. # exp = -1, 0, 1 are already handled at this stage
  31. if (exp < 0) == True:
  32. raise NonInvertibleMatrixError("Matrix det == 0; not invertible")
  33. return self
  34. def _eval_transpose(self):
  35. return ZeroMatrix(self.cols, self.rows)
  36. def _eval_adjoint(self):
  37. return ZeroMatrix(self.cols, self.rows)
  38. def _eval_trace(self):
  39. return S.Zero
  40. def _eval_determinant(self):
  41. return S.Zero
  42. def _eval_inverse(self):
  43. raise NonInvertibleMatrixError("Matrix det == 0; not invertible.")
  44. def _eval_as_real_imag(self):
  45. return (self, self)
  46. def _eval_conjugate(self):
  47. return self
  48. def _entry(self, i, j, **kwargs):
  49. return S.Zero
  50. class GenericZeroMatrix(ZeroMatrix):
  51. """
  52. A zero matrix without a specified shape
  53. This exists primarily so MatAdd() with no arguments can return something
  54. meaningful.
  55. """
  56. def __new__(cls):
  57. # super(ZeroMatrix, cls) instead of super(GenericZeroMatrix, cls)
  58. # because ZeroMatrix.__new__ doesn't have the same signature
  59. return super(ZeroMatrix, cls).__new__(cls)
  60. @property
  61. def rows(self):
  62. raise TypeError("GenericZeroMatrix does not have a specified shape")
  63. @property
  64. def cols(self):
  65. raise TypeError("GenericZeroMatrix does not have a specified shape")
  66. @property
  67. def shape(self):
  68. raise TypeError("GenericZeroMatrix does not have a specified shape")
  69. # Avoid Matrix.__eq__ which might call .shape
  70. def __eq__(self, other):
  71. return isinstance(other, GenericZeroMatrix)
  72. def __ne__(self, other):
  73. return not (self == other)
  74. def __hash__(self):
  75. return super().__hash__()
  76. class Identity(MatrixExpr):
  77. """The Matrix Identity I - multiplicative identity
  78. Examples
  79. ========
  80. >>> from sympy import Identity, MatrixSymbol
  81. >>> A = MatrixSymbol('A', 3, 5)
  82. >>> I = Identity(3)
  83. >>> I*A
  84. A
  85. """
  86. is_Identity = True
  87. def __new__(cls, n):
  88. n = _sympify(n)
  89. cls._check_dim(n)
  90. return super().__new__(cls, n)
  91. @property
  92. def rows(self):
  93. return self.args[0]
  94. @property
  95. def cols(self):
  96. return self.args[0]
  97. @property
  98. def shape(self):
  99. return (self.args[0], self.args[0])
  100. @property
  101. def is_square(self):
  102. return True
  103. def _eval_transpose(self):
  104. return self
  105. def _eval_trace(self):
  106. return self.rows
  107. def _eval_inverse(self):
  108. return self
  109. def _eval_as_real_imag(self):
  110. return (self, ZeroMatrix(*self.shape))
  111. def _eval_conjugate(self):
  112. return self
  113. def _eval_adjoint(self):
  114. return self
  115. def _entry(self, i, j, **kwargs):
  116. eq = Eq(i, j)
  117. if eq is S.true:
  118. return S.One
  119. elif eq is S.false:
  120. return S.Zero
  121. return KroneckerDelta(i, j, (0, self.cols-1))
  122. def _eval_determinant(self):
  123. return S.One
  124. def _eval_power(self, exp):
  125. return self
  126. class GenericIdentity(Identity):
  127. """
  128. An identity matrix without a specified shape
  129. This exists primarily so MatMul() with no arguments can return something
  130. meaningful.
  131. """
  132. def __new__(cls):
  133. # super(Identity, cls) instead of super(GenericIdentity, cls) because
  134. # Identity.__new__ doesn't have the same signature
  135. return super(Identity, cls).__new__(cls)
  136. @property
  137. def rows(self):
  138. raise TypeError("GenericIdentity does not have a specified shape")
  139. @property
  140. def cols(self):
  141. raise TypeError("GenericIdentity does not have a specified shape")
  142. @property
  143. def shape(self):
  144. raise TypeError("GenericIdentity does not have a specified shape")
  145. @property
  146. def is_square(self):
  147. return True
  148. # Avoid Matrix.__eq__ which might call .shape
  149. def __eq__(self, other):
  150. return isinstance(other, GenericIdentity)
  151. def __ne__(self, other):
  152. return not (self == other)
  153. def __hash__(self):
  154. return super().__hash__()
  155. class OneMatrix(MatrixExpr):
  156. """
  157. Matrix whose all entries are ones.
  158. """
  159. def __new__(cls, m, n, evaluate=False):
  160. m, n = _sympify(m), _sympify(n)
  161. cls._check_dim(m)
  162. cls._check_dim(n)
  163. if evaluate:
  164. condition = Eq(m, 1) & Eq(n, 1)
  165. if condition == True:
  166. return Identity(1)
  167. obj = super().__new__(cls, m, n)
  168. return obj
  169. @property
  170. def shape(self):
  171. return self._args
  172. @property
  173. def is_Identity(self):
  174. return self._is_1x1() == True
  175. def as_explicit(self):
  176. from sympy.matrices.immutable import ImmutableDenseMatrix
  177. return ImmutableDenseMatrix.ones(*self.shape)
  178. def doit(self, **hints):
  179. args = self.args
  180. if hints.get('deep', True):
  181. args = [a.doit(**hints) for a in args]
  182. return self.func(*args, evaluate=True)
  183. def _eval_power(self, exp):
  184. # exp = -1, 0, 1 are already handled at this stage
  185. if self._is_1x1() == True:
  186. return Identity(1)
  187. if (exp < 0) == True:
  188. raise NonInvertibleMatrixError("Matrix det == 0; not invertible")
  189. if ask(Q.integer(exp)):
  190. return self.shape[0] ** (exp - 1) * OneMatrix(*self.shape)
  191. return super()._eval_power(exp)
  192. def _eval_transpose(self):
  193. return OneMatrix(self.cols, self.rows)
  194. def _eval_adjoint(self):
  195. return OneMatrix(self.cols, self.rows)
  196. def _eval_trace(self):
  197. return S.One*self.rows
  198. def _is_1x1(self):
  199. """Returns true if the matrix is known to be 1x1"""
  200. shape = self.shape
  201. return Eq(shape[0], 1) & Eq(shape[1], 1)
  202. def _eval_determinant(self):
  203. condition = self._is_1x1()
  204. if condition == True:
  205. return S.One
  206. elif condition == False:
  207. return S.Zero
  208. else:
  209. from sympy.matrices.expressions.determinant import Determinant
  210. return Determinant(self)
  211. def _eval_inverse(self):
  212. condition = self._is_1x1()
  213. if condition == True:
  214. return Identity(1)
  215. elif condition == False:
  216. raise NonInvertibleMatrixError("Matrix det == 0; not invertible.")
  217. else:
  218. from .inverse import Inverse
  219. return Inverse(self)
  220. def _eval_as_real_imag(self):
  221. return (self, ZeroMatrix(*self.shape))
  222. def _eval_conjugate(self):
  223. return self
  224. def _entry(self, i, j, **kwargs):
  225. return S.One