_logsumexp.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import numpy as np
  2. from scipy._lib._util import _asarray_validated
  3. __all__ = ["logsumexp", "softmax", "log_softmax"]
  4. def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
  5. """Compute the log of the sum of exponentials of input elements.
  6. Parameters
  7. ----------
  8. a : array_like
  9. Input array.
  10. axis : None or int or tuple of ints, optional
  11. Axis or axes over which the sum is taken. By default `axis` is None,
  12. and all elements are summed.
  13. .. versionadded:: 0.11.0
  14. b : array-like, optional
  15. Scaling factor for exp(`a`) must be of the same shape as `a` or
  16. broadcastable to `a`. These values may be negative in order to
  17. implement subtraction.
  18. .. versionadded:: 0.12.0
  19. keepdims : bool, optional
  20. If this is set to True, the axes which are reduced are left in the
  21. result as dimensions with size one. With this option, the result
  22. will broadcast correctly against the original array.
  23. .. versionadded:: 0.15.0
  24. return_sign : bool, optional
  25. If this is set to True, the result will be a pair containing sign
  26. information; if False, results that are negative will be returned
  27. as NaN. Default is False (no sign information).
  28. .. versionadded:: 0.16.0
  29. Returns
  30. -------
  31. res : ndarray
  32. The result, ``np.log(np.sum(np.exp(a)))`` calculated in a numerically
  33. more stable way. If `b` is given then ``np.log(np.sum(b*np.exp(a)))``
  34. is returned.
  35. sgn : ndarray
  36. If return_sign is True, this will be an array of floating-point
  37. numbers matching res and +1, 0, or -1 depending on the sign
  38. of the result. If False, only one result is returned.
  39. See Also
  40. --------
  41. numpy.logaddexp, numpy.logaddexp2
  42. Notes
  43. -----
  44. NumPy has a logaddexp function which is very similar to `logsumexp`, but
  45. only handles two arguments. `logaddexp.reduce` is similar to this
  46. function, but may be less stable.
  47. Examples
  48. --------
  49. >>> import numpy as np
  50. >>> from scipy.special import logsumexp
  51. >>> a = np.arange(10)
  52. >>> logsumexp(a)
  53. 9.4586297444267107
  54. >>> np.log(np.sum(np.exp(a)))
  55. 9.4586297444267107
  56. With weights
  57. >>> a = np.arange(10)
  58. >>> b = np.arange(10, 0, -1)
  59. >>> logsumexp(a, b=b)
  60. 9.9170178533034665
  61. >>> np.log(np.sum(b*np.exp(a)))
  62. 9.9170178533034647
  63. Returning a sign flag
  64. >>> logsumexp([1,2],b=[1,-1],return_sign=True)
  65. (1.5413248546129181, -1.0)
  66. Notice that `logsumexp` does not directly support masked arrays. To use it
  67. on a masked array, convert the mask into zero weights:
  68. >>> a = np.ma.array([np.log(2), 2, np.log(3)],
  69. ... mask=[False, True, False])
  70. >>> b = (~a.mask).astype(int)
  71. >>> logsumexp(a.data, b=b), np.log(5)
  72. 1.6094379124341005, 1.6094379124341005
  73. """
  74. a = _asarray_validated(a, check_finite=False)
  75. if b is not None:
  76. a, b = np.broadcast_arrays(a, b)
  77. if np.any(b == 0):
  78. a = a + 0. # promote to at least float
  79. a[b == 0] = -np.inf
  80. a_max = np.amax(a, axis=axis, keepdims=True)
  81. if a_max.ndim > 0:
  82. a_max[~np.isfinite(a_max)] = 0
  83. elif not np.isfinite(a_max):
  84. a_max = 0
  85. if b is not None:
  86. b = np.asarray(b)
  87. tmp = b * np.exp(a - a_max)
  88. else:
  89. tmp = np.exp(a - a_max)
  90. # suppress warnings about log of zero
  91. with np.errstate(divide='ignore'):
  92. s = np.sum(tmp, axis=axis, keepdims=keepdims)
  93. if return_sign:
  94. sgn = np.sign(s)
  95. s *= sgn # /= makes more sense but we need zero -> zero
  96. out = np.log(s)
  97. if not keepdims:
  98. a_max = np.squeeze(a_max, axis=axis)
  99. out += a_max
  100. if return_sign:
  101. return out, sgn
  102. else:
  103. return out
  104. def softmax(x, axis=None):
  105. r"""Compute the softmax function.
  106. The softmax function transforms each element of a collection by
  107. computing the exponential of each element divided by the sum of the
  108. exponentials of all the elements. That is, if `x` is a one-dimensional
  109. numpy array::
  110. softmax(x) = np.exp(x)/sum(np.exp(x))
  111. Parameters
  112. ----------
  113. x : array_like
  114. Input array.
  115. axis : int or tuple of ints, optional
  116. Axis to compute values along. Default is None and softmax will be
  117. computed over the entire array `x`.
  118. Returns
  119. -------
  120. s : ndarray
  121. An array the same shape as `x`. The result will sum to 1 along the
  122. specified axis.
  123. Notes
  124. -----
  125. The formula for the softmax function :math:`\sigma(x)` for a vector
  126. :math:`x = \{x_0, x_1, ..., x_{n-1}\}` is
  127. .. math:: \sigma(x)_j = \frac{e^{x_j}}{\sum_k e^{x_k}}
  128. The `softmax` function is the gradient of `logsumexp`.
  129. The implementation uses shifting to avoid overflow. See [1]_ for more
  130. details.
  131. .. versionadded:: 1.2.0
  132. References
  133. ----------
  134. .. [1] P. Blanchard, D.J. Higham, N.J. Higham, "Accurately computing the
  135. log-sum-exp and softmax functions", IMA Journal of Numerical Analysis,
  136. Vol.41(4), :doi:`10.1093/imanum/draa038`.
  137. Examples
  138. --------
  139. >>> import numpy as np
  140. >>> from scipy.special import softmax
  141. >>> np.set_printoptions(precision=5)
  142. >>> x = np.array([[1, 0.5, 0.2, 3],
  143. ... [1, -1, 7, 3],
  144. ... [2, 12, 13, 3]])
  145. ...
  146. Compute the softmax transformation over the entire array.
  147. >>> m = softmax(x)
  148. >>> m
  149. array([[ 4.48309e-06, 2.71913e-06, 2.01438e-06, 3.31258e-05],
  150. [ 4.48309e-06, 6.06720e-07, 1.80861e-03, 3.31258e-05],
  151. [ 1.21863e-05, 2.68421e-01, 7.29644e-01, 3.31258e-05]])
  152. >>> m.sum()
  153. 1.0
  154. Compute the softmax transformation along the first axis (i.e., the
  155. columns).
  156. >>> m = softmax(x, axis=0)
  157. >>> m
  158. array([[ 2.11942e-01, 1.01300e-05, 2.75394e-06, 3.33333e-01],
  159. [ 2.11942e-01, 2.26030e-06, 2.47262e-03, 3.33333e-01],
  160. [ 5.76117e-01, 9.99988e-01, 9.97525e-01, 3.33333e-01]])
  161. >>> m.sum(axis=0)
  162. array([ 1., 1., 1., 1.])
  163. Compute the softmax transformation along the second axis (i.e., the rows).
  164. >>> m = softmax(x, axis=1)
  165. >>> m
  166. array([[ 1.05877e-01, 6.42177e-02, 4.75736e-02, 7.82332e-01],
  167. [ 2.42746e-03, 3.28521e-04, 9.79307e-01, 1.79366e-02],
  168. [ 1.22094e-05, 2.68929e-01, 7.31025e-01, 3.31885e-05]])
  169. >>> m.sum(axis=1)
  170. array([ 1., 1., 1.])
  171. """
  172. x = _asarray_validated(x, check_finite=False)
  173. x_max = np.amax(x, axis=axis, keepdims=True)
  174. exp_x_shifted = np.exp(x - x_max)
  175. return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
  176. def log_softmax(x, axis=None):
  177. r"""Compute the logarithm of the softmax function.
  178. In principle::
  179. log_softmax(x) = log(softmax(x))
  180. but using a more accurate implementation.
  181. Parameters
  182. ----------
  183. x : array_like
  184. Input array.
  185. axis : int or tuple of ints, optional
  186. Axis to compute values along. Default is None and softmax will be
  187. computed over the entire array `x`.
  188. Returns
  189. -------
  190. s : ndarray or scalar
  191. An array with the same shape as `x`. Exponential of the result will
  192. sum to 1 along the specified axis. If `x` is a scalar, a scalar is
  193. returned.
  194. Notes
  195. -----
  196. `log_softmax` is more accurate than ``np.log(softmax(x))`` with inputs that
  197. make `softmax` saturate (see examples below).
  198. .. versionadded:: 1.5.0
  199. Examples
  200. --------
  201. >>> import numpy as np
  202. >>> from scipy.special import log_softmax
  203. >>> from scipy.special import softmax
  204. >>> np.set_printoptions(precision=5)
  205. >>> x = np.array([1000.0, 1.0])
  206. >>> y = log_softmax(x)
  207. >>> y
  208. array([ 0., -999.])
  209. >>> with np.errstate(divide='ignore'):
  210. ... y = np.log(softmax(x))
  211. ...
  212. >>> y
  213. array([ 0., -inf])
  214. """
  215. x = _asarray_validated(x, check_finite=False)
  216. x_max = np.amax(x, axis=axis, keepdims=True)
  217. if x_max.ndim > 0:
  218. x_max[~np.isfinite(x_max)] = 0
  219. elif not np.isfinite(x_max):
  220. x_max = 0
  221. tmp = x - x_max
  222. exp_tmp = np.exp(tmp)
  223. # suppress warnings about log of zero
  224. with np.errstate(divide='ignore'):
  225. s = np.sum(exp_tmp, axis=axis, keepdims=True)
  226. out = np.log(s)
  227. out = tmp - out
  228. return out