123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703 |
- import typing
- import types
- import inspect
- import functools
- from . import _uarray
- import copyreg
- import pickle
- import contextlib
- ArgumentExtractorType = typing.Callable[..., typing.Tuple["Dispatchable", ...]]
- ArgumentReplacerType = typing.Callable[
- [typing.Tuple, typing.Dict, typing.Tuple], typing.Tuple[typing.Tuple, typing.Dict]
- ]
- from ._uarray import ( # type: ignore
- BackendNotImplementedError,
- _Function,
- _SkipBackendContext,
- _SetBackendContext,
- _BackendState,
- )
- __all__ = [
- "set_backend",
- "set_global_backend",
- "skip_backend",
- "register_backend",
- "determine_backend",
- "determine_backend_multi",
- "clear_backends",
- "create_multimethod",
- "generate_multimethod",
- "_Function",
- "BackendNotImplementedError",
- "Dispatchable",
- "wrap_single_convertor",
- "wrap_single_convertor_instance",
- "all_of_type",
- "mark_as",
- "set_state",
- "get_state",
- "reset_state",
- "_BackendState",
- "_SkipBackendContext",
- "_SetBackendContext",
- ]
- def unpickle_function(mod_name, qname, self_):
- import importlib
- try:
- module = importlib.import_module(mod_name)
- qname = qname.split(".")
- func = module
- for q in qname:
- func = getattr(func, q)
- if self_ is not None:
- func = types.MethodType(func, self_)
- return func
- except (ImportError, AttributeError) as e:
- from pickle import UnpicklingError
- raise UnpicklingError from e
- def pickle_function(func):
- mod_name = getattr(func, "__module__", None)
- qname = getattr(func, "__qualname__", None)
- self_ = getattr(func, "__self__", None)
- try:
- test = unpickle_function(mod_name, qname, self_)
- except pickle.UnpicklingError:
- test = None
- if test is not func:
- raise pickle.PicklingError(
- "Can't pickle {}: it's not the same object as {}".format(func, test)
- )
- return unpickle_function, (mod_name, qname, self_)
- def pickle_state(state):
- return _uarray._BackendState._unpickle, state._pickle()
- def pickle_set_backend_context(ctx):
- return _SetBackendContext, ctx._pickle()
- def pickle_skip_backend_context(ctx):
- return _SkipBackendContext, ctx._pickle()
- copyreg.pickle(_Function, pickle_function)
- copyreg.pickle(_uarray._BackendState, pickle_state)
- copyreg.pickle(_SetBackendContext, pickle_set_backend_context)
- copyreg.pickle(_SkipBackendContext, pickle_skip_backend_context)
- def get_state():
- """
- Returns an opaque object containing the current state of all the backends.
- Can be used for synchronization between threads/processes.
- See Also
- --------
- set_state
- Sets the state returned by this function.
- """
- return _uarray.get_state()
- @contextlib.contextmanager
- def reset_state():
- """
- Returns a context manager that resets all state once exited.
- See Also
- --------
- set_state
- Context manager that sets the backend state.
- get_state
- Gets a state to be set by this context manager.
- """
- with set_state(get_state()):
- yield
- @contextlib.contextmanager
- def set_state(state):
- """
- A context manager that sets the state of the backends to one returned by :obj:`get_state`.
- See Also
- --------
- get_state
- Gets a state to be set by this context manager.
- """
- old_state = get_state()
- _uarray.set_state(state)
- try:
- yield
- finally:
- _uarray.set_state(old_state, True)
- def create_multimethod(*args, **kwargs):
- """
- Creates a decorator for generating multimethods.
- This function creates a decorator that can be used with an argument
- extractor in order to generate a multimethod. Other than for the
- argument extractor, all arguments are passed on to
- :obj:`generate_multimethod`.
- See Also
- --------
- generate_multimethod
- Generates a multimethod.
- """
- def wrapper(a):
- return generate_multimethod(a, *args, **kwargs)
- return wrapper
- def generate_multimethod(
- argument_extractor: ArgumentExtractorType,
- argument_replacer: ArgumentReplacerType,
- domain: str,
- default: typing.Optional[typing.Callable] = None,
- ):
- """
- Generates a multimethod.
- Parameters
- ----------
- argument_extractor : ArgumentExtractorType
- A callable which extracts the dispatchable arguments. Extracted arguments
- should be marked by the :obj:`Dispatchable` class. It has the same signature
- as the desired multimethod.
- argument_replacer : ArgumentReplacerType
- A callable with the signature (args, kwargs, dispatchables), which should also
- return an (args, kwargs) pair with the dispatchables replaced inside the args/kwargs.
- domain : str
- A string value indicating the domain of this multimethod.
- default: Optional[Callable], optional
- The default implementation of this multimethod, where ``None`` (the default) specifies
- there is no default implementation.
- Examples
- --------
- In this example, ``a`` is to be dispatched over, so we return it, while marking it as an ``int``.
- The trailing comma is needed because the args have to be returned as an iterable.
- >>> def override_me(a, b):
- ... return Dispatchable(a, int),
- Next, we define the argument replacer that replaces the dispatchables inside args/kwargs with the
- supplied ones.
- >>> def override_replacer(args, kwargs, dispatchables):
- ... return (dispatchables[0], args[1]), {}
- Next, we define the multimethod.
- >>> overridden_me = generate_multimethod(
- ... override_me, override_replacer, "ua_examples"
- ... )
- Notice that there's no default implementation, unless you supply one.
- >>> overridden_me(1, "a")
- Traceback (most recent call last):
- ...
- uarray.BackendNotImplementedError: ...
- >>> overridden_me2 = generate_multimethod(
- ... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y)
- ... )
- >>> overridden_me2(1, "a")
- (1, 'a')
- See Also
- --------
- uarray
- See the module documentation for how to override the method by creating backends.
- """
- kw_defaults, arg_defaults, opts = get_defaults(argument_extractor)
- ua_func = _Function(
- argument_extractor,
- argument_replacer,
- domain,
- arg_defaults,
- kw_defaults,
- default,
- )
- return functools.update_wrapper(ua_func, argument_extractor)
- def set_backend(backend, coerce=False, only=False):
- """
- A context manager that sets the preferred backend.
- Parameters
- ----------
- backend
- The backend to set.
- coerce
- Whether or not to coerce to a specific backend's types. Implies ``only``.
- only
- Whether or not this should be the last backend to try.
- See Also
- --------
- skip_backend: A context manager that allows skipping of backends.
- set_global_backend: Set a single, global backend for a domain.
- """
- try:
- return backend.__ua_cache__["set", coerce, only]
- except AttributeError:
- backend.__ua_cache__ = {}
- except KeyError:
- pass
- ctx = _SetBackendContext(backend, coerce, only)
- backend.__ua_cache__["set", coerce, only] = ctx
- return ctx
- def skip_backend(backend):
- """
- A context manager that allows one to skip a given backend from processing
- entirely. This allows one to use another backend's code in a library that
- is also a consumer of the same backend.
- Parameters
- ----------
- backend
- The backend to skip.
- See Also
- --------
- set_backend: A context manager that allows setting of backends.
- set_global_backend: Set a single, global backend for a domain.
- """
- try:
- return backend.__ua_cache__["skip"]
- except AttributeError:
- backend.__ua_cache__ = {}
- except KeyError:
- pass
- ctx = _SkipBackendContext(backend)
- backend.__ua_cache__["skip"] = ctx
- return ctx
- def get_defaults(f):
- sig = inspect.signature(f)
- kw_defaults = {}
- arg_defaults = []
- opts = set()
- for k, v in sig.parameters.items():
- if v.default is not inspect.Parameter.empty:
- kw_defaults[k] = v.default
- if v.kind in (
- inspect.Parameter.POSITIONAL_ONLY,
- inspect.Parameter.POSITIONAL_OR_KEYWORD,
- ):
- arg_defaults.append(v.default)
- opts.add(k)
- return kw_defaults, tuple(arg_defaults), opts
- def set_global_backend(backend, coerce=False, only=False, *, try_last=False):
- """
- This utility method replaces the default backend for permanent use. It
- will be tried in the list of backends automatically, unless the
- ``only`` flag is set on a backend. This will be the first tried
- backend outside the :obj:`set_backend` context manager.
- Note that this method is not thread-safe.
- .. warning::
- We caution library authors against using this function in
- their code. We do *not* support this use-case. This function
- is meant to be used only by users themselves, or by a reference
- implementation, if one exists.
- Parameters
- ----------
- backend
- The backend to register.
- coerce : bool
- Whether to coerce input types when trying this backend.
- only : bool
- If ``True``, no more backends will be tried if this fails.
- Implied by ``coerce=True``.
- try_last : bool
- If ``True``, the global backend is tried after registered backends.
- See Also
- --------
- set_backend: A context manager that allows setting of backends.
- skip_backend: A context manager that allows skipping of backends.
- """
- _uarray.set_global_backend(backend, coerce, only, try_last)
- def register_backend(backend):
- """
- This utility method sets registers backend for permanent use. It
- will be tried in the list of backends automatically, unless the
- ``only`` flag is set on a backend.
- Note that this method is not thread-safe.
- Parameters
- ----------
- backend
- The backend to register.
- """
- _uarray.register_backend(backend)
- def clear_backends(domain, registered=True, globals=False):
- """
- This utility method clears registered backends.
- .. warning::
- We caution library authors against using this function in
- their code. We do *not* support this use-case. This function
- is meant to be used only by users themselves.
- .. warning::
- Do NOT use this method inside a multimethod call, or the
- program is likely to crash.
- Parameters
- ----------
- domain : Optional[str]
- The domain for which to de-register backends. ``None`` means
- de-register for all domains.
- registered : bool
- Whether or not to clear registered backends. See :obj:`register_backend`.
- globals : bool
- Whether or not to clear global backends. See :obj:`set_global_backend`.
- See Also
- --------
- register_backend : Register a backend globally.
- set_global_backend : Set a global backend.
- """
- _uarray.clear_backends(domain, registered, globals)
- class Dispatchable:
- """
- A utility class which marks an argument with a specific dispatch type.
- Attributes
- ----------
- value
- The value of the Dispatchable.
- type
- The type of the Dispatchable.
- Examples
- --------
- >>> x = Dispatchable(1, str)
- >>> x
- <Dispatchable: type=<class 'str'>, value=1>
- See Also
- --------
- all_of_type
- Marks all unmarked parameters of a function.
- mark_as
- Allows one to create a utility function to mark as a given type.
- """
- def __init__(self, value, dispatch_type, coercible=True):
- self.value = value
- self.type = dispatch_type
- self.coercible = coercible
- def __getitem__(self, index):
- return (self.type, self.value)[index]
- def __str__(self):
- return "<{0}: type={1!r}, value={2!r}>".format(
- type(self).__name__, self.type, self.value
- )
- __repr__ = __str__
- def mark_as(dispatch_type):
- """
- Creates a utility function to mark something as a specific type.
- Examples
- --------
- >>> mark_int = mark_as(int)
- >>> mark_int(1)
- <Dispatchable: type=<class 'int'>, value=1>
- """
- return functools.partial(Dispatchable, dispatch_type=dispatch_type)
- def all_of_type(arg_type):
- """
- Marks all unmarked arguments as a given type.
- Examples
- --------
- >>> @all_of_type(str)
- ... def f(a, b):
- ... return a, Dispatchable(b, int)
- >>> f('a', 1)
- (<Dispatchable: type=<class 'str'>, value='a'>, <Dispatchable: type=<class 'int'>, value=1>)
- """
- def outer(func):
- @functools.wraps(func)
- def inner(*args, **kwargs):
- extracted_args = func(*args, **kwargs)
- return tuple(
- Dispatchable(arg, arg_type)
- if not isinstance(arg, Dispatchable)
- else arg
- for arg in extracted_args
- )
- return inner
- return outer
- def wrap_single_convertor(convert_single):
- """
- Wraps a ``__ua_convert__`` defined for a single element to all elements.
- If any of them return ``NotImplemented``, the operation is assumed to be
- undefined.
- Accepts a signature of (value, type, coerce).
- """
- @functools.wraps(convert_single)
- def __ua_convert__(dispatchables, coerce):
- converted = []
- for d in dispatchables:
- c = convert_single(d.value, d.type, coerce and d.coercible)
- if c is NotImplemented:
- return NotImplemented
- converted.append(c)
- return converted
- return __ua_convert__
- def wrap_single_convertor_instance(convert_single):
- """
- Wraps a ``__ua_convert__`` defined for a single element to all elements.
- If any of them return ``NotImplemented``, the operation is assumed to be
- undefined.
- Accepts a signature of (value, type, coerce).
- """
- @functools.wraps(convert_single)
- def __ua_convert__(self, dispatchables, coerce):
- converted = []
- for d in dispatchables:
- c = convert_single(self, d.value, d.type, coerce and d.coercible)
- if c is NotImplemented:
- return NotImplemented
- converted.append(c)
- return converted
- return __ua_convert__
- def determine_backend(value, dispatch_type, *, domain, only=True, coerce=False):
- """Set the backend to the first active backend that supports ``value``
- This is useful for functions that call multimethods without any dispatchable
- arguments. You can use :func:`determine_backend` to ensure the same backend
- is used everywhere in a block of multimethod calls.
- Parameters
- ----------
- value
- The value being tested
- dispatch_type
- The dispatch type associated with ``value``, aka
- ":ref:`marking <MarkingGlossary>`".
- domain: string
- The domain to query for backends and set.
- coerce: bool
- Whether or not to allow coercion to the backend's types. Implies ``only``.
- only: bool
- Whether or not this should be the last backend to try.
- See Also
- --------
- set_backend: For when you know which backend to set
- Notes
- -----
- Support is determined by the ``__ua_convert__`` protocol. Backends not
- supporting the type must return ``NotImplemented`` from their
- ``__ua_convert__`` if they don't support input of that type.
- Examples
- --------
- Suppose we have two backends ``BackendA`` and ``BackendB`` each supporting
- different types, ``TypeA`` and ``TypeB``. Neither supporting the other type:
- >>> with ua.set_backend(ex.BackendA):
- ... ex.call_multimethod(ex.TypeB(), ex.TypeB())
- Traceback (most recent call last):
- ...
- uarray.BackendNotImplementedError: ...
- Now consider a multimethod that creates a new object of ``TypeA``, or
- ``TypeB`` depending on the active backend.
- >>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
- ... res = ex.creation_multimethod()
- ... ex.call_multimethod(res, ex.TypeA())
- Traceback (most recent call last):
- ...
- uarray.BackendNotImplementedError: ...
- ``res`` is an object of ``TypeB`` because ``BackendB`` is set in the
- innermost with statement. So, ``call_multimethod`` fails since the types
- don't match.
- Instead, we need to first find a backend suitable for all of our objects.
- >>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
- ... x = ex.TypeA()
- ... with ua.determine_backend(x, "mark", domain="ua_examples"):
- ... res = ex.creation_multimethod()
- ... ex.call_multimethod(res, x)
- TypeA
- """
- dispatchables = (Dispatchable(value, dispatch_type, coerce),)
- backend = _uarray.determine_backend(domain, dispatchables, coerce)
- return set_backend(backend, coerce=coerce, only=only)
- def determine_backend_multi(
- dispatchables, *, domain, only=True, coerce=False, **kwargs
- ):
- """Set a backend supporting all ``dispatchables``
- This is useful for functions that call multimethods without any dispatchable
- arguments. You can use :func:`determine_backend_multi` to ensure the same
- backend is used everywhere in a block of multimethod calls involving
- multiple arrays.
- Parameters
- ----------
- dispatchables: Sequence[Union[uarray.Dispatchable, Any]]
- The dispatchables that must be supported
- domain: string
- The domain to query for backends and set.
- coerce: bool
- Whether or not to allow coercion to the backend's types. Implies ``only``.
- only: bool
- Whether or not this should be the last backend to try.
- dispatch_type: Optional[Any]
- The default dispatch type associated with ``dispatchables``, aka
- ":ref:`marking <MarkingGlossary>`".
- See Also
- --------
- determine_backend: For a single dispatch value
- set_backend: For when you know which backend to set
- Notes
- -----
- Support is determined by the ``__ua_convert__`` protocol. Backends not
- supporting the type must return ``NotImplemented`` from their
- ``__ua_convert__`` if they don't support input of that type.
- Examples
- --------
- :func:`determine_backend` allows the backend to be set from a single
- object. :func:`determine_backend_multi` allows multiple objects to be
- checked simultaneously for support in the backend. Suppose we have a
- ``BackendAB`` which supports ``TypeA`` and ``TypeB`` in the same call,
- and a ``BackendBC`` that doesn't support ``TypeA``.
- >>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
- ... a, b = ex.TypeA(), ex.TypeB()
- ... with ua.determine_backend_multi(
- ... [ua.Dispatchable(a, "mark"), ua.Dispatchable(b, "mark")],
- ... domain="ua_examples"
- ... ):
- ... res = ex.creation_multimethod()
- ... ex.call_multimethod(res, a, b)
- TypeA
- This won't call ``BackendBC`` because it doesn't support ``TypeA``.
- We can also use leave out the ``ua.Dispatchable`` if we specify the
- default ``dispatch_type`` for the ``dispatchables`` argument.
- >>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
- ... a, b = ex.TypeA(), ex.TypeB()
- ... with ua.determine_backend_multi(
- ... [a, b], dispatch_type="mark", domain="ua_examples"
- ... ):
- ... res = ex.creation_multimethod()
- ... ex.call_multimethod(res, a, b)
- TypeA
- """
- if "dispatch_type" in kwargs:
- disp_type = kwargs.pop("dispatch_type")
- dispatchables = tuple(
- d if isinstance(d, Dispatchable) else Dispatchable(d, disp_type)
- for d in dispatchables
- )
- else:
- dispatchables = tuple(dispatchables)
- if not all(isinstance(d, Dispatchable) for d in dispatchables):
- raise TypeError("dispatchables must be instances of uarray.Dispatchable")
- if len(kwargs) != 0:
- raise TypeError("Received unexpected keyword arguments: {}".format(kwargs))
- backend = _uarray.determine_backend(domain, dispatchables, coerce)
- return set_backend(backend, coerce=coerce, only=only)
|