test_dispatcher.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. from sympy.multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError,
  2. MethodDispatcher, halt_ordering,
  3. restart_ordering,
  4. ambiguity_register_error_ignore_dup)
  5. from sympy.testing.pytest import raises, warns
  6. def identity(x):
  7. return x
  8. def inc(x):
  9. return x + 1
  10. def dec(x):
  11. return x - 1
  12. def test_dispatcher():
  13. f = Dispatcher('f')
  14. f.add((int,), inc)
  15. f.add((float,), dec)
  16. with warns(DeprecationWarning, test_stacklevel=False):
  17. assert f.resolve((int,)) == inc
  18. assert f.dispatch(int) is inc
  19. assert f(1) == 2
  20. assert f(1.0) == 0.0
  21. def test_union_types():
  22. f = Dispatcher('f')
  23. f.register((int, float))(inc)
  24. assert f(1) == 2
  25. assert f(1.0) == 2.0
  26. def test_dispatcher_as_decorator():
  27. f = Dispatcher('f')
  28. @f.register(int)
  29. def inc(x): # noqa:F811
  30. return x + 1
  31. @f.register(float) # noqa:F811
  32. def inc(x): # noqa:F811
  33. return x - 1
  34. assert f(1) == 2
  35. assert f(1.0) == 0.0
  36. def test_register_instance_method():
  37. class Test:
  38. __init__ = MethodDispatcher('f')
  39. @__init__.register(list)
  40. def _init_list(self, data):
  41. self.data = data
  42. @__init__.register(object)
  43. def _init_obj(self, datum):
  44. self.data = [datum]
  45. a = Test(3)
  46. b = Test([3])
  47. assert a.data == b.data
  48. def test_on_ambiguity():
  49. f = Dispatcher('f')
  50. def identity(x): return x
  51. ambiguities = [False]
  52. def on_ambiguity(dispatcher, amb):
  53. ambiguities[0] = True
  54. f.add((object, object), identity, on_ambiguity=on_ambiguity)
  55. assert not ambiguities[0]
  56. f.add((object, float), identity, on_ambiguity=on_ambiguity)
  57. assert not ambiguities[0]
  58. f.add((float, object), identity, on_ambiguity=on_ambiguity)
  59. assert ambiguities[0]
  60. def test_raise_error_on_non_class():
  61. f = Dispatcher('f')
  62. assert raises(TypeError, lambda: f.add((1,), inc))
  63. def test_docstring():
  64. def one(x, y):
  65. """ Docstring number one """
  66. return x + y
  67. def two(x, y):
  68. """ Docstring number two """
  69. return x + y
  70. def three(x, y):
  71. return x + y
  72. master_doc = 'Doc of the multimethod itself'
  73. f = Dispatcher('f', doc=master_doc)
  74. f.add((object, object), one)
  75. f.add((int, int), two)
  76. f.add((float, float), three)
  77. assert one.__doc__.strip() in f.__doc__
  78. assert two.__doc__.strip() in f.__doc__
  79. assert f.__doc__.find(one.__doc__.strip()) < \
  80. f.__doc__.find(two.__doc__.strip())
  81. assert 'object, object' in f.__doc__
  82. assert master_doc in f.__doc__
  83. def test_help():
  84. def one(x, y):
  85. """ Docstring number one """
  86. return x + y
  87. def two(x, y):
  88. """ Docstring number two """
  89. return x + y
  90. def three(x, y):
  91. """ Docstring number three """
  92. return x + y
  93. master_doc = 'Doc of the multimethod itself'
  94. f = Dispatcher('f', doc=master_doc)
  95. f.add((object, object), one)
  96. f.add((int, int), two)
  97. f.add((float, float), three)
  98. assert f._help(1, 1) == two.__doc__
  99. assert f._help(1.0, 2.0) == three.__doc__
  100. def test_source():
  101. def one(x, y):
  102. """ Docstring number one """
  103. return x + y
  104. def two(x, y):
  105. """ Docstring number two """
  106. return x - y
  107. master_doc = 'Doc of the multimethod itself'
  108. f = Dispatcher('f', doc=master_doc)
  109. f.add((int, int), one)
  110. f.add((float, float), two)
  111. assert 'x + y' in f._source(1, 1)
  112. assert 'x - y' in f._source(1.0, 1.0)
  113. def test_source_raises_on_missing_function():
  114. f = Dispatcher('f')
  115. assert raises(TypeError, lambda: f.source(1))
  116. def test_halt_method_resolution():
  117. g = [0]
  118. def on_ambiguity(a, b):
  119. g[0] += 1
  120. f = Dispatcher('f')
  121. halt_ordering()
  122. def func(*args):
  123. pass
  124. f.add((int, object), func)
  125. f.add((object, int), func)
  126. assert g == [0]
  127. restart_ordering(on_ambiguity=on_ambiguity)
  128. assert g == [1]
  129. assert set(f.ordering) == {(int, object), (object, int)}
  130. def test_no_implementations():
  131. f = Dispatcher('f')
  132. assert raises(NotImplementedError, lambda: f('hello'))
  133. def test_register_stacking():
  134. f = Dispatcher('f')
  135. @f.register(list)
  136. @f.register(tuple)
  137. def rev(x):
  138. return x[::-1]
  139. assert f((1, 2, 3)) == (3, 2, 1)
  140. assert f([1, 2, 3]) == [3, 2, 1]
  141. assert raises(NotImplementedError, lambda: f('hello'))
  142. assert rev('hello') == 'olleh'
  143. def test_dispatch_method():
  144. f = Dispatcher('f')
  145. @f.register(list)
  146. def rev(x):
  147. return x[::-1]
  148. @f.register(int, int)
  149. def add(x, y):
  150. return x + y
  151. class MyList(list):
  152. pass
  153. assert f.dispatch(list) is rev
  154. assert f.dispatch(MyList) is rev
  155. assert f.dispatch(int, int) is add
  156. def test_not_implemented():
  157. f = Dispatcher('f')
  158. @f.register(object)
  159. def _(x):
  160. return 'default'
  161. @f.register(int)
  162. def _(x):
  163. if x % 2 == 0:
  164. return 'even'
  165. else:
  166. raise MDNotImplementedError()
  167. assert f('hello') == 'default' # default behavior
  168. assert f(2) == 'even' # specialized behavior
  169. assert f(3) == 'default' # fall bac to default behavior
  170. assert raises(NotImplementedError, lambda: f(1, 2))
  171. def test_not_implemented_error():
  172. f = Dispatcher('f')
  173. @f.register(float)
  174. def _(a):
  175. raise MDNotImplementedError()
  176. assert raises(NotImplementedError, lambda: f(1.0))
  177. def test_ambiguity_register_error_ignore_dup():
  178. f = Dispatcher('f')
  179. class A:
  180. pass
  181. class B(A):
  182. pass
  183. class C(A):
  184. pass
  185. # suppress warning for registering ambiguous signal
  186. f.add((A, B), lambda x,y: None, ambiguity_register_error_ignore_dup)
  187. f.add((B, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
  188. f.add((A, C), lambda x,y: None, ambiguity_register_error_ignore_dup)
  189. f.add((C, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
  190. # raises error if ambiguous signal is passed
  191. assert raises(NotImplementedError, lambda: f(B(), C()))