123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 |
- from sympy.multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError,
- MethodDispatcher, halt_ordering,
- restart_ordering,
- ambiguity_register_error_ignore_dup)
- from sympy.testing.pytest import raises, warns
- def identity(x):
- return x
- def inc(x):
- return x + 1
- def dec(x):
- return x - 1
- def test_dispatcher():
- f = Dispatcher('f')
- f.add((int,), inc)
- f.add((float,), dec)
- with warns(DeprecationWarning, test_stacklevel=False):
- assert f.resolve((int,)) == inc
- assert f.dispatch(int) is inc
- assert f(1) == 2
- assert f(1.0) == 0.0
- def test_union_types():
- f = Dispatcher('f')
- f.register((int, float))(inc)
- assert f(1) == 2
- assert f(1.0) == 2.0
- def test_dispatcher_as_decorator():
- f = Dispatcher('f')
- @f.register(int)
- def inc(x): # noqa:F811
- return x + 1
- @f.register(float) # noqa:F811
- def inc(x): # noqa:F811
- return x - 1
- assert f(1) == 2
- assert f(1.0) == 0.0
- def test_register_instance_method():
- class Test:
- __init__ = MethodDispatcher('f')
- @__init__.register(list)
- def _init_list(self, data):
- self.data = data
- @__init__.register(object)
- def _init_obj(self, datum):
- self.data = [datum]
- a = Test(3)
- b = Test([3])
- assert a.data == b.data
- def test_on_ambiguity():
- f = Dispatcher('f')
- def identity(x): return x
- ambiguities = [False]
- def on_ambiguity(dispatcher, amb):
- ambiguities[0] = True
- f.add((object, object), identity, on_ambiguity=on_ambiguity)
- assert not ambiguities[0]
- f.add((object, float), identity, on_ambiguity=on_ambiguity)
- assert not ambiguities[0]
- f.add((float, object), identity, on_ambiguity=on_ambiguity)
- assert ambiguities[0]
- def test_raise_error_on_non_class():
- f = Dispatcher('f')
- assert raises(TypeError, lambda: f.add((1,), inc))
- def test_docstring():
- def one(x, y):
- """ Docstring number one """
- return x + y
- def two(x, y):
- """ Docstring number two """
- return x + y
- def three(x, y):
- return x + y
- master_doc = 'Doc of the multimethod itself'
- f = Dispatcher('f', doc=master_doc)
- f.add((object, object), one)
- f.add((int, int), two)
- f.add((float, float), three)
- assert one.__doc__.strip() in f.__doc__
- assert two.__doc__.strip() in f.__doc__
- assert f.__doc__.find(one.__doc__.strip()) < \
- f.__doc__.find(two.__doc__.strip())
- assert 'object, object' in f.__doc__
- assert master_doc in f.__doc__
- def test_help():
- def one(x, y):
- """ Docstring number one """
- return x + y
- def two(x, y):
- """ Docstring number two """
- return x + y
- def three(x, y):
- """ Docstring number three """
- return x + y
- master_doc = 'Doc of the multimethod itself'
- f = Dispatcher('f', doc=master_doc)
- f.add((object, object), one)
- f.add((int, int), two)
- f.add((float, float), three)
- assert f._help(1, 1) == two.__doc__
- assert f._help(1.0, 2.0) == three.__doc__
- def test_source():
- def one(x, y):
- """ Docstring number one """
- return x + y
- def two(x, y):
- """ Docstring number two """
- return x - y
- master_doc = 'Doc of the multimethod itself'
- f = Dispatcher('f', doc=master_doc)
- f.add((int, int), one)
- f.add((float, float), two)
- assert 'x + y' in f._source(1, 1)
- assert 'x - y' in f._source(1.0, 1.0)
- def test_source_raises_on_missing_function():
- f = Dispatcher('f')
- assert raises(TypeError, lambda: f.source(1))
- def test_halt_method_resolution():
- g = [0]
- def on_ambiguity(a, b):
- g[0] += 1
- f = Dispatcher('f')
- halt_ordering()
- def func(*args):
- pass
- f.add((int, object), func)
- f.add((object, int), func)
- assert g == [0]
- restart_ordering(on_ambiguity=on_ambiguity)
- assert g == [1]
- assert set(f.ordering) == {(int, object), (object, int)}
- def test_no_implementations():
- f = Dispatcher('f')
- assert raises(NotImplementedError, lambda: f('hello'))
- def test_register_stacking():
- f = Dispatcher('f')
- @f.register(list)
- @f.register(tuple)
- def rev(x):
- return x[::-1]
- assert f((1, 2, 3)) == (3, 2, 1)
- assert f([1, 2, 3]) == [3, 2, 1]
- assert raises(NotImplementedError, lambda: f('hello'))
- assert rev('hello') == 'olleh'
- def test_dispatch_method():
- f = Dispatcher('f')
- @f.register(list)
- def rev(x):
- return x[::-1]
- @f.register(int, int)
- def add(x, y):
- return x + y
- class MyList(list):
- pass
- assert f.dispatch(list) is rev
- assert f.dispatch(MyList) is rev
- assert f.dispatch(int, int) is add
- def test_not_implemented():
- f = Dispatcher('f')
- @f.register(object)
- def _(x):
- return 'default'
- @f.register(int)
- def _(x):
- if x % 2 == 0:
- return 'even'
- else:
- raise MDNotImplementedError()
- assert f('hello') == 'default' # default behavior
- assert f(2) == 'even' # specialized behavior
- assert f(3) == 'default' # fall bac to default behavior
- assert raises(NotImplementedError, lambda: f(1, 2))
- def test_not_implemented_error():
- f = Dispatcher('f')
- @f.register(float)
- def _(a):
- raise MDNotImplementedError()
- assert raises(NotImplementedError, lambda: f(1.0))
- def test_ambiguity_register_error_ignore_dup():
- f = Dispatcher('f')
- class A:
- pass
- class B(A):
- pass
- class C(A):
- pass
- # suppress warning for registering ambiguous signal
- f.add((A, B), lambda x,y: None, ambiguity_register_error_ignore_dup)
- f.add((B, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
- f.add((A, C), lambda x,y: None, ambiguity_register_error_ignore_dup)
- f.add((C, A), lambda x,y: None, ambiguity_register_error_ignore_dup)
- # raises error if ambiguous signal is passed
- assert raises(NotImplementedError, lambda: f(B(), C()))
|