dotproduct.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from sympy.core import Basic, Expr
  2. from sympy.core.sympify import _sympify
  3. from sympy.matrices.expressions.transpose import transpose
  4. class DotProduct(Expr):
  5. """
  6. Dot product of vector matrices
  7. The input should be two 1 x n or n x 1 matrices. The output represents the
  8. scalar dotproduct.
  9. This is similar to using MatrixElement and MatMul, except DotProduct does
  10. not require that one vector to be a row vector and the other vector to be
  11. a column vector.
  12. >>> from sympy import MatrixSymbol, DotProduct
  13. >>> A = MatrixSymbol('A', 1, 3)
  14. >>> B = MatrixSymbol('B', 1, 3)
  15. >>> DotProduct(A, B)
  16. DotProduct(A, B)
  17. >>> DotProduct(A, B).doit()
  18. A[0, 0]*B[0, 0] + A[0, 1]*B[0, 1] + A[0, 2]*B[0, 2]
  19. """
  20. def __new__(cls, arg1, arg2):
  21. arg1, arg2 = _sympify((arg1, arg2))
  22. if not arg1.is_Matrix:
  23. raise TypeError("Argument 1 of DotProduct is not a matrix")
  24. if not arg2.is_Matrix:
  25. raise TypeError("Argument 2 of DotProduct is not a matrix")
  26. if not (1 in arg1.shape):
  27. raise TypeError("Argument 1 of DotProduct is not a vector")
  28. if not (1 in arg2.shape):
  29. raise TypeError("Argument 2 of DotProduct is not a vector")
  30. if set(arg1.shape) != set(arg2.shape):
  31. raise TypeError("DotProduct arguments are not the same length")
  32. return Basic.__new__(cls, arg1, arg2)
  33. def doit(self, expand=False, **hints):
  34. if self.args[0].shape == self.args[1].shape:
  35. if self.args[0].shape[0] == 1:
  36. mul = self.args[0]*transpose(self.args[1])
  37. else:
  38. mul = transpose(self.args[0])*self.args[1]
  39. else:
  40. if self.args[0].shape[0] == 1:
  41. mul = self.args[0]*self.args[1]
  42. else:
  43. mul = transpose(self.args[0])*transpose(self.args[1])
  44. return mul[0]