_util.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711
  1. from contextlib import contextmanager
  2. import functools
  3. import operator
  4. import warnings
  5. import numbers
  6. from collections import namedtuple
  7. import inspect
  8. import math
  9. from typing import (
  10. Optional,
  11. Union,
  12. TYPE_CHECKING,
  13. TypeVar,
  14. )
  15. import numpy as np
  16. IntNumber = Union[int, np.integer]
  17. DecimalNumber = Union[float, np.floating, np.integer]
  18. # Since Generator was introduced in numpy 1.17, the following condition is needed for
  19. # backward compatibility
  20. if TYPE_CHECKING:
  21. SeedType = Optional[Union[IntNumber, np.random.Generator,
  22. np.random.RandomState]]
  23. GeneratorType = TypeVar("GeneratorType", bound=Union[np.random.Generator,
  24. np.random.RandomState])
  25. try:
  26. from numpy.random import Generator as Generator
  27. except ImportError:
  28. class Generator(): # type: ignore[no-redef]
  29. pass
  30. def _lazywhere(cond, arrays, f, fillvalue=None, f2=None):
  31. """
  32. np.where(cond, x, fillvalue) always evaluates x even where cond is False.
  33. This one only evaluates f(arr1[cond], arr2[cond], ...).
  34. Examples
  35. --------
  36. >>> import numpy as np
  37. >>> a, b = np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8])
  38. >>> def f(a, b):
  39. ... return a*b
  40. >>> _lazywhere(a > 2, (a, b), f, np.nan)
  41. array([ nan, nan, 21., 32.])
  42. Notice, it assumes that all `arrays` are of the same shape, or can be
  43. broadcasted together.
  44. """
  45. cond = np.asarray(cond)
  46. if fillvalue is None:
  47. if f2 is None:
  48. raise ValueError("One of (fillvalue, f2) must be given.")
  49. else:
  50. fillvalue = np.nan
  51. else:
  52. if f2 is not None:
  53. raise ValueError("Only one of (fillvalue, f2) can be given.")
  54. args = np.broadcast_arrays(cond, *arrays)
  55. cond, arrays = args[0], args[1:]
  56. temp = tuple(np.extract(cond, arr) for arr in arrays)
  57. tcode = np.mintypecode([a.dtype.char for a in arrays])
  58. out = np.full(np.shape(arrays[0]), fill_value=fillvalue, dtype=tcode)
  59. np.place(out, cond, f(*temp))
  60. if f2 is not None:
  61. temp = tuple(np.extract(~cond, arr) for arr in arrays)
  62. np.place(out, ~cond, f2(*temp))
  63. return out
  64. def _lazyselect(condlist, choicelist, arrays, default=0):
  65. """
  66. Mimic `np.select(condlist, choicelist)`.
  67. Notice, it assumes that all `arrays` are of the same shape or can be
  68. broadcasted together.
  69. All functions in `choicelist` must accept array arguments in the order
  70. given in `arrays` and must return an array of the same shape as broadcasted
  71. `arrays`.
  72. Examples
  73. --------
  74. >>> import numpy as np
  75. >>> x = np.arange(6)
  76. >>> np.select([x <3, x > 3], [x**2, x**3], default=0)
  77. array([ 0, 1, 4, 0, 64, 125])
  78. >>> _lazyselect([x < 3, x > 3], [lambda x: x**2, lambda x: x**3], (x,))
  79. array([ 0., 1., 4., 0., 64., 125.])
  80. >>> a = -np.ones_like(x)
  81. >>> _lazyselect([x < 3, x > 3],
  82. ... [lambda x, a: x**2, lambda x, a: a * x**3],
  83. ... (x, a), default=np.nan)
  84. array([ 0., 1., 4., nan, -64., -125.])
  85. """
  86. arrays = np.broadcast_arrays(*arrays)
  87. tcode = np.mintypecode([a.dtype.char for a in arrays])
  88. out = np.full(np.shape(arrays[0]), fill_value=default, dtype=tcode)
  89. for func, cond in zip(choicelist, condlist):
  90. if np.all(cond is False):
  91. continue
  92. cond, _ = np.broadcast_arrays(cond, arrays[0])
  93. temp = tuple(np.extract(cond, arr) for arr in arrays)
  94. np.place(out, cond, func(*temp))
  95. return out
  96. def _aligned_zeros(shape, dtype=float, order="C", align=None):
  97. """Allocate a new ndarray with aligned memory.
  98. Primary use case for this currently is working around a f2py issue
  99. in NumPy 1.9.1, where dtype.alignment is such that np.zeros() does
  100. not necessarily create arrays aligned up to it.
  101. """
  102. dtype = np.dtype(dtype)
  103. if align is None:
  104. align = dtype.alignment
  105. if not hasattr(shape, '__len__'):
  106. shape = (shape,)
  107. size = functools.reduce(operator.mul, shape) * dtype.itemsize
  108. buf = np.empty(size + align + 1, np.uint8)
  109. offset = buf.__array_interface__['data'][0] % align
  110. if offset != 0:
  111. offset = align - offset
  112. # Note: slices producing 0-size arrays do not necessarily change
  113. # data pointer --- so we use and allocate size+1
  114. buf = buf[offset:offset+size+1][:-1]
  115. data = np.ndarray(shape, dtype, buf, order=order)
  116. data.fill(0)
  117. return data
  118. def _prune_array(array):
  119. """Return an array equivalent to the input array. If the input
  120. array is a view of a much larger array, copy its contents to a
  121. newly allocated array. Otherwise, return the input unchanged.
  122. """
  123. if array.base is not None and array.size < array.base.size // 2:
  124. return array.copy()
  125. return array
  126. def prod(iterable):
  127. """
  128. Product of a sequence of numbers.
  129. Faster than np.prod for short lists like array shapes, and does
  130. not overflow if using Python integers.
  131. """
  132. product = 1
  133. for x in iterable:
  134. product *= x
  135. return product
  136. def float_factorial(n: int) -> float:
  137. """Compute the factorial and return as a float
  138. Returns infinity when result is too large for a double
  139. """
  140. return float(math.factorial(n)) if n < 171 else np.inf
  141. # copy-pasted from scikit-learn utils/validation.py
  142. # change this to scipy.stats._qmc.check_random_state once numpy 1.16 is dropped
  143. def check_random_state(seed):
  144. """Turn `seed` into a `np.random.RandomState` instance.
  145. Parameters
  146. ----------
  147. seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
  148. If `seed` is None (or `np.random`), the `numpy.random.RandomState`
  149. singleton is used.
  150. If `seed` is an int, a new ``RandomState`` instance is used,
  151. seeded with `seed`.
  152. If `seed` is already a ``Generator`` or ``RandomState`` instance then
  153. that instance is used.
  154. Returns
  155. -------
  156. seed : {`numpy.random.Generator`, `numpy.random.RandomState`}
  157. Random number generator.
  158. """
  159. if seed is None or seed is np.random:
  160. return np.random.mtrand._rand
  161. if isinstance(seed, (numbers.Integral, np.integer)):
  162. return np.random.RandomState(seed)
  163. if isinstance(seed, (np.random.RandomState, np.random.Generator)):
  164. return seed
  165. raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
  166. ' instance' % seed)
  167. def _asarray_validated(a, check_finite=True,
  168. sparse_ok=False, objects_ok=False, mask_ok=False,
  169. as_inexact=False):
  170. """
  171. Helper function for SciPy argument validation.
  172. Many SciPy linear algebra functions do support arbitrary array-like
  173. input arguments. Examples of commonly unsupported inputs include
  174. matrices containing inf/nan, sparse matrix representations, and
  175. matrices with complicated elements.
  176. Parameters
  177. ----------
  178. a : array_like
  179. The array-like input.
  180. check_finite : bool, optional
  181. Whether to check that the input matrices contain only finite numbers.
  182. Disabling may give a performance gain, but may result in problems
  183. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  184. Default: True
  185. sparse_ok : bool, optional
  186. True if scipy sparse matrices are allowed.
  187. objects_ok : bool, optional
  188. True if arrays with dype('O') are allowed.
  189. mask_ok : bool, optional
  190. True if masked arrays are allowed.
  191. as_inexact : bool, optional
  192. True to convert the input array to a np.inexact dtype.
  193. Returns
  194. -------
  195. ret : ndarray
  196. The converted validated array.
  197. """
  198. if not sparse_ok:
  199. import scipy.sparse
  200. if scipy.sparse.issparse(a):
  201. msg = ('Sparse matrices are not supported by this function. '
  202. 'Perhaps one of the scipy.sparse.linalg functions '
  203. 'would work instead.')
  204. raise ValueError(msg)
  205. if not mask_ok:
  206. if np.ma.isMaskedArray(a):
  207. raise ValueError('masked arrays are not supported')
  208. toarray = np.asarray_chkfinite if check_finite else np.asarray
  209. a = toarray(a)
  210. if not objects_ok:
  211. if a.dtype is np.dtype('O'):
  212. raise ValueError('object arrays are not supported')
  213. if as_inexact:
  214. if not np.issubdtype(a.dtype, np.inexact):
  215. a = toarray(a, dtype=np.float_)
  216. return a
  217. def _validate_int(k, name, minimum=None):
  218. """
  219. Validate a scalar integer.
  220. This functon can be used to validate an argument to a function
  221. that expects the value to be an integer. It uses `operator.index`
  222. to validate the value (so, for example, k=2.0 results in a
  223. TypeError).
  224. Parameters
  225. ----------
  226. k : int
  227. The value to be validated.
  228. name : str
  229. The name of the parameter.
  230. minimum : int, optional
  231. An optional lower bound.
  232. """
  233. try:
  234. k = operator.index(k)
  235. except TypeError:
  236. raise TypeError(f'{name} must be an integer.') from None
  237. if minimum is not None and k < minimum:
  238. raise ValueError(f'{name} must be an integer not less '
  239. f'than {minimum}') from None
  240. return k
  241. # Add a replacement for inspect.getfullargspec()/
  242. # The version below is borrowed from Django,
  243. # https://github.com/django/django/pull/4846.
  244. # Note an inconsistency between inspect.getfullargspec(func) and
  245. # inspect.signature(func). If `func` is a bound method, the latter does *not*
  246. # list `self` as a first argument, while the former *does*.
  247. # Hence, cook up a common ground replacement: `getfullargspec_no_self` which
  248. # mimics `inspect.getfullargspec` but does not list `self`.
  249. #
  250. # This way, the caller code does not need to know whether it uses a legacy
  251. # .getfullargspec or a bright and shiny .signature.
  252. FullArgSpec = namedtuple('FullArgSpec',
  253. ['args', 'varargs', 'varkw', 'defaults',
  254. 'kwonlyargs', 'kwonlydefaults', 'annotations'])
  255. def getfullargspec_no_self(func):
  256. """inspect.getfullargspec replacement using inspect.signature.
  257. If func is a bound method, do not list the 'self' parameter.
  258. Parameters
  259. ----------
  260. func : callable
  261. A callable to inspect
  262. Returns
  263. -------
  264. fullargspec : FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
  265. kwonlydefaults, annotations)
  266. NOTE: if the first argument of `func` is self, it is *not*, I repeat
  267. *not*, included in fullargspec.args.
  268. This is done for consistency between inspect.getargspec() under
  269. Python 2.x, and inspect.signature() under Python 3.x.
  270. """
  271. sig = inspect.signature(func)
  272. args = [
  273. p.name for p in sig.parameters.values()
  274. if p.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD,
  275. inspect.Parameter.POSITIONAL_ONLY]
  276. ]
  277. varargs = [
  278. p.name for p in sig.parameters.values()
  279. if p.kind == inspect.Parameter.VAR_POSITIONAL
  280. ]
  281. varargs = varargs[0] if varargs else None
  282. varkw = [
  283. p.name for p in sig.parameters.values()
  284. if p.kind == inspect.Parameter.VAR_KEYWORD
  285. ]
  286. varkw = varkw[0] if varkw else None
  287. defaults = tuple(
  288. p.default for p in sig.parameters.values()
  289. if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and
  290. p.default is not p.empty)
  291. ) or None
  292. kwonlyargs = [
  293. p.name for p in sig.parameters.values()
  294. if p.kind == inspect.Parameter.KEYWORD_ONLY
  295. ]
  296. kwdefaults = {p.name: p.default for p in sig.parameters.values()
  297. if p.kind == inspect.Parameter.KEYWORD_ONLY and
  298. p.default is not p.empty}
  299. annotations = {p.name: p.annotation for p in sig.parameters.values()
  300. if p.annotation is not p.empty}
  301. return FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
  302. kwdefaults or None, annotations)
  303. class _FunctionWrapper:
  304. """
  305. Object to wrap user's function, allowing picklability
  306. """
  307. def __init__(self, f, args):
  308. self.f = f
  309. self.args = [] if args is None else args
  310. def __call__(self, x):
  311. return self.f(x, *self.args)
  312. class MapWrapper:
  313. """
  314. Parallelisation wrapper for working with map-like callables, such as
  315. `multiprocessing.Pool.map`.
  316. Parameters
  317. ----------
  318. pool : int or map-like callable
  319. If `pool` is an integer, then it specifies the number of threads to
  320. use for parallelization. If ``int(pool) == 1``, then no parallel
  321. processing is used and the map builtin is used.
  322. If ``pool == -1``, then the pool will utilize all available CPUs.
  323. If `pool` is a map-like callable that follows the same
  324. calling sequence as the built-in map function, then this callable is
  325. used for parallelization.
  326. """
  327. def __init__(self, pool=1):
  328. self.pool = None
  329. self._mapfunc = map
  330. self._own_pool = False
  331. if callable(pool):
  332. self.pool = pool
  333. self._mapfunc = self.pool
  334. else:
  335. from multiprocessing import Pool
  336. # user supplies a number
  337. if int(pool) == -1:
  338. # use as many processors as possible
  339. self.pool = Pool()
  340. self._mapfunc = self.pool.map
  341. self._own_pool = True
  342. elif int(pool) == 1:
  343. pass
  344. elif int(pool) > 1:
  345. # use the number of processors requested
  346. self.pool = Pool(processes=int(pool))
  347. self._mapfunc = self.pool.map
  348. self._own_pool = True
  349. else:
  350. raise RuntimeError("Number of workers specified must be -1,"
  351. " an int >= 1, or an object with a 'map' "
  352. "method")
  353. def __enter__(self):
  354. return self
  355. def terminate(self):
  356. if self._own_pool:
  357. self.pool.terminate()
  358. def join(self):
  359. if self._own_pool:
  360. self.pool.join()
  361. def close(self):
  362. if self._own_pool:
  363. self.pool.close()
  364. def __exit__(self, exc_type, exc_value, traceback):
  365. if self._own_pool:
  366. self.pool.close()
  367. self.pool.terminate()
  368. def __call__(self, func, iterable):
  369. # only accept one iterable because that's all Pool.map accepts
  370. try:
  371. return self._mapfunc(func, iterable)
  372. except TypeError as e:
  373. # wrong number of arguments
  374. raise TypeError("The map-like callable must be of the"
  375. " form f(func, iterable)") from e
  376. def rng_integers(gen, low, high=None, size=None, dtype='int64',
  377. endpoint=False):
  378. """
  379. Return random integers from low (inclusive) to high (exclusive), or if
  380. endpoint=True, low (inclusive) to high (inclusive). Replaces
  381. `RandomState.randint` (with endpoint=False) and
  382. `RandomState.random_integers` (with endpoint=True).
  383. Return random integers from the "discrete uniform" distribution of the
  384. specified dtype. If high is None (the default), then results are from
  385. 0 to low.
  386. Parameters
  387. ----------
  388. gen : {None, np.random.RandomState, np.random.Generator}
  389. Random number generator. If None, then the np.random.RandomState
  390. singleton is used.
  391. low : int or array-like of ints
  392. Lowest (signed) integers to be drawn from the distribution (unless
  393. high=None, in which case this parameter is 0 and this value is used
  394. for high).
  395. high : int or array-like of ints
  396. If provided, one above the largest (signed) integer to be drawn from
  397. the distribution (see above for behavior if high=None). If array-like,
  398. must contain integer values.
  399. size : array-like of ints, optional
  400. Output shape. If the given shape is, e.g., (m, n, k), then m * n * k
  401. samples are drawn. Default is None, in which case a single value is
  402. returned.
  403. dtype : {str, dtype}, optional
  404. Desired dtype of the result. All dtypes are determined by their name,
  405. i.e., 'int64', 'int', etc, so byteorder is not available and a specific
  406. precision may have different C types depending on the platform.
  407. The default value is np.int_.
  408. endpoint : bool, optional
  409. If True, sample from the interval [low, high] instead of the default
  410. [low, high) Defaults to False.
  411. Returns
  412. -------
  413. out: int or ndarray of ints
  414. size-shaped array of random integers from the appropriate distribution,
  415. or a single such random int if size not provided.
  416. """
  417. if isinstance(gen, Generator):
  418. return gen.integers(low, high=high, size=size, dtype=dtype,
  419. endpoint=endpoint)
  420. else:
  421. if gen is None:
  422. # default is RandomState singleton used by np.random.
  423. gen = np.random.mtrand._rand
  424. if endpoint:
  425. # inclusive of endpoint
  426. # remember that low and high can be arrays, so don't modify in
  427. # place
  428. if high is None:
  429. return gen.randint(low + 1, size=size, dtype=dtype)
  430. if high is not None:
  431. return gen.randint(low, high=high + 1, size=size, dtype=dtype)
  432. # exclusive
  433. return gen.randint(low, high=high, size=size, dtype=dtype)
  434. @contextmanager
  435. def _fixed_default_rng(seed=1638083107694713882823079058616272161):
  436. """Context with a fixed np.random.default_rng seed."""
  437. orig_fun = np.random.default_rng
  438. np.random.default_rng = lambda seed=seed: orig_fun(seed)
  439. try:
  440. yield
  441. finally:
  442. np.random.default_rng = orig_fun
  443. def _argmin(a, keepdims=False, axis=None):
  444. """
  445. argmin with a `keepdims` parameter.
  446. See https://github.com/numpy/numpy/issues/8710
  447. If axis is not None, a.shape[axis] must be greater than 0.
  448. """
  449. res = np.argmin(a, axis=axis)
  450. if keepdims and axis is not None:
  451. res = np.expand_dims(res, axis=axis)
  452. return res
  453. def _first_nonnan(a, axis):
  454. """
  455. Return the first non-nan value along the given axis.
  456. If a slice is all nan, nan is returned for that slice.
  457. The shape of the return value corresponds to ``keepdims=True``.
  458. Examples
  459. --------
  460. >>> import numpy as np
  461. >>> nan = np.nan
  462. >>> a = np.array([[ 3., 3., nan, 3.],
  463. [ 1., nan, 2., 4.],
  464. [nan, nan, 9., -1.],
  465. [nan, 5., 4., 3.],
  466. [ 2., 2., 2., 2.],
  467. [nan, nan, nan, nan]])
  468. >>> _first_nonnan(a, axis=0)
  469. array([[3., 3., 2., 3.]])
  470. >>> _first_nonnan(a, axis=1)
  471. array([[ 3.],
  472. [ 1.],
  473. [ 9.],
  474. [ 5.],
  475. [ 2.],
  476. [nan]])
  477. """
  478. k = _argmin(np.isnan(a), axis=axis, keepdims=True)
  479. return np.take_along_axis(a, k, axis=axis)
  480. def _nan_allsame(a, axis, keepdims=False):
  481. """
  482. Determine if the values along an axis are all the same.
  483. nan values are ignored.
  484. `a` must be a numpy array.
  485. `axis` is assumed to be normalized; that is, 0 <= axis < a.ndim.
  486. For an axis of length 0, the result is True. That is, we adopt the
  487. convention that ``allsame([])`` is True. (There are no values in the
  488. input that are different.)
  489. `True` is returned for slices that are all nan--not because all the
  490. values are the same, but because this is equivalent to ``allsame([])``.
  491. Examples
  492. --------
  493. >>> import numpy as np
  494. >>> a = np.array([[ 3., 3., nan, 3.],
  495. [ 1., nan, 2., 4.],
  496. [nan, nan, 9., -1.],
  497. [nan, 5., 4., 3.],
  498. [ 2., 2., 2., 2.],
  499. [nan, nan, nan, nan]])
  500. >>> _nan_allsame(a, axis=1, keepdims=True)
  501. array([[ True],
  502. [False],
  503. [False],
  504. [False],
  505. [ True],
  506. [ True]])
  507. """
  508. if axis is None:
  509. if a.size == 0:
  510. return True
  511. a = a.ravel()
  512. axis = 0
  513. else:
  514. shp = a.shape
  515. if shp[axis] == 0:
  516. shp = shp[:axis] + (1,)*keepdims + shp[axis + 1:]
  517. return np.full(shp, fill_value=True, dtype=bool)
  518. a0 = _first_nonnan(a, axis=axis)
  519. return ((a0 == a) | np.isnan(a)).all(axis=axis, keepdims=keepdims)
  520. def _contains_nan(a, nan_policy='propagate', use_summation=True):
  521. if not isinstance(a, np.ndarray):
  522. use_summation = False # some array_likes ignore nans (e.g. pandas)
  523. policies = ['propagate', 'raise', 'omit']
  524. if nan_policy not in policies:
  525. raise ValueError("nan_policy must be one of {%s}" %
  526. ', '.join("'%s'" % s for s in policies))
  527. if np.issubdtype(a.dtype, np.inexact):
  528. # The summation method avoids creating a (potentially huge) array.
  529. if use_summation:
  530. with np.errstate(invalid='ignore', over='ignore'):
  531. contains_nan = np.isnan(np.sum(a))
  532. else:
  533. contains_nan = np.isnan(a).any()
  534. elif np.issubdtype(a.dtype, object):
  535. contains_nan = False
  536. for el in a.ravel():
  537. # isnan doesn't work on non-numeric elements
  538. if np.issubdtype(type(el), np.number) and np.isnan(el):
  539. contains_nan = True
  540. break
  541. else:
  542. # Only `object` and `inexact` arrays can have NaNs
  543. contains_nan = False
  544. if contains_nan and nan_policy == 'raise':
  545. raise ValueError("The input contains nan values")
  546. return contains_nan, nan_policy
  547. def _rename_parameter(old_name, new_name, dep_version=None):
  548. """
  549. Generate decorator for backward-compatible keyword renaming.
  550. Apply the decorator generated by `_rename_parameter` to functions with a
  551. recently renamed parameter to maintain backward-compatibility.
  552. After decoration, the function behaves as follows:
  553. If only the new parameter is passed into the function, behave as usual.
  554. If only the old parameter is passed into the function (as a keyword), raise
  555. a DeprecationWarning if `dep_version` is provided, and behave as usual
  556. otherwise.
  557. If both old and new parameters are passed into the function, raise a
  558. DeprecationWarning if `dep_version` is provided, and raise the appropriate
  559. TypeError (function got multiple values for argument).
  560. Parameters
  561. ----------
  562. old_name : str
  563. Old name of parameter
  564. new_name : str
  565. New name of parameter
  566. dep_version : str, optional
  567. Version of SciPy in which old parameter was deprecated in the format
  568. 'X.Y.Z'. If supplied, the deprecation message will indicate that
  569. support for the old parameter will be removed in version 'X.Y+2.Z'
  570. Notes
  571. -----
  572. Untested with functions that accept *args. Probably won't work as written.
  573. """
  574. def decorator(fun):
  575. @functools.wraps(fun)
  576. def wrapper(*args, **kwargs):
  577. if old_name in kwargs:
  578. if dep_version:
  579. end_version = dep_version.split('.')
  580. end_version[1] = str(int(end_version[1]) + 2)
  581. end_version = '.'.join(end_version)
  582. message = (f"Use of keyword argument `{old_name}` is "
  583. f"deprecated and replaced by `{new_name}`. "
  584. f"Support for `{old_name}` will be removed "
  585. f"in SciPy {end_version}.")
  586. warnings.warn(message, DeprecationWarning, stacklevel=2)
  587. if new_name in kwargs:
  588. message = (f"{fun.__name__}() got multiple values for "
  589. f"argument now known as `{new_name}`")
  590. raise TypeError(message)
  591. kwargs[new_name] = kwargs.pop(old_name)
  592. return fun(*args, **kwargs)
  593. return wrapper
  594. return decorator
  595. def _rng_spawn(rng, n_children):
  596. # spawns independent RNGs from a parent RNG
  597. bg = rng._bit_generator
  598. ss = bg._seed_seq
  599. child_rngs = [np.random.Generator(type(bg)(child_ss))
  600. for child_ss in ss.spawn(n_children)]
  601. return child_rngs