extending.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import numpy as np
  2. import numba as nb
  3. from numpy.random import PCG64
  4. from timeit import timeit
  5. bit_gen = PCG64()
  6. next_d = bit_gen.cffi.next_double
  7. state_addr = bit_gen.cffi.state_address
  8. def normals(n, state):
  9. out = np.empty(n)
  10. for i in range((n + 1) // 2):
  11. x1 = 2.0 * next_d(state) - 1.0
  12. x2 = 2.0 * next_d(state) - 1.0
  13. r2 = x1 * x1 + x2 * x2
  14. while r2 >= 1.0 or r2 == 0.0:
  15. x1 = 2.0 * next_d(state) - 1.0
  16. x2 = 2.0 * next_d(state) - 1.0
  17. r2 = x1 * x1 + x2 * x2
  18. f = np.sqrt(-2.0 * np.log(r2) / r2)
  19. out[2 * i] = f * x1
  20. if 2 * i + 1 < n:
  21. out[2 * i + 1] = f * x2
  22. return out
  23. # Compile using Numba
  24. normalsj = nb.jit(normals, nopython=True)
  25. # Must use state address not state with numba
  26. n = 10000
  27. def numbacall():
  28. return normalsj(n, state_addr)
  29. rg = np.random.Generator(PCG64())
  30. def numpycall():
  31. return rg.normal(size=n)
  32. # Check that the functions work
  33. r1 = numbacall()
  34. r2 = numpycall()
  35. assert r1.shape == (n,)
  36. assert r1.shape == r2.shape
  37. t1 = timeit(numbacall, number=1000)
  38. print(f'{t1:.2f} secs for {n} PCG64 (Numba/PCG64) gaussian randoms')
  39. t2 = timeit(numpycall, number=1000)
  40. print(f'{t2:.2f} secs for {n} PCG64 (NumPy/PCG64) gaussian randoms')
  41. # example 2
  42. next_u32 = bit_gen.ctypes.next_uint32
  43. ctypes_state = bit_gen.ctypes.state
  44. @nb.jit(nopython=True)
  45. def bounded_uint(lb, ub, state):
  46. mask = delta = ub - lb
  47. mask |= mask >> 1
  48. mask |= mask >> 2
  49. mask |= mask >> 4
  50. mask |= mask >> 8
  51. mask |= mask >> 16
  52. val = next_u32(state) & mask
  53. while val > delta:
  54. val = next_u32(state) & mask
  55. return lb + val
  56. print(bounded_uint(323, 2394691, ctypes_state.value))
  57. @nb.jit(nopython=True)
  58. def bounded_uints(lb, ub, n, state):
  59. out = np.empty(n, dtype=np.uint32)
  60. for i in range(n):
  61. out[i] = bounded_uint(lb, ub, state)
  62. bounded_uints(323, 2394691, 10000000, ctypes_state.value)