123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- r"""
- The ``distributions`` package contains parameterizable probability distributions
- and sampling functions. This allows the construction of stochastic computation
- graphs and stochastic gradient estimators for optimization. This package
- generally follows the design of the `TensorFlow Distributions`_ package.
- .. _`TensorFlow Distributions`:
- https://arxiv.org/abs/1711.10604
- It is not possible to directly backpropagate through random samples. However,
- there are two main methods for creating surrogate functions that can be
- backpropagated through. These are the score function estimator/likelihood ratio
- estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly
- seen as the basis for policy gradient methods in reinforcement learning, and the
- pathwise derivative estimator is commonly seen in the reparameterization trick
- in variational autoencoders. Whilst the score function only requires the value
- of samples :math:`f(x)`, the pathwise derivative requires the derivative
- :math:`f'(x)`. The next sections discuss these two in a reinforcement learning
- example. For more details see
- `Gradient Estimation Using Stochastic Computation Graphs`_ .
- .. _`Gradient Estimation Using Stochastic Computation Graphs`:
- https://arxiv.org/abs/1506.05254
- Score function
- ^^^^^^^^^^^^^^
- When the probability density function is differentiable with respect to its
- parameters, we only need :meth:`~torch.distributions.Distribution.sample` and
- :meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE:
- .. math::
- \Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}
- where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate,
- :math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of
- taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`.
- In practice we would sample an action from the output of a network, apply this
- action in an environment, and then use ``log_prob`` to construct an equivalent
- loss function. Note that we use a negative because optimizers use gradient
- descent, whilst the rule above assumes gradient ascent. With a categorical
- policy, the code for implementing REINFORCE would be as follows::
- probs = policy_network(state)
- # Note that this is equivalent to what used to be called multinomial
- m = Categorical(probs)
- action = m.sample()
- next_state, reward = env.step(action)
- loss = -m.log_prob(action) * reward
- loss.backward()
- Pathwise derivative
- ^^^^^^^^^^^^^^^^^^^
- The other way to implement these stochastic/policy gradients would be to use the
- reparameterization trick from the
- :meth:`~torch.distributions.Distribution.rsample` method, where the
- parameterized random variable can be constructed via a parameterized
- deterministic function of a parameter-free random variable. The reparameterized
- sample therefore becomes differentiable. The code for implementing the pathwise
- derivative would be as follows::
- params = policy_network(state)
- m = Normal(*params)
- # Any distribution with .has_rsample == True could work based on the application
- action = m.rsample()
- next_state, reward = env.step(action) # Assuming that reward is differentiable
- loss = -reward
- loss.backward()
- """
- from .bernoulli import Bernoulli
- from .beta import Beta
- from .binomial import Binomial
- from .categorical import Categorical
- from .cauchy import Cauchy
- from .chi2 import Chi2
- from .constraint_registry import biject_to, transform_to
- from .continuous_bernoulli import ContinuousBernoulli
- from .dirichlet import Dirichlet
- from .distribution import Distribution
- from .exp_family import ExponentialFamily
- from .exponential import Exponential
- from .fishersnedecor import FisherSnedecor
- from .gamma import Gamma
- from .geometric import Geometric
- from .gumbel import Gumbel
- from .half_cauchy import HalfCauchy
- from .half_normal import HalfNormal
- from .independent import Independent
- from .kl import kl_divergence, register_kl, _add_kl_info
- from .kumaraswamy import Kumaraswamy
- from .laplace import Laplace
- from .lkj_cholesky import LKJCholesky
- from .log_normal import LogNormal
- from .logistic_normal import LogisticNormal
- from .lowrank_multivariate_normal import LowRankMultivariateNormal
- from .mixture_same_family import MixtureSameFamily
- from .multinomial import Multinomial
- from .multivariate_normal import MultivariateNormal
- from .negative_binomial import NegativeBinomial
- from .normal import Normal
- from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough
- from .pareto import Pareto
- from .poisson import Poisson
- from .relaxed_bernoulli import RelaxedBernoulli
- from .relaxed_categorical import RelaxedOneHotCategorical
- from .studentT import StudentT
- from .transformed_distribution import TransformedDistribution
- from .transforms import * # noqa: F403
- from .uniform import Uniform
- from .von_mises import VonMises
- from .weibull import Weibull
- from .wishart import Wishart
- from . import transforms
- _add_kl_info()
- del _add_kl_info
- __all__ = [
- 'Bernoulli',
- 'Beta',
- 'Binomial',
- 'Categorical',
- 'Cauchy',
- 'Chi2',
- 'ContinuousBernoulli',
- 'Dirichlet',
- 'Distribution',
- 'Exponential',
- 'ExponentialFamily',
- 'FisherSnedecor',
- 'Gamma',
- 'Geometric',
- 'Gumbel',
- 'HalfCauchy',
- 'HalfNormal',
- 'Independent',
- 'Kumaraswamy',
- 'LKJCholesky',
- 'Laplace',
- 'LogNormal',
- 'LogisticNormal',
- 'LowRankMultivariateNormal',
- 'MixtureSameFamily',
- 'Multinomial',
- 'MultivariateNormal',
- 'NegativeBinomial',
- 'Normal',
- 'OneHotCategorical',
- 'OneHotCategoricalStraightThrough',
- 'Pareto',
- 'RelaxedBernoulli',
- 'RelaxedOneHotCategorical',
- 'StudentT',
- 'Poisson',
- 'Uniform',
- 'VonMises',
- 'Weibull',
- 'Wishart',
- 'TransformedDistribution',
- 'biject_to',
- 'kl_divergence',
- 'register_kl',
- 'transform_to',
- ]
- __all__.extend(transforms.__all__)
|