test_funcmatrix.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from sympy.core import symbols, Lambda
  2. from sympy.functions import KroneckerDelta
  3. from sympy.matrices import Matrix
  4. from sympy.matrices.expressions import FunctionMatrix, MatrixExpr, Identity
  5. from sympy.testing.pytest import raises, warns
  6. from sympy.utilities.exceptions import SymPyDeprecationWarning
  7. def test_funcmatrix_creation():
  8. i, j, k = symbols('i j k')
  9. assert FunctionMatrix(2, 2, Lambda((i, j), 0))
  10. assert FunctionMatrix(0, 0, Lambda((i, j), 0))
  11. raises(ValueError, lambda: FunctionMatrix(-1, 0, Lambda((i, j), 0)))
  12. raises(ValueError, lambda: FunctionMatrix(2.0, 0, Lambda((i, j), 0)))
  13. raises(ValueError, lambda: FunctionMatrix(2j, 0, Lambda((i, j), 0)))
  14. raises(ValueError, lambda: FunctionMatrix(0, -1, Lambda((i, j), 0)))
  15. raises(ValueError, lambda: FunctionMatrix(0, 2.0, Lambda((i, j), 0)))
  16. raises(ValueError, lambda: FunctionMatrix(0, 2j, Lambda((i, j), 0)))
  17. raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda(i, 0)))
  18. with warns(SymPyDeprecationWarning, test_stacklevel=False):
  19. # This raises a deprecation warning from sympify()
  20. raises(ValueError, lambda: FunctionMatrix(2, 2, lambda i, j: 0))
  21. raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda((i,), 0)))
  22. raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda((i, j, k), 0)))
  23. raises(ValueError, lambda: FunctionMatrix(2, 2, i+j))
  24. assert FunctionMatrix(2, 2, "lambda i, j: 0") == \
  25. FunctionMatrix(2, 2, Lambda((i, j), 0))
  26. m = FunctionMatrix(2, 2, KroneckerDelta)
  27. assert m.as_explicit() == Identity(2).as_explicit()
  28. assert m.args[2].dummy_eq(Lambda((i, j), KroneckerDelta(i, j)))
  29. n = symbols('n')
  30. assert FunctionMatrix(n, n, Lambda((i, j), 0))
  31. n = symbols('n', integer=False)
  32. raises(ValueError, lambda: FunctionMatrix(n, n, Lambda((i, j), 0)))
  33. n = symbols('n', negative=True)
  34. raises(ValueError, lambda: FunctionMatrix(n, n, Lambda((i, j), 0)))
  35. def test_funcmatrix():
  36. i, j = symbols('i,j')
  37. X = FunctionMatrix(3, 3, Lambda((i, j), i - j))
  38. assert X[1, 1] == 0
  39. assert X[1, 2] == -1
  40. assert X.shape == (3, 3)
  41. assert X.rows == X.cols == 3
  42. assert Matrix(X) == Matrix(3, 3, lambda i, j: i - j)
  43. assert isinstance(X*X + X, MatrixExpr)
  44. def test_replace_issue():
  45. X = FunctionMatrix(3, 3, KroneckerDelta)
  46. assert X.replace(lambda x: True, lambda x: x) == X