123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711 |
- from contextlib import contextmanager
- import functools
- import operator
- import warnings
- import numbers
- from collections import namedtuple
- import inspect
- import math
- from typing import (
- Optional,
- Union,
- TYPE_CHECKING,
- TypeVar,
- )
- import numpy as np
- IntNumber = Union[int, np.integer]
- DecimalNumber = Union[float, np.floating, np.integer]
- # Since Generator was introduced in numpy 1.17, the following condition is needed for
- # backward compatibility
- if TYPE_CHECKING:
- SeedType = Optional[Union[IntNumber, np.random.Generator,
- np.random.RandomState]]
- GeneratorType = TypeVar("GeneratorType", bound=Union[np.random.Generator,
- np.random.RandomState])
- try:
- from numpy.random import Generator as Generator
- except ImportError:
- class Generator(): # type: ignore[no-redef]
- pass
- def _lazywhere(cond, arrays, f, fillvalue=None, f2=None):
- """
- np.where(cond, x, fillvalue) always evaluates x even where cond is False.
- This one only evaluates f(arr1[cond], arr2[cond], ...).
- Examples
- --------
- >>> import numpy as np
- >>> a, b = np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8])
- >>> def f(a, b):
- ... return a*b
- >>> _lazywhere(a > 2, (a, b), f, np.nan)
- array([ nan, nan, 21., 32.])
- Notice, it assumes that all `arrays` are of the same shape, or can be
- broadcasted together.
- """
- cond = np.asarray(cond)
- if fillvalue is None:
- if f2 is None:
- raise ValueError("One of (fillvalue, f2) must be given.")
- else:
- fillvalue = np.nan
- else:
- if f2 is not None:
- raise ValueError("Only one of (fillvalue, f2) can be given.")
- args = np.broadcast_arrays(cond, *arrays)
- cond, arrays = args[0], args[1:]
- temp = tuple(np.extract(cond, arr) for arr in arrays)
- tcode = np.mintypecode([a.dtype.char for a in arrays])
- out = np.full(np.shape(arrays[0]), fill_value=fillvalue, dtype=tcode)
- np.place(out, cond, f(*temp))
- if f2 is not None:
- temp = tuple(np.extract(~cond, arr) for arr in arrays)
- np.place(out, ~cond, f2(*temp))
- return out
- def _lazyselect(condlist, choicelist, arrays, default=0):
- """
- Mimic `np.select(condlist, choicelist)`.
- Notice, it assumes that all `arrays` are of the same shape or can be
- broadcasted together.
- All functions in `choicelist` must accept array arguments in the order
- given in `arrays` and must return an array of the same shape as broadcasted
- `arrays`.
- Examples
- --------
- >>> import numpy as np
- >>> x = np.arange(6)
- >>> np.select([x <3, x > 3], [x**2, x**3], default=0)
- array([ 0, 1, 4, 0, 64, 125])
- >>> _lazyselect([x < 3, x > 3], [lambda x: x**2, lambda x: x**3], (x,))
- array([ 0., 1., 4., 0., 64., 125.])
- >>> a = -np.ones_like(x)
- >>> _lazyselect([x < 3, x > 3],
- ... [lambda x, a: x**2, lambda x, a: a * x**3],
- ... (x, a), default=np.nan)
- array([ 0., 1., 4., nan, -64., -125.])
- """
- arrays = np.broadcast_arrays(*arrays)
- tcode = np.mintypecode([a.dtype.char for a in arrays])
- out = np.full(np.shape(arrays[0]), fill_value=default, dtype=tcode)
- for func, cond in zip(choicelist, condlist):
- if np.all(cond is False):
- continue
- cond, _ = np.broadcast_arrays(cond, arrays[0])
- temp = tuple(np.extract(cond, arr) for arr in arrays)
- np.place(out, cond, func(*temp))
- return out
- def _aligned_zeros(shape, dtype=float, order="C", align=None):
- """Allocate a new ndarray with aligned memory.
- Primary use case for this currently is working around a f2py issue
- in NumPy 1.9.1, where dtype.alignment is such that np.zeros() does
- not necessarily create arrays aligned up to it.
- """
- dtype = np.dtype(dtype)
- if align is None:
- align = dtype.alignment
- if not hasattr(shape, '__len__'):
- shape = (shape,)
- size = functools.reduce(operator.mul, shape) * dtype.itemsize
- buf = np.empty(size + align + 1, np.uint8)
- offset = buf.__array_interface__['data'][0] % align
- if offset != 0:
- offset = align - offset
- # Note: slices producing 0-size arrays do not necessarily change
- # data pointer --- so we use and allocate size+1
- buf = buf[offset:offset+size+1][:-1]
- data = np.ndarray(shape, dtype, buf, order=order)
- data.fill(0)
- return data
- def _prune_array(array):
- """Return an array equivalent to the input array. If the input
- array is a view of a much larger array, copy its contents to a
- newly allocated array. Otherwise, return the input unchanged.
- """
- if array.base is not None and array.size < array.base.size // 2:
- return array.copy()
- return array
- def prod(iterable):
- """
- Product of a sequence of numbers.
- Faster than np.prod for short lists like array shapes, and does
- not overflow if using Python integers.
- """
- product = 1
- for x in iterable:
- product *= x
- return product
- def float_factorial(n: int) -> float:
- """Compute the factorial and return as a float
- Returns infinity when result is too large for a double
- """
- return float(math.factorial(n)) if n < 171 else np.inf
- # copy-pasted from scikit-learn utils/validation.py
- # change this to scipy.stats._qmc.check_random_state once numpy 1.16 is dropped
- def check_random_state(seed):
- """Turn `seed` into a `np.random.RandomState` instance.
- Parameters
- ----------
- seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
- If `seed` is None (or `np.random`), the `numpy.random.RandomState`
- singleton is used.
- If `seed` is an int, a new ``RandomState`` instance is used,
- seeded with `seed`.
- If `seed` is already a ``Generator`` or ``RandomState`` instance then
- that instance is used.
- Returns
- -------
- seed : {`numpy.random.Generator`, `numpy.random.RandomState`}
- Random number generator.
- """
- if seed is None or seed is np.random:
- return np.random.mtrand._rand
- if isinstance(seed, (numbers.Integral, np.integer)):
- return np.random.RandomState(seed)
- if isinstance(seed, (np.random.RandomState, np.random.Generator)):
- return seed
- raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
- ' instance' % seed)
- def _asarray_validated(a, check_finite=True,
- sparse_ok=False, objects_ok=False, mask_ok=False,
- as_inexact=False):
- """
- Helper function for SciPy argument validation.
- Many SciPy linear algebra functions do support arbitrary array-like
- input arguments. Examples of commonly unsupported inputs include
- matrices containing inf/nan, sparse matrix representations, and
- matrices with complicated elements.
- Parameters
- ----------
- a : array_like
- The array-like input.
- check_finite : bool, optional
- Whether to check that the input matrices contain only finite numbers.
- Disabling may give a performance gain, but may result in problems
- (crashes, non-termination) if the inputs do contain infinities or NaNs.
- Default: True
- sparse_ok : bool, optional
- True if scipy sparse matrices are allowed.
- objects_ok : bool, optional
- True if arrays with dype('O') are allowed.
- mask_ok : bool, optional
- True if masked arrays are allowed.
- as_inexact : bool, optional
- True to convert the input array to a np.inexact dtype.
- Returns
- -------
- ret : ndarray
- The converted validated array.
- """
- if not sparse_ok:
- import scipy.sparse
- if scipy.sparse.issparse(a):
- msg = ('Sparse matrices are not supported by this function. '
- 'Perhaps one of the scipy.sparse.linalg functions '
- 'would work instead.')
- raise ValueError(msg)
- if not mask_ok:
- if np.ma.isMaskedArray(a):
- raise ValueError('masked arrays are not supported')
- toarray = np.asarray_chkfinite if check_finite else np.asarray
- a = toarray(a)
- if not objects_ok:
- if a.dtype is np.dtype('O'):
- raise ValueError('object arrays are not supported')
- if as_inexact:
- if not np.issubdtype(a.dtype, np.inexact):
- a = toarray(a, dtype=np.float_)
- return a
- def _validate_int(k, name, minimum=None):
- """
- Validate a scalar integer.
- This functon can be used to validate an argument to a function
- that expects the value to be an integer. It uses `operator.index`
- to validate the value (so, for example, k=2.0 results in a
- TypeError).
- Parameters
- ----------
- k : int
- The value to be validated.
- name : str
- The name of the parameter.
- minimum : int, optional
- An optional lower bound.
- """
- try:
- k = operator.index(k)
- except TypeError:
- raise TypeError(f'{name} must be an integer.') from None
- if minimum is not None and k < minimum:
- raise ValueError(f'{name} must be an integer not less '
- f'than {minimum}') from None
- return k
- # Add a replacement for inspect.getfullargspec()/
- # The version below is borrowed from Django,
- # https://github.com/django/django/pull/4846.
- # Note an inconsistency between inspect.getfullargspec(func) and
- # inspect.signature(func). If `func` is a bound method, the latter does *not*
- # list `self` as a first argument, while the former *does*.
- # Hence, cook up a common ground replacement: `getfullargspec_no_self` which
- # mimics `inspect.getfullargspec` but does not list `self`.
- #
- # This way, the caller code does not need to know whether it uses a legacy
- # .getfullargspec or a bright and shiny .signature.
- FullArgSpec = namedtuple('FullArgSpec',
- ['args', 'varargs', 'varkw', 'defaults',
- 'kwonlyargs', 'kwonlydefaults', 'annotations'])
- def getfullargspec_no_self(func):
- """inspect.getfullargspec replacement using inspect.signature.
- If func is a bound method, do not list the 'self' parameter.
- Parameters
- ----------
- func : callable
- A callable to inspect
- Returns
- -------
- fullargspec : FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
- kwonlydefaults, annotations)
- NOTE: if the first argument of `func` is self, it is *not*, I repeat
- *not*, included in fullargspec.args.
- This is done for consistency between inspect.getargspec() under
- Python 2.x, and inspect.signature() under Python 3.x.
- """
- sig = inspect.signature(func)
- args = [
- p.name for p in sig.parameters.values()
- if p.kind in [inspect.Parameter.POSITIONAL_OR_KEYWORD,
- inspect.Parameter.POSITIONAL_ONLY]
- ]
- varargs = [
- p.name for p in sig.parameters.values()
- if p.kind == inspect.Parameter.VAR_POSITIONAL
- ]
- varargs = varargs[0] if varargs else None
- varkw = [
- p.name for p in sig.parameters.values()
- if p.kind == inspect.Parameter.VAR_KEYWORD
- ]
- varkw = varkw[0] if varkw else None
- defaults = tuple(
- p.default for p in sig.parameters.values()
- if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and
- p.default is not p.empty)
- ) or None
- kwonlyargs = [
- p.name for p in sig.parameters.values()
- if p.kind == inspect.Parameter.KEYWORD_ONLY
- ]
- kwdefaults = {p.name: p.default for p in sig.parameters.values()
- if p.kind == inspect.Parameter.KEYWORD_ONLY and
- p.default is not p.empty}
- annotations = {p.name: p.annotation for p in sig.parameters.values()
- if p.annotation is not p.empty}
- return FullArgSpec(args, varargs, varkw, defaults, kwonlyargs,
- kwdefaults or None, annotations)
- class _FunctionWrapper:
- """
- Object to wrap user's function, allowing picklability
- """
- def __init__(self, f, args):
- self.f = f
- self.args = [] if args is None else args
- def __call__(self, x):
- return self.f(x, *self.args)
- class MapWrapper:
- """
- Parallelisation wrapper for working with map-like callables, such as
- `multiprocessing.Pool.map`.
- Parameters
- ----------
- pool : int or map-like callable
- If `pool` is an integer, then it specifies the number of threads to
- use for parallelization. If ``int(pool) == 1``, then no parallel
- processing is used and the map builtin is used.
- If ``pool == -1``, then the pool will utilize all available CPUs.
- If `pool` is a map-like callable that follows the same
- calling sequence as the built-in map function, then this callable is
- used for parallelization.
- """
- def __init__(self, pool=1):
- self.pool = None
- self._mapfunc = map
- self._own_pool = False
- if callable(pool):
- self.pool = pool
- self._mapfunc = self.pool
- else:
- from multiprocessing import Pool
- # user supplies a number
- if int(pool) == -1:
- # use as many processors as possible
- self.pool = Pool()
- self._mapfunc = self.pool.map
- self._own_pool = True
- elif int(pool) == 1:
- pass
- elif int(pool) > 1:
- # use the number of processors requested
- self.pool = Pool(processes=int(pool))
- self._mapfunc = self.pool.map
- self._own_pool = True
- else:
- raise RuntimeError("Number of workers specified must be -1,"
- " an int >= 1, or an object with a 'map' "
- "method")
- def __enter__(self):
- return self
- def terminate(self):
- if self._own_pool:
- self.pool.terminate()
- def join(self):
- if self._own_pool:
- self.pool.join()
- def close(self):
- if self._own_pool:
- self.pool.close()
- def __exit__(self, exc_type, exc_value, traceback):
- if self._own_pool:
- self.pool.close()
- self.pool.terminate()
- def __call__(self, func, iterable):
- # only accept one iterable because that's all Pool.map accepts
- try:
- return self._mapfunc(func, iterable)
- except TypeError as e:
- # wrong number of arguments
- raise TypeError("The map-like callable must be of the"
- " form f(func, iterable)") from e
- def rng_integers(gen, low, high=None, size=None, dtype='int64',
- endpoint=False):
- """
- Return random integers from low (inclusive) to high (exclusive), or if
- endpoint=True, low (inclusive) to high (inclusive). Replaces
- `RandomState.randint` (with endpoint=False) and
- `RandomState.random_integers` (with endpoint=True).
- Return random integers from the "discrete uniform" distribution of the
- specified dtype. If high is None (the default), then results are from
- 0 to low.
- Parameters
- ----------
- gen : {None, np.random.RandomState, np.random.Generator}
- Random number generator. If None, then the np.random.RandomState
- singleton is used.
- low : int or array-like of ints
- Lowest (signed) integers to be drawn from the distribution (unless
- high=None, in which case this parameter is 0 and this value is used
- for high).
- high : int or array-like of ints
- If provided, one above the largest (signed) integer to be drawn from
- the distribution (see above for behavior if high=None). If array-like,
- must contain integer values.
- size : array-like of ints, optional
- Output shape. If the given shape is, e.g., (m, n, k), then m * n * k
- samples are drawn. Default is None, in which case a single value is
- returned.
- dtype : {str, dtype}, optional
- Desired dtype of the result. All dtypes are determined by their name,
- i.e., 'int64', 'int', etc, so byteorder is not available and a specific
- precision may have different C types depending on the platform.
- The default value is np.int_.
- endpoint : bool, optional
- If True, sample from the interval [low, high] instead of the default
- [low, high) Defaults to False.
- Returns
- -------
- out: int or ndarray of ints
- size-shaped array of random integers from the appropriate distribution,
- or a single such random int if size not provided.
- """
- if isinstance(gen, Generator):
- return gen.integers(low, high=high, size=size, dtype=dtype,
- endpoint=endpoint)
- else:
- if gen is None:
- # default is RandomState singleton used by np.random.
- gen = np.random.mtrand._rand
- if endpoint:
- # inclusive of endpoint
- # remember that low and high can be arrays, so don't modify in
- # place
- if high is None:
- return gen.randint(low + 1, size=size, dtype=dtype)
- if high is not None:
- return gen.randint(low, high=high + 1, size=size, dtype=dtype)
- # exclusive
- return gen.randint(low, high=high, size=size, dtype=dtype)
- @contextmanager
- def _fixed_default_rng(seed=1638083107694713882823079058616272161):
- """Context with a fixed np.random.default_rng seed."""
- orig_fun = np.random.default_rng
- np.random.default_rng = lambda seed=seed: orig_fun(seed)
- try:
- yield
- finally:
- np.random.default_rng = orig_fun
- def _argmin(a, keepdims=False, axis=None):
- """
- argmin with a `keepdims` parameter.
- See https://github.com/numpy/numpy/issues/8710
- If axis is not None, a.shape[axis] must be greater than 0.
- """
- res = np.argmin(a, axis=axis)
- if keepdims and axis is not None:
- res = np.expand_dims(res, axis=axis)
- return res
- def _first_nonnan(a, axis):
- """
- Return the first non-nan value along the given axis.
- If a slice is all nan, nan is returned for that slice.
- The shape of the return value corresponds to ``keepdims=True``.
- Examples
- --------
- >>> import numpy as np
- >>> nan = np.nan
- >>> a = np.array([[ 3., 3., nan, 3.],
- [ 1., nan, 2., 4.],
- [nan, nan, 9., -1.],
- [nan, 5., 4., 3.],
- [ 2., 2., 2., 2.],
- [nan, nan, nan, nan]])
- >>> _first_nonnan(a, axis=0)
- array([[3., 3., 2., 3.]])
- >>> _first_nonnan(a, axis=1)
- array([[ 3.],
- [ 1.],
- [ 9.],
- [ 5.],
- [ 2.],
- [nan]])
- """
- k = _argmin(np.isnan(a), axis=axis, keepdims=True)
- return np.take_along_axis(a, k, axis=axis)
- def _nan_allsame(a, axis, keepdims=False):
- """
- Determine if the values along an axis are all the same.
- nan values are ignored.
- `a` must be a numpy array.
- `axis` is assumed to be normalized; that is, 0 <= axis < a.ndim.
- For an axis of length 0, the result is True. That is, we adopt the
- convention that ``allsame([])`` is True. (There are no values in the
- input that are different.)
- `True` is returned for slices that are all nan--not because all the
- values are the same, but because this is equivalent to ``allsame([])``.
- Examples
- --------
- >>> import numpy as np
- >>> a = np.array([[ 3., 3., nan, 3.],
- [ 1., nan, 2., 4.],
- [nan, nan, 9., -1.],
- [nan, 5., 4., 3.],
- [ 2., 2., 2., 2.],
- [nan, nan, nan, nan]])
- >>> _nan_allsame(a, axis=1, keepdims=True)
- array([[ True],
- [False],
- [False],
- [False],
- [ True],
- [ True]])
- """
- if axis is None:
- if a.size == 0:
- return True
- a = a.ravel()
- axis = 0
- else:
- shp = a.shape
- if shp[axis] == 0:
- shp = shp[:axis] + (1,)*keepdims + shp[axis + 1:]
- return np.full(shp, fill_value=True, dtype=bool)
- a0 = _first_nonnan(a, axis=axis)
- return ((a0 == a) | np.isnan(a)).all(axis=axis, keepdims=keepdims)
- def _contains_nan(a, nan_policy='propagate', use_summation=True):
- if not isinstance(a, np.ndarray):
- use_summation = False # some array_likes ignore nans (e.g. pandas)
- policies = ['propagate', 'raise', 'omit']
- if nan_policy not in policies:
- raise ValueError("nan_policy must be one of {%s}" %
- ', '.join("'%s'" % s for s in policies))
- if np.issubdtype(a.dtype, np.inexact):
- # The summation method avoids creating a (potentially huge) array.
- if use_summation:
- with np.errstate(invalid='ignore', over='ignore'):
- contains_nan = np.isnan(np.sum(a))
- else:
- contains_nan = np.isnan(a).any()
- elif np.issubdtype(a.dtype, object):
- contains_nan = False
- for el in a.ravel():
- # isnan doesn't work on non-numeric elements
- if np.issubdtype(type(el), np.number) and np.isnan(el):
- contains_nan = True
- break
- else:
- # Only `object` and `inexact` arrays can have NaNs
- contains_nan = False
- if contains_nan and nan_policy == 'raise':
- raise ValueError("The input contains nan values")
- return contains_nan, nan_policy
- def _rename_parameter(old_name, new_name, dep_version=None):
- """
- Generate decorator for backward-compatible keyword renaming.
- Apply the decorator generated by `_rename_parameter` to functions with a
- recently renamed parameter to maintain backward-compatibility.
- After decoration, the function behaves as follows:
- If only the new parameter is passed into the function, behave as usual.
- If only the old parameter is passed into the function (as a keyword), raise
- a DeprecationWarning if `dep_version` is provided, and behave as usual
- otherwise.
- If both old and new parameters are passed into the function, raise a
- DeprecationWarning if `dep_version` is provided, and raise the appropriate
- TypeError (function got multiple values for argument).
- Parameters
- ----------
- old_name : str
- Old name of parameter
- new_name : str
- New name of parameter
- dep_version : str, optional
- Version of SciPy in which old parameter was deprecated in the format
- 'X.Y.Z'. If supplied, the deprecation message will indicate that
- support for the old parameter will be removed in version 'X.Y+2.Z'
- Notes
- -----
- Untested with functions that accept *args. Probably won't work as written.
- """
- def decorator(fun):
- @functools.wraps(fun)
- def wrapper(*args, **kwargs):
- if old_name in kwargs:
- if dep_version:
- end_version = dep_version.split('.')
- end_version[1] = str(int(end_version[1]) + 2)
- end_version = '.'.join(end_version)
- message = (f"Use of keyword argument `{old_name}` is "
- f"deprecated and replaced by `{new_name}`. "
- f"Support for `{old_name}` will be removed "
- f"in SciPy {end_version}.")
- warnings.warn(message, DeprecationWarning, stacklevel=2)
- if new_name in kwargs:
- message = (f"{fun.__name__}() got multiple values for "
- f"argument now known as `{new_name}`")
- raise TypeError(message)
- kwargs[new_name] = kwargs.pop(old_name)
- return fun(*args, **kwargs)
- return wrapper
- return decorator
- def _rng_spawn(rng, n_children):
- # spawns independent RNGs from a parent RNG
- bg = rng._bit_generator
- ss = bg._seed_seq
- child_rngs = [np.random.Generator(type(bg)(child_ss))
- for child_ss in ss.spawn(n_children)]
- return child_rngs
|