rv.py 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792
  1. """
  2. Main Random Variables Module
  3. Defines abstract random variable type.
  4. Contains interfaces for probability space object (PSpace) as well as standard
  5. operators, P, E, sample, density, where, quantile
  6. See Also
  7. ========
  8. sympy.stats.crv
  9. sympy.stats.frv
  10. sympy.stats.rv_interface
  11. """
  12. from __future__ import annotations
  13. from functools import singledispatch
  14. from math import prod
  15. from sympy.core.add import Add
  16. from sympy.core.basic import Basic
  17. from sympy.core.containers import Tuple
  18. from sympy.core.expr import Expr
  19. from sympy.core.function import (Function, Lambda)
  20. from sympy.core.logic import fuzzy_and
  21. from sympy.core.mul import Mul
  22. from sympy.core.relational import (Eq, Ne)
  23. from sympy.core.singleton import S
  24. from sympy.core.symbol import (Dummy, Symbol)
  25. from sympy.core.sympify import sympify
  26. from sympy.functions.special.delta_functions import DiracDelta
  27. from sympy.functions.special.tensor_functions import KroneckerDelta
  28. from sympy.logic.boolalg import (And, Or)
  29. from sympy.matrices.expressions.matexpr import MatrixSymbol
  30. from sympy.tensor.indexed import Indexed
  31. from sympy.utilities.lambdify import lambdify
  32. from sympy.core.relational import Relational
  33. from sympy.core.sympify import _sympify
  34. from sympy.sets.sets import FiniteSet, ProductSet, Intersection
  35. from sympy.solvers.solveset import solveset
  36. from sympy.external import import_module
  37. from sympy.utilities.decorator import doctest_depends_on
  38. from sympy.utilities.exceptions import sympy_deprecation_warning
  39. from sympy.utilities.iterables import iterable
  40. x = Symbol('x')
  41. @singledispatch
  42. def is_random(x):
  43. return False
  44. @is_random.register(Basic)
  45. def _(x):
  46. atoms = x.free_symbols
  47. return any(is_random(i) for i in atoms)
  48. class RandomDomain(Basic):
  49. """
  50. Represents a set of variables and the values which they can take.
  51. See Also
  52. ========
  53. sympy.stats.crv.ContinuousDomain
  54. sympy.stats.frv.FiniteDomain
  55. """
  56. is_ProductDomain = False
  57. is_Finite = False
  58. is_Continuous = False
  59. is_Discrete = False
  60. def __new__(cls, symbols, *args):
  61. symbols = FiniteSet(*symbols)
  62. return Basic.__new__(cls, symbols, *args)
  63. @property
  64. def symbols(self):
  65. return self.args[0]
  66. @property
  67. def set(self):
  68. return self.args[1]
  69. def __contains__(self, other):
  70. raise NotImplementedError()
  71. def compute_expectation(self, expr):
  72. raise NotImplementedError()
  73. class SingleDomain(RandomDomain):
  74. """
  75. A single variable and its domain.
  76. See Also
  77. ========
  78. sympy.stats.crv.SingleContinuousDomain
  79. sympy.stats.frv.SingleFiniteDomain
  80. """
  81. def __new__(cls, symbol, set):
  82. assert symbol.is_Symbol
  83. return Basic.__new__(cls, symbol, set)
  84. @property
  85. def symbol(self):
  86. return self.args[0]
  87. @property
  88. def symbols(self):
  89. return FiniteSet(self.symbol)
  90. def __contains__(self, other):
  91. if len(other) != 1:
  92. return False
  93. sym, val = tuple(other)[0]
  94. return self.symbol == sym and val in self.set
  95. class MatrixDomain(RandomDomain):
  96. """
  97. A Random Matrix variable and its domain.
  98. """
  99. def __new__(cls, symbol, set):
  100. symbol, set = _symbol_converter(symbol), _sympify(set)
  101. return Basic.__new__(cls, symbol, set)
  102. @property
  103. def symbol(self):
  104. return self.args[0]
  105. @property
  106. def symbols(self):
  107. return FiniteSet(self.symbol)
  108. class ConditionalDomain(RandomDomain):
  109. """
  110. A RandomDomain with an attached condition.
  111. See Also
  112. ========
  113. sympy.stats.crv.ConditionalContinuousDomain
  114. sympy.stats.frv.ConditionalFiniteDomain
  115. """
  116. def __new__(cls, fulldomain, condition):
  117. condition = condition.xreplace({rs: rs.symbol
  118. for rs in random_symbols(condition)})
  119. return Basic.__new__(cls, fulldomain, condition)
  120. @property
  121. def symbols(self):
  122. return self.fulldomain.symbols
  123. @property
  124. def fulldomain(self):
  125. return self.args[0]
  126. @property
  127. def condition(self):
  128. return self.args[1]
  129. @property
  130. def set(self):
  131. raise NotImplementedError("Set of Conditional Domain not Implemented")
  132. def as_boolean(self):
  133. return And(self.fulldomain.as_boolean(), self.condition)
  134. class PSpace(Basic):
  135. """
  136. A Probability Space.
  137. Explanation
  138. ===========
  139. Probability Spaces encode processes that equal different values
  140. probabilistically. These underly Random Symbols which occur in SymPy
  141. expressions and contain the mechanics to evaluate statistical statements.
  142. See Also
  143. ========
  144. sympy.stats.crv.ContinuousPSpace
  145. sympy.stats.frv.FinitePSpace
  146. """
  147. is_Finite = None # type: bool
  148. is_Continuous = None # type: bool
  149. is_Discrete = None # type: bool
  150. is_real = None # type: bool
  151. @property
  152. def domain(self):
  153. return self.args[0]
  154. @property
  155. def density(self):
  156. return self.args[1]
  157. @property
  158. def values(self):
  159. return frozenset(RandomSymbol(sym, self) for sym in self.symbols)
  160. @property
  161. def symbols(self):
  162. return self.domain.symbols
  163. def where(self, condition):
  164. raise NotImplementedError()
  165. def compute_density(self, expr):
  166. raise NotImplementedError()
  167. def sample(self, size=(), library='scipy', seed=None):
  168. raise NotImplementedError()
  169. def probability(self, condition):
  170. raise NotImplementedError()
  171. def compute_expectation(self, expr):
  172. raise NotImplementedError()
  173. class SinglePSpace(PSpace):
  174. """
  175. Represents the probabilities of a set of random events that can be
  176. attributed to a single variable/symbol.
  177. """
  178. def __new__(cls, s, distribution):
  179. s = _symbol_converter(s)
  180. return Basic.__new__(cls, s, distribution)
  181. @property
  182. def value(self):
  183. return RandomSymbol(self.symbol, self)
  184. @property
  185. def symbol(self):
  186. return self.args[0]
  187. @property
  188. def distribution(self):
  189. return self.args[1]
  190. @property
  191. def pdf(self):
  192. return self.distribution.pdf(self.symbol)
  193. class RandomSymbol(Expr):
  194. """
  195. Random Symbols represent ProbabilitySpaces in SymPy Expressions.
  196. In principle they can take on any value that their symbol can take on
  197. within the associated PSpace with probability determined by the PSpace
  198. Density.
  199. Explanation
  200. ===========
  201. Random Symbols contain pspace and symbol properties.
  202. The pspace property points to the represented Probability Space
  203. The symbol is a standard SymPy Symbol that is used in that probability space
  204. for example in defining a density.
  205. You can form normal SymPy expressions using RandomSymbols and operate on
  206. those expressions with the Functions
  207. E - Expectation of a random expression
  208. P - Probability of a condition
  209. density - Probability Density of an expression
  210. given - A new random expression (with new random symbols) given a condition
  211. An object of the RandomSymbol type should almost never be created by the
  212. user. They tend to be created instead by the PSpace class's value method.
  213. Traditionally a user does not even do this but instead calls one of the
  214. convenience functions Normal, Exponential, Coin, Die, FiniteRV, etc....
  215. """
  216. def __new__(cls, symbol, pspace=None):
  217. from sympy.stats.joint_rv import JointRandomSymbol
  218. if pspace is None:
  219. # Allow single arg, representing pspace == PSpace()
  220. pspace = PSpace()
  221. symbol = _symbol_converter(symbol)
  222. if not isinstance(pspace, PSpace):
  223. raise TypeError("pspace variable should be of type PSpace")
  224. if cls == JointRandomSymbol and isinstance(pspace, SinglePSpace):
  225. cls = RandomSymbol
  226. return Basic.__new__(cls, symbol, pspace)
  227. is_finite = True
  228. is_symbol = True
  229. is_Atom = True
  230. _diff_wrt = True
  231. pspace = property(lambda self: self.args[1])
  232. symbol = property(lambda self: self.args[0])
  233. name = property(lambda self: self.symbol.name)
  234. def _eval_is_positive(self):
  235. return self.symbol.is_positive
  236. def _eval_is_integer(self):
  237. return self.symbol.is_integer
  238. def _eval_is_real(self):
  239. return self.symbol.is_real or self.pspace.is_real
  240. @property
  241. def is_commutative(self):
  242. return self.symbol.is_commutative
  243. @property
  244. def free_symbols(self):
  245. return {self}
  246. class RandomIndexedSymbol(RandomSymbol):
  247. def __new__(cls, idx_obj, pspace=None):
  248. if pspace is None:
  249. # Allow single arg, representing pspace == PSpace()
  250. pspace = PSpace()
  251. if not isinstance(idx_obj, (Indexed, Function)):
  252. raise TypeError("An Function or Indexed object is expected not %s"%(idx_obj))
  253. return Basic.__new__(cls, idx_obj, pspace)
  254. symbol = property(lambda self: self.args[0])
  255. name = property(lambda self: str(self.args[0]))
  256. @property
  257. def key(self):
  258. if isinstance(self.symbol, Indexed):
  259. return self.symbol.args[1]
  260. elif isinstance(self.symbol, Function):
  261. return self.symbol.args[0]
  262. @property
  263. def free_symbols(self):
  264. if self.key.free_symbols:
  265. free_syms = self.key.free_symbols
  266. free_syms.add(self)
  267. return free_syms
  268. return {self}
  269. @property
  270. def pspace(self):
  271. return self.args[1]
  272. class RandomMatrixSymbol(RandomSymbol, MatrixSymbol): # type: ignore
  273. def __new__(cls, symbol, n, m, pspace=None):
  274. n, m = _sympify(n), _sympify(m)
  275. symbol = _symbol_converter(symbol)
  276. if pspace is None:
  277. # Allow single arg, representing pspace == PSpace()
  278. pspace = PSpace()
  279. return Basic.__new__(cls, symbol, n, m, pspace)
  280. symbol = property(lambda self: self.args[0])
  281. pspace = property(lambda self: self.args[3])
  282. class ProductPSpace(PSpace):
  283. """
  284. Abstract class for representing probability spaces with multiple random
  285. variables.
  286. See Also
  287. ========
  288. sympy.stats.rv.IndependentProductPSpace
  289. sympy.stats.joint_rv.JointPSpace
  290. """
  291. pass
  292. class IndependentProductPSpace(ProductPSpace):
  293. """
  294. A probability space resulting from the merger of two independent probability
  295. spaces.
  296. Often created using the function, pspace.
  297. """
  298. def __new__(cls, *spaces):
  299. rs_space_dict = {}
  300. for space in spaces:
  301. for value in space.values:
  302. rs_space_dict[value] = space
  303. symbols = FiniteSet(*[val.symbol for val in rs_space_dict.keys()])
  304. # Overlapping symbols
  305. from sympy.stats.joint_rv import MarginalDistribution
  306. from sympy.stats.compound_rv import CompoundDistribution
  307. if len(symbols) < sum(len(space.symbols) for space in spaces if not
  308. isinstance(space.distribution, (
  309. CompoundDistribution, MarginalDistribution))):
  310. raise ValueError("Overlapping Random Variables")
  311. if all(space.is_Finite for space in spaces):
  312. from sympy.stats.frv import ProductFinitePSpace
  313. cls = ProductFinitePSpace
  314. obj = Basic.__new__(cls, *FiniteSet(*spaces))
  315. return obj
  316. @property
  317. def pdf(self):
  318. p = Mul(*[space.pdf for space in self.spaces])
  319. return p.subs({rv: rv.symbol for rv in self.values})
  320. @property
  321. def rs_space_dict(self):
  322. d = {}
  323. for space in self.spaces:
  324. for value in space.values:
  325. d[value] = space
  326. return d
  327. @property
  328. def symbols(self):
  329. return FiniteSet(*[val.symbol for val in self.rs_space_dict.keys()])
  330. @property
  331. def spaces(self):
  332. return FiniteSet(*self.args)
  333. @property
  334. def values(self):
  335. return sumsets(space.values for space in self.spaces)
  336. def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs):
  337. rvs = rvs or self.values
  338. rvs = frozenset(rvs)
  339. for space in self.spaces:
  340. expr = space.compute_expectation(expr, rvs & space.values, evaluate=False, **kwargs)
  341. if evaluate and hasattr(expr, 'doit'):
  342. return expr.doit(**kwargs)
  343. return expr
  344. @property
  345. def domain(self):
  346. return ProductDomain(*[space.domain for space in self.spaces])
  347. @property
  348. def density(self):
  349. raise NotImplementedError("Density not available for ProductSpaces")
  350. def sample(self, size=(), library='scipy', seed=None):
  351. return {k: v for space in self.spaces
  352. for k, v in space.sample(size=size, library=library, seed=seed).items()}
  353. def probability(self, condition, **kwargs):
  354. cond_inv = False
  355. if isinstance(condition, Ne):
  356. condition = Eq(condition.args[0], condition.args[1])
  357. cond_inv = True
  358. elif isinstance(condition, And): # they are independent
  359. return Mul(*[self.probability(arg) for arg in condition.args])
  360. elif isinstance(condition, Or): # they are independent
  361. return Add(*[self.probability(arg) for arg in condition.args])
  362. expr = condition.lhs - condition.rhs
  363. rvs = random_symbols(expr)
  364. dens = self.compute_density(expr)
  365. if any(pspace(rv).is_Continuous for rv in rvs):
  366. from sympy.stats.crv import SingleContinuousPSpace
  367. from sympy.stats.crv_types import ContinuousDistributionHandmade
  368. if expr in self.values:
  369. # Marginalize all other random symbols out of the density
  370. randomsymbols = tuple(set(self.values) - frozenset([expr]))
  371. symbols = tuple(rs.symbol for rs in randomsymbols)
  372. pdf = self.domain.integrate(self.pdf, symbols, **kwargs)
  373. return Lambda(expr.symbol, pdf)
  374. dens = ContinuousDistributionHandmade(dens)
  375. z = Dummy('z', real=True)
  376. space = SingleContinuousPSpace(z, dens)
  377. result = space.probability(condition.__class__(space.value, 0))
  378. else:
  379. from sympy.stats.drv import SingleDiscretePSpace
  380. from sympy.stats.drv_types import DiscreteDistributionHandmade
  381. dens = DiscreteDistributionHandmade(dens)
  382. z = Dummy('z', integer=True)
  383. space = SingleDiscretePSpace(z, dens)
  384. result = space.probability(condition.__class__(space.value, 0))
  385. return result if not cond_inv else S.One - result
  386. def compute_density(self, expr, **kwargs):
  387. rvs = random_symbols(expr)
  388. if any(pspace(rv).is_Continuous for rv in rvs):
  389. z = Dummy('z', real=True)
  390. expr = self.compute_expectation(DiracDelta(expr - z),
  391. **kwargs)
  392. else:
  393. z = Dummy('z', integer=True)
  394. expr = self.compute_expectation(KroneckerDelta(expr, z),
  395. **kwargs)
  396. return Lambda(z, expr)
  397. def compute_cdf(self, expr, **kwargs):
  398. raise ValueError("CDF not well defined on multivariate expressions")
  399. def conditional_space(self, condition, normalize=True, **kwargs):
  400. rvs = random_symbols(condition)
  401. condition = condition.xreplace({rv: rv.symbol for rv in self.values})
  402. pspaces = [pspace(rv) for rv in rvs]
  403. if any(ps.is_Continuous for ps in pspaces):
  404. from sympy.stats.crv import (ConditionalContinuousDomain,
  405. ContinuousPSpace)
  406. space = ContinuousPSpace
  407. domain = ConditionalContinuousDomain(self.domain, condition)
  408. elif any(ps.is_Discrete for ps in pspaces):
  409. from sympy.stats.drv import (ConditionalDiscreteDomain,
  410. DiscretePSpace)
  411. space = DiscretePSpace
  412. domain = ConditionalDiscreteDomain(self.domain, condition)
  413. elif all(ps.is_Finite for ps in pspaces):
  414. from sympy.stats.frv import FinitePSpace
  415. return FinitePSpace.conditional_space(self, condition)
  416. if normalize:
  417. replacement = {rv: Dummy(str(rv)) for rv in self.symbols}
  418. norm = domain.compute_expectation(self.pdf, **kwargs)
  419. pdf = self.pdf / norm.xreplace(replacement)
  420. # XXX: Converting symbols from set to tuple. The order matters to
  421. # Lambda though so we shouldn't be starting with a set here...
  422. density = Lambda(tuple(domain.symbols), pdf)
  423. return space(domain, density)
  424. class ProductDomain(RandomDomain):
  425. """
  426. A domain resulting from the merger of two independent domains.
  427. See Also
  428. ========
  429. sympy.stats.crv.ProductContinuousDomain
  430. sympy.stats.frv.ProductFiniteDomain
  431. """
  432. is_ProductDomain = True
  433. def __new__(cls, *domains):
  434. # Flatten any product of products
  435. domains2 = []
  436. for domain in domains:
  437. if not domain.is_ProductDomain:
  438. domains2.append(domain)
  439. else:
  440. domains2.extend(domain.domains)
  441. domains2 = FiniteSet(*domains2)
  442. if all(domain.is_Finite for domain in domains2):
  443. from sympy.stats.frv import ProductFiniteDomain
  444. cls = ProductFiniteDomain
  445. if all(domain.is_Continuous for domain in domains2):
  446. from sympy.stats.crv import ProductContinuousDomain
  447. cls = ProductContinuousDomain
  448. if all(domain.is_Discrete for domain in domains2):
  449. from sympy.stats.drv import ProductDiscreteDomain
  450. cls = ProductDiscreteDomain
  451. return Basic.__new__(cls, *domains2)
  452. @property
  453. def sym_domain_dict(self):
  454. return {symbol: domain for domain in self.domains
  455. for symbol in domain.symbols}
  456. @property
  457. def symbols(self):
  458. return FiniteSet(*[sym for domain in self.domains
  459. for sym in domain.symbols])
  460. @property
  461. def domains(self):
  462. return self.args
  463. @property
  464. def set(self):
  465. return ProductSet(*(domain.set for domain in self.domains))
  466. def __contains__(self, other):
  467. # Split event into each subdomain
  468. for domain in self.domains:
  469. # Collect the parts of this event which associate to this domain
  470. elem = frozenset([item for item in other
  471. if sympify(domain.symbols.contains(item[0]))
  472. is S.true])
  473. # Test this sub-event
  474. if elem not in domain:
  475. return False
  476. # All subevents passed
  477. return True
  478. def as_boolean(self):
  479. return And(*[domain.as_boolean() for domain in self.domains])
  480. def random_symbols(expr):
  481. """
  482. Returns all RandomSymbols within a SymPy Expression.
  483. """
  484. atoms = getattr(expr, 'atoms', None)
  485. if atoms is not None:
  486. comp = lambda rv: rv.symbol.name
  487. l = list(atoms(RandomSymbol))
  488. return sorted(l, key=comp)
  489. else:
  490. return []
  491. def pspace(expr):
  492. """
  493. Returns the underlying Probability Space of a random expression.
  494. For internal use.
  495. Examples
  496. ========
  497. >>> from sympy.stats import pspace, Normal
  498. >>> X = Normal('X', 0, 1)
  499. >>> pspace(2*X + 1) == X.pspace
  500. True
  501. """
  502. expr = sympify(expr)
  503. if isinstance(expr, RandomSymbol) and expr.pspace is not None:
  504. return expr.pspace
  505. if expr.has(RandomMatrixSymbol):
  506. rm = list(expr.atoms(RandomMatrixSymbol))[0]
  507. return rm.pspace
  508. rvs = random_symbols(expr)
  509. if not rvs:
  510. raise ValueError("Expression containing Random Variable expected, not %s" % (expr))
  511. # If only one space present
  512. if all(rv.pspace == rvs[0].pspace for rv in rvs):
  513. return rvs[0].pspace
  514. from sympy.stats.compound_rv import CompoundPSpace
  515. from sympy.stats.stochastic_process import StochasticPSpace
  516. for rv in rvs:
  517. if isinstance(rv.pspace, (CompoundPSpace, StochasticPSpace)):
  518. return rv.pspace
  519. # Otherwise make a product space
  520. return IndependentProductPSpace(*[rv.pspace for rv in rvs])
  521. def sumsets(sets):
  522. """
  523. Union of sets
  524. """
  525. return frozenset().union(*sets)
  526. def rs_swap(a, b):
  527. """
  528. Build a dictionary to swap RandomSymbols based on their underlying symbol.
  529. i.e.
  530. if ``X = ('x', pspace1)``
  531. and ``Y = ('x', pspace2)``
  532. then ``X`` and ``Y`` match and the key, value pair
  533. ``{X:Y}`` will appear in the result
  534. Inputs: collections a and b of random variables which share common symbols
  535. Output: dict mapping RVs in a to RVs in b
  536. """
  537. d = {}
  538. for rsa in a:
  539. d[rsa] = [rsb for rsb in b if rsa.symbol == rsb.symbol][0]
  540. return d
  541. def given(expr, condition=None, **kwargs):
  542. r""" Conditional Random Expression.
  543. Explanation
  544. ===========
  545. From a random expression and a condition on that expression creates a new
  546. probability space from the condition and returns the same expression on that
  547. conditional probability space.
  548. Examples
  549. ========
  550. >>> from sympy.stats import given, density, Die
  551. >>> X = Die('X', 6)
  552. >>> Y = given(X, X > 3)
  553. >>> density(Y).dict
  554. {4: 1/3, 5: 1/3, 6: 1/3}
  555. Following convention, if the condition is a random symbol then that symbol
  556. is considered fixed.
  557. >>> from sympy.stats import Normal
  558. >>> from sympy import pprint
  559. >>> from sympy.abc import z
  560. >>> X = Normal('X', 0, 1)
  561. >>> Y = Normal('Y', 0, 1)
  562. >>> pprint(density(X + Y, Y)(z), use_unicode=False)
  563. 2
  564. -(-Y + z)
  565. -----------
  566. ___ 2
  567. \/ 2 *e
  568. ------------------
  569. ____
  570. 2*\/ pi
  571. """
  572. if not is_random(condition) or pspace_independent(expr, condition):
  573. return expr
  574. if isinstance(condition, RandomSymbol):
  575. condition = Eq(condition, condition.symbol)
  576. condsymbols = random_symbols(condition)
  577. if (isinstance(condition, Eq) and len(condsymbols) == 1 and
  578. not isinstance(pspace(expr).domain, ConditionalDomain)):
  579. rv = tuple(condsymbols)[0]
  580. results = solveset(condition, rv)
  581. if isinstance(results, Intersection) and S.Reals in results.args:
  582. results = list(results.args[1])
  583. sums = 0
  584. for res in results:
  585. temp = expr.subs(rv, res)
  586. if temp == True:
  587. return True
  588. if temp != False:
  589. # XXX: This seems nonsensical but preserves existing behaviour
  590. # after the change that Relational is no longer a subclass of
  591. # Expr. Here expr is sometimes Relational and sometimes Expr
  592. # but we are trying to add them with +=. This needs to be
  593. # fixed somehow.
  594. if sums == 0 and isinstance(expr, Relational):
  595. sums = expr.subs(rv, res)
  596. else:
  597. sums += expr.subs(rv, res)
  598. if sums == 0:
  599. return False
  600. return sums
  601. # Get full probability space of both the expression and the condition
  602. fullspace = pspace(Tuple(expr, condition))
  603. # Build new space given the condition
  604. space = fullspace.conditional_space(condition, **kwargs)
  605. # Dictionary to swap out RandomSymbols in expr with new RandomSymbols
  606. # That point to the new conditional space
  607. swapdict = rs_swap(fullspace.values, space.values)
  608. # Swap random variables in the expression
  609. expr = expr.xreplace(swapdict)
  610. return expr
  611. def expectation(expr, condition=None, numsamples=None, evaluate=True, **kwargs):
  612. """
  613. Returns the expected value of a random expression.
  614. Parameters
  615. ==========
  616. expr : Expr containing RandomSymbols
  617. The expression of which you want to compute the expectation value
  618. given : Expr containing RandomSymbols
  619. A conditional expression. E(X, X>0) is expectation of X given X > 0
  620. numsamples : int
  621. Enables sampling and approximates the expectation with this many samples
  622. evalf : Bool (defaults to True)
  623. If sampling return a number rather than a complex expression
  624. evaluate : Bool (defaults to True)
  625. In case of continuous systems return unevaluated integral
  626. Examples
  627. ========
  628. >>> from sympy.stats import E, Die
  629. >>> X = Die('X', 6)
  630. >>> E(X)
  631. 7/2
  632. >>> E(2*X + 1)
  633. 8
  634. >>> E(X, X > 3) # Expectation of X given that it is above 3
  635. 5
  636. """
  637. if not is_random(expr): # expr isn't random?
  638. return expr
  639. kwargs['numsamples'] = numsamples
  640. from sympy.stats.symbolic_probability import Expectation
  641. if evaluate:
  642. return Expectation(expr, condition).doit(**kwargs)
  643. return Expectation(expr, condition)
  644. def probability(condition, given_condition=None, numsamples=None,
  645. evaluate=True, **kwargs):
  646. """
  647. Probability that a condition is true, optionally given a second condition.
  648. Parameters
  649. ==========
  650. condition : Combination of Relationals containing RandomSymbols
  651. The condition of which you want to compute the probability
  652. given_condition : Combination of Relationals containing RandomSymbols
  653. A conditional expression. P(X > 1, X > 0) is expectation of X > 1
  654. given X > 0
  655. numsamples : int
  656. Enables sampling and approximates the probability with this many samples
  657. evaluate : Bool (defaults to True)
  658. In case of continuous systems return unevaluated integral
  659. Examples
  660. ========
  661. >>> from sympy.stats import P, Die
  662. >>> from sympy import Eq
  663. >>> X, Y = Die('X', 6), Die('Y', 6)
  664. >>> P(X > 3)
  665. 1/2
  666. >>> P(Eq(X, 5), X > 2) # Probability that X == 5 given that X > 2
  667. 1/4
  668. >>> P(X > Y)
  669. 5/12
  670. """
  671. kwargs['numsamples'] = numsamples
  672. from sympy.stats.symbolic_probability import Probability
  673. if evaluate:
  674. return Probability(condition, given_condition).doit(**kwargs)
  675. return Probability(condition, given_condition)
  676. class Density(Basic):
  677. expr = property(lambda self: self.args[0])
  678. def __new__(cls, expr, condition = None):
  679. expr = _sympify(expr)
  680. if condition is None:
  681. obj = Basic.__new__(cls, expr)
  682. else:
  683. condition = _sympify(condition)
  684. obj = Basic.__new__(cls, expr, condition)
  685. return obj
  686. @property
  687. def condition(self):
  688. if len(self.args) > 1:
  689. return self.args[1]
  690. else:
  691. return None
  692. def doit(self, evaluate=True, **kwargs):
  693. from sympy.stats.random_matrix import RandomMatrixPSpace
  694. from sympy.stats.joint_rv import JointPSpace
  695. from sympy.stats.matrix_distributions import MatrixPSpace
  696. from sympy.stats.compound_rv import CompoundPSpace
  697. from sympy.stats.frv import SingleFiniteDistribution
  698. expr, condition = self.expr, self.condition
  699. if isinstance(expr, SingleFiniteDistribution):
  700. return expr.dict
  701. if condition is not None:
  702. # Recompute on new conditional expr
  703. expr = given(expr, condition, **kwargs)
  704. if not random_symbols(expr):
  705. return Lambda(x, DiracDelta(x - expr))
  706. if isinstance(expr, RandomSymbol):
  707. if isinstance(expr.pspace, (SinglePSpace, JointPSpace, MatrixPSpace)) and \
  708. hasattr(expr.pspace, 'distribution'):
  709. return expr.pspace.distribution
  710. elif isinstance(expr.pspace, RandomMatrixPSpace):
  711. return expr.pspace.model
  712. if isinstance(pspace(expr), CompoundPSpace):
  713. kwargs['compound_evaluate'] = evaluate
  714. result = pspace(expr).compute_density(expr, **kwargs)
  715. if evaluate and hasattr(result, 'doit'):
  716. return result.doit()
  717. else:
  718. return result
  719. def density(expr, condition=None, evaluate=True, numsamples=None, **kwargs):
  720. """
  721. Probability density of a random expression, optionally given a second
  722. condition.
  723. Explanation
  724. ===========
  725. This density will take on different forms for different types of
  726. probability spaces. Discrete variables produce Dicts. Continuous
  727. variables produce Lambdas.
  728. Parameters
  729. ==========
  730. expr : Expr containing RandomSymbols
  731. The expression of which you want to compute the density value
  732. condition : Relational containing RandomSymbols
  733. A conditional expression. density(X > 1, X > 0) is density of X > 1
  734. given X > 0
  735. numsamples : int
  736. Enables sampling and approximates the density with this many samples
  737. Examples
  738. ========
  739. >>> from sympy.stats import density, Die, Normal
  740. >>> from sympy import Symbol
  741. >>> x = Symbol('x')
  742. >>> D = Die('D', 6)
  743. >>> X = Normal(x, 0, 1)
  744. >>> density(D).dict
  745. {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6}
  746. >>> density(2*D).dict
  747. {2: 1/6, 4: 1/6, 6: 1/6, 8: 1/6, 10: 1/6, 12: 1/6}
  748. >>> density(X)(x)
  749. sqrt(2)*exp(-x**2/2)/(2*sqrt(pi))
  750. """
  751. if numsamples:
  752. return sampling_density(expr, condition, numsamples=numsamples,
  753. **kwargs)
  754. return Density(expr, condition).doit(evaluate=evaluate, **kwargs)
  755. def cdf(expr, condition=None, evaluate=True, **kwargs):
  756. """
  757. Cumulative Distribution Function of a random expression.
  758. optionally given a second condition.
  759. Explanation
  760. ===========
  761. This density will take on different forms for different types of
  762. probability spaces.
  763. Discrete variables produce Dicts.
  764. Continuous variables produce Lambdas.
  765. Examples
  766. ========
  767. >>> from sympy.stats import density, Die, Normal, cdf
  768. >>> D = Die('D', 6)
  769. >>> X = Normal('X', 0, 1)
  770. >>> density(D).dict
  771. {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6}
  772. >>> cdf(D)
  773. {1: 1/6, 2: 1/3, 3: 1/2, 4: 2/3, 5: 5/6, 6: 1}
  774. >>> cdf(3*D, D > 2)
  775. {9: 1/4, 12: 1/2, 15: 3/4, 18: 1}
  776. >>> cdf(X)
  777. Lambda(_z, erf(sqrt(2)*_z/2)/2 + 1/2)
  778. """
  779. if condition is not None: # If there is a condition
  780. # Recompute on new conditional expr
  781. return cdf(given(expr, condition, **kwargs), **kwargs)
  782. # Otherwise pass work off to the ProbabilitySpace
  783. result = pspace(expr).compute_cdf(expr, **kwargs)
  784. if evaluate and hasattr(result, 'doit'):
  785. return result.doit()
  786. else:
  787. return result
  788. def characteristic_function(expr, condition=None, evaluate=True, **kwargs):
  789. """
  790. Characteristic function of a random expression, optionally given a second condition.
  791. Returns a Lambda.
  792. Examples
  793. ========
  794. >>> from sympy.stats import Normal, DiscreteUniform, Poisson, characteristic_function
  795. >>> X = Normal('X', 0, 1)
  796. >>> characteristic_function(X)
  797. Lambda(_t, exp(-_t**2/2))
  798. >>> Y = DiscreteUniform('Y', [1, 2, 7])
  799. >>> characteristic_function(Y)
  800. Lambda(_t, exp(7*_t*I)/3 + exp(2*_t*I)/3 + exp(_t*I)/3)
  801. >>> Z = Poisson('Z', 2)
  802. >>> characteristic_function(Z)
  803. Lambda(_t, exp(2*exp(_t*I) - 2))
  804. """
  805. if condition is not None:
  806. return characteristic_function(given(expr, condition, **kwargs), **kwargs)
  807. result = pspace(expr).compute_characteristic_function(expr, **kwargs)
  808. if evaluate and hasattr(result, 'doit'):
  809. return result.doit()
  810. else:
  811. return result
  812. def moment_generating_function(expr, condition=None, evaluate=True, **kwargs):
  813. if condition is not None:
  814. return moment_generating_function(given(expr, condition, **kwargs), **kwargs)
  815. result = pspace(expr).compute_moment_generating_function(expr, **kwargs)
  816. if evaluate and hasattr(result, 'doit'):
  817. return result.doit()
  818. else:
  819. return result
  820. def where(condition, given_condition=None, **kwargs):
  821. """
  822. Returns the domain where a condition is True.
  823. Examples
  824. ========
  825. >>> from sympy.stats import where, Die, Normal
  826. >>> from sympy import And
  827. >>> D1, D2 = Die('a', 6), Die('b', 6)
  828. >>> a, b = D1.symbol, D2.symbol
  829. >>> X = Normal('x', 0, 1)
  830. >>> where(X**2<1)
  831. Domain: (-1 < x) & (x < 1)
  832. >>> where(X**2<1).set
  833. Interval.open(-1, 1)
  834. >>> where(And(D1<=D2, D2<3))
  835. Domain: (Eq(a, 1) & Eq(b, 1)) | (Eq(a, 1) & Eq(b, 2)) | (Eq(a, 2) & Eq(b, 2))
  836. """
  837. if given_condition is not None: # If there is a condition
  838. # Recompute on new conditional expr
  839. return where(given(condition, given_condition, **kwargs), **kwargs)
  840. # Otherwise pass work off to the ProbabilitySpace
  841. return pspace(condition).where(condition, **kwargs)
  842. @doctest_depends_on(modules=('scipy',))
  843. def sample(expr, condition=None, size=(), library='scipy',
  844. numsamples=1, seed=None, **kwargs):
  845. """
  846. A realization of the random expression.
  847. Parameters
  848. ==========
  849. expr : Expression of random variables
  850. Expression from which sample is extracted
  851. condition : Expr containing RandomSymbols
  852. A conditional expression
  853. size : int, tuple
  854. Represents size of each sample in numsamples
  855. library : str
  856. - 'scipy' : Sample using scipy
  857. - 'numpy' : Sample using numpy
  858. - 'pymc' : Sample using PyMC
  859. Choose any of the available options to sample from as string,
  860. by default is 'scipy'
  861. numsamples : int
  862. Number of samples, each with size as ``size``.
  863. .. deprecated:: 1.9
  864. The ``numsamples`` parameter is deprecated and is only provided for
  865. compatibility with v1.8. Use a list comprehension or an additional
  866. dimension in ``size`` instead. See
  867. :ref:`deprecated-sympy-stats-numsamples` for details.
  868. seed :
  869. An object to be used as seed by the given external library for sampling `expr`.
  870. Following is the list of possible types of object for the supported libraries,
  871. - 'scipy': int, numpy.random.RandomState, numpy.random.Generator
  872. - 'numpy': int, numpy.random.RandomState, numpy.random.Generator
  873. - 'pymc': int
  874. Optional, by default None, in which case seed settings
  875. related to the given library will be used.
  876. No modifications to environment's global seed settings
  877. are done by this argument.
  878. Returns
  879. =======
  880. sample: float/list/numpy.ndarray
  881. one sample or a collection of samples of the random expression.
  882. - sample(X) returns float/numpy.float64/numpy.int64 object.
  883. - sample(X, size=int/tuple) returns numpy.ndarray object.
  884. Examples
  885. ========
  886. >>> from sympy.stats import Die, sample, Normal, Geometric
  887. >>> X, Y, Z = Die('X', 6), Die('Y', 6), Die('Z', 6) # Finite Random Variable
  888. >>> die_roll = sample(X + Y + Z)
  889. >>> die_roll # doctest: +SKIP
  890. 3
  891. >>> N = Normal('N', 3, 4) # Continuous Random Variable
  892. >>> samp = sample(N)
  893. >>> samp in N.pspace.domain.set
  894. True
  895. >>> samp = sample(N, N>0)
  896. >>> samp > 0
  897. True
  898. >>> samp_list = sample(N, size=4)
  899. >>> [sam in N.pspace.domain.set for sam in samp_list]
  900. [True, True, True, True]
  901. >>> sample(N, size = (2,3)) # doctest: +SKIP
  902. array([[5.42519758, 6.40207856, 4.94991743],
  903. [1.85819627, 6.83403519, 1.9412172 ]])
  904. >>> G = Geometric('G', 0.5) # Discrete Random Variable
  905. >>> samp_list = sample(G, size=3)
  906. >>> samp_list # doctest: +SKIP
  907. [1, 3, 2]
  908. >>> [sam in G.pspace.domain.set for sam in samp_list]
  909. [True, True, True]
  910. >>> MN = Normal("MN", [3, 4], [[2, 1], [1, 2]]) # Joint Random Variable
  911. >>> samp_list = sample(MN, size=4)
  912. >>> samp_list # doctest: +SKIP
  913. [array([2.85768055, 3.38954165]),
  914. array([4.11163337, 4.3176591 ]),
  915. array([0.79115232, 1.63232916]),
  916. array([4.01747268, 3.96716083])]
  917. >>> [tuple(sam) in MN.pspace.domain.set for sam in samp_list]
  918. [True, True, True, True]
  919. .. versionchanged:: 1.7.0
  920. sample used to return an iterator containing the samples instead of value.
  921. .. versionchanged:: 1.9.0
  922. sample returns values or array of values instead of an iterator and numsamples is deprecated.
  923. """
  924. iterator = sample_iter(expr, condition, size=size, library=library,
  925. numsamples=numsamples, seed=seed)
  926. if numsamples != 1:
  927. sympy_deprecation_warning(
  928. f"""
  929. The numsamples parameter to sympy.stats.sample() is deprecated.
  930. Either use a list comprehension, like
  931. [sample(...) for i in range({numsamples})]
  932. or add a dimension to size, like
  933. sample(..., size={(numsamples,) + size})
  934. """,
  935. deprecated_since_version="1.9",
  936. active_deprecations_target="deprecated-sympy-stats-numsamples",
  937. )
  938. return [next(iterator) for i in range(numsamples)]
  939. return next(iterator)
  940. def quantile(expr, evaluate=True, **kwargs):
  941. r"""
  942. Return the :math:`p^{th}` order quantile of a probability distribution.
  943. Explanation
  944. ===========
  945. Quantile is defined as the value at which the probability of the random
  946. variable is less than or equal to the given probability.
  947. .. math::
  948. Q(p) = \inf\{x \in (-\infty, \infty) : p \le F(x)\}
  949. Examples
  950. ========
  951. >>> from sympy.stats import quantile, Die, Exponential
  952. >>> from sympy import Symbol, pprint
  953. >>> p = Symbol("p")
  954. >>> l = Symbol("lambda", positive=True)
  955. >>> X = Exponential("x", l)
  956. >>> quantile(X)(p)
  957. -log(1 - p)/lambda
  958. >>> D = Die("d", 6)
  959. >>> pprint(quantile(D)(p), use_unicode=False)
  960. /nan for Or(p > 1, p < 0)
  961. |
  962. | 1 for p <= 1/6
  963. |
  964. | 2 for p <= 1/3
  965. |
  966. < 3 for p <= 1/2
  967. |
  968. | 4 for p <= 2/3
  969. |
  970. | 5 for p <= 5/6
  971. |
  972. \ 6 for p <= 1
  973. """
  974. result = pspace(expr).compute_quantile(expr, **kwargs)
  975. if evaluate and hasattr(result, 'doit'):
  976. return result.doit()
  977. else:
  978. return result
  979. def sample_iter(expr, condition=None, size=(), library='scipy',
  980. numsamples=S.Infinity, seed=None, **kwargs):
  981. """
  982. Returns an iterator of realizations from the expression given a condition.
  983. Parameters
  984. ==========
  985. expr: Expr
  986. Random expression to be realized
  987. condition: Expr, optional
  988. A conditional expression
  989. size : int, tuple
  990. Represents size of each sample in numsamples
  991. numsamples: integer, optional
  992. Length of the iterator (defaults to infinity)
  993. seed :
  994. An object to be used as seed by the given external library for sampling `expr`.
  995. Following is the list of possible types of object for the supported libraries,
  996. - 'scipy': int, numpy.random.RandomState, numpy.random.Generator
  997. - 'numpy': int, numpy.random.RandomState, numpy.random.Generator
  998. - 'pymc': int
  999. Optional, by default None, in which case seed settings
  1000. related to the given library will be used.
  1001. No modifications to environment's global seed settings
  1002. are done by this argument.
  1003. Examples
  1004. ========
  1005. >>> from sympy.stats import Normal, sample_iter
  1006. >>> X = Normal('X', 0, 1)
  1007. >>> expr = X*X + 3
  1008. >>> iterator = sample_iter(expr, numsamples=3) # doctest: +SKIP
  1009. >>> list(iterator) # doctest: +SKIP
  1010. [12, 4, 7]
  1011. Returns
  1012. =======
  1013. sample_iter: iterator object
  1014. iterator object containing the sample/samples of given expr
  1015. See Also
  1016. ========
  1017. sample
  1018. sampling_P
  1019. sampling_E
  1020. """
  1021. from sympy.stats.joint_rv import JointRandomSymbol
  1022. if not import_module(library):
  1023. raise ValueError("Failed to import %s" % library)
  1024. if condition is not None:
  1025. ps = pspace(Tuple(expr, condition))
  1026. else:
  1027. ps = pspace(expr)
  1028. rvs = list(ps.values)
  1029. if isinstance(expr, JointRandomSymbol):
  1030. expr = expr.subs({expr: RandomSymbol(expr.symbol, expr.pspace)})
  1031. else:
  1032. sub = {}
  1033. for arg in expr.args:
  1034. if isinstance(arg, JointRandomSymbol):
  1035. sub[arg] = RandomSymbol(arg.symbol, arg.pspace)
  1036. expr = expr.subs(sub)
  1037. def fn_subs(*args):
  1038. return expr.subs({rv: arg for rv, arg in zip(rvs, args)})
  1039. def given_fn_subs(*args):
  1040. if condition is not None:
  1041. return condition.subs({rv: arg for rv, arg in zip(rvs, args)})
  1042. return False
  1043. if library in ('pymc', 'pymc3'):
  1044. # Currently unable to lambdify in pymc
  1045. # TODO : Remove when lambdify accepts 'pymc' as module
  1046. fn = lambdify(rvs, expr, **kwargs)
  1047. else:
  1048. fn = lambdify(rvs, expr, modules=library, **kwargs)
  1049. if condition is not None:
  1050. given_fn = lambdify(rvs, condition, **kwargs)
  1051. def return_generator_infinite():
  1052. count = 0
  1053. _size = (1,)+((size,) if isinstance(size, int) else size)
  1054. while count < numsamples:
  1055. d = ps.sample(size=_size, library=library, seed=seed) # a dictionary that maps RVs to values
  1056. args = [d[rv][0] for rv in rvs]
  1057. if condition is not None: # Check that these values satisfy the condition
  1058. # TODO: Replace the try-except block with only given_fn(*args)
  1059. # once lambdify works with unevaluated SymPy objects.
  1060. try:
  1061. gd = given_fn(*args)
  1062. except (NameError, TypeError):
  1063. gd = given_fn_subs(*args)
  1064. if gd != True and gd != False:
  1065. raise ValueError(
  1066. "Conditions must not contain free symbols")
  1067. if not gd: # If the values don't satisfy then try again
  1068. continue
  1069. yield fn(*args)
  1070. count += 1
  1071. def return_generator_finite():
  1072. faulty = True
  1073. while faulty:
  1074. d = ps.sample(size=(numsamples,) + ((size,) if isinstance(size, int) else size),
  1075. library=library, seed=seed) # a dictionary that maps RVs to values
  1076. faulty = False
  1077. count = 0
  1078. while count < numsamples and not faulty:
  1079. args = [d[rv][count] for rv in rvs]
  1080. if condition is not None: # Check that these values satisfy the condition
  1081. # TODO: Replace the try-except block with only given_fn(*args)
  1082. # once lambdify works with unevaluated SymPy objects.
  1083. try:
  1084. gd = given_fn(*args)
  1085. except (NameError, TypeError):
  1086. gd = given_fn_subs(*args)
  1087. if gd != True and gd != False:
  1088. raise ValueError(
  1089. "Conditions must not contain free symbols")
  1090. if not gd: # If the values don't satisfy then try again
  1091. faulty = True
  1092. count += 1
  1093. count = 0
  1094. while count < numsamples:
  1095. args = [d[rv][count] for rv in rvs]
  1096. # TODO: Replace the try-except block with only fn(*args)
  1097. # once lambdify works with unevaluated SymPy objects.
  1098. try:
  1099. yield fn(*args)
  1100. except (NameError, TypeError):
  1101. yield fn_subs(*args)
  1102. count += 1
  1103. if numsamples is S.Infinity:
  1104. return return_generator_infinite()
  1105. return return_generator_finite()
  1106. def sample_iter_lambdify(expr, condition=None, size=(),
  1107. numsamples=S.Infinity, seed=None, **kwargs):
  1108. return sample_iter(expr, condition=condition, size=size,
  1109. numsamples=numsamples, seed=seed, **kwargs)
  1110. def sample_iter_subs(expr, condition=None, size=(),
  1111. numsamples=S.Infinity, seed=None, **kwargs):
  1112. return sample_iter(expr, condition=condition, size=size,
  1113. numsamples=numsamples, seed=seed, **kwargs)
  1114. def sampling_P(condition, given_condition=None, library='scipy', numsamples=1,
  1115. evalf=True, seed=None, **kwargs):
  1116. """
  1117. Sampling version of P.
  1118. See Also
  1119. ========
  1120. P
  1121. sampling_E
  1122. sampling_density
  1123. """
  1124. count_true = 0
  1125. count_false = 0
  1126. samples = sample_iter(condition, given_condition, library=library,
  1127. numsamples=numsamples, seed=seed, **kwargs)
  1128. for sample in samples:
  1129. if sample:
  1130. count_true += 1
  1131. else:
  1132. count_false += 1
  1133. result = S(count_true) / numsamples
  1134. if evalf:
  1135. return result.evalf()
  1136. else:
  1137. return result
  1138. def sampling_E(expr, given_condition=None, library='scipy', numsamples=1,
  1139. evalf=True, seed=None, **kwargs):
  1140. """
  1141. Sampling version of E.
  1142. See Also
  1143. ========
  1144. P
  1145. sampling_P
  1146. sampling_density
  1147. """
  1148. samples = list(sample_iter(expr, given_condition, library=library,
  1149. numsamples=numsamples, seed=seed, **kwargs))
  1150. result = Add(*samples) / numsamples
  1151. if evalf:
  1152. return result.evalf()
  1153. else:
  1154. return result
  1155. def sampling_density(expr, given_condition=None, library='scipy',
  1156. numsamples=1, seed=None, **kwargs):
  1157. """
  1158. Sampling version of density.
  1159. See Also
  1160. ========
  1161. density
  1162. sampling_P
  1163. sampling_E
  1164. """
  1165. results = {}
  1166. for result in sample_iter(expr, given_condition, library=library,
  1167. numsamples=numsamples, seed=seed, **kwargs):
  1168. results[result] = results.get(result, 0) + 1
  1169. return results
  1170. def dependent(a, b):
  1171. """
  1172. Dependence of two random expressions.
  1173. Two expressions are independent if knowledge of one does not change
  1174. computations on the other.
  1175. Examples
  1176. ========
  1177. >>> from sympy.stats import Normal, dependent, given
  1178. >>> from sympy import Tuple, Eq
  1179. >>> X, Y = Normal('X', 0, 1), Normal('Y', 0, 1)
  1180. >>> dependent(X, Y)
  1181. False
  1182. >>> dependent(2*X + Y, -Y)
  1183. True
  1184. >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3))
  1185. >>> dependent(X, Y)
  1186. True
  1187. See Also
  1188. ========
  1189. independent
  1190. """
  1191. if pspace_independent(a, b):
  1192. return False
  1193. z = Symbol('z', real=True)
  1194. # Dependent if density is unchanged when one is given information about
  1195. # the other
  1196. return (density(a, Eq(b, z)) != density(a) or
  1197. density(b, Eq(a, z)) != density(b))
  1198. def independent(a, b):
  1199. """
  1200. Independence of two random expressions.
  1201. Two expressions are independent if knowledge of one does not change
  1202. computations on the other.
  1203. Examples
  1204. ========
  1205. >>> from sympy.stats import Normal, independent, given
  1206. >>> from sympy import Tuple, Eq
  1207. >>> X, Y = Normal('X', 0, 1), Normal('Y', 0, 1)
  1208. >>> independent(X, Y)
  1209. True
  1210. >>> independent(2*X + Y, -Y)
  1211. False
  1212. >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3))
  1213. >>> independent(X, Y)
  1214. False
  1215. See Also
  1216. ========
  1217. dependent
  1218. """
  1219. return not dependent(a, b)
  1220. def pspace_independent(a, b):
  1221. """
  1222. Tests for independence between a and b by checking if their PSpaces have
  1223. overlapping symbols. This is a sufficient but not necessary condition for
  1224. independence and is intended to be used internally.
  1225. Notes
  1226. =====
  1227. pspace_independent(a, b) implies independent(a, b)
  1228. independent(a, b) does not imply pspace_independent(a, b)
  1229. """
  1230. a_symbols = set(pspace(b).symbols)
  1231. b_symbols = set(pspace(a).symbols)
  1232. if len(set(random_symbols(a)).intersection(random_symbols(b))) != 0:
  1233. return False
  1234. if len(a_symbols.intersection(b_symbols)) == 0:
  1235. return True
  1236. return None
  1237. def rv_subs(expr, symbols=None):
  1238. """
  1239. Given a random expression replace all random variables with their symbols.
  1240. If symbols keyword is given restrict the swap to only the symbols listed.
  1241. """
  1242. if symbols is None:
  1243. symbols = random_symbols(expr)
  1244. if not symbols:
  1245. return expr
  1246. swapdict = {rv: rv.symbol for rv in symbols}
  1247. return expr.subs(swapdict)
  1248. class NamedArgsMixin:
  1249. _argnames: tuple[str, ...] = ()
  1250. def __getattr__(self, attr):
  1251. try:
  1252. return self.args[self._argnames.index(attr)]
  1253. except ValueError:
  1254. raise AttributeError("'%s' object has no attribute '%s'" % (
  1255. type(self).__name__, attr))
  1256. class Distribution(Basic):
  1257. def sample(self, size=(), library='scipy', seed=None):
  1258. """ A random realization from the distribution """
  1259. module = import_module(library)
  1260. if library in {'scipy', 'numpy', 'pymc3', 'pymc'} and module is None:
  1261. raise ValueError("Failed to import %s" % library)
  1262. if library == 'scipy':
  1263. # scipy does not require map as it can handle using custom distributions.
  1264. # However, we will still use a map where we can.
  1265. # TODO: do this for drv.py and frv.py if necessary.
  1266. # TODO: add more distributions here if there are more
  1267. # See links below referring to sections beginning with "A common parametrization..."
  1268. # I will remove all these comments if everything is ok.
  1269. from sympy.stats.sampling.sample_scipy import do_sample_scipy
  1270. import numpy
  1271. if seed is None or isinstance(seed, int):
  1272. rand_state = numpy.random.default_rng(seed=seed)
  1273. else:
  1274. rand_state = seed
  1275. samps = do_sample_scipy(self, size, rand_state)
  1276. elif library == 'numpy':
  1277. from sympy.stats.sampling.sample_numpy import do_sample_numpy
  1278. import numpy
  1279. if seed is None or isinstance(seed, int):
  1280. rand_state = numpy.random.default_rng(seed=seed)
  1281. else:
  1282. rand_state = seed
  1283. _size = None if size == () else size
  1284. samps = do_sample_numpy(self, _size, rand_state)
  1285. elif library in ('pymc', 'pymc3'):
  1286. from sympy.stats.sampling.sample_pymc import do_sample_pymc
  1287. import logging
  1288. logging.getLogger("pymc").setLevel(logging.ERROR)
  1289. try:
  1290. import pymc
  1291. except ImportError:
  1292. import pymc3 as pymc
  1293. with pymc.Model():
  1294. if do_sample_pymc(self):
  1295. samps = pymc.sample(draws=prod(size), chains=1, compute_convergence_checks=False,
  1296. progressbar=False, random_seed=seed, return_inferencedata=False)[:]['X']
  1297. samps = samps.reshape(size)
  1298. else:
  1299. samps = None
  1300. else:
  1301. raise NotImplementedError("Sampling from %s is not supported yet."
  1302. % str(library))
  1303. if samps is not None:
  1304. return samps
  1305. raise NotImplementedError(
  1306. "Sampling for %s is not currently implemented from %s"
  1307. % (self, library))
  1308. def _value_check(condition, message):
  1309. """
  1310. Raise a ValueError with message if condition is False, else
  1311. return True if all conditions were True, else False.
  1312. Examples
  1313. ========
  1314. >>> from sympy.stats.rv import _value_check
  1315. >>> from sympy.abc import a, b, c
  1316. >>> from sympy import And, Dummy
  1317. >>> _value_check(2 < 3, '')
  1318. True
  1319. Here, the condition is not False, but it does not evaluate to True
  1320. so False is returned (but no error is raised). So checking if the
  1321. return value is True or False will tell you if all conditions were
  1322. evaluated.
  1323. >>> _value_check(a < b, '')
  1324. False
  1325. In this case the condition is False so an error is raised:
  1326. >>> r = Dummy(real=True)
  1327. >>> _value_check(r < r - 1, 'condition is not true')
  1328. Traceback (most recent call last):
  1329. ...
  1330. ValueError: condition is not true
  1331. If no condition of many conditions must be False, they can be
  1332. checked by passing them as an iterable:
  1333. >>> _value_check((a < 0, b < 0, c < 0), '')
  1334. False
  1335. The iterable can be a generator, too:
  1336. >>> _value_check((i < 0 for i in (a, b, c)), '')
  1337. False
  1338. The following are equivalent to the above but do not pass
  1339. an iterable:
  1340. >>> all(_value_check(i < 0, '') for i in (a, b, c))
  1341. False
  1342. >>> _value_check(And(a < 0, b < 0, c < 0), '')
  1343. False
  1344. """
  1345. if not iterable(condition):
  1346. condition = [condition]
  1347. truth = fuzzy_and(condition)
  1348. if truth == False:
  1349. raise ValueError(message)
  1350. return truth == True
  1351. def _symbol_converter(sym):
  1352. """
  1353. Casts the parameter to Symbol if it is 'str'
  1354. otherwise no operation is performed on it.
  1355. Parameters
  1356. ==========
  1357. sym
  1358. The parameter to be converted.
  1359. Returns
  1360. =======
  1361. Symbol
  1362. the parameter converted to Symbol.
  1363. Raises
  1364. ======
  1365. TypeError
  1366. If the parameter is not an instance of both str and
  1367. Symbol.
  1368. Examples
  1369. ========
  1370. >>> from sympy import Symbol
  1371. >>> from sympy.stats.rv import _symbol_converter
  1372. >>> s = _symbol_converter('s')
  1373. >>> isinstance(s, Symbol)
  1374. True
  1375. >>> _symbol_converter(1)
  1376. Traceback (most recent call last):
  1377. ...
  1378. TypeError: 1 is neither a Symbol nor a string
  1379. >>> r = Symbol('r')
  1380. >>> isinstance(r, Symbol)
  1381. True
  1382. """
  1383. if isinstance(sym, str):
  1384. sym = Symbol(sym)
  1385. if not isinstance(sym, Symbol):
  1386. raise TypeError("%s is neither a Symbol nor a string"%(sym))
  1387. return sym
  1388. def sample_stochastic_process(process):
  1389. """
  1390. This function is used to sample from stochastic process.
  1391. Parameters
  1392. ==========
  1393. process: StochasticProcess
  1394. Process used to extract the samples. It must be an instance of
  1395. StochasticProcess
  1396. Examples
  1397. ========
  1398. >>> from sympy.stats import sample_stochastic_process, DiscreteMarkovChain
  1399. >>> from sympy import Matrix
  1400. >>> T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]])
  1401. >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T)
  1402. >>> next(sample_stochastic_process(Y)) in Y.state_space
  1403. True
  1404. >>> next(sample_stochastic_process(Y)) # doctest: +SKIP
  1405. 0
  1406. >>> next(sample_stochastic_process(Y)) # doctest: +SKIP
  1407. 2
  1408. Returns
  1409. =======
  1410. sample: iterator object
  1411. iterator object containing the sample of given process
  1412. """
  1413. from sympy.stats.stochastic_process_types import StochasticProcess
  1414. if not isinstance(process, StochasticProcess):
  1415. raise ValueError("Process must be an instance of Stochastic Process")
  1416. return process.sample()