test_sample_discrete_rv.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from sympy.core.singleton import S
  2. from sympy.core.symbol import Symbol
  3. from sympy.external import import_module
  4. from sympy.stats import Geometric, Poisson, Zeta, sample, Skellam, DiscreteRV, Logarithmic, NegativeBinomial, YuleSimon
  5. from sympy.testing.pytest import skip, raises, slow
  6. def test_sample_numpy():
  7. distribs_numpy = [
  8. Geometric('G', 0.5),
  9. Poisson('P', 1),
  10. Zeta('Z', 2)
  11. ]
  12. size = 3
  13. numpy = import_module('numpy')
  14. if not numpy:
  15. skip('Numpy is not installed. Abort tests for _sample_numpy.')
  16. else:
  17. for X in distribs_numpy:
  18. samps = sample(X, size=size, library='numpy')
  19. for sam in samps:
  20. assert sam in X.pspace.domain.set
  21. raises(NotImplementedError,
  22. lambda: sample(Skellam('S', 1, 1), library='numpy'))
  23. raises(NotImplementedError,
  24. lambda: Skellam('S', 1, 1).pspace.distribution.sample(library='tensorflow'))
  25. def test_sample_scipy():
  26. p = S(2)/3
  27. x = Symbol('x', integer=True, positive=True)
  28. pdf = p*(1 - p)**(x - 1) # pdf of Geometric Distribution
  29. distribs_scipy = [
  30. DiscreteRV(x, pdf, set=S.Naturals),
  31. Geometric('G', 0.5),
  32. Logarithmic('L', 0.5),
  33. NegativeBinomial('N', 5, 0.4),
  34. Poisson('P', 1),
  35. Skellam('S', 1, 1),
  36. YuleSimon('Y', 1),
  37. Zeta('Z', 2)
  38. ]
  39. size = 3
  40. scipy = import_module('scipy')
  41. if not scipy:
  42. skip('Scipy is not installed. Abort tests for _sample_scipy.')
  43. else:
  44. for X in distribs_scipy:
  45. samps = sample(X, size=size, library='scipy')
  46. samps2 = sample(X, size=(2, 2), library='scipy')
  47. for sam in samps:
  48. assert sam in X.pspace.domain.set
  49. for i in range(2):
  50. for j in range(2):
  51. assert samps2[i][j] in X.pspace.domain.set
  52. def test_sample_pymc():
  53. distribs_pymc = [
  54. Geometric('G', 0.5),
  55. Poisson('P', 1),
  56. NegativeBinomial('N', 5, 0.4)
  57. ]
  58. size = 3
  59. pymc = import_module('pymc')
  60. if not pymc:
  61. skip('PyMC is not installed. Abort tests for _sample_pymc.')
  62. else:
  63. for X in distribs_pymc:
  64. samps = sample(X, size=size, library='pymc')
  65. for sam in samps:
  66. assert sam in X.pspace.domain.set
  67. raises(NotImplementedError,
  68. lambda: sample(Skellam('S', 1, 1), library='pymc'))
  69. @slow
  70. def test_sample_discrete():
  71. X = Geometric('X', S.Half)
  72. scipy = import_module('scipy')
  73. if not scipy:
  74. skip('Scipy not installed. Abort tests')
  75. assert sample(X) in X.pspace.domain.set
  76. samps = sample(X, size=2) # This takes long time if ran without scipy
  77. for samp in samps:
  78. assert samp in X.pspace.domain.set
  79. libraries = ['scipy', 'numpy', 'pymc']
  80. for lib in libraries:
  81. try:
  82. imported_lib = import_module(lib)
  83. if imported_lib:
  84. s0, s1, s2 = [], [], []
  85. s0 = sample(X, size=10, library=lib, seed=0)
  86. s1 = sample(X, size=10, library=lib, seed=0)
  87. s2 = sample(X, size=10, library=lib, seed=1)
  88. assert all(s0 == s1)
  89. assert not all(s1 == s2)
  90. except NotImplementedError:
  91. continue