test_interactions.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """
  2. We have a few different kind of Matrices
  3. Matrix, ImmutableMatrix, MatrixExpr
  4. Here we test the extent to which they cooperate
  5. """
  6. from sympy.core.symbol import symbols
  7. from sympy.matrices import (Matrix, MatrixSymbol, eye, Identity,
  8. ImmutableMatrix)
  9. from sympy.matrices.expressions import MatrixExpr, MatAdd
  10. from sympy.matrices.common import classof
  11. from sympy.testing.pytest import raises
  12. SM = MatrixSymbol('X', 3, 3)
  13. SV = MatrixSymbol('v', 3, 1)
  14. MM = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  15. IM = ImmutableMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  16. meye = eye(3)
  17. imeye = ImmutableMatrix(eye(3))
  18. ideye = Identity(3)
  19. a, b, c = symbols('a,b,c')
  20. def test_IM_MM():
  21. assert isinstance(MM + IM, ImmutableMatrix)
  22. assert isinstance(IM + MM, ImmutableMatrix)
  23. assert isinstance(2*IM + MM, ImmutableMatrix)
  24. assert MM.equals(IM)
  25. def test_ME_MM():
  26. assert isinstance(Identity(3) + MM, MatrixExpr)
  27. assert isinstance(SM + MM, MatAdd)
  28. assert isinstance(MM + SM, MatAdd)
  29. assert (Identity(3) + MM)[1, 1] == 6
  30. def test_equality():
  31. a, b, c = Identity(3), eye(3), ImmutableMatrix(eye(3))
  32. for x in [a, b, c]:
  33. for y in [a, b, c]:
  34. assert x.equals(y)
  35. def test_matrix_symbol_MM():
  36. X = MatrixSymbol('X', 3, 3)
  37. Y = eye(3) + X
  38. assert Y[1, 1] == 1 + X[1, 1]
  39. def test_matrix_symbol_vector_matrix_multiplication():
  40. A = MM * SV
  41. B = IM * SV
  42. assert A == B
  43. C = (SV.T * MM.T).T
  44. assert B == C
  45. D = (SV.T * IM.T).T
  46. assert C == D
  47. def test_indexing_interactions():
  48. assert (a * IM)[1, 1] == 5*a
  49. assert (SM + IM)[1, 1] == SM[1, 1] + IM[1, 1]
  50. assert (SM * IM)[1, 1] == SM[1, 0]*IM[0, 1] + SM[1, 1]*IM[1, 1] + \
  51. SM[1, 2]*IM[2, 1]
  52. def test_classof():
  53. A = Matrix(3, 3, range(9))
  54. B = ImmutableMatrix(3, 3, range(9))
  55. C = MatrixSymbol('C', 3, 3)
  56. assert classof(A, A) == Matrix
  57. assert classof(B, B) == ImmutableMatrix
  58. assert classof(A, B) == ImmutableMatrix
  59. assert classof(B, A) == ImmutableMatrix
  60. raises(TypeError, lambda: classof(A, C))