test_core.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. from __future__ import annotations
  2. from typing import Any
  3. from sympy.multipledispatch import dispatch
  4. from sympy.multipledispatch.conflict import AmbiguityWarning
  5. from sympy.testing.pytest import raises, warns
  6. from functools import partial
  7. test_namespace: dict[str, Any] = {}
  8. orig_dispatch = dispatch
  9. dispatch = partial(dispatch, namespace=test_namespace)
  10. def test_singledispatch():
  11. @dispatch(int)
  12. def f(x): # noqa:F811
  13. return x + 1
  14. @dispatch(int)
  15. def g(x): # noqa:F811
  16. return x + 2
  17. @dispatch(float) # noqa:F811
  18. def f(x): # noqa:F811
  19. return x - 1
  20. assert f(1) == 2
  21. assert g(1) == 3
  22. assert f(1.0) == 0
  23. assert raises(NotImplementedError, lambda: f('hello'))
  24. def test_multipledispatch():
  25. @dispatch(int, int)
  26. def f(x, y): # noqa:F811
  27. return x + y
  28. @dispatch(float, float) # noqa:F811
  29. def f(x, y): # noqa:F811
  30. return x - y
  31. assert f(1, 2) == 3
  32. assert f(1.0, 2.0) == -1.0
  33. class A: pass
  34. class B: pass
  35. class C(A): pass
  36. class D(C): pass
  37. class E(C): pass
  38. def test_inheritance():
  39. @dispatch(A)
  40. def f(x): # noqa:F811
  41. return 'a'
  42. @dispatch(B) # noqa:F811
  43. def f(x): # noqa:F811
  44. return 'b'
  45. assert f(A()) == 'a'
  46. assert f(B()) == 'b'
  47. assert f(C()) == 'a'
  48. def test_inheritance_and_multiple_dispatch():
  49. @dispatch(A, A)
  50. def f(x, y): # noqa:F811
  51. return type(x), type(y)
  52. @dispatch(A, B) # noqa:F811
  53. def f(x, y): # noqa:F811
  54. return 0
  55. assert f(A(), A()) == (A, A)
  56. assert f(A(), C()) == (A, C)
  57. assert f(A(), B()) == 0
  58. assert f(C(), B()) == 0
  59. assert raises(NotImplementedError, lambda: f(B(), B()))
  60. def test_competing_solutions():
  61. @dispatch(A)
  62. def h(x): # noqa:F811
  63. return 1
  64. @dispatch(C) # noqa:F811
  65. def h(x): # noqa:F811
  66. return 2
  67. assert h(D()) == 2
  68. def test_competing_multiple():
  69. @dispatch(A, B)
  70. def h(x, y): # noqa:F811
  71. return 1
  72. @dispatch(C, B) # noqa:F811
  73. def h(x, y): # noqa:F811
  74. return 2
  75. assert h(D(), B()) == 2
  76. def test_competing_ambiguous():
  77. test_namespace = {}
  78. dispatch = partial(orig_dispatch, namespace=test_namespace)
  79. @dispatch(A, C)
  80. def f(x, y): # noqa:F811
  81. return 2
  82. with warns(AmbiguityWarning, test_stacklevel=False):
  83. @dispatch(C, A) # noqa:F811
  84. def f(x, y): # noqa:F811
  85. return 2
  86. assert f(A(), C()) == f(C(), A()) == 2
  87. # assert raises(Warning, lambda : f(C(), C()))
  88. def test_caching_correct_behavior():
  89. @dispatch(A)
  90. def f(x): # noqa:F811
  91. return 1
  92. assert f(C()) == 1
  93. @dispatch(C)
  94. def f(x): # noqa:F811
  95. return 2
  96. assert f(C()) == 2
  97. def test_union_types():
  98. @dispatch((A, C))
  99. def f(x): # noqa:F811
  100. return 1
  101. assert f(A()) == 1
  102. assert f(C()) == 1
  103. def test_namespaces():
  104. ns1 = {}
  105. ns2 = {}
  106. def foo(x):
  107. return 1
  108. foo1 = orig_dispatch(int, namespace=ns1)(foo)
  109. def foo(x):
  110. return 2
  111. foo2 = orig_dispatch(int, namespace=ns2)(foo)
  112. assert foo1(0) == 1
  113. assert foo2(0) == 2
  114. """
  115. Fails
  116. def test_dispatch_on_dispatch():
  117. @dispatch(A)
  118. @dispatch(C)
  119. def q(x): # noqa:F811
  120. return 1
  121. assert q(A()) == 1
  122. assert q(C()) == 1
  123. """
  124. def test_methods():
  125. class Foo:
  126. @dispatch(float)
  127. def f(self, x): # noqa:F811
  128. return x - 1
  129. @dispatch(int) # noqa:F811
  130. def f(self, x): # noqa:F811
  131. return x + 1
  132. @dispatch(int)
  133. def g(self, x): # noqa:F811
  134. return x + 3
  135. foo = Foo()
  136. assert foo.f(1) == 2
  137. assert foo.f(1.0) == 0.0
  138. assert foo.g(1) == 4
  139. def test_methods_multiple_dispatch():
  140. class Foo:
  141. @dispatch(A, A)
  142. def f(x, y): # noqa:F811
  143. return 1
  144. @dispatch(A, C) # noqa:F811
  145. def f(x, y): # noqa:F811
  146. return 2
  147. foo = Foo()
  148. assert foo.f(A(), A()) == 1
  149. assert foo.f(A(), C()) == 2
  150. assert foo.f(C(), C()) == 2