12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- import torch
- from torch.distributions.distribution import Distribution
- __all__ = ['ExponentialFamily']
- class ExponentialFamily(Distribution):
- r"""
- ExponentialFamily is the abstract base class for probability distributions belonging to an
- exponential family, whose probability mass/density function has the form is defined below
- .. math::
- p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
- where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic,
- :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier
- measure.
- Note:
- This class is an intermediary between the `Distribution` class and distributions which belong
- to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL
- divergence methods. We use this class to compute the entropy and KL divergence using the AD
- framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
- Cross-entropies of Exponential Families).
- """
- @property
- def _natural_params(self):
- """
- Abstract method for natural parameters. Returns a tuple of Tensors based
- on the distribution
- """
- raise NotImplementedError
- def _log_normalizer(self, *natural_params):
- """
- Abstract method for log normalizer function. Returns a log normalizer based on
- the distribution and input
- """
- raise NotImplementedError
- @property
- def _mean_carrier_measure(self):
- """
- Abstract method for expected carrier measure, which is required for computing
- entropy.
- """
- raise NotImplementedError
- def entropy(self):
- """
- Method to compute the entropy using Bregman divergence of the log normalizer.
- """
- result = -self._mean_carrier_measure
- nparams = [p.detach().requires_grad_() for p in self._natural_params]
- lg_normal = self._log_normalizer(*nparams)
- gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
- result += lg_normal
- for np, g in zip(nparams, gradients):
- result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)
- return result
|