test_diagonal.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from sympy.matrices.expressions import MatrixSymbol
  2. from sympy.matrices.expressions.diagonal import DiagonalMatrix, DiagonalOf, DiagMatrix, diagonalize_vector
  3. from sympy.assumptions.ask import (Q, ask)
  4. from sympy.core.symbol import Symbol
  5. from sympy.functions.special.tensor_functions import KroneckerDelta
  6. from sympy.matrices.dense import Matrix
  7. from sympy.matrices.expressions.matmul import MatMul
  8. from sympy.matrices.expressions.special import Identity
  9. from sympy.testing.pytest import raises
  10. n = Symbol('n')
  11. m = Symbol('m')
  12. def test_DiagonalMatrix():
  13. x = MatrixSymbol('x', n, m)
  14. D = DiagonalMatrix(x)
  15. assert D.diagonal_length is None
  16. assert D.shape == (n, m)
  17. x = MatrixSymbol('x', n, n)
  18. D = DiagonalMatrix(x)
  19. assert D.diagonal_length == n
  20. assert D.shape == (n, n)
  21. assert D[1, 2] == 0
  22. assert D[1, 1] == x[1, 1]
  23. i = Symbol('i')
  24. j = Symbol('j')
  25. x = MatrixSymbol('x', 3, 3)
  26. ij = DiagonalMatrix(x)[i, j]
  27. assert ij != 0
  28. assert ij.subs({i:0, j:0}) == x[0, 0]
  29. assert ij.subs({i:0, j:1}) == 0
  30. assert ij.subs({i:1, j:1}) == x[1, 1]
  31. assert ask(Q.diagonal(D)) # affirm that D is diagonal
  32. x = MatrixSymbol('x', n, 3)
  33. D = DiagonalMatrix(x)
  34. assert D.diagonal_length == 3
  35. assert D.shape == (n, 3)
  36. assert D[2, m] == KroneckerDelta(2, m)*x[2, m]
  37. assert D[3, m] == 0
  38. raises(IndexError, lambda: D[m, 3])
  39. x = MatrixSymbol('x', 3, n)
  40. D = DiagonalMatrix(x)
  41. assert D.diagonal_length == 3
  42. assert D.shape == (3, n)
  43. assert D[m, 2] == KroneckerDelta(m, 2)*x[m, 2]
  44. assert D[m, 3] == 0
  45. raises(IndexError, lambda: D[3, m])
  46. x = MatrixSymbol('x', n, m)
  47. D = DiagonalMatrix(x)
  48. assert D.diagonal_length is None
  49. assert D.shape == (n, m)
  50. assert D[m, 4] != 0
  51. x = MatrixSymbol('x', 3, 4)
  52. assert [DiagonalMatrix(x)[i] for i in range(12)] == [
  53. x[0, 0], 0, 0, 0, 0, x[1, 1], 0, 0, 0, 0, x[2, 2], 0]
  54. # shape is retained, issue 12427
  55. assert (
  56. DiagonalMatrix(MatrixSymbol('x', 3, 4))*
  57. DiagonalMatrix(MatrixSymbol('x', 4, 2))).shape == (3, 2)
  58. def test_DiagonalOf():
  59. x = MatrixSymbol('x', n, n)
  60. d = DiagonalOf(x)
  61. assert d.shape == (n, 1)
  62. assert d.diagonal_length == n
  63. assert d[2, 0] == d[2] == x[2, 2]
  64. x = MatrixSymbol('x', n, m)
  65. d = DiagonalOf(x)
  66. assert d.shape == (None, 1)
  67. assert d.diagonal_length is None
  68. assert d[2, 0] == d[2] == x[2, 2]
  69. d = DiagonalOf(MatrixSymbol('x', 4, 3))
  70. assert d.shape == (3, 1)
  71. d = DiagonalOf(MatrixSymbol('x', n, 3))
  72. assert d.shape == (3, 1)
  73. d = DiagonalOf(MatrixSymbol('x', 3, n))
  74. assert d.shape == (3, 1)
  75. x = MatrixSymbol('x', n, m)
  76. assert [DiagonalOf(x)[i] for i in range(4)] ==[
  77. x[0, 0], x[1, 1], x[2, 2], x[3, 3]]
  78. def test_DiagMatrix():
  79. x = MatrixSymbol('x', n, 1)
  80. d = DiagMatrix(x)
  81. assert d.shape == (n, n)
  82. assert d[0, 1] == 0
  83. assert d[0, 0] == x[0, 0]
  84. a = MatrixSymbol('a', 1, 1)
  85. d = diagonalize_vector(a)
  86. assert isinstance(d, MatrixSymbol)
  87. assert a == d
  88. assert diagonalize_vector(Identity(3)) == Identity(3)
  89. assert DiagMatrix(Identity(3)).doit() == Identity(3)
  90. assert isinstance(DiagMatrix(Identity(3)), DiagMatrix)
  91. # A diagonal matrix is equal to its transpose:
  92. assert DiagMatrix(x).T == DiagMatrix(x)
  93. assert diagonalize_vector(x.T) == DiagMatrix(x)
  94. dx = DiagMatrix(x)
  95. assert dx[0, 0] == x[0, 0]
  96. assert dx[1, 1] == x[1, 0]
  97. assert dx[0, 1] == 0
  98. assert dx[0, m] == x[0, 0]*KroneckerDelta(0, m)
  99. z = MatrixSymbol('z', 1, n)
  100. dz = DiagMatrix(z)
  101. assert dz[0, 0] == z[0, 0]
  102. assert dz[1, 1] == z[0, 1]
  103. assert dz[0, 1] == 0
  104. assert dz[0, m] == z[0, m]*KroneckerDelta(0, m)
  105. v = MatrixSymbol('v', 3, 1)
  106. dv = DiagMatrix(v)
  107. assert dv.as_explicit() == Matrix([
  108. [v[0, 0], 0, 0],
  109. [0, v[1, 0], 0],
  110. [0, 0, v[2, 0]],
  111. ])
  112. v = MatrixSymbol('v', 1, 3)
  113. dv = DiagMatrix(v)
  114. assert dv.as_explicit() == Matrix([
  115. [v[0, 0], 0, 0],
  116. [0, v[0, 1], 0],
  117. [0, 0, v[0, 2]],
  118. ])
  119. dv = DiagMatrix(3*v)
  120. assert dv.args == (3*v,)
  121. assert dv.doit() == 3*DiagMatrix(v)
  122. assert isinstance(dv.doit(), MatMul)
  123. a = MatrixSymbol("a", 3, 1).as_explicit()
  124. expr = DiagMatrix(a)
  125. result = Matrix([
  126. [a[0, 0], 0, 0],
  127. [0, a[1, 0], 0],
  128. [0, 0, a[2, 0]],
  129. ])
  130. assert expr.doit() == result
  131. expr = DiagMatrix(a.T)
  132. assert expr.doit() == result