_mannwhitneyu.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. import numpy as np
  2. from collections import namedtuple
  3. from scipy import special
  4. from scipy import stats
  5. from ._axis_nan_policy import _axis_nan_policy_factory
  6. def _broadcast_concatenate(x, y, axis):
  7. '''Broadcast then concatenate arrays, leaving concatenation axis last'''
  8. x = np.moveaxis(x, axis, -1)
  9. y = np.moveaxis(y, axis, -1)
  10. z = np.broadcast(x[..., 0], y[..., 0])
  11. x = np.broadcast_to(x, z.shape + (x.shape[-1],))
  12. y = np.broadcast_to(y, z.shape + (y.shape[-1],))
  13. z = np.concatenate((x, y), axis=-1)
  14. return x, y, z
  15. class _MWU:
  16. '''Distribution of MWU statistic under the null hypothesis'''
  17. # Possible improvement: if m and n are small enough, use integer arithmetic
  18. def __init__(self):
  19. '''Minimal initializer'''
  20. self._fmnks = -np.ones((1, 1, 1))
  21. self._recursive = None
  22. def pmf(self, k, m, n):
  23. if (self._recursive is None and m <= 500 and n <= 500
  24. or self._recursive):
  25. return self.pmf_recursive(k, m, n)
  26. else:
  27. return self.pmf_iterative(k, m, n)
  28. def pmf_recursive(self, k, m, n):
  29. '''Probability mass function, recursive version'''
  30. self._resize_fmnks(m, n, np.max(k))
  31. # could loop over just the unique elements, but probably not worth
  32. # the time to find them
  33. for i in np.ravel(k):
  34. self._f(m, n, i)
  35. return self._fmnks[m, n, k] / special.binom(m + n, m)
  36. def pmf_iterative(self, k, m, n):
  37. '''Probability mass function, iterative version'''
  38. fmnks = {}
  39. for i in np.ravel(k):
  40. fmnks = _mwu_f_iterative(m, n, i, fmnks)
  41. return (np.array([fmnks[(m, n, ki)] for ki in k])
  42. / special.binom(m + n, m))
  43. def cdf(self, k, m, n):
  44. '''Cumulative distribution function'''
  45. # We could use the fact that the distribution is symmetric to avoid
  46. # summing more than m*n/2 terms, but it might not be worth the
  47. # overhead. Let's leave that to an improvement.
  48. pmfs = self.pmf(np.arange(0, np.max(k) + 1), m, n)
  49. cdfs = np.cumsum(pmfs)
  50. return cdfs[k]
  51. def sf(self, k, m, n):
  52. '''Survival function'''
  53. # Use the fact that the distribution is symmetric; i.e.
  54. # _f(m, n, m*n-k) = _f(m, n, k), and sum from the left
  55. k = m*n - k
  56. # Note that both CDF and SF include the PMF at k. The p-value is
  57. # calculated from the SF and should include the mass at k, so this
  58. # is desirable
  59. return self.cdf(k, m, n)
  60. def _resize_fmnks(self, m, n, k):
  61. '''If necessary, expand the array that remembers PMF values'''
  62. # could probably use `np.pad` but I'm not sure it would save code
  63. shape_old = np.array(self._fmnks.shape)
  64. shape_new = np.array((m+1, n+1, k+1))
  65. if np.any(shape_new > shape_old):
  66. shape = np.maximum(shape_old, shape_new)
  67. fmnks = -np.ones(shape) # create the new array
  68. m0, n0, k0 = shape_old
  69. fmnks[:m0, :n0, :k0] = self._fmnks # copy remembered values
  70. self._fmnks = fmnks
  71. def _f(self, m, n, k):
  72. '''Recursive implementation of function of [3] Theorem 2.5'''
  73. # [3] Theorem 2.5 Line 1
  74. if k < 0 or m < 0 or n < 0 or k > m*n:
  75. return 0
  76. # if already calculated, return the value
  77. if self._fmnks[m, n, k] >= 0:
  78. return self._fmnks[m, n, k]
  79. if k == 0 and m >= 0 and n >= 0: # [3] Theorem 2.5 Line 2
  80. fmnk = 1
  81. else: # [3] Theorem 2.5 Line 3 / Equation 3
  82. fmnk = self._f(m-1, n, k-n) + self._f(m, n-1, k)
  83. self._fmnks[m, n, k] = fmnk # remember result
  84. return fmnk
  85. # Maintain state for faster repeat calls to mannwhitneyu w/ method='exact'
  86. _mwu_state = _MWU()
  87. def _mwu_f_iterative(m, n, k, fmnks):
  88. '''Iterative implementation of function of [3] Theorem 2.5'''
  89. def _base_case(m, n, k):
  90. '''Base cases from recursive version'''
  91. # if already calculated, return the value
  92. if fmnks.get((m, n, k), -1) >= 0:
  93. return fmnks[(m, n, k)]
  94. # [3] Theorem 2.5 Line 1
  95. elif k < 0 or m < 0 or n < 0 or k > m*n:
  96. return 0
  97. # [3] Theorem 2.5 Line 2
  98. elif k == 0 and m >= 0 and n >= 0:
  99. return 1
  100. return None
  101. stack = [(m, n, k)]
  102. fmnk = None
  103. while stack:
  104. # Popping only if necessary would save a tiny bit of time, but NWI.
  105. m, n, k = stack.pop()
  106. # If we're at a base case, continue (stack unwinds)
  107. fmnk = _base_case(m, n, k)
  108. if fmnk is not None:
  109. fmnks[(m, n, k)] = fmnk
  110. continue
  111. # If both terms are base cases, continue (stack unwinds)
  112. f1 = _base_case(m-1, n, k-n)
  113. f2 = _base_case(m, n-1, k)
  114. if f1 is not None and f2 is not None:
  115. # [3] Theorem 2.5 Line 3 / Equation 3
  116. fmnk = f1 + f2
  117. fmnks[(m, n, k)] = fmnk
  118. continue
  119. # recurse deeper
  120. stack.append((m, n, k))
  121. if f1 is None:
  122. stack.append((m-1, n, k-n))
  123. if f2 is None:
  124. stack.append((m, n-1, k))
  125. return fmnks
  126. def _tie_term(ranks):
  127. """Tie correction term"""
  128. # element i of t is the number of elements sharing rank i
  129. _, t = np.unique(ranks, return_counts=True, axis=-1)
  130. return (t**3 - t).sum(axis=-1)
  131. def _get_mwu_z(U, n1, n2, ranks, axis=0, continuity=True):
  132. '''Standardized MWU statistic'''
  133. # Follows mannwhitneyu [2]
  134. mu = n1 * n2 / 2
  135. n = n1 + n2
  136. # Tie correction according to [2]
  137. tie_term = np.apply_along_axis(_tie_term, -1, ranks)
  138. s = np.sqrt(n1*n2/12 * ((n + 1) - tie_term/(n*(n-1))))
  139. # equivalent to using scipy.stats.tiecorrect
  140. # T = np.apply_along_axis(stats.tiecorrect, -1, ranks)
  141. # s = np.sqrt(T * n1 * n2 * (n1+n2+1) / 12.0)
  142. numerator = U - mu
  143. # Continuity correction.
  144. # Because SF is always used to calculate the p-value, we can always
  145. # _subtract_ 0.5 for the continuity correction. This always increases the
  146. # p-value to account for the rest of the probability mass _at_ q = U.
  147. if continuity:
  148. numerator -= 0.5
  149. # no problem evaluating the norm SF at an infinity
  150. with np.errstate(divide='ignore', invalid='ignore'):
  151. z = numerator / s
  152. return z
  153. def _mwu_input_validation(x, y, use_continuity, alternative, axis, method):
  154. ''' Input validation and standardization for mannwhitneyu '''
  155. # Would use np.asarray_chkfinite, but infs are OK
  156. x, y = np.atleast_1d(x), np.atleast_1d(y)
  157. if np.isnan(x).any() or np.isnan(y).any():
  158. raise ValueError('`x` and `y` must not contain NaNs.')
  159. if np.size(x) == 0 or np.size(y) == 0:
  160. raise ValueError('`x` and `y` must be of nonzero size.')
  161. bools = {True, False}
  162. if use_continuity not in bools:
  163. raise ValueError(f'`use_continuity` must be one of {bools}.')
  164. alternatives = {"two-sided", "less", "greater"}
  165. alternative = alternative.lower()
  166. if alternative not in alternatives:
  167. raise ValueError(f'`alternative` must be one of {alternatives}.')
  168. axis_int = int(axis)
  169. if axis != axis_int:
  170. raise ValueError('`axis` must be an integer.')
  171. methods = {"asymptotic", "exact", "auto"}
  172. method = method.lower()
  173. if method not in methods:
  174. raise ValueError(f'`method` must be one of {methods}.')
  175. return x, y, use_continuity, alternative, axis_int, method
  176. def _tie_check(xy):
  177. """Find any ties in data"""
  178. _, t = np.unique(xy, return_counts=True, axis=-1)
  179. return np.any(t != 1)
  180. def _mwu_choose_method(n1, n2, xy, method):
  181. """Choose method 'asymptotic' or 'exact' depending on input size, ties"""
  182. # if both inputs are large, asymptotic is OK
  183. if n1 > 8 and n2 > 8:
  184. return "asymptotic"
  185. # if there are any ties, asymptotic is preferred
  186. if np.apply_along_axis(_tie_check, -1, xy).any():
  187. return "asymptotic"
  188. return "exact"
  189. MannwhitneyuResult = namedtuple('MannwhitneyuResult', ('statistic', 'pvalue'))
  190. @_axis_nan_policy_factory(MannwhitneyuResult, n_samples=2)
  191. def mannwhitneyu(x, y, use_continuity=True, alternative="two-sided",
  192. axis=0, method="auto"):
  193. r'''Perform the Mann-Whitney U rank test on two independent samples.
  194. The Mann-Whitney U test is a nonparametric test of the null hypothesis
  195. that the distribution underlying sample `x` is the same as the
  196. distribution underlying sample `y`. It is often used as a test of
  197. difference in location between distributions.
  198. Parameters
  199. ----------
  200. x, y : array-like
  201. N-d arrays of samples. The arrays must be broadcastable except along
  202. the dimension given by `axis`.
  203. use_continuity : bool, optional
  204. Whether a continuity correction (1/2) should be applied.
  205. Default is True when `method` is ``'asymptotic'``; has no effect
  206. otherwise.
  207. alternative : {'two-sided', 'less', 'greater'}, optional
  208. Defines the alternative hypothesis. Default is 'two-sided'.
  209. Let *F(u)* and *G(u)* be the cumulative distribution functions of the
  210. distributions underlying `x` and `y`, respectively. Then the following
  211. alternative hypotheses are available:
  212. * 'two-sided': the distributions are not equal, i.e. *F(u) ≠ G(u)* for
  213. at least one *u*.
  214. * 'less': the distribution underlying `x` is stochastically less
  215. than the distribution underlying `y`, i.e. *F(u) > G(u)* for all *u*.
  216. * 'greater': the distribution underlying `x` is stochastically greater
  217. than the distribution underlying `y`, i.e. *F(u) < G(u)* for all *u*.
  218. Under a more restrictive set of assumptions, the alternative hypotheses
  219. can be expressed in terms of the locations of the distributions;
  220. see [5] section 5.1.
  221. axis : int, optional
  222. Axis along which to perform the test. Default is 0.
  223. method : {'auto', 'asymptotic', 'exact'}, optional
  224. Selects the method used to calculate the *p*-value.
  225. Default is 'auto'. The following options are available.
  226. * ``'asymptotic'``: compares the standardized test statistic
  227. against the normal distribution, correcting for ties.
  228. * ``'exact'``: computes the exact *p*-value by comparing the observed
  229. :math:`U` statistic against the exact distribution of the :math:`U`
  230. statistic under the null hypothesis. No correction is made for ties.
  231. * ``'auto'``: chooses ``'exact'`` when the size of one of the samples
  232. is less than 8 and there are no ties; chooses ``'asymptotic'``
  233. otherwise.
  234. Returns
  235. -------
  236. res : MannwhitneyuResult
  237. An object containing attributes:
  238. statistic : float
  239. The Mann-Whitney U statistic corresponding with sample `x`. See
  240. Notes for the test statistic corresponding with sample `y`.
  241. pvalue : float
  242. The associated *p*-value for the chosen `alternative`.
  243. Notes
  244. -----
  245. If ``U1`` is the statistic corresponding with sample `x`, then the
  246. statistic corresponding with sample `y` is
  247. `U2 = `x.shape[axis] * y.shape[axis] - U1``.
  248. `mannwhitneyu` is for independent samples. For related / paired samples,
  249. consider `scipy.stats.wilcoxon`.
  250. `method` ``'exact'`` is recommended when there are no ties and when either
  251. sample size is less than 8 [1]_. The implementation follows the recurrence
  252. relation originally proposed in [1]_ as it is described in [3]_.
  253. Note that the exact method is *not* corrected for ties, but
  254. `mannwhitneyu` will not raise errors or warnings if there are ties in the
  255. data.
  256. The Mann-Whitney U test is a non-parametric version of the t-test for
  257. independent samples. When the means of samples from the populations
  258. are normally distributed, consider `scipy.stats.ttest_ind`.
  259. See Also
  260. --------
  261. scipy.stats.wilcoxon, scipy.stats.ranksums, scipy.stats.ttest_ind
  262. References
  263. ----------
  264. .. [1] H.B. Mann and D.R. Whitney, "On a test of whether one of two random
  265. variables is stochastically larger than the other", The Annals of
  266. Mathematical Statistics, Vol. 18, pp. 50-60, 1947.
  267. .. [2] Mann-Whitney U Test, Wikipedia,
  268. http://en.wikipedia.org/wiki/Mann-Whitney_U_test
  269. .. [3] A. Di Bucchianico, "Combinatorics, computer algebra, and the
  270. Wilcoxon-Mann-Whitney test", Journal of Statistical Planning and
  271. Inference, Vol. 79, pp. 349-364, 1999.
  272. .. [4] Rosie Shier, "Statistics: 2.3 The Mann-Whitney U Test", Mathematics
  273. Learning Support Centre, 2004.
  274. .. [5] Michael P. Fay and Michael A. Proschan. "Wilcoxon-Mann-Whitney
  275. or t-test? On assumptions for hypothesis tests and multiple \
  276. interpretations of decision rules." Statistics surveys, Vol. 4, pp.
  277. 1-39, 2010. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2857732/
  278. Examples
  279. --------
  280. We follow the example from [4]_: nine randomly sampled young adults were
  281. diagnosed with type II diabetes at the ages below.
  282. >>> males = [19, 22, 16, 29, 24]
  283. >>> females = [20, 11, 17, 12]
  284. We use the Mann-Whitney U test to assess whether there is a statistically
  285. significant difference in the diagnosis age of males and females.
  286. The null hypothesis is that the distribution of male diagnosis ages is
  287. the same as the distribution of female diagnosis ages. We decide
  288. that a confidence level of 95% is required to reject the null hypothesis
  289. in favor of the alternative that the distributions are different.
  290. Since the number of samples is very small and there are no ties in the
  291. data, we can compare the observed test statistic against the *exact*
  292. distribution of the test statistic under the null hypothesis.
  293. >>> from scipy.stats import mannwhitneyu
  294. >>> U1, p = mannwhitneyu(males, females, method="exact")
  295. >>> print(U1)
  296. 17.0
  297. `mannwhitneyu` always reports the statistic associated with the first
  298. sample, which, in this case, is males. This agrees with :math:`U_M = 17`
  299. reported in [4]_. The statistic associated with the second statistic
  300. can be calculated:
  301. >>> nx, ny = len(males), len(females)
  302. >>> U2 = nx*ny - U1
  303. >>> print(U2)
  304. 3.0
  305. This agrees with :math:`U_F = 3` reported in [4]_. The two-sided
  306. *p*-value can be calculated from either statistic, and the value produced
  307. by `mannwhitneyu` agrees with :math:`p = 0.11` reported in [4]_.
  308. >>> print(p)
  309. 0.1111111111111111
  310. The exact distribution of the test statistic is asymptotically normal, so
  311. the example continues by comparing the exact *p*-value against the
  312. *p*-value produced using the normal approximation.
  313. >>> _, pnorm = mannwhitneyu(males, females, method="asymptotic")
  314. >>> print(pnorm)
  315. 0.11134688653314041
  316. Here `mannwhitneyu`'s reported *p*-value appears to conflict with the
  317. value :math:`p = 0.09` given in [4]_. The reason is that [4]_
  318. does not apply the continuity correction performed by `mannwhitneyu`;
  319. `mannwhitneyu` reduces the distance between the test statistic and the
  320. mean :math:`\mu = n_x n_y / 2` by 0.5 to correct for the fact that the
  321. discrete statistic is being compared against a continuous distribution.
  322. Here, the :math:`U` statistic used is less than the mean, so we reduce
  323. the distance by adding 0.5 in the numerator.
  324. >>> import numpy as np
  325. >>> from scipy.stats import norm
  326. >>> U = min(U1, U2)
  327. >>> N = nx + ny
  328. >>> z = (U - nx*ny/2 + 0.5) / np.sqrt(nx*ny * (N + 1)/ 12)
  329. >>> p = 2 * norm.cdf(z) # use CDF to get p-value from smaller statistic
  330. >>> print(p)
  331. 0.11134688653314041
  332. If desired, we can disable the continuity correction to get a result
  333. that agrees with that reported in [4]_.
  334. >>> _, pnorm = mannwhitneyu(males, females, use_continuity=False,
  335. ... method="asymptotic")
  336. >>> print(pnorm)
  337. 0.0864107329737
  338. Regardless of whether we perform an exact or asymptotic test, the
  339. probability of the test statistic being as extreme or more extreme by
  340. chance exceeds 5%, so we do not consider the results statistically
  341. significant.
  342. Suppose that, before seeing the data, we had hypothesized that females
  343. would tend to be diagnosed at a younger age than males.
  344. In that case, it would be natural to provide the female ages as the
  345. first input, and we would have performed a one-sided test using
  346. ``alternative = 'less'``: females are diagnosed at an age that is
  347. stochastically less than that of males.
  348. >>> res = mannwhitneyu(females, males, alternative="less", method="exact")
  349. >>> print(res)
  350. MannwhitneyuResult(statistic=3.0, pvalue=0.05555555555555555)
  351. Again, the probability of getting a sufficiently low value of the
  352. test statistic by chance under the null hypothesis is greater than 5%,
  353. so we do not reject the null hypothesis in favor of our alternative.
  354. If it is reasonable to assume that the means of samples from the
  355. populations are normally distributed, we could have used a t-test to
  356. perform the analysis.
  357. >>> from scipy.stats import ttest_ind
  358. >>> res = ttest_ind(females, males, alternative="less")
  359. >>> print(res)
  360. Ttest_indResult(statistic=-2.239334696520584, pvalue=0.030068441095757924)
  361. Under this assumption, the *p*-value would be low enough to reject the
  362. null hypothesis in favor of the alternative.
  363. '''
  364. x, y, use_continuity, alternative, axis_int, method = (
  365. _mwu_input_validation(x, y, use_continuity, alternative, axis, method))
  366. x, y, xy = _broadcast_concatenate(x, y, axis)
  367. n1, n2 = x.shape[-1], y.shape[-1]
  368. if method == "auto":
  369. method = _mwu_choose_method(n1, n2, xy, method)
  370. # Follows [2]
  371. ranks = stats.rankdata(xy, axis=-1) # method 2, step 1
  372. R1 = ranks[..., :n1].sum(axis=-1) # method 2, step 2
  373. U1 = R1 - n1*(n1+1)/2 # method 2, step 3
  374. U2 = n1 * n2 - U1 # as U1 + U2 = n1 * n2
  375. if alternative == "greater":
  376. U, f = U1, 1 # U is the statistic to use for p-value, f is a factor
  377. elif alternative == "less":
  378. U, f = U2, 1 # Due to symmetry, use SF of U2 rather than CDF of U1
  379. else:
  380. U, f = np.maximum(U1, U2), 2 # multiply SF by two for two-sided test
  381. if method == "exact":
  382. p = _mwu_state.sf(U.astype(int), n1, n2)
  383. elif method == "asymptotic":
  384. z = _get_mwu_z(U, n1, n2, ranks, continuity=use_continuity)
  385. p = stats.norm.sf(z)
  386. p *= f
  387. # Ensure that test statistic is not greater than 1
  388. # This could happen for exact test when U = m*n/2
  389. p = np.clip(p, 0, 1)
  390. return MannwhitneyuResult(U1, p)