test_ccallback.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from numpy.testing import assert_equal, assert_
  2. from pytest import raises as assert_raises
  3. import time
  4. import pytest
  5. import ctypes
  6. import threading
  7. from scipy._lib import _ccallback_c as _test_ccallback_cython
  8. from scipy._lib import _test_ccallback
  9. from scipy._lib._ccallback import LowLevelCallable
  10. try:
  11. import cffi
  12. HAVE_CFFI = True
  13. except ImportError:
  14. HAVE_CFFI = False
  15. ERROR_VALUE = 2.0
  16. def callback_python(a, user_data=None):
  17. if a == ERROR_VALUE:
  18. raise ValueError("bad value")
  19. if user_data is None:
  20. return a + 1
  21. else:
  22. return a + user_data
  23. def _get_cffi_func(base, signature):
  24. if not HAVE_CFFI:
  25. pytest.skip("cffi not installed")
  26. # Get function address
  27. voidp = ctypes.cast(base, ctypes.c_void_p)
  28. address = voidp.value
  29. # Create corresponding cffi handle
  30. ffi = cffi.FFI()
  31. func = ffi.cast(signature, address)
  32. return func
  33. def _get_ctypes_data():
  34. value = ctypes.c_double(2.0)
  35. return ctypes.cast(ctypes.pointer(value), ctypes.c_voidp)
  36. def _get_cffi_data():
  37. if not HAVE_CFFI:
  38. pytest.skip("cffi not installed")
  39. ffi = cffi.FFI()
  40. return ffi.new('double *', 2.0)
  41. CALLERS = {
  42. 'simple': _test_ccallback.test_call_simple,
  43. 'nodata': _test_ccallback.test_call_nodata,
  44. 'nonlocal': _test_ccallback.test_call_nonlocal,
  45. 'cython': _test_ccallback_cython.test_call_cython,
  46. }
  47. # These functions have signatures known to the callers
  48. FUNCS = {
  49. 'python': lambda: callback_python,
  50. 'capsule': lambda: _test_ccallback.test_get_plus1_capsule(),
  51. 'cython': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1_cython"),
  52. 'ctypes': lambda: _test_ccallback_cython.plus1_ctypes,
  53. 'cffi': lambda: _get_cffi_func(_test_ccallback_cython.plus1_ctypes,
  54. 'double (*)(double, int *, void *)'),
  55. 'capsule_b': lambda: _test_ccallback.test_get_plus1b_capsule(),
  56. 'cython_b': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1b_cython"),
  57. 'ctypes_b': lambda: _test_ccallback_cython.plus1b_ctypes,
  58. 'cffi_b': lambda: _get_cffi_func(_test_ccallback_cython.plus1b_ctypes,
  59. 'double (*)(double, double, int *, void *)'),
  60. }
  61. # These functions have signatures the callers don't know
  62. BAD_FUNCS = {
  63. 'capsule_bc': lambda: _test_ccallback.test_get_plus1bc_capsule(),
  64. 'cython_bc': lambda: LowLevelCallable.from_cython(_test_ccallback_cython, "plus1bc_cython"),
  65. 'ctypes_bc': lambda: _test_ccallback_cython.plus1bc_ctypes,
  66. 'cffi_bc': lambda: _get_cffi_func(_test_ccallback_cython.plus1bc_ctypes,
  67. 'double (*)(double, double, double, int *, void *)'),
  68. }
  69. USER_DATAS = {
  70. 'ctypes': _get_ctypes_data,
  71. 'cffi': _get_cffi_data,
  72. 'capsule': _test_ccallback.test_get_data_capsule,
  73. }
  74. def test_callbacks():
  75. def check(caller, func, user_data):
  76. caller = CALLERS[caller]
  77. func = FUNCS[func]()
  78. user_data = USER_DATAS[user_data]()
  79. if func is callback_python:
  80. func2 = lambda x: func(x, 2.0)
  81. else:
  82. func2 = LowLevelCallable(func, user_data)
  83. func = LowLevelCallable(func)
  84. # Test basic call
  85. assert_equal(caller(func, 1.0), 2.0)
  86. # Test 'bad' value resulting to an error
  87. assert_raises(ValueError, caller, func, ERROR_VALUE)
  88. # Test passing in user_data
  89. assert_equal(caller(func2, 1.0), 3.0)
  90. for caller in sorted(CALLERS.keys()):
  91. for func in sorted(FUNCS.keys()):
  92. for user_data in sorted(USER_DATAS.keys()):
  93. check(caller, func, user_data)
  94. def test_bad_callbacks():
  95. def check(caller, func, user_data):
  96. caller = CALLERS[caller]
  97. user_data = USER_DATAS[user_data]()
  98. func = BAD_FUNCS[func]()
  99. if func is callback_python:
  100. func2 = lambda x: func(x, 2.0)
  101. else:
  102. func2 = LowLevelCallable(func, user_data)
  103. func = LowLevelCallable(func)
  104. # Test that basic call fails
  105. assert_raises(ValueError, caller, LowLevelCallable(func), 1.0)
  106. # Test that passing in user_data also fails
  107. assert_raises(ValueError, caller, func2, 1.0)
  108. # Test error message
  109. llfunc = LowLevelCallable(func)
  110. try:
  111. caller(llfunc, 1.0)
  112. except ValueError as err:
  113. msg = str(err)
  114. assert_(llfunc.signature in msg, msg)
  115. assert_('double (double, double, int *, void *)' in msg, msg)
  116. for caller in sorted(CALLERS.keys()):
  117. for func in sorted(BAD_FUNCS.keys()):
  118. for user_data in sorted(USER_DATAS.keys()):
  119. check(caller, func, user_data)
  120. def test_signature_override():
  121. caller = _test_ccallback.test_call_simple
  122. func = _test_ccallback.test_get_plus1_capsule()
  123. llcallable = LowLevelCallable(func, signature="bad signature")
  124. assert_equal(llcallable.signature, "bad signature")
  125. assert_raises(ValueError, caller, llcallable, 3)
  126. llcallable = LowLevelCallable(func, signature="double (double, int *, void *)")
  127. assert_equal(llcallable.signature, "double (double, int *, void *)")
  128. assert_equal(caller(llcallable, 3), 4)
  129. def test_threadsafety():
  130. def callback(a, caller):
  131. if a <= 0:
  132. return 1
  133. else:
  134. res = caller(lambda x: callback(x, caller), a - 1)
  135. return 2*res
  136. def check(caller):
  137. caller = CALLERS[caller]
  138. results = []
  139. count = 10
  140. def run():
  141. time.sleep(0.01)
  142. r = caller(lambda x: callback(x, caller), count)
  143. results.append(r)
  144. threads = [threading.Thread(target=run) for j in range(20)]
  145. for thread in threads:
  146. thread.start()
  147. for thread in threads:
  148. thread.join()
  149. assert_equal(results, [2.0**count]*len(threads))
  150. for caller in CALLERS.keys():
  151. check(caller)