123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434 |
- from warnings import warn
- import inspect
- from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
- from .utils import expand_tuples
- from .variadic import Variadic, isvariadic
- import itertools as itl
- __all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter",
- "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"]
- class MDNotImplementedError(NotImplementedError):
- """ A NotImplementedError for multiple dispatch """
- def ambiguity_warn(dispatcher, ambiguities):
- """ Raise warning when ambiguity is detected
- Parameters
- ----------
- dispatcher : Dispatcher
- The dispatcher on which the ambiguity was detected
- ambiguities : set
- Set of type signature pairs that are ambiguous within this dispatcher
- See Also:
- Dispatcher.add
- warning_text
- """
- warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
- def halt_ordering():
- """Deprecated interface to temporarily disable ordering.
- """
- warn(
- 'halt_ordering is deprecated, you can safely remove this call.',
- DeprecationWarning,
- )
- def restart_ordering(on_ambiguity=ambiguity_warn):
- """Deprecated interface to temporarily resume ordering.
- """
- warn(
- 'restart_ordering is deprecated, if you would like to eagerly order'
- 'the dispatchers, you should call the ``reorder()`` method on each'
- ' dispatcher.',
- DeprecationWarning,
- )
- def variadic_signature_matches_iter(types, full_signature):
- """Check if a set of input types matches a variadic signature.
- Notes
- -----
- The algorithm is as follows:
- Initialize the current signature to the first in the sequence
- For each type in `types`:
- If the current signature is variadic
- If the type matches the signature
- yield True
- Else
- Try to get the next signature
- If no signatures are left we can't possibly have a match
- so yield False
- Else
- yield True if the type matches the current signature
- Get the next signature
- """
- sigiter = iter(full_signature)
- sig = next(sigiter)
- for typ in types:
- matches = issubclass(typ, sig)
- yield matches
- if not isvariadic(sig):
- # we're not matching a variadic argument, so move to the next
- # element in the signature
- sig = next(sigiter)
- else:
- try:
- sig = next(sigiter)
- except StopIteration:
- assert isvariadic(sig)
- yield True
- else:
- # We have signature items left over, so all of our arguments
- # haven't matched
- yield False
- def variadic_signature_matches(types, full_signature):
- # No arguments always matches a variadic signature
- assert full_signature
- return all(variadic_signature_matches_iter(types, full_signature))
- class Dispatcher:
- """ Dispatch methods based on type signature
- Use ``dispatch`` to add implementations
- Examples
- --------
- >>> # xdoctest: +SKIP("bad import name")
- >>> from multipledispatch import dispatch
- >>> @dispatch(int)
- ... def f(x):
- ... return x + 1
- >>> @dispatch(float)
- ... def f(x):
- ... return x - 1
- >>> f(3)
- 4
- >>> f(3.0)
- 2.0
- """
- __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
- def __init__(self, name, doc=None):
- self.name = self.__name__ = name
- self.funcs = {}
- self.doc = doc
- self._cache = {}
- def register(self, *types, **kwargs):
- """ register dispatcher with new implementation
- >>> # xdoctest: +SKIP
- >>> f = Dispatcher('f')
- >>> @f.register(int)
- ... def inc(x):
- ... return x + 1
- >>> @f.register(float)
- ... def dec(x):
- ... return x - 1
- >>> @f.register(list)
- ... @f.register(tuple)
- ... def reverse(x):
- ... return x[::-1]
- >>> f(1)
- 2
- >>> f(1.0)
- 0.0
- >>> f([1, 2, 3])
- [3, 2, 1]
- """
- def _df(func):
- self.add(types, func, **kwargs) # type: ignore[call-arg]
- return func
- return _df
- @classmethod
- def get_func_params(cls, func):
- if hasattr(inspect, "signature"):
- sig = inspect.signature(func)
- return sig.parameters.values()
- @classmethod
- def get_func_annotations(cls, func):
- """ get annotations of function positional parameters
- """
- params = cls.get_func_params(func)
- if params:
- Parameter = inspect.Parameter
- params = (param for param in params
- if param.kind in
- (Parameter.POSITIONAL_ONLY,
- Parameter.POSITIONAL_OR_KEYWORD))
- annotations = tuple(
- param.annotation
- for param in params)
- if all(ann is not Parameter.empty for ann in annotations):
- return annotations
- def add(self, signature, func):
- """ Add new types/method pair to dispatcher
- >>> # xdoctest: +SKIP
- >>> D = Dispatcher('add')
- >>> D.add((int, int), lambda x, y: x + y)
- >>> D.add((float, float), lambda x, y: x + y)
- >>> D(1, 2)
- 3
- >>> D(1, 2.0)
- Traceback (most recent call last):
- ...
- NotImplementedError: Could not find signature for add: <int, float>
- >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback
- >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs
- >>> # as inputs. See ``ambiguity_warn`` for an example.
- """
- # Handle annotations
- if not signature:
- annotations = self.get_func_annotations(func)
- if annotations:
- signature = annotations
- # Handle union types
- if any(isinstance(typ, tuple) for typ in signature):
- for typs in expand_tuples(signature):
- self.add(typs, func)
- return
- new_signature = []
- for index, typ in enumerate(signature, start=1):
- if not isinstance(typ, (type, list)):
- str_sig = ', '.join(c.__name__ if isinstance(c, type)
- else str(c) for c in signature)
- raise TypeError("Tried to dispatch on non-type: %s\n"
- "In signature: <%s>\n"
- "In function: %s" %
- (typ, str_sig, self.name))
- # handle variadic signatures
- if isinstance(typ, list):
- if index != len(signature):
- raise TypeError(
- 'Variadic signature must be the last element'
- )
- if len(typ) != 1:
- raise TypeError(
- 'Variadic signature must contain exactly one element. '
- 'To use a variadic union type place the desired types '
- 'inside of a tuple, e.g., [(int, str)]'
- )
- new_signature.append(Variadic[typ[0]])
- else:
- new_signature.append(typ)
- self.funcs[tuple(new_signature)] = func
- self._cache.clear()
- try:
- del self._ordering
- except AttributeError:
- pass
- @property
- def ordering(self):
- try:
- return self._ordering
- except AttributeError:
- return self.reorder()
- def reorder(self, on_ambiguity=ambiguity_warn):
- self._ordering = od = ordering(self.funcs)
- amb = ambiguities(self.funcs)
- if amb:
- on_ambiguity(self, amb)
- return od
- def __call__(self, *args, **kwargs):
- types = tuple([type(arg) for arg in args])
- try:
- func = self._cache[types]
- except KeyError as e:
- func = self.dispatch(*types)
- if not func:
- raise NotImplementedError(
- 'Could not find signature for %s: <%s>' %
- (self.name, str_signature(types))) from e
- self._cache[types] = func
- try:
- return func(*args, **kwargs)
- except MDNotImplementedError as e:
- funcs = self.dispatch_iter(*types)
- next(funcs) # burn first
- for func in funcs:
- try:
- return func(*args, **kwargs)
- except MDNotImplementedError:
- pass
- raise NotImplementedError(
- "Matching functions for "
- "%s: <%s> found, but none completed successfully" % (
- self.name, str_signature(types),),) from e
- def __str__(self):
- return "<dispatched %s>" % self.name
- __repr__ = __str__
- def dispatch(self, *types):
- """Deterimine appropriate implementation for this type signature
- This method is internal. Users should call this object as a function.
- Implementation resolution occurs within the ``__call__`` method.
- >>> # xdoctest: +SKIP
- >>> from multipledispatch import dispatch
- >>> @dispatch(int)
- ... def inc(x):
- ... return x + 1
- >>> implementation = inc.dispatch(int)
- >>> implementation(3)
- 4
- >>> print(inc.dispatch(float))
- None
- See Also:
- ``multipledispatch.conflict`` - module to determine resolution order
- """
- if types in self.funcs:
- return self.funcs[types]
- try:
- return next(self.dispatch_iter(*types))
- except StopIteration:
- return None
- def dispatch_iter(self, *types):
- n = len(types)
- for signature in self.ordering:
- if len(signature) == n and all(map(issubclass, types, signature)):
- result = self.funcs[signature]
- yield result
- elif len(signature) and isvariadic(signature[-1]):
- if variadic_signature_matches(types, signature):
- result = self.funcs[signature]
- yield result
- def resolve(self, types):
- """ Deterimine appropriate implementation for this type signature
- .. deprecated:: 0.4.4
- Use ``dispatch(*types)`` instead
- """
- warn("resolve() is deprecated, use dispatch(*types)",
- DeprecationWarning)
- return self.dispatch(*types)
- def __getstate__(self):
- return {'name': self.name,
- 'funcs': self.funcs}
- def __setstate__(self, d):
- self.name = d['name']
- self.funcs = d['funcs']
- self._ordering = ordering(self.funcs)
- self._cache = {}
- @property
- def __doc__(self):
- docs = ["Multiply dispatched method: %s" % self.name]
- if self.doc:
- docs.append(self.doc)
- other = []
- for sig in self.ordering[::-1]:
- func = self.funcs[sig]
- if func.__doc__:
- s = 'Inputs: <%s>\n' % str_signature(sig)
- s += '-' * len(s) + '\n'
- s += func.__doc__.strip()
- docs.append(s)
- else:
- other.append(str_signature(sig))
- if other:
- docs.append('Other signatures:\n ' + '\n '.join(other))
- return '\n\n'.join(docs)
- def _help(self, *args):
- return self.dispatch(*map(type, args)).__doc__
- def help(self, *args, **kwargs):
- """ Print docstring for the function corresponding to inputs """
- print(self._help(*args))
- def _source(self, *args):
- func = self.dispatch(*map(type, args))
- if not func:
- raise TypeError("No function found")
- return source(func)
- def source(self, *args, **kwargs):
- """ Print source code for the function corresponding to inputs """
- print(self._source(*args))
- def source(func):
- s = 'File: %s\n\n' % inspect.getsourcefile(func)
- s = s + inspect.getsource(func)
- return s
- class MethodDispatcher(Dispatcher):
- """ Dispatch methods based on type signature
- See Also:
- Dispatcher
- """
- __slots__ = ('obj', 'cls')
- @classmethod
- def get_func_params(cls, func):
- if hasattr(inspect, "signature"):
- sig = inspect.signature(func)
- return itl.islice(sig.parameters.values(), 1, None)
- def __get__(self, instance, owner):
- self.obj = instance
- self.cls = owner
- return self
- def __call__(self, *args, **kwargs):
- types = tuple([type(arg) for arg in args])
- func = self.dispatch(*types)
- if not func:
- raise NotImplementedError('Could not find signature for %s: <%s>' %
- (self.name, str_signature(types)))
- return func(self.obj, *args, **kwargs)
- def str_signature(sig):
- """ String representation of type signature
- >>> str_signature((int, float))
- 'int, float'
- """
- return ', '.join(cls.__name__ for cls in sig)
- def warning_text(name, amb):
- """ The text for ambiguity warnings """
- text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
- text += "The following signatures may result in ambiguous behavior:\n"
- for pair in amb:
- text += "\t" + \
- ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
- text += "\n\nConsider making the following additions:\n\n"
- text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
- + ')\ndef %s(...)' % name for s in amb])
- return text
|