test_hadamard.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from sympy.matrices.dense import Matrix, eye
  2. from sympy.matrices.common import ShapeError
  3. from sympy.matrices.expressions.matadd import MatAdd
  4. from sympy.matrices.expressions.special import Identity, OneMatrix, ZeroMatrix
  5. from sympy.core import symbols
  6. from sympy.testing.pytest import raises, warns_deprecated_sympy
  7. from sympy.matrices import MatrixSymbol
  8. from sympy.matrices.expressions import (HadamardProduct, hadamard_product, HadamardPower, hadamard_power)
  9. n, m, k = symbols('n,m,k')
  10. Z = MatrixSymbol('Z', n, n)
  11. A = MatrixSymbol('A', n, m)
  12. B = MatrixSymbol('B', n, m)
  13. C = MatrixSymbol('C', m, k)
  14. def test_HadamardProduct():
  15. assert HadamardProduct(A, B, A).shape == A.shape
  16. raises(TypeError, lambda: HadamardProduct(A, n))
  17. raises(TypeError, lambda: HadamardProduct(A, 1))
  18. assert HadamardProduct(A, 2*B, -A)[1, 1] == \
  19. -2 * A[1, 1] * B[1, 1] * A[1, 1]
  20. mix = HadamardProduct(Z*A, B)*C
  21. assert mix.shape == (n, k)
  22. assert set(HadamardProduct(A, B, A).T.args) == {A.T, A.T, B.T}
  23. def test_HadamardProduct_isnt_commutative():
  24. assert HadamardProduct(A, B) != HadamardProduct(B, A)
  25. def test_mixed_indexing():
  26. X = MatrixSymbol('X', 2, 2)
  27. Y = MatrixSymbol('Y', 2, 2)
  28. Z = MatrixSymbol('Z', 2, 2)
  29. assert (X*HadamardProduct(Y, Z))[0, 0] == \
  30. X[0, 0]*Y[0, 0]*Z[0, 0] + X[0, 1]*Y[1, 0]*Z[1, 0]
  31. def test_canonicalize():
  32. X = MatrixSymbol('X', 2, 2)
  33. Y = MatrixSymbol('Y', 2, 2)
  34. with warns_deprecated_sympy():
  35. expr = HadamardProduct(X, check=False)
  36. assert isinstance(expr, HadamardProduct)
  37. expr2 = expr.doit() # unpack is called
  38. assert isinstance(expr2, MatrixSymbol)
  39. Z = ZeroMatrix(2, 2)
  40. U = OneMatrix(2, 2)
  41. assert HadamardProduct(Z, X).doit() == Z
  42. assert HadamardProduct(U, X, X, U).doit() == HadamardPower(X, 2)
  43. assert HadamardProduct(X, U, Y).doit() == HadamardProduct(X, Y)
  44. assert HadamardProduct(X, Z, U, Y).doit() == Z
  45. def test_hadamard():
  46. m, n, p = symbols('m, n, p', integer=True)
  47. A = MatrixSymbol('A', m, n)
  48. B = MatrixSymbol('B', m, n)
  49. X = MatrixSymbol('X', m, m)
  50. I = Identity(m)
  51. raises(TypeError, lambda: hadamard_product())
  52. assert hadamard_product(A) == A
  53. assert isinstance(hadamard_product(A, B), HadamardProduct)
  54. assert hadamard_product(A, B).doit() == hadamard_product(A, B)
  55. assert hadamard_product(X, I) == HadamardProduct(I, X)
  56. assert isinstance(hadamard_product(X, I), HadamardProduct)
  57. a = MatrixSymbol("a", k, 1)
  58. expr = MatAdd(ZeroMatrix(k, 1), OneMatrix(k, 1))
  59. expr = HadamardProduct(expr, a)
  60. assert expr.doit() == a
  61. raises(ValueError, lambda: HadamardProduct())
  62. def test_hadamard_product_with_explicit_mat():
  63. A = MatrixSymbol("A", 3, 3).as_explicit()
  64. B = MatrixSymbol("B", 3, 3).as_explicit()
  65. X = MatrixSymbol("X", 3, 3)
  66. expr = hadamard_product(A, B)
  67. ret = Matrix([i*j for i, j in zip(A, B)]).reshape(3, 3)
  68. assert expr == ret
  69. expr = hadamard_product(A, X, B)
  70. assert expr == HadamardProduct(ret, X)
  71. expr = hadamard_product(eye(3), A)
  72. assert expr == Matrix([[A[0, 0], 0, 0], [0, A[1, 1], 0], [0, 0, A[2, 2]]])
  73. expr = hadamard_product(eye(3), eye(3))
  74. assert expr == eye(3)
  75. def test_hadamard_power():
  76. m, n, p = symbols('m, n, p', integer=True)
  77. A = MatrixSymbol('A', m, n)
  78. assert hadamard_power(A, 1) == A
  79. assert isinstance(hadamard_power(A, 2), HadamardPower)
  80. assert hadamard_power(A, n).T == hadamard_power(A.T, n)
  81. assert hadamard_power(A, n)[0, 0] == A[0, 0]**n
  82. assert hadamard_power(m, n) == m**n
  83. raises(ValueError, lambda: hadamard_power(A, A))
  84. def test_hadamard_power_explicit():
  85. A = MatrixSymbol('A', 2, 2)
  86. B = MatrixSymbol('B', 2, 2)
  87. a, b = symbols('a b')
  88. assert HadamardPower(a, b) == a**b
  89. assert HadamardPower(a, B).as_explicit() == \
  90. Matrix([
  91. [a**B[0, 0], a**B[0, 1]],
  92. [a**B[1, 0], a**B[1, 1]]])
  93. assert HadamardPower(A, b).as_explicit() == \
  94. Matrix([
  95. [A[0, 0]**b, A[0, 1]**b],
  96. [A[1, 0]**b, A[1, 1]**b]])
  97. assert HadamardPower(A, B).as_explicit() == \
  98. Matrix([
  99. [A[0, 0]**B[0, 0], A[0, 1]**B[0, 1]],
  100. [A[1, 0]**B[1, 0], A[1, 1]**B[1, 1]]])
  101. def test_shape_error():
  102. A = MatrixSymbol('A', 2, 3)
  103. B = MatrixSymbol('B', 3, 3)
  104. raises(ShapeError, lambda: HadamardProduct(A, B))
  105. raises(ShapeError, lambda: HadamardPower(A, B))
  106. A = MatrixSymbol('A', 3, 2)
  107. raises(ShapeError, lambda: HadamardProduct(A, B))
  108. raises(ShapeError, lambda: HadamardPower(A, B))