_quad_vec.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653
  1. import sys
  2. import copy
  3. import heapq
  4. import collections
  5. import functools
  6. import numpy as np
  7. from scipy._lib._util import MapWrapper, _FunctionWrapper
  8. class LRUDict(collections.OrderedDict):
  9. def __init__(self, max_size):
  10. self.__max_size = max_size
  11. def __setitem__(self, key, value):
  12. existing_key = (key in self)
  13. super().__setitem__(key, value)
  14. if existing_key:
  15. self.move_to_end(key)
  16. elif len(self) > self.__max_size:
  17. self.popitem(last=False)
  18. def update(self, other):
  19. # Not needed below
  20. raise NotImplementedError()
  21. class SemiInfiniteFunc:
  22. """
  23. Argument transform from (start, +-oo) to (0, 1)
  24. """
  25. def __init__(self, func, start, infty):
  26. self._func = func
  27. self._start = start
  28. self._sgn = -1 if infty < 0 else 1
  29. # Overflow threshold for the 1/t**2 factor
  30. self._tmin = sys.float_info.min**0.5
  31. def get_t(self, x):
  32. z = self._sgn * (x - self._start) + 1
  33. if z == 0:
  34. # Can happen only if point not in range
  35. return np.inf
  36. return 1 / z
  37. def __call__(self, t):
  38. if t < self._tmin:
  39. return 0.0
  40. else:
  41. x = self._start + self._sgn * (1 - t) / t
  42. f = self._func(x)
  43. return self._sgn * (f / t) / t
  44. class DoubleInfiniteFunc:
  45. """
  46. Argument transform from (-oo, oo) to (-1, 1)
  47. """
  48. def __init__(self, func):
  49. self._func = func
  50. # Overflow threshold for the 1/t**2 factor
  51. self._tmin = sys.float_info.min**0.5
  52. def get_t(self, x):
  53. s = -1 if x < 0 else 1
  54. return s / (abs(x) + 1)
  55. def __call__(self, t):
  56. if abs(t) < self._tmin:
  57. return 0.0
  58. else:
  59. x = (1 - abs(t)) / t
  60. f = self._func(x)
  61. return (f / t) / t
  62. def _max_norm(x):
  63. return np.amax(abs(x))
  64. def _get_sizeof(obj):
  65. try:
  66. return sys.getsizeof(obj)
  67. except TypeError:
  68. # occurs on pypy
  69. if hasattr(obj, '__sizeof__'):
  70. return int(obj.__sizeof__())
  71. return 64
  72. class _Bunch:
  73. def __init__(self, **kwargs):
  74. self.__keys = kwargs.keys()
  75. self.__dict__.update(**kwargs)
  76. def __repr__(self):
  77. return "_Bunch({})".format(", ".join("{}={}".format(k, repr(self.__dict__[k]))
  78. for k in self.__keys))
  79. def quad_vec(f, a, b, epsabs=1e-200, epsrel=1e-8, norm='2', cache_size=100e6, limit=10000,
  80. workers=1, points=None, quadrature=None, full_output=False,
  81. *, args=()):
  82. r"""Adaptive integration of a vector-valued function.
  83. Parameters
  84. ----------
  85. f : callable
  86. Vector-valued function f(x) to integrate.
  87. a : float
  88. Initial point.
  89. b : float
  90. Final point.
  91. epsabs : float, optional
  92. Absolute tolerance.
  93. epsrel : float, optional
  94. Relative tolerance.
  95. norm : {'max', '2'}, optional
  96. Vector norm to use for error estimation.
  97. cache_size : int, optional
  98. Number of bytes to use for memoization.
  99. limit : float or int, optional
  100. An upper bound on the number of subintervals used in the adaptive
  101. algorithm.
  102. workers : int or map-like callable, optional
  103. If `workers` is an integer, part of the computation is done in
  104. parallel subdivided to this many tasks (using
  105. :class:`python:multiprocessing.pool.Pool`).
  106. Supply `-1` to use all cores available to the Process.
  107. Alternatively, supply a map-like callable, such as
  108. :meth:`python:multiprocessing.pool.Pool.map` for evaluating the
  109. population in parallel.
  110. This evaluation is carried out as ``workers(func, iterable)``.
  111. points : list, optional
  112. List of additional breakpoints.
  113. quadrature : {'gk21', 'gk15', 'trapezoid'}, optional
  114. Quadrature rule to use on subintervals.
  115. Options: 'gk21' (Gauss-Kronrod 21-point rule),
  116. 'gk15' (Gauss-Kronrod 15-point rule),
  117. 'trapezoid' (composite trapezoid rule).
  118. Default: 'gk21' for finite intervals and 'gk15' for (semi-)infinite
  119. full_output : bool, optional
  120. Return an additional ``info`` dictionary.
  121. args : tuple, optional
  122. Extra arguments to pass to function, if any.
  123. .. versionadded:: 1.8.0
  124. Returns
  125. -------
  126. res : {float, array-like}
  127. Estimate for the result
  128. err : float
  129. Error estimate for the result in the given norm
  130. info : dict
  131. Returned only when ``full_output=True``.
  132. Info dictionary. Is an object with the attributes:
  133. success : bool
  134. Whether integration reached target precision.
  135. status : int
  136. Indicator for convergence, success (0),
  137. failure (1), and failure due to rounding error (2).
  138. neval : int
  139. Number of function evaluations.
  140. intervals : ndarray, shape (num_intervals, 2)
  141. Start and end points of subdivision intervals.
  142. integrals : ndarray, shape (num_intervals, ...)
  143. Integral for each interval.
  144. Note that at most ``cache_size`` values are recorded,
  145. and the array may contains *nan* for missing items.
  146. errors : ndarray, shape (num_intervals,)
  147. Estimated integration error for each interval.
  148. Notes
  149. -----
  150. The algorithm mainly follows the implementation of QUADPACK's
  151. DQAG* algorithms, implementing global error control and adaptive
  152. subdivision.
  153. The algorithm here has some differences to the QUADPACK approach:
  154. Instead of subdividing one interval at a time, the algorithm
  155. subdivides N intervals with largest errors at once. This enables
  156. (partial) parallelization of the integration.
  157. The logic of subdividing "next largest" intervals first is then
  158. not implemented, and we rely on the above extension to avoid
  159. concentrating on "small" intervals only.
  160. The Wynn epsilon table extrapolation is not used (QUADPACK uses it
  161. for infinite intervals). This is because the algorithm here is
  162. supposed to work on vector-valued functions, in an user-specified
  163. norm, and the extension of the epsilon algorithm to this case does
  164. not appear to be widely agreed. For max-norm, using elementwise
  165. Wynn epsilon could be possible, but we do not do this here with
  166. the hope that the epsilon extrapolation is mainly useful in
  167. special cases.
  168. References
  169. ----------
  170. [1] R. Piessens, E. de Doncker, QUADPACK (1983).
  171. Examples
  172. --------
  173. We can compute integrations of a vector-valued function:
  174. >>> from scipy.integrate import quad_vec
  175. >>> import numpy as np
  176. >>> import matplotlib.pyplot as plt
  177. >>> alpha = np.linspace(0.0, 2.0, num=30)
  178. >>> f = lambda x: x**alpha
  179. >>> x0, x1 = 0, 2
  180. >>> y, err = quad_vec(f, x0, x1)
  181. >>> plt.plot(alpha, y)
  182. >>> plt.xlabel(r"$\alpha$")
  183. >>> plt.ylabel(r"$\int_{0}^{2} x^\alpha dx$")
  184. >>> plt.show()
  185. """
  186. a = float(a)
  187. b = float(b)
  188. if args:
  189. if not isinstance(args, tuple):
  190. args = (args,)
  191. # create a wrapped function to allow the use of map and Pool.map
  192. f = _FunctionWrapper(f, args)
  193. # Use simple transformations to deal with integrals over infinite
  194. # intervals.
  195. kwargs = dict(epsabs=epsabs,
  196. epsrel=epsrel,
  197. norm=norm,
  198. cache_size=cache_size,
  199. limit=limit,
  200. workers=workers,
  201. points=points,
  202. quadrature='gk15' if quadrature is None else quadrature,
  203. full_output=full_output)
  204. if np.isfinite(a) and np.isinf(b):
  205. f2 = SemiInfiniteFunc(f, start=a, infty=b)
  206. if points is not None:
  207. kwargs['points'] = tuple(f2.get_t(xp) for xp in points)
  208. return quad_vec(f2, 0, 1, **kwargs)
  209. elif np.isfinite(b) and np.isinf(a):
  210. f2 = SemiInfiniteFunc(f, start=b, infty=a)
  211. if points is not None:
  212. kwargs['points'] = tuple(f2.get_t(xp) for xp in points)
  213. res = quad_vec(f2, 0, 1, **kwargs)
  214. return (-res[0],) + res[1:]
  215. elif np.isinf(a) and np.isinf(b):
  216. sgn = -1 if b < a else 1
  217. # NB. explicitly split integral at t=0, which separates
  218. # the positive and negative sides
  219. f2 = DoubleInfiniteFunc(f)
  220. if points is not None:
  221. kwargs['points'] = (0,) + tuple(f2.get_t(xp) for xp in points)
  222. else:
  223. kwargs['points'] = (0,)
  224. if a != b:
  225. res = quad_vec(f2, -1, 1, **kwargs)
  226. else:
  227. res = quad_vec(f2, 1, 1, **kwargs)
  228. return (res[0]*sgn,) + res[1:]
  229. elif not (np.isfinite(a) and np.isfinite(b)):
  230. raise ValueError("invalid integration bounds a={}, b={}".format(a, b))
  231. norm_funcs = {
  232. None: _max_norm,
  233. 'max': _max_norm,
  234. '2': np.linalg.norm
  235. }
  236. if callable(norm):
  237. norm_func = norm
  238. else:
  239. norm_func = norm_funcs[norm]
  240. parallel_count = 128
  241. min_intervals = 2
  242. try:
  243. _quadrature = {None: _quadrature_gk21,
  244. 'gk21': _quadrature_gk21,
  245. 'gk15': _quadrature_gk15,
  246. 'trapz': _quadrature_trapezoid, # alias for backcompat
  247. 'trapezoid': _quadrature_trapezoid}[quadrature]
  248. except KeyError as e:
  249. raise ValueError("unknown quadrature {!r}".format(quadrature)) from e
  250. # Initial interval set
  251. if points is None:
  252. initial_intervals = [(a, b)]
  253. else:
  254. prev = a
  255. initial_intervals = []
  256. for p in sorted(points):
  257. p = float(p)
  258. if not (a < p < b) or p == prev:
  259. continue
  260. initial_intervals.append((prev, p))
  261. prev = p
  262. initial_intervals.append((prev, b))
  263. global_integral = None
  264. global_error = None
  265. rounding_error = None
  266. interval_cache = None
  267. intervals = []
  268. neval = 0
  269. for x1, x2 in initial_intervals:
  270. ig, err, rnd = _quadrature(x1, x2, f, norm_func)
  271. neval += _quadrature.num_eval
  272. if global_integral is None:
  273. if isinstance(ig, (float, complex)):
  274. # Specialize for scalars
  275. if norm_func in (_max_norm, np.linalg.norm):
  276. norm_func = abs
  277. global_integral = ig
  278. global_error = float(err)
  279. rounding_error = float(rnd)
  280. cache_count = cache_size // _get_sizeof(ig)
  281. interval_cache = LRUDict(cache_count)
  282. else:
  283. global_integral += ig
  284. global_error += err
  285. rounding_error += rnd
  286. interval_cache[(x1, x2)] = copy.copy(ig)
  287. intervals.append((-err, x1, x2))
  288. heapq.heapify(intervals)
  289. CONVERGED = 0
  290. NOT_CONVERGED = 1
  291. ROUNDING_ERROR = 2
  292. NOT_A_NUMBER = 3
  293. status_msg = {
  294. CONVERGED: "Target precision reached.",
  295. NOT_CONVERGED: "Target precision not reached.",
  296. ROUNDING_ERROR: "Target precision could not be reached due to rounding error.",
  297. NOT_A_NUMBER: "Non-finite values encountered."
  298. }
  299. # Process intervals
  300. with MapWrapper(workers) as mapwrapper:
  301. ier = NOT_CONVERGED
  302. while intervals and len(intervals) < limit:
  303. # Select intervals with largest errors for subdivision
  304. tol = max(epsabs, epsrel*norm_func(global_integral))
  305. to_process = []
  306. err_sum = 0
  307. for j in range(parallel_count):
  308. if not intervals:
  309. break
  310. if j > 0 and err_sum > global_error - tol/8:
  311. # avoid unnecessary parallel splitting
  312. break
  313. interval = heapq.heappop(intervals)
  314. neg_old_err, a, b = interval
  315. old_int = interval_cache.pop((a, b), None)
  316. to_process.append(((-neg_old_err, a, b, old_int), f, norm_func, _quadrature))
  317. err_sum += -neg_old_err
  318. # Subdivide intervals
  319. for dint, derr, dround_err, subint, dneval in mapwrapper(_subdivide_interval, to_process):
  320. neval += dneval
  321. global_integral += dint
  322. global_error += derr
  323. rounding_error += dround_err
  324. for x in subint:
  325. x1, x2, ig, err = x
  326. interval_cache[(x1, x2)] = ig
  327. heapq.heappush(intervals, (-err, x1, x2))
  328. # Termination check
  329. if len(intervals) >= min_intervals:
  330. tol = max(epsabs, epsrel*norm_func(global_integral))
  331. if global_error < tol/8:
  332. ier = CONVERGED
  333. break
  334. if global_error < rounding_error:
  335. ier = ROUNDING_ERROR
  336. break
  337. if not (np.isfinite(global_error) and np.isfinite(rounding_error)):
  338. ier = NOT_A_NUMBER
  339. break
  340. res = global_integral
  341. err = global_error + rounding_error
  342. if full_output:
  343. res_arr = np.asarray(res)
  344. dummy = np.full(res_arr.shape, np.nan, dtype=res_arr.dtype)
  345. integrals = np.array([interval_cache.get((z[1], z[2]), dummy)
  346. for z in intervals], dtype=res_arr.dtype)
  347. errors = np.array([-z[0] for z in intervals])
  348. intervals = np.array([[z[1], z[2]] for z in intervals])
  349. info = _Bunch(neval=neval,
  350. success=(ier == CONVERGED),
  351. status=ier,
  352. message=status_msg[ier],
  353. intervals=intervals,
  354. integrals=integrals,
  355. errors=errors)
  356. return (res, err, info)
  357. else:
  358. return (res, err)
  359. def _subdivide_interval(args):
  360. interval, f, norm_func, _quadrature = args
  361. old_err, a, b, old_int = interval
  362. c = 0.5 * (a + b)
  363. # Left-hand side
  364. if getattr(_quadrature, 'cache_size', 0) > 0:
  365. f = functools.lru_cache(_quadrature.cache_size)(f)
  366. s1, err1, round1 = _quadrature(a, c, f, norm_func)
  367. dneval = _quadrature.num_eval
  368. s2, err2, round2 = _quadrature(c, b, f, norm_func)
  369. dneval += _quadrature.num_eval
  370. if old_int is None:
  371. old_int, _, _ = _quadrature(a, b, f, norm_func)
  372. dneval += _quadrature.num_eval
  373. if getattr(_quadrature, 'cache_size', 0) > 0:
  374. dneval = f.cache_info().misses
  375. dint = s1 + s2 - old_int
  376. derr = err1 + err2 - old_err
  377. dround_err = round1 + round2
  378. subintervals = ((a, c, s1, err1), (c, b, s2, err2))
  379. return dint, derr, dround_err, subintervals, dneval
  380. def _quadrature_trapezoid(x1, x2, f, norm_func):
  381. """
  382. Composite trapezoid quadrature
  383. """
  384. x3 = 0.5*(x1 + x2)
  385. f1 = f(x1)
  386. f2 = f(x2)
  387. f3 = f(x3)
  388. s2 = 0.25 * (x2 - x1) * (f1 + 2*f3 + f2)
  389. round_err = 0.25 * abs(x2 - x1) * (float(norm_func(f1))
  390. + 2*float(norm_func(f3))
  391. + float(norm_func(f2))) * 2e-16
  392. s1 = 0.5 * (x2 - x1) * (f1 + f2)
  393. err = 1/3 * float(norm_func(s1 - s2))
  394. return s2, err, round_err
  395. _quadrature_trapezoid.cache_size = 3 * 3
  396. _quadrature_trapezoid.num_eval = 3
  397. def _quadrature_gk(a, b, f, norm_func, x, w, v):
  398. """
  399. Generic Gauss-Kronrod quadrature
  400. """
  401. fv = [0.0]*len(x)
  402. c = 0.5 * (a + b)
  403. h = 0.5 * (b - a)
  404. # Gauss-Kronrod
  405. s_k = 0.0
  406. s_k_abs = 0.0
  407. for i in range(len(x)):
  408. ff = f(c + h*x[i])
  409. fv[i] = ff
  410. vv = v[i]
  411. # \int f(x)
  412. s_k += vv * ff
  413. # \int |f(x)|
  414. s_k_abs += vv * abs(ff)
  415. # Gauss
  416. s_g = 0.0
  417. for i in range(len(w)):
  418. s_g += w[i] * fv[2*i + 1]
  419. # Quadrature of abs-deviation from average
  420. s_k_dabs = 0.0
  421. y0 = s_k / 2.0
  422. for i in range(len(x)):
  423. # \int |f(x) - y0|
  424. s_k_dabs += v[i] * abs(fv[i] - y0)
  425. # Use similar error estimation as quadpack
  426. err = float(norm_func((s_k - s_g) * h))
  427. dabs = float(norm_func(s_k_dabs * h))
  428. if dabs != 0 and err != 0:
  429. err = dabs * min(1.0, (200 * err / dabs)**1.5)
  430. eps = sys.float_info.epsilon
  431. round_err = float(norm_func(50 * eps * h * s_k_abs))
  432. if round_err > sys.float_info.min:
  433. err = max(err, round_err)
  434. return h * s_k, err, round_err
  435. def _quadrature_gk21(a, b, f, norm_func):
  436. """
  437. Gauss-Kronrod 21 quadrature with error estimate
  438. """
  439. # Gauss-Kronrod points
  440. x = (0.995657163025808080735527280689003,
  441. 0.973906528517171720077964012084452,
  442. 0.930157491355708226001207180059508,
  443. 0.865063366688984510732096688423493,
  444. 0.780817726586416897063717578345042,
  445. 0.679409568299024406234327365114874,
  446. 0.562757134668604683339000099272694,
  447. 0.433395394129247190799265943165784,
  448. 0.294392862701460198131126603103866,
  449. 0.148874338981631210884826001129720,
  450. 0,
  451. -0.148874338981631210884826001129720,
  452. -0.294392862701460198131126603103866,
  453. -0.433395394129247190799265943165784,
  454. -0.562757134668604683339000099272694,
  455. -0.679409568299024406234327365114874,
  456. -0.780817726586416897063717578345042,
  457. -0.865063366688984510732096688423493,
  458. -0.930157491355708226001207180059508,
  459. -0.973906528517171720077964012084452,
  460. -0.995657163025808080735527280689003)
  461. # 10-point weights
  462. w = (0.066671344308688137593568809893332,
  463. 0.149451349150580593145776339657697,
  464. 0.219086362515982043995534934228163,
  465. 0.269266719309996355091226921569469,
  466. 0.295524224714752870173892994651338,
  467. 0.295524224714752870173892994651338,
  468. 0.269266719309996355091226921569469,
  469. 0.219086362515982043995534934228163,
  470. 0.149451349150580593145776339657697,
  471. 0.066671344308688137593568809893332)
  472. # 21-point weights
  473. v = (0.011694638867371874278064396062192,
  474. 0.032558162307964727478818972459390,
  475. 0.054755896574351996031381300244580,
  476. 0.075039674810919952767043140916190,
  477. 0.093125454583697605535065465083366,
  478. 0.109387158802297641899210590325805,
  479. 0.123491976262065851077958109831074,
  480. 0.134709217311473325928054001771707,
  481. 0.142775938577060080797094273138717,
  482. 0.147739104901338491374841515972068,
  483. 0.149445554002916905664936468389821,
  484. 0.147739104901338491374841515972068,
  485. 0.142775938577060080797094273138717,
  486. 0.134709217311473325928054001771707,
  487. 0.123491976262065851077958109831074,
  488. 0.109387158802297641899210590325805,
  489. 0.093125454583697605535065465083366,
  490. 0.075039674810919952767043140916190,
  491. 0.054755896574351996031381300244580,
  492. 0.032558162307964727478818972459390,
  493. 0.011694638867371874278064396062192)
  494. return _quadrature_gk(a, b, f, norm_func, x, w, v)
  495. _quadrature_gk21.num_eval = 21
  496. def _quadrature_gk15(a, b, f, norm_func):
  497. """
  498. Gauss-Kronrod 15 quadrature with error estimate
  499. """
  500. # Gauss-Kronrod points
  501. x = (0.991455371120812639206854697526329,
  502. 0.949107912342758524526189684047851,
  503. 0.864864423359769072789712788640926,
  504. 0.741531185599394439863864773280788,
  505. 0.586087235467691130294144838258730,
  506. 0.405845151377397166906606412076961,
  507. 0.207784955007898467600689403773245,
  508. 0.000000000000000000000000000000000,
  509. -0.207784955007898467600689403773245,
  510. -0.405845151377397166906606412076961,
  511. -0.586087235467691130294144838258730,
  512. -0.741531185599394439863864773280788,
  513. -0.864864423359769072789712788640926,
  514. -0.949107912342758524526189684047851,
  515. -0.991455371120812639206854697526329)
  516. # 7-point weights
  517. w = (0.129484966168869693270611432679082,
  518. 0.279705391489276667901467771423780,
  519. 0.381830050505118944950369775488975,
  520. 0.417959183673469387755102040816327,
  521. 0.381830050505118944950369775488975,
  522. 0.279705391489276667901467771423780,
  523. 0.129484966168869693270611432679082)
  524. # 15-point weights
  525. v = (0.022935322010529224963732008058970,
  526. 0.063092092629978553290700663189204,
  527. 0.104790010322250183839876322541518,
  528. 0.140653259715525918745189590510238,
  529. 0.169004726639267902826583426598550,
  530. 0.190350578064785409913256402421014,
  531. 0.204432940075298892414161999234649,
  532. 0.209482141084727828012999174891714,
  533. 0.204432940075298892414161999234649,
  534. 0.190350578064785409913256402421014,
  535. 0.169004726639267902826583426598550,
  536. 0.140653259715525918745189590510238,
  537. 0.104790010322250183839876322541518,
  538. 0.063092092629978553290700663189204,
  539. 0.022935322010529224963732008058970)
  540. return _quadrature_gk(a, b, f, norm_func, x, w, v)
  541. _quadrature_gk15.num_eval = 15