123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- """Implementation of DPLL algorithm
- Further improvements: eliminate calls to pl_true, implement branching rules,
- efficient unit propagation.
- References:
- - https://en.wikipedia.org/wiki/DPLL_algorithm
- - https://www.researchgate.net/publication/242384772_Implementations_of_the_DPLL_Algorithm
- """
- from sympy.core.sorting import default_sort_key
- from sympy.logic.boolalg import Or, Not, conjuncts, disjuncts, to_cnf, \
- to_int_repr, _find_predicates
- from sympy.assumptions.cnf import CNF
- from sympy.logic.inference import pl_true, literal_symbol
- def dpll_satisfiable(expr):
- """
- Check satisfiability of a propositional sentence.
- It returns a model rather than True when it succeeds
- >>> from sympy.abc import A, B
- >>> from sympy.logic.algorithms.dpll import dpll_satisfiable
- >>> dpll_satisfiable(A & ~B)
- {A: True, B: False}
- >>> dpll_satisfiable(A & ~A)
- False
- """
- if not isinstance(expr, CNF):
- clauses = conjuncts(to_cnf(expr))
- else:
- clauses = expr.clauses
- if False in clauses:
- return False
- symbols = sorted(_find_predicates(expr), key=default_sort_key)
- symbols_int_repr = set(range(1, len(symbols) + 1))
- clauses_int_repr = to_int_repr(clauses, symbols)
- result = dpll_int_repr(clauses_int_repr, symbols_int_repr, {})
- if not result:
- return result
- output = {}
- for key in result:
- output.update({symbols[key - 1]: result[key]})
- return output
- def dpll(clauses, symbols, model):
- """
- Compute satisfiability in a partial model.
- Clauses is an array of conjuncts.
- >>> from sympy.abc import A, B, D
- >>> from sympy.logic.algorithms.dpll import dpll
- >>> dpll([A, B, D], [A, B], {D: False})
- False
- """
- # compute DP kernel
- P, value = find_unit_clause(clauses, model)
- while P:
- model.update({P: value})
- symbols.remove(P)
- if not value:
- P = ~P
- clauses = unit_propagate(clauses, P)
- P, value = find_unit_clause(clauses, model)
- P, value = find_pure_symbol(symbols, clauses)
- while P:
- model.update({P: value})
- symbols.remove(P)
- if not value:
- P = ~P
- clauses = unit_propagate(clauses, P)
- P, value = find_pure_symbol(symbols, clauses)
- # end DP kernel
- unknown_clauses = []
- for c in clauses:
- val = pl_true(c, model)
- if val is False:
- return False
- if val is not True:
- unknown_clauses.append(c)
- if not unknown_clauses:
- return model
- if not clauses:
- return model
- P = symbols.pop()
- model_copy = model.copy()
- model.update({P: True})
- model_copy.update({P: False})
- symbols_copy = symbols[:]
- return (dpll(unit_propagate(unknown_clauses, P), symbols, model) or
- dpll(unit_propagate(unknown_clauses, Not(P)), symbols_copy, model_copy))
- def dpll_int_repr(clauses, symbols, model):
- """
- Compute satisfiability in a partial model.
- Arguments are expected to be in integer representation
- >>> from sympy.logic.algorithms.dpll import dpll_int_repr
- >>> dpll_int_repr([{1}, {2}, {3}], {1, 2}, {3: False})
- False
- """
- # compute DP kernel
- P, value = find_unit_clause_int_repr(clauses, model)
- while P:
- model.update({P: value})
- symbols.remove(P)
- if not value:
- P = -P
- clauses = unit_propagate_int_repr(clauses, P)
- P, value = find_unit_clause_int_repr(clauses, model)
- P, value = find_pure_symbol_int_repr(symbols, clauses)
- while P:
- model.update({P: value})
- symbols.remove(P)
- if not value:
- P = -P
- clauses = unit_propagate_int_repr(clauses, P)
- P, value = find_pure_symbol_int_repr(symbols, clauses)
- # end DP kernel
- unknown_clauses = []
- for c in clauses:
- val = pl_true_int_repr(c, model)
- if val is False:
- return False
- if val is not True:
- unknown_clauses.append(c)
- if not unknown_clauses:
- return model
- P = symbols.pop()
- model_copy = model.copy()
- model.update({P: True})
- model_copy.update({P: False})
- symbols_copy = symbols.copy()
- return (dpll_int_repr(unit_propagate_int_repr(unknown_clauses, P), symbols, model) or
- dpll_int_repr(unit_propagate_int_repr(unknown_clauses, -P), symbols_copy, model_copy))
- ### helper methods for DPLL
- def pl_true_int_repr(clause, model={}):
- """
- Lightweight version of pl_true.
- Argument clause represents the set of args of an Or clause. This is used
- inside dpll_int_repr, it is not meant to be used directly.
- >>> from sympy.logic.algorithms.dpll import pl_true_int_repr
- >>> pl_true_int_repr({1, 2}, {1: False})
- >>> pl_true_int_repr({1, 2}, {1: False, 2: False})
- False
- """
- result = False
- for lit in clause:
- if lit < 0:
- p = model.get(-lit)
- if p is not None:
- p = not p
- else:
- p = model.get(lit)
- if p is True:
- return True
- elif p is None:
- result = None
- return result
- def unit_propagate(clauses, symbol):
- """
- Returns an equivalent set of clauses
- If a set of clauses contains the unit clause l, the other clauses are
- simplified by the application of the two following rules:
- 1. every clause containing l is removed
- 2. in every clause that contains ~l this literal is deleted
- Arguments are expected to be in CNF.
- >>> from sympy.abc import A, B, D
- >>> from sympy.logic.algorithms.dpll import unit_propagate
- >>> unit_propagate([A | B, D | ~B, B], B)
- [D, B]
- """
- output = []
- for c in clauses:
- if c.func != Or:
- output.append(c)
- continue
- for arg in c.args:
- if arg == ~symbol:
- output.append(Or(*[x for x in c.args if x != ~symbol]))
- break
- if arg == symbol:
- break
- else:
- output.append(c)
- return output
- def unit_propagate_int_repr(clauses, s):
- """
- Same as unit_propagate, but arguments are expected to be in integer
- representation
- >>> from sympy.logic.algorithms.dpll import unit_propagate_int_repr
- >>> unit_propagate_int_repr([{1, 2}, {3, -2}, {2}], 2)
- [{3}]
- """
- negated = {-s}
- return [clause - negated for clause in clauses if s not in clause]
- def find_pure_symbol(symbols, unknown_clauses):
- """
- Find a symbol and its value if it appears only as a positive literal
- (or only as a negative) in clauses.
- >>> from sympy.abc import A, B, D
- >>> from sympy.logic.algorithms.dpll import find_pure_symbol
- >>> find_pure_symbol([A, B, D], [A|~B,~B|~D,D|A])
- (A, True)
- """
- for sym in symbols:
- found_pos, found_neg = False, False
- for c in unknown_clauses:
- if not found_pos and sym in disjuncts(c):
- found_pos = True
- if not found_neg and Not(sym) in disjuncts(c):
- found_neg = True
- if found_pos != found_neg:
- return sym, found_pos
- return None, None
- def find_pure_symbol_int_repr(symbols, unknown_clauses):
- """
- Same as find_pure_symbol, but arguments are expected
- to be in integer representation
- >>> from sympy.logic.algorithms.dpll import find_pure_symbol_int_repr
- >>> find_pure_symbol_int_repr({1,2,3},
- ... [{1, -2}, {-2, -3}, {3, 1}])
- (1, True)
- """
- all_symbols = set().union(*unknown_clauses)
- found_pos = all_symbols.intersection(symbols)
- found_neg = all_symbols.intersection([-s for s in symbols])
- for p in found_pos:
- if -p not in found_neg:
- return p, True
- for p in found_neg:
- if -p not in found_pos:
- return -p, False
- return None, None
- def find_unit_clause(clauses, model):
- """
- A unit clause has only 1 variable that is not bound in the model.
- >>> from sympy.abc import A, B, D
- >>> from sympy.logic.algorithms.dpll import find_unit_clause
- >>> find_unit_clause([A | B | D, B | ~D, A | ~B], {A:True})
- (B, False)
- """
- for clause in clauses:
- num_not_in_model = 0
- for literal in disjuncts(clause):
- sym = literal_symbol(literal)
- if sym not in model:
- num_not_in_model += 1
- P, value = sym, not isinstance(literal, Not)
- if num_not_in_model == 1:
- return P, value
- return None, None
- def find_unit_clause_int_repr(clauses, model):
- """
- Same as find_unit_clause, but arguments are expected to be in
- integer representation.
- >>> from sympy.logic.algorithms.dpll import find_unit_clause_int_repr
- >>> find_unit_clause_int_repr([{1, 2, 3},
- ... {2, -3}, {1, -2}], {1: True})
- (2, False)
- """
- bound = set(model) | {-sym for sym in model}
- for clause in clauses:
- unbound = clause - bound
- if len(unbound) == 1:
- p = unbound.pop()
- if p < 0:
- return -p, False
- else:
- return p, True
- return None, None