test_trace.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from sympy.core import Lambda, S, symbols
  2. from sympy.concrete import Sum
  3. from sympy.functions import adjoint, conjugate, transpose
  4. from sympy.matrices import eye, Matrix, ShapeError, ImmutableMatrix
  5. from sympy.matrices.expressions import (
  6. Adjoint, Identity, FunctionMatrix, MatrixExpr, MatrixSymbol, Trace,
  7. ZeroMatrix, trace, MatPow, MatAdd, MatMul
  8. )
  9. from sympy.matrices.expressions.special import OneMatrix
  10. from sympy.testing.pytest import raises
  11. from sympy.abc import i
  12. n = symbols('n', integer=True)
  13. A = MatrixSymbol('A', n, n)
  14. B = MatrixSymbol('B', n, n)
  15. C = MatrixSymbol('C', 3, 4)
  16. def test_Trace():
  17. assert isinstance(Trace(A), Trace)
  18. assert not isinstance(Trace(A), MatrixExpr)
  19. raises(ShapeError, lambda: Trace(C))
  20. assert trace(eye(3)) == 3
  21. assert trace(Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])) == 15
  22. assert adjoint(Trace(A)) == trace(Adjoint(A))
  23. assert conjugate(Trace(A)) == trace(Adjoint(A))
  24. assert transpose(Trace(A)) == Trace(A)
  25. _ = A / Trace(A) # Make sure this is possible
  26. # Some easy simplifications
  27. assert trace(Identity(5)) == 5
  28. assert trace(ZeroMatrix(5, 5)) == 0
  29. assert trace(OneMatrix(1, 1)) == 1
  30. assert trace(OneMatrix(2, 2)) == 2
  31. assert trace(OneMatrix(n, n)) == n
  32. assert trace(2*A*B) == 2*Trace(A*B)
  33. assert trace(A.T) == trace(A)
  34. i, j = symbols('i j')
  35. F = FunctionMatrix(3, 3, Lambda((i, j), i + j))
  36. assert trace(F) == (0 + 0) + (1 + 1) + (2 + 2)
  37. raises(TypeError, lambda: Trace(S.One))
  38. assert Trace(A).arg is A
  39. assert str(trace(A)) == str(Trace(A).doit())
  40. assert Trace(A).is_commutative is True
  41. def test_Trace_A_plus_B():
  42. assert trace(A + B) == Trace(A) + Trace(B)
  43. assert Trace(A + B).arg == MatAdd(A, B)
  44. assert Trace(A + B).doit() == Trace(A) + Trace(B)
  45. def test_Trace_MatAdd_doit():
  46. # See issue #9028
  47. X = ImmutableMatrix([[1, 2, 3]]*3)
  48. Y = MatrixSymbol('Y', 3, 3)
  49. q = MatAdd(X, 2*X, Y, -3*Y)
  50. assert Trace(q).arg == q
  51. assert Trace(q).doit() == 18 - 2*Trace(Y)
  52. def test_Trace_MatPow_doit():
  53. X = Matrix([[1, 2], [3, 4]])
  54. assert Trace(X).doit() == 5
  55. q = MatPow(X, 2)
  56. assert Trace(q).arg == q
  57. assert Trace(q).doit() == 29
  58. def test_Trace_MutableMatrix_plus():
  59. # See issue #9043
  60. X = Matrix([[1, 2], [3, 4]])
  61. assert Trace(X) + Trace(X) == 2*Trace(X)
  62. def test_Trace_doit_deep_False():
  63. X = Matrix([[1, 2], [3, 4]])
  64. q = MatPow(X, 2)
  65. assert Trace(q).doit(deep=False).arg == q
  66. q = MatAdd(X, 2*X)
  67. assert Trace(q).doit(deep=False).arg == q
  68. q = MatMul(X, 2*X)
  69. assert Trace(q).doit(deep=False).arg == q
  70. def test_trace_constant_factor():
  71. # Issue 9052: gave 2*Trace(MatMul(A)) instead of 2*Trace(A)
  72. assert trace(2*A) == 2*Trace(A)
  73. X = ImmutableMatrix([[1, 2], [3, 4]])
  74. assert trace(MatMul(2, X)) == 10
  75. def test_trace_rewrite():
  76. assert trace(A).rewrite(Sum) == Sum(A[i, i], (i, 0, n - 1))
  77. assert trace(eye(3)).rewrite(Sum) == 3
  78. def test_trace_normalize():
  79. assert Trace(B*A) != Trace(A*B)
  80. assert Trace(B*A)._normalize() == Trace(A*B)
  81. assert Trace(B*A.T)._normalize() == Trace(A*B.T)
  82. def test_trace_as_explicit():
  83. raises(ValueError, lambda: Trace(A).as_explicit())
  84. X = MatrixSymbol("X", 3, 3)
  85. assert Trace(X).as_explicit() == X[0, 0] + X[1, 1] + X[2, 2]
  86. assert Trace(eye(3)).as_explicit() == 3