123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- from numpy import arange, newaxis, hstack, prod, array
- def _central_diff_weights(Np, ndiv=1):
- """
- Return weights for an Np-point central derivative.
- Assumes equally-spaced function points.
- If weights are in the vector w, then
- derivative is w[0] * f(x-ho*dx) + ... + w[-1] * f(x+h0*dx)
- Parameters
- ----------
- Np : int
- Number of points for the central derivative.
- ndiv : int, optional
- Number of divisions. Default is 1.
- Returns
- -------
- w : ndarray
- Weights for an Np-point central derivative. Its size is `Np`.
- Notes
- -----
- Can be inaccurate for a large number of points.
- Examples
- --------
- We can calculate a derivative value of a function.
- >>> def f(x):
- ... return 2 * x**2 + 3
- >>> x = 3.0 # derivative point
- >>> h = 0.1 # differential step
- >>> Np = 3 # point number for central derivative
- >>> weights = _central_diff_weights(Np) # weights for first derivative
- >>> vals = [f(x + (i - Np/2) * h) for i in range(Np)]
- >>> sum(w * v for (w, v) in zip(weights, vals))/h
- 11.79999999999998
- This value is close to the analytical solution:
- f'(x) = 4x, so f'(3) = 12
- References
- ----------
- .. [1] https://en.wikipedia.org/wiki/Finite_difference
- """
- if Np < ndiv + 1:
- raise ValueError(
- "Number of points must be at least the derivative order + 1."
- )
- if Np % 2 == 0:
- raise ValueError("The number of points must be odd.")
- from scipy import linalg
- ho = Np >> 1
- x = arange(-ho, ho + 1.0)
- x = x[:, newaxis]
- X = x**0.0
- for k in range(1, Np):
- X = hstack([X, x**k])
- w = prod(arange(1, ndiv + 1), axis=0) * linalg.inv(X)[ndiv]
- return w
- def _derivative(func, x0, dx=1.0, n=1, args=(), order=3):
- """
- Find the nth derivative of a function at a point.
- Given a function, use a central difference formula with spacing `dx` to
- compute the nth derivative at `x0`.
- Parameters
- ----------
- func : function
- Input function.
- x0 : float
- The point at which the nth derivative is found.
- dx : float, optional
- Spacing.
- n : int, optional
- Order of the derivative. Default is 1.
- args : tuple, optional
- Arguments
- order : int, optional
- Number of points to use, must be odd.
- Notes
- -----
- Decreasing the step size too small can result in round-off error.
- Examples
- --------
- >>> def f(x):
- ... return x**3 + x**2
- >>> _derivative(f, 1.0, dx=1e-6)
- 4.9999999999217337
- """
- if order < n + 1:
- raise ValueError(
- "'order' (the number of points used to compute the derivative), "
- "must be at least the derivative order 'n' + 1."
- )
- if order % 2 == 0:
- raise ValueError(
- "'order' (the number of points used to compute the derivative) "
- "must be odd."
- )
- # pre-computed for n=1 and 2 and low-order for speed.
- if n == 1:
- if order == 3:
- weights = array([-1, 0, 1]) / 2.0
- elif order == 5:
- weights = array([1, -8, 0, 8, -1]) / 12.0
- elif order == 7:
- weights = array([-1, 9, -45, 0, 45, -9, 1]) / 60.0
- elif order == 9:
- weights = array([3, -32, 168, -672, 0, 672, -168, 32, -3]) / 840.0
- else:
- weights = _central_diff_weights(order, 1)
- elif n == 2:
- if order == 3:
- weights = array([1, -2.0, 1])
- elif order == 5:
- weights = array([-1, 16, -30, 16, -1]) / 12.0
- elif order == 7:
- weights = array([2, -27, 270, -490, 270, -27, 2]) / 180.0
- elif order == 9:
- weights = (
- array([-9, 128, -1008, 8064, -14350, 8064, -1008, 128, -9])
- / 5040.0
- )
- else:
- weights = _central_diff_weights(order, 2)
- else:
- weights = _central_diff_weights(order, n)
- val = 0.0
- ho = order >> 1
- for k in range(order):
- val += weights[k] * func(x0 + (k - ho) * dx, *args)
- return val / prod((dx,) * n, axis=0)
|