test_matadd.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from sympy.matrices.expressions import MatrixSymbol, MatAdd, MatPow, MatMul
  2. from sympy.matrices.expressions.special import GenericZeroMatrix, ZeroMatrix
  3. from sympy.matrices.common import ShapeError
  4. from sympy.matrices import eye, ImmutableMatrix
  5. from sympy.core import Add, Basic, S
  6. from sympy.core.add import add
  7. from sympy.testing.pytest import XFAIL, raises
  8. X = MatrixSymbol('X', 2, 2)
  9. Y = MatrixSymbol('Y', 2, 2)
  10. def test_evaluate():
  11. assert MatAdd(X, X, evaluate=True) == add(X, X, evaluate=True) == MatAdd(X, X).doit()
  12. def test_sort_key():
  13. assert MatAdd(Y, X).doit().args == add(Y, X).doit().args == (X, Y)
  14. def test_matadd_sympify():
  15. assert isinstance(MatAdd(eye(1), eye(1)).args[0], Basic)
  16. assert isinstance(add(eye(1), eye(1)).args[0], Basic)
  17. def test_matadd_of_matrices():
  18. assert MatAdd(eye(2), 4*eye(2), eye(2)).doit() == ImmutableMatrix(6*eye(2))
  19. assert add(eye(2), 4*eye(2), eye(2)).doit() == ImmutableMatrix(6*eye(2))
  20. def test_doit_args():
  21. A = ImmutableMatrix([[1, 2], [3, 4]])
  22. B = ImmutableMatrix([[2, 3], [4, 5]])
  23. assert MatAdd(A, MatPow(B, 2)).doit() == A + B**2
  24. assert MatAdd(A, MatMul(A, B)).doit() == A + A*B
  25. assert (MatAdd(A, X, MatMul(A, B), Y, MatAdd(2*A, B)).doit() ==
  26. add(A, X, MatMul(A, B), Y, add(2*A, B)).doit() ==
  27. MatAdd(3*A + A*B + B, X, Y))
  28. def test_generic_identity():
  29. assert MatAdd.identity == GenericZeroMatrix()
  30. assert MatAdd.identity != S.Zero
  31. def test_zero_matrix_add():
  32. assert Add(ZeroMatrix(2, 2), ZeroMatrix(2, 2)) == ZeroMatrix(2, 2)
  33. @XFAIL
  34. def test_matrix_Add_with_scalar():
  35. raises(TypeError, lambda: Add(0, ZeroMatrix(2, 2)))
  36. def test_shape_error():
  37. A = MatrixSymbol('A', 2, 3)
  38. B = MatrixSymbol('B', 3, 3)
  39. raises(ShapeError, lambda: MatAdd(A, B))
  40. A = MatrixSymbol('A', 3, 2)
  41. raises(ShapeError, lambda: MatAdd(A, B))