sample_scipy.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from functools import singledispatch
  2. from sympy.core.symbol import Dummy
  3. from sympy.functions.elementary.exponential import exp
  4. from sympy.utilities.lambdify import lambdify
  5. from sympy.external import import_module
  6. from sympy.stats import DiscreteDistributionHandmade
  7. from sympy.stats.crv import SingleContinuousDistribution
  8. from sympy.stats.crv_types import ChiSquaredDistribution, ExponentialDistribution, GammaDistribution, \
  9. LogNormalDistribution, NormalDistribution, ParetoDistribution, UniformDistribution, BetaDistribution, \
  10. StudentTDistribution, CauchyDistribution
  11. from sympy.stats.drv_types import GeometricDistribution, LogarithmicDistribution, NegativeBinomialDistribution, \
  12. PoissonDistribution, SkellamDistribution, YuleSimonDistribution, ZetaDistribution
  13. from sympy.stats.frv import SingleFiniteDistribution
  14. scipy = import_module("scipy", import_kwargs={'fromlist':['stats']})
  15. @singledispatch
  16. def do_sample_scipy(dist, size, seed):
  17. return None
  18. # CRV
  19. @do_sample_scipy.register(SingleContinuousDistribution)
  20. def _(dist: SingleContinuousDistribution, size, seed):
  21. # if we don't need to make a handmade pdf, we won't
  22. import scipy.stats
  23. z = Dummy('z')
  24. handmade_pdf = lambdify(z, dist.pdf(z), ['numpy', 'scipy'])
  25. class scipy_pdf(scipy.stats.rv_continuous):
  26. def _pdf(dist, x):
  27. return handmade_pdf(x)
  28. scipy_rv = scipy_pdf(a=float(dist.set._inf),
  29. b=float(dist.set._sup), name='scipy_pdf')
  30. return scipy_rv.rvs(size=size, random_state=seed)
  31. @do_sample_scipy.register(ChiSquaredDistribution)
  32. def _(dist: ChiSquaredDistribution, size, seed):
  33. # same parametrisation
  34. return scipy.stats.chi2.rvs(df=float(dist.k), size=size, random_state=seed)
  35. @do_sample_scipy.register(ExponentialDistribution)
  36. def _(dist: ExponentialDistribution, size, seed):
  37. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.expon.html#scipy.stats.expon
  38. return scipy.stats.expon.rvs(scale=1 / float(dist.rate), size=size, random_state=seed)
  39. @do_sample_scipy.register(GammaDistribution)
  40. def _(dist: GammaDistribution, size, seed):
  41. # https://stackoverflow.com/questions/42150965/how-to-plot-gamma-distribution-with-alpha-and-beta-parameters-in-python
  42. return scipy.stats.gamma.rvs(a=float(dist.k), scale=float(dist.theta), size=size, random_state=seed)
  43. @do_sample_scipy.register(LogNormalDistribution)
  44. def _(dist: LogNormalDistribution, size, seed):
  45. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.lognorm.html
  46. return scipy.stats.lognorm.rvs(scale=float(exp(dist.mean)), s=float(dist.std), size=size, random_state=seed)
  47. @do_sample_scipy.register(NormalDistribution)
  48. def _(dist: NormalDistribution, size, seed):
  49. return scipy.stats.norm.rvs(loc=float(dist.mean), scale=float(dist.std), size=size, random_state=seed)
  50. @do_sample_scipy.register(ParetoDistribution)
  51. def _(dist: ParetoDistribution, size, seed):
  52. # https://stackoverflow.com/questions/42260519/defining-pareto-distribution-in-python-scipy
  53. return scipy.stats.pareto.rvs(b=float(dist.alpha), scale=float(dist.xm), size=size, random_state=seed)
  54. @do_sample_scipy.register(StudentTDistribution)
  55. def _(dist: StudentTDistribution, size, seed):
  56. return scipy.stats.t.rvs(df=float(dist.nu), size=size, random_state=seed)
  57. @do_sample_scipy.register(UniformDistribution)
  58. def _(dist: UniformDistribution, size, seed):
  59. # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.uniform.html
  60. return scipy.stats.uniform.rvs(loc=float(dist.left), scale=float(dist.right - dist.left), size=size, random_state=seed)
  61. @do_sample_scipy.register(BetaDistribution)
  62. def _(dist: BetaDistribution, size, seed):
  63. # same parametrisation
  64. return scipy.stats.beta.rvs(a=float(dist.alpha), b=float(dist.beta), size=size, random_state=seed)
  65. @do_sample_scipy.register(CauchyDistribution)
  66. def _(dist: CauchyDistribution, size, seed):
  67. return scipy.stats.cauchy.rvs(loc=float(dist.x0), scale=float(dist.gamma), size=size, random_state=seed)
  68. # DRV:
  69. @do_sample_scipy.register(DiscreteDistributionHandmade)
  70. def _(dist: DiscreteDistributionHandmade, size, seed):
  71. from scipy.stats import rv_discrete
  72. z = Dummy('z')
  73. handmade_pmf = lambdify(z, dist.pdf(z), ['numpy', 'scipy'])
  74. class scipy_pmf(rv_discrete):
  75. def _pmf(dist, x):
  76. return handmade_pmf(x)
  77. scipy_rv = scipy_pmf(a=float(dist.set._inf), b=float(dist.set._sup),
  78. name='scipy_pmf')
  79. return scipy_rv.rvs(size=size, random_state=seed)
  80. @do_sample_scipy.register(GeometricDistribution)
  81. def _(dist: GeometricDistribution, size, seed):
  82. return scipy.stats.geom.rvs(p=float(dist.p), size=size, random_state=seed)
  83. @do_sample_scipy.register(LogarithmicDistribution)
  84. def _(dist: LogarithmicDistribution, size, seed):
  85. return scipy.stats.logser.rvs(p=float(dist.p), size=size, random_state=seed)
  86. @do_sample_scipy.register(NegativeBinomialDistribution)
  87. def _(dist: NegativeBinomialDistribution, size, seed):
  88. return scipy.stats.nbinom.rvs(n=float(dist.r), p=float(dist.p), size=size, random_state=seed)
  89. @do_sample_scipy.register(PoissonDistribution)
  90. def _(dist: PoissonDistribution, size, seed):
  91. return scipy.stats.poisson.rvs(mu=float(dist.lamda), size=size, random_state=seed)
  92. @do_sample_scipy.register(SkellamDistribution)
  93. def _(dist: SkellamDistribution, size, seed):
  94. return scipy.stats.skellam.rvs(mu1=float(dist.mu1), mu2=float(dist.mu2), size=size, random_state=seed)
  95. @do_sample_scipy.register(YuleSimonDistribution)
  96. def _(dist: YuleSimonDistribution, size, seed):
  97. return scipy.stats.yulesimon.rvs(alpha=float(dist.rho), size=size, random_state=seed)
  98. @do_sample_scipy.register(ZetaDistribution)
  99. def _(dist: ZetaDistribution, size, seed):
  100. return scipy.stats.zipf.rvs(a=float(dist.s), size=size, random_state=seed)
  101. # FRV:
  102. @do_sample_scipy.register(SingleFiniteDistribution)
  103. def _(dist: SingleFiniteDistribution, size, seed):
  104. # scipy can handle with custom distributions
  105. from scipy.stats import rv_discrete
  106. density_ = dist.dict
  107. x, y = [], []
  108. for k, v in density_.items():
  109. x.append(int(k))
  110. y.append(float(v))
  111. scipy_rv = rv_discrete(name='scipy_rv', values=(x, y))
  112. return scipy_rv.rvs(size=size, random_state=seed)