test_smtlib.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. import contextlib
  2. import itertools
  3. import re
  4. import typing
  5. from enum import Enum
  6. from typing import Callable
  7. import sympy
  8. from sympy import Add, Implies, sqrt
  9. from sympy.core import Mul, Pow
  10. from sympy.core import (S, pi, symbols, Function, Rational, Integer,
  11. Symbol, Eq, Ne, Le, Lt, Gt, Ge)
  12. from sympy.functions import Piecewise, exp, sin, cos
  13. from sympy.printing.smtlib import smtlib_code
  14. from sympy.testing.pytest import raises, Failed
  15. x, y, z = symbols('x,y,z')
  16. class _W(Enum):
  17. DEFAULTING_TO_FLOAT = re.compile("Could not infer type of `.+`. Defaulting to float.", re.I)
  18. WILL_NOT_DECLARE = re.compile("Non-Symbol/Function `.+` will not be declared.", re.I)
  19. WILL_NOT_ASSERT = re.compile("Non-Boolean expression `.+` will not be asserted. Converting to SMTLib verbatim.", re.I)
  20. @contextlib.contextmanager
  21. def _check_warns(expected: typing.Iterable[_W]):
  22. warns: typing.List[str] = []
  23. log_warn = warns.append
  24. yield log_warn
  25. errors = []
  26. for i, (w, e) in enumerate(itertools.zip_longest(warns, expected)):
  27. if not e:
  28. errors += [f"[{i}] Received unexpected warning `{w}`."]
  29. elif not w:
  30. errors += [f"[{i}] Did not receive expected warning `{e.name}`."]
  31. elif not e.value.match(w):
  32. errors += [f"[{i}] Warning `{w}` does not match expected {e.name}."]
  33. if errors: raise Failed('\n'.join(errors))
  34. def test_Integer():
  35. with _check_warns([_W.WILL_NOT_ASSERT] * 2) as w:
  36. assert smtlib_code(Integer(67), log_warn=w) == "67"
  37. assert smtlib_code(Integer(-1), log_warn=w) == "-1"
  38. with _check_warns([]) as w:
  39. assert smtlib_code(Integer(67)) == "67"
  40. assert smtlib_code(Integer(-1)) == "-1"
  41. def test_Rational():
  42. with _check_warns([_W.WILL_NOT_ASSERT] * 4) as w:
  43. assert smtlib_code(Rational(3, 7), log_warn=w) == "(/ 3 7)"
  44. assert smtlib_code(Rational(18, 9), log_warn=w) == "2"
  45. assert smtlib_code(Rational(3, -7), log_warn=w) == "(/ -3 7)"
  46. assert smtlib_code(Rational(-3, -7), log_warn=w) == "(/ 3 7)"
  47. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT] * 2) as w:
  48. assert smtlib_code(x + Rational(3, 7), auto_declare=False, log_warn=w) == "(+ (/ 3 7) x)"
  49. assert smtlib_code(Rational(3, 7) * x, log_warn=w) == "(declare-const x Real)\n" \
  50. "(* (/ 3 7) x)"
  51. def test_Relational():
  52. with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w:
  53. assert smtlib_code(Eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))"
  54. assert smtlib_code(Ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))"
  55. assert smtlib_code(Le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))"
  56. assert smtlib_code(Lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))"
  57. assert smtlib_code(Gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))"
  58. assert smtlib_code(Ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))"
  59. def test_Function():
  60. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  61. assert smtlib_code(sin(x) ** cos(x), auto_declare=False, log_warn=w) == "(pow (sin x) (cos x))"
  62. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  63. assert smtlib_code(
  64. abs(x),
  65. symbol_table={x: int, y: bool},
  66. known_types={int: "INTEGER_TYPE"},
  67. known_functions={sympy.Abs: "ABSOLUTE_VALUE_OF"},
  68. log_warn=w
  69. ) == "(declare-const x INTEGER_TYPE)\n" \
  70. "(ABSOLUTE_VALUE_OF x)"
  71. my_fun1 = Function('f1')
  72. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  73. assert smtlib_code(
  74. my_fun1(x),
  75. symbol_table={my_fun1: Callable[[bool], float]},
  76. log_warn=w
  77. ) == "(declare-const x Bool)\n" \
  78. "(declare-fun f1 (Bool) Real)\n" \
  79. "(f1 x)"
  80. with _check_warns([]) as w:
  81. assert smtlib_code(
  82. my_fun1(x),
  83. symbol_table={my_fun1: Callable[[bool], bool]},
  84. log_warn=w
  85. ) == "(declare-const x Bool)\n" \
  86. "(declare-fun f1 (Bool) Bool)\n" \
  87. "(assert (f1 x))"
  88. assert smtlib_code(
  89. Eq(my_fun1(x, z), y),
  90. symbol_table={my_fun1: Callable[[int, bool], bool]},
  91. log_warn=w
  92. ) == "(declare-const x Int)\n" \
  93. "(declare-const y Bool)\n" \
  94. "(declare-const z Bool)\n" \
  95. "(declare-fun f1 (Int Bool) Bool)\n" \
  96. "(assert (= (f1 x z) y))"
  97. assert smtlib_code(
  98. Eq(my_fun1(x, z), y),
  99. symbol_table={my_fun1: Callable[[int, bool], bool]},
  100. known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
  101. log_warn=w
  102. ) == "(declare-const x Int)\n" \
  103. "(declare-const y Bool)\n" \
  104. "(declare-const z Bool)\n" \
  105. "(assert (== (MY_KNOWN_FUN x z) y))"
  106. with _check_warns([_W.DEFAULTING_TO_FLOAT] * 3) as w:
  107. assert smtlib_code(
  108. Eq(my_fun1(x, z), y),
  109. known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
  110. log_warn=w
  111. ) == "(declare-const x Real)\n" \
  112. "(declare-const y Real)\n" \
  113. "(declare-const z Real)\n" \
  114. "(assert (== (MY_KNOWN_FUN x z) y))"
  115. def test_Pow():
  116. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  117. assert smtlib_code(x ** 3, auto_declare=False, log_warn=w) == "(pow x 3)"
  118. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  119. assert smtlib_code(x ** (y ** 3), auto_declare=False, log_warn=w) == "(pow x (pow y 3))"
  120. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  121. assert smtlib_code(x ** Rational(2, 3), auto_declare=False, log_warn=w) == '(pow x (/ 2 3))'
  122. a = Symbol('a', integer=True)
  123. b = Symbol('b', real=True)
  124. c = Symbol('c')
  125. def g(x): return 2 * x
  126. # if x=1, y=2, then expr=2.333...
  127. expr = 1 / (g(a) * 3.5) ** (a - b ** a) / (a ** 2 + b)
  128. with _check_warns([]) as w:
  129. assert smtlib_code(
  130. [
  131. Eq(a < 2, c),
  132. Eq(b > a, c),
  133. c & True,
  134. Eq(expr, 2 + Rational(1, 3))
  135. ],
  136. log_warn=w
  137. ) == '(declare-const a Int)\n' \
  138. '(declare-const b Real)\n' \
  139. '(declare-const c Bool)\n' \
  140. '(assert (= (< a 2) c))\n' \
  141. '(assert (= (> b a) c))\n' \
  142. '(assert c)\n' \
  143. '(assert (= ' \
  144. '(* (pow (* 7. a) (+ (pow b a) (* -1 a))) (pow (+ b (pow a 2)) -1)) ' \
  145. '(/ 7 3)' \
  146. '))'
  147. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  148. assert smtlib_code(
  149. Mul(-2, c, Pow(Mul(b, b, evaluate=False), -1, evaluate=False), evaluate=False),
  150. log_warn=w
  151. ) == '(declare-const b Real)\n' \
  152. '(declare-const c Real)\n' \
  153. '(* -2 c (pow (* b b) -1))'
  154. def test_basic_ops():
  155. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  156. assert smtlib_code(x * y, auto_declare=False, log_warn=w) == "(* x y)"
  157. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  158. assert smtlib_code(x + y, auto_declare=False, log_warn=w) == "(+ x y)"
  159. # with _check_warns([_SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.WILL_NOT_ASSERT]) as w:
  160. # todo: implement re-write, currently does '(+ x (* -1 y))' instead
  161. # assert smtlib_code(x - y, auto_declare=False, log_warn=w) == "(- x y)"
  162. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  163. assert smtlib_code(-x, auto_declare=False, log_warn=w) == "(* -1 x)"
  164. def test_quantifier_extensions():
  165. from sympy.logic.boolalg import Boolean
  166. from sympy import Interval, Tuple, sympify
  167. # start For-all quantifier class example
  168. class ForAll(Boolean):
  169. def _smtlib(self, printer):
  170. bound_symbol_declarations = [
  171. printer._s_expr(sym.name, [
  172. printer._known_types[printer.symbol_table[sym]],
  173. Interval(start, end)
  174. ]) for sym, start, end in self.limits
  175. ]
  176. return printer._s_expr('forall', [
  177. printer._s_expr('', bound_symbol_declarations),
  178. self.function
  179. ])
  180. @property
  181. def bound_symbols(self):
  182. return {s for s, _, _ in self.limits}
  183. @property
  184. def free_symbols(self):
  185. bound_symbol_names = {s.name for s in self.bound_symbols}
  186. return {
  187. s for s in self.function.free_symbols
  188. if s.name not in bound_symbol_names
  189. }
  190. def __new__(cls, *args):
  191. limits = [sympify(a) for a in args if isinstance(a, tuple) or isinstance(a, Tuple)]
  192. function = [sympify(a) for a in args if isinstance(a, Boolean)]
  193. assert len(limits) + len(function) == len(args)
  194. assert len(function) == 1
  195. function = function[0]
  196. if isinstance(function, ForAll): return ForAll.__new__(
  197. ForAll, *(limits + function.limits), function.function
  198. )
  199. inst = Boolean.__new__(cls)
  200. inst._args = tuple(limits + [function])
  201. inst.limits = limits
  202. inst.function = function
  203. return inst
  204. # end For-All Quantifier class example
  205. f = Function('f')
  206. with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
  207. assert smtlib_code(
  208. ForAll((x, -42, +21), Eq(f(x), f(x))),
  209. symbol_table={f: Callable[[float], float]},
  210. log_warn=w
  211. ) == '(assert (forall ( (x Real [-42, 21])) true))'
  212. with _check_warns([_W.DEFAULTING_TO_FLOAT] * 2) as w:
  213. assert smtlib_code(
  214. ForAll(
  215. (x, -42, +21), (y, -100, 3),
  216. Implies(Eq(x, y), Eq(f(x), f(y)))
  217. ),
  218. symbol_table={f: Callable[[float], float]},
  219. log_warn=w
  220. ) == '(declare-fun f (Real) Real)\n' \
  221. '(assert (' \
  222. 'forall ( (x Real [-42, 21]) (y Real [-100, 3])) ' \
  223. '(=> (= x y) (= (f x) (f y)))' \
  224. '))'
  225. a = Symbol('a', integer=True)
  226. b = Symbol('b', real=True)
  227. c = Symbol('c')
  228. with _check_warns([]) as w:
  229. assert smtlib_code(
  230. ForAll(
  231. (a, 2, 100), ForAll(
  232. (b, 2, 100),
  233. Implies(a < b, sqrt(a) < b) | c
  234. )),
  235. log_warn=w
  236. ) == '(declare-const c Bool)\n' \
  237. '(assert (forall ( (a Int [2, 100]) (b Real [2, 100])) ' \
  238. '(or c (=> (< a b) (< (pow a (/ 1 2)) b)))' \
  239. '))'
  240. def test_mix_number_mult_symbols():
  241. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  242. assert smtlib_code(
  243. 1 / pi,
  244. known_constants={pi: "MY_PI"},
  245. log_warn=w
  246. ) == '(pow MY_PI -1)'
  247. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  248. assert smtlib_code(
  249. [
  250. Eq(pi, 3.14, evaluate=False),
  251. 1 / pi,
  252. ],
  253. known_constants={pi: "MY_PI"},
  254. log_warn=w
  255. ) == '(assert (= MY_PI 3.14))\n' \
  256. '(pow MY_PI -1)'
  257. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  258. assert smtlib_code(
  259. Add(S.Zero, S.One, S.NegativeOne, S.Half,
  260. S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
  261. known_constants={
  262. S.Pi: 'p', S.GoldenRatio: 'g',
  263. S.Exp1: 'e'
  264. },
  265. known_functions={
  266. Add: 'plus',
  267. exp: 'exp'
  268. },
  269. precision=3,
  270. log_warn=w
  271. ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p g)'
  272. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  273. assert smtlib_code(
  274. Add(S.Zero, S.One, S.NegativeOne, S.Half,
  275. S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
  276. known_constants={
  277. S.Pi: 'p'
  278. },
  279. known_functions={
  280. Add: 'plus',
  281. exp: 'exp'
  282. },
  283. precision=3,
  284. log_warn=w
  285. ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p 1.62)'
  286. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  287. assert smtlib_code(
  288. Add(S.Zero, S.One, S.NegativeOne, S.Half,
  289. S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
  290. known_functions={Add: 'plus'},
  291. precision=3,
  292. log_warn=w
  293. ) == '(plus 0 1 -1 (/ 1 2) 2.72 3.14 1.62)'
  294. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  295. assert smtlib_code(
  296. Add(S.Zero, S.One, S.NegativeOne, S.Half,
  297. S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
  298. known_constants={S.Exp1: 'e'},
  299. known_functions={Add: 'plus'},
  300. precision=3,
  301. log_warn=w
  302. ) == '(plus 0 1 -1 (/ 1 2) e 3.14 1.62)'
  303. def test_boolean():
  304. with _check_warns([]) as w:
  305. assert smtlib_code(x & y, log_warn=w) == '(declare-const x Bool)\n' \
  306. '(declare-const y Bool)\n' \
  307. '(assert (and x y))'
  308. assert smtlib_code(x | y, log_warn=w) == '(declare-const x Bool)\n' \
  309. '(declare-const y Bool)\n' \
  310. '(assert (or x y))'
  311. assert smtlib_code(~x, log_warn=w) == '(declare-const x Bool)\n' \
  312. '(assert (not x))'
  313. assert smtlib_code(x & y & z, log_warn=w) == '(declare-const x Bool)\n' \
  314. '(declare-const y Bool)\n' \
  315. '(declare-const z Bool)\n' \
  316. '(assert (and x y z))'
  317. with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
  318. assert smtlib_code((x & ~y) | (z > 3), log_warn=w) == '(declare-const x Bool)\n' \
  319. '(declare-const y Bool)\n' \
  320. '(declare-const z Real)\n' \
  321. '(assert (or (> z 3) (and x (not y))))'
  322. f = Function('f')
  323. g = Function('g')
  324. h = Function('h')
  325. with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
  326. assert smtlib_code(
  327. [Gt(f(x), y),
  328. Lt(y, g(z))],
  329. symbol_table={
  330. f: Callable[[bool], int], g: Callable[[bool], int],
  331. }, log_warn=w
  332. ) == '(declare-const x Bool)\n' \
  333. '(declare-const y Real)\n' \
  334. '(declare-const z Bool)\n' \
  335. '(declare-fun f (Bool) Int)\n' \
  336. '(declare-fun g (Bool) Int)\n' \
  337. '(assert (> (f x) y))\n' \
  338. '(assert (< y (g z)))'
  339. with _check_warns([]) as w:
  340. assert smtlib_code(
  341. [Eq(f(x), y),
  342. Lt(y, g(z))],
  343. symbol_table={
  344. f: Callable[[bool], int], g: Callable[[bool], int],
  345. }, log_warn=w
  346. ) == '(declare-const x Bool)\n' \
  347. '(declare-const y Int)\n' \
  348. '(declare-const z Bool)\n' \
  349. '(declare-fun f (Bool) Int)\n' \
  350. '(declare-fun g (Bool) Int)\n' \
  351. '(assert (= (f x) y))\n' \
  352. '(assert (< y (g z)))'
  353. with _check_warns([]) as w:
  354. assert smtlib_code(
  355. [Eq(f(x), y),
  356. Eq(g(f(x)), z),
  357. Eq(h(g(f(x))), x)],
  358. symbol_table={
  359. f: Callable[[float], int],
  360. g: Callable[[int], bool],
  361. h: Callable[[bool], float]
  362. },
  363. log_warn=w
  364. ) == '(declare-const x Real)\n' \
  365. '(declare-const y Int)\n' \
  366. '(declare-const z Bool)\n' \
  367. '(declare-fun f (Real) Int)\n' \
  368. '(declare-fun g (Int) Bool)\n' \
  369. '(declare-fun h (Bool) Real)\n' \
  370. '(assert (= (f x) y))\n' \
  371. '(assert (= (g (f x)) z))\n' \
  372. '(assert (= (h (g (f x))) x))'
  373. # todo: make smtlib_code support arrays
  374. # def test_containers():
  375. # assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
  376. # "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]"
  377. # assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))"
  378. # assert julia_code([1]) == "Any[1]"
  379. # assert julia_code((1,)) == "(1,)"
  380. # assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)"
  381. # assert julia_code((1, x * y, (3, x ** 2))) == "(1, x .* y, (3, x .^ 2))"
  382. # # scalar, matrix, empty matrix and empty list
  383. # assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])"
  384. def test_smtlib_piecewise():
  385. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  386. assert smtlib_code(
  387. Piecewise((x, x < 1),
  388. (x ** 2, True)),
  389. auto_declare=False,
  390. log_warn=w
  391. ) == '(ite (< x 1) x (pow x 2))'
  392. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  393. assert smtlib_code(
  394. Piecewise((x ** 2, x < 1),
  395. (x ** 3, x < 2),
  396. (x ** 4, x < 3),
  397. (x ** 5, True)),
  398. auto_declare=False,
  399. log_warn=w
  400. ) == '(ite (< x 1) (pow x 2) ' \
  401. '(ite (< x 2) (pow x 3) ' \
  402. '(ite (< x 3) (pow x 4) ' \
  403. '(pow x 5))))'
  404. # Check that Piecewise without a True (default) condition error
  405. expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0))
  406. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  407. raises(AssertionError, lambda: smtlib_code(expr, log_warn=w))
  408. def test_smtlib_piecewise_times_const():
  409. pw = Piecewise((x, x < 1), (x ** 2, True))
  410. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  411. assert smtlib_code(2 * pw, log_warn=w) == '(declare-const x Real)\n(* 2 (ite (< x 1) x (pow x 2)))'
  412. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  413. assert smtlib_code(pw / x, log_warn=w) == '(declare-const x Real)\n(* (pow x -1) (ite (< x 1) x (pow x 2)))'
  414. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  415. assert smtlib_code(pw / (x * y), log_warn=w) == '(declare-const x Real)\n(declare-const y Real)\n(* (pow x -1) (pow y -1) (ite (< x 1) x (pow x 2)))'
  416. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  417. assert smtlib_code(pw / 3, log_warn=w) == '(declare-const x Real)\n(* (/ 1 3) (ite (< x 1) x (pow x 2)))'
  418. # todo: make smtlib_code support arrays / matrices ?
  419. # def test_smtlib_matrix_assign_to():
  420. # A = Matrix([[1, 2, 3]])
  421. # assert smtlib_code(A, assign_to='a') == "a = [1 2 3]"
  422. # A = Matrix([[1, 2], [3, 4]])
  423. # assert smtlib_code(A, assign_to='A') == "A = [1 2;\n3 4]"
  424. # def test_julia_matrix_1x1():
  425. # A = Matrix([[3]])
  426. # B = MatrixSymbol('B', 1, 1)
  427. # C = MatrixSymbol('C', 1, 2)
  428. # assert julia_code(A, assign_to=B) == "B = [3]"
  429. # raises(ValueError, lambda: julia_code(A, assign_to=C))
  430. # def test_julia_matrix_elements():
  431. # A = Matrix([[x, 2, x * y]])
  432. # assert julia_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2"
  433. # A = MatrixSymbol('AA', 1, 3)
  434. # assert julia_code(A) == "AA"
  435. # assert julia_code(A[0, 0] ** 2 + sin(A[0, 1]) + A[0, 2]) == \
  436. # "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]"
  437. # assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]"
  438. def test_smtlib_boolean():
  439. with _check_warns([]) as w:
  440. assert smtlib_code(True, auto_assert=False, log_warn=w) == 'true'
  441. assert smtlib_code(True, log_warn=w) == '(assert true)'
  442. assert smtlib_code(S.true, log_warn=w) == '(assert true)'
  443. assert smtlib_code(S.false, log_warn=w) == '(assert false)'
  444. assert smtlib_code(False, log_warn=w) == '(assert false)'
  445. assert smtlib_code(False, auto_assert=False, log_warn=w) == 'false'
  446. def test_not_supported():
  447. f = Function('f')
  448. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  449. raises(KeyError, lambda: smtlib_code(f(x).diff(x), symbol_table={f: Callable[[float], float]}, log_warn=w))
  450. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  451. raises(KeyError, lambda: smtlib_code(S.ComplexInfinity, log_warn=w))