test_joint_rv.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. from sympy.concrete.products import Product
  2. from sympy.concrete.summations import Sum
  3. from sympy.core.numbers import (Rational, oo, pi)
  4. from sympy.core.relational import Eq
  5. from sympy.core.singleton import S
  6. from sympy.core.symbol import symbols
  7. from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial)
  8. from sympy.functions.elementary.complexes import polar_lift
  9. from sympy.functions.elementary.exponential import exp
  10. from sympy.functions.elementary.miscellaneous import sqrt
  11. from sympy.functions.elementary.piecewise import Piecewise
  12. from sympy.functions.special.bessel import besselk
  13. from sympy.functions.special.gamma_functions import gamma
  14. from sympy.matrices.dense import eye
  15. from sympy.matrices.expressions.determinant import Determinant
  16. from sympy.sets.fancysets import Range
  17. from sympy.sets.sets import (Interval, ProductSet)
  18. from sympy.simplify.simplify import simplify
  19. from sympy.tensor.indexed import (Indexed, IndexedBase)
  20. from sympy.core.numbers import comp
  21. from sympy.integrals.integrals import integrate
  22. from sympy.matrices import Matrix, MatrixSymbol
  23. from sympy.matrices.expressions.matexpr import MatrixElement
  24. from sympy.stats import density, median, marginal_distribution, Normal, Laplace, E, sample
  25. from sympy.stats.joint_rv_types import (JointRV, MultivariateNormalDistribution,
  26. JointDistributionHandmade, MultivariateT, NormalGamma,
  27. GeneralizedMultivariateLogGammaOmega as GMVLGO, MultivariateBeta,
  28. GeneralizedMultivariateLogGamma as GMVLG, MultivariateEwens,
  29. Multinomial, NegativeMultinomial, MultivariateNormal,
  30. MultivariateLaplace)
  31. from sympy.testing.pytest import raises, XFAIL, skip, slow
  32. from sympy.external import import_module
  33. from sympy.abc import x, y
  34. def test_Normal():
  35. m = Normal('A', [1, 2], [[1, 0], [0, 1]])
  36. A = MultivariateNormal('A', [1, 2], [[1, 0], [0, 1]])
  37. assert m == A
  38. assert density(m)(1, 2) == 1/(2*pi)
  39. assert m.pspace.distribution.set == ProductSet(S.Reals, S.Reals)
  40. raises (ValueError, lambda:m[2])
  41. n = Normal('B', [1, 2, 3], [[1, 0, 0], [0, 1, 0], [0, 0, 1]])
  42. p = Normal('C', Matrix([1, 2]), Matrix([[1, 0], [0, 1]]))
  43. assert density(m)(x, y) == density(p)(x, y)
  44. assert marginal_distribution(n, 0, 1)(1, 2) == 1/(2*pi)
  45. raises(ValueError, lambda: marginal_distribution(m))
  46. assert integrate(density(m)(x, y), (x, -oo, oo), (y, -oo, oo)).evalf() == 1.0
  47. N = Normal('N', [1, 2], [[x, 0], [0, y]])
  48. assert density(N)(0, 0) == exp(-((4*x + y)/(2*x*y)))/(2*pi*sqrt(x*y))
  49. raises (ValueError, lambda: Normal('M', [1, 2], [[1, 1], [1, -1]]))
  50. # symbolic
  51. n = symbols('n', integer=True, positive=True)
  52. mu = MatrixSymbol('mu', n, 1)
  53. sigma = MatrixSymbol('sigma', n, n)
  54. X = Normal('X', mu, sigma)
  55. assert density(X) == MultivariateNormalDistribution(mu, sigma)
  56. raises (NotImplementedError, lambda: median(m))
  57. # Below tests should work after issue #17267 is resolved
  58. # assert E(X) == mu
  59. # assert variance(X) == sigma
  60. # test symbolic multivariate normal densities
  61. n = 3
  62. Sg = MatrixSymbol('Sg', n, n)
  63. mu = MatrixSymbol('mu', n, 1)
  64. obs = MatrixSymbol('obs', n, 1)
  65. X = MultivariateNormal('X', mu, Sg)
  66. density_X = density(X)
  67. eval_a = density_X(obs).subs({Sg: eye(3),
  68. mu: Matrix([0, 0, 0]), obs: Matrix([0, 0, 0])}).doit()
  69. eval_b = density_X(0, 0, 0).subs({Sg: eye(3), mu: Matrix([0, 0, 0])}).doit()
  70. assert eval_a == sqrt(2)/(4*pi**Rational(3/2))
  71. assert eval_b == sqrt(2)/(4*pi**Rational(3/2))
  72. n = symbols('n', integer=True, positive=True)
  73. Sg = MatrixSymbol('Sg', n, n)
  74. mu = MatrixSymbol('mu', n, 1)
  75. obs = MatrixSymbol('obs', n, 1)
  76. X = MultivariateNormal('X', mu, Sg)
  77. density_X_at_obs = density(X)(obs)
  78. expected_density = MatrixElement(
  79. exp((S(1)/2) * (mu.T - obs.T) * Sg**(-1) * (-mu + obs)) / \
  80. sqrt((2*pi)**n * Determinant(Sg)), 0, 0)
  81. assert density_X_at_obs == expected_density
  82. def test_MultivariateTDist():
  83. t1 = MultivariateT('T', [0, 0], [[1, 0], [0, 1]], 2)
  84. assert(density(t1))(1, 1) == 1/(8*pi)
  85. assert t1.pspace.distribution.set == ProductSet(S.Reals, S.Reals)
  86. assert integrate(density(t1)(x, y), (x, -oo, oo), \
  87. (y, -oo, oo)).evalf() == 1.0
  88. raises(ValueError, lambda: MultivariateT('T', [1, 2], [[1, 1], [1, -1]], 1))
  89. t2 = MultivariateT('t2', [1, 2], [[x, 0], [0, y]], 1)
  90. assert density(t2)(1, 2) == 1/(2*pi*sqrt(x*y))
  91. def test_multivariate_laplace():
  92. raises(ValueError, lambda: Laplace('T', [1, 2], [[1, 2], [2, 1]]))
  93. L = Laplace('L', [1, 0], [[1, 0], [0, 1]])
  94. L2 = MultivariateLaplace('L2', [1, 0], [[1, 0], [0, 1]])
  95. assert density(L)(2, 3) == exp(2)*besselk(0, sqrt(39))/pi
  96. L1 = Laplace('L1', [1, 2], [[x, 0], [0, y]])
  97. assert density(L1)(0, 1) == \
  98. exp(2/y)*besselk(0, sqrt((2 + 4/y + 1/x)/y))/(pi*sqrt(x*y))
  99. assert L.pspace.distribution.set == ProductSet(S.Reals, S.Reals)
  100. assert L.pspace.distribution == L2.pspace.distribution
  101. def test_NormalGamma():
  102. ng = NormalGamma('G', 1, 2, 3, 4)
  103. assert density(ng)(1, 1) == 32*exp(-4)/sqrt(pi)
  104. assert ng.pspace.distribution.set == ProductSet(S.Reals, Interval(0, oo))
  105. raises(ValueError, lambda:NormalGamma('G', 1, 2, 3, -1))
  106. assert marginal_distribution(ng, 0)(1) == \
  107. 3*sqrt(10)*gamma(Rational(7, 4))/(10*sqrt(pi)*gamma(Rational(5, 4)))
  108. assert marginal_distribution(ng, y)(1) == exp(Rational(-1, 4))/128
  109. assert marginal_distribution(ng,[0,1])(x) == x**2*exp(-x/4)/128
  110. def test_GeneralizedMultivariateLogGammaDistribution():
  111. h = S.Half
  112. omega = Matrix([[1, h, h, h],
  113. [h, 1, h, h],
  114. [h, h, 1, h],
  115. [h, h, h, 1]])
  116. v, l, mu = (4, [1, 2, 3, 4], [1, 2, 3, 4])
  117. y_1, y_2, y_3, y_4 = symbols('y_1:5', real=True)
  118. delta = symbols('d', positive=True)
  119. G = GMVLGO('G', omega, v, l, mu)
  120. Gd = GMVLG('Gd', delta, v, l, mu)
  121. dend = ("d**4*Sum(4*24**(-n - 4)*(1 - d)**n*exp((n + 4)*(y_1 + 2*y_2 + 3*y_3 "
  122. "+ 4*y_4) - exp(y_1) - exp(2*y_2)/2 - exp(3*y_3)/3 - exp(4*y_4)/4)/"
  123. "(gamma(n + 1)*gamma(n + 4)**3), (n, 0, oo))")
  124. assert str(density(Gd)(y_1, y_2, y_3, y_4)) == dend
  125. den = ("5*2**(2/3)*5**(1/3)*Sum(4*24**(-n - 4)*(-2**(2/3)*5**(1/3)/4 + 1)**n*"
  126. "exp((n + 4)*(y_1 + 2*y_2 + 3*y_3 + 4*y_4) - exp(y_1) - exp(2*y_2)/2 - "
  127. "exp(3*y_3)/3 - exp(4*y_4)/4)/(gamma(n + 1)*gamma(n + 4)**3), (n, 0, oo))/64")
  128. assert str(density(G)(y_1, y_2, y_3, y_4)) == den
  129. marg = ("5*2**(2/3)*5**(1/3)*exp(4*y_1)*exp(-exp(y_1))*Integral(exp(-exp(4*G[3])"
  130. "/4)*exp(16*G[3])*Integral(exp(-exp(3*G[2])/3)*exp(12*G[2])*Integral(exp("
  131. "-exp(2*G[1])/2)*exp(8*G[1])*Sum((-1/4)**n*(-4 + 2**(2/3)*5**(1/3"
  132. "))**n*exp(n*y_1)*exp(2*n*G[1])*exp(3*n*G[2])*exp(4*n*G[3])/(24**n*gamma(n + 1)"
  133. "*gamma(n + 4)**3), (n, 0, oo)), (G[1], -oo, oo)), (G[2], -oo, oo)), (G[3]"
  134. ", -oo, oo))/5308416")
  135. assert str(marginal_distribution(G, G[0])(y_1)) == marg
  136. omega_f1 = Matrix([[1, h, h]])
  137. omega_f2 = Matrix([[1, h, h, h],
  138. [h, 1, 2, h],
  139. [h, h, 1, h],
  140. [h, h, h, 1]])
  141. omega_f3 = Matrix([[6, h, h, h],
  142. [h, 1, 2, h],
  143. [h, h, 1, h],
  144. [h, h, h, 1]])
  145. v_f = symbols("v_f", positive=False, real=True)
  146. l_f = [1, 2, v_f, 4]
  147. m_f = [v_f, 2, 3, 4]
  148. omega_f4 = Matrix([[1, h, h, h, h],
  149. [h, 1, h, h, h],
  150. [h, h, 1, h, h],
  151. [h, h, h, 1, h],
  152. [h, h, h, h, 1]])
  153. l_f1 = [1, 2, 3, 4, 5]
  154. omega_f5 = Matrix([[1]])
  155. mu_f5 = l_f5 = [1]
  156. raises(ValueError, lambda: GMVLGO('G', omega_f1, v, l, mu))
  157. raises(ValueError, lambda: GMVLGO('G', omega_f2, v, l, mu))
  158. raises(ValueError, lambda: GMVLGO('G', omega_f3, v, l, mu))
  159. raises(ValueError, lambda: GMVLGO('G', omega, v_f, l, mu))
  160. raises(ValueError, lambda: GMVLGO('G', omega, v, l_f, mu))
  161. raises(ValueError, lambda: GMVLGO('G', omega, v, l, m_f))
  162. raises(ValueError, lambda: GMVLGO('G', omega_f4, v, l, mu))
  163. raises(ValueError, lambda: GMVLGO('G', omega, v, l_f1, mu))
  164. raises(ValueError, lambda: GMVLGO('G', omega_f5, v, l_f5, mu_f5))
  165. raises(ValueError, lambda: GMVLG('G', Rational(3, 2), v, l, mu))
  166. def test_MultivariateBeta():
  167. a1, a2 = symbols('a1, a2', positive=True)
  168. a1_f, a2_f = symbols('a1, a2', positive=False, real=True)
  169. mb = MultivariateBeta('B', [a1, a2])
  170. mb_c = MultivariateBeta('C', a1, a2)
  171. assert density(mb)(1, 2) == S(2)**(a2 - 1)*gamma(a1 + a2)/\
  172. (gamma(a1)*gamma(a2))
  173. assert marginal_distribution(mb_c, 0)(3) == S(3)**(a1 - 1)*gamma(a1 + a2)/\
  174. (a2*gamma(a1)*gamma(a2))
  175. raises(ValueError, lambda: MultivariateBeta('b1', [a1_f, a2]))
  176. raises(ValueError, lambda: MultivariateBeta('b2', [a1, a2_f]))
  177. raises(ValueError, lambda: MultivariateBeta('b3', [0, 0]))
  178. raises(ValueError, lambda: MultivariateBeta('b4', [a1_f, a2_f]))
  179. assert mb.pspace.distribution.set == ProductSet(Interval(0, 1), Interval(0, 1))
  180. def test_MultivariateEwens():
  181. n, theta, i = symbols('n theta i', positive=True)
  182. # tests for integer dimensions
  183. theta_f = symbols('t_f', negative=True)
  184. a = symbols('a_1:4', positive = True, integer = True)
  185. ed = MultivariateEwens('E', 3, theta)
  186. assert density(ed)(a[0], a[1], a[2]) == Piecewise((6*2**(-a[1])*3**(-a[2])*
  187. theta**a[0]*theta**a[1]*theta**a[2]/
  188. (theta*(theta + 1)*(theta + 2)*
  189. factorial(a[0])*factorial(a[1])*
  190. factorial(a[2])), Eq(a[0] + 2*a[1] +
  191. 3*a[2], 3)), (0, True))
  192. assert marginal_distribution(ed, ed[1])(a[1]) == Piecewise((6*2**(-a[1])*
  193. theta**a[1]/((theta + 1)*
  194. (theta + 2)*factorial(a[1])),
  195. Eq(2*a[1] + 1, 3)), (0, True))
  196. raises(ValueError, lambda: MultivariateEwens('e1', 5, theta_f))
  197. assert ed.pspace.distribution.set == ProductSet(Range(0, 4, 1),
  198. Range(0, 2, 1), Range(0, 2, 1))
  199. # tests for symbolic dimensions
  200. eds = MultivariateEwens('E', n, theta)
  201. a = IndexedBase('a')
  202. j, k = symbols('j, k')
  203. den = Piecewise((factorial(n)*Product(theta**a[j]*(j + 1)**(-a[j])/
  204. factorial(a[j]), (j, 0, n - 1))/RisingFactorial(theta, n),
  205. Eq(n, Sum((k + 1)*a[k], (k, 0, n - 1)))), (0, True))
  206. assert density(eds)(a).dummy_eq(den)
  207. def test_Multinomial():
  208. n, x1, x2, x3, x4 = symbols('n, x1, x2, x3, x4', nonnegative=True, integer=True)
  209. p1, p2, p3, p4 = symbols('p1, p2, p3, p4', positive=True)
  210. p1_f, n_f = symbols('p1_f, n_f', negative=True)
  211. M = Multinomial('M', n, [p1, p2, p3, p4])
  212. C = Multinomial('C', 3, p1, p2, p3)
  213. f = factorial
  214. assert density(M)(x1, x2, x3, x4) == Piecewise((p1**x1*p2**x2*p3**x3*p4**x4*
  215. f(n)/(f(x1)*f(x2)*f(x3)*f(x4)),
  216. Eq(n, x1 + x2 + x3 + x4)), (0, True))
  217. assert marginal_distribution(C, C[0])(x1).subs(x1, 1) ==\
  218. 3*p1*p2**2 +\
  219. 6*p1*p2*p3 +\
  220. 3*p1*p3**2
  221. raises(ValueError, lambda: Multinomial('b1', 5, [p1, p2, p3, p1_f]))
  222. raises(ValueError, lambda: Multinomial('b2', n_f, [p1, p2, p3, p4]))
  223. raises(ValueError, lambda: Multinomial('b3', n, 0.5, 0.4, 0.3, 0.1))
  224. def test_NegativeMultinomial():
  225. k0, x1, x2, x3, x4 = symbols('k0, x1, x2, x3, x4', nonnegative=True, integer=True)
  226. p1, p2, p3, p4 = symbols('p1, p2, p3, p4', positive=True)
  227. p1_f = symbols('p1_f', negative=True)
  228. N = NegativeMultinomial('N', 4, [p1, p2, p3, p4])
  229. C = NegativeMultinomial('C', 4, 0.1, 0.2, 0.3)
  230. g = gamma
  231. f = factorial
  232. assert simplify(density(N)(x1, x2, x3, x4) -
  233. p1**x1*p2**x2*p3**x3*p4**x4*(-p1 - p2 - p3 - p4 + 1)**4*g(x1 + x2 +
  234. x3 + x4 + 4)/(6*f(x1)*f(x2)*f(x3)*f(x4))) is S.Zero
  235. assert comp(marginal_distribution(C, C[0])(1).evalf(), 0.33, .01)
  236. raises(ValueError, lambda: NegativeMultinomial('b1', 5, [p1, p2, p3, p1_f]))
  237. raises(ValueError, lambda: NegativeMultinomial('b2', k0, 0.5, 0.4, 0.3, 0.4))
  238. assert N.pspace.distribution.set == ProductSet(Range(0, oo, 1),
  239. Range(0, oo, 1), Range(0, oo, 1), Range(0, oo, 1))
  240. @slow
  241. def test_JointPSpace_marginal_distribution():
  242. T = MultivariateT('T', [0, 0], [[1, 0], [0, 1]], 2)
  243. got = marginal_distribution(T, T[1])(x)
  244. ans = sqrt(2)*(x**2/2 + 1)/(4*polar_lift(x**2/2 + 1)**(S(5)/2))
  245. assert got == ans, got
  246. assert integrate(marginal_distribution(T, 1)(x), (x, -oo, oo)) == 1
  247. t = MultivariateT('T', [0, 0, 0], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], 3)
  248. assert comp(marginal_distribution(t, 0)(1).evalf(), 0.2, .01)
  249. def test_JointRV():
  250. x1, x2 = (Indexed('x', i) for i in (1, 2))
  251. pdf = exp(-x1**2/2 + x1 - x2**2/2 - S.Half)/(2*pi)
  252. X = JointRV('x', pdf)
  253. assert density(X)(1, 2) == exp(-2)/(2*pi)
  254. assert isinstance(X.pspace.distribution, JointDistributionHandmade)
  255. assert marginal_distribution(X, 0)(2) == sqrt(2)*exp(Rational(-1, 2))/(2*sqrt(pi))
  256. def test_expectation():
  257. m = Normal('A', [x, y], [[1, 0], [0, 1]])
  258. assert simplify(E(m[1])) == y
  259. @XFAIL
  260. def test_joint_vector_expectation():
  261. m = Normal('A', [x, y], [[1, 0], [0, 1]])
  262. assert E(m) == (x, y)
  263. def test_sample_numpy():
  264. distribs_numpy = [
  265. MultivariateNormal("M", [3, 4], [[2, 1], [1, 2]]),
  266. MultivariateBeta("B", [0.4, 5, 15, 50, 203]),
  267. Multinomial("N", 50, [0.3, 0.2, 0.1, 0.25, 0.15])
  268. ]
  269. size = 3
  270. numpy = import_module('numpy')
  271. if not numpy:
  272. skip('Numpy is not installed. Abort tests for _sample_numpy.')
  273. else:
  274. for X in distribs_numpy:
  275. samps = sample(X, size=size, library='numpy')
  276. for sam in samps:
  277. assert tuple(sam) in X.pspace.distribution.set
  278. N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1)
  279. raises(NotImplementedError, lambda: sample(N_c, library='numpy'))
  280. def test_sample_scipy():
  281. distribs_scipy = [
  282. MultivariateNormal("M", [0, 0], [[0.1, 0.025], [0.025, 0.1]]),
  283. MultivariateBeta("B", [0.4, 5, 15]),
  284. Multinomial("N", 8, [0.3, 0.2, 0.1, 0.4])
  285. ]
  286. size = 3
  287. scipy = import_module('scipy')
  288. if not scipy:
  289. skip('Scipy not installed. Abort tests for _sample_scipy.')
  290. else:
  291. for X in distribs_scipy:
  292. samps = sample(X, size=size)
  293. samps2 = sample(X, size=(2, 2))
  294. for sam in samps:
  295. assert tuple(sam) in X.pspace.distribution.set
  296. for i in range(2):
  297. for j in range(2):
  298. assert tuple(samps2[i][j]) in X.pspace.distribution.set
  299. N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1)
  300. raises(NotImplementedError, lambda: sample(N_c))
  301. def test_sample_pymc():
  302. distribs_pymc = [
  303. MultivariateNormal("M", [5, 2], [[1, 0], [0, 1]]),
  304. MultivariateBeta("B", [0.4, 5, 15]),
  305. Multinomial("N", 4, [0.3, 0.2, 0.1, 0.4])
  306. ]
  307. size = 3
  308. pymc = import_module('pymc')
  309. if not pymc:
  310. skip('PyMC is not installed. Abort tests for _sample_pymc.')
  311. else:
  312. for X in distribs_pymc:
  313. samps = sample(X, size=size, library='pymc')
  314. for sam in samps:
  315. assert tuple(sam.flatten()) in X.pspace.distribution.set
  316. N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1)
  317. raises(NotImplementedError, lambda: sample(N_c, library='pymc'))
  318. def test_sample_seed():
  319. x1, x2 = (Indexed('x', i) for i in (1, 2))
  320. pdf = exp(-x1**2/2 + x1 - x2**2/2 - S.Half)/(2*pi)
  321. X = JointRV('x', pdf)
  322. libraries = ['scipy', 'numpy', 'pymc']
  323. for lib in libraries:
  324. try:
  325. imported_lib = import_module(lib)
  326. if imported_lib:
  327. s0, s1, s2 = [], [], []
  328. s0 = sample(X, size=10, library=lib, seed=0)
  329. s1 = sample(X, size=10, library=lib, seed=0)
  330. s2 = sample(X, size=10, library=lib, seed=1)
  331. assert all(s0 == s1)
  332. assert all(s1 != s2)
  333. except NotImplementedError:
  334. continue
  335. #
  336. # XXX: This fails for pymc. Previously the test appeared to pass but that is
  337. # just because the library argument was not passed so the test always used
  338. # scipy.
  339. #
  340. def test_issue_21057():
  341. m = Normal("x", [0, 0], [[0, 0], [0, 0]])
  342. n = MultivariateNormal("x", [0, 0], [[0, 0], [0, 0]])
  343. p = Normal("x", [0, 0], [[0, 0], [0, 1]])
  344. assert m == n
  345. libraries = ('scipy', 'numpy') # , 'pymc') # <-- pymc fails
  346. for library in libraries:
  347. try:
  348. imported_lib = import_module(library)
  349. if imported_lib:
  350. s1 = sample(m, size=8, library=library)
  351. s2 = sample(n, size=8, library=library)
  352. s3 = sample(p, size=8, library=library)
  353. assert tuple(s1.flatten()) == tuple(s2.flatten())
  354. for s in s3:
  355. assert tuple(s.flatten()) in p.pspace.distribution.set
  356. except NotImplementedError:
  357. continue
  358. #
  359. # When this passes the pymc part can be uncommented in test_issue_21057 above
  360. # and this can be deleted.
  361. #
  362. @XFAIL
  363. def test_issue_21057_pymc():
  364. m = Normal("x", [0, 0], [[0, 0], [0, 0]])
  365. n = MultivariateNormal("x", [0, 0], [[0, 0], [0, 0]])
  366. p = Normal("x", [0, 0], [[0, 0], [0, 1]])
  367. assert m == n
  368. libraries = ('pymc',)
  369. for library in libraries:
  370. try:
  371. imported_lib = import_module(library)
  372. if imported_lib:
  373. s1 = sample(m, size=8, library=library)
  374. s2 = sample(n, size=8, library=library)
  375. s3 = sample(p, size=8, library=library)
  376. assert tuple(s1.flatten()) == tuple(s2.flatten())
  377. for s in s3:
  378. assert tuple(s.flatten()) in p.pspace.distribution.set
  379. except NotImplementedError:
  380. continue