123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441 |
- from sympy.concrete.summations import Sum
- from sympy.core.basic import Basic
- from sympy.core.containers import Tuple
- from sympy.core.function import Lambda
- from sympy.core.numbers import (Rational, nan, oo, pi)
- from sympy.core.relational import Eq
- from sympy.core.singleton import S
- from sympy.core.symbol import (Symbol, symbols)
- from sympy.functions.combinatorial.factorials import (FallingFactorial, binomial)
- from sympy.functions.elementary.exponential import (exp, log)
- from sympy.functions.elementary.trigonometric import (cos, sin)
- from sympy.functions.special.delta_functions import DiracDelta
- from sympy.integrals.integrals import integrate
- from sympy.logic.boolalg import (And, Or)
- from sympy.matrices.dense import Matrix
- from sympy.sets.sets import Interval
- from sympy.tensor.indexed import Indexed
- from sympy.stats import (Die, Normal, Exponential, FiniteRV, P, E, H, variance,
- density, given, independent, dependent, where, pspace, GaussianUnitaryEnsemble,
- random_symbols, sample, Geometric, factorial_moment, Binomial, Hypergeometric,
- DiscreteUniform, Poisson, characteristic_function, moment_generating_function,
- BernoulliProcess, Variance, Expectation, Probability, Covariance, covariance, cmoment,
- moment, median)
- from sympy.stats.rv import (IndependentProductPSpace, rs_swap, Density, NamedArgsMixin,
- RandomSymbol, sample_iter, PSpace, is_random, RandomIndexedSymbol, RandomMatrixSymbol)
- from sympy.testing.pytest import raises, skip, XFAIL, warns_deprecated_sympy
- from sympy.external import import_module
- from sympy.core.numbers import comp
- from sympy.stats.frv_types import BernoulliDistribution
- from sympy.core.symbol import Dummy
- from sympy.functions.elementary.piecewise import Piecewise
- def test_where():
- X, Y = Die('X'), Die('Y')
- Z = Normal('Z', 0, 1)
- assert where(Z**2 <= 1).set == Interval(-1, 1)
- assert where(Z**2 <= 1).as_boolean() == Interval(-1, 1).as_relational(Z.symbol)
- assert where(And(X > Y, Y > 4)).as_boolean() == And(
- Eq(X.symbol, 6), Eq(Y.symbol, 5))
- assert len(where(X < 3).set) == 2
- assert 1 in where(X < 3).set
- X, Y = Normal('X', 0, 1), Normal('Y', 0, 1)
- assert where(And(X**2 <= 1, X >= 0)).set == Interval(0, 1)
- XX = given(X, And(X**2 <= 1, X >= 0))
- assert XX.pspace.domain.set == Interval(0, 1)
- assert XX.pspace.domain.as_boolean() == \
- And(0 <= X.symbol, X.symbol**2 <= 1, -oo < X.symbol, X.symbol < oo)
- with raises(TypeError):
- XX = given(X, X + 3)
- def test_random_symbols():
- X, Y = Normal('X', 0, 1), Normal('Y', 0, 1)
- assert set(random_symbols(2*X + 1)) == {X}
- assert set(random_symbols(2*X + Y)) == {X, Y}
- assert set(random_symbols(2*X + Y.symbol)) == {X}
- assert set(random_symbols(2)) == set()
- def test_characteristic_function():
- # Imports I from sympy
- from sympy.core.numbers import I
- X = Normal('X',0,1)
- Y = DiscreteUniform('Y', [1,2,7])
- Z = Poisson('Z', 2)
- t = symbols('_t')
- P = Lambda(t, exp(-t**2/2))
- Q = Lambda(t, exp(7*t*I)/3 + exp(2*t*I)/3 + exp(t*I)/3)
- R = Lambda(t, exp(2 * exp(t*I) - 2))
- assert characteristic_function(X).dummy_eq(P)
- assert characteristic_function(Y).dummy_eq(Q)
- assert characteristic_function(Z).dummy_eq(R)
- def test_moment_generating_function():
- X = Normal('X',0,1)
- Y = DiscreteUniform('Y', [1,2,7])
- Z = Poisson('Z', 2)
- t = symbols('_t')
- P = Lambda(t, exp(t**2/2))
- Q = Lambda(t, (exp(7*t)/3 + exp(2*t)/3 + exp(t)/3))
- R = Lambda(t, exp(2 * exp(t) - 2))
- assert moment_generating_function(X).dummy_eq(P)
- assert moment_generating_function(Y).dummy_eq(Q)
- assert moment_generating_function(Z).dummy_eq(R)
- def test_sample_iter():
- X = Normal('X',0,1)
- Y = DiscreteUniform('Y', [1, 2, 7])
- Z = Poisson('Z', 2)
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy is not installed. Abort tests')
- expr = X**2 + 3
- iterator = sample_iter(expr)
- expr2 = Y**2 + 5*Y + 4
- iterator2 = sample_iter(expr2)
- expr3 = Z**3 + 4
- iterator3 = sample_iter(expr3)
- def is_iterator(obj):
- if (
- hasattr(obj, '__iter__') and
- (hasattr(obj, 'next') or
- hasattr(obj, '__next__')) and
- callable(obj.__iter__) and
- obj.__iter__() is obj
- ):
- return True
- else:
- return False
- assert is_iterator(iterator)
- assert is_iterator(iterator2)
- assert is_iterator(iterator3)
- def test_pspace():
- X, Y = Normal('X', 0, 1), Normal('Y', 0, 1)
- x = Symbol('x')
- raises(ValueError, lambda: pspace(5 + 3))
- raises(ValueError, lambda: pspace(x < 1))
- assert pspace(X) == X.pspace
- assert pspace(2*X + 1) == X.pspace
- assert pspace(2*X + Y) == IndependentProductPSpace(Y.pspace, X.pspace)
- def test_rs_swap():
- X = Normal('x', 0, 1)
- Y = Exponential('y', 1)
- XX = Normal('x', 0, 2)
- YY = Normal('y', 0, 3)
- expr = 2*X + Y
- assert expr.subs(rs_swap((X, Y), (YY, XX))) == 2*XX + YY
- def test_RandomSymbol():
- X = Normal('x', 0, 1)
- Y = Normal('x', 0, 2)
- assert X.symbol == Y.symbol
- assert X != Y
- assert X.name == X.symbol.name
- X = Normal('lambda', 0, 1) # make sure we can use protected terms
- X = Normal('Lambda', 0, 1) # make sure we can use SymPy terms
- def test_RandomSymbol_diff():
- X = Normal('x', 0, 1)
- assert (2*X).diff(X)
- def test_random_symbol_no_pspace():
- x = RandomSymbol(Symbol('x'))
- assert x.pspace == PSpace()
- def test_overlap():
- X = Normal('x', 0, 1)
- Y = Normal('x', 0, 2)
- raises(ValueError, lambda: P(X > Y))
- def test_IndependentProductPSpace():
- X = Normal('X', 0, 1)
- Y = Normal('Y', 0, 1)
- px = X.pspace
- py = Y.pspace
- assert pspace(X + Y) == IndependentProductPSpace(px, py)
- assert pspace(X + Y) == IndependentProductPSpace(py, px)
- def test_E():
- assert E(5) == 5
- def test_H():
- X = Normal('X', 0, 1)
- D = Die('D', sides = 4)
- G = Geometric('G', 0.5)
- assert H(X, X > 0) == -log(2)/2 + S.Half + log(pi)/2
- assert H(D, D > 2) == log(2)
- assert comp(H(G).evalf().round(2), 1.39)
- def test_Sample():
- X = Die('X', 6)
- Y = Normal('Y', 0, 1)
- z = Symbol('z', integer=True)
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy is not installed. Abort tests')
- assert sample(X) in [1, 2, 3, 4, 5, 6]
- assert isinstance(sample(X + Y), float)
- assert P(X + Y > 0, Y < 0, numsamples=10).is_number
- assert E(X + Y, numsamples=10).is_number
- assert E(X**2 + Y, numsamples=10).is_number
- assert E((X + Y)**2, numsamples=10).is_number
- assert variance(X + Y, numsamples=10).is_number
- raises(TypeError, lambda: P(Y > z, numsamples=5))
- assert P(sin(Y) <= 1, numsamples=10) == 1.0
- assert P(sin(Y) <= 1, cos(Y) < 1, numsamples=10) == 1.0
- assert all(i in range(1, 7) for i in density(X, numsamples=10))
- assert all(i in range(4, 7) for i in density(X, X>3, numsamples=10))
- numpy = import_module('numpy')
- if not numpy:
- skip('Numpy is not installed. Abort tests')
- #Test Issue #21563: Output of sample must be a float or array
- assert isinstance(sample(X), (numpy.int32, numpy.int64))
- assert isinstance(sample(Y), numpy.float64)
- assert isinstance(sample(X, size=2), numpy.ndarray)
- with warns_deprecated_sympy():
- sample(X, numsamples=2)
- @XFAIL
- def test_samplingE():
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy is not installed. Abort tests')
- Y = Normal('Y', 0, 1)
- z = Symbol('z', integer=True)
- assert E(Sum(1/z**Y, (z, 1, oo)), Y > 2, numsamples=3).is_number
- def test_given():
- X = Normal('X', 0, 1)
- Y = Normal('Y', 0, 1)
- A = given(X, True)
- B = given(X, Y > 2)
- assert X == A == B
- def test_factorial_moment():
- X = Poisson('X', 2)
- Y = Binomial('Y', 2, S.Half)
- Z = Hypergeometric('Z', 4, 2, 2)
- assert factorial_moment(X, 2) == 4
- assert factorial_moment(Y, 2) == S.Half
- assert factorial_moment(Z, 2) == Rational(1, 3)
- x, y, z, l = symbols('x y z l')
- Y = Binomial('Y', 2, y)
- Z = Hypergeometric('Z', 10, 2, 3)
- assert factorial_moment(Y, l) == y**2*FallingFactorial(
- 2, l) + 2*y*(1 - y)*FallingFactorial(1, l) + (1 - y)**2*\
- FallingFactorial(0, l)
- assert factorial_moment(Z, l) == 7*FallingFactorial(0, l)/\
- 15 + 7*FallingFactorial(1, l)/15 + FallingFactorial(2, l)/15
- def test_dependence():
- X, Y = Die('X'), Die('Y')
- assert independent(X, 2*Y)
- assert not dependent(X, 2*Y)
- X, Y = Normal('X', 0, 1), Normal('Y', 0, 1)
- assert independent(X, Y)
- assert dependent(X, 2*X)
- # Create a dependency
- XX, YY = given(Tuple(X, Y), Eq(X + Y, 3))
- assert dependent(XX, YY)
- def test_dependent_finite():
- X, Y = Die('X'), Die('Y')
- # Dependence testing requires symbolic conditions which currently break
- # finite random variables
- assert dependent(X, Y + X)
- XX, YY = given(Tuple(X, Y), X + Y > 5) # Create a dependency
- assert dependent(XX, YY)
- def test_normality():
- X, Y = Normal('X', 0, 1), Normal('Y', 0, 1)
- x = Symbol('x', real=True)
- z = Symbol('z', real=True)
- dens = density(X - Y, Eq(X + Y, z))
- assert integrate(dens(x), (x, -oo, oo)) == 1
- def test_Density():
- X = Die('X', 6)
- d = Density(X)
- assert d.doit() == density(X)
- def test_NamedArgsMixin():
- class Foo(Basic, NamedArgsMixin):
- _argnames = 'foo', 'bar'
- a = Foo(S(1), S(2))
- assert a.foo == 1
- assert a.bar == 2
- raises(AttributeError, lambda: a.baz)
- class Bar(Basic, NamedArgsMixin):
- pass
- raises(AttributeError, lambda: Bar(S(1), S(2)).foo)
- def test_density_constant():
- assert density(3)(2) == 0
- assert density(3)(3) == DiracDelta(0)
- def test_cmoment_constant():
- assert variance(3) == 0
- assert cmoment(3, 3) == 0
- assert cmoment(3, 4) == 0
- x = Symbol('x')
- assert variance(x) == 0
- assert cmoment(x, 15) == 0
- assert cmoment(x, 0) == 1
- def test_moment_constant():
- assert moment(3, 0) == 1
- assert moment(3, 1) == 3
- assert moment(3, 2) == 9
- x = Symbol('x')
- assert moment(x, 2) == x**2
- def test_median_constant():
- assert median(3) == 3
- x = Symbol('x')
- assert median(x) == x
- def test_real():
- x = Normal('x', 0, 1)
- assert x.is_real
- def test_issue_10052():
- X = Exponential('X', 3)
- assert P(X < oo) == 1
- assert P(X > oo) == 0
- assert P(X < 2, X > oo) == 0
- assert P(X < oo, X > oo) == 0
- assert P(X < oo, X > 2) == 1
- assert P(X < 3, X == 2) == 0
- raises(ValueError, lambda: P(1))
- raises(ValueError, lambda: P(X < 1, 2))
- def test_issue_11934():
- density = {0: .5, 1: .5}
- X = FiniteRV('X', density)
- assert E(X) == 0.5
- assert P( X>= 2) == 0
- def test_issue_8129():
- X = Exponential('X', 4)
- assert P(X >= X) == 1
- assert P(X > X) == 0
- assert P(X > X+1) == 0
- def test_issue_12237():
- X = Normal('X', 0, 1)
- Y = Normal('Y', 0, 1)
- U = P(X > 0, X)
- V = P(Y < 0, X)
- W = P(X + Y > 0, X)
- assert W == P(X + Y > 0, X)
- assert U == BernoulliDistribution(S.Half, S.Zero, S.One)
- assert V == S.Half
- def test_is_random():
- X = Normal('X', 0, 1)
- Y = Normal('Y', 0, 1)
- a, b = symbols('a, b')
- G = GaussianUnitaryEnsemble('U', 2)
- B = BernoulliProcess('B', 0.9)
- assert not is_random(a)
- assert not is_random(a + b)
- assert not is_random(a * b)
- assert not is_random(Matrix([a**2, b**2]))
- assert is_random(X)
- assert is_random(X**2 + Y)
- assert is_random(Y + b**2)
- assert is_random(Y > 5)
- assert is_random(B[3] < 1)
- assert is_random(G)
- assert is_random(X * Y * B[1])
- assert is_random(Matrix([[X, B[2]], [G, Y]]))
- assert is_random(Eq(X, 4))
- def test_issue_12283():
- x = symbols('x')
- X = RandomSymbol(x)
- Y = RandomSymbol('Y')
- Z = RandomMatrixSymbol('Z', 2, 1)
- W = RandomMatrixSymbol('W', 2, 1)
- RI = RandomIndexedSymbol(Indexed('RI', 3))
- assert pspace(Z) == PSpace()
- assert pspace(RI) == PSpace()
- assert pspace(X) == PSpace()
- assert E(X) == Expectation(X)
- assert P(Y > 3) == Probability(Y > 3)
- assert variance(X) == Variance(X)
- assert variance(RI) == Variance(RI)
- assert covariance(X, Y) == Covariance(X, Y)
- assert covariance(W, Z) == Covariance(W, Z)
- def test_issue_6810():
- X = Die('X', 6)
- Y = Normal('Y', 0, 1)
- assert P(Eq(X, 2)) == S(1)/6
- assert P(Eq(Y, 0)) == 0
- assert P(Or(X > 2, X < 3)) == 1
- assert P(And(X > 3, X > 2)) == S(1)/2
- def test_issue_20286():
- n, p = symbols('n p')
- B = Binomial('B', n, p)
- k = Dummy('k', integer = True)
- eq = Sum(Piecewise((-p**k*(1 - p)**(-k + n)*log(p**k*(1 - p)**(-k + n)*binomial(n, k))*binomial(n, k), (k >= 0) & (k <= n)), (nan, True)), (k, 0, n))
- assert eq.dummy_eq(H(B))
|