test_axis_nan_policy.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044
  1. # Many scipy.stats functions support `axis` and `nan_policy` parameters.
  2. # When the two are combined, it can be tricky to get all the behavior just
  3. # right. This file contains a suite of common tests for scipy.stats functions
  4. # that support `axis` and `nan_policy` and additional tests for some associated
  5. # functions in stats._util.
  6. from itertools import product, combinations_with_replacement, permutations
  7. import re
  8. import pickle
  9. import pytest
  10. import numpy as np
  11. from numpy.testing import assert_allclose, assert_equal, suppress_warnings
  12. from scipy import stats
  13. from scipy.stats._axis_nan_policy import _masked_arrays_2_sentinel_arrays
  14. def unpack_ttest_result(res):
  15. low, high = res.confidence_interval()
  16. return (res.statistic, res.pvalue, res.df, res._standard_error,
  17. res._estimate, low, high)
  18. axis_nan_policy_cases = [
  19. # function, args, kwds, number of samples, number of outputs,
  20. # ... paired, unpacker function
  21. # args, kwds typically aren't needed; just showing that they work
  22. (stats.kruskal, tuple(), dict(), 3, 2, False, None), # 4 samples is slow
  23. (stats.ranksums, ('less',), dict(), 2, 2, False, None),
  24. (stats.mannwhitneyu, tuple(), {'method': 'asymptotic'}, 2, 2, False, None),
  25. (stats.wilcoxon, ('pratt',), {'mode': 'auto'}, 2, 2, True,
  26. lambda res: (res.statistic, res.pvalue)),
  27. (stats.wilcoxon, tuple(), dict(), 1, 2, True,
  28. lambda res: (res.statistic, res.pvalue)),
  29. (stats.wilcoxon, tuple(), {'mode': 'approx'}, 1, 3, True,
  30. lambda res: (res.statistic, res.pvalue, res.zstatistic)),
  31. (stats.gmean, tuple(), dict(), 1, 1, False, lambda x: (x,)),
  32. (stats.hmean, tuple(), dict(), 1, 1, False, lambda x: (x,)),
  33. (stats.pmean, (1.42,), dict(), 1, 1, False, lambda x: (x,)),
  34. (stats.kurtosis, tuple(), dict(), 1, 1, False, lambda x: (x,)),
  35. (stats.skew, tuple(), dict(), 1, 1, False, lambda x: (x,)),
  36. (stats.kstat, tuple(), dict(), 1, 1, False, lambda x: (x,)),
  37. (stats.kstatvar, tuple(), dict(), 1, 1, False, lambda x: (x,)),
  38. (stats.moment, tuple(), dict(), 1, 1, False, lambda x: (x,)),
  39. (stats.moment, tuple(), dict(moment=[1, 2]), 1, 2, False, None),
  40. (stats.jarque_bera, tuple(), dict(), 1, 2, False, None),
  41. (stats.ttest_1samp, (np.array([0]),), dict(), 1, 7, False,
  42. unpack_ttest_result),
  43. (stats.ttest_rel, tuple(), dict(), 2, 7, True, unpack_ttest_result)
  44. ]
  45. # If the message is one of those expected, put nans in
  46. # appropriate places of `statistics` and `pvalues`
  47. too_small_messages = {"The input contains nan", # for nan_policy="raise"
  48. "Degrees of freedom <= 0 for slice",
  49. "x and y should have at least 5 elements",
  50. "Data must be at least length 3",
  51. "The sample must contain at least two",
  52. "x and y must contain at least two",
  53. "division by zero",
  54. "Mean of empty slice",
  55. "Data passed to ks_2samp must not be empty",
  56. "Not enough test observations",
  57. "Not enough other observations",
  58. "At least one observation is required",
  59. "zero-size array to reduction operation maximum",
  60. "`x` and `y` must be of nonzero size.",
  61. "The exact distribution of the Wilcoxon test",
  62. "Data input must not be empty"}
  63. # If the message is one of these, results of the function may be inaccurate,
  64. # but NaNs are not to be placed
  65. inaccuracy_messages = {"Precision loss occurred in moment calculation",
  66. "Sample size too small for normal approximation."}
  67. def _mixed_data_generator(n_samples, n_repetitions, axis, rng,
  68. paired=False):
  69. # generate random samples to check the response of hypothesis tests to
  70. # samples with different (but broadcastable) shapes and various
  71. # nan patterns (e.g. all nans, some nans, no nans) along axis-slices
  72. data = []
  73. for i in range(n_samples):
  74. n_patterns = 6 # number of distinct nan patterns
  75. n_obs = 20 if paired else 20 + i # observations per axis-slice
  76. x = np.ones((n_repetitions, n_patterns, n_obs)) * np.nan
  77. for j in range(n_repetitions):
  78. samples = x[j, :, :]
  79. # case 0: axis-slice with all nans (0 reals)
  80. # cases 1-3: axis-slice with 1-3 reals (the rest nans)
  81. # case 4: axis-slice with mostly (all but two) reals
  82. # case 5: axis slice with all reals
  83. for k, n_reals in enumerate([0, 1, 2, 3, n_obs-2, n_obs]):
  84. # for cases 1-3, need paired nansw to be in the same place
  85. indices = rng.permutation(n_obs)[:n_reals]
  86. samples[k, indices] = rng.random(size=n_reals)
  87. # permute the axis-slices just to show that order doesn't matter
  88. samples[:] = rng.permutation(samples, axis=0)
  89. # For multi-sample tests, we want to test broadcasting and check
  90. # that nan policy works correctly for each nan pattern for each input.
  91. # This takes care of both simultaneosly.
  92. new_shape = [n_repetitions] + [1]*n_samples + [n_obs]
  93. new_shape[1 + i] = 6
  94. x = x.reshape(new_shape)
  95. x = np.moveaxis(x, -1, axis)
  96. data.append(x)
  97. return data
  98. def _homogeneous_data_generator(n_samples, n_repetitions, axis, rng,
  99. paired=False, all_nans=True):
  100. # generate random samples to check the response of hypothesis tests to
  101. # samples with different (but broadcastable) shapes and homogeneous
  102. # data (all nans or all finite)
  103. data = []
  104. for i in range(n_samples):
  105. n_obs = 20 if paired else 20 + i # observations per axis-slice
  106. shape = [n_repetitions] + [1]*n_samples + [n_obs]
  107. shape[1 + i] = 2
  108. x = np.ones(shape) * np.nan if all_nans else rng.random(shape)
  109. x = np.moveaxis(x, -1, axis)
  110. data.append(x)
  111. return data
  112. def nan_policy_1d(hypotest, data1d, unpacker, *args, n_outputs=2,
  113. nan_policy='raise', paired=False, _no_deco=True, **kwds):
  114. # Reference implementation for how `nan_policy` should work for 1d samples
  115. if nan_policy == 'raise':
  116. for sample in data1d:
  117. if np.any(np.isnan(sample)):
  118. raise ValueError("The input contains nan values")
  119. elif nan_policy == 'propagate':
  120. # For all hypothesis tests tested, returning nans is the right thing.
  121. # But many hypothesis tests don't propagate correctly (e.g. they treat
  122. # np.nan the same as np.inf, which doesn't make sense when ranks are
  123. # involved) so override that behavior here.
  124. for sample in data1d:
  125. if np.any(np.isnan(sample)):
  126. return np.full(n_outputs, np.nan)
  127. elif nan_policy == 'omit':
  128. # manually omit nans (or pairs in which at least one element is nan)
  129. if not paired:
  130. data1d = [sample[~np.isnan(sample)] for sample in data1d]
  131. else:
  132. nan_mask = np.isnan(data1d[0])
  133. for sample in data1d[1:]:
  134. nan_mask = np.logical_or(nan_mask, np.isnan(sample))
  135. data1d = [sample[~nan_mask] for sample in data1d]
  136. return unpacker(hypotest(*data1d, *args, _no_deco=_no_deco, **kwds))
  137. @pytest.mark.parametrize(("hypotest", "args", "kwds", "n_samples", "n_outputs",
  138. "paired", "unpacker"), axis_nan_policy_cases)
  139. @pytest.mark.parametrize(("nan_policy"), ("propagate", "omit", "raise"))
  140. @pytest.mark.parametrize(("axis"), (1,))
  141. @pytest.mark.parametrize(("data_generator"), ("mixed",))
  142. def test_axis_nan_policy_fast(hypotest, args, kwds, n_samples, n_outputs,
  143. paired, unpacker, nan_policy, axis,
  144. data_generator):
  145. _axis_nan_policy_test(hypotest, args, kwds, n_samples, n_outputs, paired,
  146. unpacker, nan_policy, axis, data_generator)
  147. @pytest.mark.slow
  148. @pytest.mark.parametrize(("hypotest", "args", "kwds", "n_samples", "n_outputs",
  149. "paired", "unpacker"), axis_nan_policy_cases)
  150. @pytest.mark.parametrize(("nan_policy"), ("propagate", "omit", "raise"))
  151. @pytest.mark.parametrize(("axis"), range(-3, 3))
  152. @pytest.mark.parametrize(("data_generator"),
  153. ("all_nans", "all_finite", "mixed"))
  154. def test_axis_nan_policy_full(hypotest, args, kwds, n_samples, n_outputs,
  155. paired, unpacker, nan_policy, axis,
  156. data_generator):
  157. _axis_nan_policy_test(hypotest, args, kwds, n_samples, n_outputs, paired,
  158. unpacker, nan_policy, axis, data_generator)
  159. def _axis_nan_policy_test(hypotest, args, kwds, n_samples, n_outputs, paired,
  160. unpacker, nan_policy, axis, data_generator):
  161. # Tests the 1D and vectorized behavior of hypothesis tests against a
  162. # reference implementation (nan_policy_1d with np.ndenumerate)
  163. # Some hypothesis tests return a non-iterable that needs an `unpacker` to
  164. # extract the statistic and p-value. For those that don't:
  165. if not unpacker:
  166. def unpacker(res):
  167. return res
  168. rng = np.random.default_rng(0)
  169. # Generate multi-dimensional test data with all important combinations
  170. # of patterns of nans along `axis`
  171. n_repetitions = 3 # number of repetitions of each pattern
  172. data_gen_kwds = {'n_samples': n_samples, 'n_repetitions': n_repetitions,
  173. 'axis': axis, 'rng': rng, 'paired': paired}
  174. if data_generator == 'mixed':
  175. inherent_size = 6 # number of distinct types of patterns
  176. data = _mixed_data_generator(**data_gen_kwds)
  177. elif data_generator == 'all_nans':
  178. inherent_size = 2 # hard-coded in _homogeneous_data_generator
  179. data_gen_kwds['all_nans'] = True
  180. data = _homogeneous_data_generator(**data_gen_kwds)
  181. elif data_generator == 'all_finite':
  182. inherent_size = 2 # hard-coded in _homogeneous_data_generator
  183. data_gen_kwds['all_nans'] = False
  184. data = _homogeneous_data_generator(**data_gen_kwds)
  185. output_shape = [n_repetitions] + [inherent_size]*n_samples
  186. # To generate reference behavior to compare against, loop over the axis-
  187. # slices in data. Make indexing easier by moving `axis` to the end and
  188. # broadcasting all samples to the same shape.
  189. data_b = [np.moveaxis(sample, axis, -1) for sample in data]
  190. data_b = [np.broadcast_to(sample, output_shape + [sample.shape[-1]])
  191. for sample in data_b]
  192. statistics = np.zeros(output_shape)
  193. pvalues = np.zeros(output_shape)
  194. for i, _ in np.ndenumerate(statistics):
  195. data1d = [sample[i] for sample in data_b]
  196. with np.errstate(divide='ignore', invalid='ignore'):
  197. try:
  198. res1d = nan_policy_1d(hypotest, data1d, unpacker, *args,
  199. n_outputs=n_outputs,
  200. nan_policy=nan_policy,
  201. paired=paired, _no_deco=True, **kwds)
  202. # Eventually we'll check the results of a single, vectorized
  203. # call of `hypotest` against the arrays `statistics` and
  204. # `pvalues` populated using the reference `nan_policy_1d`.
  205. # But while we're at it, check the results of a 1D call to
  206. # `hypotest` against the reference `nan_policy_1d`.
  207. res1db = unpacker(hypotest(*data1d, *args,
  208. nan_policy=nan_policy, **kwds))
  209. assert_equal(res1db[0], res1d[0])
  210. if len(res1db) == 2:
  211. assert_equal(res1db[1], res1d[1])
  212. # When there is not enough data in 1D samples, many existing
  213. # hypothesis tests raise errors instead of returning nans .
  214. # For vectorized calls, we put nans in the corresponding elements
  215. # of the output.
  216. except (RuntimeWarning, UserWarning, ValueError,
  217. ZeroDivisionError) as e:
  218. # whatever it is, make sure same error is raised by both
  219. # `nan_policy_1d` and `hypotest`
  220. with pytest.raises(type(e), match=re.escape(str(e))):
  221. nan_policy_1d(hypotest, data1d, unpacker, *args,
  222. n_outputs=n_outputs, nan_policy=nan_policy,
  223. paired=paired, _no_deco=True, **kwds)
  224. with pytest.raises(type(e), match=re.escape(str(e))):
  225. hypotest(*data1d, *args, nan_policy=nan_policy, **kwds)
  226. if any([str(e).startswith(message)
  227. for message in too_small_messages]):
  228. res1d = np.full(n_outputs, np.nan)
  229. elif any([str(e).startswith(message)
  230. for message in inaccuracy_messages]):
  231. with suppress_warnings() as sup:
  232. sup.filter(RuntimeWarning)
  233. sup.filter(UserWarning)
  234. res1d = nan_policy_1d(hypotest, data1d, unpacker,
  235. *args, n_outputs=n_outputs,
  236. nan_policy=nan_policy,
  237. paired=paired, _no_deco=True,
  238. **kwds)
  239. else:
  240. raise e
  241. statistics[i] = res1d[0]
  242. if len(res1d) == 2:
  243. pvalues[i] = res1d[1]
  244. # Perform a vectorized call to the hypothesis test.
  245. # If `nan_policy == 'raise'`, check that it raises the appropriate error.
  246. # If not, compare against the output against `statistics` and `pvalues`
  247. if nan_policy == 'raise' and not data_generator == "all_finite":
  248. message = 'The input contains nan values'
  249. with pytest.raises(ValueError, match=message):
  250. hypotest(*data, axis=axis, nan_policy=nan_policy, *args, **kwds)
  251. else:
  252. with suppress_warnings() as sup, \
  253. np.errstate(divide='ignore', invalid='ignore'):
  254. sup.filter(RuntimeWarning, "Precision loss occurred in moment")
  255. sup.filter(UserWarning, "Sample size too small for normal "
  256. "approximation.")
  257. res = unpacker(hypotest(*data, axis=axis, nan_policy=nan_policy,
  258. *args, **kwds))
  259. assert_allclose(res[0], statistics, rtol=1e-15)
  260. assert_equal(res[0].dtype, statistics.dtype)
  261. if len(res) == 2:
  262. assert_allclose(res[1], pvalues, rtol=1e-15)
  263. assert_equal(res[1].dtype, pvalues.dtype)
  264. @pytest.mark.parametrize(("hypotest", "args", "kwds", "n_samples", "n_outputs",
  265. "paired", "unpacker"), axis_nan_policy_cases)
  266. @pytest.mark.parametrize(("nan_policy"), ("propagate", "omit", "raise"))
  267. @pytest.mark.parametrize(("data_generator"),
  268. ("all_nans", "all_finite", "mixed", "empty"))
  269. def test_axis_nan_policy_axis_is_None(hypotest, args, kwds, n_samples,
  270. n_outputs, paired, unpacker, nan_policy,
  271. data_generator):
  272. # check for correct behavior when `axis=None`
  273. if not unpacker:
  274. def unpacker(res):
  275. return res
  276. rng = np.random.default_rng(0)
  277. if data_generator == "empty":
  278. data = [rng.random((2, 0)) for i in range(n_samples)]
  279. else:
  280. data = [rng.random((2, 20)) for i in range(n_samples)]
  281. if data_generator == "mixed":
  282. masks = [rng.random((2, 20)) > 0.9 for i in range(n_samples)]
  283. for sample, mask in zip(data, masks):
  284. sample[mask] = np.nan
  285. elif data_generator == "all_nans":
  286. data = [sample * np.nan for sample in data]
  287. data_raveled = [sample.ravel() for sample in data]
  288. if nan_policy == 'raise' and data_generator not in {"all_finite", "empty"}:
  289. message = 'The input contains nan values'
  290. # check for correct behavior whether or not data is 1d to begin with
  291. with pytest.raises(ValueError, match=message):
  292. hypotest(*data, axis=None, nan_policy=nan_policy,
  293. *args, **kwds)
  294. with pytest.raises(ValueError, match=message):
  295. hypotest(*data_raveled, axis=None, nan_policy=nan_policy,
  296. *args, **kwds)
  297. else:
  298. # behavior of reference implementation with 1d input, hypotest with 1d
  299. # input, and hypotest with Nd input should match, whether that means
  300. # that outputs are equal or they raise the same exception
  301. ea_str, eb_str, ec_str = None, None, None
  302. with np.errstate(divide='ignore', invalid='ignore'):
  303. try:
  304. res1da = nan_policy_1d(hypotest, data_raveled, unpacker, *args,
  305. n_outputs=n_outputs,
  306. nan_policy=nan_policy, paired=paired,
  307. _no_deco=True, **kwds)
  308. except (RuntimeWarning, ValueError, ZeroDivisionError) as ea:
  309. ea_str = str(ea)
  310. try:
  311. res1db = unpacker(hypotest(*data_raveled, *args,
  312. nan_policy=nan_policy, **kwds))
  313. except (RuntimeWarning, ValueError, ZeroDivisionError) as eb:
  314. eb_str = str(eb)
  315. try:
  316. res1dc = unpacker(hypotest(*data, *args, axis=None,
  317. nan_policy=nan_policy, **kwds))
  318. except (RuntimeWarning, ValueError, ZeroDivisionError) as ec:
  319. ec_str = str(ec)
  320. if ea_str or eb_str or ec_str:
  321. assert any([str(ea_str).startswith(message)
  322. for message in too_small_messages])
  323. assert ea_str == eb_str == ec_str
  324. else:
  325. assert_equal(res1db, res1da)
  326. assert_equal(res1dc, res1da)
  327. # Test keepdims for:
  328. # - single-output and multi-output functions (gmean and mannwhitneyu)
  329. # - Axis negative, positive, None, and tuple
  330. # - 1D with no NaNs
  331. # - 1D with NaN propagation
  332. # - Zero-sized output
  333. @pytest.mark.parametrize("nan_policy", ("omit", "propagate"))
  334. @pytest.mark.parametrize(
  335. ("hypotest", "args", "kwds", "n_samples", "unpacker"),
  336. ((stats.gmean, tuple(), dict(), 1, lambda x: (x,)),
  337. (stats.mannwhitneyu, tuple(), {'method': 'asymptotic'}, 2, None))
  338. )
  339. @pytest.mark.parametrize(
  340. ("sample_shape", "axis_cases"),
  341. (((2, 3, 3, 4), (None, 0, -1, (0, 2), (1, -1), (3, 1, 2, 0))),
  342. ((10, ), (0, -1)),
  343. ((20, 0), (0, 1)))
  344. )
  345. def test_keepdims(hypotest, args, kwds, n_samples, unpacker,
  346. sample_shape, axis_cases, nan_policy):
  347. # test if keepdims parameter works correctly
  348. if not unpacker:
  349. def unpacker(res):
  350. return res
  351. rng = np.random.default_rng(0)
  352. data = [rng.random(sample_shape) for _ in range(n_samples)]
  353. nan_data = [sample.copy() for sample in data]
  354. nan_mask = [rng.random(sample_shape) < 0.2 for _ in range(n_samples)]
  355. for sample, mask in zip(nan_data, nan_mask):
  356. sample[mask] = np.nan
  357. for axis in axis_cases:
  358. expected_shape = list(sample_shape)
  359. if axis is None:
  360. expected_shape = np.ones(len(sample_shape))
  361. else:
  362. if isinstance(axis, int):
  363. expected_shape[axis] = 1
  364. else:
  365. for ax in axis:
  366. expected_shape[ax] = 1
  367. expected_shape = tuple(expected_shape)
  368. res = unpacker(hypotest(*data, *args, axis=axis, keepdims=True,
  369. **kwds))
  370. res_base = unpacker(hypotest(*data, *args, axis=axis, keepdims=False,
  371. **kwds))
  372. nan_res = unpacker(hypotest(*nan_data, *args, axis=axis,
  373. keepdims=True, nan_policy=nan_policy,
  374. **kwds))
  375. nan_res_base = unpacker(hypotest(*nan_data, *args, axis=axis,
  376. keepdims=False,
  377. nan_policy=nan_policy, **kwds))
  378. for r, r_base, rn, rn_base in zip(res, res_base, nan_res,
  379. nan_res_base):
  380. assert r.shape == expected_shape
  381. r = np.squeeze(r, axis=axis)
  382. assert_equal(r, r_base)
  383. assert rn.shape == expected_shape
  384. rn = np.squeeze(rn, axis=axis)
  385. assert_equal(rn, rn_base)
  386. @pytest.mark.parametrize(("fun", "nsamp"),
  387. [(stats.kstat, 1),
  388. (stats.kstatvar, 1)])
  389. def test_hypotest_back_compat_no_axis(fun, nsamp):
  390. m, n = 8, 9
  391. rng = np.random.default_rng(0)
  392. x = rng.random((nsamp, m, n))
  393. res = fun(*x)
  394. res2 = fun(*x, _no_deco=True)
  395. res3 = fun([xi.ravel() for xi in x])
  396. assert_equal(res, res2)
  397. assert_equal(res, res3)
  398. @pytest.mark.parametrize(("axis"), (0, 1, 2))
  399. def test_axis_nan_policy_decorated_positional_axis(axis):
  400. # Test for correct behavior of function decorated with
  401. # _axis_nan_policy_decorator whether `axis` is provided as positional or
  402. # keyword argument
  403. shape = (8, 9, 10)
  404. rng = np.random.default_rng(0)
  405. x = rng.random(shape)
  406. y = rng.random(shape)
  407. res1 = stats.mannwhitneyu(x, y, True, 'two-sided', axis)
  408. res2 = stats.mannwhitneyu(x, y, True, 'two-sided', axis=axis)
  409. assert_equal(res1, res2)
  410. message = "mannwhitneyu() got multiple values for argument 'axis'"
  411. with pytest.raises(TypeError, match=re.escape(message)):
  412. stats.mannwhitneyu(x, y, True, 'two-sided', axis, axis=axis)
  413. def test_axis_nan_policy_decorated_positional_args():
  414. # Test for correct behavior of function decorated with
  415. # _axis_nan_policy_decorator when function accepts *args
  416. shape = (3, 8, 9, 10)
  417. rng = np.random.default_rng(0)
  418. x = rng.random(shape)
  419. x[0, 0, 0, 0] = np.nan
  420. stats.kruskal(*x)
  421. message = "kruskal() got an unexpected keyword argument 'samples'"
  422. with pytest.raises(TypeError, match=re.escape(message)):
  423. stats.kruskal(samples=x)
  424. with pytest.raises(TypeError, match=re.escape(message)):
  425. stats.kruskal(*x, samples=x)
  426. def test_axis_nan_policy_decorated_keyword_samples():
  427. # Test for correct behavior of function decorated with
  428. # _axis_nan_policy_decorator whether samples are provided as positional or
  429. # keyword arguments
  430. shape = (2, 8, 9, 10)
  431. rng = np.random.default_rng(0)
  432. x = rng.random(shape)
  433. x[0, 0, 0, 0] = np.nan
  434. res1 = stats.mannwhitneyu(*x)
  435. res2 = stats.mannwhitneyu(x=x[0], y=x[1])
  436. assert_equal(res1, res2)
  437. message = "mannwhitneyu() got multiple values for argument"
  438. with pytest.raises(TypeError, match=re.escape(message)):
  439. stats.mannwhitneyu(*x, x=x[0], y=x[1])
  440. @pytest.mark.parametrize(("hypotest", "args", "kwds", "n_samples", "n_outputs",
  441. "paired", "unpacker"), axis_nan_policy_cases)
  442. def test_axis_nan_policy_decorated_pickled(hypotest, args, kwds, n_samples,
  443. n_outputs, paired, unpacker):
  444. rng = np.random.default_rng(0)
  445. # Some hypothesis tests return a non-iterable that needs an `unpacker` to
  446. # extract the statistic and p-value. For those that don't:
  447. if not unpacker:
  448. def unpacker(res):
  449. return res
  450. data = rng.uniform(size=(n_samples, 2, 30))
  451. pickled_hypotest = pickle.dumps(hypotest)
  452. unpickled_hypotest = pickle.loads(pickled_hypotest)
  453. res1 = unpacker(hypotest(*data, *args, axis=-1, **kwds))
  454. res2 = unpacker(unpickled_hypotest(*data, *args, axis=-1, **kwds))
  455. assert_allclose(res1, res2, rtol=1e-12)
  456. def test_check_empty_inputs():
  457. # Test that _check_empty_inputs is doing its job, at least for single-
  458. # sample inputs. (Multi-sample functionality is tested below.)
  459. # If the input sample is not empty, it should return None.
  460. # If the input sample is empty, it should return an array of NaNs or an
  461. # empty array of appropriate shape. np.mean is used as a reference for the
  462. # output because, like the statistics calculated by these functions,
  463. # it works along and "consumes" `axis` but preserves the other axes.
  464. for i in range(5):
  465. for combo in combinations_with_replacement([0, 1, 2], i):
  466. for axis in range(len(combo)):
  467. samples = (np.zeros(combo),)
  468. output = stats._axis_nan_policy._check_empty_inputs(samples,
  469. axis)
  470. if output is not None:
  471. with np.testing.suppress_warnings() as sup:
  472. sup.filter(RuntimeWarning, "Mean of empty slice.")
  473. sup.filter(RuntimeWarning, "invalid value encountered")
  474. reference = samples[0].mean(axis=axis)
  475. np.testing.assert_equal(output, reference)
  476. def _check_arrays_broadcastable(arrays, axis):
  477. # https://numpy.org/doc/stable/user/basics.broadcasting.html
  478. # "When operating on two arrays, NumPy compares their shapes element-wise.
  479. # It starts with the trailing (i.e. rightmost) dimensions and works its
  480. # way left.
  481. # Two dimensions are compatible when
  482. # 1. they are equal, or
  483. # 2. one of them is 1
  484. # ...
  485. # Arrays do not need to have the same number of dimensions."
  486. # (Clarification: if the arrays are compatible according to the criteria
  487. # above and an array runs out of dimensions, it is still compatible.)
  488. # Below, we follow the rules above except ignoring `axis`
  489. n_dims = max([arr.ndim for arr in arrays])
  490. if axis is not None:
  491. # convert to negative axis
  492. axis = (-n_dims + axis) if axis >= 0 else axis
  493. for dim in range(1, n_dims+1): # we'll index from -1 to -n_dims, inclusive
  494. if -dim == axis:
  495. continue # ignore lengths along `axis`
  496. dim_lengths = set()
  497. for arr in arrays:
  498. if dim <= arr.ndim and arr.shape[-dim] != 1:
  499. dim_lengths.add(arr.shape[-dim])
  500. if len(dim_lengths) > 1:
  501. return False
  502. return True
  503. @pytest.mark.slow
  504. @pytest.mark.parametrize(("hypotest", "args", "kwds", "n_samples", "n_outputs",
  505. "paired", "unpacker"), axis_nan_policy_cases)
  506. def test_empty(hypotest, args, kwds, n_samples, n_outputs, paired, unpacker):
  507. # test for correct output shape when at least one input is empty
  508. if unpacker is None:
  509. unpacker = lambda res: (res[0], res[1]) # noqa: E731
  510. def small_data_generator(n_samples, n_dims):
  511. def small_sample_generator(n_dims):
  512. # return all possible "small" arrays in up to n_dim dimensions
  513. for i in n_dims:
  514. # "small" means with size along dimension either 0 or 1
  515. for combo in combinations_with_replacement([0, 1, 2], i):
  516. yield np.zeros(combo)
  517. # yield all possible combinations of small samples
  518. gens = [small_sample_generator(n_dims) for i in range(n_samples)]
  519. for i in product(*gens):
  520. yield i
  521. n_dims = [2, 3]
  522. for samples in small_data_generator(n_samples, n_dims):
  523. # this test is only for arrays of zero size
  524. if not any((sample.size == 0 for sample in samples)):
  525. continue
  526. max_axis = max((sample.ndim for sample in samples))
  527. # need to test for all valid values of `axis` parameter, too
  528. for axis in range(-max_axis, max_axis):
  529. try:
  530. # After broadcasting, all arrays are the same shape, so
  531. # the shape of the output should be the same as a single-
  532. # sample statistic. Use np.mean as a reference.
  533. concat = stats._stats_py._broadcast_concatenate(samples, axis)
  534. with np.testing.suppress_warnings() as sup:
  535. sup.filter(RuntimeWarning, "Mean of empty slice.")
  536. sup.filter(RuntimeWarning, "invalid value encountered")
  537. expected = np.mean(concat, axis=axis) * np.nan
  538. res = hypotest(*samples, *args, axis=axis, **kwds)
  539. res = unpacker(res)
  540. for i in range(n_outputs):
  541. assert_equal(res[i], expected)
  542. except ValueError:
  543. # confirm that the arrays truly are not broadcastable
  544. assert not _check_arrays_broadcastable(samples, axis)
  545. # confirm that _both_ `_broadcast_concatenate` and `hypotest`
  546. # produce this information.
  547. message = "Array shapes are incompatible for broadcasting."
  548. with pytest.raises(ValueError, match=message):
  549. stats._stats_py._broadcast_concatenate(samples, axis)
  550. with pytest.raises(ValueError, match=message):
  551. hypotest(*samples, *args, axis=axis, **kwds)
  552. def test_masked_array_2_sentinel_array():
  553. # prepare arrays
  554. np.random.seed(0)
  555. A = np.random.rand(10, 11, 12)
  556. B = np.random.rand(12)
  557. mask = A < 0.5
  558. A = np.ma.masked_array(A, mask)
  559. # set arbitrary elements to special values
  560. # (these values might have been considered for use as sentinel values)
  561. max_float = np.finfo(np.float64).max
  562. max_float2 = np.nextafter(max_float, -np.inf)
  563. max_float3 = np.nextafter(max_float2, -np.inf)
  564. A[3, 4, 1] = np.nan
  565. A[4, 5, 2] = np.inf
  566. A[5, 6, 3] = max_float
  567. B[8] = np.nan
  568. B[7] = np.inf
  569. B[6] = max_float2
  570. # convert masked A to array with sentinel value, don't modify B
  571. out_arrays, sentinel = _masked_arrays_2_sentinel_arrays([A, B])
  572. A_out, B_out = out_arrays
  573. # check that good sentinel value was chosen (according to intended logic)
  574. assert (sentinel != max_float) and (sentinel != max_float2)
  575. assert sentinel == max_float3
  576. # check that output arrays are as intended
  577. A_reference = A.data
  578. A_reference[A.mask] = sentinel
  579. np.testing.assert_array_equal(A_out, A_reference)
  580. assert B_out is B
  581. def test_masked_dtype():
  582. # When _masked_arrays_2_sentinel_arrays was first added, it always
  583. # upcast the arrays to np.float64. After gh16662, check expected promotion
  584. # and that the expected sentinel is found.
  585. # these are important because the max of the promoted dtype is the first
  586. # candidate to be the sentinel value
  587. max16 = np.iinfo(np.int16).max
  588. max128c = np.finfo(np.complex128).max
  589. # a is a regular array, b has masked elements, and c has no masked elements
  590. a = np.array([1, 2, max16], dtype=np.int16)
  591. b = np.ma.array([1, 2, 1], dtype=np.int8, mask=[0, 1, 0])
  592. c = np.ma.array([1, 2, 1], dtype=np.complex128, mask=[0, 0, 0])
  593. # check integer masked -> sentinel conversion
  594. out_arrays, sentinel = _masked_arrays_2_sentinel_arrays([a, b])
  595. a_out, b_out = out_arrays
  596. assert sentinel == max16-1 # not max16 because max16 was in the data
  597. assert b_out.dtype == np.int16 # check expected promotion
  598. assert_allclose(b_out, [b[0], sentinel, b[-1]]) # check sentinel placement
  599. assert a_out is a # not a masked array, so left untouched
  600. assert not isinstance(b_out, np.ma.MaskedArray) # b became regular array
  601. # similarly with complex
  602. out_arrays, sentinel = _masked_arrays_2_sentinel_arrays([b, c])
  603. b_out, c_out = out_arrays
  604. assert sentinel == max128c # max128c was not in the data
  605. assert b_out.dtype == np.complex128 # b got promoted
  606. assert_allclose(b_out, [b[0], sentinel, b[-1]]) # check sentinel placement
  607. assert not isinstance(b_out, np.ma.MaskedArray) # b became regular array
  608. assert not isinstance(c_out, np.ma.MaskedArray) # c became regular array
  609. # Also, check edge case when a sentinel value cannot be found in the data
  610. min8, max8 = np.iinfo(np.int8).min, np.iinfo(np.int8).max
  611. a = np.arange(min8, max8+1, dtype=np.int8) # use all possible values
  612. mask1 = np.zeros_like(a, dtype=bool)
  613. mask0 = np.zeros_like(a, dtype=bool)
  614. # a masked value can be used as the sentinel
  615. mask1[1] = True
  616. a1 = np.ma.array(a, mask=mask1)
  617. out_arrays, sentinel = _masked_arrays_2_sentinel_arrays([a1])
  618. assert sentinel == min8+1
  619. # unless it's the smallest possible; skipped for simiplicity (see code)
  620. mask0[0] = True
  621. a0 = np.ma.array(a, mask=mask0)
  622. message = "This function replaces masked elements with sentinel..."
  623. with pytest.raises(ValueError, match=message):
  624. _masked_arrays_2_sentinel_arrays([a0])
  625. # test that dtype is preserved in functions
  626. a = np.ma.array([1, 2, 3], mask=[0, 1, 0], dtype=np.float32)
  627. assert stats.gmean(a).dtype == np.float32
  628. def test_masked_stat_1d():
  629. # basic test of _axis_nan_policy_factory with 1D masked sample
  630. males = [19, 22, 16, 29, 24]
  631. females = [20, 11, 17, 12]
  632. res = stats.mannwhitneyu(males, females)
  633. # same result when extra nan is omitted
  634. females2 = [20, 11, 17, np.nan, 12]
  635. res2 = stats.mannwhitneyu(males, females2, nan_policy='omit')
  636. np.testing.assert_array_equal(res2, res)
  637. # same result when extra element is masked
  638. females3 = [20, 11, 17, 1000, 12]
  639. mask3 = [False, False, False, True, False]
  640. females3 = np.ma.masked_array(females3, mask=mask3)
  641. res3 = stats.mannwhitneyu(males, females3)
  642. np.testing.assert_array_equal(res3, res)
  643. # same result when extra nan is omitted and additional element is masked
  644. females4 = [20, 11, 17, np.nan, 1000, 12]
  645. mask4 = [False, False, False, False, True, False]
  646. females4 = np.ma.masked_array(females4, mask=mask4)
  647. res4 = stats.mannwhitneyu(males, females4, nan_policy='omit')
  648. np.testing.assert_array_equal(res4, res)
  649. # same result when extra elements, including nan, are masked
  650. females5 = [20, 11, 17, np.nan, 1000, 12]
  651. mask5 = [False, False, False, True, True, False]
  652. females5 = np.ma.masked_array(females5, mask=mask5)
  653. res5 = stats.mannwhitneyu(males, females5, nan_policy='propagate')
  654. res6 = stats.mannwhitneyu(males, females5, nan_policy='raise')
  655. np.testing.assert_array_equal(res5, res)
  656. np.testing.assert_array_equal(res6, res)
  657. @pytest.mark.parametrize(("axis"), range(-3, 3))
  658. def test_masked_stat_3d(axis):
  659. # basic test of _axis_nan_policy_factory with 3D masked sample
  660. np.random.seed(0)
  661. a = np.random.rand(3, 4, 5)
  662. b = np.random.rand(4, 5)
  663. c = np.random.rand(4, 1)
  664. mask_a = a < 0.1
  665. mask_c = [False, False, False, True]
  666. a_masked = np.ma.masked_array(a, mask=mask_a)
  667. c_masked = np.ma.masked_array(c, mask=mask_c)
  668. a_nans = a.copy()
  669. a_nans[mask_a] = np.nan
  670. c_nans = c.copy()
  671. c_nans[mask_c] = np.nan
  672. res = stats.kruskal(a_nans, b, c_nans, nan_policy='omit', axis=axis)
  673. res2 = stats.kruskal(a_masked, b, c_masked, axis=axis)
  674. np.testing.assert_array_equal(res, res2)
  675. def test_mixed_mask_nan_1():
  676. # targeted test of _axis_nan_policy_factory with 2D masked sample:
  677. # omitting samples with masks and nan_policy='omit' are equivalent
  678. # also checks paired-sample sentinel value removal
  679. m, n = 3, 20
  680. axis = -1
  681. np.random.seed(0)
  682. a = np.random.rand(m, n)
  683. b = np.random.rand(m, n)
  684. mask_a1 = np.random.rand(m, n) < 0.2
  685. mask_a2 = np.random.rand(m, n) < 0.1
  686. mask_b1 = np.random.rand(m, n) < 0.15
  687. mask_b2 = np.random.rand(m, n) < 0.15
  688. mask_a1[2, :] = True
  689. a_nans = a.copy()
  690. b_nans = b.copy()
  691. a_nans[mask_a1 | mask_a2] = np.nan
  692. b_nans[mask_b1 | mask_b2] = np.nan
  693. a_masked1 = np.ma.masked_array(a, mask=mask_a1)
  694. b_masked1 = np.ma.masked_array(b, mask=mask_b1)
  695. a_masked1[mask_a2] = np.nan
  696. b_masked1[mask_b2] = np.nan
  697. a_masked2 = np.ma.masked_array(a, mask=mask_a2)
  698. b_masked2 = np.ma.masked_array(b, mask=mask_b2)
  699. a_masked2[mask_a1] = np.nan
  700. b_masked2[mask_b1] = np.nan
  701. a_masked3 = np.ma.masked_array(a, mask=(mask_a1 | mask_a2))
  702. b_masked3 = np.ma.masked_array(b, mask=(mask_b1 | mask_b2))
  703. res = stats.wilcoxon(a_nans, b_nans, nan_policy='omit', axis=axis)
  704. res1 = stats.wilcoxon(a_masked1, b_masked1, nan_policy='omit', axis=axis)
  705. res2 = stats.wilcoxon(a_masked2, b_masked2, nan_policy='omit', axis=axis)
  706. res3 = stats.wilcoxon(a_masked3, b_masked3, nan_policy='raise', axis=axis)
  707. res4 = stats.wilcoxon(a_masked3, b_masked3,
  708. nan_policy='propagate', axis=axis)
  709. np.testing.assert_array_equal(res1, res)
  710. np.testing.assert_array_equal(res2, res)
  711. np.testing.assert_array_equal(res3, res)
  712. np.testing.assert_array_equal(res4, res)
  713. def test_mixed_mask_nan_2():
  714. # targeted test of _axis_nan_policy_factory with 2D masked sample:
  715. # check for expected interaction between masks and nans
  716. # Cases here are
  717. # [mixed nan/mask, all nans, all masked,
  718. # unmasked nan, masked nan, unmasked non-nan]
  719. a = [[1, np.nan, 2], [np.nan, np.nan, np.nan], [1, 2, 3],
  720. [1, np.nan, 3], [1, np.nan, 3], [1, 2, 3]]
  721. mask = [[1, 0, 1], [0, 0, 0], [1, 1, 1],
  722. [0, 0, 0], [0, 1, 0], [0, 0, 0]]
  723. a_masked = np.ma.masked_array(a, mask=mask)
  724. b = [[4, 5, 6]]
  725. ref1 = stats.ranksums([1, 3], [4, 5, 6])
  726. ref2 = stats.ranksums([1, 2, 3], [4, 5, 6])
  727. # nan_policy = 'omit'
  728. # all elements are removed from first three rows
  729. # middle element is removed from fourth and fifth rows
  730. # no elements removed from last row
  731. res = stats.ranksums(a_masked, b, nan_policy='omit', axis=-1)
  732. stat_ref = [np.nan, np.nan, np.nan,
  733. ref1.statistic, ref1.statistic, ref2.statistic]
  734. p_ref = [np.nan, np.nan, np.nan,
  735. ref1.pvalue, ref1.pvalue, ref2.pvalue]
  736. np.testing.assert_array_equal(res.statistic, stat_ref)
  737. np.testing.assert_array_equal(res.pvalue, p_ref)
  738. # nan_policy = 'propagate'
  739. # nans propagate in first, second, and fourth row
  740. # all elements are removed by mask from third row
  741. # middle element is removed from fifth row
  742. # no elements removed from last row
  743. res = stats.ranksums(a_masked, b, nan_policy='propagate', axis=-1)
  744. stat_ref = [np.nan, np.nan, np.nan,
  745. np.nan, ref1.statistic, ref2.statistic]
  746. p_ref = [np.nan, np.nan, np.nan,
  747. np.nan, ref1.pvalue, ref2.pvalue]
  748. np.testing.assert_array_equal(res.statistic, stat_ref)
  749. np.testing.assert_array_equal(res.pvalue, p_ref)
  750. def test_axis_None_vs_tuple():
  751. # `axis` `None` should be equivalent to tuple with all axes
  752. shape = (3, 8, 9, 10)
  753. rng = np.random.default_rng(0)
  754. x = rng.random(shape)
  755. res = stats.kruskal(*x, axis=None)
  756. res2 = stats.kruskal(*x, axis=(0, 1, 2))
  757. np.testing.assert_array_equal(res, res2)
  758. def test_axis_None_vs_tuple_with_broadcasting():
  759. # `axis` `None` should be equivalent to tuple with all axes,
  760. # which should be equivalent to raveling the arrays before passing them
  761. rng = np.random.default_rng(0)
  762. x = rng.random((5, 1))
  763. y = rng.random((1, 5))
  764. x2, y2 = np.broadcast_arrays(x, y)
  765. res0 = stats.mannwhitneyu(x.ravel(), y.ravel())
  766. res1 = stats.mannwhitneyu(x, y, axis=None)
  767. res2 = stats.mannwhitneyu(x, y, axis=(0, 1))
  768. res3 = stats.mannwhitneyu(x2.ravel(), y2.ravel())
  769. assert res1 == res0
  770. assert res2 == res0
  771. assert res3 != res0
  772. @pytest.mark.parametrize(("axis"),
  773. list(permutations(range(-3, 3), 2)) + [(-4, 1)])
  774. def test_other_axis_tuples(axis):
  775. # Check that _axis_nan_policy_factory treates all `axis` tuples as expected
  776. rng = np.random.default_rng(0)
  777. shape_x = (4, 5, 6)
  778. shape_y = (1, 6)
  779. x = rng.random(shape_x)
  780. y = rng.random(shape_y)
  781. axis_original = axis
  782. # convert axis elements to positive
  783. axis = tuple([(i if i >= 0 else 3 + i) for i in axis])
  784. axis = sorted(axis)
  785. if len(set(axis)) != len(axis):
  786. message = "`axis` must contain only distinct elements"
  787. with pytest.raises(np.AxisError, match=re.escape(message)):
  788. stats.mannwhitneyu(x, y, axis=axis_original)
  789. return
  790. if axis[0] < 0 or axis[-1] > 2:
  791. message = "`axis` is out of bounds for array of dimension 3"
  792. with pytest.raises(np.AxisError, match=re.escape(message)):
  793. stats.mannwhitneyu(x, y, axis=axis_original)
  794. return
  795. res = stats.mannwhitneyu(x, y, axis=axis_original)
  796. # reference behavior
  797. not_axis = {0, 1, 2} - set(axis) # which axis is not part of `axis`
  798. not_axis = next(iter(not_axis)) # take it out of the set
  799. x2 = x
  800. shape_y_broadcasted = [1, 1, 6]
  801. shape_y_broadcasted[not_axis] = shape_x[not_axis]
  802. y2 = np.broadcast_to(y, shape_y_broadcasted)
  803. m = x2.shape[not_axis]
  804. x2 = np.moveaxis(x2, axis, (1, 2))
  805. y2 = np.moveaxis(y2, axis, (1, 2))
  806. x2 = np.reshape(x2, (m, -1))
  807. y2 = np.reshape(y2, (m, -1))
  808. res2 = stats.mannwhitneyu(x2, y2, axis=1)
  809. np.testing.assert_array_equal(res, res2)
  810. @pytest.mark.parametrize(("weighted_fun_name"), ["gmean", "hmean", "pmean"])
  811. def test_mean_mixed_mask_nan_weights(weighted_fun_name):
  812. # targeted test of _axis_nan_policy_factory with 2D masked sample:
  813. # omitting samples with masks and nan_policy='omit' are equivalent
  814. # also checks paired-sample sentinel value removal
  815. if weighted_fun_name == 'pmean':
  816. def weighted_fun(a, **kwargs):
  817. return stats.pmean(a, p=0.42, **kwargs)
  818. else:
  819. weighted_fun = getattr(stats, weighted_fun_name)
  820. m, n = 3, 20
  821. axis = -1
  822. rng = np.random.default_rng(6541968121)
  823. a = rng.uniform(size=(m, n))
  824. b = rng.uniform(size=(m, n))
  825. mask_a1 = rng.uniform(size=(m, n)) < 0.2
  826. mask_a2 = rng.uniform(size=(m, n)) < 0.1
  827. mask_b1 = rng.uniform(size=(m, n)) < 0.15
  828. mask_b2 = rng.uniform(size=(m, n)) < 0.15
  829. mask_a1[2, :] = True
  830. a_nans = a.copy()
  831. b_nans = b.copy()
  832. a_nans[mask_a1 | mask_a2] = np.nan
  833. b_nans[mask_b1 | mask_b2] = np.nan
  834. a_masked1 = np.ma.masked_array(a, mask=mask_a1)
  835. b_masked1 = np.ma.masked_array(b, mask=mask_b1)
  836. a_masked1[mask_a2] = np.nan
  837. b_masked1[mask_b2] = np.nan
  838. a_masked2 = np.ma.masked_array(a, mask=mask_a2)
  839. b_masked2 = np.ma.masked_array(b, mask=mask_b2)
  840. a_masked2[mask_a1] = np.nan
  841. b_masked2[mask_b1] = np.nan
  842. a_masked3 = np.ma.masked_array(a, mask=(mask_a1 | mask_a2))
  843. b_masked3 = np.ma.masked_array(b, mask=(mask_b1 | mask_b2))
  844. mask_all = (mask_a1 | mask_a2 | mask_b1 | mask_b2)
  845. a_masked4 = np.ma.masked_array(a, mask=mask_all)
  846. b_masked4 = np.ma.masked_array(b, mask=mask_all)
  847. with np.testing.suppress_warnings() as sup:
  848. message = 'invalid value encountered'
  849. sup.filter(RuntimeWarning, message)
  850. res = weighted_fun(a_nans, weights=b_nans,
  851. nan_policy='omit', axis=axis)
  852. res1 = weighted_fun(a_masked1, weights=b_masked1,
  853. nan_policy='omit', axis=axis)
  854. res2 = weighted_fun(a_masked2, weights=b_masked2,
  855. nan_policy='omit', axis=axis)
  856. res3 = weighted_fun(a_masked3, weights=b_masked3,
  857. nan_policy='raise', axis=axis)
  858. res4 = weighted_fun(a_masked3, weights=b_masked3,
  859. nan_policy='propagate', axis=axis)
  860. # Would test with a_masked3/b_masked3, but there is a bug in np.average
  861. # that causes a bug in _no_deco mean with masked weights. Would use
  862. # np.ma.average, but that causes other problems. See numpy/numpy#7330.
  863. if weighted_fun_name not in {'pmean', 'gmean'}:
  864. weighted_fun_ma = getattr(stats.mstats, weighted_fun_name)
  865. res5 = weighted_fun_ma(a_masked4, weights=b_masked4,
  866. axis=axis, _no_deco=True)
  867. np.testing.assert_array_equal(res1, res)
  868. np.testing.assert_array_equal(res2, res)
  869. np.testing.assert_array_equal(res3, res)
  870. np.testing.assert_array_equal(res4, res)
  871. if weighted_fun_name not in {'pmean', 'gmean'}:
  872. # _no_deco mean returns masked array, last element was masked
  873. np.testing.assert_allclose(res5.compressed(), res[~np.isnan(res)])