test_immutable.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. from itertools import product
  2. from sympy.core.relational import (Equality, Unequality)
  3. from sympy.core.singleton import S
  4. from sympy.core.sympify import sympify
  5. from sympy.integrals.integrals import integrate
  6. from sympy.matrices.dense import (Matrix, eye, zeros)
  7. from sympy.matrices.immutable import ImmutableMatrix
  8. from sympy.matrices import SparseMatrix
  9. from sympy.matrices.immutable import \
  10. ImmutableDenseMatrix, ImmutableSparseMatrix
  11. from sympy.abc import x, y
  12. from sympy.testing.pytest import raises
  13. IM = ImmutableDenseMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  14. ISM = ImmutableSparseMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  15. ieye = ImmutableDenseMatrix(eye(3))
  16. def test_creation():
  17. assert IM.shape == ISM.shape == (3, 3)
  18. assert IM[1, 2] == ISM[1, 2] == 6
  19. assert IM[2, 2] == ISM[2, 2] == 9
  20. def test_immutability():
  21. with raises(TypeError):
  22. IM[2, 2] = 5
  23. with raises(TypeError):
  24. ISM[2, 2] = 5
  25. def test_slicing():
  26. assert IM[1, :] == ImmutableDenseMatrix([[4, 5, 6]])
  27. assert IM[:2, :2] == ImmutableDenseMatrix([[1, 2], [4, 5]])
  28. assert ISM[1, :] == ImmutableSparseMatrix([[4, 5, 6]])
  29. assert ISM[:2, :2] == ImmutableSparseMatrix([[1, 2], [4, 5]])
  30. def test_subs():
  31. A = ImmutableMatrix([[1, 2], [3, 4]])
  32. B = ImmutableMatrix([[1, 2], [x, 4]])
  33. C = ImmutableMatrix([[-x, x*y], [-(x + y), y**2]])
  34. assert B.subs(x, 3) == A
  35. assert (x*B).subs(x, 3) == 3*A
  36. assert (x*eye(2) + B).subs(x, 3) == 3*eye(2) + A
  37. assert C.subs([[x, -1], [y, -2]]) == A
  38. assert C.subs([(x, -1), (y, -2)]) == A
  39. assert C.subs({x: -1, y: -2}) == A
  40. assert C.subs({x: y - 1, y: x - 1}, simultaneous=True) == \
  41. ImmutableMatrix([[1 - y, (x - 1)*(y - 1)], [2 - x - y, (x - 1)**2]])
  42. def test_as_immutable():
  43. data = [[1, 2], [3, 4]]
  44. X = Matrix(data)
  45. assert sympify(X) == X.as_immutable() == ImmutableMatrix(data)
  46. data = {(0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4}
  47. X = SparseMatrix(2, 2, data)
  48. assert sympify(X) == X.as_immutable() == ImmutableSparseMatrix(2, 2, data)
  49. def test_function_return_types():
  50. # Lets ensure that decompositions of immutable matrices remain immutable
  51. # I.e. do MatrixBase methods return the correct class?
  52. X = ImmutableMatrix([[1, 2], [3, 4]])
  53. Y = ImmutableMatrix([[1], [0]])
  54. q, r = X.QRdecomposition()
  55. assert (type(q), type(r)) == (ImmutableMatrix, ImmutableMatrix)
  56. assert type(X.LUsolve(Y)) == ImmutableMatrix
  57. assert type(X.QRsolve(Y)) == ImmutableMatrix
  58. X = ImmutableMatrix([[5, 2], [2, 7]])
  59. assert X.T == X
  60. assert X.is_symmetric
  61. assert type(X.cholesky()) == ImmutableMatrix
  62. L, D = X.LDLdecomposition()
  63. assert (type(L), type(D)) == (ImmutableMatrix, ImmutableMatrix)
  64. X = ImmutableMatrix([[1, 2], [2, 1]])
  65. assert X.is_diagonalizable()
  66. assert X.det() == -3
  67. assert X.norm(2) == 3
  68. assert type(X.eigenvects()[0][2][0]) == ImmutableMatrix
  69. assert type(zeros(3, 3).as_immutable().nullspace()[0]) == ImmutableMatrix
  70. X = ImmutableMatrix([[1, 0], [2, 1]])
  71. assert type(X.lower_triangular_solve(Y)) == ImmutableMatrix
  72. assert type(X.T.upper_triangular_solve(Y)) == ImmutableMatrix
  73. assert type(X.minor_submatrix(0, 0)) == ImmutableMatrix
  74. # issue 6279
  75. # https://github.com/sympy/sympy/issues/6279
  76. # Test that Immutable _op_ Immutable => Immutable and not MatExpr
  77. def test_immutable_evaluation():
  78. X = ImmutableMatrix(eye(3))
  79. A = ImmutableMatrix(3, 3, range(9))
  80. assert isinstance(X + A, ImmutableMatrix)
  81. assert isinstance(X * A, ImmutableMatrix)
  82. assert isinstance(X * 2, ImmutableMatrix)
  83. assert isinstance(2 * X, ImmutableMatrix)
  84. assert isinstance(A**2, ImmutableMatrix)
  85. def test_deterimant():
  86. assert ImmutableMatrix(4, 4, lambda i, j: i + j).det() == 0
  87. def test_Equality():
  88. assert Equality(IM, IM) is S.true
  89. assert Unequality(IM, IM) is S.false
  90. assert Equality(IM, IM.subs(1, 2)) is S.false
  91. assert Unequality(IM, IM.subs(1, 2)) is S.true
  92. assert Equality(IM, 2) is S.false
  93. assert Unequality(IM, 2) is S.true
  94. M = ImmutableMatrix([x, y])
  95. assert Equality(M, IM) is S.false
  96. assert Unequality(M, IM) is S.true
  97. assert Equality(M, M.subs(x, 2)).subs(x, 2) is S.true
  98. assert Unequality(M, M.subs(x, 2)).subs(x, 2) is S.false
  99. assert Equality(M, M.subs(x, 2)).subs(x, 3) is S.false
  100. assert Unequality(M, M.subs(x, 2)).subs(x, 3) is S.true
  101. def test_integrate():
  102. intIM = integrate(IM, x)
  103. assert intIM.shape == IM.shape
  104. assert all([intIM[i, j] == (1 + j + 3*i)*x for i, j in
  105. product(range(3), range(3))])