test_permutation.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. from sympy.combinatorics import Permutation
  2. from sympy.core.expr import unchanged
  3. from sympy.matrices import Matrix
  4. from sympy.matrices.expressions import \
  5. MatMul, BlockDiagMatrix, Determinant, Inverse
  6. from sympy.matrices.expressions.matexpr import MatrixSymbol
  7. from sympy.matrices.expressions.special import ZeroMatrix, OneMatrix, Identity
  8. from sympy.matrices.expressions.permutation import \
  9. MatrixPermute, PermutationMatrix
  10. from sympy.testing.pytest import raises
  11. from sympy.core.symbol import Symbol
  12. def test_PermutationMatrix_basic():
  13. p = Permutation([1, 0])
  14. assert unchanged(PermutationMatrix, p)
  15. raises(ValueError, lambda: PermutationMatrix((0, 1, 2)))
  16. assert PermutationMatrix(p).as_explicit() == Matrix([[0, 1], [1, 0]])
  17. assert isinstance(PermutationMatrix(p)*MatrixSymbol('A', 2, 2), MatMul)
  18. def test_PermutationMatrix_matmul():
  19. p = Permutation([1, 2, 0])
  20. P = PermutationMatrix(p)
  21. M = Matrix([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
  22. assert (P*M).as_explicit() == P.as_explicit()*M
  23. assert (M*P).as_explicit() == M*P.as_explicit()
  24. P1 = PermutationMatrix(Permutation([1, 2, 0]))
  25. P2 = PermutationMatrix(Permutation([2, 1, 0]))
  26. P3 = PermutationMatrix(Permutation([1, 0, 2]))
  27. assert P1*P2 == P3
  28. def test_PermutationMatrix_matpow():
  29. p1 = Permutation([1, 2, 0])
  30. P1 = PermutationMatrix(p1)
  31. p2 = Permutation([2, 0, 1])
  32. P2 = PermutationMatrix(p2)
  33. assert P1**2 == P2
  34. assert P1**3 == Identity(3)
  35. def test_PermutationMatrix_identity():
  36. p = Permutation([0, 1])
  37. assert PermutationMatrix(p).is_Identity
  38. p = Permutation([1, 0])
  39. assert not PermutationMatrix(p).is_Identity
  40. def test_PermutationMatrix_determinant():
  41. P = PermutationMatrix(Permutation([0, 1, 2]))
  42. assert Determinant(P).doit() == 1
  43. P = PermutationMatrix(Permutation([0, 2, 1]))
  44. assert Determinant(P).doit() == -1
  45. P = PermutationMatrix(Permutation([2, 0, 1]))
  46. assert Determinant(P).doit() == 1
  47. def test_PermutationMatrix_inverse():
  48. P = PermutationMatrix(Permutation(0, 1, 2))
  49. assert Inverse(P).doit() == PermutationMatrix(Permutation(0, 2, 1))
  50. def test_PermutationMatrix_rewrite_BlockDiagMatrix():
  51. P = PermutationMatrix(Permutation([0, 1, 2, 3, 4, 5]))
  52. P0 = PermutationMatrix(Permutation([0]))
  53. assert P.rewrite(BlockDiagMatrix) == \
  54. BlockDiagMatrix(P0, P0, P0, P0, P0, P0)
  55. P = PermutationMatrix(Permutation([0, 1, 3, 2, 4, 5]))
  56. P10 = PermutationMatrix(Permutation(0, 1))
  57. assert P.rewrite(BlockDiagMatrix) == \
  58. BlockDiagMatrix(P0, P0, P10, P0, P0)
  59. P = PermutationMatrix(Permutation([1, 0, 3, 2, 5, 4]))
  60. assert P.rewrite(BlockDiagMatrix) == \
  61. BlockDiagMatrix(P10, P10, P10)
  62. P = PermutationMatrix(Permutation([0, 4, 3, 2, 1, 5]))
  63. P3210 = PermutationMatrix(Permutation([3, 2, 1, 0]))
  64. assert P.rewrite(BlockDiagMatrix) == \
  65. BlockDiagMatrix(P0, P3210, P0)
  66. P = PermutationMatrix(Permutation([0, 4, 2, 3, 1, 5]))
  67. P3120 = PermutationMatrix(Permutation([3, 1, 2, 0]))
  68. assert P.rewrite(BlockDiagMatrix) == \
  69. BlockDiagMatrix(P0, P3120, P0)
  70. P = PermutationMatrix(Permutation(0, 3)(1, 4)(2, 5))
  71. assert P.rewrite(BlockDiagMatrix) == BlockDiagMatrix(P)
  72. def test_MartrixPermute_basic():
  73. p = Permutation(0, 1)
  74. P = PermutationMatrix(p)
  75. A = MatrixSymbol('A', 2, 2)
  76. raises(ValueError, lambda: MatrixPermute(Symbol('x'), p))
  77. raises(ValueError, lambda: MatrixPermute(A, Symbol('x')))
  78. assert MatrixPermute(A, P) == MatrixPermute(A, p)
  79. raises(ValueError, lambda: MatrixPermute(A, p, 2))
  80. pp = Permutation(0, 1, size=3)
  81. assert MatrixPermute(A, pp) == MatrixPermute(A, p)
  82. pp = Permutation(0, 1, 2)
  83. raises(ValueError, lambda: MatrixPermute(A, pp))
  84. def test_MatrixPermute_shape():
  85. p = Permutation(0, 1)
  86. A = MatrixSymbol('A', 2, 3)
  87. assert MatrixPermute(A, p).shape == (2, 3)
  88. def test_MatrixPermute_explicit():
  89. p = Permutation(0, 1, 2)
  90. A = MatrixSymbol('A', 3, 3)
  91. AA = A.as_explicit()
  92. assert MatrixPermute(A, p, 0).as_explicit() == \
  93. AA.permute(p, orientation='rows')
  94. assert MatrixPermute(A, p, 1).as_explicit() == \
  95. AA.permute(p, orientation='cols')
  96. def test_MatrixPermute_rewrite_MatMul():
  97. p = Permutation(0, 1, 2)
  98. A = MatrixSymbol('A', 3, 3)
  99. assert MatrixPermute(A, p, 0).rewrite(MatMul).as_explicit() == \
  100. MatrixPermute(A, p, 0).as_explicit()
  101. assert MatrixPermute(A, p, 1).rewrite(MatMul).as_explicit() == \
  102. MatrixPermute(A, p, 1).as_explicit()
  103. def test_MatrixPermute_doit():
  104. p = Permutation(0, 1, 2)
  105. A = MatrixSymbol('A', 3, 3)
  106. assert MatrixPermute(A, p).doit() == MatrixPermute(A, p)
  107. p = Permutation(0, size=3)
  108. A = MatrixSymbol('A', 3, 3)
  109. assert MatrixPermute(A, p).doit().as_explicit() == \
  110. MatrixPermute(A, p).as_explicit()
  111. p = Permutation(0, 1, 2)
  112. A = Identity(3)
  113. assert MatrixPermute(A, p, 0).doit().as_explicit() == \
  114. MatrixPermute(A, p, 0).as_explicit()
  115. assert MatrixPermute(A, p, 1).doit().as_explicit() == \
  116. MatrixPermute(A, p, 1).as_explicit()
  117. A = ZeroMatrix(3, 3)
  118. assert MatrixPermute(A, p).doit() == A
  119. A = OneMatrix(3, 3)
  120. assert MatrixPermute(A, p).doit() == A
  121. A = MatrixSymbol('A', 4, 4)
  122. p1 = Permutation(0, 1, 2, 3)
  123. p2 = Permutation(0, 2, 3, 1)
  124. expr = MatrixPermute(MatrixPermute(A, p1, 0), p2, 0)
  125. assert expr.as_explicit() == expr.doit().as_explicit()
  126. expr = MatrixPermute(MatrixPermute(A, p1, 1), p2, 1)
  127. assert expr.as_explicit() == expr.doit().as_explicit()