dispatcher.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. from warnings import warn
  2. import inspect
  3. from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
  4. from .utils import expand_tuples
  5. from .variadic import Variadic, isvariadic
  6. import itertools as itl
  7. __all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter",
  8. "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"]
  9. class MDNotImplementedError(NotImplementedError):
  10. """ A NotImplementedError for multiple dispatch """
  11. def ambiguity_warn(dispatcher, ambiguities):
  12. """ Raise warning when ambiguity is detected
  13. Parameters
  14. ----------
  15. dispatcher : Dispatcher
  16. The dispatcher on which the ambiguity was detected
  17. ambiguities : set
  18. Set of type signature pairs that are ambiguous within this dispatcher
  19. See Also:
  20. Dispatcher.add
  21. warning_text
  22. """
  23. warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
  24. def halt_ordering():
  25. """Deprecated interface to temporarily disable ordering.
  26. """
  27. warn(
  28. 'halt_ordering is deprecated, you can safely remove this call.',
  29. DeprecationWarning,
  30. )
  31. def restart_ordering(on_ambiguity=ambiguity_warn):
  32. """Deprecated interface to temporarily resume ordering.
  33. """
  34. warn(
  35. 'restart_ordering is deprecated, if you would like to eagerly order'
  36. 'the dispatchers, you should call the ``reorder()`` method on each'
  37. ' dispatcher.',
  38. DeprecationWarning,
  39. )
  40. def variadic_signature_matches_iter(types, full_signature):
  41. """Check if a set of input types matches a variadic signature.
  42. Notes
  43. -----
  44. The algorithm is as follows:
  45. Initialize the current signature to the first in the sequence
  46. For each type in `types`:
  47. If the current signature is variadic
  48. If the type matches the signature
  49. yield True
  50. Else
  51. Try to get the next signature
  52. If no signatures are left we can't possibly have a match
  53. so yield False
  54. Else
  55. yield True if the type matches the current signature
  56. Get the next signature
  57. """
  58. sigiter = iter(full_signature)
  59. sig = next(sigiter)
  60. for typ in types:
  61. matches = issubclass(typ, sig)
  62. yield matches
  63. if not isvariadic(sig):
  64. # we're not matching a variadic argument, so move to the next
  65. # element in the signature
  66. sig = next(sigiter)
  67. else:
  68. try:
  69. sig = next(sigiter)
  70. except StopIteration:
  71. assert isvariadic(sig)
  72. yield True
  73. else:
  74. # We have signature items left over, so all of our arguments
  75. # haven't matched
  76. yield False
  77. def variadic_signature_matches(types, full_signature):
  78. # No arguments always matches a variadic signature
  79. assert full_signature
  80. return all(variadic_signature_matches_iter(types, full_signature))
  81. class Dispatcher:
  82. """ Dispatch methods based on type signature
  83. Use ``dispatch`` to add implementations
  84. Examples
  85. --------
  86. >>> # xdoctest: +SKIP("bad import name")
  87. >>> from multipledispatch import dispatch
  88. >>> @dispatch(int)
  89. ... def f(x):
  90. ... return x + 1
  91. >>> @dispatch(float)
  92. ... def f(x):
  93. ... return x - 1
  94. >>> f(3)
  95. 4
  96. >>> f(3.0)
  97. 2.0
  98. """
  99. __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc'
  100. def __init__(self, name, doc=None):
  101. self.name = self.__name__ = name
  102. self.funcs = {}
  103. self.doc = doc
  104. self._cache = {}
  105. def register(self, *types, **kwargs):
  106. """ register dispatcher with new implementation
  107. >>> # xdoctest: +SKIP
  108. >>> f = Dispatcher('f')
  109. >>> @f.register(int)
  110. ... def inc(x):
  111. ... return x + 1
  112. >>> @f.register(float)
  113. ... def dec(x):
  114. ... return x - 1
  115. >>> @f.register(list)
  116. ... @f.register(tuple)
  117. ... def reverse(x):
  118. ... return x[::-1]
  119. >>> f(1)
  120. 2
  121. >>> f(1.0)
  122. 0.0
  123. >>> f([1, 2, 3])
  124. [3, 2, 1]
  125. """
  126. def _df(func):
  127. self.add(types, func, **kwargs) # type: ignore[call-arg]
  128. return func
  129. return _df
  130. @classmethod
  131. def get_func_params(cls, func):
  132. if hasattr(inspect, "signature"):
  133. sig = inspect.signature(func)
  134. return sig.parameters.values()
  135. @classmethod
  136. def get_func_annotations(cls, func):
  137. """ get annotations of function positional parameters
  138. """
  139. params = cls.get_func_params(func)
  140. if params:
  141. Parameter = inspect.Parameter
  142. params = (param for param in params
  143. if param.kind in
  144. (Parameter.POSITIONAL_ONLY,
  145. Parameter.POSITIONAL_OR_KEYWORD))
  146. annotations = tuple(
  147. param.annotation
  148. for param in params)
  149. if all(ann is not Parameter.empty for ann in annotations):
  150. return annotations
  151. def add(self, signature, func):
  152. """ Add new types/method pair to dispatcher
  153. >>> # xdoctest: +SKIP
  154. >>> D = Dispatcher('add')
  155. >>> D.add((int, int), lambda x, y: x + y)
  156. >>> D.add((float, float), lambda x, y: x + y)
  157. >>> D(1, 2)
  158. 3
  159. >>> D(1, 2.0)
  160. Traceback (most recent call last):
  161. ...
  162. NotImplementedError: Could not find signature for add: <int, float>
  163. >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback
  164. >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs
  165. >>> # as inputs. See ``ambiguity_warn`` for an example.
  166. """
  167. # Handle annotations
  168. if not signature:
  169. annotations = self.get_func_annotations(func)
  170. if annotations:
  171. signature = annotations
  172. # Handle union types
  173. if any(isinstance(typ, tuple) for typ in signature):
  174. for typs in expand_tuples(signature):
  175. self.add(typs, func)
  176. return
  177. new_signature = []
  178. for index, typ in enumerate(signature, start=1):
  179. if not isinstance(typ, (type, list)):
  180. str_sig = ', '.join(c.__name__ if isinstance(c, type)
  181. else str(c) for c in signature)
  182. raise TypeError("Tried to dispatch on non-type: %s\n"
  183. "In signature: <%s>\n"
  184. "In function: %s" %
  185. (typ, str_sig, self.name))
  186. # handle variadic signatures
  187. if isinstance(typ, list):
  188. if index != len(signature):
  189. raise TypeError(
  190. 'Variadic signature must be the last element'
  191. )
  192. if len(typ) != 1:
  193. raise TypeError(
  194. 'Variadic signature must contain exactly one element. '
  195. 'To use a variadic union type place the desired types '
  196. 'inside of a tuple, e.g., [(int, str)]'
  197. )
  198. new_signature.append(Variadic[typ[0]])
  199. else:
  200. new_signature.append(typ)
  201. self.funcs[tuple(new_signature)] = func
  202. self._cache.clear()
  203. try:
  204. del self._ordering
  205. except AttributeError:
  206. pass
  207. @property
  208. def ordering(self):
  209. try:
  210. return self._ordering
  211. except AttributeError:
  212. return self.reorder()
  213. def reorder(self, on_ambiguity=ambiguity_warn):
  214. self._ordering = od = ordering(self.funcs)
  215. amb = ambiguities(self.funcs)
  216. if amb:
  217. on_ambiguity(self, amb)
  218. return od
  219. def __call__(self, *args, **kwargs):
  220. types = tuple([type(arg) for arg in args])
  221. try:
  222. func = self._cache[types]
  223. except KeyError as e:
  224. func = self.dispatch(*types)
  225. if not func:
  226. raise NotImplementedError(
  227. 'Could not find signature for %s: <%s>' %
  228. (self.name, str_signature(types))) from e
  229. self._cache[types] = func
  230. try:
  231. return func(*args, **kwargs)
  232. except MDNotImplementedError as e:
  233. funcs = self.dispatch_iter(*types)
  234. next(funcs) # burn first
  235. for func in funcs:
  236. try:
  237. return func(*args, **kwargs)
  238. except MDNotImplementedError:
  239. pass
  240. raise NotImplementedError(
  241. "Matching functions for "
  242. "%s: <%s> found, but none completed successfully" % (
  243. self.name, str_signature(types),),) from e
  244. def __str__(self):
  245. return "<dispatched %s>" % self.name
  246. __repr__ = __str__
  247. def dispatch(self, *types):
  248. """Deterimine appropriate implementation for this type signature
  249. This method is internal. Users should call this object as a function.
  250. Implementation resolution occurs within the ``__call__`` method.
  251. >>> # xdoctest: +SKIP
  252. >>> from multipledispatch import dispatch
  253. >>> @dispatch(int)
  254. ... def inc(x):
  255. ... return x + 1
  256. >>> implementation = inc.dispatch(int)
  257. >>> implementation(3)
  258. 4
  259. >>> print(inc.dispatch(float))
  260. None
  261. See Also:
  262. ``multipledispatch.conflict`` - module to determine resolution order
  263. """
  264. if types in self.funcs:
  265. return self.funcs[types]
  266. try:
  267. return next(self.dispatch_iter(*types))
  268. except StopIteration:
  269. return None
  270. def dispatch_iter(self, *types):
  271. n = len(types)
  272. for signature in self.ordering:
  273. if len(signature) == n and all(map(issubclass, types, signature)):
  274. result = self.funcs[signature]
  275. yield result
  276. elif len(signature) and isvariadic(signature[-1]):
  277. if variadic_signature_matches(types, signature):
  278. result = self.funcs[signature]
  279. yield result
  280. def resolve(self, types):
  281. """ Deterimine appropriate implementation for this type signature
  282. .. deprecated:: 0.4.4
  283. Use ``dispatch(*types)`` instead
  284. """
  285. warn("resolve() is deprecated, use dispatch(*types)",
  286. DeprecationWarning)
  287. return self.dispatch(*types)
  288. def __getstate__(self):
  289. return {'name': self.name,
  290. 'funcs': self.funcs}
  291. def __setstate__(self, d):
  292. self.name = d['name']
  293. self.funcs = d['funcs']
  294. self._ordering = ordering(self.funcs)
  295. self._cache = {}
  296. @property
  297. def __doc__(self):
  298. docs = ["Multiply dispatched method: %s" % self.name]
  299. if self.doc:
  300. docs.append(self.doc)
  301. other = []
  302. for sig in self.ordering[::-1]:
  303. func = self.funcs[sig]
  304. if func.__doc__:
  305. s = 'Inputs: <%s>\n' % str_signature(sig)
  306. s += '-' * len(s) + '\n'
  307. s += func.__doc__.strip()
  308. docs.append(s)
  309. else:
  310. other.append(str_signature(sig))
  311. if other:
  312. docs.append('Other signatures:\n ' + '\n '.join(other))
  313. return '\n\n'.join(docs)
  314. def _help(self, *args):
  315. return self.dispatch(*map(type, args)).__doc__
  316. def help(self, *args, **kwargs):
  317. """ Print docstring for the function corresponding to inputs """
  318. print(self._help(*args))
  319. def _source(self, *args):
  320. func = self.dispatch(*map(type, args))
  321. if not func:
  322. raise TypeError("No function found")
  323. return source(func)
  324. def source(self, *args, **kwargs):
  325. """ Print source code for the function corresponding to inputs """
  326. print(self._source(*args))
  327. def source(func):
  328. s = 'File: %s\n\n' % inspect.getsourcefile(func)
  329. s = s + inspect.getsource(func)
  330. return s
  331. class MethodDispatcher(Dispatcher):
  332. """ Dispatch methods based on type signature
  333. See Also:
  334. Dispatcher
  335. """
  336. __slots__ = ('obj', 'cls')
  337. @classmethod
  338. def get_func_params(cls, func):
  339. if hasattr(inspect, "signature"):
  340. sig = inspect.signature(func)
  341. return itl.islice(sig.parameters.values(), 1, None)
  342. def __get__(self, instance, owner):
  343. self.obj = instance
  344. self.cls = owner
  345. return self
  346. def __call__(self, *args, **kwargs):
  347. types = tuple([type(arg) for arg in args])
  348. func = self.dispatch(*types)
  349. if not func:
  350. raise NotImplementedError('Could not find signature for %s: <%s>' %
  351. (self.name, str_signature(types)))
  352. return func(self.obj, *args, **kwargs)
  353. def str_signature(sig):
  354. """ String representation of type signature
  355. >>> str_signature((int, float))
  356. 'int, float'
  357. """
  358. return ', '.join(cls.__name__ for cls in sig)
  359. def warning_text(name, amb):
  360. """ The text for ambiguity warnings """
  361. text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
  362. text += "The following signatures may result in ambiguous behavior:\n"
  363. for pair in amb:
  364. text += "\t" + \
  365. ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
  366. text += "\n\nConsider making the following additions:\n\n"
  367. text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
  368. + ')\ndef %s(...)' % name for s in amb])
  369. return text