_shape.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from sympy.core.relational import Eq
  2. from sympy.core.expr import Expr
  3. from sympy.core.numbers import Integer
  4. from sympy.logic.boolalg import Boolean, And
  5. from sympy.matrices.expressions.matexpr import MatrixExpr
  6. from sympy.matrices.common import ShapeError
  7. from typing import Union
  8. def is_matadd_valid(*args: MatrixExpr) -> Boolean:
  9. """Return the symbolic condition how ``MatAdd``, ``HadamardProduct``
  10. makes sense.
  11. Parameters
  12. ==========
  13. args
  14. The list of arguments of matrices to be tested for.
  15. Examples
  16. ========
  17. >>> from sympy import MatrixSymbol, symbols
  18. >>> from sympy.matrices.expressions._shape import is_matadd_valid
  19. >>> m, n, p, q = symbols('m n p q')
  20. >>> A = MatrixSymbol('A', m, n)
  21. >>> B = MatrixSymbol('B', p, q)
  22. >>> is_matadd_valid(A, B)
  23. Eq(m, p) & Eq(n, q)
  24. """
  25. rows, cols = zip(*(arg.shape for arg in args))
  26. return And(
  27. *(Eq(i, j) for i, j in zip(rows[:-1], rows[1:])),
  28. *(Eq(i, j) for i, j in zip(cols[:-1], cols[1:])),
  29. )
  30. def is_matmul_valid(*args: Union[MatrixExpr, Expr]) -> Boolean:
  31. """Return the symbolic condition how ``MatMul`` makes sense
  32. Parameters
  33. ==========
  34. args
  35. The list of arguments of matrices and scalar expressions to be tested
  36. for.
  37. Examples
  38. ========
  39. >>> from sympy import MatrixSymbol, symbols
  40. >>> from sympy.matrices.expressions._shape import is_matmul_valid
  41. >>> m, n, p, q = symbols('m n p q')
  42. >>> A = MatrixSymbol('A', m, n)
  43. >>> B = MatrixSymbol('B', p, q)
  44. >>> is_matmul_valid(A, B)
  45. Eq(n, p)
  46. """
  47. rows, cols = zip(*(arg.shape for arg in args if isinstance(arg, MatrixExpr)))
  48. return And(*(Eq(i, j) for i, j in zip(cols[:-1], rows[1:])))
  49. def is_square(arg: MatrixExpr, /) -> Boolean:
  50. """Return the symbolic condition how the matrix is assumed to be square
  51. Parameters
  52. ==========
  53. arg
  54. The matrix to be tested for.
  55. Examples
  56. ========
  57. >>> from sympy import MatrixSymbol, symbols
  58. >>> from sympy.matrices.expressions._shape import is_square
  59. >>> m, n = symbols('m n')
  60. >>> A = MatrixSymbol('A', m, n)
  61. >>> is_square(A)
  62. Eq(m, n)
  63. """
  64. return Eq(arg.rows, arg.cols)
  65. def validate_matadd_integer(*args: MatrixExpr) -> None:
  66. """Validate matrix shape for addition only for integer values"""
  67. rows, cols = zip(*(x.shape for x in args))
  68. if len(set(filter(lambda x: isinstance(x, (int, Integer)), rows))) > 1:
  69. raise ShapeError(f"Matrices have mismatching shape: {rows}")
  70. if len(set(filter(lambda x: isinstance(x, (int, Integer)), cols))) > 1:
  71. raise ShapeError(f"Matrices have mismatching shape: {cols}")
  72. def validate_matmul_integer(*args: MatrixExpr) -> None:
  73. """Validate matrix shape for multiplication only for integer values"""
  74. for A, B in zip(args[:-1], args[1:]):
  75. i, j = A.cols, B.rows
  76. if isinstance(i, (int, Integer)) and isinstance(j, (int, Integer)) and i != j:
  77. raise ShapeError("Matrices are not aligned", i, j)