helper.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from numbers import Number
  2. import operator
  3. import os
  4. import threading
  5. import contextlib
  6. import numpy as np
  7. # good_size is exposed (and used) from this import
  8. from .pypocketfft import good_size
  9. _config = threading.local()
  10. _cpu_count = os.cpu_count()
  11. def _iterable_of_int(x, name=None):
  12. """Convert ``x`` to an iterable sequence of int
  13. Parameters
  14. ----------
  15. x : value, or sequence of values, convertible to int
  16. name : str, optional
  17. Name of the argument being converted, only used in the error message
  18. Returns
  19. -------
  20. y : ``List[int]``
  21. """
  22. if isinstance(x, Number):
  23. x = (x,)
  24. try:
  25. x = [operator.index(a) for a in x]
  26. except TypeError as e:
  27. name = name or "value"
  28. raise ValueError("{} must be a scalar or iterable of integers"
  29. .format(name)) from e
  30. return x
  31. def _init_nd_shape_and_axes(x, shape, axes):
  32. """Handles shape and axes arguments for nd transforms"""
  33. noshape = shape is None
  34. noaxes = axes is None
  35. if not noaxes:
  36. axes = _iterable_of_int(axes, 'axes')
  37. axes = [a + x.ndim if a < 0 else a for a in axes]
  38. if any(a >= x.ndim or a < 0 for a in axes):
  39. raise ValueError("axes exceeds dimensionality of input")
  40. if len(set(axes)) != len(axes):
  41. raise ValueError("all axes must be unique")
  42. if not noshape:
  43. shape = _iterable_of_int(shape, 'shape')
  44. if axes and len(axes) != len(shape):
  45. raise ValueError("when given, axes and shape arguments"
  46. " have to be of the same length")
  47. if noaxes:
  48. if len(shape) > x.ndim:
  49. raise ValueError("shape requires more axes than are present")
  50. axes = range(x.ndim - len(shape), x.ndim)
  51. shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)]
  52. elif noaxes:
  53. shape = list(x.shape)
  54. axes = range(x.ndim)
  55. else:
  56. shape = [x.shape[a] for a in axes]
  57. if any(s < 1 for s in shape):
  58. raise ValueError(
  59. "invalid number of data points ({0}) specified".format(shape))
  60. return shape, axes
  61. def _asfarray(x):
  62. """
  63. Convert to array with floating or complex dtype.
  64. float16 values are also promoted to float32.
  65. """
  66. if not hasattr(x, "dtype"):
  67. x = np.asarray(x)
  68. if x.dtype == np.float16:
  69. return np.asarray(x, np.float32)
  70. elif x.dtype.kind not in 'fc':
  71. return np.asarray(x, np.float64)
  72. # Require native byte order
  73. dtype = x.dtype.newbyteorder('=')
  74. # Always align input
  75. copy = not x.flags['ALIGNED']
  76. return np.array(x, dtype=dtype, copy=copy)
  77. def _datacopied(arr, original):
  78. """
  79. Strict check for `arr` not sharing any data with `original`,
  80. under the assumption that arr = asarray(original)
  81. """
  82. if arr is original:
  83. return False
  84. if not isinstance(original, np.ndarray) and hasattr(original, '__array__'):
  85. return False
  86. return arr.base is None
  87. def _fix_shape(x, shape, axes):
  88. """Internal auxiliary function for _raw_fft, _raw_fftnd."""
  89. must_copy = False
  90. # Build an nd slice with the dimensions to be read from x
  91. index = [slice(None)]*x.ndim
  92. for n, ax in zip(shape, axes):
  93. if x.shape[ax] >= n:
  94. index[ax] = slice(0, n)
  95. else:
  96. index[ax] = slice(0, x.shape[ax])
  97. must_copy = True
  98. index = tuple(index)
  99. if not must_copy:
  100. return x[index], False
  101. s = list(x.shape)
  102. for n, axis in zip(shape, axes):
  103. s[axis] = n
  104. z = np.zeros(s, x.dtype)
  105. z[index] = x[index]
  106. return z, True
  107. def _fix_shape_1d(x, n, axis):
  108. if n < 1:
  109. raise ValueError(
  110. "invalid number of data points ({0}) specified".format(n))
  111. return _fix_shape(x, (n,), (axis,))
  112. _NORM_MAP = {None: 0, 'backward': 0, 'ortho': 1, 'forward': 2}
  113. def _normalization(norm, forward):
  114. """Returns the pypocketfft normalization mode from the norm argument"""
  115. try:
  116. inorm = _NORM_MAP[norm]
  117. return inorm if forward else (2 - inorm)
  118. except KeyError:
  119. raise ValueError(
  120. f'Invalid norm value {norm!r}, should '
  121. 'be "backward", "ortho" or "forward"') from None
  122. def _workers(workers):
  123. if workers is None:
  124. return getattr(_config, 'default_workers', 1)
  125. if workers < 0:
  126. if workers >= -_cpu_count:
  127. workers += 1 + _cpu_count
  128. else:
  129. raise ValueError("workers value out of range; got {}, must not be"
  130. " less than {}".format(workers, -_cpu_count))
  131. elif workers == 0:
  132. raise ValueError("workers must not be zero")
  133. return workers
  134. @contextlib.contextmanager
  135. def set_workers(workers):
  136. """Context manager for the default number of workers used in `scipy.fft`
  137. Parameters
  138. ----------
  139. workers : int
  140. The default number of workers to use
  141. Examples
  142. --------
  143. >>> import numpy as np
  144. >>> from scipy import fft, signal
  145. >>> rng = np.random.default_rng()
  146. >>> x = rng.standard_normal((128, 64))
  147. >>> with fft.set_workers(4):
  148. ... y = signal.fftconvolve(x, x)
  149. """
  150. old_workers = get_workers()
  151. _config.default_workers = _workers(operator.index(workers))
  152. try:
  153. yield
  154. finally:
  155. _config.default_workers = old_workers
  156. def get_workers():
  157. """Returns the default number of workers within the current context
  158. Examples
  159. --------
  160. >>> from scipy import fft
  161. >>> fft.get_workers()
  162. 1
  163. >>> with fft.set_workers(4):
  164. ... fft.get_workers()
  165. 4
  166. """
  167. return getattr(_config, 'default_workers', 1)