test_matrices.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. from sympy.assumptions.ask import (Q, ask)
  2. from sympy.core.symbol import Symbol
  3. from sympy.matrices.expressions.diagonal import (DiagMatrix, DiagonalMatrix)
  4. from sympy.matrices.dense import Matrix
  5. from sympy.matrices.expressions import (MatrixSymbol, Identity, ZeroMatrix,
  6. OneMatrix, Trace, MatrixSlice, Determinant, BlockMatrix, BlockDiagMatrix)
  7. from sympy.matrices.expressions.factorizations import LofLU
  8. from sympy.testing.pytest import XFAIL
  9. X = MatrixSymbol('X', 2, 2)
  10. Y = MatrixSymbol('Y', 2, 3)
  11. Z = MatrixSymbol('Z', 2, 2)
  12. A1x1 = MatrixSymbol('A1x1', 1, 1)
  13. B1x1 = MatrixSymbol('B1x1', 1, 1)
  14. C0x0 = MatrixSymbol('C0x0', 0, 0)
  15. V1 = MatrixSymbol('V1', 2, 1)
  16. V2 = MatrixSymbol('V2', 2, 1)
  17. def test_square():
  18. assert ask(Q.square(X))
  19. assert not ask(Q.square(Y))
  20. assert ask(Q.square(Y*Y.T))
  21. def test_invertible():
  22. assert ask(Q.invertible(X), Q.invertible(X))
  23. assert ask(Q.invertible(Y)) is False
  24. assert ask(Q.invertible(X*Y), Q.invertible(X)) is False
  25. assert ask(Q.invertible(X*Z), Q.invertible(X)) is None
  26. assert ask(Q.invertible(X*Z), Q.invertible(X) & Q.invertible(Z)) is True
  27. assert ask(Q.invertible(X.T)) is None
  28. assert ask(Q.invertible(X.T), Q.invertible(X)) is True
  29. assert ask(Q.invertible(X.I)) is True
  30. assert ask(Q.invertible(Identity(3))) is True
  31. assert ask(Q.invertible(ZeroMatrix(3, 3))) is False
  32. assert ask(Q.invertible(OneMatrix(1, 1))) is True
  33. assert ask(Q.invertible(OneMatrix(3, 3))) is False
  34. assert ask(Q.invertible(X), Q.fullrank(X) & Q.square(X))
  35. def test_singular():
  36. assert ask(Q.singular(X)) is None
  37. assert ask(Q.singular(X), Q.invertible(X)) is False
  38. assert ask(Q.singular(X), ~Q.invertible(X)) is True
  39. @XFAIL
  40. def test_invertible_fullrank():
  41. assert ask(Q.invertible(X), Q.fullrank(X)) is True
  42. def test_invertible_BlockMatrix():
  43. assert ask(Q.invertible(BlockMatrix([Identity(3)]))) == True
  44. assert ask(Q.invertible(BlockMatrix([ZeroMatrix(3, 3)]))) == False
  45. X = Matrix([[1, 2, 3], [3, 5, 4]])
  46. Y = Matrix([[4, 2, 7], [2, 3, 5]])
  47. # non-invertible A block
  48. assert ask(Q.invertible(BlockMatrix([
  49. [Matrix.ones(3, 3), Y.T],
  50. [X, Matrix.eye(2)],
  51. ]))) == True
  52. # non-invertible B block
  53. assert ask(Q.invertible(BlockMatrix([
  54. [Y.T, Matrix.ones(3, 3)],
  55. [Matrix.eye(2), X],
  56. ]))) == True
  57. # non-invertible C block
  58. assert ask(Q.invertible(BlockMatrix([
  59. [X, Matrix.eye(2)],
  60. [Matrix.ones(3, 3), Y.T],
  61. ]))) == True
  62. # non-invertible D block
  63. assert ask(Q.invertible(BlockMatrix([
  64. [Matrix.eye(2), X],
  65. [Y.T, Matrix.ones(3, 3)],
  66. ]))) == True
  67. def test_invertible_BlockDiagMatrix():
  68. assert ask(Q.invertible(BlockDiagMatrix(Identity(3), Identity(5)))) == True
  69. assert ask(Q.invertible(BlockDiagMatrix(ZeroMatrix(3, 3), Identity(5)))) == False
  70. assert ask(Q.invertible(BlockDiagMatrix(Identity(3), OneMatrix(5, 5)))) == False
  71. def test_symmetric():
  72. assert ask(Q.symmetric(X), Q.symmetric(X))
  73. assert ask(Q.symmetric(X*Z), Q.symmetric(X)) is None
  74. assert ask(Q.symmetric(X*Z), Q.symmetric(X) & Q.symmetric(Z)) is True
  75. assert ask(Q.symmetric(X + Z), Q.symmetric(X) & Q.symmetric(Z)) is True
  76. assert ask(Q.symmetric(Y)) is False
  77. assert ask(Q.symmetric(Y*Y.T)) is True
  78. assert ask(Q.symmetric(Y.T*X*Y)) is None
  79. assert ask(Q.symmetric(Y.T*X*Y), Q.symmetric(X)) is True
  80. assert ask(Q.symmetric(X**10), Q.symmetric(X)) is True
  81. assert ask(Q.symmetric(A1x1)) is True
  82. assert ask(Q.symmetric(A1x1 + B1x1)) is True
  83. assert ask(Q.symmetric(A1x1 * B1x1)) is True
  84. assert ask(Q.symmetric(V1.T*V1)) is True
  85. assert ask(Q.symmetric(V1.T*(V1 + V2))) is True
  86. assert ask(Q.symmetric(V1.T*(V1 + V2) + A1x1)) is True
  87. assert ask(Q.symmetric(MatrixSlice(Y, (0, 1), (1, 2)))) is True
  88. assert ask(Q.symmetric(Identity(3))) is True
  89. assert ask(Q.symmetric(ZeroMatrix(3, 3))) is True
  90. assert ask(Q.symmetric(OneMatrix(3, 3))) is True
  91. def _test_orthogonal_unitary(predicate):
  92. assert ask(predicate(X), predicate(X))
  93. assert ask(predicate(X.T), predicate(X)) is True
  94. assert ask(predicate(X.I), predicate(X)) is True
  95. assert ask(predicate(X**2), predicate(X))
  96. assert ask(predicate(Y)) is False
  97. assert ask(predicate(X)) is None
  98. assert ask(predicate(X), ~Q.invertible(X)) is False
  99. assert ask(predicate(X*Z*X), predicate(X) & predicate(Z)) is True
  100. assert ask(predicate(Identity(3))) is True
  101. assert ask(predicate(ZeroMatrix(3, 3))) is False
  102. assert ask(Q.invertible(X), predicate(X))
  103. assert not ask(predicate(X + Z), predicate(X) & predicate(Z))
  104. def test_orthogonal():
  105. _test_orthogonal_unitary(Q.orthogonal)
  106. def test_unitary():
  107. _test_orthogonal_unitary(Q.unitary)
  108. assert ask(Q.unitary(X), Q.orthogonal(X))
  109. def test_fullrank():
  110. assert ask(Q.fullrank(X), Q.fullrank(X))
  111. assert ask(Q.fullrank(X**2), Q.fullrank(X))
  112. assert ask(Q.fullrank(X.T), Q.fullrank(X)) is True
  113. assert ask(Q.fullrank(X)) is None
  114. assert ask(Q.fullrank(Y)) is None
  115. assert ask(Q.fullrank(X*Z), Q.fullrank(X) & Q.fullrank(Z)) is True
  116. assert ask(Q.fullrank(Identity(3))) is True
  117. assert ask(Q.fullrank(ZeroMatrix(3, 3))) is False
  118. assert ask(Q.fullrank(OneMatrix(1, 1))) is True
  119. assert ask(Q.fullrank(OneMatrix(3, 3))) is False
  120. assert ask(Q.invertible(X), ~Q.fullrank(X)) == False
  121. def test_positive_definite():
  122. assert ask(Q.positive_definite(X), Q.positive_definite(X))
  123. assert ask(Q.positive_definite(X.T), Q.positive_definite(X)) is True
  124. assert ask(Q.positive_definite(X.I), Q.positive_definite(X)) is True
  125. assert ask(Q.positive_definite(Y)) is False
  126. assert ask(Q.positive_definite(X)) is None
  127. assert ask(Q.positive_definite(X**3), Q.positive_definite(X))
  128. assert ask(Q.positive_definite(X*Z*X),
  129. Q.positive_definite(X) & Q.positive_definite(Z)) is True
  130. assert ask(Q.positive_definite(X), Q.orthogonal(X))
  131. assert ask(Q.positive_definite(Y.T*X*Y),
  132. Q.positive_definite(X) & Q.fullrank(Y)) is True
  133. assert not ask(Q.positive_definite(Y.T*X*Y), Q.positive_definite(X))
  134. assert ask(Q.positive_definite(Identity(3))) is True
  135. assert ask(Q.positive_definite(ZeroMatrix(3, 3))) is False
  136. assert ask(Q.positive_definite(OneMatrix(1, 1))) is True
  137. assert ask(Q.positive_definite(OneMatrix(3, 3))) is False
  138. assert ask(Q.positive_definite(X + Z), Q.positive_definite(X) &
  139. Q.positive_definite(Z)) is True
  140. assert not ask(Q.positive_definite(-X), Q.positive_definite(X))
  141. assert ask(Q.positive(X[1, 1]), Q.positive_definite(X))
  142. def test_triangular():
  143. assert ask(Q.upper_triangular(X + Z.T + Identity(2)), Q.upper_triangular(X) &
  144. Q.lower_triangular(Z)) is True
  145. assert ask(Q.upper_triangular(X*Z.T), Q.upper_triangular(X) &
  146. Q.lower_triangular(Z)) is True
  147. assert ask(Q.lower_triangular(Identity(3))) is True
  148. assert ask(Q.lower_triangular(ZeroMatrix(3, 3))) is True
  149. assert ask(Q.upper_triangular(ZeroMatrix(3, 3))) is True
  150. assert ask(Q.lower_triangular(OneMatrix(1, 1))) is True
  151. assert ask(Q.upper_triangular(OneMatrix(1, 1))) is True
  152. assert ask(Q.lower_triangular(OneMatrix(3, 3))) is False
  153. assert ask(Q.upper_triangular(OneMatrix(3, 3))) is False
  154. assert ask(Q.triangular(X), Q.unit_triangular(X))
  155. assert ask(Q.upper_triangular(X**3), Q.upper_triangular(X))
  156. assert ask(Q.lower_triangular(X**3), Q.lower_triangular(X))
  157. def test_diagonal():
  158. assert ask(Q.diagonal(X + Z.T + Identity(2)), Q.diagonal(X) &
  159. Q.diagonal(Z)) is True
  160. assert ask(Q.diagonal(ZeroMatrix(3, 3)))
  161. assert ask(Q.diagonal(OneMatrix(1, 1))) is True
  162. assert ask(Q.diagonal(OneMatrix(3, 3))) is False
  163. assert ask(Q.lower_triangular(X) & Q.upper_triangular(X), Q.diagonal(X))
  164. assert ask(Q.diagonal(X), Q.lower_triangular(X) & Q.upper_triangular(X))
  165. assert ask(Q.symmetric(X), Q.diagonal(X))
  166. assert ask(Q.triangular(X), Q.diagonal(X))
  167. assert ask(Q.diagonal(C0x0))
  168. assert ask(Q.diagonal(A1x1))
  169. assert ask(Q.diagonal(A1x1 + B1x1))
  170. assert ask(Q.diagonal(A1x1*B1x1))
  171. assert ask(Q.diagonal(V1.T*V2))
  172. assert ask(Q.diagonal(V1.T*(X + Z)*V1))
  173. assert ask(Q.diagonal(MatrixSlice(Y, (0, 1), (1, 2)))) is True
  174. assert ask(Q.diagonal(V1.T*(V1 + V2))) is True
  175. assert ask(Q.diagonal(X**3), Q.diagonal(X))
  176. assert ask(Q.diagonal(Identity(3)))
  177. assert ask(Q.diagonal(DiagMatrix(V1)))
  178. assert ask(Q.diagonal(DiagonalMatrix(X)))
  179. def test_non_atoms():
  180. assert ask(Q.real(Trace(X)), Q.positive(Trace(X)))
  181. @XFAIL
  182. def test_non_trivial_implies():
  183. X = MatrixSymbol('X', 3, 3)
  184. Y = MatrixSymbol('Y', 3, 3)
  185. assert ask(Q.lower_triangular(X+Y), Q.lower_triangular(X) &
  186. Q.lower_triangular(Y)) is True
  187. assert ask(Q.triangular(X), Q.lower_triangular(X)) is True
  188. assert ask(Q.triangular(X+Y), Q.lower_triangular(X) &
  189. Q.lower_triangular(Y)) is True
  190. def test_MatrixSlice():
  191. X = MatrixSymbol('X', 4, 4)
  192. B = MatrixSlice(X, (1, 3), (1, 3))
  193. C = MatrixSlice(X, (0, 3), (1, 3))
  194. assert ask(Q.symmetric(B), Q.symmetric(X))
  195. assert ask(Q.invertible(B), Q.invertible(X))
  196. assert ask(Q.diagonal(B), Q.diagonal(X))
  197. assert ask(Q.orthogonal(B), Q.orthogonal(X))
  198. assert ask(Q.upper_triangular(B), Q.upper_triangular(X))
  199. assert not ask(Q.symmetric(C), Q.symmetric(X))
  200. assert not ask(Q.invertible(C), Q.invertible(X))
  201. assert not ask(Q.diagonal(C), Q.diagonal(X))
  202. assert not ask(Q.orthogonal(C), Q.orthogonal(X))
  203. assert not ask(Q.upper_triangular(C), Q.upper_triangular(X))
  204. def test_det_trace_positive():
  205. X = MatrixSymbol('X', 4, 4)
  206. assert ask(Q.positive(Trace(X)), Q.positive_definite(X))
  207. assert ask(Q.positive(Determinant(X)), Q.positive_definite(X))
  208. def test_field_assumptions():
  209. X = MatrixSymbol('X', 4, 4)
  210. Y = MatrixSymbol('Y', 4, 4)
  211. assert ask(Q.real_elements(X), Q.real_elements(X))
  212. assert not ask(Q.integer_elements(X), Q.real_elements(X))
  213. assert ask(Q.complex_elements(X), Q.real_elements(X))
  214. assert ask(Q.complex_elements(X**2), Q.real_elements(X))
  215. assert ask(Q.real_elements(X**2), Q.integer_elements(X))
  216. assert ask(Q.real_elements(X+Y), Q.real_elements(X)) is None
  217. assert ask(Q.real_elements(X+Y), Q.real_elements(X) & Q.real_elements(Y))
  218. from sympy.matrices.expressions.hadamard import HadamardProduct
  219. assert ask(Q.real_elements(HadamardProduct(X, Y)),
  220. Q.real_elements(X) & Q.real_elements(Y))
  221. assert ask(Q.complex_elements(X+Y), Q.real_elements(X) & Q.complex_elements(Y))
  222. assert ask(Q.real_elements(X.T), Q.real_elements(X))
  223. assert ask(Q.real_elements(X.I), Q.real_elements(X) & Q.invertible(X))
  224. assert ask(Q.real_elements(Trace(X)), Q.real_elements(X))
  225. assert ask(Q.integer_elements(Determinant(X)), Q.integer_elements(X))
  226. assert not ask(Q.integer_elements(X.I), Q.integer_elements(X))
  227. alpha = Symbol('alpha')
  228. assert ask(Q.real_elements(alpha*X), Q.real_elements(X) & Q.real(alpha))
  229. assert ask(Q.real_elements(LofLU(X)), Q.real_elements(X))
  230. e = Symbol('e', integer=True, negative=True)
  231. assert ask(Q.real_elements(X**e), Q.real_elements(X) & Q.invertible(X))
  232. assert ask(Q.real_elements(X**e), Q.real_elements(X)) is None
  233. def test_matrix_element_sets():
  234. X = MatrixSymbol('X', 4, 4)
  235. assert ask(Q.real(X[1, 2]), Q.real_elements(X))
  236. assert ask(Q.integer(X[1, 2]), Q.integer_elements(X))
  237. assert ask(Q.complex(X[1, 2]), Q.complex_elements(X))
  238. assert ask(Q.integer_elements(Identity(3)))
  239. assert ask(Q.integer_elements(ZeroMatrix(3, 3)))
  240. assert ask(Q.integer_elements(OneMatrix(3, 3)))
  241. from sympy.matrices.expressions.fourier import DFT
  242. assert ask(Q.complex_elements(DFT(3)))
  243. def test_matrix_element_sets_slices_blocks():
  244. X = MatrixSymbol('X', 4, 4)
  245. assert ask(Q.integer_elements(X[:, 3]), Q.integer_elements(X))
  246. assert ask(Q.integer_elements(BlockMatrix([[X], [X]])),
  247. Q.integer_elements(X))
  248. def test_matrix_element_sets_determinant_trace():
  249. assert ask(Q.integer(Determinant(X)), Q.integer_elements(X))
  250. assert ask(Q.integer(Trace(X)), Q.integer_elements(X))