test_matrix_distributions.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. from sympy.concrete.products import Product
  2. from sympy.core.numbers import pi
  3. from sympy.core.singleton import S
  4. from sympy.core.symbol import (Dummy, symbols)
  5. from sympy.functions.elementary.exponential import exp
  6. from sympy.functions.elementary.miscellaneous import sqrt
  7. from sympy.functions.special.gamma_functions import gamma
  8. from sympy.matrices import Determinant, Matrix, Trace, MatrixSymbol, MatrixSet
  9. from sympy.stats import density, sample
  10. from sympy.stats.matrix_distributions import (MatrixGammaDistribution,
  11. MatrixGamma, MatrixPSpace, Wishart, MatrixNormal, MatrixStudentT)
  12. from sympy.testing.pytest import raises, skip
  13. from sympy.external import import_module
  14. def test_MatrixPSpace():
  15. M = MatrixGammaDistribution(1, 2, [[2, 1], [1, 2]])
  16. MP = MatrixPSpace('M', M, 2, 2)
  17. assert MP.distribution == M
  18. raises(ValueError, lambda: MatrixPSpace('M', M, 1.2, 2))
  19. def test_MatrixGamma():
  20. M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]])
  21. assert M.pspace.distribution.set == MatrixSet(2, 2, S.Reals)
  22. assert isinstance(density(M), MatrixGammaDistribution)
  23. X = MatrixSymbol('X', 2, 2)
  24. num = exp(Trace(Matrix([[-S(1)/2, 0], [0, -S(1)/2]])*X))
  25. assert density(M)(X).doit() == num/(4*pi*sqrt(Determinant(X)))
  26. assert density(M)([[2, 1], [1, 2]]).doit() == sqrt(3)*exp(-2)/(12*pi)
  27. X = MatrixSymbol('X', 1, 2)
  28. Y = MatrixSymbol('Y', 1, 2)
  29. assert density(M)([X, Y]).doit() == exp(-X[0, 0]/2 - Y[0, 1]/2)/(4*pi*sqrt(
  30. X[0, 0]*Y[0, 1] - X[0, 1]*Y[0, 0]))
  31. # symbolic
  32. a, b = symbols('a b', positive=True)
  33. d = symbols('d', positive=True, integer=True)
  34. Y = MatrixSymbol('Y', d, d)
  35. Z = MatrixSymbol('Z', 2, 2)
  36. SM = MatrixSymbol('SM', d, d)
  37. M2 = MatrixGamma('M2', a, b, SM)
  38. M3 = MatrixGamma('M3', 2, 3, [[2, 1], [1, 2]])
  39. k = Dummy('k')
  40. exprd = pi**(-d*(d - 1)/4)*b**(-a*d)*exp(Trace((-1/b)*SM**(-1)*Y)
  41. )*Determinant(SM)**(-a)*Determinant(Y)**(a - d/2 - S(1)/2)/Product(
  42. gamma(-k/2 + a + S(1)/2), (k, 1, d))
  43. assert density(M2)(Y).dummy_eq(exprd)
  44. raises(NotImplementedError, lambda: density(M3 + M)(Z))
  45. raises(ValueError, lambda: density(M)(1))
  46. raises(ValueError, lambda: MatrixGamma('M', -1, 2, [[1, 0], [0, 1]]))
  47. raises(ValueError, lambda: MatrixGamma('M', -1, -2, [[1, 0], [0, 1]]))
  48. raises(ValueError, lambda: MatrixGamma('M', -1, 2, [[1, 0], [2, 1]]))
  49. raises(ValueError, lambda: MatrixGamma('M', -1, 2, [[1, 0], [0]]))
  50. def test_Wishart():
  51. W = Wishart('W', 5, [[1, 0], [0, 1]])
  52. assert W.pspace.distribution.set == MatrixSet(2, 2, S.Reals)
  53. X = MatrixSymbol('X', 2, 2)
  54. term1 = exp(Trace(Matrix([[-S(1)/2, 0], [0, -S(1)/2]])*X))
  55. assert density(W)(X).doit() == term1 * Determinant(X)/(24*pi)
  56. assert density(W)([[2, 1], [1, 2]]).doit() == exp(-2)/(8*pi)
  57. n = symbols('n', positive=True)
  58. d = symbols('d', positive=True, integer=True)
  59. Y = MatrixSymbol('Y', d, d)
  60. SM = MatrixSymbol('SM', d, d)
  61. W = Wishart('W', n, SM)
  62. k = Dummy('k')
  63. exprd = 2**(-d*n/2)*pi**(-d*(d - 1)/4)*exp(Trace(-(S(1)/2)*SM**(-1)*Y)
  64. )*Determinant(SM)**(-n/2)*Determinant(Y)**(
  65. -d/2 + n/2 - S(1)/2)/Product(gamma(-k/2 + n/2 + S(1)/2), (k, 1, d))
  66. assert density(W)(Y).dummy_eq(exprd)
  67. raises(ValueError, lambda: density(W)(1))
  68. raises(ValueError, lambda: Wishart('W', -1, [[1, 0], [0, 1]]))
  69. raises(ValueError, lambda: Wishart('W', -1, [[1, 0], [2, 1]]))
  70. raises(ValueError, lambda: Wishart('W', 2, [[1, 0], [0]]))
  71. def test_MatrixNormal():
  72. M = MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]])
  73. assert M.pspace.distribution.set == MatrixSet(1, 2, S.Reals)
  74. X = MatrixSymbol('X', 1, 2)
  75. term1 = exp(-Trace(Matrix([[ S(2)/3, -S(1)/3], [-S(1)/3, S(2)/3]])*(
  76. Matrix([[-5], [-6]]) + X.T)*Matrix([[S(1)/4]])*(Matrix([[-5, -6]]) + X))/2)
  77. assert density(M)(X).doit() == (sqrt(3)) * term1/(24*pi)
  78. assert density(M)([[7, 8]]).doit() == sqrt(3)*exp(-S(1)/3)/(24*pi)
  79. d, n = symbols('d n', positive=True, integer=True)
  80. SM2 = MatrixSymbol('SM2', d, d)
  81. SM1 = MatrixSymbol('SM1', n, n)
  82. LM = MatrixSymbol('LM', n, d)
  83. Y = MatrixSymbol('Y', n, d)
  84. M = MatrixNormal('M', LM, SM1, SM2)
  85. exprd = (2*pi)**(-d*n/2)*exp(-Trace(SM2**(-1)*(-LM.T + Y.T)*SM1**(-1)*(-LM + Y)
  86. )/2)*Determinant(SM1)**(-d/2)*Determinant(SM2)**(-n/2)
  87. assert density(M)(Y).doit() == exprd
  88. raises(ValueError, lambda: density(M)(1))
  89. raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [0, 1]], [[1, 0], [2, 1]]))
  90. raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2, 1]], [[1, 0], [0, 1]]))
  91. raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [0, 1]], [[1, 0], [0, 1]]))
  92. raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2]], [[1, 0], [0, 1]]))
  93. raises(ValueError, lambda: MatrixNormal('M', [1, 2], [[1, 0], [2, 1]], [[1, 0], [0]]))
  94. raises(ValueError, lambda: MatrixNormal('M', [[1, 2]], [[1, 0], [0, 1]], [[1, 0]]))
  95. raises(ValueError, lambda: MatrixNormal('M', [[1, 2]], [1], [[1, 0]]))
  96. def test_MatrixStudentT():
  97. M = MatrixStudentT('M', 2, [[5, 6]], [[2, 1], [1, 2]], [4])
  98. assert M.pspace.distribution.set == MatrixSet(1, 2, S.Reals)
  99. X = MatrixSymbol('X', 1, 2)
  100. D = pi ** (-1.0) * Determinant(Matrix([[4]])) ** (-1.0) * Determinant(Matrix([[2, 1], [1, 2]])) \
  101. ** (-0.5) / Determinant(Matrix([[S(1) / 4]]) * (Matrix([[-5, -6]]) + X)
  102. * Matrix([[S(2) / 3, -S(1) / 3], [-S(1) / 3, S(2) / 3]]) * (
  103. Matrix([[-5], [-6]]) + X.T) + Matrix([[1]])) ** 2
  104. assert density(M)(X) == D
  105. v = symbols('v', positive=True)
  106. n, p = 1, 2
  107. Omega = MatrixSymbol('Omega', p, p)
  108. Sigma = MatrixSymbol('Sigma', n, n)
  109. Location = MatrixSymbol('Location', n, p)
  110. Y = MatrixSymbol('Y', n, p)
  111. M = MatrixStudentT('M', v, Location, Omega, Sigma)
  112. exprd = gamma(v/2 + 1)*Determinant(Matrix([[1]]) + Sigma**(-1)*(-Location + Y)*Omega**(-1)*(-Location.T + Y.T))**(-v/2 - 1) / \
  113. (pi*gamma(v/2)*sqrt(Determinant(Omega))*Determinant(Sigma))
  114. assert density(M)(Y) == exprd
  115. raises(ValueError, lambda: density(M)(1))
  116. raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [0, 1]], [[1, 0], [2, 1]]))
  117. raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [2, 1]], [[1, 0], [0, 1]]))
  118. raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [0, 1]], [[1, 0], [0, 1]]))
  119. raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [2]], [[1, 0], [0, 1]]))
  120. raises(ValueError, lambda: MatrixStudentT('M', 1, [1, 2], [[1, 0], [2, 1]], [[1], [2]]))
  121. raises(ValueError, lambda: MatrixStudentT('M', 1, [[1, 2]], [[1, 0], [0, 1]], [[1, 0]]))
  122. raises(ValueError, lambda: MatrixStudentT('M', 1, [[1, 2]], [1], [[1, 0]]))
  123. raises(ValueError, lambda: MatrixStudentT('M', -1, [1, 2], [[1, 0], [0, 1]], [4]))
  124. def test_sample_scipy():
  125. distribs_scipy = [
  126. MatrixNormal('M', [[5, 6]], [4], [[2, 1], [1, 2]]),
  127. Wishart('W', 5, [[1, 0], [0, 1]])
  128. ]
  129. size = 5
  130. scipy = import_module('scipy')
  131. if not scipy:
  132. skip('Scipy not installed. Abort tests for _sample_scipy.')
  133. else:
  134. for X in distribs_scipy:
  135. samps = sample(X, size=size)
  136. for sam in samps:
  137. assert Matrix(sam) in X.pspace.distribution.set
  138. M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]])
  139. raises(NotImplementedError, lambda: sample(M, size=3))
  140. def test_sample_pymc():
  141. distribs_pymc = [
  142. MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]], [[2, 1], [1, 2]]),
  143. Wishart('W', 7, [[2, 1], [1, 2]])
  144. ]
  145. size = 3
  146. pymc = import_module('pymc')
  147. if not pymc:
  148. skip('PyMC is not installed. Abort tests for _sample_pymc.')
  149. else:
  150. for X in distribs_pymc:
  151. samps = sample(X, size=size, library='pymc')
  152. for sam in samps:
  153. assert Matrix(sam) in X.pspace.distribution.set
  154. M = MatrixGamma('M', 1, 2, [[1, 0], [0, 1]])
  155. raises(NotImplementedError, lambda: sample(M, size=3))
  156. def test_sample_seed():
  157. X = MatrixNormal('M', [[5, 6], [3, 4]], [[1, 0], [0, 1]], [[2, 1], [1, 2]])
  158. libraries = ['scipy', 'numpy', 'pymc']
  159. for lib in libraries:
  160. try:
  161. imported_lib = import_module(lib)
  162. if imported_lib:
  163. s0, s1, s2 = [], [], []
  164. s0 = sample(X, size=10, library=lib, seed=0)
  165. s1 = sample(X, size=10, library=lib, seed=0)
  166. s2 = sample(X, size=10, library=lib, seed=1)
  167. for i in range(10):
  168. assert (s0[i] == s1[i]).all()
  169. assert (s1[i] != s2[i]).all()
  170. except NotImplementedError:
  171. continue