test_sym_expr.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from sympy.parsing.sym_expr import SymPyExpression
  2. from sympy.testing.pytest import raises
  3. from sympy.external import import_module
  4. lfortran = import_module('lfortran')
  5. cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']})
  6. if lfortran and cin:
  7. from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String,
  8. Declaration, FloatType)
  9. from sympy.core import Integer, Float
  10. from sympy.core.symbol import Symbol
  11. expr1 = SymPyExpression()
  12. src = """\
  13. integer :: a, b, c, d
  14. real :: p, q, r, s
  15. """
  16. def test_c_parse():
  17. src1 = """\
  18. int a, b = 4;
  19. float c, d = 2.4;
  20. """
  21. expr1.convert_to_expr(src1, 'c')
  22. ls = expr1.return_expr()
  23. assert ls[0] == Declaration(
  24. Variable(
  25. Symbol('a'),
  26. type=IntBaseType(String('intc'))
  27. )
  28. )
  29. assert ls[1] == Declaration(
  30. Variable(
  31. Symbol('b'),
  32. type=IntBaseType(String('intc')),
  33. value=Integer(4)
  34. )
  35. )
  36. assert ls[2] == Declaration(
  37. Variable(
  38. Symbol('c'),
  39. type=FloatType(
  40. String('float32'),
  41. nbits=Integer(32),
  42. nmant=Integer(23),
  43. nexp=Integer(8)
  44. )
  45. )
  46. )
  47. assert ls[3] == Declaration(
  48. Variable(
  49. Symbol('d'),
  50. type=FloatType(
  51. String('float32'),
  52. nbits=Integer(32),
  53. nmant=Integer(23),
  54. nexp=Integer(8)
  55. ),
  56. value=Float('2.3999999999999999', precision=53)
  57. )
  58. )
  59. def test_fortran_parse():
  60. expr = SymPyExpression(src, 'f')
  61. ls = expr.return_expr()
  62. assert ls[0] == Declaration(
  63. Variable(
  64. Symbol('a'),
  65. type=IntBaseType(String('integer')),
  66. value=Integer(0)
  67. )
  68. )
  69. assert ls[1] == Declaration(
  70. Variable(
  71. Symbol('b'),
  72. type=IntBaseType(String('integer')),
  73. value=Integer(0)
  74. )
  75. )
  76. assert ls[2] == Declaration(
  77. Variable(
  78. Symbol('c'),
  79. type=IntBaseType(String('integer')),
  80. value=Integer(0)
  81. )
  82. )
  83. assert ls[3] == Declaration(
  84. Variable(
  85. Symbol('d'),
  86. type=IntBaseType(String('integer')),
  87. value=Integer(0)
  88. )
  89. )
  90. assert ls[4] == Declaration(
  91. Variable(
  92. Symbol('p'),
  93. type=FloatBaseType(String('real')),
  94. value=Float('0.0', precision=53)
  95. )
  96. )
  97. assert ls[5] == Declaration(
  98. Variable(
  99. Symbol('q'),
  100. type=FloatBaseType(String('real')),
  101. value=Float('0.0', precision=53)
  102. )
  103. )
  104. assert ls[6] == Declaration(
  105. Variable(
  106. Symbol('r'),
  107. type=FloatBaseType(String('real')),
  108. value=Float('0.0', precision=53)
  109. )
  110. )
  111. assert ls[7] == Declaration(
  112. Variable(
  113. Symbol('s'),
  114. type=FloatBaseType(String('real')),
  115. value=Float('0.0', precision=53)
  116. )
  117. )
  118. def test_convert_py():
  119. src1 = (
  120. src +
  121. """\
  122. a = b + c
  123. s = p * q / r
  124. """
  125. )
  126. expr1.convert_to_expr(src1, 'f')
  127. exp_py = expr1.convert_to_python()
  128. assert exp_py == [
  129. 'a = 0',
  130. 'b = 0',
  131. 'c = 0',
  132. 'd = 0',
  133. 'p = 0.0',
  134. 'q = 0.0',
  135. 'r = 0.0',
  136. 's = 0.0',
  137. 'a = b + c',
  138. 's = p*q/r'
  139. ]
  140. def test_convert_fort():
  141. src1 = (
  142. src +
  143. """\
  144. a = b + c
  145. s = p * q / r
  146. """
  147. )
  148. expr1.convert_to_expr(src1, 'f')
  149. exp_fort = expr1.convert_to_fortran()
  150. assert exp_fort == [
  151. ' integer*4 a',
  152. ' integer*4 b',
  153. ' integer*4 c',
  154. ' integer*4 d',
  155. ' real*8 p',
  156. ' real*8 q',
  157. ' real*8 r',
  158. ' real*8 s',
  159. ' a = b + c',
  160. ' s = p*q/r'
  161. ]
  162. def test_convert_c():
  163. src1 = (
  164. src +
  165. """\
  166. a = b + c
  167. s = p * q / r
  168. """
  169. )
  170. expr1.convert_to_expr(src1, 'f')
  171. exp_c = expr1.convert_to_c()
  172. assert exp_c == [
  173. 'int a = 0',
  174. 'int b = 0',
  175. 'int c = 0',
  176. 'int d = 0',
  177. 'double p = 0.0',
  178. 'double q = 0.0',
  179. 'double r = 0.0',
  180. 'double s = 0.0',
  181. 'a = b + c;',
  182. 's = p*q/r;'
  183. ]
  184. def test_exceptions():
  185. src = 'int a;'
  186. raises(ValueError, lambda: SymPyExpression(src))
  187. raises(ValueError, lambda: SymPyExpression(mode = 'c'))
  188. raises(NotImplementedError, lambda: SymPyExpression(src, mode = 'd'))
  189. elif not lfortran and not cin:
  190. def test_raise():
  191. raises(ImportError, lambda: SymPyExpression('int a;', 'c'))
  192. raises(ImportError, lambda: SymPyExpression('integer :: a', 'f'))