123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- from sympy.core.relational import Eq
- from sympy.core.expr import Expr
- from sympy.core.numbers import Integer
- from sympy.logic.boolalg import Boolean, And
- from sympy.matrices.expressions.matexpr import MatrixExpr
- from sympy.matrices.common import ShapeError
- from typing import Union
- def is_matadd_valid(*args: MatrixExpr) -> Boolean:
- """Return the symbolic condition how ``MatAdd``, ``HadamardProduct``
- makes sense.
- Parameters
- ==========
- args
- The list of arguments of matrices to be tested for.
- Examples
- ========
- >>> from sympy import MatrixSymbol, symbols
- >>> from sympy.matrices.expressions._shape import is_matadd_valid
- >>> m, n, p, q = symbols('m n p q')
- >>> A = MatrixSymbol('A', m, n)
- >>> B = MatrixSymbol('B', p, q)
- >>> is_matadd_valid(A, B)
- Eq(m, p) & Eq(n, q)
- """
- rows, cols = zip(*(arg.shape for arg in args))
- return And(
- *(Eq(i, j) for i, j in zip(rows[:-1], rows[1:])),
- *(Eq(i, j) for i, j in zip(cols[:-1], cols[1:])),
- )
- def is_matmul_valid(*args: Union[MatrixExpr, Expr]) -> Boolean:
- """Return the symbolic condition how ``MatMul`` makes sense
- Parameters
- ==========
- args
- The list of arguments of matrices and scalar expressions to be tested
- for.
- Examples
- ========
- >>> from sympy import MatrixSymbol, symbols
- >>> from sympy.matrices.expressions._shape import is_matmul_valid
- >>> m, n, p, q = symbols('m n p q')
- >>> A = MatrixSymbol('A', m, n)
- >>> B = MatrixSymbol('B', p, q)
- >>> is_matmul_valid(A, B)
- Eq(n, p)
- """
- rows, cols = zip(*(arg.shape for arg in args if isinstance(arg, MatrixExpr)))
- return And(*(Eq(i, j) for i, j in zip(cols[:-1], rows[1:])))
- def is_square(arg: MatrixExpr, /) -> Boolean:
- """Return the symbolic condition how the matrix is assumed to be square
- Parameters
- ==========
- arg
- The matrix to be tested for.
- Examples
- ========
- >>> from sympy import MatrixSymbol, symbols
- >>> from sympy.matrices.expressions._shape import is_square
- >>> m, n = symbols('m n')
- >>> A = MatrixSymbol('A', m, n)
- >>> is_square(A)
- Eq(m, n)
- """
- return Eq(arg.rows, arg.cols)
- def validate_matadd_integer(*args: MatrixExpr) -> None:
- """Validate matrix shape for addition only for integer values"""
- rows, cols = zip(*(x.shape for x in args))
- if len(set(filter(lambda x: isinstance(x, (int, Integer)), rows))) > 1:
- raise ShapeError(f"Matrices have mismatching shape: {rows}")
- if len(set(filter(lambda x: isinstance(x, (int, Integer)), cols))) > 1:
- raise ShapeError(f"Matrices have mismatching shape: {cols}")
- def validate_matmul_integer(*args: MatrixExpr) -> None:
- """Validate matrix shape for multiplication only for integer values"""
- for A, B in zip(args[:-1], args[1:]):
- i, j = A.cols, B.rows
- if isinstance(i, (int, Integer)) and isinstance(j, (int, Integer)) and i != j:
- raise ShapeError("Matrices are not aligned", i, j)
|