match.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from .core import unify, reify # type: ignore[attr-defined]
  2. from .variable import isvar
  3. from .utils import _toposort, freeze
  4. from .unification_tools import groupby, first # type: ignore[import]
  5. class Dispatcher:
  6. def __init__(self, name):
  7. self.name = name
  8. self.funcs = {}
  9. self.ordering = []
  10. def add(self, signature, func):
  11. self.funcs[freeze(signature)] = func
  12. self.ordering = ordering(self.funcs)
  13. def __call__(self, *args, **kwargs):
  14. func, s = self.resolve(args)
  15. return func(*args, **kwargs)
  16. def resolve(self, args):
  17. n = len(args)
  18. for signature in self.ordering:
  19. if len(signature) != n:
  20. continue
  21. s = unify(freeze(args), signature)
  22. if s is not False:
  23. result = self.funcs[signature]
  24. return result, s
  25. raise NotImplementedError("No match found. \nKnown matches: "
  26. + str(self.ordering) + "\nInput: " + str(args))
  27. def register(self, *signature):
  28. def _(func):
  29. self.add(signature, func)
  30. return self
  31. return _
  32. class VarDispatcher(Dispatcher):
  33. """ A dispatcher that calls functions with variable names
  34. >>> # xdoctest: +SKIP
  35. >>> d = VarDispatcher('d')
  36. >>> x = var('x')
  37. >>> @d.register('inc', x)
  38. ... def f(x):
  39. ... return x + 1
  40. >>> @d.register('double', x)
  41. ... def f(x):
  42. ... return x * 2
  43. >>> d('inc', 10)
  44. 11
  45. >>> d('double', 10)
  46. 20
  47. """
  48. def __call__(self, *args, **kwargs):
  49. func, s = self.resolve(args)
  50. d = {k.token: v for k, v in s.items()}
  51. return func(**d)
  52. global_namespace = {} # type: ignore[var-annotated]
  53. def match(*signature, **kwargs):
  54. namespace = kwargs.get('namespace', global_namespace)
  55. dispatcher = kwargs.get('Dispatcher', Dispatcher)
  56. def _(func):
  57. name = func.__name__
  58. if name not in namespace:
  59. namespace[name] = dispatcher(name)
  60. d = namespace[name]
  61. d.add(signature, func)
  62. return d
  63. return _
  64. def supercedes(a, b):
  65. """ ``a`` is a more specific match than ``b`` """
  66. if isvar(b) and not isvar(a):
  67. return True
  68. s = unify(a, b)
  69. if s is False:
  70. return False
  71. s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)}
  72. if reify(a, s) == a:
  73. return True
  74. if reify(b, s) == b:
  75. return False
  76. # Taken from multipledispatch
  77. def edge(a, b, tie_breaker=hash):
  78. """ A should be checked before B
  79. Tie broken by tie_breaker, defaults to ``hash``
  80. """
  81. if supercedes(a, b):
  82. if supercedes(b, a):
  83. return tie_breaker(a) > tie_breaker(b)
  84. else:
  85. return True
  86. return False
  87. # Taken from multipledispatch
  88. def ordering(signatures):
  89. """ A sane ordering of signatures to check, first to last
  90. Topoological sort of edges as given by ``edge`` and ``supercedes``
  91. """
  92. signatures = list(map(tuple, signatures))
  93. edges = [(a, b) for a in signatures for b in signatures if edge(a, b)]
  94. edges = groupby(first, edges)
  95. for s in signatures:
  96. if s not in edges:
  97. edges[s] = []
  98. edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment]
  99. return _toposort(edges)