test_symbolic_multivariate.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. from sympy.stats import Expectation, Normal, Variance, Covariance
  2. from sympy.testing.pytest import raises
  3. from sympy.core.symbol import symbols
  4. from sympy.matrices.common import ShapeError
  5. from sympy.matrices.dense import Matrix
  6. from sympy.matrices.expressions.matexpr import MatrixSymbol
  7. from sympy.matrices.expressions.special import ZeroMatrix
  8. from sympy.stats.rv import RandomMatrixSymbol
  9. from sympy.stats.symbolic_multivariate_probability import (ExpectationMatrix,
  10. VarianceMatrix, CrossCovarianceMatrix)
  11. j, k = symbols("j,k")
  12. A = MatrixSymbol("A", k, k)
  13. B = MatrixSymbol("B", k, k)
  14. C = MatrixSymbol("C", k, k)
  15. D = MatrixSymbol("D", k, k)
  16. a = MatrixSymbol("a", k, 1)
  17. b = MatrixSymbol("b", k, 1)
  18. A2 = MatrixSymbol("A2", 2, 2)
  19. B2 = MatrixSymbol("B2", 2, 2)
  20. X = RandomMatrixSymbol("X", k, 1)
  21. Y = RandomMatrixSymbol("Y", k, 1)
  22. Z = RandomMatrixSymbol("Z", k, 1)
  23. W = RandomMatrixSymbol("W", k, 1)
  24. R = RandomMatrixSymbol("R", k, k)
  25. X2 = RandomMatrixSymbol("X2", 2, 1)
  26. normal = Normal("normal", 0, 1)
  27. m1 = Matrix([
  28. [1, j*Normal("normal2", 2, 1)],
  29. [normal, 0]
  30. ])
  31. def test_multivariate_expectation():
  32. expr = Expectation(a)
  33. assert expr == Expectation(a) == ExpectationMatrix(a)
  34. assert expr.expand() == a
  35. expr = Expectation(X)
  36. assert expr == Expectation(X) == ExpectationMatrix(X)
  37. assert expr.shape == (k, 1)
  38. assert expr.rows == k
  39. assert expr.cols == 1
  40. assert isinstance(expr, ExpectationMatrix)
  41. expr = Expectation(A*X + b)
  42. assert expr == ExpectationMatrix(A*X + b)
  43. assert expr.expand() == A*ExpectationMatrix(X) + b
  44. assert isinstance(expr, ExpectationMatrix)
  45. assert expr.shape == (k, 1)
  46. expr = Expectation(m1*X2)
  47. assert expr.expand() == expr
  48. expr = Expectation(A2*m1*B2*X2)
  49. assert expr.args[0].args == (A2, m1, B2, X2)
  50. assert expr.expand() == A2*ExpectationMatrix(m1*B2*X2)
  51. expr = Expectation((X + Y)*(X - Y).T)
  52. assert expr.expand() == ExpectationMatrix(X*X.T) - ExpectationMatrix(X*Y.T) +\
  53. ExpectationMatrix(Y*X.T) - ExpectationMatrix(Y*Y.T)
  54. expr = Expectation(A*X + B*Y)
  55. assert expr.expand() == A*ExpectationMatrix(X) + B*ExpectationMatrix(Y)
  56. assert Expectation(m1).doit() == Matrix([[1, 2*j], [0, 0]])
  57. x1 = Matrix([
  58. [Normal('N11', 11, 1), Normal('N12', 12, 1)],
  59. [Normal('N21', 21, 1), Normal('N22', 22, 1)]
  60. ])
  61. x2 = Matrix([
  62. [Normal('M11', 1, 1), Normal('M12', 2, 1)],
  63. [Normal('M21', 3, 1), Normal('M22', 4, 1)]
  64. ])
  65. assert Expectation(Expectation(x1 + x2)).doit(deep=False) == ExpectationMatrix(x1 + x2)
  66. assert Expectation(Expectation(x1 + x2)).doit() == Matrix([[12, 14], [24, 26]])
  67. def test_multivariate_variance():
  68. raises(ShapeError, lambda: Variance(A))
  69. expr = Variance(a)
  70. assert expr == Variance(a) == VarianceMatrix(a)
  71. assert expr.expand() == ZeroMatrix(k, k)
  72. expr = Variance(a.T)
  73. assert expr == Variance(a.T) == VarianceMatrix(a.T)
  74. assert expr.expand() == ZeroMatrix(k, k)
  75. expr = Variance(X)
  76. assert expr == Variance(X) == VarianceMatrix(X)
  77. assert expr.shape == (k, k)
  78. assert expr.rows == k
  79. assert expr.cols == k
  80. assert isinstance(expr, VarianceMatrix)
  81. expr = Variance(A*X)
  82. assert expr == VarianceMatrix(A*X)
  83. assert expr.expand() == A*VarianceMatrix(X)*A.T
  84. assert isinstance(expr, VarianceMatrix)
  85. assert expr.shape == (k, k)
  86. expr = Variance(A*B*X)
  87. assert expr.expand() == A*B*VarianceMatrix(X)*B.T*A.T
  88. expr = Variance(m1*X2)
  89. assert expr.expand() == expr
  90. expr = Variance(A2*m1*B2*X2)
  91. assert expr.args[0].args == (A2, m1, B2, X2)
  92. assert expr.expand() == expr
  93. expr = Variance(A*X + B*Y)
  94. assert expr.expand() == 2*A*CrossCovarianceMatrix(X, Y)*B.T +\
  95. A*VarianceMatrix(X)*A.T + B*VarianceMatrix(Y)*B.T
  96. def test_multivariate_crosscovariance():
  97. raises(ShapeError, lambda: Covariance(X, Y.T))
  98. raises(ShapeError, lambda: Covariance(X, A))
  99. expr = Covariance(a.T, b.T)
  100. assert expr.shape == (1, 1)
  101. assert expr.expand() == ZeroMatrix(1, 1)
  102. expr = Covariance(a, b)
  103. assert expr == Covariance(a, b) == CrossCovarianceMatrix(a, b)
  104. assert expr.expand() == ZeroMatrix(k, k)
  105. assert expr.shape == (k, k)
  106. assert expr.rows == k
  107. assert expr.cols == k
  108. assert isinstance(expr, CrossCovarianceMatrix)
  109. expr = Covariance(A*X + a, b)
  110. assert expr.expand() == ZeroMatrix(k, k)
  111. expr = Covariance(X, Y)
  112. assert isinstance(expr, CrossCovarianceMatrix)
  113. assert expr.expand() == expr
  114. expr = Covariance(X, X)
  115. assert isinstance(expr, CrossCovarianceMatrix)
  116. assert expr.expand() == VarianceMatrix(X)
  117. expr = Covariance(X + Y, Z)
  118. assert isinstance(expr, CrossCovarianceMatrix)
  119. assert expr.expand() == CrossCovarianceMatrix(X, Z) + CrossCovarianceMatrix(Y, Z)
  120. expr = Covariance(A*X, Y)
  121. assert isinstance(expr, CrossCovarianceMatrix)
  122. assert expr.expand() == A*CrossCovarianceMatrix(X, Y)
  123. expr = Covariance(X, B*Y)
  124. assert isinstance(expr, CrossCovarianceMatrix)
  125. assert expr.expand() == CrossCovarianceMatrix(X, Y)*B.T
  126. expr = Covariance(A*X + a, B.T*Y + b)
  127. assert isinstance(expr, CrossCovarianceMatrix)
  128. assert expr.expand() == A*CrossCovarianceMatrix(X, Y)*B
  129. expr = Covariance(A*X + B*Y + a, C.T*Z + D.T*W + b)
  130. assert isinstance(expr, CrossCovarianceMatrix)
  131. assert expr.expand() == A*CrossCovarianceMatrix(X, W)*D + A*CrossCovarianceMatrix(X, Z)*C \
  132. + B*CrossCovarianceMatrix(Y, W)*D + B*CrossCovarianceMatrix(Y, Z)*C