test_fortran_parser.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. from sympy.testing.pytest import raises
  2. from sympy.parsing.sym_expr import SymPyExpression
  3. from sympy.external import import_module
  4. lfortran = import_module('lfortran')
  5. if lfortran:
  6. from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String,
  7. Return, FunctionDefinition, Assignment,
  8. Declaration, CodeBlock)
  9. from sympy.core import Integer, Float, Add
  10. from sympy.core.symbol import Symbol
  11. expr1 = SymPyExpression()
  12. expr2 = SymPyExpression()
  13. src = """\
  14. integer :: a, b, c, d
  15. real :: p, q, r, s
  16. """
  17. def test_sym_expr():
  18. src1 = (
  19. src +
  20. """\
  21. d = a + b -c
  22. """
  23. )
  24. expr3 = SymPyExpression(src,'f')
  25. expr4 = SymPyExpression(src1,'f')
  26. ls1 = expr3.return_expr()
  27. ls2 = expr4.return_expr()
  28. for i in range(0, 7):
  29. assert isinstance(ls1[i], Declaration)
  30. assert isinstance(ls2[i], Declaration)
  31. assert isinstance(ls2[8], Assignment)
  32. assert ls1[0] == Declaration(
  33. Variable(
  34. Symbol('a'),
  35. type = IntBaseType(String('integer')),
  36. value = Integer(0)
  37. )
  38. )
  39. assert ls1[1] == Declaration(
  40. Variable(
  41. Symbol('b'),
  42. type = IntBaseType(String('integer')),
  43. value = Integer(0)
  44. )
  45. )
  46. assert ls1[2] == Declaration(
  47. Variable(
  48. Symbol('c'),
  49. type = IntBaseType(String('integer')),
  50. value = Integer(0)
  51. )
  52. )
  53. assert ls1[3] == Declaration(
  54. Variable(
  55. Symbol('d'),
  56. type = IntBaseType(String('integer')),
  57. value = Integer(0)
  58. )
  59. )
  60. assert ls1[4] == Declaration(
  61. Variable(
  62. Symbol('p'),
  63. type = FloatBaseType(String('real')),
  64. value = Float(0.0)
  65. )
  66. )
  67. assert ls1[5] == Declaration(
  68. Variable(
  69. Symbol('q'),
  70. type = FloatBaseType(String('real')),
  71. value = Float(0.0)
  72. )
  73. )
  74. assert ls1[6] == Declaration(
  75. Variable(
  76. Symbol('r'),
  77. type = FloatBaseType(String('real')),
  78. value = Float(0.0)
  79. )
  80. )
  81. assert ls1[7] == Declaration(
  82. Variable(
  83. Symbol('s'),
  84. type = FloatBaseType(String('real')),
  85. value = Float(0.0)
  86. )
  87. )
  88. assert ls2[8] == Assignment(
  89. Variable(Symbol('d')),
  90. Symbol('a') + Symbol('b') - Symbol('c')
  91. )
  92. def test_assignment():
  93. src1 = (
  94. src +
  95. """\
  96. a = b
  97. c = d
  98. p = q
  99. r = s
  100. """
  101. )
  102. expr1.convert_to_expr(src1, 'f')
  103. ls1 = expr1.return_expr()
  104. for iter in range(0, 12):
  105. if iter < 8:
  106. assert isinstance(ls1[iter], Declaration)
  107. else:
  108. assert isinstance(ls1[iter], Assignment)
  109. assert ls1[8] == Assignment(
  110. Variable(Symbol('a')),
  111. Variable(Symbol('b'))
  112. )
  113. assert ls1[9] == Assignment(
  114. Variable(Symbol('c')),
  115. Variable(Symbol('d'))
  116. )
  117. assert ls1[10] == Assignment(
  118. Variable(Symbol('p')),
  119. Variable(Symbol('q'))
  120. )
  121. assert ls1[11] == Assignment(
  122. Variable(Symbol('r')),
  123. Variable(Symbol('s'))
  124. )
  125. def test_binop_add():
  126. src1 = (
  127. src +
  128. """\
  129. c = a + b
  130. d = a + c
  131. s = p + q + r
  132. """
  133. )
  134. expr1.convert_to_expr(src1, 'f')
  135. ls1 = expr1.return_expr()
  136. for iter in range(8, 11):
  137. assert isinstance(ls1[iter], Assignment)
  138. assert ls1[8] == Assignment(
  139. Variable(Symbol('c')),
  140. Symbol('a') + Symbol('b')
  141. )
  142. assert ls1[9] == Assignment(
  143. Variable(Symbol('d')),
  144. Symbol('a') + Symbol('c')
  145. )
  146. assert ls1[10] == Assignment(
  147. Variable(Symbol('s')),
  148. Symbol('p') + Symbol('q') + Symbol('r')
  149. )
  150. def test_binop_sub():
  151. src1 = (
  152. src +
  153. """\
  154. c = a - b
  155. d = a - c
  156. s = p - q - r
  157. """
  158. )
  159. expr1.convert_to_expr(src1, 'f')
  160. ls1 = expr1.return_expr()
  161. for iter in range(8, 11):
  162. assert isinstance(ls1[iter], Assignment)
  163. assert ls1[8] == Assignment(
  164. Variable(Symbol('c')),
  165. Symbol('a') - Symbol('b')
  166. )
  167. assert ls1[9] == Assignment(
  168. Variable(Symbol('d')),
  169. Symbol('a') - Symbol('c')
  170. )
  171. assert ls1[10] == Assignment(
  172. Variable(Symbol('s')),
  173. Symbol('p') - Symbol('q') - Symbol('r')
  174. )
  175. def test_binop_mul():
  176. src1 = (
  177. src +
  178. """\
  179. c = a * b
  180. d = a * c
  181. s = p * q * r
  182. """
  183. )
  184. expr1.convert_to_expr(src1, 'f')
  185. ls1 = expr1.return_expr()
  186. for iter in range(8, 11):
  187. assert isinstance(ls1[iter], Assignment)
  188. assert ls1[8] == Assignment(
  189. Variable(Symbol('c')),
  190. Symbol('a') * Symbol('b')
  191. )
  192. assert ls1[9] == Assignment(
  193. Variable(Symbol('d')),
  194. Symbol('a') * Symbol('c')
  195. )
  196. assert ls1[10] == Assignment(
  197. Variable(Symbol('s')),
  198. Symbol('p') * Symbol('q') * Symbol('r')
  199. )
  200. def test_binop_div():
  201. src1 = (
  202. src +
  203. """\
  204. c = a / b
  205. d = a / c
  206. s = p / q
  207. r = q / p
  208. """
  209. )
  210. expr1.convert_to_expr(src1, 'f')
  211. ls1 = expr1.return_expr()
  212. for iter in range(8, 12):
  213. assert isinstance(ls1[iter], Assignment)
  214. assert ls1[8] == Assignment(
  215. Variable(Symbol('c')),
  216. Symbol('a') / Symbol('b')
  217. )
  218. assert ls1[9] == Assignment(
  219. Variable(Symbol('d')),
  220. Symbol('a') / Symbol('c')
  221. )
  222. assert ls1[10] == Assignment(
  223. Variable(Symbol('s')),
  224. Symbol('p') / Symbol('q')
  225. )
  226. assert ls1[11] == Assignment(
  227. Variable(Symbol('r')),
  228. Symbol('q') / Symbol('p')
  229. )
  230. def test_mul_binop():
  231. src1 = (
  232. src +
  233. """\
  234. d = a + b - c
  235. c = a * b + d
  236. s = p * q / r
  237. r = p * s + q / p
  238. """
  239. )
  240. expr1.convert_to_expr(src1, 'f')
  241. ls1 = expr1.return_expr()
  242. for iter in range(8, 12):
  243. assert isinstance(ls1[iter], Assignment)
  244. assert ls1[8] == Assignment(
  245. Variable(Symbol('d')),
  246. Symbol('a') + Symbol('b') - Symbol('c')
  247. )
  248. assert ls1[9] == Assignment(
  249. Variable(Symbol('c')),
  250. Symbol('a') * Symbol('b') + Symbol('d')
  251. )
  252. assert ls1[10] == Assignment(
  253. Variable(Symbol('s')),
  254. Symbol('p') * Symbol('q') / Symbol('r')
  255. )
  256. assert ls1[11] == Assignment(
  257. Variable(Symbol('r')),
  258. Symbol('p') * Symbol('s') + Symbol('q') / Symbol('p')
  259. )
  260. def test_function():
  261. src1 = """\
  262. integer function f(a,b)
  263. integer :: x, y
  264. f = x + y
  265. end function
  266. """
  267. expr1.convert_to_expr(src1, 'f')
  268. for iter in expr1.return_expr():
  269. assert isinstance(iter, FunctionDefinition)
  270. assert iter == FunctionDefinition(
  271. IntBaseType(String('integer')),
  272. name=String('f'),
  273. parameters=(
  274. Variable(Symbol('a')),
  275. Variable(Symbol('b'))
  276. ),
  277. body=CodeBlock(
  278. Declaration(
  279. Variable(
  280. Symbol('a'),
  281. type=IntBaseType(String('integer')),
  282. value=Integer(0)
  283. )
  284. ),
  285. Declaration(
  286. Variable(
  287. Symbol('b'),
  288. type=IntBaseType(String('integer')),
  289. value=Integer(0)
  290. )
  291. ),
  292. Declaration(
  293. Variable(
  294. Symbol('f'),
  295. type=IntBaseType(String('integer')),
  296. value=Integer(0)
  297. )
  298. ),
  299. Declaration(
  300. Variable(
  301. Symbol('x'),
  302. type=IntBaseType(String('integer')),
  303. value=Integer(0)
  304. )
  305. ),
  306. Declaration(
  307. Variable(
  308. Symbol('y'),
  309. type=IntBaseType(String('integer')),
  310. value=Integer(0)
  311. )
  312. ),
  313. Assignment(
  314. Variable(Symbol('f')),
  315. Add(Symbol('x'), Symbol('y'))
  316. ),
  317. Return(Variable(Symbol('f')))
  318. )
  319. )
  320. def test_var():
  321. expr1.convert_to_expr(src, 'f')
  322. ls = expr1.return_expr()
  323. for iter in expr1.return_expr():
  324. assert isinstance(iter, Declaration)
  325. assert ls[0] == Declaration(
  326. Variable(
  327. Symbol('a'),
  328. type = IntBaseType(String('integer')),
  329. value = Integer(0)
  330. )
  331. )
  332. assert ls[1] == Declaration(
  333. Variable(
  334. Symbol('b'),
  335. type = IntBaseType(String('integer')),
  336. value = Integer(0)
  337. )
  338. )
  339. assert ls[2] == Declaration(
  340. Variable(
  341. Symbol('c'),
  342. type = IntBaseType(String('integer')),
  343. value = Integer(0)
  344. )
  345. )
  346. assert ls[3] == Declaration(
  347. Variable(
  348. Symbol('d'),
  349. type = IntBaseType(String('integer')),
  350. value = Integer(0)
  351. )
  352. )
  353. assert ls[4] == Declaration(
  354. Variable(
  355. Symbol('p'),
  356. type = FloatBaseType(String('real')),
  357. value = Float(0.0)
  358. )
  359. )
  360. assert ls[5] == Declaration(
  361. Variable(
  362. Symbol('q'),
  363. type = FloatBaseType(String('real')),
  364. value = Float(0.0)
  365. )
  366. )
  367. assert ls[6] == Declaration(
  368. Variable(
  369. Symbol('r'),
  370. type = FloatBaseType(String('real')),
  371. value = Float(0.0)
  372. )
  373. )
  374. assert ls[7] == Declaration(
  375. Variable(
  376. Symbol('s'),
  377. type = FloatBaseType(String('real')),
  378. value = Float(0.0)
  379. )
  380. )
  381. else:
  382. def test_raise():
  383. from sympy.parsing.fortran.fortran_parser import ASR2PyVisitor
  384. raises(ImportError, lambda: ASR2PyVisitor())
  385. raises(ImportError, lambda: SymPyExpression(' ', mode = 'f'))