error_prop.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. """Tools for arithmetic error propagation."""
  2. from itertools import repeat, combinations
  3. from sympy.core.add import Add
  4. from sympy.core.mul import Mul
  5. from sympy.core.power import Pow
  6. from sympy.core.singleton import S
  7. from sympy.core.symbol import Symbol
  8. from sympy.functions.elementary.exponential import exp
  9. from sympy.simplify.simplify import simplify
  10. from sympy.stats.symbolic_probability import RandomSymbol, Variance, Covariance
  11. from sympy.stats.rv import is_random
  12. _arg0_or_var = lambda var: var.args[0] if len(var.args) > 0 else var
  13. def variance_prop(expr, consts=(), include_covar=False):
  14. r"""Symbolically propagates variance (`\sigma^2`) for expressions.
  15. This is computed as as seen in [1]_.
  16. Parameters
  17. ==========
  18. expr : Expr
  19. A SymPy expression to compute the variance for.
  20. consts : sequence of Symbols, optional
  21. Represents symbols that are known constants in the expr,
  22. and thus have zero variance. All symbols not in consts are
  23. assumed to be variant.
  24. include_covar : bool, optional
  25. Flag for whether or not to include covariances, default=False.
  26. Returns
  27. =======
  28. var_expr : Expr
  29. An expression for the total variance of the expr.
  30. The variance for the original symbols (e.g. x) are represented
  31. via instance of the Variance symbol (e.g. Variance(x)).
  32. Examples
  33. ========
  34. >>> from sympy import symbols, exp
  35. >>> from sympy.stats.error_prop import variance_prop
  36. >>> x, y = symbols('x y')
  37. >>> variance_prop(x + y)
  38. Variance(x) + Variance(y)
  39. >>> variance_prop(x * y)
  40. x**2*Variance(y) + y**2*Variance(x)
  41. >>> variance_prop(exp(2*x))
  42. 4*exp(4*x)*Variance(x)
  43. References
  44. ==========
  45. .. [1] https://en.wikipedia.org/wiki/Propagation_of_uncertainty
  46. """
  47. args = expr.args
  48. if len(args) == 0:
  49. if expr in consts:
  50. return S.Zero
  51. elif is_random(expr):
  52. return Variance(expr).doit()
  53. elif isinstance(expr, Symbol):
  54. return Variance(RandomSymbol(expr)).doit()
  55. else:
  56. return S.Zero
  57. nargs = len(args)
  58. var_args = list(map(variance_prop, args, repeat(consts, nargs),
  59. repeat(include_covar, nargs)))
  60. if isinstance(expr, Add):
  61. var_expr = Add(*var_args)
  62. if include_covar:
  63. terms = [2 * Covariance(_arg0_or_var(x), _arg0_or_var(y)).expand() \
  64. for x, y in combinations(var_args, 2)]
  65. var_expr += Add(*terms)
  66. elif isinstance(expr, Mul):
  67. terms = [v/a**2 for a, v in zip(args, var_args)]
  68. var_expr = simplify(expr**2 * Add(*terms))
  69. if include_covar:
  70. terms = [2*Covariance(_arg0_or_var(x), _arg0_or_var(y)).expand()/(a*b) \
  71. for (a, b), (x, y) in zip(combinations(args, 2),
  72. combinations(var_args, 2))]
  73. var_expr += Add(*terms)
  74. elif isinstance(expr, Pow):
  75. b = args[1]
  76. v = var_args[0] * (expr * b / args[0])**2
  77. var_expr = simplify(v)
  78. elif isinstance(expr, exp):
  79. var_expr = simplify(var_args[0] * expr**2)
  80. else:
  81. # unknown how to proceed, return variance of whole expr.
  82. var_expr = Variance(expr)
  83. return var_expr