14#include "llvm/ADT/ArrayRef.h"
15#include "llvm/ADT/SmallVector.h"
16#include "llvm/Support/ErrorHandling.h"
17#include "llvm/Support/SMTAPI.h"
19#ifdef CIRCT_CADICAL_ENABLED
20#include "third_party/cadical/cadical.hpp"
35#ifdef CIRCT_CADICAL_ENABLED
49 llvm_unreachable(
"unknown CaDiCaL configuration");
52class CadicalSATSolver :
public IncrementalSATSolver {
54 explicit CadicalSATSolver(
const CadicalSATSolverOptions &options) {
56 CadicalSATSolverOptions::CadicalSolverConfig::Default) {
57 bool configured = solver.configure(toCadicalConfigName(options.config));
58 assert(configured &&
"invalid CaDiCaL configuration");
62 void add(
int lit)
override { solver.add(
lit); }
63 void assume(
int lit)
override {
67 Result solve()
override {
68 if (conflictLimit >= 0)
69 solver.limit(
"conflicts", conflictLimit);
70 switch (solver.solve()) {
71 case CaDiCaL::SATISFIABLE:
73 case CaDiCaL::UNSATISFIABLE:
79 int val(
int v)
const override {
80 if (v <= 0 || v > maxVariable)
84 void setConflictLimit(
int limit)
override { conflictLimit = limit; }
85 void reserveVars(
int maxVar)
override {
86 if (maxVar <= maxVariable)
88 solver.resize(maxVar);
91 void addClause(llvm::ArrayRef<int> lits)
override {
96 solver.clause(lits.data(), lits.size());
98 int newVar()
override {
99 int var = solver.declare_one_more_variable();
100 if (var > maxVariable)
106 mutable CaDiCaL::Solver solver;
108 int conflictLimit = -1;
119class Z3SATSolver :
public IncrementalSATSolver {
122 ~Z3SATSolver()
override;
124 void add(
int lit)
override;
125 void assume(
int lit)
override;
126 Result solve()
override;
127 Result solve(llvm::ArrayRef<int> assumptions)
override;
128 int val(
int v)
const override;
129 void reserveVars(
int maxVar)
override;
130 int newVar()
override;
133 void clearSolveScope();
135 llvm::SMTExprRef literalToExpr(
int lit);
136 void addClauseInternal(llvm::ArrayRef<int> lits);
138 llvm::SMTSolverRef solver;
139 llvm::SmallVector<llvm::SMTExprRef> variables;
140 llvm::SmallVector<int> assumptions;
141 llvm::SmallVector<int> clauseBuffer;
143 Result lastResult = kUNKNOWN;
144 bool solveScopeActive =
false;
147Z3SATSolver::Z3SATSolver() : solver(
llvm::CreateZ3Solver()) {}
149Z3SATSolver::~Z3SATSolver() { clearSolveScope(); }
151void Z3SATSolver::add(
int lit) {
154 addClauseInternal(clauseBuffer);
155 clauseBuffer.clear();
159 reserveVars(std::abs(
lit));
160 clauseBuffer.push_back(
lit);
163void Z3SATSolver::assume(
int lit) {
167 assumptions.push_back(
lit);
170IncrementalSATSolver::Result Z3SATSolver::solve() {
171 auto localAssumptions = assumptions;
173 return solve(localAssumptions);
176IncrementalSATSolver::Result
177Z3SATSolver::solve(llvm::ArrayRef<int> assumptions) {
180 solveScopeActive =
true;
181 for (
int lit : assumptions)
182 solver->addConstraint(literalToExpr(
lit));
183 auto result = solver->check();
185 return lastResult = kUNKNOWN;
187 return lastResult = kSAT;
188 return lastResult = kUNSAT;
191int Z3SATSolver::val(
int v)
const {
192 if (lastResult != kSAT || v <= 0 || v > maxVariable)
194 llvm::APSInt value(llvm::APInt(1, 0),
true);
198 if (!solver->getInterpretation(variables[v - 1], value))
200 return value != 0 ? v : -v;
203void Z3SATSolver::reserveVars(
int maxVar) {
204 if (maxVar <= maxVariable)
206 while (
static_cast<int>(variables.size()) < maxVar)
208 maxVariable = maxVar;
211int Z3SATSolver::newVar() {
212 int var = newVariable();
213 if (var > maxVariable)
218void Z3SATSolver::clearSolveScope() {
219 if (!solveScopeActive)
222 solveScopeActive =
false;
223 lastResult = kUNKNOWN;
226int Z3SATSolver::newVariable() {
227 int varIndex =
static_cast<int>(variables.size()) + 1;
228 std::string name =
"v" + std::to_string(varIndex);
229 variables.push_back(solver->mkSymbol(name.c_str(), solver->getBoolSort()));
233llvm::SMTExprRef Z3SATSolver::literalToExpr(
int lit) {
234 int absLit = std::abs(
lit);
237 auto *variable = variables[absLit - 1];
238 return lit > 0 ? variable : solver->mkNot(variable);
241void Z3SATSolver::addClauseInternal(llvm::ArrayRef<int> lits) {
243 solver->addConstraint(solver->mkBoolean(
false));
247 llvm::SMTExprRef clause =
nullptr;
248 for (
int lit : lits) {
251 auto *expr = literalToExpr(
lit);
252 clause = clause ? solver->mkOr(clause, expr) : expr;
256 solver->addConstraint(solver->mkBoolean(
false));
259 solver->addConstraint(clause);
267 llvm::function_ref<
void(llvm::ArrayRef<int>)> addClause) {
268 for (
int lit : inputLits)
269 addClause({-outVar,
lit});
271 llvm::SmallVector<int> clause;
272 clause.reserve(inputLits.size() + 1);
273 for (
int lit : inputLits)
274 clause.push_back(-
lit);
275 clause.push_back(outVar);
280 llvm::function_ref<
void(llvm::ArrayRef<int>)> addClause) {
281 for (
int lit : inputLits)
282 addClause({-
lit, outVar});
284 llvm::SmallVector<int> clause;
285 clause.reserve(inputLits.size() + 1);
286 clause.push_back(-outVar);
287 clause.append(inputLits.begin(), inputLits.end());
292 llvm::function_ref<
void(llvm::ArrayRef<int>)> addClause) {
293 addClause({-lhsLit, -rhsLit, -outVar});
294 addClause({lhsLit, rhsLit, -outVar});
295 addClause({lhsLit, -rhsLit, outVar});
296 addClause({-lhsLit, rhsLit, outVar});
300 llvm::function_ref<
void(llvm::ArrayRef<int>)> addClause,
301 llvm::function_ref<
int()> newVar) {
302 assert(!inputLits.empty() &&
"parity requires at least one input");
303 if (inputLits.size() == 1) {
304 addClause({-outVar, inputLits.front()});
305 addClause({outVar, -inputLits.front()});
309 int accumulatedLit = inputLits.front();
310 for (
auto [index,
lit] : llvm::enumerate(inputLits.drop_front())) {
311 bool isLast = index + 2 == inputLits.size();
312 int outLit = isLast ? outVar : newVar();
314 accumulatedLit = outLit;
319 llvm::ArrayRef<int> inputLits,
320 llvm::function_ref<
void(llvm::ArrayRef<int>)> addClause,
321 llvm::function_ref<
int()> newVar) {
322 if (inputLits.size() < 2)
326 auto imply = [&](
int lhs,
int rhs) { addClause({-lhs, rhs}); };
335 llvm::SmallVector<int, 8> ladder(inputLits.size() - 1);
336 for (
int &var : ladder)
339 imply(inputLits.front(), ladder.front());
340 for (
unsigned i = 1, e = inputLits.size() - 1; i < e; ++i) {
343 imply(inputLits[i], ladder[i]);
346 imply(ladder[i - 1], ladder[i]);
349 imply(ladder[i - 1], -inputLits[i]);
354 imply(ladder.back(), -inputLits.back());
358 llvm::ArrayRef<int> inputLits,
359 llvm::function_ref<
void(llvm::ArrayRef<int>)> addClause,
360 llvm::function_ref<
int()> newVar) {
361 addClause(inputLits);
367 return std::make_unique<Z3SATSolver>();
373std::unique_ptr<IncrementalSATSolver>
375#ifdef CIRCT_CADICAL_ENABLED
376 return std::make_unique<CadicalSATSolver>(options);
assert(baseType &&"element must be base type")
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
bool hasIncrementalSATSolverBackend()
Return true when at least one incremental SAT backend is available.
void addExactlyOneClauses(llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause, llvm::function_ref< int()> newVar)
Emit clauses encoding that exactly one literal in inputLits is true.
void addAndClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> and(inputLits).
void addXorClauses(int outVar, int lhsLit, int rhsLit, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> (lhsLit xor rhsLit).
std::unique_ptr< IncrementalSATSolver > createZ3SATSolver()
Construct a Z3-backed incremental IPASIR-style SAT solver.
std::unique_ptr< IncrementalSATSolver > createCadicalSATSolver(const CadicalSATSolverOptions &options={})
Construct a CaDiCaL-backed incremental IPASIR-style SAT solver.
void addOrClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> or(inputLits).
void addParityClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause, llvm::function_ref< int()> newVar)
Emit clauses encoding outVar <=> parity(inputLits).
void addAtMostOneClauses(llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause, llvm::function_ref< int()> newVar)
Emit clauses encoding that at most one literal in inputLits can be true.