_arraytools.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. """
  2. Functions for acting on a axis of an array.
  3. """
  4. import numpy as np
  5. def axis_slice(a, start=None, stop=None, step=None, axis=-1):
  6. """Take a slice along axis 'axis' from 'a'.
  7. Parameters
  8. ----------
  9. a : numpy.ndarray
  10. The array to be sliced.
  11. start, stop, step : int or None
  12. The slice parameters.
  13. axis : int, optional
  14. The axis of `a` to be sliced.
  15. Examples
  16. --------
  17. >>> a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  18. >>> axis_slice(a, start=0, stop=1, axis=1)
  19. array([[1],
  20. [4],
  21. [7]])
  22. >>> axis_slice(a, start=1, axis=0)
  23. array([[4, 5, 6],
  24. [7, 8, 9]])
  25. Notes
  26. -----
  27. The keyword arguments start, stop and step are used by calling
  28. slice(start, stop, step). This implies axis_slice() does not
  29. handle its arguments the exactly the same as indexing. To select
  30. a single index k, for example, use
  31. axis_slice(a, start=k, stop=k+1)
  32. In this case, the length of the axis 'axis' in the result will
  33. be 1; the trivial dimension is not removed. (Use numpy.squeeze()
  34. to remove trivial axes.)
  35. """
  36. a_slice = [slice(None)] * a.ndim
  37. a_slice[axis] = slice(start, stop, step)
  38. b = a[tuple(a_slice)]
  39. return b
  40. def axis_reverse(a, axis=-1):
  41. """Reverse the 1-D slices of `a` along axis `axis`.
  42. Returns axis_slice(a, step=-1, axis=axis).
  43. """
  44. return axis_slice(a, step=-1, axis=axis)
  45. def odd_ext(x, n, axis=-1):
  46. """
  47. Odd extension at the boundaries of an array
  48. Generate a new ndarray by making an odd extension of `x` along an axis.
  49. Parameters
  50. ----------
  51. x : ndarray
  52. The array to be extended.
  53. n : int
  54. The number of elements by which to extend `x` at each end of the axis.
  55. axis : int, optional
  56. The axis along which to extend `x`. Default is -1.
  57. Examples
  58. --------
  59. >>> from scipy.signal._arraytools import odd_ext
  60. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  61. >>> odd_ext(a, 2)
  62. array([[-1, 0, 1, 2, 3, 4, 5, 6, 7],
  63. [-4, -1, 0, 1, 4, 9, 16, 23, 28]])
  64. Odd extension is a "180 degree rotation" at the endpoints of the original
  65. array:
  66. >>> t = np.linspace(0, 1.5, 100)
  67. >>> a = 0.9 * np.sin(2 * np.pi * t**2)
  68. >>> b = odd_ext(a, 40)
  69. >>> import matplotlib.pyplot as plt
  70. >>> plt.plot(arange(-40, 140), b, 'b', lw=1, label='odd extension')
  71. >>> plt.plot(arange(100), a, 'r', lw=2, label='original')
  72. >>> plt.legend(loc='best')
  73. >>> plt.show()
  74. """
  75. if n < 1:
  76. return x
  77. if n > x.shape[axis] - 1:
  78. raise ValueError(("The extension length n (%d) is too big. " +
  79. "It must not exceed x.shape[axis]-1, which is %d.")
  80. % (n, x.shape[axis] - 1))
  81. left_end = axis_slice(x, start=0, stop=1, axis=axis)
  82. left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
  83. right_end = axis_slice(x, start=-1, axis=axis)
  84. right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
  85. ext = np.concatenate((2 * left_end - left_ext,
  86. x,
  87. 2 * right_end - right_ext),
  88. axis=axis)
  89. return ext
  90. def even_ext(x, n, axis=-1):
  91. """
  92. Even extension at the boundaries of an array
  93. Generate a new ndarray by making an even extension of `x` along an axis.
  94. Parameters
  95. ----------
  96. x : ndarray
  97. The array to be extended.
  98. n : int
  99. The number of elements by which to extend `x` at each end of the axis.
  100. axis : int, optional
  101. The axis along which to extend `x`. Default is -1.
  102. Examples
  103. --------
  104. >>> from scipy.signal._arraytools import even_ext
  105. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  106. >>> even_ext(a, 2)
  107. array([[ 3, 2, 1, 2, 3, 4, 5, 4, 3],
  108. [ 4, 1, 0, 1, 4, 9, 16, 9, 4]])
  109. Even extension is a "mirror image" at the boundaries of the original array:
  110. >>> t = np.linspace(0, 1.5, 100)
  111. >>> a = 0.9 * np.sin(2 * np.pi * t**2)
  112. >>> b = even_ext(a, 40)
  113. >>> import matplotlib.pyplot as plt
  114. >>> plt.plot(arange(-40, 140), b, 'b', lw=1, label='even extension')
  115. >>> plt.plot(arange(100), a, 'r', lw=2, label='original')
  116. >>> plt.legend(loc='best')
  117. >>> plt.show()
  118. """
  119. if n < 1:
  120. return x
  121. if n > x.shape[axis] - 1:
  122. raise ValueError(("The extension length n (%d) is too big. " +
  123. "It must not exceed x.shape[axis]-1, which is %d.")
  124. % (n, x.shape[axis] - 1))
  125. left_ext = axis_slice(x, start=n, stop=0, step=-1, axis=axis)
  126. right_ext = axis_slice(x, start=-2, stop=-(n + 2), step=-1, axis=axis)
  127. ext = np.concatenate((left_ext,
  128. x,
  129. right_ext),
  130. axis=axis)
  131. return ext
  132. def const_ext(x, n, axis=-1):
  133. """
  134. Constant extension at the boundaries of an array
  135. Generate a new ndarray that is a constant extension of `x` along an axis.
  136. The extension repeats the values at the first and last element of
  137. the axis.
  138. Parameters
  139. ----------
  140. x : ndarray
  141. The array to be extended.
  142. n : int
  143. The number of elements by which to extend `x` at each end of the axis.
  144. axis : int, optional
  145. The axis along which to extend `x`. Default is -1.
  146. Examples
  147. --------
  148. >>> from scipy.signal._arraytools import const_ext
  149. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  150. >>> const_ext(a, 2)
  151. array([[ 1, 1, 1, 2, 3, 4, 5, 5, 5],
  152. [ 0, 0, 0, 1, 4, 9, 16, 16, 16]])
  153. Constant extension continues with the same values as the endpoints of the
  154. array:
  155. >>> t = np.linspace(0, 1.5, 100)
  156. >>> a = 0.9 * np.sin(2 * np.pi * t**2)
  157. >>> b = const_ext(a, 40)
  158. >>> import matplotlib.pyplot as plt
  159. >>> plt.plot(arange(-40, 140), b, 'b', lw=1, label='constant extension')
  160. >>> plt.plot(arange(100), a, 'r', lw=2, label='original')
  161. >>> plt.legend(loc='best')
  162. >>> plt.show()
  163. """
  164. if n < 1:
  165. return x
  166. left_end = axis_slice(x, start=0, stop=1, axis=axis)
  167. ones_shape = [1] * x.ndim
  168. ones_shape[axis] = n
  169. ones = np.ones(ones_shape, dtype=x.dtype)
  170. left_ext = ones * left_end
  171. right_end = axis_slice(x, start=-1, axis=axis)
  172. right_ext = ones * right_end
  173. ext = np.concatenate((left_ext,
  174. x,
  175. right_ext),
  176. axis=axis)
  177. return ext
  178. def zero_ext(x, n, axis=-1):
  179. """
  180. Zero padding at the boundaries of an array
  181. Generate a new ndarray that is a zero-padded extension of `x` along
  182. an axis.
  183. Parameters
  184. ----------
  185. x : ndarray
  186. The array to be extended.
  187. n : int
  188. The number of elements by which to extend `x` at each end of the
  189. axis.
  190. axis : int, optional
  191. The axis along which to extend `x`. Default is -1.
  192. Examples
  193. --------
  194. >>> from scipy.signal._arraytools import zero_ext
  195. >>> a = np.array([[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]])
  196. >>> zero_ext(a, 2)
  197. array([[ 0, 0, 1, 2, 3, 4, 5, 0, 0],
  198. [ 0, 0, 0, 1, 4, 9, 16, 0, 0]])
  199. """
  200. if n < 1:
  201. return x
  202. zeros_shape = list(x.shape)
  203. zeros_shape[axis] = n
  204. zeros = np.zeros(zeros_shape, dtype=x.dtype)
  205. ext = np.concatenate((zeros, x, zeros), axis=axis)
  206. return ext