test_transpose.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from sympy.functions import adjoint, conjugate, transpose
  2. from sympy.matrices.expressions import MatrixSymbol, Adjoint, trace, Transpose
  3. from sympy.matrices import eye, Matrix
  4. from sympy.assumptions.ask import Q
  5. from sympy.assumptions.refine import refine
  6. from sympy.core.singleton import S
  7. from sympy.core.symbol import symbols
  8. n, m, l, k, p = symbols('n m l k p', integer=True)
  9. A = MatrixSymbol('A', n, m)
  10. B = MatrixSymbol('B', m, l)
  11. C = MatrixSymbol('C', n, n)
  12. def test_transpose():
  13. Sq = MatrixSymbol('Sq', n, n)
  14. assert transpose(A) == Transpose(A)
  15. assert Transpose(A).shape == (m, n)
  16. assert Transpose(A*B).shape == (l, n)
  17. assert transpose(Transpose(A)) == A
  18. assert isinstance(Transpose(Transpose(A)), Transpose)
  19. assert adjoint(Transpose(A)) == Adjoint(Transpose(A))
  20. assert conjugate(Transpose(A)) == Adjoint(A)
  21. assert Transpose(eye(3)).doit() == eye(3)
  22. assert Transpose(S(5)).doit() == S(5)
  23. assert Transpose(Matrix([[1, 2], [3, 4]])).doit() == Matrix([[1, 3], [2, 4]])
  24. assert transpose(trace(Sq)) == trace(Sq)
  25. assert trace(Transpose(Sq)) == trace(Sq)
  26. assert Transpose(Sq)[0, 1] == Sq[1, 0]
  27. assert Transpose(A*B).doit() == Transpose(B) * Transpose(A)
  28. def test_transpose_MatAdd_MatMul():
  29. # Issue 16807
  30. from sympy.functions.elementary.trigonometric import cos
  31. x = symbols('x')
  32. M = MatrixSymbol('M', 3, 3)
  33. N = MatrixSymbol('N', 3, 3)
  34. assert (N + (cos(x) * M)).T == cos(x)*M.T + N.T
  35. def test_refine():
  36. assert refine(C.T, Q.symmetric(C)) == C
  37. def test_transpose1x1():
  38. m = MatrixSymbol('m', 1, 1)
  39. assert m == refine(m.T)
  40. assert m == refine(m.T.T)
  41. def test_issue_9817():
  42. from sympy.matrices.expressions import Identity
  43. v = MatrixSymbol('v', 3, 1)
  44. A = MatrixSymbol('A', 3, 3)
  45. x = Matrix([i + 1 for i in range(3)])
  46. X = Identity(3)
  47. quadratic = v.T * A * v
  48. subbed = quadratic.xreplace({v:x, A:X})
  49. assert subbed.as_explicit() == Matrix([[14]])