dpll.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. """Implementation of DPLL algorithm
  2. Further improvements: eliminate calls to pl_true, implement branching rules,
  3. efficient unit propagation.
  4. References:
  5. - https://en.wikipedia.org/wiki/DPLL_algorithm
  6. - https://www.researchgate.net/publication/242384772_Implementations_of_the_DPLL_Algorithm
  7. """
  8. from sympy.core.sorting import default_sort_key
  9. from sympy.logic.boolalg import Or, Not, conjuncts, disjuncts, to_cnf, \
  10. to_int_repr, _find_predicates
  11. from sympy.assumptions.cnf import CNF
  12. from sympy.logic.inference import pl_true, literal_symbol
  13. def dpll_satisfiable(expr):
  14. """
  15. Check satisfiability of a propositional sentence.
  16. It returns a model rather than True when it succeeds
  17. >>> from sympy.abc import A, B
  18. >>> from sympy.logic.algorithms.dpll import dpll_satisfiable
  19. >>> dpll_satisfiable(A & ~B)
  20. {A: True, B: False}
  21. >>> dpll_satisfiable(A & ~A)
  22. False
  23. """
  24. if not isinstance(expr, CNF):
  25. clauses = conjuncts(to_cnf(expr))
  26. else:
  27. clauses = expr.clauses
  28. if False in clauses:
  29. return False
  30. symbols = sorted(_find_predicates(expr), key=default_sort_key)
  31. symbols_int_repr = set(range(1, len(symbols) + 1))
  32. clauses_int_repr = to_int_repr(clauses, symbols)
  33. result = dpll_int_repr(clauses_int_repr, symbols_int_repr, {})
  34. if not result:
  35. return result
  36. output = {}
  37. for key in result:
  38. output.update({symbols[key - 1]: result[key]})
  39. return output
  40. def dpll(clauses, symbols, model):
  41. """
  42. Compute satisfiability in a partial model.
  43. Clauses is an array of conjuncts.
  44. >>> from sympy.abc import A, B, D
  45. >>> from sympy.logic.algorithms.dpll import dpll
  46. >>> dpll([A, B, D], [A, B], {D: False})
  47. False
  48. """
  49. # compute DP kernel
  50. P, value = find_unit_clause(clauses, model)
  51. while P:
  52. model.update({P: value})
  53. symbols.remove(P)
  54. if not value:
  55. P = ~P
  56. clauses = unit_propagate(clauses, P)
  57. P, value = find_unit_clause(clauses, model)
  58. P, value = find_pure_symbol(symbols, clauses)
  59. while P:
  60. model.update({P: value})
  61. symbols.remove(P)
  62. if not value:
  63. P = ~P
  64. clauses = unit_propagate(clauses, P)
  65. P, value = find_pure_symbol(symbols, clauses)
  66. # end DP kernel
  67. unknown_clauses = []
  68. for c in clauses:
  69. val = pl_true(c, model)
  70. if val is False:
  71. return False
  72. if val is not True:
  73. unknown_clauses.append(c)
  74. if not unknown_clauses:
  75. return model
  76. if not clauses:
  77. return model
  78. P = symbols.pop()
  79. model_copy = model.copy()
  80. model.update({P: True})
  81. model_copy.update({P: False})
  82. symbols_copy = symbols[:]
  83. return (dpll(unit_propagate(unknown_clauses, P), symbols, model) or
  84. dpll(unit_propagate(unknown_clauses, Not(P)), symbols_copy, model_copy))
  85. def dpll_int_repr(clauses, symbols, model):
  86. """
  87. Compute satisfiability in a partial model.
  88. Arguments are expected to be in integer representation
  89. >>> from sympy.logic.algorithms.dpll import dpll_int_repr
  90. >>> dpll_int_repr([{1}, {2}, {3}], {1, 2}, {3: False})
  91. False
  92. """
  93. # compute DP kernel
  94. P, value = find_unit_clause_int_repr(clauses, model)
  95. while P:
  96. model.update({P: value})
  97. symbols.remove(P)
  98. if not value:
  99. P = -P
  100. clauses = unit_propagate_int_repr(clauses, P)
  101. P, value = find_unit_clause_int_repr(clauses, model)
  102. P, value = find_pure_symbol_int_repr(symbols, clauses)
  103. while P:
  104. model.update({P: value})
  105. symbols.remove(P)
  106. if not value:
  107. P = -P
  108. clauses = unit_propagate_int_repr(clauses, P)
  109. P, value = find_pure_symbol_int_repr(symbols, clauses)
  110. # end DP kernel
  111. unknown_clauses = []
  112. for c in clauses:
  113. val = pl_true_int_repr(c, model)
  114. if val is False:
  115. return False
  116. if val is not True:
  117. unknown_clauses.append(c)
  118. if not unknown_clauses:
  119. return model
  120. P = symbols.pop()
  121. model_copy = model.copy()
  122. model.update({P: True})
  123. model_copy.update({P: False})
  124. symbols_copy = symbols.copy()
  125. return (dpll_int_repr(unit_propagate_int_repr(unknown_clauses, P), symbols, model) or
  126. dpll_int_repr(unit_propagate_int_repr(unknown_clauses, -P), symbols_copy, model_copy))
  127. ### helper methods for DPLL
  128. def pl_true_int_repr(clause, model={}):
  129. """
  130. Lightweight version of pl_true.
  131. Argument clause represents the set of args of an Or clause. This is used
  132. inside dpll_int_repr, it is not meant to be used directly.
  133. >>> from sympy.logic.algorithms.dpll import pl_true_int_repr
  134. >>> pl_true_int_repr({1, 2}, {1: False})
  135. >>> pl_true_int_repr({1, 2}, {1: False, 2: False})
  136. False
  137. """
  138. result = False
  139. for lit in clause:
  140. if lit < 0:
  141. p = model.get(-lit)
  142. if p is not None:
  143. p = not p
  144. else:
  145. p = model.get(lit)
  146. if p is True:
  147. return True
  148. elif p is None:
  149. result = None
  150. return result
  151. def unit_propagate(clauses, symbol):
  152. """
  153. Returns an equivalent set of clauses
  154. If a set of clauses contains the unit clause l, the other clauses are
  155. simplified by the application of the two following rules:
  156. 1. every clause containing l is removed
  157. 2. in every clause that contains ~l this literal is deleted
  158. Arguments are expected to be in CNF.
  159. >>> from sympy.abc import A, B, D
  160. >>> from sympy.logic.algorithms.dpll import unit_propagate
  161. >>> unit_propagate([A | B, D | ~B, B], B)
  162. [D, B]
  163. """
  164. output = []
  165. for c in clauses:
  166. if c.func != Or:
  167. output.append(c)
  168. continue
  169. for arg in c.args:
  170. if arg == ~symbol:
  171. output.append(Or(*[x for x in c.args if x != ~symbol]))
  172. break
  173. if arg == symbol:
  174. break
  175. else:
  176. output.append(c)
  177. return output
  178. def unit_propagate_int_repr(clauses, s):
  179. """
  180. Same as unit_propagate, but arguments are expected to be in integer
  181. representation
  182. >>> from sympy.logic.algorithms.dpll import unit_propagate_int_repr
  183. >>> unit_propagate_int_repr([{1, 2}, {3, -2}, {2}], 2)
  184. [{3}]
  185. """
  186. negated = {-s}
  187. return [clause - negated for clause in clauses if s not in clause]
  188. def find_pure_symbol(symbols, unknown_clauses):
  189. """
  190. Find a symbol and its value if it appears only as a positive literal
  191. (or only as a negative) in clauses.
  192. >>> from sympy.abc import A, B, D
  193. >>> from sympy.logic.algorithms.dpll import find_pure_symbol
  194. >>> find_pure_symbol([A, B, D], [A|~B,~B|~D,D|A])
  195. (A, True)
  196. """
  197. for sym in symbols:
  198. found_pos, found_neg = False, False
  199. for c in unknown_clauses:
  200. if not found_pos and sym in disjuncts(c):
  201. found_pos = True
  202. if not found_neg and Not(sym) in disjuncts(c):
  203. found_neg = True
  204. if found_pos != found_neg:
  205. return sym, found_pos
  206. return None, None
  207. def find_pure_symbol_int_repr(symbols, unknown_clauses):
  208. """
  209. Same as find_pure_symbol, but arguments are expected
  210. to be in integer representation
  211. >>> from sympy.logic.algorithms.dpll import find_pure_symbol_int_repr
  212. >>> find_pure_symbol_int_repr({1,2,3},
  213. ... [{1, -2}, {-2, -3}, {3, 1}])
  214. (1, True)
  215. """
  216. all_symbols = set().union(*unknown_clauses)
  217. found_pos = all_symbols.intersection(symbols)
  218. found_neg = all_symbols.intersection([-s for s in symbols])
  219. for p in found_pos:
  220. if -p not in found_neg:
  221. return p, True
  222. for p in found_neg:
  223. if -p not in found_pos:
  224. return -p, False
  225. return None, None
  226. def find_unit_clause(clauses, model):
  227. """
  228. A unit clause has only 1 variable that is not bound in the model.
  229. >>> from sympy.abc import A, B, D
  230. >>> from sympy.logic.algorithms.dpll import find_unit_clause
  231. >>> find_unit_clause([A | B | D, B | ~D, A | ~B], {A:True})
  232. (B, False)
  233. """
  234. for clause in clauses:
  235. num_not_in_model = 0
  236. for literal in disjuncts(clause):
  237. sym = literal_symbol(literal)
  238. if sym not in model:
  239. num_not_in_model += 1
  240. P, value = sym, not isinstance(literal, Not)
  241. if num_not_in_model == 1:
  242. return P, value
  243. return None, None
  244. def find_unit_clause_int_repr(clauses, model):
  245. """
  246. Same as find_unit_clause, but arguments are expected to be in
  247. integer representation.
  248. >>> from sympy.logic.algorithms.dpll import find_unit_clause_int_repr
  249. >>> find_unit_clause_int_repr([{1, 2, 3},
  250. ... {2, -3}, {1, -2}], {1: True})
  251. (2, False)
  252. """
  253. bound = set(model) | {-sym for sym in model}
  254. for clause in clauses:
  255. unbound = clause - bound
  256. if len(unbound) == 1:
  257. p = unbound.pop()
  258. if p < 0:
  259. return -p, False
  260. else:
  261. return p, True
  262. return None, None