test__util.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. from multiprocessing import Pool
  2. from multiprocessing.pool import Pool as PWL
  3. import os
  4. import re
  5. import math
  6. from fractions import Fraction
  7. import numpy as np
  8. from numpy.testing import assert_equal, assert_
  9. import pytest
  10. from pytest import raises as assert_raises, deprecated_call
  11. import scipy
  12. from scipy._lib._util import (_aligned_zeros, check_random_state, MapWrapper,
  13. getfullargspec_no_self, FullArgSpec,
  14. rng_integers, _validate_int, _rename_parameter,
  15. _contains_nan)
  16. def test__aligned_zeros():
  17. niter = 10
  18. def check(shape, dtype, order, align):
  19. err_msg = repr((shape, dtype, order, align))
  20. x = _aligned_zeros(shape, dtype, order, align=align)
  21. if align is None:
  22. align = np.dtype(dtype).alignment
  23. assert_equal(x.__array_interface__['data'][0] % align, 0)
  24. if hasattr(shape, '__len__'):
  25. assert_equal(x.shape, shape, err_msg)
  26. else:
  27. assert_equal(x.shape, (shape,), err_msg)
  28. assert_equal(x.dtype, dtype)
  29. if order == "C":
  30. assert_(x.flags.c_contiguous, err_msg)
  31. elif order == "F":
  32. if x.size > 0:
  33. # Size-0 arrays get invalid flags on NumPy 1.5
  34. assert_(x.flags.f_contiguous, err_msg)
  35. elif order is None:
  36. assert_(x.flags.c_contiguous, err_msg)
  37. else:
  38. raise ValueError()
  39. # try various alignments
  40. for align in [1, 2, 3, 4, 8, 16, 32, 64, None]:
  41. for n in [0, 1, 3, 11]:
  42. for order in ["C", "F", None]:
  43. for dtype in [np.uint8, np.float64]:
  44. for shape in [n, (1, 2, 3, n)]:
  45. for j in range(niter):
  46. check(shape, dtype, order, align)
  47. def test_check_random_state():
  48. # If seed is None, return the RandomState singleton used by np.random.
  49. # If seed is an int, return a new RandomState instance seeded with seed.
  50. # If seed is already a RandomState instance, return it.
  51. # Otherwise raise ValueError.
  52. rsi = check_random_state(1)
  53. assert_equal(type(rsi), np.random.RandomState)
  54. rsi = check_random_state(rsi)
  55. assert_equal(type(rsi), np.random.RandomState)
  56. rsi = check_random_state(None)
  57. assert_equal(type(rsi), np.random.RandomState)
  58. assert_raises(ValueError, check_random_state, 'a')
  59. if hasattr(np.random, 'Generator'):
  60. # np.random.Generator is only available in NumPy >= 1.17
  61. rg = np.random.Generator(np.random.PCG64())
  62. rsi = check_random_state(rg)
  63. assert_equal(type(rsi), np.random.Generator)
  64. def test_getfullargspec_no_self():
  65. p = MapWrapper(1)
  66. argspec = getfullargspec_no_self(p.__init__)
  67. assert_equal(argspec, FullArgSpec(['pool'], None, None, (1,), [],
  68. None, {}))
  69. argspec = getfullargspec_no_self(p.__call__)
  70. assert_equal(argspec, FullArgSpec(['func', 'iterable'], None, None, None,
  71. [], None, {}))
  72. class _rv_generic:
  73. def _rvs(self, a, b=2, c=3, *args, size=None, **kwargs):
  74. return None
  75. rv_obj = _rv_generic()
  76. argspec = getfullargspec_no_self(rv_obj._rvs)
  77. assert_equal(argspec, FullArgSpec(['a', 'b', 'c'], 'args', 'kwargs',
  78. (2, 3), ['size'], {'size': None}, {}))
  79. def test_mapwrapper_serial():
  80. in_arg = np.arange(10.)
  81. out_arg = np.sin(in_arg)
  82. p = MapWrapper(1)
  83. assert_(p._mapfunc is map)
  84. assert_(p.pool is None)
  85. assert_(p._own_pool is False)
  86. out = list(p(np.sin, in_arg))
  87. assert_equal(out, out_arg)
  88. with assert_raises(RuntimeError):
  89. p = MapWrapper(0)
  90. def test_pool():
  91. with Pool(2) as p:
  92. p.map(math.sin, [1, 2, 3, 4])
  93. def test_mapwrapper_parallel():
  94. in_arg = np.arange(10.)
  95. out_arg = np.sin(in_arg)
  96. with MapWrapper(2) as p:
  97. out = p(np.sin, in_arg)
  98. assert_equal(list(out), out_arg)
  99. assert_(p._own_pool is True)
  100. assert_(isinstance(p.pool, PWL))
  101. assert_(p._mapfunc is not None)
  102. # the context manager should've closed the internal pool
  103. # check that it has by asking it to calculate again.
  104. with assert_raises(Exception) as excinfo:
  105. p(np.sin, in_arg)
  106. assert_(excinfo.type is ValueError)
  107. # can also set a PoolWrapper up with a map-like callable instance
  108. with Pool(2) as p:
  109. q = MapWrapper(p.map)
  110. assert_(q._own_pool is False)
  111. q.close()
  112. # closing the PoolWrapper shouldn't close the internal pool
  113. # because it didn't create it
  114. out = p.map(np.sin, in_arg)
  115. assert_equal(list(out), out_arg)
  116. # get our custom ones and a few from the "import *" cases
  117. @pytest.mark.parametrize(
  118. 'key', ('ifft', 'diag', 'arccos', 'randn', 'rand', 'array'))
  119. def test_numpy_deprecation(key):
  120. """Test that 'from numpy import *' functions are deprecated."""
  121. if key in ('ifft', 'diag', 'arccos'):
  122. arg = [1.0, 0.]
  123. elif key == 'finfo':
  124. arg = float
  125. else:
  126. arg = 2
  127. func = getattr(scipy, key)
  128. match = r'scipy\.%s is deprecated.*2\.0\.0' % key
  129. with deprecated_call(match=match) as dep:
  130. func(arg) # deprecated
  131. # in case we catch more than one dep warning
  132. fnames = [os.path.splitext(d.filename)[0] for d in dep.list]
  133. basenames = [os.path.basename(fname) for fname in fnames]
  134. assert 'test__util' in basenames
  135. if key in ('rand', 'randn'):
  136. root = np.random
  137. elif key == 'ifft':
  138. root = np.fft
  139. else:
  140. root = np
  141. func_np = getattr(root, key)
  142. func_np(arg) # not deprecated
  143. assert func_np is not func
  144. # classes should remain classes
  145. if isinstance(func_np, type):
  146. assert isinstance(func, type)
  147. def test_numpy_deprecation_functionality():
  148. # Check that the deprecation wrappers don't break basic NumPy
  149. # functionality
  150. with deprecated_call():
  151. x = scipy.array([1, 2, 3], dtype=scipy.float64)
  152. assert x.dtype == scipy.float64
  153. assert x.dtype == np.float64
  154. x = scipy.finfo(scipy.float32)
  155. assert x.eps == np.finfo(np.float32).eps
  156. assert scipy.float64 == np.float64
  157. assert issubclass(np.float64, scipy.float64)
  158. def test_rng_integers():
  159. rng = np.random.RandomState()
  160. # test that numbers are inclusive of high point
  161. arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
  162. assert np.max(arr) == 5
  163. assert np.min(arr) == 2
  164. assert arr.shape == (100, )
  165. # test that numbers are inclusive of high point
  166. arr = rng_integers(rng, low=5, size=100, endpoint=True)
  167. assert np.max(arr) == 5
  168. assert np.min(arr) == 0
  169. assert arr.shape == (100, )
  170. # test that numbers are exclusive of high point
  171. arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
  172. assert np.max(arr) == 4
  173. assert np.min(arr) == 2
  174. assert arr.shape == (100, )
  175. # test that numbers are exclusive of high point
  176. arr = rng_integers(rng, low=5, size=100, endpoint=False)
  177. assert np.max(arr) == 4
  178. assert np.min(arr) == 0
  179. assert arr.shape == (100, )
  180. # now try with np.random.Generator
  181. try:
  182. rng = np.random.default_rng()
  183. except AttributeError:
  184. return
  185. # test that numbers are inclusive of high point
  186. arr = rng_integers(rng, low=2, high=5, size=100, endpoint=True)
  187. assert np.max(arr) == 5
  188. assert np.min(arr) == 2
  189. assert arr.shape == (100, )
  190. # test that numbers are inclusive of high point
  191. arr = rng_integers(rng, low=5, size=100, endpoint=True)
  192. assert np.max(arr) == 5
  193. assert np.min(arr) == 0
  194. assert arr.shape == (100, )
  195. # test that numbers are exclusive of high point
  196. arr = rng_integers(rng, low=2, high=5, size=100, endpoint=False)
  197. assert np.max(arr) == 4
  198. assert np.min(arr) == 2
  199. assert arr.shape == (100, )
  200. # test that numbers are exclusive of high point
  201. arr = rng_integers(rng, low=5, size=100, endpoint=False)
  202. assert np.max(arr) == 4
  203. assert np.min(arr) == 0
  204. assert arr.shape == (100, )
  205. class TestValidateInt:
  206. @pytest.mark.parametrize('n', [4, np.uint8(4), np.int16(4), np.array(4)])
  207. def test_validate_int(self, n):
  208. n = _validate_int(n, 'n')
  209. assert n == 4
  210. @pytest.mark.parametrize('n', [4.0, np.array([4]), Fraction(4, 1)])
  211. def test_validate_int_bad(self, n):
  212. with pytest.raises(TypeError, match='n must be an integer'):
  213. _validate_int(n, 'n')
  214. def test_validate_int_below_min(self):
  215. with pytest.raises(ValueError, match='n must be an integer not '
  216. 'less than 0'):
  217. _validate_int(-1, 'n', 0)
  218. class TestRenameParameter:
  219. # check that wrapper `_rename_parameter` for backward-compatible
  220. # keyword renaming works correctly
  221. # Example method/function that still accepts keyword `old`
  222. @_rename_parameter("old", "new")
  223. def old_keyword_still_accepted(self, new):
  224. return new
  225. # Example method/function for which keyword `old` is deprecated
  226. @_rename_parameter("old", "new", dep_version="1.9.0")
  227. def old_keyword_deprecated(self, new):
  228. return new
  229. def test_old_keyword_still_accepted(self):
  230. # positional argument and both keyword work identically
  231. res1 = self.old_keyword_still_accepted(10)
  232. res2 = self.old_keyword_still_accepted(new=10)
  233. res3 = self.old_keyword_still_accepted(old=10)
  234. assert res1 == res2 == res3 == 10
  235. # unexpected keyword raises an error
  236. message = re.escape("old_keyword_still_accepted() got an unexpected")
  237. with pytest.raises(TypeError, match=message):
  238. self.old_keyword_still_accepted(unexpected=10)
  239. # multiple values for the same parameter raises an error
  240. message = re.escape("old_keyword_still_accepted() got multiple")
  241. with pytest.raises(TypeError, match=message):
  242. self.old_keyword_still_accepted(10, new=10)
  243. with pytest.raises(TypeError, match=message):
  244. self.old_keyword_still_accepted(10, old=10)
  245. with pytest.raises(TypeError, match=message):
  246. self.old_keyword_still_accepted(new=10, old=10)
  247. def test_old_keyword_deprecated(self):
  248. # positional argument and both keyword work identically,
  249. # but use of old keyword results in DeprecationWarning
  250. dep_msg = "Use of keyword argument `old` is deprecated"
  251. res1 = self.old_keyword_deprecated(10)
  252. res2 = self.old_keyword_deprecated(new=10)
  253. with pytest.warns(DeprecationWarning, match=dep_msg):
  254. res3 = self.old_keyword_deprecated(old=10)
  255. assert res1 == res2 == res3 == 10
  256. # unexpected keyword raises an error
  257. message = re.escape("old_keyword_deprecated() got an unexpected")
  258. with pytest.raises(TypeError, match=message):
  259. self.old_keyword_deprecated(unexpected=10)
  260. # multiple values for the same parameter raises an error and,
  261. # if old keyword is used, results in DeprecationWarning
  262. message = re.escape("old_keyword_deprecated() got multiple")
  263. with pytest.raises(TypeError, match=message):
  264. self.old_keyword_deprecated(10, new=10)
  265. with pytest.raises(TypeError, match=message), \
  266. pytest.warns(DeprecationWarning, match=dep_msg):
  267. self.old_keyword_deprecated(10, old=10)
  268. with pytest.raises(TypeError, match=message), \
  269. pytest.warns(DeprecationWarning, match=dep_msg):
  270. self.old_keyword_deprecated(new=10, old=10)
  271. class TestContainsNaNTest:
  272. def test_policy(self):
  273. data = np.array([1, 2, 3, np.nan])
  274. contains_nan, nan_policy = _contains_nan(data, nan_policy="propagate")
  275. assert contains_nan
  276. assert nan_policy == "propagate"
  277. contains_nan, nan_policy = _contains_nan(data, nan_policy="omit")
  278. assert contains_nan
  279. assert nan_policy == "omit"
  280. msg = "The input contains nan values"
  281. with pytest.raises(ValueError, match=msg):
  282. _contains_nan(data, nan_policy="raise")
  283. msg = "nan_policy must be one of"
  284. with pytest.raises(ValueError, match=msg):
  285. _contains_nan(data, nan_policy="nan")
  286. def test_contains_nan_1d(self):
  287. data1 = np.array([1, 2, 3])
  288. assert not _contains_nan(data1)[0]
  289. data2 = np.array([1, 2, 3, np.nan])
  290. assert _contains_nan(data2)[0]
  291. data3 = np.array([np.nan, 2, 3, np.nan])
  292. assert _contains_nan(data3)[0]
  293. data4 = np.array([1, 2, "3", np.nan]) # converted to string "nan"
  294. assert not _contains_nan(data4)[0]
  295. data5 = np.array([1, 2, "3", np.nan], dtype='object')
  296. assert _contains_nan(data5)[0]
  297. def test_contains_nan_2d(self):
  298. data1 = np.array([[1, 2], [3, 4]])
  299. assert not _contains_nan(data1)[0]
  300. data2 = np.array([[1, 2], [3, np.nan]])
  301. assert _contains_nan(data2)[0]
  302. data3 = np.array([["1", 2], [3, np.nan]]) # converted to string "nan"
  303. assert not _contains_nan(data3)[0]
  304. data4 = np.array([["1", 2], [3, np.nan]], dtype='object')
  305. assert _contains_nan(data4)[0]