_finite_differences.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from numpy import arange, newaxis, hstack, prod, array
  2. def _central_diff_weights(Np, ndiv=1):
  3. """
  4. Return weights for an Np-point central derivative.
  5. Assumes equally-spaced function points.
  6. If weights are in the vector w, then
  7. derivative is w[0] * f(x-ho*dx) + ... + w[-1] * f(x+h0*dx)
  8. Parameters
  9. ----------
  10. Np : int
  11. Number of points for the central derivative.
  12. ndiv : int, optional
  13. Number of divisions. Default is 1.
  14. Returns
  15. -------
  16. w : ndarray
  17. Weights for an Np-point central derivative. Its size is `Np`.
  18. Notes
  19. -----
  20. Can be inaccurate for a large number of points.
  21. Examples
  22. --------
  23. We can calculate a derivative value of a function.
  24. >>> def f(x):
  25. ... return 2 * x**2 + 3
  26. >>> x = 3.0 # derivative point
  27. >>> h = 0.1 # differential step
  28. >>> Np = 3 # point number for central derivative
  29. >>> weights = _central_diff_weights(Np) # weights for first derivative
  30. >>> vals = [f(x + (i - Np/2) * h) for i in range(Np)]
  31. >>> sum(w * v for (w, v) in zip(weights, vals))/h
  32. 11.79999999999998
  33. This value is close to the analytical solution:
  34. f'(x) = 4x, so f'(3) = 12
  35. References
  36. ----------
  37. .. [1] https://en.wikipedia.org/wiki/Finite_difference
  38. """
  39. if Np < ndiv + 1:
  40. raise ValueError(
  41. "Number of points must be at least the derivative order + 1."
  42. )
  43. if Np % 2 == 0:
  44. raise ValueError("The number of points must be odd.")
  45. from scipy import linalg
  46. ho = Np >> 1
  47. x = arange(-ho, ho + 1.0)
  48. x = x[:, newaxis]
  49. X = x**0.0
  50. for k in range(1, Np):
  51. X = hstack([X, x**k])
  52. w = prod(arange(1, ndiv + 1), axis=0) * linalg.inv(X)[ndiv]
  53. return w
  54. def _derivative(func, x0, dx=1.0, n=1, args=(), order=3):
  55. """
  56. Find the nth derivative of a function at a point.
  57. Given a function, use a central difference formula with spacing `dx` to
  58. compute the nth derivative at `x0`.
  59. Parameters
  60. ----------
  61. func : function
  62. Input function.
  63. x0 : float
  64. The point at which the nth derivative is found.
  65. dx : float, optional
  66. Spacing.
  67. n : int, optional
  68. Order of the derivative. Default is 1.
  69. args : tuple, optional
  70. Arguments
  71. order : int, optional
  72. Number of points to use, must be odd.
  73. Notes
  74. -----
  75. Decreasing the step size too small can result in round-off error.
  76. Examples
  77. --------
  78. >>> def f(x):
  79. ... return x**3 + x**2
  80. >>> _derivative(f, 1.0, dx=1e-6)
  81. 4.9999999999217337
  82. """
  83. if order < n + 1:
  84. raise ValueError(
  85. "'order' (the number of points used to compute the derivative), "
  86. "must be at least the derivative order 'n' + 1."
  87. )
  88. if order % 2 == 0:
  89. raise ValueError(
  90. "'order' (the number of points used to compute the derivative) "
  91. "must be odd."
  92. )
  93. # pre-computed for n=1 and 2 and low-order for speed.
  94. if n == 1:
  95. if order == 3:
  96. weights = array([-1, 0, 1]) / 2.0
  97. elif order == 5:
  98. weights = array([1, -8, 0, 8, -1]) / 12.0
  99. elif order == 7:
  100. weights = array([-1, 9, -45, 0, 45, -9, 1]) / 60.0
  101. elif order == 9:
  102. weights = array([3, -32, 168, -672, 0, 672, -168, 32, -3]) / 840.0
  103. else:
  104. weights = _central_diff_weights(order, 1)
  105. elif n == 2:
  106. if order == 3:
  107. weights = array([1, -2.0, 1])
  108. elif order == 5:
  109. weights = array([-1, 16, -30, 16, -1]) / 12.0
  110. elif order == 7:
  111. weights = array([2, -27, 270, -490, 270, -27, 2]) / 180.0
  112. elif order == 9:
  113. weights = (
  114. array([-9, 128, -1008, 8064, -14350, 8064, -1008, 128, -9])
  115. / 5040.0
  116. )
  117. else:
  118. weights = _central_diff_weights(order, 2)
  119. else:
  120. weights = _central_diff_weights(order, n)
  121. val = 0.0
  122. ho = order >> 1
  123. for k in range(order):
  124. val += weights[k] * func(x0 + (k - ho) * dx, *args)
  125. return val / prod((dx,) * n, axis=0)