test_inference.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. """For more tests on satisfiability, see test_dimacs"""
  2. from sympy.assumptions.ask import Q
  3. from sympy.core.symbol import symbols
  4. from sympy.logic.boolalg import And, Implies, Equivalent, true, false
  5. from sympy.logic.inference import literal_symbol, \
  6. pl_true, satisfiable, valid, entails, PropKB
  7. from sympy.logic.algorithms.dpll import dpll, dpll_satisfiable, \
  8. find_pure_symbol, find_unit_clause, unit_propagate, \
  9. find_pure_symbol_int_repr, find_unit_clause_int_repr, \
  10. unit_propagate_int_repr
  11. from sympy.logic.algorithms.dpll2 import dpll_satisfiable as dpll2_satisfiable
  12. from sympy.testing.pytest import raises
  13. def test_literal():
  14. A, B = symbols('A,B')
  15. assert literal_symbol(True) is True
  16. assert literal_symbol(False) is False
  17. assert literal_symbol(A) is A
  18. assert literal_symbol(~A) is A
  19. def test_find_pure_symbol():
  20. A, B, C = symbols('A,B,C')
  21. assert find_pure_symbol([A], [A]) == (A, True)
  22. assert find_pure_symbol([A, B], [~A | B, ~B | A]) == (None, None)
  23. assert find_pure_symbol([A, B, C], [ A | ~B, ~B | ~C, C | A]) == (A, True)
  24. assert find_pure_symbol([A, B, C], [~A | B, B | ~C, C | A]) == (B, True)
  25. assert find_pure_symbol([A, B, C], [~A | ~B, ~B | ~C, C | A]) == (B, False)
  26. assert find_pure_symbol(
  27. [A, B, C], [~A | B, ~B | ~C, C | A]) == (None, None)
  28. def test_find_pure_symbol_int_repr():
  29. assert find_pure_symbol_int_repr([1], [{1}]) == (1, True)
  30. assert find_pure_symbol_int_repr([1, 2],
  31. [{-1, 2}, {-2, 1}]) == (None, None)
  32. assert find_pure_symbol_int_repr([1, 2, 3],
  33. [{1, -2}, {-2, -3}, {3, 1}]) == (1, True)
  34. assert find_pure_symbol_int_repr([1, 2, 3],
  35. [{-1, 2}, {2, -3}, {3, 1}]) == (2, True)
  36. assert find_pure_symbol_int_repr([1, 2, 3],
  37. [{-1, -2}, {-2, -3}, {3, 1}]) == (2, False)
  38. assert find_pure_symbol_int_repr([1, 2, 3],
  39. [{-1, 2}, {-2, -3}, {3, 1}]) == (None, None)
  40. def test_unit_clause():
  41. A, B, C = symbols('A,B,C')
  42. assert find_unit_clause([A], {}) == (A, True)
  43. assert find_unit_clause([A, ~A], {}) == (A, True) # Wrong ??
  44. assert find_unit_clause([A | B], {A: True}) == (B, True)
  45. assert find_unit_clause([A | B], {B: True}) == (A, True)
  46. assert find_unit_clause(
  47. [A | B | C, B | ~C, A | ~B], {A: True}) == (B, False)
  48. assert find_unit_clause([A | B | C, B | ~C, A | B], {A: True}) == (B, True)
  49. assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True)
  50. def test_unit_clause_int_repr():
  51. assert find_unit_clause_int_repr(map(set, [[1]]), {}) == (1, True)
  52. assert find_unit_clause_int_repr(map(set, [[1], [-1]]), {}) == (1, True)
  53. assert find_unit_clause_int_repr([{1, 2}], {1: True}) == (2, True)
  54. assert find_unit_clause_int_repr([{1, 2}], {2: True}) == (1, True)
  55. assert find_unit_clause_int_repr(map(set,
  56. [[1, 2, 3], [2, -3], [1, -2]]), {1: True}) == (2, False)
  57. assert find_unit_clause_int_repr(map(set,
  58. [[1, 2, 3], [3, -3], [1, 2]]), {1: True}) == (2, True)
  59. A, B, C = symbols('A,B,C')
  60. assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True)
  61. def test_unit_propagate():
  62. A, B, C = symbols('A,B,C')
  63. assert unit_propagate([A | B], A) == []
  64. assert unit_propagate([A | B, ~A | C, ~C | B, A], A) == [C, ~C | B, A]
  65. def test_unit_propagate_int_repr():
  66. assert unit_propagate_int_repr([{1, 2}], 1) == []
  67. assert unit_propagate_int_repr(map(set,
  68. [[1, 2], [-1, 3], [-3, 2], [1]]), 1) == [{3}, {-3, 2}]
  69. def test_dpll():
  70. """This is also tested in test_dimacs"""
  71. A, B, C = symbols('A,B,C')
  72. assert dpll([A | B], [A, B], {A: True, B: True}) == {A: True, B: True}
  73. def test_dpll_satisfiable():
  74. A, B, C = symbols('A,B,C')
  75. assert dpll_satisfiable( A & ~A ) is False
  76. assert dpll_satisfiable( A & ~B ) == {A: True, B: False}
  77. assert dpll_satisfiable(
  78. A | B ) in ({A: True}, {B: True}, {A: True, B: True})
  79. assert dpll_satisfiable(
  80. (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
  81. assert dpll_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False},
  82. {A: True, C: True}, {B: True, C: True})
  83. assert dpll_satisfiable( A & B & C ) == {A: True, B: True, C: True}
  84. assert dpll_satisfiable( (A | B) & (A >> B) ) == {B: True}
  85. assert dpll_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
  86. assert dpll_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
  87. def test_dpll2_satisfiable():
  88. A, B, C = symbols('A,B,C')
  89. assert dpll2_satisfiable( A & ~A ) is False
  90. assert dpll2_satisfiable( A & ~B ) == {A: True, B: False}
  91. assert dpll2_satisfiable(
  92. A | B ) in ({A: True}, {B: True}, {A: True, B: True})
  93. assert dpll2_satisfiable(
  94. (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
  95. assert dpll2_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True},
  96. {A: True, B: True, C: True})
  97. assert dpll2_satisfiable( A & B & C ) == {A: True, B: True, C: True}
  98. assert dpll2_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False},
  99. {B: True, A: True})
  100. assert dpll2_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
  101. assert dpll2_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
  102. def test_minisat22_satisfiable():
  103. A, B, C = symbols('A,B,C')
  104. minisat22_satisfiable = lambda expr: satisfiable(expr, algorithm="minisat22")
  105. assert minisat22_satisfiable( A & ~A ) is False
  106. assert minisat22_satisfiable( A & ~B ) == {A: True, B: False}
  107. assert minisat22_satisfiable(
  108. A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False})
  109. assert minisat22_satisfiable(
  110. (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
  111. assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True},
  112. {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False})
  113. assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True}
  114. assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False},
  115. {B: True, A: True})
  116. assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
  117. assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
  118. def test_minisat22_minimal_satisfiable():
  119. A, B, C = symbols('A,B,C')
  120. minisat22_satisfiable = lambda expr, minimal=True: satisfiable(expr, algorithm="minisat22", minimal=True)
  121. assert minisat22_satisfiable( A & ~A ) is False
  122. assert minisat22_satisfiable( A & ~B ) == {A: True, B: False}
  123. assert minisat22_satisfiable(
  124. A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False})
  125. assert minisat22_satisfiable(
  126. (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
  127. assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True},
  128. {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False})
  129. assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True}
  130. assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False},
  131. {B: True, A: True})
  132. assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
  133. assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
  134. g = satisfiable((A | B | C),algorithm="minisat22",minimal=True,all_models=True)
  135. sol = next(g)
  136. first_solution = {key for key, value in sol.items() if value}
  137. sol=next(g)
  138. second_solution = {key for key, value in sol.items() if value}
  139. sol=next(g)
  140. third_solution = {key for key, value in sol.items() if value}
  141. assert not first_solution <= second_solution
  142. assert not second_solution <= third_solution
  143. assert not first_solution <= third_solution
  144. def test_satisfiable():
  145. A, B, C = symbols('A,B,C')
  146. assert satisfiable(A & (A >> B) & ~B) is False
  147. def test_valid():
  148. A, B, C = symbols('A,B,C')
  149. assert valid(A >> (B >> A)) is True
  150. assert valid((A >> (B >> C)) >> ((A >> B) >> (A >> C))) is True
  151. assert valid((~B >> ~A) >> (A >> B)) is True
  152. assert valid(A | B | C) is False
  153. assert valid(A >> B) is False
  154. def test_pl_true():
  155. A, B, C = symbols('A,B,C')
  156. assert pl_true(True) is True
  157. assert pl_true( A & B, {A: True, B: True}) is True
  158. assert pl_true( A | B, {A: True}) is True
  159. assert pl_true( A | B, {B: True}) is True
  160. assert pl_true( A | B, {A: None, B: True}) is True
  161. assert pl_true( A >> B, {A: False}) is True
  162. assert pl_true( A | B | ~C, {A: False, B: True, C: True}) is True
  163. assert pl_true(Equivalent(A, B), {A: False, B: False}) is True
  164. # test for false
  165. assert pl_true(False) is False
  166. assert pl_true( A & B, {A: False, B: False}) is False
  167. assert pl_true( A & B, {A: False}) is False
  168. assert pl_true( A & B, {B: False}) is False
  169. assert pl_true( A | B, {A: False, B: False}) is False
  170. #test for None
  171. assert pl_true(B, {B: None}) is None
  172. assert pl_true( A & B, {A: True, B: None}) is None
  173. assert pl_true( A >> B, {A: True, B: None}) is None
  174. assert pl_true(Equivalent(A, B), {A: None}) is None
  175. assert pl_true(Equivalent(A, B), {A: True, B: None}) is None
  176. # Test for deep
  177. assert pl_true(A | B, {A: False}, deep=True) is None
  178. assert pl_true(~A & ~B, {A: False}, deep=True) is None
  179. assert pl_true(A | B, {A: False, B: False}, deep=True) is False
  180. assert pl_true(A & B & (~A | ~B), {A: True}, deep=True) is False
  181. assert pl_true((C >> A) >> (B >> A), {C: True}, deep=True) is True
  182. def test_pl_true_wrong_input():
  183. from sympy.core.numbers import pi
  184. raises(ValueError, lambda: pl_true('John Cleese'))
  185. raises(ValueError, lambda: pl_true(42 + pi + pi ** 2))
  186. raises(ValueError, lambda: pl_true(42))
  187. def test_entails():
  188. A, B, C = symbols('A, B, C')
  189. assert entails(A, [A >> B, ~B]) is False
  190. assert entails(B, [Equivalent(A, B), A]) is True
  191. assert entails((A >> B) >> (~A >> ~B)) is False
  192. assert entails((A >> B) >> (~B >> ~A)) is True
  193. def test_PropKB():
  194. A, B, C = symbols('A,B,C')
  195. kb = PropKB()
  196. assert kb.ask(A >> B) is False
  197. assert kb.ask(A >> (B >> A)) is True
  198. kb.tell(A >> B)
  199. kb.tell(B >> C)
  200. assert kb.ask(A) is False
  201. assert kb.ask(B) is False
  202. assert kb.ask(C) is False
  203. assert kb.ask(~A) is False
  204. assert kb.ask(~B) is False
  205. assert kb.ask(~C) is False
  206. assert kb.ask(A >> C) is True
  207. kb.tell(A)
  208. assert kb.ask(A) is True
  209. assert kb.ask(B) is True
  210. assert kb.ask(C) is True
  211. assert kb.ask(~C) is False
  212. kb.retract(A)
  213. assert kb.ask(C) is False
  214. def test_propKB_tolerant():
  215. """"tolerant to bad input"""
  216. kb = PropKB()
  217. A, B, C = symbols('A,B,C')
  218. assert kb.ask(B) is False
  219. def test_satisfiable_non_symbols():
  220. x, y = symbols('x y')
  221. assumptions = Q.zero(x*y)
  222. facts = Implies(Q.zero(x*y), Q.zero(x) | Q.zero(y))
  223. query = ~Q.zero(x) & ~Q.zero(y)
  224. refutations = [
  225. {Q.zero(x): True, Q.zero(x*y): True},
  226. {Q.zero(y): True, Q.zero(x*y): True},
  227. {Q.zero(x): True, Q.zero(y): True, Q.zero(x*y): True},
  228. {Q.zero(x): True, Q.zero(y): False, Q.zero(x*y): True},
  229. {Q.zero(x): False, Q.zero(y): True, Q.zero(x*y): True}]
  230. assert not satisfiable(And(assumptions, facts, query), algorithm='dpll')
  231. assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll') in refutations
  232. assert not satisfiable(And(assumptions, facts, query), algorithm='dpll2')
  233. assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll2') in refutations
  234. def test_satisfiable_bool():
  235. from sympy.core.singleton import S
  236. assert satisfiable(true) == {true: true}
  237. assert satisfiable(S.true) == {true: true}
  238. assert satisfiable(false) is False
  239. assert satisfiable(S.false) is False
  240. def test_satisfiable_all_models():
  241. from sympy.abc import A, B
  242. assert next(satisfiable(False, all_models=True)) is False
  243. assert list(satisfiable((A >> ~A) & A, all_models=True)) == [False]
  244. assert list(satisfiable(True, all_models=True)) == [{true: true}]
  245. models = [{A: True, B: False}, {A: False, B: True}]
  246. result = satisfiable(A ^ B, all_models=True)
  247. models.remove(next(result))
  248. models.remove(next(result))
  249. raises(StopIteration, lambda: next(result))
  250. assert not models
  251. assert list(satisfiable(Equivalent(A, B), all_models=True)) == \
  252. [{A: False, B: False}, {A: True, B: True}]
  253. models = [{A: False, B: False}, {A: False, B: True}, {A: True, B: True}]
  254. for model in satisfiable(A >> B, all_models=True):
  255. models.remove(model)
  256. assert not models
  257. # This is a santiy test to check that only the required number
  258. # of solutions are generated. The expr below has 2**100 - 1 models
  259. # which would time out the test if all are generated at once.
  260. from sympy.utilities.iterables import numbered_symbols
  261. from sympy.logic.boolalg import Or
  262. sym = numbered_symbols()
  263. X = [next(sym) for i in range(100)]
  264. result = satisfiable(Or(*X), all_models=True)
  265. for i in range(10):
  266. assert next(result)