18 #include "mlir/IR/Builders.h"
19 #include "llvm/ADT/STLExtras.h"
21 #define DEBUG_TYPE "lec-circuit"
24 using namespace circt;
27 void Solver::Circuit::addInput(Value
value) {
28 LLVM_DEBUG(
lec::dbgs() << name <<
" addInput\n");
30 z3::expr input = fetchOrAllocateExpr(
value);
35 void Solver::Circuit::addOutput(Value
value) {
36 LLVM_DEBUG(
lec::dbgs() << name <<
" addOutput\n");
38 z3::expr output = fetchOrAllocateExpr(
value);
43 llvm::ArrayRef<z3::expr> Solver::Circuit::getInputs() {
return inputs; }
46 llvm::ArrayRef<z3::expr> Solver::Circuit::getOutputs() {
return outputs; }
52 void Solver::Circuit::addConstant(Value opResult,
const APInt &opValue) {
53 LLVM_DEBUG(
lec::dbgs() << name <<
" addConstant\n");
55 allocateConstant(opResult, opValue);
58 void Solver::Circuit::addInstance(llvm::StringRef instanceName,
59 circt::hw::HWModuleOp op,
60 OperandRange arguments, ResultRange results) {
61 LLVM_DEBUG(
lec::dbgs() << name <<
" addInstance\n");
63 LLVM_DEBUG(
lec::dbgs() <<
"instance name: " << instanceName <<
"\n");
64 LLVM_DEBUG(
lec::dbgs() <<
"module name: " << op->getName() <<
"\n");
67 std::string suffix =
"_" + std::to_string(assignments);
68 Circuit instance(name +
"@" + instanceName + suffix, solver);
73 assert(res.succeeded() &&
"Instance visit failed");
78 LLVM_DEBUG(
lec::dbgs() <<
"instance inputs:\n");
80 auto *input = instance.
inputs.begin();
81 for (Value argument : arguments) {
83 z3::expr argExpr = fetchOrAllocateExpr(argument);
84 solver.solver.add(argExpr == *input++);
88 LLVM_DEBUG(
lec::dbgs() <<
"instance results:\n");
90 auto *output = instance.
outputs.begin();
91 for (circt::OpResult result : results) {
92 z3::expr resultExpr = fetchOrAllocateExpr(result);
93 solver.solver.add(resultExpr == *output++);
102 void Solver::Circuit::performAdd(Value result, OperandRange operands) {
103 LLVM_DEBUG(
lec::dbgs() << name <<
" perform Add\n");
105 variadicOperation(result, operands,
106 [](
auto op1,
auto op2) {
return op1 + op2; });
109 void Solver::Circuit::performAnd(Value result, OperandRange operands) {
110 LLVM_DEBUG(
lec::dbgs() << name <<
" perform And\n");
112 variadicOperation(result, operands,
113 [](
auto op1,
auto op2) {
return z3::operator&(op1, op2); });
116 void Solver::Circuit::performConcat(Value result, OperandRange operands) {
117 LLVM_DEBUG(
lec::dbgs() << name <<
" perform Concat\n");
119 variadicOperation(result, operands,
120 [](
auto op1,
auto op2) {
return z3::concat(op1, op2); });
123 void Solver::Circuit::performDivS(Value result, Value lhs, Value rhs) {
124 LLVM_DEBUG(
lec::dbgs() << name <<
" perform DivS\n");
127 z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
129 z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
130 z3::expr op = z3::operator/(lhsExpr, rhsExpr);
131 constrainResult(result, op);
134 void Solver::Circuit::performDivU(Value result, Value lhs, Value rhs) {
135 LLVM_DEBUG(
lec::dbgs() << name <<
" perform DivU\n");
138 z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
140 z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
141 z3::expr op = z3::udiv(lhsExpr, rhsExpr);
142 constrainResult(result, op);
145 void Solver::Circuit::performExtract(Value result, Value input,
147 LLVM_DEBUG(
lec::dbgs() << name <<
" performExtract\n");
150 z3::expr inputExpr = fetchOrAllocateExpr(input);
151 unsigned width = result.getType().getIntOrFloatBitWidth();
153 z3::expr extract = inputExpr.extract(lowBit +
width - 1, lowBit);
154 constrainResult(result, extract);
157 LogicalResult Solver::Circuit::performICmp(Value result,
158 circt::comb::ICmpPredicate predicate,
159 Value lhs, Value rhs) {
160 LLVM_DEBUG(
lec::dbgs() << name <<
" performICmp\n");
163 z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
165 z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
166 z3::expr icmp(solver.context);
169 case circt::comb::ICmpPredicate::eq:
170 icmp = boolToBv(lhsExpr == rhsExpr);
172 case circt::comb::ICmpPredicate::ne:
173 icmp = boolToBv(lhsExpr != rhsExpr);
175 case circt::comb::ICmpPredicate::slt:
176 icmp = boolToBv(z3::slt(lhsExpr, rhsExpr));
178 case circt::comb::ICmpPredicate::sle:
179 icmp = boolToBv(z3::sle(lhsExpr, rhsExpr));
181 case circt::comb::ICmpPredicate::sgt:
182 icmp = boolToBv(z3::sgt(lhsExpr, rhsExpr));
184 case circt::comb::ICmpPredicate::sge:
185 icmp = boolToBv(z3::sge(lhsExpr, rhsExpr));
187 case circt::comb::ICmpPredicate::ult:
188 icmp = boolToBv(z3::ult(lhsExpr, rhsExpr));
190 case circt::comb::ICmpPredicate::ule:
191 icmp = boolToBv(z3::ule(lhsExpr, rhsExpr));
193 case circt::comb::ICmpPredicate::ugt:
194 icmp = boolToBv(z3::ugt(lhsExpr, rhsExpr));
196 case circt::comb::ICmpPredicate::uge:
197 icmp = boolToBv(z3::uge(lhsExpr, rhsExpr));
200 case circt::comb::ICmpPredicate::ceq:
201 case circt::comb::ICmpPredicate::weq:
202 case circt::comb::ICmpPredicate::cne:
203 case circt::comb::ICmpPredicate::wne:
204 result.getDefiningOp()->emitError(
205 "n-state logic predicates are not supported");
209 constrainResult(result, icmp);
213 void Solver::Circuit::performModS(Value result, Value lhs, Value rhs) {
214 LLVM_DEBUG(
lec::dbgs() << name <<
" perform ModS\n");
217 z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
219 z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
220 z3::expr op = z3::smod(lhsExpr, rhsExpr);
221 constrainResult(result, op);
224 void Solver::Circuit::performModU(Value result, Value lhs, Value rhs) {
225 LLVM_DEBUG(
lec::dbgs() << name <<
" perform ModU\n");
228 z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
230 z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
231 z3::expr op = z3::urem(lhsExpr, rhsExpr);
232 constrainResult(result, op);
235 void Solver::Circuit::performMul(Value result, OperandRange operands) {
236 LLVM_DEBUG(
lec::dbgs() << name <<
" perform Mul\n");
238 variadicOperation(result, operands,
239 [](
auto op1,
auto op2) {
return op1 * op2; });
242 void Solver::Circuit::performMux(Value result, Value cond, Value trueValue,
244 LLVM_DEBUG(
lec::dbgs() << name <<
" performMux\n");
247 z3::expr condExpr = fetchOrAllocateExpr(cond);
248 LLVM_DEBUG(
lec::dbgs() <<
"trueValue:\n");
249 z3::expr tvalue = fetchOrAllocateExpr(trueValue);
250 LLVM_DEBUG(
lec::dbgs() <<
"falseValue:\n");
251 z3::expr fvalue = fetchOrAllocateExpr(falseValue);
253 z3::expr mux = z3::ite(bvToBool(condExpr), tvalue, fvalue);
254 constrainResult(result, mux);
257 void Solver::Circuit::performOr(Value result, OperandRange operands) {
258 LLVM_DEBUG(
lec::dbgs() << name <<
" perform Or\n");
260 variadicOperation(result, operands,
261 [](
auto op1,
auto op2) {
return op1 | op2; });
264 void Solver::Circuit::performParity(Value result, Value input) {
265 LLVM_DEBUG(
lec::dbgs() << name <<
" performParity\n");
268 z3::expr inputExpr = fetchOrAllocateExpr(input);
270 unsigned width = inputExpr.get_sort().bv_size();
273 z3::expr parity = inputExpr.extract(0, 0);
275 for (
unsigned int i = 1; i <
width; i++) {
276 parity = parity ^ inputExpr.extract(i, i);
279 constrainResult(result, parity);
282 void Solver::Circuit::performReplicate(Value result, Value input) {
283 LLVM_DEBUG(
lec::dbgs() << name <<
" performReplicate\n");
286 z3::expr inputExpr = fetchOrAllocateExpr(input);
288 unsigned int final = result.getType().getIntOrFloatBitWidth();
289 unsigned int initial = input.getType().getIntOrFloatBitWidth();
290 unsigned int times =
final / initial;
291 LLVM_DEBUG(
lec::dbgs() <<
"replies: " << times <<
"\n");
293 z3::expr replicate = inputExpr;
294 for (
unsigned int i = 1; i < times; i++) {
298 constrainResult(result, replicate);
301 void Solver::Circuit::performShl(Value result, Value lhs, Value rhs) {
302 LLVM_DEBUG(
lec::dbgs() << name <<
" perform Shl\n");
305 z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
307 z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
308 z3::expr op = z3::shl(lhsExpr, rhsExpr);
309 constrainResult(result, op);
313 void Solver::Circuit::performShrS(Value result, Value lhs, Value rhs) {
314 LLVM_DEBUG(
lec::dbgs() << name <<
" perform ShrS\n");
317 z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
319 z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
320 z3::expr op = z3::ashr(lhsExpr, rhsExpr);
321 constrainResult(result, op);
325 void Solver::Circuit::performShrU(Value result, Value lhs, Value rhs) {
326 LLVM_DEBUG(
lec::dbgs() << name <<
" perform ShrU\n");
329 z3::expr lhsExpr = fetchOrAllocateExpr(lhs);
331 z3::expr rhsExpr = fetchOrAllocateExpr(rhs);
332 z3::expr op = z3::lshr(lhsExpr, rhsExpr);
333 constrainResult(result, op);
336 void Solver::Circuit::performSub(Value result, OperandRange operands) {
337 LLVM_DEBUG(
lec::dbgs() << name <<
" perform Sub\n");
339 variadicOperation(result, operands,
340 [](
auto op1,
auto op2) {
return op1 - op2; });
343 void Solver::Circuit::performXor(Value result, OperandRange operands) {
344 LLVM_DEBUG(
lec::dbgs() << name <<
" perform Xor\n");
346 variadicOperation(result, operands,
347 [](
auto op1,
auto op2) {
return op1 ^ op2; });
352 void Solver::Circuit::variadicOperation(
353 Value result, OperandRange operands,
354 llvm::function_ref<z3::expr(
const z3::expr &,
const z3::expr &)>
357 LLVM_DEBUG(
lec::dbgs() <<
"variadic operation\n");
360 auto it = operands.begin();
362 z3::expr varOp = fetchOrAllocateExpr(operand);
364 LLVM_DEBUG(
lec::dbgs() <<
"first operand:\n");
370 while (it != operands.end()) {
372 varOp = operation(varOp, fetchOrAllocateExpr(operand));
374 LLVM_DEBUG(
lec::dbgs() <<
"next operand:\n");
380 constrainResult(result, varOp);
385 z3::expr Solver::Circuit::fetchOrAllocateExpr(Value
value) {
386 z3::expr expr(solver.context);
387 auto exprPair = exprTable.find(
value);
388 if (exprPair != exprTable.end()) {
389 LLVM_DEBUG(
lec::dbgs() <<
"value already allocated:\n");
391 expr = exprPair->second;
395 std::string
valueName = name +
"%" + std::to_string(assignments++);
396 LLVM_DEBUG(
lec::dbgs() <<
"allocating value:\n");
398 Type type =
value.getType();
399 assert(type.isSignlessInteger() &&
"Unsupported type");
400 unsigned int width = type.getIntOrFloatBitWidth();
403 assert(
width > 0 &&
"0-width integers are not supported");
407 auto exprInsertion = exprTable.insert(std::pair(
value, expr));
409 assert(exprInsertion.second &&
"Value not inserted in expression table");
410 Builder
builder(solver.mlirCtx);
412 auto symInsertion = solver.symbolTable.insert(std::pair(symbol,
value));
414 assert(symInsertion.second &&
"Value not inserted in symbol table");
421 void Solver::Circuit::allocateConstant(Value result,
const APInt &
value) {
424 const z3::expr constant =
425 solver.context.bv_val(
value.getZExtValue(),
value.getBitWidth());
427 auto allocatedPair = exprTable.find(result);
428 if (allocatedPair == exprTable.end()) {
430 auto insertion = exprTable.insert(std::pair(result, constant));
432 assert(insertion.second &&
"Constant not inserted in expression table");
439 solver.solver.add(allocatedPair->second == constant);
440 LLVM_DEBUG(
lec::dbgs() <<
"constraining symbolic value to constant:\n");
449 void Solver::Circuit::constrainResult(Value &result, z3::expr &expr) {
450 LLVM_DEBUG(
lec::dbgs() <<
"constraining result:\n");
453 LLVM_DEBUG(
lec::dbgs() <<
"result expression:\n");
457 z3::expr resExpr = fetchOrAllocateExpr(result);
458 z3::expr constraint = resExpr == expr;
460 LLVM_DEBUG(
lec::dbgs() <<
"adding constraint:\n");
462 LLVM_DEBUG(
lec::dbgs() << constraint.to_string() <<
"\n");
464 solver.solver.add(constraint);
468 z3::expr Solver::Circuit::bvToBool(
const z3::expr &condition) {
470 return condition != 0;
474 z3::expr Solver::Circuit::boolToBv(
const z3::expr &condition) {
475 return z3::ite(condition, solver.context.bv_val(1, 1),
476 solver.context.bv_val(0, 1));
assert(baseType &&"element must be base type")
static std::string valueName(Operation *scopeOp, Value v)
Convenience function for getting the SSA name of v under the scope of operation scopeOp.
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
A class traversing MLIR IR to extrapolate the logic of a given circuit.
mlir::LogicalResult run(mlir::ModuleOp &module)
Initializes the exporting by visiting the builtin module.
The representation of a circuit within a logical engine.
llvm::SmallVector< z3::expr > outputs
The list for the circuit's outputs.
llvm::SmallVector< z3::expr > inputs
The list for the circuit's inputs.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
void printExpr(const z3::expr &expr)
Helper function to provide a common debug formatting for z3 expressions.
void printValue(const mlir::Value &value)
Helper function to provide a common debug formatting for MLIR values.
mlir::raw_indented_ostream & dbgs()
RAII struct to indent the output streams.