123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447 |
- import os
- import sys
- import time
- from itertools import zip_longest
- import numpy as np
- from numpy.testing import assert_
- import pytest
- from scipy.special._testutils import assert_func_equal
- try:
- import mpmath
- except ImportError:
- pass
- # ------------------------------------------------------------------------------
- # Machinery for systematic tests with mpmath
- # ------------------------------------------------------------------------------
- class Arg:
- """Generate a set of numbers on the real axis, concentrating on
- 'interesting' regions and covering all orders of magnitude.
- """
- def __init__(self, a=-np.inf, b=np.inf, inclusive_a=True, inclusive_b=True):
- if a > b:
- raise ValueError("a should be less than or equal to b")
- if a == -np.inf:
- a = -0.5*np.finfo(float).max
- if b == np.inf:
- b = 0.5*np.finfo(float).max
- self.a, self.b = a, b
- self.inclusive_a, self.inclusive_b = inclusive_a, inclusive_b
- def _positive_values(self, a, b, n):
- if a < 0:
- raise ValueError("a should be positive")
- # Try to put half of the points into a linspace between a and
- # 10 the other half in a logspace.
- if n % 2 == 0:
- nlogpts = n//2
- nlinpts = nlogpts
- else:
- nlogpts = n//2
- nlinpts = nlogpts + 1
- if a >= 10:
- # Outside of linspace range; just return a logspace.
- pts = np.logspace(np.log10(a), np.log10(b), n)
- elif a > 0 and b < 10:
- # Outside of logspace range; just return a linspace
- pts = np.linspace(a, b, n)
- elif a > 0:
- # Linspace between a and 10 and a logspace between 10 and
- # b.
- linpts = np.linspace(a, 10, nlinpts, endpoint=False)
- logpts = np.logspace(1, np.log10(b), nlogpts)
- pts = np.hstack((linpts, logpts))
- elif a == 0 and b <= 10:
- # Linspace between 0 and b and a logspace between 0 and
- # the smallest positive point of the linspace
- linpts = np.linspace(0, b, nlinpts)
- if linpts.size > 1:
- right = np.log10(linpts[1])
- else:
- right = -30
- logpts = np.logspace(-30, right, nlogpts, endpoint=False)
- pts = np.hstack((logpts, linpts))
- else:
- # Linspace between 0 and 10, logspace between 0 and the
- # smallest positive point of the linspace, and a logspace
- # between 10 and b.
- if nlogpts % 2 == 0:
- nlogpts1 = nlogpts//2
- nlogpts2 = nlogpts1
- else:
- nlogpts1 = nlogpts//2
- nlogpts2 = nlogpts1 + 1
- linpts = np.linspace(0, 10, nlinpts, endpoint=False)
- if linpts.size > 1:
- right = np.log10(linpts[1])
- else:
- right = -30
- logpts1 = np.logspace(-30, right, nlogpts1, endpoint=False)
- logpts2 = np.logspace(1, np.log10(b), nlogpts2)
- pts = np.hstack((logpts1, linpts, logpts2))
- return np.sort(pts)
- def values(self, n):
- """Return an array containing n numbers."""
- a, b = self.a, self.b
- if a == b:
- return np.zeros(n)
- if not self.inclusive_a:
- n += 1
- if not self.inclusive_b:
- n += 1
- if n % 2 == 0:
- n1 = n//2
- n2 = n1
- else:
- n1 = n//2
- n2 = n1 + 1
- if a >= 0:
- pospts = self._positive_values(a, b, n)
- negpts = []
- elif b <= 0:
- pospts = []
- negpts = -self._positive_values(-b, -a, n)
- else:
- pospts = self._positive_values(0, b, n1)
- negpts = -self._positive_values(0, -a, n2 + 1)
- # Don't want to get zero twice
- negpts = negpts[1:]
- pts = np.hstack((negpts[::-1], pospts))
- if not self.inclusive_a:
- pts = pts[1:]
- if not self.inclusive_b:
- pts = pts[:-1]
- return pts
- class FixedArg:
- def __init__(self, values):
- self._values = np.asarray(values)
- def values(self, n):
- return self._values
- class ComplexArg:
- def __init__(self, a=complex(-np.inf, -np.inf), b=complex(np.inf, np.inf)):
- self.real = Arg(a.real, b.real)
- self.imag = Arg(a.imag, b.imag)
- def values(self, n):
- m = int(np.floor(np.sqrt(n)))
- x = self.real.values(m)
- y = self.imag.values(m + 1)
- return (x[:,None] + 1j*y[None,:]).ravel()
- class IntArg:
- def __init__(self, a=-1000, b=1000):
- self.a = a
- self.b = b
- def values(self, n):
- v1 = Arg(self.a, self.b).values(max(1 + n//2, n-5)).astype(int)
- v2 = np.arange(-5, 5)
- v = np.unique(np.r_[v1, v2])
- v = v[(v >= self.a) & (v < self.b)]
- return v
- def get_args(argspec, n):
- if isinstance(argspec, np.ndarray):
- args = argspec.copy()
- else:
- nargs = len(argspec)
- ms = np.asarray([1.5 if isinstance(spec, ComplexArg) else 1.0 for spec in argspec])
- ms = (n**(ms/sum(ms))).astype(int) + 1
- args = [spec.values(m) for spec, m in zip(argspec, ms)]
- args = np.array(np.broadcast_arrays(*np.ix_(*args))).reshape(nargs, -1).T
- return args
- class MpmathData:
- def __init__(self, scipy_func, mpmath_func, arg_spec, name=None,
- dps=None, prec=None, n=None, rtol=1e-7, atol=1e-300,
- ignore_inf_sign=False, distinguish_nan_and_inf=True,
- nan_ok=True, param_filter=None):
- # mpmath tests are really slow (see gh-6989). Use a small number of
- # points by default, increase back to 5000 (old default) if XSLOW is
- # set
- if n is None:
- try:
- is_xslow = int(os.environ.get('SCIPY_XSLOW', '0'))
- except ValueError:
- is_xslow = False
- n = 5000 if is_xslow else 500
- self.scipy_func = scipy_func
- self.mpmath_func = mpmath_func
- self.arg_spec = arg_spec
- self.dps = dps
- self.prec = prec
- self.n = n
- self.rtol = rtol
- self.atol = atol
- self.ignore_inf_sign = ignore_inf_sign
- self.nan_ok = nan_ok
- if isinstance(self.arg_spec, np.ndarray):
- self.is_complex = np.issubdtype(self.arg_spec.dtype, np.complexfloating)
- else:
- self.is_complex = any([isinstance(arg, ComplexArg) for arg in self.arg_spec])
- self.ignore_inf_sign = ignore_inf_sign
- self.distinguish_nan_and_inf = distinguish_nan_and_inf
- if not name or name == '<lambda>':
- name = getattr(scipy_func, '__name__', None)
- if not name or name == '<lambda>':
- name = getattr(mpmath_func, '__name__', None)
- self.name = name
- self.param_filter = param_filter
- def check(self):
- np.random.seed(1234)
- # Generate values for the arguments
- argarr = get_args(self.arg_spec, self.n)
- # Check
- old_dps, old_prec = mpmath.mp.dps, mpmath.mp.prec
- try:
- if self.dps is not None:
- dps_list = [self.dps]
- else:
- dps_list = [20]
- if self.prec is not None:
- mpmath.mp.prec = self.prec
- # Proper casting of mpmath input and output types. Using
- # native mpmath types as inputs gives improved precision
- # in some cases.
- if np.issubdtype(argarr.dtype, np.complexfloating):
- pytype = mpc2complex
- def mptype(x):
- return mpmath.mpc(complex(x))
- else:
- def mptype(x):
- return mpmath.mpf(float(x))
- def pytype(x):
- if abs(x.imag) > 1e-16*(1 + abs(x.real)):
- return np.nan
- else:
- return mpf2float(x.real)
- # Try out different dps until one (or none) works
- for j, dps in enumerate(dps_list):
- mpmath.mp.dps = dps
- try:
- assert_func_equal(self.scipy_func,
- lambda *a: pytype(self.mpmath_func(*map(mptype, a))),
- argarr,
- vectorized=False,
- rtol=self.rtol, atol=self.atol,
- ignore_inf_sign=self.ignore_inf_sign,
- distinguish_nan_and_inf=self.distinguish_nan_and_inf,
- nan_ok=self.nan_ok,
- param_filter=self.param_filter)
- break
- except AssertionError:
- if j >= len(dps_list)-1:
- # reraise the Exception
- tp, value, tb = sys.exc_info()
- if value.__traceback__ is not tb:
- raise value.with_traceback(tb)
- raise value
- finally:
- mpmath.mp.dps, mpmath.mp.prec = old_dps, old_prec
- def __repr__(self):
- if self.is_complex:
- return "<MpmathData: %s (complex)>" % (self.name,)
- else:
- return "<MpmathData: %s>" % (self.name,)
- def assert_mpmath_equal(*a, **kw):
- d = MpmathData(*a, **kw)
- d.check()
- def nonfunctional_tooslow(func):
- return pytest.mark.skip(reason=" Test not yet functional (too slow), needs more work.")(func)
- # ------------------------------------------------------------------------------
- # Tools for dealing with mpmath quirks
- # ------------------------------------------------------------------------------
- def mpf2float(x):
- """
- Convert an mpf to the nearest floating point number. Just using
- float directly doesn't work because of results like this:
- with mp.workdps(50):
- float(mpf("0.99999999999999999")) = 0.9999999999999999
- """
- return float(mpmath.nstr(x, 17, min_fixed=0, max_fixed=0))
- def mpc2complex(x):
- return complex(mpf2float(x.real), mpf2float(x.imag))
- def trace_args(func):
- def tofloat(x):
- if isinstance(x, mpmath.mpc):
- return complex(x)
- else:
- return float(x)
- def wrap(*a, **kw):
- sys.stderr.write("%r: " % (tuple(map(tofloat, a)),))
- sys.stderr.flush()
- try:
- r = func(*a, **kw)
- sys.stderr.write("-> %r" % r)
- finally:
- sys.stderr.write("\n")
- sys.stderr.flush()
- return r
- return wrap
- try:
- import posix
- import signal
- POSIX = ('setitimer' in dir(signal))
- except ImportError:
- POSIX = False
- class TimeoutError(Exception):
- pass
- def time_limited(timeout=0.5, return_val=np.nan, use_sigalrm=True):
- """
- Decorator for setting a timeout for pure-Python functions.
- If the function does not return within `timeout` seconds, the
- value `return_val` is returned instead.
- On POSIX this uses SIGALRM by default. On non-POSIX, settrace is
- used. Do not use this with threads: the SIGALRM implementation
- does probably not work well. The settrace implementation only
- traces the current thread.
- The settrace implementation slows down execution speed. Slowdown
- by a factor around 10 is probably typical.
- """
- if POSIX and use_sigalrm:
- def sigalrm_handler(signum, frame):
- raise TimeoutError()
- def deco(func):
- def wrap(*a, **kw):
- old_handler = signal.signal(signal.SIGALRM, sigalrm_handler)
- signal.setitimer(signal.ITIMER_REAL, timeout)
- try:
- return func(*a, **kw)
- except TimeoutError:
- return return_val
- finally:
- signal.setitimer(signal.ITIMER_REAL, 0)
- signal.signal(signal.SIGALRM, old_handler)
- return wrap
- else:
- def deco(func):
- def wrap(*a, **kw):
- start_time = time.time()
- def trace(frame, event, arg):
- if time.time() - start_time > timeout:
- raise TimeoutError()
- return trace
- sys.settrace(trace)
- try:
- return func(*a, **kw)
- except TimeoutError:
- sys.settrace(None)
- return return_val
- finally:
- sys.settrace(None)
- return wrap
- return deco
- def exception_to_nan(func):
- """Decorate function to return nan if it raises an exception"""
- def wrap(*a, **kw):
- try:
- return func(*a, **kw)
- except Exception:
- return np.nan
- return wrap
- def inf_to_nan(func):
- """Decorate function to return nan if it returns inf"""
- def wrap(*a, **kw):
- v = func(*a, **kw)
- if not np.isfinite(v):
- return np.nan
- return v
- return wrap
- def mp_assert_allclose(res, std, atol=0, rtol=1e-17):
- """
- Compare lists of mpmath.mpf's or mpmath.mpc's directly so that it
- can be done to higher precision than double.
- """
- failures = []
- for k, (resval, stdval) in enumerate(zip_longest(res, std)):
- if resval is None or stdval is None:
- raise ValueError('Lengths of inputs res and std are not equal.')
- if mpmath.fabs(resval - stdval) > atol + rtol*mpmath.fabs(stdval):
- failures.append((k, resval, stdval))
- nfail = len(failures)
- if nfail > 0:
- ndigits = int(abs(np.log10(rtol)))
- msg = [""]
- msg.append("Bad results ({} out of {}) for the following points:"
- .format(nfail, k + 1))
- for k, resval, stdval in failures:
- resrep = mpmath.nstr(resval, ndigits, min_fixed=0, max_fixed=0)
- stdrep = mpmath.nstr(stdval, ndigits, min_fixed=0, max_fixed=0)
- if stdval == 0:
- rdiff = "inf"
- else:
- rdiff = mpmath.fabs((resval - stdval)/stdval)
- rdiff = mpmath.nstr(rdiff, 3)
- msg.append("{}: {} != {} (rdiff {})".format(k, resrep, stdrep,
- rdiff))
- assert_(False, "\n".join(msg))