test_llvmjit.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. from sympy.external import import_module
  2. from sympy.testing.pytest import raises
  3. import ctypes
  4. if import_module('llvmlite'):
  5. import sympy.printing.llvmjitcode as g
  6. else:
  7. disabled = True
  8. import sympy
  9. from sympy.abc import a, b, n
  10. # copied from numpy.isclose documentation
  11. def isclose(a, b):
  12. rtol = 1e-5
  13. atol = 1e-8
  14. return abs(a-b) <= atol + rtol*abs(b)
  15. def test_simple_expr():
  16. e = a + 1.0
  17. f = g.llvm_callable([a], e)
  18. res = float(e.subs({a: 4.0}).evalf())
  19. jit_res = f(4.0)
  20. assert isclose(jit_res, res)
  21. def test_two_arg():
  22. e = 4.0*a + b + 3.0
  23. f = g.llvm_callable([a, b], e)
  24. res = float(e.subs({a: 4.0, b: 3.0}).evalf())
  25. jit_res = f(4.0, 3.0)
  26. assert isclose(jit_res, res)
  27. def test_func():
  28. e = 4.0*sympy.exp(-a)
  29. f = g.llvm_callable([a], e)
  30. res = float(e.subs({a: 1.5}).evalf())
  31. jit_res = f(1.5)
  32. assert isclose(jit_res, res)
  33. def test_two_func():
  34. e = 4.0*sympy.exp(-a) + sympy.exp(b)
  35. f = g.llvm_callable([a, b], e)
  36. res = float(e.subs({a: 1.5, b: 2.0}).evalf())
  37. jit_res = f(1.5, 2.0)
  38. assert isclose(jit_res, res)
  39. def test_two_sqrt():
  40. e = 4.0*sympy.sqrt(a) + sympy.sqrt(b)
  41. f = g.llvm_callable([a, b], e)
  42. res = float(e.subs({a: 1.5, b: 2.0}).evalf())
  43. jit_res = f(1.5, 2.0)
  44. assert isclose(jit_res, res)
  45. def test_two_pow():
  46. e = a**1.5 + b**7
  47. f = g.llvm_callable([a, b], e)
  48. res = float(e.subs({a: 1.5, b: 2.0}).evalf())
  49. jit_res = f(1.5, 2.0)
  50. assert isclose(jit_res, res)
  51. def test_callback():
  52. e = a + 1.2
  53. f = g.llvm_callable([a], e, callback_type='scipy.integrate.test')
  54. m = ctypes.c_int(1)
  55. array_type = ctypes.c_double * 1
  56. inp = {a: 2.2}
  57. array = array_type(inp[a])
  58. jit_res = f(m, array)
  59. res = float(e.subs(inp).evalf())
  60. assert isclose(jit_res, res)
  61. def test_callback_cubature():
  62. e = a + 1.2
  63. f = g.llvm_callable([a], e, callback_type='cubature')
  64. m = ctypes.c_int(1)
  65. array_type = ctypes.c_double * 1
  66. inp = {a: 2.2}
  67. array = array_type(inp[a])
  68. out_array = array_type(0.0)
  69. jit_ret = f(m, array, None, m, out_array)
  70. assert jit_ret == 0
  71. res = float(e.subs(inp).evalf())
  72. assert isclose(out_array[0], res)
  73. def test_callback_two():
  74. e = 3*a*b
  75. f = g.llvm_callable([a, b], e, callback_type='scipy.integrate.test')
  76. m = ctypes.c_int(2)
  77. array_type = ctypes.c_double * 2
  78. inp = {a: 0.2, b: 1.7}
  79. array = array_type(inp[a], inp[b])
  80. jit_res = f(m, array)
  81. res = float(e.subs(inp).evalf())
  82. assert isclose(jit_res, res)
  83. def test_callback_alt_two():
  84. d = sympy.IndexedBase('d')
  85. e = 3*d[0]*d[1]
  86. f = g.llvm_callable([n, d], e, callback_type='scipy.integrate.test')
  87. m = ctypes.c_int(2)
  88. array_type = ctypes.c_double * 2
  89. inp = {d[0]: 0.2, d[1]: 1.7}
  90. array = array_type(inp[d[0]], inp[d[1]])
  91. jit_res = f(m, array)
  92. res = float(e.subs(inp).evalf())
  93. assert isclose(jit_res, res)
  94. def test_multiple_statements():
  95. # Match return from CSE
  96. e = [[(b, 4.0*a)], [b + 5]]
  97. f = g.llvm_callable([a], e)
  98. b_val = e[0][0][1].subs({a: 1.5})
  99. res = float(e[1][0].subs({b: b_val}).evalf())
  100. jit_res = f(1.5)
  101. assert isclose(jit_res, res)
  102. f_callback = g.llvm_callable([a], e, callback_type='scipy.integrate.test')
  103. m = ctypes.c_int(1)
  104. array_type = ctypes.c_double * 1
  105. array = array_type(1.5)
  106. jit_callback_res = f_callback(m, array)
  107. assert isclose(jit_callback_res, res)
  108. def test_cse():
  109. e = a*a + b*b + sympy.exp(-a*a - b*b)
  110. e2 = sympy.cse(e)
  111. f = g.llvm_callable([a, b], e2)
  112. res = float(e.subs({a: 2.3, b: 0.1}).evalf())
  113. jit_res = f(2.3, 0.1)
  114. assert isclose(jit_res, res)
  115. def eval_cse(e, sub_dict):
  116. tmp_dict = {}
  117. for tmp_name, tmp_expr in e[0]:
  118. e2 = tmp_expr.subs(sub_dict)
  119. e3 = e2.subs(tmp_dict)
  120. tmp_dict[tmp_name] = e3
  121. return [e.subs(sub_dict).subs(tmp_dict) for e in e[1]]
  122. def test_cse_multiple():
  123. e1 = a*a
  124. e2 = a*a + b*b
  125. e3 = sympy.cse([e1, e2])
  126. raises(NotImplementedError,
  127. lambda: g.llvm_callable([a, b], e3, callback_type='scipy.integrate'))
  128. f = g.llvm_callable([a, b], e3)
  129. jit_res = f(0.1, 1.5)
  130. assert len(jit_res) == 2
  131. res = eval_cse(e3, {a: 0.1, b: 1.5})
  132. assert isclose(res[0], jit_res[0])
  133. assert isclose(res[1], jit_res[1])
  134. def test_callback_cubature_multiple():
  135. e1 = a*a
  136. e2 = a*a + b*b
  137. e3 = sympy.cse([e1, e2, 4*e2])
  138. f = g.llvm_callable([a, b], e3, callback_type='cubature')
  139. # Number of input variables
  140. ndim = 2
  141. # Number of output expression values
  142. outdim = 3
  143. m = ctypes.c_int(ndim)
  144. fdim = ctypes.c_int(outdim)
  145. array_type = ctypes.c_double * ndim
  146. out_array_type = ctypes.c_double * outdim
  147. inp = {a: 0.2, b: 1.5}
  148. array = array_type(inp[a], inp[b])
  149. out_array = out_array_type()
  150. jit_ret = f(m, array, None, fdim, out_array)
  151. assert jit_ret == 0
  152. res = eval_cse(e3, inp)
  153. assert isclose(out_array[0], res[0])
  154. assert isclose(out_array[1], res[1])
  155. assert isclose(out_array[2], res[2])
  156. def test_symbol_not_found():
  157. e = a*a + b
  158. raises(LookupError, lambda: g.llvm_callable([a], e))
  159. def test_bad_callback():
  160. e = a
  161. raises(ValueError, lambda: g.llvm_callable([a], e, callback_type='bad_callback'))