test_matpow.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. from sympy.functions.elementary.miscellaneous import sqrt
  2. from sympy.simplify.powsimp import powsimp
  3. from sympy.testing.pytest import raises
  4. from sympy.core.expr import unchanged
  5. from sympy.core import symbols, S
  6. from sympy.matrices import Identity, MatrixSymbol, ImmutableMatrix, ZeroMatrix, OneMatrix, Matrix
  7. from sympy.matrices.common import NonSquareMatrixError
  8. from sympy.matrices.expressions import MatPow, MatAdd, MatMul
  9. from sympy.matrices.expressions.inverse import Inverse
  10. from sympy.matrices.expressions.matexpr import MatrixElement
  11. n, m, l, k = symbols('n m l k', integer=True)
  12. A = MatrixSymbol('A', n, m)
  13. B = MatrixSymbol('B', m, l)
  14. C = MatrixSymbol('C', n, n)
  15. D = MatrixSymbol('D', n, n)
  16. E = MatrixSymbol('E', m, n)
  17. def test_entry_matrix():
  18. X = ImmutableMatrix([[1, 2], [3, 4]])
  19. assert MatPow(X, 0)[0, 0] == 1
  20. assert MatPow(X, 0)[0, 1] == 0
  21. assert MatPow(X, 1)[0, 0] == 1
  22. assert MatPow(X, 1)[0, 1] == 2
  23. assert MatPow(X, 2)[0, 0] == 7
  24. def test_entry_symbol():
  25. from sympy.concrete import Sum
  26. assert MatPow(C, 0)[0, 0] == 1
  27. assert MatPow(C, 0)[0, 1] == 0
  28. assert MatPow(C, 1)[0, 0] == C[0, 0]
  29. assert isinstance(MatPow(C, 2)[0, 0], Sum)
  30. assert isinstance(MatPow(C, n)[0, 0], MatrixElement)
  31. def test_as_explicit_symbol():
  32. X = MatrixSymbol('X', 2, 2)
  33. assert MatPow(X, 0).as_explicit() == ImmutableMatrix(Identity(2))
  34. assert MatPow(X, 1).as_explicit() == X.as_explicit()
  35. assert MatPow(X, 2).as_explicit() == (X.as_explicit())**2
  36. assert MatPow(X, n).as_explicit() == ImmutableMatrix([
  37. [(X ** n)[0, 0], (X ** n)[0, 1]],
  38. [(X ** n)[1, 0], (X ** n)[1, 1]],
  39. ])
  40. a = MatrixSymbol("a", 3, 1)
  41. b = MatrixSymbol("b", 3, 1)
  42. c = MatrixSymbol("c", 3, 1)
  43. expr = (a.T*b)**S.Half
  44. assert expr.as_explicit() == Matrix([[sqrt(a[0, 0]*b[0, 0] + a[1, 0]*b[1, 0] + a[2, 0]*b[2, 0])]])
  45. expr = c*(a.T*b)**S.Half
  46. m = sqrt(a[0, 0]*b[0, 0] + a[1, 0]*b[1, 0] + a[2, 0]*b[2, 0])
  47. assert expr.as_explicit() == Matrix([[c[0, 0]*m], [c[1, 0]*m], [c[2, 0]*m]])
  48. expr = (a*b.T)**S.Half
  49. denom = sqrt(a[0, 0]*b[0, 0] + a[1, 0]*b[1, 0] + a[2, 0]*b[2, 0])
  50. expected = (a*b.T).as_explicit()/denom
  51. assert expr.as_explicit() == expected
  52. expr = X**-1
  53. det = X[0, 0]*X[1, 1] - X[1, 0]*X[0, 1]
  54. expected = Matrix([[X[1, 1], -X[0, 1]], [-X[1, 0], X[0, 0]]])/det
  55. assert expr.as_explicit() == expected
  56. expr = X**m
  57. assert expr.as_explicit() == X.as_explicit()**m
  58. def test_as_explicit_matrix():
  59. A = ImmutableMatrix([[1, 2], [3, 4]])
  60. assert MatPow(A, 0).as_explicit() == ImmutableMatrix(Identity(2))
  61. assert MatPow(A, 1).as_explicit() == A
  62. assert MatPow(A, 2).as_explicit() == A**2
  63. assert MatPow(A, -1).as_explicit() == A.inv()
  64. assert MatPow(A, -2).as_explicit() == (A.inv())**2
  65. # less expensive than testing on a 2x2
  66. A = ImmutableMatrix([4])
  67. assert MatPow(A, S.Half).as_explicit() == A**S.Half
  68. def test_doit_symbol():
  69. assert MatPow(C, 0).doit() == Identity(n)
  70. assert MatPow(C, 1).doit() == C
  71. assert MatPow(C, -1).doit() == C.I
  72. for r in [2, S.Half, S.Pi, n]:
  73. assert MatPow(C, r).doit() == MatPow(C, r)
  74. def test_doit_matrix():
  75. X = ImmutableMatrix([[1, 2], [3, 4]])
  76. assert MatPow(X, 0).doit() == ImmutableMatrix(Identity(2))
  77. assert MatPow(X, 1).doit() == X
  78. assert MatPow(X, 2).doit() == X**2
  79. assert MatPow(X, -1).doit() == X.inv()
  80. assert MatPow(X, -2).doit() == (X.inv())**2
  81. # less expensive than testing on a 2x2
  82. assert MatPow(ImmutableMatrix([4]), S.Half).doit() == ImmutableMatrix([2])
  83. X = ImmutableMatrix([[0, 2], [0, 4]]) # det() == 0
  84. raises(ValueError, lambda: MatPow(X,-1).doit())
  85. raises(ValueError, lambda: MatPow(X,-2).doit())
  86. def test_nonsquare():
  87. A = MatrixSymbol('A', 2, 3)
  88. B = ImmutableMatrix([[1, 2, 3], [4, 5, 6]])
  89. for r in [-1, 0, 1, 2, S.Half, S.Pi, n]:
  90. raises(NonSquareMatrixError, lambda: MatPow(A, r))
  91. raises(NonSquareMatrixError, lambda: MatPow(B, r))
  92. def test_doit_equals_pow(): #17179
  93. X = ImmutableMatrix ([[1,0],[0,1]])
  94. assert MatPow(X, n).doit() == X**n == X
  95. def test_doit_nested_MatrixExpr():
  96. X = ImmutableMatrix([[1, 2], [3, 4]])
  97. Y = ImmutableMatrix([[2, 3], [4, 5]])
  98. assert MatPow(MatMul(X, Y), 2).doit() == (X*Y)**2
  99. assert MatPow(MatAdd(X, Y), 2).doit() == (X + Y)**2
  100. def test_identity_power():
  101. k = Identity(n)
  102. assert MatPow(k, 4).doit() == k
  103. assert MatPow(k, n).doit() == k
  104. assert MatPow(k, -3).doit() == k
  105. assert MatPow(k, 0).doit() == k
  106. l = Identity(3)
  107. assert MatPow(l, n).doit() == l
  108. assert MatPow(l, -1).doit() == l
  109. assert MatPow(l, 0).doit() == l
  110. def test_zero_power():
  111. z1 = ZeroMatrix(n, n)
  112. assert MatPow(z1, 3).doit() == z1
  113. raises(ValueError, lambda:MatPow(z1, -1).doit())
  114. assert MatPow(z1, 0).doit() == Identity(n)
  115. assert MatPow(z1, n).doit() == z1
  116. raises(ValueError, lambda:MatPow(z1, -2).doit())
  117. z2 = ZeroMatrix(4, 4)
  118. assert MatPow(z2, n).doit() == z2
  119. raises(ValueError, lambda:MatPow(z2, -3).doit())
  120. assert MatPow(z2, 2).doit() == z2
  121. assert MatPow(z2, 0).doit() == Identity(4)
  122. raises(ValueError, lambda:MatPow(z2, -1).doit())
  123. def test_OneMatrix_power():
  124. o = OneMatrix(3, 3)
  125. assert o ** 0 == Identity(3)
  126. assert o ** 1 == o
  127. assert o * o == o ** 2 == 3 * o
  128. assert o * o * o == o ** 3 == 9 * o
  129. o = OneMatrix(n, n)
  130. assert o * o == o ** 2 == n * o
  131. # powsimp necessary as n ** (n - 2) * n does not produce n ** (n - 1)
  132. assert powsimp(o ** (n - 1) * o) == o ** n == n ** (n - 1) * o
  133. def test_transpose_power():
  134. from sympy.matrices.expressions.transpose import Transpose as TP
  135. assert (C*D).T**5 == ((C*D)**5).T == (D.T * C.T)**5
  136. assert ((C*D).T**5).T == (C*D)**5
  137. assert (C.T.I.T)**7 == C**-7
  138. assert (C.T**l).T**k == C**(l*k)
  139. assert ((E.T * A.T)**5).T == (A*E)**5
  140. assert ((A*E).T**5).T**7 == (A*E)**35
  141. assert TP(TP(C**2 * D**3)**5).doit() == (C**2 * D**3)**5
  142. assert ((D*C)**-5).T**-5 == ((D*C)**25).T
  143. assert (((D*C)**l).T**k).T == (D*C)**(l*k)
  144. def test_Inverse():
  145. assert Inverse(MatPow(C, 0)).doit() == Identity(n)
  146. assert Inverse(MatPow(C, 1)).doit() == Inverse(C)
  147. assert Inverse(MatPow(C, 2)).doit() == MatPow(C, -2)
  148. assert Inverse(MatPow(C, -1)).doit() == C
  149. assert MatPow(Inverse(C), 0).doit() == Identity(n)
  150. assert MatPow(Inverse(C), 1).doit() == Inverse(C)
  151. assert MatPow(Inverse(C), 2).doit() == MatPow(C, -2)
  152. assert MatPow(Inverse(C), -1).doit() == C
  153. def test_combine_powers():
  154. assert (C ** 1) ** 1 == C
  155. assert (C ** 2) ** 3 == MatPow(C, 6)
  156. assert (C ** -2) ** -3 == MatPow(C, 6)
  157. assert (C ** -1) ** -1 == C
  158. assert (((C ** 2) ** 3) ** 4) ** 5 == MatPow(C, 120)
  159. assert (C ** n) ** n == C ** (n ** 2)
  160. def test_unchanged():
  161. assert unchanged(MatPow, C, 0)
  162. assert unchanged(MatPow, C, 1)
  163. assert unchanged(MatPow, Inverse(C), -1)
  164. assert unchanged(Inverse, MatPow(C, -1), -1)
  165. assert unchanged(MatPow, MatPow(C, -1), -1)
  166. assert unchanged(MatPow, MatPow(C, 1), 1)
  167. def test_no_exponentiation():
  168. # if this passes, Pow.as_numer_denom should recognize
  169. # MatAdd as exponent
  170. raises(NotImplementedError, lambda: 3**(-2*C))