CIRCT  19.0.0git
Solver.cpp
Go to the documentation of this file.
1 //===-- Solver.h - SMT solver interface -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// This file defines a SMT solver interface for the 'circt-lec' tool.
10 ///
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/Builders.h"
18 #include <string>
19 #include <z3++.h>
20 
21 #define DEBUG_TYPE "lec-solver"
22 
23 using namespace circt;
24 using namespace mlir;
25 
26 Solver::Solver(MLIRContext *mlirCtx, bool statisticsOpt)
27  : circuits{}, mlirCtx(mlirCtx), context(), solver(context),
28  statisticsOpt(statisticsOpt) {}
29 
30 /// Solve the equivalence problem between the two circuits, then present the
31 /// results to the user.
32 LogicalResult Solver::solve() {
33  // Constrain the circuits for equivalence checking to be made:
34  // require them to produce different outputs starting from the same inputs.
35  if (constrainCircuits().failed())
36  return failure();
37 
38  // Instruct the logical engine to solve the constraints:
39  // if they can't be satisfied it must mean the two circuits are functionally
40  // equivalent. Otherwise, print a model to act as a counterexample.
41  LogicalResult outcome = success();
42  switch (solver.check()) {
43  case z3::unsat:
44  lec::outs() << "c1 == c2\n";
45  break;
46  case z3::sat:
47  lec::outs() << "c1 != c2\n";
48  printModel();
49  outcome = failure();
50  break;
51  case z3::unknown:
52  outcome = failure();
53  lec::errs() << "circt-lec error: solver ran out of time\n";
54  }
55 
56  // Print further relevant information as requested.
57  LLVM_DEBUG(printAssertions());
58  if (statisticsOpt)
60 
61  return outcome;
62 }
63 
64 /// Create a new circuit to be compared and return it.
65 Solver::Circuit *Solver::addCircuit(llvm::StringRef name) {
66  // NOLINTNEXTLINE
67  assert(!(circuits[0] && circuits[1]) && "already added two circuits");
68  // Hack: entities within the logical engine are namespaced by the circuit
69  // they belong to, which may cause shadowing when parsing two files with a
70  // similar module naming scheme.
71  // To avoid that, they're differentiated by a prefix.
72  unsigned n = circuits[0] ? 1 : 0;
73  std::string prefix = n == 0 ? "c1@" : "c2@";
74  circuits[n] = new Solver::Circuit(prefix + name, *this);
75  return circuits[n];
76 }
77 
78 /// Prints a model satisfying the solved constraints.
80  lec::outs() << "Model:\n";
81  lec::Scope indent;
82  z3::model model = solver.get_model();
83  for (unsigned int i = 0; i < model.size(); i++) {
84  // Recover the corresponding Value for the z3::expression
85  // then emit a remark for its location.
86  z3::func_decl f = model.get_const_decl(i);
87  Builder builder(mlirCtx);
88  std::string symbolStr = f.name().str();
89  StringAttr symbol = builder.getStringAttr(symbolStr);
90  Value value = symbolTable.find(symbol)->second;
91  z3::expr e = model.get_const_interp(f);
92  emitRemark(value.getLoc(), "");
93  // Explicitly unfolded asm printing for `Value`.
94  if (auto arg = value.dyn_cast<BlockArgument>()) {
95  // Value is an argument rather than a SSA'ed value of an operation.
96  Operation *parentOp = value.getParentRegion()->getParentOp();
97  if (auto op = llvm::dyn_cast<hw::HWModuleOp>(parentOp)) {
98  // Argument of a `hw.module`.
99  lec::outs() << "argument name: " << op.getArgName(arg.getArgNumber())
100  << "\n";
101  } else {
102  // Argument of a different operation.
103  lec::outs() << arg << "\n";
104  }
105  }
106  // Accompanying model information.
107  lec::outs() << "internal symbol: " << symbol << "\n";
108  lec::outs() << "model interpretation: " << e.to_string() << "\n\n";
109  }
110 }
111 
112 /// Prints the constraints which were added to the solver.
113 /// Compared to solver.assertions().to_string() this method exposes each
114 /// assertion as a z3::expression for eventual in-depth debugging.
116  lec::dbgs() << "Assertions:\n";
117  lec::Scope indent;
118  for (z3::expr assertion : solver.assertions()) {
119  lec::dbgs() << assertion.to_string() << "\n";
120  }
121 }
122 
123 /// Prints the internal statistics of the SMT solver for benchmarking purposes
124 /// and operational insight.
126  lec::outs() << "SMT solver statistics:\n";
127  lec::Scope indent;
128  z3::stats stats = solver.statistics();
129  for (unsigned i = 0; i < stats.size(); i++) {
130  lec::outs() << stats.key(i) << " : " << stats.uint_value(i) << "\n";
131  }
132 }
133 
134 /// Formulates additional constraints which are satisfiable if only if the
135 /// two circuits which are being compared are NOT equivalent, in which case
136 /// there would be a model acting as a counterexample.
137 /// The procedure fails when detecting a mismatch of arity or type between
138 /// the inputs and outputs of the circuits.
139 LogicalResult Solver::constrainCircuits() {
140  // TODO: Perform these failure checks before nalyzing the whole IR of the
141  // modules during the pass.
142  auto c1Inputs = circuits[0]->getInputs();
143  auto c2Inputs = circuits[1]->getInputs();
144  unsigned nc1Inputs = std::distance(c1Inputs.begin(), c1Inputs.end());
145  unsigned nc2Inputs = std::distance(c2Inputs.begin(), c2Inputs.end());
146 
147  // Can't compare two circuits with different number of inputs.
148  if (nc1Inputs != nc2Inputs) {
149  lec::errs() << "circt-lec error: different input arity\n";
150  return failure();
151  }
152 
153  const auto *c1inIt = c1Inputs.begin();
154  const auto *c2inIt = c2Inputs.begin();
155  for (unsigned i = 0; i < nc1Inputs; i++) {
156  // Can't compare two circuits when their ith inputs differ in type.
157  if (c1inIt->get_sort().bv_size() != c2inIt->get_sort().bv_size()) {
158  lec::errs() << "circt-lec error: input #" << i + 1 << " type mismatch\n";
159  return failure();
160  }
161  // Their ith inputs have to be equivalent.
162  solver.add(*c1inIt++ == *c2inIt++);
163  }
164 
165  auto c1Outputs = circuits[0]->getOutputs();
166  auto c2Outputs = circuits[1]->getOutputs();
167  unsigned nc1Outputs = std::distance(c1Outputs.begin(), c1Outputs.end());
168  unsigned nc2Outputs = std::distance(c2Outputs.begin(), c2Outputs.end());
169 
170  // Can't compare two circuits with different number of outputs.
171  if (nc1Outputs != nc2Outputs) {
172  lec::errs() << "circt-lec error: different output arity\n";
173  return failure();
174  }
175 
176  z3::expr_vector outputTerms(context);
177 
178  const auto *c1outIt = c1Outputs.begin();
179  const auto *c2outIt = c2Outputs.begin();
180  for (unsigned i = 0; i < nc1Outputs; i++) {
181  // Can't compare two circuits when their ith outputs differ in type.
182  if (c1outIt->get_sort().bv_size() != c2outIt->get_sort().bv_size()) {
183  lec::errs() << "circt-lec error: output #" << i + 1 << " type mismatch\n";
184  return failure();
185  }
186  // Their ith outputs have to be equivalent.
187  outputTerms.push_back(*c1outIt++ != *c2outIt++);
188  }
189 
190  // The circuits are not equivalent iff any of the outputs is not equal
191  solver.add(z3::mk_or(outputTerms));
192 
193  return success();
194 }
assert(baseType &&"element must be base type")
Builder builder
The representation of a circuit within a logical engine.
Definition: Circuit.h:35
llvm::ArrayRef< z3::expr > getOutputs()
Recover the outputs.
Definition: Circuit.cpp:46
llvm::ArrayRef< z3::expr > getInputs()
Recover the inputs.
Definition: Circuit.cpp:43
mlir::LogicalResult constrainCircuits()
Formulates additional constraints which are satisfiable if only if the two circuits which are being c...
Definition: Solver.cpp:139
void printAssertions()
Prints the constraints which were added to the solver.
Definition: Solver.cpp:115
Circuit * circuits[2]
The two circuits to be compared.
Definition: Solver.h:65
z3::context context
The Z3 context of reference, owning all the declared values, constants and expressions.
Definition: Solver.h:70
void printModel()
Prints a model satisfying the solved constraints.
Definition: Solver.cpp:79
llvm::DenseMap< mlir::StringAttr, mlir::Value > symbolTable
A map from internal solver symbols to the IR values they represent.
Definition: Solver.h:63
mlir::LogicalResult solve()
Solve the equivalence problem between the two circuits, then present the results to the user.
Definition: Solver.cpp:32
Circuit * addCircuit(llvm::StringRef name)
Create a new circuit to be compared and return it.
Definition: Solver.cpp:65
mlir::MLIRContext * mlirCtx
The MLIR context of reference, owning all the MLIR entities.
Definition: Solver.h:67
bool statisticsOpt
The value of the statistics command-line option.
Definition: Solver.h:74
void printStatistics()
Prints the internal statistics of the SMT solver for benchmarking purposes and operational insight.
Definition: Solver.cpp:125
Solver(mlir::MLIRContext *mlirCtx, bool statisticsOpt)
Definition: Solver.cpp:26
z3::solver solver
The Z3 solver acting as the logical engine backend.
Definition: Solver.h:72
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
mlir::raw_indented_ostream & outs()
Definition: Utility.h:38
mlir::raw_indented_ostream & errs()
Definition: Utility.h:33
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28
RAII struct to indent the output streams.
Definition: Utility.h:44