__init__.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. r"""
  2. The ``distributions`` package contains parameterizable probability distributions
  3. and sampling functions. This allows the construction of stochastic computation
  4. graphs and stochastic gradient estimators for optimization. This package
  5. generally follows the design of the `TensorFlow Distributions`_ package.
  6. .. _`TensorFlow Distributions`:
  7. https://arxiv.org/abs/1711.10604
  8. It is not possible to directly backpropagate through random samples. However,
  9. there are two main methods for creating surrogate functions that can be
  10. backpropagated through. These are the score function estimator/likelihood ratio
  11. estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly
  12. seen as the basis for policy gradient methods in reinforcement learning, and the
  13. pathwise derivative estimator is commonly seen in the reparameterization trick
  14. in variational autoencoders. Whilst the score function only requires the value
  15. of samples :math:`f(x)`, the pathwise derivative requires the derivative
  16. :math:`f'(x)`. The next sections discuss these two in a reinforcement learning
  17. example. For more details see
  18. `Gradient Estimation Using Stochastic Computation Graphs`_ .
  19. .. _`Gradient Estimation Using Stochastic Computation Graphs`:
  20. https://arxiv.org/abs/1506.05254
  21. Score function
  22. ^^^^^^^^^^^^^^
  23. When the probability density function is differentiable with respect to its
  24. parameters, we only need :meth:`~torch.distributions.Distribution.sample` and
  25. :meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE:
  26. .. math::
  27. \Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}
  28. where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,
  29. :math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of
  30. taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.
  31. In practice we would sample an action from the output of a network, apply this
  32. action in an environment, and then use ``log_prob`` to construct an equivalent
  33. loss function. Note that we use a negative because optimizers use gradient
  34. descent, whilst the rule above assumes gradient ascent. With a categorical
  35. policy, the code for implementing REINFORCE would be as follows::
  36. probs = policy_network(state)
  37. # Note that this is equivalent to what used to be called multinomial
  38. m = Categorical(probs)
  39. action = m.sample()
  40. next_state, reward = env.step(action)
  41. loss = -m.log_prob(action) * reward
  42. loss.backward()
  43. Pathwise derivative
  44. ^^^^^^^^^^^^^^^^^^^
  45. The other way to implement these stochastic/policy gradients would be to use the
  46. reparameterization trick from the
  47. :meth:`~torch.distributions.Distribution.rsample` method, where the
  48. parameterized random variable can be constructed via a parameterized
  49. deterministic function of a parameter-free random variable. The reparameterized
  50. sample therefore becomes differentiable. The code for implementing the pathwise
  51. derivative would be as follows::
  52. params = policy_network(state)
  53. m = Normal(*params)
  54. # Any distribution with .has_rsample == True could work based on the application
  55. action = m.rsample()
  56. next_state, reward = env.step(action) # Assuming that reward is differentiable
  57. loss = -reward
  58. loss.backward()
  59. """
  60. from .bernoulli import Bernoulli
  61. from .beta import Beta
  62. from .binomial import Binomial
  63. from .categorical import Categorical
  64. from .cauchy import Cauchy
  65. from .chi2 import Chi2
  66. from .constraint_registry import biject_to, transform_to
  67. from .continuous_bernoulli import ContinuousBernoulli
  68. from .dirichlet import Dirichlet
  69. from .distribution import Distribution
  70. from .exp_family import ExponentialFamily
  71. from .exponential import Exponential
  72. from .fishersnedecor import FisherSnedecor
  73. from .gamma import Gamma
  74. from .geometric import Geometric
  75. from .gumbel import Gumbel
  76. from .half_cauchy import HalfCauchy
  77. from .half_normal import HalfNormal
  78. from .independent import Independent
  79. from .kl import kl_divergence, register_kl, _add_kl_info
  80. from .kumaraswamy import Kumaraswamy
  81. from .laplace import Laplace
  82. from .lkj_cholesky import LKJCholesky
  83. from .log_normal import LogNormal
  84. from .logistic_normal import LogisticNormal
  85. from .lowrank_multivariate_normal import LowRankMultivariateNormal
  86. from .mixture_same_family import MixtureSameFamily
  87. from .multinomial import Multinomial
  88. from .multivariate_normal import MultivariateNormal
  89. from .negative_binomial import NegativeBinomial
  90. from .normal import Normal
  91. from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
  92. from .pareto import Pareto
  93. from .poisson import Poisson
  94. from .relaxed_bernoulli import RelaxedBernoulli
  95. from .relaxed_categorical import RelaxedOneHotCategorical
  96. from .studentT import StudentT
  97. from .transformed_distribution import TransformedDistribution
  98. from .transforms import * # noqa: F403
  99. from .uniform import Uniform
  100. from .von_mises import VonMises
  101. from .weibull import Weibull
  102. from .wishart import Wishart
  103. from . import transforms
  104. _add_kl_info()
  105. del _add_kl_info
  106. __all__ = [
  107. 'Bernoulli',
  108. 'Beta',
  109. 'Binomial',
  110. 'Categorical',
  111. 'Cauchy',
  112. 'Chi2',
  113. 'ContinuousBernoulli',
  114. 'Dirichlet',
  115. 'Distribution',
  116. 'Exponential',
  117. 'ExponentialFamily',
  118. 'FisherSnedecor',
  119. 'Gamma',
  120. 'Geometric',
  121. 'Gumbel',
  122. 'HalfCauchy',
  123. 'HalfNormal',
  124. 'Independent',
  125. 'Kumaraswamy',
  126. 'LKJCholesky',
  127. 'Laplace',
  128. 'LogNormal',
  129. 'LogisticNormal',
  130. 'LowRankMultivariateNormal',
  131. 'MixtureSameFamily',
  132. 'Multinomial',
  133. 'MultivariateNormal',
  134. 'NegativeBinomial',
  135. 'Normal',
  136. 'OneHotCategorical',
  137. 'OneHotCategoricalStraightThrough',
  138. 'Pareto',
  139. 'RelaxedBernoulli',
  140. 'RelaxedOneHotCategorical',
  141. 'StudentT',
  142. 'Poisson',
  143. 'Uniform',
  144. 'VonMises',
  145. 'Weibull',
  146. 'Wishart',
  147. 'TransformedDistribution',
  148. 'biject_to',
  149. 'kl_divergence',
  150. 'register_kl',
  151. 'transform_to',
  152. ]
  153. __all__.extend(transforms.__all__)