exp_family.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import torch
  2. from torch.distributions.distribution import Distribution
  3. __all__ = ['ExponentialFamily']
  4. class ExponentialFamily(Distribution):
  5. r"""
  6. ExponentialFamily is the abstract base class for probability distributions belonging to an
  7. exponential family, whose probability mass/density function has the form is defined below
  8. .. math::
  9. p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
  10. where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic,
  11. :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier
  12. measure.
  13. Note:
  14. This class is an intermediary between the `Distribution` class and distributions which belong
  15. to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL
  16. divergence methods. We use this class to compute the entropy and KL divergence using the AD
  17. framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
  18. Cross-entropies of Exponential Families).
  19. """
  20. @property
  21. def _natural_params(self):
  22. """
  23. Abstract method for natural parameters. Returns a tuple of Tensors based
  24. on the distribution
  25. """
  26. raise NotImplementedError
  27. def _log_normalizer(self, *natural_params):
  28. """
  29. Abstract method for log normalizer function. Returns a log normalizer based on
  30. the distribution and input
  31. """
  32. raise NotImplementedError
  33. @property
  34. def _mean_carrier_measure(self):
  35. """
  36. Abstract method for expected carrier measure, which is required for computing
  37. entropy.
  38. """
  39. raise NotImplementedError
  40. def entropy(self):
  41. """
  42. Method to compute the entropy using Bregman divergence of the log normalizer.
  43. """
  44. result = -self._mean_carrier_measure
  45. nparams = [p.detach().requires_grad_() for p in self._natural_params]
  46. lg_normal = self._log_normalizer(*nparams)
  47. gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
  48. result += lg_normal
  49. for np, g in zip(nparams, gradients):
  50. result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)
  51. return result