123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- from sympy.core.numbers import oo
- from sympy.core.symbol import Symbol
- from sympy.functions.elementary.exponential import exp
- from sympy.sets.sets import Interval
- from sympy.external import import_module
- from sympy.stats import Beta, Chi, Normal, Gamma, Exponential, LogNormal, Pareto, ChiSquared, Uniform, sample, \
- BetaPrime, Cauchy, GammaInverse, GaussianInverse, StudentT, Weibull, density, ContinuousRV, FDistribution, \
- Gumbel, Laplace, Logistic, Rayleigh, Triangular
- from sympy.testing.pytest import skip, raises
- def test_sample_numpy():
- distribs_numpy = [
- Beta("B", 1, 1),
- Normal("N", 0, 1),
- Gamma("G", 2, 7),
- Exponential("E", 2),
- LogNormal("LN", 0, 1),
- Pareto("P", 1, 1),
- ChiSquared("CS", 2),
- Uniform("U", 0, 1),
- FDistribution("FD", 1, 2),
- Gumbel("GB", 1, 2),
- Laplace("L", 1, 2),
- Logistic("LO", 1, 2),
- Rayleigh("R", 1),
- Triangular("T", 1, 2, 2),
- ]
- size = 3
- numpy = import_module('numpy')
- if not numpy:
- skip('Numpy is not installed. Abort tests for _sample_numpy.')
- else:
- for X in distribs_numpy:
- samps = sample(X, size=size, library='numpy')
- for sam in samps:
- assert sam in X.pspace.domain.set
- raises(NotImplementedError,
- lambda: sample(Chi("C", 1), library='numpy'))
- raises(NotImplementedError,
- lambda: Chi("C", 1).pspace.distribution.sample(library='tensorflow'))
- def test_sample_scipy():
- distribs_scipy = [
- Beta("B", 1, 1),
- BetaPrime("BP", 1, 1),
- Cauchy("C", 1, 1),
- Chi("C", 1),
- Normal("N", 0, 1),
- Gamma("G", 2, 7),
- GammaInverse("GI", 1, 1),
- GaussianInverse("GUI", 1, 1),
- Exponential("E", 2),
- LogNormal("LN", 0, 1),
- Pareto("P", 1, 1),
- StudentT("S", 2),
- ChiSquared("CS", 2),
- Uniform("U", 0, 1)
- ]
- size = 3
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy is not installed. Abort tests for _sample_scipy.')
- else:
- for X in distribs_scipy:
- samps = sample(X, size=size, library='scipy')
- samps2 = sample(X, size=(2, 2), library='scipy')
- for sam in samps:
- assert sam in X.pspace.domain.set
- for i in range(2):
- for j in range(2):
- assert samps2[i][j] in X.pspace.domain.set
- def test_sample_pymc():
- distribs_pymc = [
- Beta("B", 1, 1),
- Cauchy("C", 1, 1),
- Normal("N", 0, 1),
- Gamma("G", 2, 7),
- GaussianInverse("GI", 1, 1),
- Exponential("E", 2),
- LogNormal("LN", 0, 1),
- Pareto("P", 1, 1),
- ChiSquared("CS", 2),
- Uniform("U", 0, 1)
- ]
- size = 3
- pymc = import_module('pymc')
- if not pymc:
- skip('PyMC is not installed. Abort tests for _sample_pymc.')
- else:
- for X in distribs_pymc:
- samps = sample(X, size=size, library='pymc')
- for sam in samps:
- assert sam in X.pspace.domain.set
- raises(NotImplementedError,
- lambda: sample(Chi("C", 1), library='pymc'))
- def test_sampling_gamma_inverse():
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy not installed. Abort tests for sampling of gamma inverse.')
- X = GammaInverse("x", 1, 1)
- assert sample(X) in X.pspace.domain.set
- def test_lognormal_sampling():
- # Right now, only density function and sampling works
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy is not installed. Abort tests')
- for i in range(3):
- X = LogNormal('x', i, 1)
- assert sample(X) in X.pspace.domain.set
- size = 5
- samps = sample(X, size=size)
- for samp in samps:
- assert samp in X.pspace.domain.set
- def test_sampling_gaussian_inverse():
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy not installed. Abort tests for sampling of Gaussian inverse.')
- X = GaussianInverse("x", 1, 1)
- assert sample(X, library='scipy') in X.pspace.domain.set
- def test_prefab_sampling():
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy is not installed. Abort tests')
- N = Normal('X', 0, 1)
- L = LogNormal('L', 0, 1)
- E = Exponential('Ex', 1)
- P = Pareto('P', 1, 3)
- W = Weibull('W', 1, 1)
- U = Uniform('U', 0, 1)
- B = Beta('B', 2, 5)
- G = Gamma('G', 1, 3)
- variables = [N, L, E, P, W, U, B, G]
- niter = 10
- size = 5
- for var in variables:
- for _ in range(niter):
- assert sample(var) in var.pspace.domain.set
- samps = sample(var, size=size)
- for samp in samps:
- assert samp in var.pspace.domain.set
- def test_sample_continuous():
- z = Symbol('z')
- Z = ContinuousRV(z, exp(-z), set=Interval(0, oo))
- assert density(Z)(-1) == 0
- scipy = import_module('scipy')
- if not scipy:
- skip('Scipy is not installed. Abort tests')
- assert sample(Z) in Z.pspace.domain.set
- sym, val = list(Z.pspace.sample().items())[0]
- assert sym == Z and val in Interval(0, oo)
- libraries = ['scipy', 'numpy', 'pymc']
- for lib in libraries:
- try:
- imported_lib = import_module(lib)
- if imported_lib:
- s0, s1, s2 = [], [], []
- s0 = sample(Z, size=10, library=lib, seed=0)
- s1 = sample(Z, size=10, library=lib, seed=0)
- s2 = sample(Z, size=10, library=lib, seed=1)
- assert all(s0 == s1)
- assert all(s1 != s2)
- except NotImplementedError:
- continue
|