test_sample_finite_rv.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from sympy.core.numbers import Rational
  2. from sympy.core.singleton import S
  3. from sympy.external import import_module
  4. from sympy.stats import Binomial, sample, Die, FiniteRV, DiscreteUniform, Bernoulli, BetaBinomial, Hypergeometric, \
  5. Rademacher
  6. from sympy.testing.pytest import skip, raises
  7. def test_given_sample():
  8. X = Die('X', 6)
  9. scipy = import_module('scipy')
  10. if not scipy:
  11. skip('Scipy is not installed. Abort tests')
  12. assert sample(X, X > 5) == 6
  13. def test_sample_numpy():
  14. distribs_numpy = [
  15. Binomial("B", 5, 0.4),
  16. Hypergeometric("H", 2, 1, 1)
  17. ]
  18. size = 3
  19. numpy = import_module('numpy')
  20. if not numpy:
  21. skip('Numpy is not installed. Abort tests for _sample_numpy.')
  22. else:
  23. for X in distribs_numpy:
  24. samps = sample(X, size=size, library='numpy')
  25. for sam in samps:
  26. assert sam in X.pspace.domain.set
  27. raises(NotImplementedError,
  28. lambda: sample(Die("D"), library='numpy'))
  29. raises(NotImplementedError,
  30. lambda: Die("D").pspace.sample(library='tensorflow'))
  31. def test_sample_scipy():
  32. distribs_scipy = [
  33. FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)}),
  34. DiscreteUniform("Y", list(range(5))),
  35. Die("D"),
  36. Bernoulli("Be", 0.3),
  37. Binomial("Bi", 5, 0.4),
  38. BetaBinomial("Bb", 2, 1, 1),
  39. Hypergeometric("H", 1, 1, 1),
  40. Rademacher("R")
  41. ]
  42. size = 3
  43. scipy = import_module('scipy')
  44. if not scipy:
  45. skip('Scipy not installed. Abort tests for _sample_scipy.')
  46. else:
  47. for X in distribs_scipy:
  48. samps = sample(X, size=size)
  49. samps2 = sample(X, size=(2, 2))
  50. for sam in samps:
  51. assert sam in X.pspace.domain.set
  52. for i in range(2):
  53. for j in range(2):
  54. assert samps2[i][j] in X.pspace.domain.set
  55. def test_sample_pymc():
  56. distribs_pymc = [
  57. Bernoulli('B', 0.2),
  58. Binomial('N', 5, 0.4)
  59. ]
  60. size = 3
  61. pymc = import_module('pymc')
  62. if not pymc:
  63. skip('PyMC is not installed. Abort tests for _sample_pymc.')
  64. else:
  65. for X in distribs_pymc:
  66. samps = sample(X, size=size, library='pymc')
  67. for sam in samps:
  68. assert sam in X.pspace.domain.set
  69. raises(NotImplementedError,
  70. lambda: (sample(Die("D"), library='pymc')))
  71. def test_sample_seed():
  72. F = FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)})
  73. size = 10
  74. libraries = ['scipy', 'numpy', 'pymc']
  75. for lib in libraries:
  76. try:
  77. imported_lib = import_module(lib)
  78. if imported_lib:
  79. s0 = sample(F, size=size, library=lib, seed=0)
  80. s1 = sample(F, size=size, library=lib, seed=0)
  81. s2 = sample(F, size=size, library=lib, seed=1)
  82. assert all(s0 == s1)
  83. assert not all(s1 == s2)
  84. except NotImplementedError:
  85. continue