test__threadsafety.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import threading
  2. import time
  3. import traceback
  4. from numpy.testing import assert_
  5. from pytest import raises as assert_raises
  6. from scipy._lib._threadsafety import ReentrancyLock, non_reentrant, ReentrancyError
  7. def test_parallel_threads():
  8. # Check that ReentrancyLock serializes work in parallel threads.
  9. #
  10. # The test is not fully deterministic, and may succeed falsely if
  11. # the timings go wrong.
  12. lock = ReentrancyLock("failure")
  13. failflag = [False]
  14. exceptions_raised = []
  15. def worker(k):
  16. try:
  17. with lock:
  18. assert_(not failflag[0])
  19. failflag[0] = True
  20. time.sleep(0.1 * k)
  21. assert_(failflag[0])
  22. failflag[0] = False
  23. except Exception:
  24. exceptions_raised.append(traceback.format_exc(2))
  25. threads = [threading.Thread(target=lambda k=k: worker(k))
  26. for k in range(3)]
  27. for t in threads:
  28. t.start()
  29. for t in threads:
  30. t.join()
  31. exceptions_raised = "\n".join(exceptions_raised)
  32. assert_(not exceptions_raised, exceptions_raised)
  33. def test_reentering():
  34. # Check that ReentrancyLock prevents re-entering from the same thread.
  35. @non_reentrant()
  36. def func(x):
  37. return func(x)
  38. assert_raises(ReentrancyError, func, 0)