test_dotproduct.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435
  1. from sympy.core.expr import unchanged
  2. from sympy.core.mul import Mul
  3. from sympy.matrices import Matrix
  4. from sympy.matrices.expressions.matexpr import MatrixSymbol
  5. from sympy.matrices.expressions.dotproduct import DotProduct
  6. from sympy.testing.pytest import raises
  7. A = Matrix(3, 1, [1, 2, 3])
  8. B = Matrix(3, 1, [1, 3, 5])
  9. C = Matrix(4, 1, [1, 2, 4, 5])
  10. D = Matrix(2, 2, [1, 2, 3, 4])
  11. def test_docproduct():
  12. assert DotProduct(A, B).doit() == 22
  13. assert DotProduct(A.T, B).doit() == 22
  14. assert DotProduct(A, B.T).doit() == 22
  15. assert DotProduct(A.T, B.T).doit() == 22
  16. raises(TypeError, lambda: DotProduct(1, A))
  17. raises(TypeError, lambda: DotProduct(A, 1))
  18. raises(TypeError, lambda: DotProduct(A, D))
  19. raises(TypeError, lambda: DotProduct(D, A))
  20. raises(TypeError, lambda: DotProduct(B, C).doit())
  21. def test_dotproduct_symbolic():
  22. A = MatrixSymbol('A', 3, 1)
  23. B = MatrixSymbol('B', 3, 1)
  24. dot = DotProduct(A, B)
  25. assert dot.is_scalar == True
  26. assert unchanged(Mul, 2, dot)
  27. # XXX Fix forced evaluation for arithmetics with matrix expressions
  28. assert dot * A == (A[0, 0]*B[0, 0] + A[1, 0]*B[1, 0] + A[2, 0]*B[2, 0])*A