test_inverse.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from sympy.core import symbols, S
  2. from sympy.matrices.expressions import MatrixSymbol, Inverse, MatPow, ZeroMatrix, OneMatrix
  3. from sympy.matrices.common import NonInvertibleMatrixError, NonSquareMatrixError
  4. from sympy.matrices import eye, Identity
  5. from sympy.testing.pytest import raises
  6. from sympy.assumptions.ask import Q
  7. from sympy.assumptions.refine import refine
  8. n, m, l = symbols('n m l', integer=True)
  9. A = MatrixSymbol('A', n, m)
  10. B = MatrixSymbol('B', m, l)
  11. C = MatrixSymbol('C', n, n)
  12. D = MatrixSymbol('D', n, n)
  13. E = MatrixSymbol('E', m, n)
  14. def test_inverse():
  15. assert Inverse(C).args == (C, S.NegativeOne)
  16. assert Inverse(C).shape == (n, n)
  17. assert Inverse(A*E).shape == (n, n)
  18. assert Inverse(E*A).shape == (m, m)
  19. assert Inverse(C).inverse() == C
  20. assert Inverse(Inverse(C)).doit() == C
  21. assert isinstance(Inverse(Inverse(C)), Inverse)
  22. assert Inverse(*Inverse(E*A).args) == Inverse(E*A)
  23. assert C.inverse().inverse() == C
  24. assert C.inverse()*C == Identity(C.rows)
  25. assert Identity(n).inverse() == Identity(n)
  26. assert (3*Identity(n)).inverse() == Identity(n)/3
  27. # Simplifies Muls if possible (i.e. submatrices are square)
  28. assert (C*D).inverse() == D.I*C.I
  29. # But still works when not possible
  30. assert isinstance((A*E).inverse(), Inverse)
  31. assert Inverse(C*D).doit(inv_expand=False) == Inverse(C*D)
  32. assert Inverse(eye(3)).doit() == eye(3)
  33. assert Inverse(eye(3)).doit(deep=False) == eye(3)
  34. assert OneMatrix(1, 1).I == Identity(1)
  35. assert isinstance(OneMatrix(n, n).I, Inverse)
  36. def test_inverse_non_invertible():
  37. raises(NonInvertibleMatrixError, lambda: ZeroMatrix(n, n).I)
  38. raises(NonInvertibleMatrixError, lambda: OneMatrix(2, 2).I)
  39. def test_refine():
  40. assert refine(C.I, Q.orthogonal(C)) == C.T
  41. def test_inverse_matpow_canonicalization():
  42. A = MatrixSymbol('A', 3, 3)
  43. assert Inverse(MatPow(A, 3)).doit() == MatPow(Inverse(A), 3).doit()
  44. def test_nonsquare_error():
  45. A = MatrixSymbol('A', 3, 4)
  46. raises(NonSquareMatrixError, lambda: Inverse(A))