test_symbolic.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. import pytest
  2. from numpy.f2py.symbolic import (
  3. Expr,
  4. Op,
  5. ArithOp,
  6. Language,
  7. as_symbol,
  8. as_number,
  9. as_string,
  10. as_array,
  11. as_complex,
  12. as_terms,
  13. as_factors,
  14. eliminate_quotes,
  15. insert_quotes,
  16. fromstring,
  17. as_expr,
  18. as_apply,
  19. as_numer_denom,
  20. as_ternary,
  21. as_ref,
  22. as_deref,
  23. normalize,
  24. as_eq,
  25. as_ne,
  26. as_lt,
  27. as_gt,
  28. as_le,
  29. as_ge,
  30. )
  31. from . import util
  32. class TestSymbolic(util.F2PyTest):
  33. def test_eliminate_quotes(self):
  34. def worker(s):
  35. r, d = eliminate_quotes(s)
  36. s1 = insert_quotes(r, d)
  37. assert s1 == s
  38. for kind in ["", "mykind_"]:
  39. worker(kind + '"1234" // "ABCD"')
  40. worker(kind + '"1234" // ' + kind + '"ABCD"')
  41. worker(kind + "\"1234\" // 'ABCD'")
  42. worker(kind + '"1234" // ' + kind + "'ABCD'")
  43. worker(kind + '"1\\"2\'AB\'34"')
  44. worker("a = " + kind + "'1\\'2\"AB\"34'")
  45. def test_sanity(self):
  46. x = as_symbol("x")
  47. y = as_symbol("y")
  48. z = as_symbol("z")
  49. assert x.op == Op.SYMBOL
  50. assert repr(x) == "Expr(Op.SYMBOL, 'x')"
  51. assert x == x
  52. assert x != y
  53. assert hash(x) is not None
  54. n = as_number(123)
  55. m = as_number(456)
  56. assert n.op == Op.INTEGER
  57. assert repr(n) == "Expr(Op.INTEGER, (123, 4))"
  58. assert n == n
  59. assert n != m
  60. assert hash(n) is not None
  61. fn = as_number(12.3)
  62. fm = as_number(45.6)
  63. assert fn.op == Op.REAL
  64. assert repr(fn) == "Expr(Op.REAL, (12.3, 4))"
  65. assert fn == fn
  66. assert fn != fm
  67. assert hash(fn) is not None
  68. c = as_complex(1, 2)
  69. c2 = as_complex(3, 4)
  70. assert c.op == Op.COMPLEX
  71. assert repr(c) == ("Expr(Op.COMPLEX, (Expr(Op.INTEGER, (1, 4)),"
  72. " Expr(Op.INTEGER, (2, 4))))")
  73. assert c == c
  74. assert c != c2
  75. assert hash(c) is not None
  76. s = as_string("'123'")
  77. s2 = as_string('"ABC"')
  78. assert s.op == Op.STRING
  79. assert repr(s) == "Expr(Op.STRING, (\"'123'\", 1))", repr(s)
  80. assert s == s
  81. assert s != s2
  82. a = as_array((n, m))
  83. b = as_array((n, ))
  84. assert a.op == Op.ARRAY
  85. assert repr(a) == ("Expr(Op.ARRAY, (Expr(Op.INTEGER, (123, 4)),"
  86. " Expr(Op.INTEGER, (456, 4))))")
  87. assert a == a
  88. assert a != b
  89. t = as_terms(x)
  90. u = as_terms(y)
  91. assert t.op == Op.TERMS
  92. assert repr(t) == "Expr(Op.TERMS, {Expr(Op.SYMBOL, 'x'): 1})"
  93. assert t == t
  94. assert t != u
  95. assert hash(t) is not None
  96. v = as_factors(x)
  97. w = as_factors(y)
  98. assert v.op == Op.FACTORS
  99. assert repr(v) == "Expr(Op.FACTORS, {Expr(Op.SYMBOL, 'x'): 1})"
  100. assert v == v
  101. assert w != v
  102. assert hash(v) is not None
  103. t = as_ternary(x, y, z)
  104. u = as_ternary(x, z, y)
  105. assert t.op == Op.TERNARY
  106. assert t == t
  107. assert t != u
  108. assert hash(t) is not None
  109. e = as_eq(x, y)
  110. f = as_lt(x, y)
  111. assert e.op == Op.RELATIONAL
  112. assert e == e
  113. assert e != f
  114. assert hash(e) is not None
  115. def test_tostring_fortran(self):
  116. x = as_symbol("x")
  117. y = as_symbol("y")
  118. z = as_symbol("z")
  119. n = as_number(123)
  120. m = as_number(456)
  121. a = as_array((n, m))
  122. c = as_complex(n, m)
  123. assert str(x) == "x"
  124. assert str(n) == "123"
  125. assert str(a) == "[123, 456]"
  126. assert str(c) == "(123, 456)"
  127. assert str(Expr(Op.TERMS, {x: 1})) == "x"
  128. assert str(Expr(Op.TERMS, {x: 2})) == "2 * x"
  129. assert str(Expr(Op.TERMS, {x: -1})) == "-x"
  130. assert str(Expr(Op.TERMS, {x: -2})) == "-2 * x"
  131. assert str(Expr(Op.TERMS, {x: 1, y: 1})) == "x + y"
  132. assert str(Expr(Op.TERMS, {x: -1, y: -1})) == "-x - y"
  133. assert str(Expr(Op.TERMS, {x: 2, y: 3})) == "2 * x + 3 * y"
  134. assert str(Expr(Op.TERMS, {x: -2, y: 3})) == "-2 * x + 3 * y"
  135. assert str(Expr(Op.TERMS, {x: 2, y: -3})) == "2 * x - 3 * y"
  136. assert str(Expr(Op.FACTORS, {x: 1})) == "x"
  137. assert str(Expr(Op.FACTORS, {x: 2})) == "x ** 2"
  138. assert str(Expr(Op.FACTORS, {x: -1})) == "x ** -1"
  139. assert str(Expr(Op.FACTORS, {x: -2})) == "x ** -2"
  140. assert str(Expr(Op.FACTORS, {x: 1, y: 1})) == "x * y"
  141. assert str(Expr(Op.FACTORS, {x: 2, y: 3})) == "x ** 2 * y ** 3"
  142. v = Expr(Op.FACTORS, {x: 2, Expr(Op.TERMS, {x: 1, y: 1}): 3})
  143. assert str(v) == "x ** 2 * (x + y) ** 3", str(v)
  144. v = Expr(Op.FACTORS, {x: 2, Expr(Op.FACTORS, {x: 1, y: 1}): 3})
  145. assert str(v) == "x ** 2 * (x * y) ** 3", str(v)
  146. assert str(Expr(Op.APPLY, ("f", (), {}))) == "f()"
  147. assert str(Expr(Op.APPLY, ("f", (x, ), {}))) == "f(x)"
  148. assert str(Expr(Op.APPLY, ("f", (x, y), {}))) == "f(x, y)"
  149. assert str(Expr(Op.INDEXING, ("f", x))) == "f[x]"
  150. assert str(as_ternary(x, y, z)) == "merge(y, z, x)"
  151. assert str(as_eq(x, y)) == "x .eq. y"
  152. assert str(as_ne(x, y)) == "x .ne. y"
  153. assert str(as_lt(x, y)) == "x .lt. y"
  154. assert str(as_le(x, y)) == "x .le. y"
  155. assert str(as_gt(x, y)) == "x .gt. y"
  156. assert str(as_ge(x, y)) == "x .ge. y"
  157. def test_tostring_c(self):
  158. language = Language.C
  159. x = as_symbol("x")
  160. y = as_symbol("y")
  161. z = as_symbol("z")
  162. n = as_number(123)
  163. assert Expr(Op.FACTORS, {x: 2}).tostring(language=language) == "x * x"
  164. assert (Expr(Op.FACTORS, {
  165. x + y: 2
  166. }).tostring(language=language) == "(x + y) * (x + y)")
  167. assert Expr(Op.FACTORS, {
  168. x: 12
  169. }).tostring(language=language) == "pow(x, 12)"
  170. assert as_apply(ArithOp.DIV, x,
  171. y).tostring(language=language) == "x / y"
  172. assert (as_apply(ArithOp.DIV, x,
  173. x + y).tostring(language=language) == "x / (x + y)")
  174. assert (as_apply(ArithOp.DIV, x - y, x +
  175. y).tostring(language=language) == "(x - y) / (x + y)")
  176. assert (x + (x - y) / (x + y) +
  177. n).tostring(language=language) == "123 + x + (x - y) / (x + y)"
  178. assert as_ternary(x, y, z).tostring(language=language) == "(x?y:z)"
  179. assert as_eq(x, y).tostring(language=language) == "x == y"
  180. assert as_ne(x, y).tostring(language=language) == "x != y"
  181. assert as_lt(x, y).tostring(language=language) == "x < y"
  182. assert as_le(x, y).tostring(language=language) == "x <= y"
  183. assert as_gt(x, y).tostring(language=language) == "x > y"
  184. assert as_ge(x, y).tostring(language=language) == "x >= y"
  185. def test_operations(self):
  186. x = as_symbol("x")
  187. y = as_symbol("y")
  188. z = as_symbol("z")
  189. assert x + x == Expr(Op.TERMS, {x: 2})
  190. assert x - x == Expr(Op.INTEGER, (0, 4))
  191. assert x + y == Expr(Op.TERMS, {x: 1, y: 1})
  192. assert x - y == Expr(Op.TERMS, {x: 1, y: -1})
  193. assert x * x == Expr(Op.FACTORS, {x: 2})
  194. assert x * y == Expr(Op.FACTORS, {x: 1, y: 1})
  195. assert +x == x
  196. assert -x == Expr(Op.TERMS, {x: -1}), repr(-x)
  197. assert 2 * x == Expr(Op.TERMS, {x: 2})
  198. assert 2 + x == Expr(Op.TERMS, {x: 1, as_number(1): 2})
  199. assert 2 * x + 3 * y == Expr(Op.TERMS, {x: 2, y: 3})
  200. assert (x + y) * 2 == Expr(Op.TERMS, {x: 2, y: 2})
  201. assert x**2 == Expr(Op.FACTORS, {x: 2})
  202. assert (x + y)**2 == Expr(
  203. Op.TERMS,
  204. {
  205. Expr(Op.FACTORS, {x: 2}): 1,
  206. Expr(Op.FACTORS, {y: 2}): 1,
  207. Expr(Op.FACTORS, {
  208. x: 1,
  209. y: 1
  210. }): 2,
  211. },
  212. )
  213. assert (x + y) * x == x**2 + x * y
  214. assert (x + y)**2 == x**2 + 2 * x * y + y**2
  215. assert (x + y)**2 + (x - y)**2 == 2 * x**2 + 2 * y**2
  216. assert (x + y) * z == x * z + y * z
  217. assert z * (x + y) == x * z + y * z
  218. assert (x / 2) == as_apply(ArithOp.DIV, x, as_number(2))
  219. assert (2 * x / 2) == x
  220. assert (3 * x / 2) == as_apply(ArithOp.DIV, 3 * x, as_number(2))
  221. assert (4 * x / 2) == 2 * x
  222. assert (5 * x / 2) == as_apply(ArithOp.DIV, 5 * x, as_number(2))
  223. assert (6 * x / 2) == 3 * x
  224. assert ((3 * 5) * x / 6) == as_apply(ArithOp.DIV, 5 * x, as_number(2))
  225. assert (30 * x**2 * y**4 / (24 * x**3 * y**3)) == as_apply(
  226. ArithOp.DIV, 5 * y, 4 * x)
  227. assert ((15 * x / 6) / 5) == as_apply(ArithOp.DIV, x,
  228. as_number(2)), (15 * x / 6) / 5
  229. assert (x / (5 / x)) == as_apply(ArithOp.DIV, x**2, as_number(5))
  230. assert (x / 2.0) == Expr(Op.TERMS, {x: 0.5})
  231. s = as_string('"ABC"')
  232. t = as_string('"123"')
  233. assert s // t == Expr(Op.STRING, ('"ABC123"', 1))
  234. assert s // x == Expr(Op.CONCAT, (s, x))
  235. assert x // s == Expr(Op.CONCAT, (x, s))
  236. c = as_complex(1.0, 2.0)
  237. assert -c == as_complex(-1.0, -2.0)
  238. assert c + c == as_expr((1 + 2j) * 2)
  239. assert c * c == as_expr((1 + 2j)**2)
  240. def test_substitute(self):
  241. x = as_symbol("x")
  242. y = as_symbol("y")
  243. z = as_symbol("z")
  244. a = as_array((x, y))
  245. assert x.substitute({x: y}) == y
  246. assert (x + y).substitute({x: z}) == y + z
  247. assert (x * y).substitute({x: z}) == y * z
  248. assert (x**4).substitute({x: z}) == z**4
  249. assert (x / y).substitute({x: z}) == z / y
  250. assert x.substitute({x: y + z}) == y + z
  251. assert a.substitute({x: y + z}) == as_array((y + z, y))
  252. assert as_ternary(x, y,
  253. z).substitute({x: y + z}) == as_ternary(y + z, y, z)
  254. assert as_eq(x, y).substitute({x: y + z}) == as_eq(y + z, y)
  255. def test_fromstring(self):
  256. x = as_symbol("x")
  257. y = as_symbol("y")
  258. z = as_symbol("z")
  259. f = as_symbol("f")
  260. s = as_string('"ABC"')
  261. t = as_string('"123"')
  262. a = as_array((x, y))
  263. assert fromstring("x") == x
  264. assert fromstring("+ x") == x
  265. assert fromstring("- x") == -x
  266. assert fromstring("x + y") == x + y
  267. assert fromstring("x + 1") == x + 1
  268. assert fromstring("x * y") == x * y
  269. assert fromstring("x * 2") == x * 2
  270. assert fromstring("x / y") == x / y
  271. assert fromstring("x ** 2", language=Language.Python) == x**2
  272. assert fromstring("x ** 2 ** 3", language=Language.Python) == x**2**3
  273. assert fromstring("(x + y) * z") == (x + y) * z
  274. assert fromstring("f(x)") == f(x)
  275. assert fromstring("f(x,y)") == f(x, y)
  276. assert fromstring("f[x]") == f[x]
  277. assert fromstring("f[x][y]") == f[x][y]
  278. assert fromstring('"ABC"') == s
  279. assert (normalize(
  280. fromstring('"ABC" // "123" ',
  281. language=Language.Fortran)) == s // t)
  282. assert fromstring('f("ABC")') == f(s)
  283. assert fromstring('MYSTRKIND_"ABC"') == as_string('"ABC"', "MYSTRKIND")
  284. assert fromstring("(/x, y/)") == a, fromstring("(/x, y/)")
  285. assert fromstring("f((/x, y/))") == f(a)
  286. assert fromstring("(/(x+y)*z/)") == as_array(((x + y) * z, ))
  287. assert fromstring("123") == as_number(123)
  288. assert fromstring("123_2") == as_number(123, 2)
  289. assert fromstring("123_myintkind") == as_number(123, "myintkind")
  290. assert fromstring("123.0") == as_number(123.0, 4)
  291. assert fromstring("123.0_4") == as_number(123.0, 4)
  292. assert fromstring("123.0_8") == as_number(123.0, 8)
  293. assert fromstring("123.0e0") == as_number(123.0, 4)
  294. assert fromstring("123.0d0") == as_number(123.0, 8)
  295. assert fromstring("123d0") == as_number(123.0, 8)
  296. assert fromstring("123e-0") == as_number(123.0, 4)
  297. assert fromstring("123d+0") == as_number(123.0, 8)
  298. assert fromstring("123.0_myrealkind") == as_number(123.0, "myrealkind")
  299. assert fromstring("3E4") == as_number(30000.0, 4)
  300. assert fromstring("(1, 2)") == as_complex(1, 2)
  301. assert fromstring("(1e2, PI)") == as_complex(as_number(100.0),
  302. as_symbol("PI"))
  303. assert fromstring("[1, 2]") == as_array((as_number(1), as_number(2)))
  304. assert fromstring("POINT(x, y=1)") == as_apply(as_symbol("POINT"),
  305. x,
  306. y=as_number(1))
  307. assert fromstring(
  308. 'PERSON(name="John", age=50, shape=(/34, 23/))') == as_apply(
  309. as_symbol("PERSON"),
  310. name=as_string('"John"'),
  311. age=as_number(50),
  312. shape=as_array((as_number(34), as_number(23))),
  313. )
  314. assert fromstring("x?y:z") == as_ternary(x, y, z)
  315. assert fromstring("*x") == as_deref(x)
  316. assert fromstring("**x") == as_deref(as_deref(x))
  317. assert fromstring("&x") == as_ref(x)
  318. assert fromstring("(*x) * (*y)") == as_deref(x) * as_deref(y)
  319. assert fromstring("(*x) * *y") == as_deref(x) * as_deref(y)
  320. assert fromstring("*x * *y") == as_deref(x) * as_deref(y)
  321. assert fromstring("*x**y") == as_deref(x) * as_deref(y)
  322. assert fromstring("x == y") == as_eq(x, y)
  323. assert fromstring("x != y") == as_ne(x, y)
  324. assert fromstring("x < y") == as_lt(x, y)
  325. assert fromstring("x > y") == as_gt(x, y)
  326. assert fromstring("x <= y") == as_le(x, y)
  327. assert fromstring("x >= y") == as_ge(x, y)
  328. assert fromstring("x .eq. y", language=Language.Fortran) == as_eq(x, y)
  329. assert fromstring("x .ne. y", language=Language.Fortran) == as_ne(x, y)
  330. assert fromstring("x .lt. y", language=Language.Fortran) == as_lt(x, y)
  331. assert fromstring("x .gt. y", language=Language.Fortran) == as_gt(x, y)
  332. assert fromstring("x .le. y", language=Language.Fortran) == as_le(x, y)
  333. assert fromstring("x .ge. y", language=Language.Fortran) == as_ge(x, y)
  334. def test_traverse(self):
  335. x = as_symbol("x")
  336. y = as_symbol("y")
  337. z = as_symbol("z")
  338. f = as_symbol("f")
  339. # Use traverse to substitute a symbol
  340. def replace_visit(s, r=z):
  341. if s == x:
  342. return r
  343. assert x.traverse(replace_visit) == z
  344. assert y.traverse(replace_visit) == y
  345. assert z.traverse(replace_visit) == z
  346. assert (f(y)).traverse(replace_visit) == f(y)
  347. assert (f(x)).traverse(replace_visit) == f(z)
  348. assert (f[y]).traverse(replace_visit) == f[y]
  349. assert (f[z]).traverse(replace_visit) == f[z]
  350. assert (x + y + z).traverse(replace_visit) == (2 * z + y)
  351. assert (x +
  352. f(y, x - z)).traverse(replace_visit) == (z +
  353. f(y, as_number(0)))
  354. assert as_eq(x, y).traverse(replace_visit) == as_eq(z, y)
  355. # Use traverse to collect symbols, method 1
  356. function_symbols = set()
  357. symbols = set()
  358. def collect_symbols(s):
  359. if s.op is Op.APPLY:
  360. oper = s.data[0]
  361. function_symbols.add(oper)
  362. if oper in symbols:
  363. symbols.remove(oper)
  364. elif s.op is Op.SYMBOL and s not in function_symbols:
  365. symbols.add(s)
  366. (x + f(y, x - z)).traverse(collect_symbols)
  367. assert function_symbols == {f}
  368. assert symbols == {x, y, z}
  369. # Use traverse to collect symbols, method 2
  370. def collect_symbols2(expr, symbols):
  371. if expr.op is Op.SYMBOL:
  372. symbols.add(expr)
  373. symbols = set()
  374. (x + f(y, x - z)).traverse(collect_symbols2, symbols)
  375. assert symbols == {x, y, z, f}
  376. # Use traverse to partially collect symbols
  377. def collect_symbols3(expr, symbols):
  378. if expr.op is Op.APPLY:
  379. # skip traversing function calls
  380. return expr
  381. if expr.op is Op.SYMBOL:
  382. symbols.add(expr)
  383. symbols = set()
  384. (x + f(y, x - z)).traverse(collect_symbols3, symbols)
  385. assert symbols == {x}
  386. def test_linear_solve(self):
  387. x = as_symbol("x")
  388. y = as_symbol("y")
  389. z = as_symbol("z")
  390. assert x.linear_solve(x) == (as_number(1), as_number(0))
  391. assert (x + 1).linear_solve(x) == (as_number(1), as_number(1))
  392. assert (2 * x).linear_solve(x) == (as_number(2), as_number(0))
  393. assert (2 * x + 3).linear_solve(x) == (as_number(2), as_number(3))
  394. assert as_number(3).linear_solve(x) == (as_number(0), as_number(3))
  395. assert y.linear_solve(x) == (as_number(0), y)
  396. assert (y * z).linear_solve(x) == (as_number(0), y * z)
  397. assert (x + y).linear_solve(x) == (as_number(1), y)
  398. assert (z * x + y).linear_solve(x) == (z, y)
  399. assert ((z + y) * x + y).linear_solve(x) == (z + y, y)
  400. assert (z * y * x + y).linear_solve(x) == (z * y, y)
  401. pytest.raises(RuntimeError, lambda: (x * x).linear_solve(x))
  402. def test_as_numer_denom(self):
  403. x = as_symbol("x")
  404. y = as_symbol("y")
  405. n = as_number(123)
  406. assert as_numer_denom(x) == (x, as_number(1))
  407. assert as_numer_denom(x / n) == (x, n)
  408. assert as_numer_denom(n / x) == (n, x)
  409. assert as_numer_denom(x / y) == (x, y)
  410. assert as_numer_denom(x * y) == (x * y, as_number(1))
  411. assert as_numer_denom(n + x / y) == (x + n * y, y)
  412. assert as_numer_denom(n + x / (y - x / n)) == (y * n**2, y * n - x)
  413. def test_polynomial_atoms(self):
  414. x = as_symbol("x")
  415. y = as_symbol("y")
  416. n = as_number(123)
  417. assert x.polynomial_atoms() == {x}
  418. assert n.polynomial_atoms() == set()
  419. assert (y[x]).polynomial_atoms() == {y[x]}
  420. assert (y(x)).polynomial_atoms() == {y(x)}
  421. assert (y(x) + x).polynomial_atoms() == {y(x), x}
  422. assert (y(x) * x[y]).polynomial_atoms() == {y(x), x[y]}
  423. assert (y(x)**x).polynomial_atoms() == {y(x)}