test_compound_rv.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from sympy.concrete.summations import Sum
  2. from sympy.core.numbers import (oo, pi)
  3. from sympy.core.relational import Eq
  4. from sympy.core.singleton import S
  5. from sympy.core.symbol import symbols
  6. from sympy.functions.combinatorial.factorials import factorial
  7. from sympy.functions.elementary.exponential import exp
  8. from sympy.functions.elementary.miscellaneous import sqrt
  9. from sympy.functions.elementary.piecewise import Piecewise
  10. from sympy.functions.special.beta_functions import beta
  11. from sympy.functions.special.error_functions import erf
  12. from sympy.functions.special.gamma_functions import gamma
  13. from sympy.integrals.integrals import Integral
  14. from sympy.sets.sets import Interval
  15. from sympy.stats import (Normal, P, E, density, Gamma, Poisson, Rayleigh,
  16. variance, Bernoulli, Beta, Uniform, cdf)
  17. from sympy.stats.compound_rv import CompoundDistribution, CompoundPSpace
  18. from sympy.stats.crv_types import NormalDistribution
  19. from sympy.stats.drv_types import PoissonDistribution
  20. from sympy.stats.frv_types import BernoulliDistribution
  21. from sympy.testing.pytest import raises, ignore_warnings
  22. from sympy.stats.joint_rv_types import MultivariateNormalDistribution
  23. from sympy.abc import x
  24. # helpers for testing troublesome unevaluated expressions
  25. flat = lambda s: ''.join(str(s).split())
  26. streq = lambda *a: len(set(map(flat, a))) == 1
  27. assert streq(x, x)
  28. assert streq(x, 'x')
  29. assert not streq(x, x + 1)
  30. def test_normal_CompoundDist():
  31. X = Normal('X', 1, 2)
  32. Y = Normal('X', X, 4)
  33. assert density(Y)(x).simplify() == sqrt(10)*exp(-x**2/40 + x/20 - S(1)/40)/(20*sqrt(pi))
  34. assert E(Y) == 1 # it is always equal to mean of X
  35. assert P(Y > 1) == S(1)/2 # as 1 is the mean
  36. assert P(Y > 5).simplify() == S(1)/2 - erf(sqrt(10)/5)/2
  37. assert variance(Y) == variance(X) + 4**2 # 2**2 + 4**2
  38. # https://math.stackexchange.com/questions/1484451/
  39. # (Contains proof of E and variance computation)
  40. def test_poisson_CompoundDist():
  41. k, t, y = symbols('k t y', positive=True, real=True)
  42. G = Gamma('G', k, t)
  43. D = Poisson('P', G)
  44. assert density(D)(y).simplify() == t**y*(t + 1)**(-k - y)*gamma(k + y)/(gamma(k)*gamma(y + 1))
  45. # https://en.wikipedia.org/wiki/Negative_binomial_distribution#Gamma%E2%80%93Poisson_mixture
  46. assert E(D).simplify() == k*t # mean of NegativeBinomialDistribution
  47. def test_bernoulli_CompoundDist():
  48. X = Beta('X', 1, 2)
  49. Y = Bernoulli('Y', X)
  50. assert density(Y).dict == {0: S(2)/3, 1: S(1)/3}
  51. assert E(Y) == P(Eq(Y, 1)) == S(1)/3
  52. assert variance(Y) == S(2)/9
  53. assert cdf(Y) == {0: S(2)/3, 1: 1}
  54. # test issue 8128
  55. a = Bernoulli('a', S(1)/2)
  56. b = Bernoulli('b', a)
  57. assert density(b).dict == {0: S(1)/2, 1: S(1)/2}
  58. assert P(b > 0.5) == S(1)/2
  59. X = Uniform('X', 0, 1)
  60. Y = Bernoulli('Y', X)
  61. assert E(Y) == S(1)/2
  62. assert P(Eq(Y, 1)) == E(Y)
  63. def test_unevaluated_CompoundDist():
  64. # these tests need to be removed once they work with evaluation as they are currently not
  65. # evaluated completely in sympy.
  66. R = Rayleigh('R', 4)
  67. X = Normal('X', 3, R)
  68. ans = '''
  69. Piecewise(((-sqrt(pi)*sinh(x/4 - 3/4) + sqrt(pi)*cosh(x/4 - 3/4))/(
  70. 8*sqrt(pi)), Abs(arg(x - 3)) <= pi/4), (Integral(sqrt(2)*exp(-(x - 3)
  71. **2/(2*R**2))*exp(-R**2/32)/(32*sqrt(pi)), (R, 0, oo)), True))'''
  72. assert streq(density(X)(x), ans)
  73. expre = '''
  74. Integral(X*Integral(sqrt(2)*exp(-(X-3)**2/(2*R**2))*exp(-R**2/32)/(32*
  75. sqrt(pi)),(R,0,oo)),(X,-oo,oo))'''
  76. with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
  77. assert streq(E(X, evaluate=False).rewrite(Integral), expre)
  78. X = Poisson('X', 1)
  79. Y = Poisson('Y', X)
  80. Z = Poisson('Z', Y)
  81. exprd = Sum(exp(-Y)*Y**x*Sum(exp(-1)*exp(-X)*X**Y/(factorial(X)*factorial(Y)
  82. ), (X, 0, oo))/factorial(x), (Y, 0, oo))
  83. assert density(Z)(x) == exprd
  84. N = Normal('N', 1, 2)
  85. M = Normal('M', 3, 4)
  86. D = Normal('D', M, N)
  87. exprd = '''
  88. Integral(sqrt(2)*exp(-(N-1)**2/8)*Integral(exp(-(x-M)**2/(2*N**2))*exp
  89. (-(M-3)**2/32)/(8*pi*N),(M,-oo,oo))/(4*sqrt(pi)),(N,-oo,oo))'''
  90. assert streq(density(D, evaluate=False)(x), exprd)
  91. def test_Compound_Distribution():
  92. X = Normal('X', 2, 4)
  93. N = NormalDistribution(X, 4)
  94. C = CompoundDistribution(N)
  95. assert C.is_Continuous
  96. assert C.set == Interval(-oo, oo)
  97. assert C.pdf(x, evaluate=True).simplify() == exp(-x**2/64 + x/16 - S(1)/16)/(8*sqrt(pi))
  98. assert not isinstance(CompoundDistribution(NormalDistribution(2, 3)),
  99. CompoundDistribution)
  100. M = MultivariateNormalDistribution([1, 2], [[2, 1], [1, 2]])
  101. raises(NotImplementedError, lambda: CompoundDistribution(M))
  102. X = Beta('X', 2, 4)
  103. B = BernoulliDistribution(X, 1, 0)
  104. C = CompoundDistribution(B)
  105. assert C.is_Finite
  106. assert C.set == {0, 1}
  107. y = symbols('y', negative=False, integer=True)
  108. assert C.pdf(y, evaluate=True) == Piecewise((S(1)/(30*beta(2, 4)), Eq(y, 0)),
  109. (S(1)/(60*beta(2, 4)), Eq(y, 1)), (0, True))
  110. k, t, z = symbols('k t z', positive=True, real=True)
  111. G = Gamma('G', k, t)
  112. X = PoissonDistribution(G)
  113. C = CompoundDistribution(X)
  114. assert C.is_Discrete
  115. assert C.set == S.Naturals0
  116. assert C.pdf(z, evaluate=True).simplify() == t**z*(t + 1)**(-k - z)*gamma(k \
  117. + z)/(gamma(k)*gamma(z + 1))
  118. def test_compound_pspace():
  119. X = Normal('X', 2, 4)
  120. Y = Normal('Y', 3, 6)
  121. assert not isinstance(Y.pspace, CompoundPSpace)
  122. N = NormalDistribution(1, 2)
  123. D = PoissonDistribution(3)
  124. B = BernoulliDistribution(0.2, 1, 0)
  125. pspace1 = CompoundPSpace('N', N)
  126. pspace2 = CompoundPSpace('D', D)
  127. pspace3 = CompoundPSpace('B', B)
  128. assert not isinstance(pspace1, CompoundPSpace)
  129. assert not isinstance(pspace2, CompoundPSpace)
  130. assert not isinstance(pspace3, CompoundPSpace)
  131. M = MultivariateNormalDistribution([1, 2], [[2, 1], [1, 2]])
  132. raises(ValueError, lambda: CompoundPSpace('M', M))
  133. Y = Normal('Y', X, 6)
  134. assert isinstance(Y.pspace, CompoundPSpace)
  135. assert Y.pspace.distribution == CompoundDistribution(NormalDistribution(X, 6))
  136. assert Y.pspace.domain.set == Interval(-oo, oo)