test_sample_continuous_rv.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. from sympy.core.numbers import oo
  2. from sympy.core.symbol import Symbol
  3. from sympy.functions.elementary.exponential import exp
  4. from sympy.sets.sets import Interval
  5. from sympy.external import import_module
  6. from sympy.stats import Beta, Chi, Normal, Gamma, Exponential, LogNormal, Pareto, ChiSquared, Uniform, sample, \
  7. BetaPrime, Cauchy, GammaInverse, GaussianInverse, StudentT, Weibull, density, ContinuousRV, FDistribution, \
  8. Gumbel, Laplace, Logistic, Rayleigh, Triangular
  9. from sympy.testing.pytest import skip, raises
  10. def test_sample_numpy():
  11. distribs_numpy = [
  12. Beta("B", 1, 1),
  13. Normal("N", 0, 1),
  14. Gamma("G", 2, 7),
  15. Exponential("E", 2),
  16. LogNormal("LN", 0, 1),
  17. Pareto("P", 1, 1),
  18. ChiSquared("CS", 2),
  19. Uniform("U", 0, 1),
  20. FDistribution("FD", 1, 2),
  21. Gumbel("GB", 1, 2),
  22. Laplace("L", 1, 2),
  23. Logistic("LO", 1, 2),
  24. Rayleigh("R", 1),
  25. Triangular("T", 1, 2, 2),
  26. ]
  27. size = 3
  28. numpy = import_module('numpy')
  29. if not numpy:
  30. skip('Numpy is not installed. Abort tests for _sample_numpy.')
  31. else:
  32. for X in distribs_numpy:
  33. samps = sample(X, size=size, library='numpy')
  34. for sam in samps:
  35. assert sam in X.pspace.domain.set
  36. raises(NotImplementedError,
  37. lambda: sample(Chi("C", 1), library='numpy'))
  38. raises(NotImplementedError,
  39. lambda: Chi("C", 1).pspace.distribution.sample(library='tensorflow'))
  40. def test_sample_scipy():
  41. distribs_scipy = [
  42. Beta("B", 1, 1),
  43. BetaPrime("BP", 1, 1),
  44. Cauchy("C", 1, 1),
  45. Chi("C", 1),
  46. Normal("N", 0, 1),
  47. Gamma("G", 2, 7),
  48. GammaInverse("GI", 1, 1),
  49. GaussianInverse("GUI", 1, 1),
  50. Exponential("E", 2),
  51. LogNormal("LN", 0, 1),
  52. Pareto("P", 1, 1),
  53. StudentT("S", 2),
  54. ChiSquared("CS", 2),
  55. Uniform("U", 0, 1)
  56. ]
  57. size = 3
  58. scipy = import_module('scipy')
  59. if not scipy:
  60. skip('Scipy is not installed. Abort tests for _sample_scipy.')
  61. else:
  62. for X in distribs_scipy:
  63. samps = sample(X, size=size, library='scipy')
  64. samps2 = sample(X, size=(2, 2), library='scipy')
  65. for sam in samps:
  66. assert sam in X.pspace.domain.set
  67. for i in range(2):
  68. for j in range(2):
  69. assert samps2[i][j] in X.pspace.domain.set
  70. def test_sample_pymc():
  71. distribs_pymc = [
  72. Beta("B", 1, 1),
  73. Cauchy("C", 1, 1),
  74. Normal("N", 0, 1),
  75. Gamma("G", 2, 7),
  76. GaussianInverse("GI", 1, 1),
  77. Exponential("E", 2),
  78. LogNormal("LN", 0, 1),
  79. Pareto("P", 1, 1),
  80. ChiSquared("CS", 2),
  81. Uniform("U", 0, 1)
  82. ]
  83. size = 3
  84. pymc = import_module('pymc')
  85. if not pymc:
  86. skip('PyMC is not installed. Abort tests for _sample_pymc.')
  87. else:
  88. for X in distribs_pymc:
  89. samps = sample(X, size=size, library='pymc')
  90. for sam in samps:
  91. assert sam in X.pspace.domain.set
  92. raises(NotImplementedError,
  93. lambda: sample(Chi("C", 1), library='pymc'))
  94. def test_sampling_gamma_inverse():
  95. scipy = import_module('scipy')
  96. if not scipy:
  97. skip('Scipy not installed. Abort tests for sampling of gamma inverse.')
  98. X = GammaInverse("x", 1, 1)
  99. assert sample(X) in X.pspace.domain.set
  100. def test_lognormal_sampling():
  101. # Right now, only density function and sampling works
  102. scipy = import_module('scipy')
  103. if not scipy:
  104. skip('Scipy is not installed. Abort tests')
  105. for i in range(3):
  106. X = LogNormal('x', i, 1)
  107. assert sample(X) in X.pspace.domain.set
  108. size = 5
  109. samps = sample(X, size=size)
  110. for samp in samps:
  111. assert samp in X.pspace.domain.set
  112. def test_sampling_gaussian_inverse():
  113. scipy = import_module('scipy')
  114. if not scipy:
  115. skip('Scipy not installed. Abort tests for sampling of Gaussian inverse.')
  116. X = GaussianInverse("x", 1, 1)
  117. assert sample(X, library='scipy') in X.pspace.domain.set
  118. def test_prefab_sampling():
  119. scipy = import_module('scipy')
  120. if not scipy:
  121. skip('Scipy is not installed. Abort tests')
  122. N = Normal('X', 0, 1)
  123. L = LogNormal('L', 0, 1)
  124. E = Exponential('Ex', 1)
  125. P = Pareto('P', 1, 3)
  126. W = Weibull('W', 1, 1)
  127. U = Uniform('U', 0, 1)
  128. B = Beta('B', 2, 5)
  129. G = Gamma('G', 1, 3)
  130. variables = [N, L, E, P, W, U, B, G]
  131. niter = 10
  132. size = 5
  133. for var in variables:
  134. for _ in range(niter):
  135. assert sample(var) in var.pspace.domain.set
  136. samps = sample(var, size=size)
  137. for samp in samps:
  138. assert samp in var.pspace.domain.set
  139. def test_sample_continuous():
  140. z = Symbol('z')
  141. Z = ContinuousRV(z, exp(-z), set=Interval(0, oo))
  142. assert density(Z)(-1) == 0
  143. scipy = import_module('scipy')
  144. if not scipy:
  145. skip('Scipy is not installed. Abort tests')
  146. assert sample(Z) in Z.pspace.domain.set
  147. sym, val = list(Z.pspace.sample().items())[0]
  148. assert sym == Z and val in Interval(0, oo)
  149. libraries = ['scipy', 'numpy', 'pymc']
  150. for lib in libraries:
  151. try:
  152. imported_lib = import_module(lib)
  153. if imported_lib:
  154. s0, s1, s2 = [], [], []
  155. s0 = sample(Z, size=10, library=lib, seed=0)
  156. s1 = sample(Z, size=10, library=lib, seed=0)
  157. s2 = sample(Z, size=10, library=lib, seed=1)
  158. assert all(s0 == s1)
  159. assert all(s1 != s2)
  160. except NotImplementedError:
  161. continue