test_mix.py 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from sympy.concrete.summations import Sum
  2. from sympy.core.add import Add
  3. from sympy.core.mul import Mul
  4. from sympy.core.numbers import (Integer, oo, pi)
  5. from sympy.core.power import Pow
  6. from sympy.core.relational import (Eq, Ne)
  7. from sympy.core.symbol import (Dummy, Symbol, symbols)
  8. from sympy.functions.combinatorial.factorials import factorial
  9. from sympy.functions.elementary.exponential import exp
  10. from sympy.functions.elementary.miscellaneous import sqrt
  11. from sympy.functions.elementary.piecewise import Piecewise
  12. from sympy.functions.special.delta_functions import DiracDelta
  13. from sympy.functions.special.gamma_functions import gamma
  14. from sympy.integrals.integrals import Integral
  15. from sympy.simplify.simplify import simplify
  16. from sympy.tensor.indexed import (Indexed, IndexedBase)
  17. from sympy.functions.elementary.piecewise import ExprCondPair
  18. from sympy.stats import (Poisson, Beta, Exponential, P,
  19. Multinomial, MultivariateBeta)
  20. from sympy.stats.crv_types import Normal
  21. from sympy.stats.drv_types import PoissonDistribution
  22. from sympy.stats.compound_rv import CompoundPSpace, CompoundDistribution
  23. from sympy.stats.joint_rv import MarginalDistribution
  24. from sympy.stats.rv import pspace, density
  25. from sympy.testing.pytest import ignore_warnings
  26. def test_density():
  27. x = Symbol('x')
  28. l = Symbol('l', positive=True)
  29. rate = Beta(l, 2, 3)
  30. X = Poisson(x, rate)
  31. assert isinstance(pspace(X), CompoundPSpace)
  32. assert density(X, Eq(rate, rate.symbol)) == PoissonDistribution(l)
  33. N1 = Normal('N1', 0, 1)
  34. N2 = Normal('N2', N1, 2)
  35. assert density(N2)(0).doit() == sqrt(10)/(10*sqrt(pi))
  36. assert simplify(density(N2, Eq(N1, 1))(x)) == \
  37. sqrt(2)*exp(-(x - 1)**2/8)/(4*sqrt(pi))
  38. assert simplify(density(N2)(x)) == sqrt(10)*exp(-x**2/10)/(10*sqrt(pi))
  39. def test_MarginalDistribution():
  40. a1, p1, p2 = symbols('a1 p1 p2', positive=True)
  41. C = Multinomial('C', 2, p1, p2)
  42. B = MultivariateBeta('B', a1, C[0])
  43. MGR = MarginalDistribution(B, (C[0],))
  44. mgrc = Mul(Symbol('B'), Piecewise(ExprCondPair(Mul(Integer(2),
  45. Pow(Symbol('p1', positive=True), Indexed(IndexedBase(Symbol('C')),
  46. Integer(0))), Pow(Symbol('p2', positive=True),
  47. Indexed(IndexedBase(Symbol('C')), Integer(1))),
  48. Pow(factorial(Indexed(IndexedBase(Symbol('C')), Integer(0))), Integer(-1)),
  49. Pow(factorial(Indexed(IndexedBase(Symbol('C')), Integer(1))), Integer(-1))),
  50. Eq(Add(Indexed(IndexedBase(Symbol('C')), Integer(0)),
  51. Indexed(IndexedBase(Symbol('C')), Integer(1))), Integer(2))),
  52. ExprCondPair(Integer(0), True)), Pow(gamma(Symbol('a1', positive=True)),
  53. Integer(-1)), gamma(Add(Symbol('a1', positive=True),
  54. Indexed(IndexedBase(Symbol('C')), Integer(0)))),
  55. Pow(gamma(Indexed(IndexedBase(Symbol('C')), Integer(0))), Integer(-1)),
  56. Pow(Indexed(IndexedBase(Symbol('B')), Integer(0)),
  57. Add(Symbol('a1', positive=True), Integer(-1))),
  58. Pow(Indexed(IndexedBase(Symbol('B')), Integer(1)),
  59. Add(Indexed(IndexedBase(Symbol('C')), Integer(0)), Integer(-1))))
  60. assert MGR(C) == mgrc
  61. def test_compound_distribution():
  62. Y = Poisson('Y', 1)
  63. Z = Poisson('Z', Y)
  64. assert isinstance(pspace(Z), CompoundPSpace)
  65. assert isinstance(pspace(Z).distribution, CompoundDistribution)
  66. assert Z.pspace.distribution.pdf(1).doit() == exp(-2)*exp(exp(-1))
  67. def test_mix_expression():
  68. Y, E = Poisson('Y', 1), Exponential('E', 1)
  69. k = Dummy('k')
  70. expr1 = Integral(Sum(exp(-1)*Integral(exp(-k)*DiracDelta(k - 2), (k, 0, oo)
  71. )/factorial(k), (k, 0, oo)), (k, -oo, 0))
  72. expr2 = Integral(Sum(exp(-1)*Integral(exp(-k)*DiracDelta(k - 2), (k, 0, oo)
  73. )/factorial(k), (k, 0, oo)), (k, 0, oo))
  74. assert P(Eq(Y + E, 1)) == 0
  75. assert P(Ne(Y + E, 2)) == 1
  76. with ignore_warnings(UserWarning): ### TODO: Restore tests once warnings are removed
  77. assert P(E + Y < 2, evaluate=False).rewrite(Integral).dummy_eq(expr1)
  78. assert P(E + Y > 2, evaluate=False).rewrite(Integral).dummy_eq(expr2)