123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- import scipy._lib.uarray as ua
- from . import _fftlog
- from . import _pocketfft
- class _ScipyBackend:
- """The default backend for fft calculations
- Notes
- -----
- We use the domain ``numpy.scipy`` rather than ``scipy`` because ``uarray``
- treats the domain as a hierarchy. This means the user can install a single
- backend for ``numpy`` and have it implement ``numpy.scipy.fft`` as well.
- """
- __ua_domain__ = "numpy.scipy.fft"
- @staticmethod
- def __ua_function__(method, args, kwargs):
- fn = getattr(_pocketfft, method.__name__, None)
- if fn is None:
- fn = getattr(_fftlog, method.__name__, None)
- if fn is None:
- return NotImplemented
- return fn(*args, **kwargs)
- _named_backends = {
- 'scipy': _ScipyBackend,
- }
- def _backend_from_arg(backend):
- """Maps strings to known backends and validates the backend"""
- if isinstance(backend, str):
- try:
- backend = _named_backends[backend]
- except KeyError as e:
- raise ValueError('Unknown backend {}'.format(backend)) from e
- if backend.__ua_domain__ != 'numpy.scipy.fft':
- raise ValueError('Backend does not implement "numpy.scipy.fft"')
- return backend
- def set_global_backend(backend, coerce=False, only=False, try_last=False):
- """Sets the global fft backend
- This utility method replaces the default backend for permanent use. It
- will be tried in the list of backends automatically, unless the
- ``only`` flag is set on a backend. This will be the first tried
- backend outside the :obj:`set_backend` context manager.
- Parameters
- ----------
- backend : {object, 'scipy'}
- The backend to use.
- Can either be a ``str`` containing the name of a known backend
- {'scipy'} or an object that implements the uarray protocol.
- coerce : bool
- Whether to coerce input types when trying this backend.
- only : bool
- If ``True``, no more backends will be tried if this fails.
- Implied by ``coerce=True``.
- try_last : bool
- If ``True``, the global backend is tried after registered backends.
- Raises
- ------
- ValueError: If the backend does not implement ``numpy.scipy.fft``.
- Notes
- -----
- This will overwrite the previously set global backend, which, by default, is
- the SciPy implementation.
- Examples
- --------
- We can set the global fft backend:
- >>> from scipy.fft import fft, set_global_backend
- >>> set_global_backend("scipy") # Sets global backend. "scipy" is the default backend.
- >>> fft([1]) # Calls the global backend
- array([1.+0.j])
- """
- backend = _backend_from_arg(backend)
- ua.set_global_backend(backend, coerce=coerce, only=only, try_last=try_last)
- def register_backend(backend):
- """
- Register a backend for permanent use.
- Registered backends have the lowest priority and will be tried after the
- global backend.
- Parameters
- ----------
- backend : {object, 'scipy'}
- The backend to use.
- Can either be a ``str`` containing the name of a known backend
- {'scipy'} or an object that implements the uarray protocol.
- Raises
- ------
- ValueError: If the backend does not implement ``numpy.scipy.fft``.
- Examples
- --------
- We can register a new fft backend:
- >>> from scipy.fft import fft, register_backend, set_global_backend
- >>> class NoopBackend: # Define an invalid Backend
- ... __ua_domain__ = "numpy.scipy.fft"
- ... def __ua_function__(self, func, args, kwargs):
- ... return NotImplemented
- >>> set_global_backend(NoopBackend()) # Set the invalid backend as global
- >>> register_backend("scipy") # Register a new backend
- >>> fft([1]) # The registered backend is called because the global backend returns `NotImplemented`
- array([1.+0.j])
- >>> set_global_backend("scipy") # Restore global backend to default
- """
- backend = _backend_from_arg(backend)
- ua.register_backend(backend)
- def set_backend(backend, coerce=False, only=False):
- """Context manager to set the backend within a fixed scope.
- Upon entering the ``with`` statement, the given backend will be added to
- the list of available backends with the highest priority. Upon exit, the
- backend is reset to the state before entering the scope.
- Parameters
- ----------
- backend : {object, 'scipy'}
- The backend to use.
- Can either be a ``str`` containing the name of a known backend
- {'scipy'} or an object that implements the uarray protocol.
- coerce : bool, optional
- Whether to allow expensive conversions for the ``x`` parameter. e.g.,
- copying a NumPy array to the GPU for a CuPy backend. Implies ``only``.
- only : bool, optional
- If only is ``True`` and this backend returns ``NotImplemented``, then a
- BackendNotImplemented error will be raised immediately. Ignoring any
- lower priority backends.
- Examples
- --------
- >>> import scipy.fft as fft
- >>> with fft.set_backend('scipy', only=True):
- ... fft.fft([1]) # Always calls the scipy implementation
- array([1.+0.j])
- """
- backend = _backend_from_arg(backend)
- return ua.set_backend(backend, coerce=coerce, only=only)
- def skip_backend(backend):
- """Context manager to skip a backend within a fixed scope.
- Within the context of a ``with`` statement, the given backend will not be
- called. This covers backends registered both locally and globally. Upon
- exit, the backend will again be considered.
- Parameters
- ----------
- backend : {object, 'scipy'}
- The backend to skip.
- Can either be a ``str`` containing the name of a known backend
- {'scipy'} or an object that implements the uarray protocol.
- Examples
- --------
- >>> import scipy.fft as fft
- >>> fft.fft([1]) # Calls default SciPy backend
- array([1.+0.j])
- >>> with fft.skip_backend('scipy'): # We explicitly skip the SciPy backend
- ... fft.fft([1]) # leaving no implementation available
- Traceback (most recent call last):
- ...
- BackendNotImplementedError: No selected backends had an implementation ...
- """
- backend = _backend_from_arg(backend)
- return ua.skip_backend(backend)
- set_global_backend('scipy', try_last=True)
|