sample_pymc.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from functools import singledispatch
  2. from sympy.external import import_module
  3. from sympy.stats.crv_types import BetaDistribution, CauchyDistribution, ChiSquaredDistribution, ExponentialDistribution, \
  4. GammaDistribution, LogNormalDistribution, NormalDistribution, ParetoDistribution, UniformDistribution, \
  5. GaussianInverseDistribution
  6. from sympy.stats.drv_types import PoissonDistribution, GeometricDistribution, NegativeBinomialDistribution
  7. from sympy.stats.frv_types import BinomialDistribution, BernoulliDistribution
  8. try:
  9. import pymc
  10. except ImportError:
  11. pymc = import_module('pymc3')
  12. @singledispatch
  13. def do_sample_pymc(dist):
  14. return None
  15. # CRV:
  16. @do_sample_pymc.register(BetaDistribution)
  17. def _(dist: BetaDistribution):
  18. return pymc.Beta('X', alpha=float(dist.alpha), beta=float(dist.beta))
  19. @do_sample_pymc.register(CauchyDistribution)
  20. def _(dist: CauchyDistribution):
  21. return pymc.Cauchy('X', alpha=float(dist.x0), beta=float(dist.gamma))
  22. @do_sample_pymc.register(ChiSquaredDistribution)
  23. def _(dist: ChiSquaredDistribution):
  24. return pymc.ChiSquared('X', nu=float(dist.k))
  25. @do_sample_pymc.register(ExponentialDistribution)
  26. def _(dist: ExponentialDistribution):
  27. return pymc.Exponential('X', lam=float(dist.rate))
  28. @do_sample_pymc.register(GammaDistribution)
  29. def _(dist: GammaDistribution):
  30. return pymc.Gamma('X', alpha=float(dist.k), beta=1 / float(dist.theta))
  31. @do_sample_pymc.register(LogNormalDistribution)
  32. def _(dist: LogNormalDistribution):
  33. return pymc.Lognormal('X', mu=float(dist.mean), sigma=float(dist.std))
  34. @do_sample_pymc.register(NormalDistribution)
  35. def _(dist: NormalDistribution):
  36. return pymc.Normal('X', float(dist.mean), float(dist.std))
  37. @do_sample_pymc.register(GaussianInverseDistribution)
  38. def _(dist: GaussianInverseDistribution):
  39. return pymc.Wald('X', mu=float(dist.mean), lam=float(dist.shape))
  40. @do_sample_pymc.register(ParetoDistribution)
  41. def _(dist: ParetoDistribution):
  42. return pymc.Pareto('X', alpha=float(dist.alpha), m=float(dist.xm))
  43. @do_sample_pymc.register(UniformDistribution)
  44. def _(dist: UniformDistribution):
  45. return pymc.Uniform('X', lower=float(dist.left), upper=float(dist.right))
  46. # DRV:
  47. @do_sample_pymc.register(GeometricDistribution)
  48. def _(dist: GeometricDistribution):
  49. return pymc.Geometric('X', p=float(dist.p))
  50. @do_sample_pymc.register(NegativeBinomialDistribution)
  51. def _(dist: NegativeBinomialDistribution):
  52. return pymc.NegativeBinomial('X', mu=float((dist.p * dist.r) / (1 - dist.p)),
  53. alpha=float(dist.r))
  54. @do_sample_pymc.register(PoissonDistribution)
  55. def _(dist: PoissonDistribution):
  56. return pymc.Poisson('X', mu=float(dist.lamda))
  57. # FRV:
  58. @do_sample_pymc.register(BernoulliDistribution)
  59. def _(dist: BernoulliDistribution):
  60. return pymc.Bernoulli('X', p=float(dist.p))
  61. @do_sample_pymc.register(BinomialDistribution)
  62. def _(dist: BinomialDistribution):
  63. return pymc.Binomial('X', n=int(dist.n), p=float(dist.p))