16 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/SCF/IR/SCF.h"
20 #include "mlir/IR/ValueRange.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
23 #include "llvm/ADT/SmallVector.h"
26 #define GEN_PASS_DEF_CONVERTVERIFTOSMT
27 #include "circt/Conversion/Passes.h.inc"
31 using namespace circt;
45 matchAndRewrite(verif::AssertOp op, OpAdaptor adaptor,
46 ConversionPatternRewriter &rewriter)
const override {
47 Value cond = typeConverter->materializeTargetConversion(
49 adaptor.getProperty());
50 Value notCond = rewriter.create<smt::NotOp>(op.getLoc(), cond);
51 rewriter.replaceOpWithNewOp<smt::AssertOp>(op, notCond);
61 matchAndRewrite(verif::AssumeOp op, OpAdaptor adaptor,
62 ConversionPatternRewriter &rewriter)
const override {
63 Value cond = typeConverter->materializeTargetConversion(
65 adaptor.getProperty());
66 rewriter.replaceOpWithNewOp<smt::AssertOp>(op, cond);
76 struct LogicEquivalenceCheckingOpConversion
82 matchAndRewrite(verif::LogicEquivalenceCheckingOp op, OpAdaptor adaptor,
83 ConversionPatternRewriter &rewriter)
const override {
84 Location loc = op.getLoc();
85 auto *firstOutputs = adaptor.getFirstCircuit().front().getTerminator();
86 auto *secondOutputs = adaptor.getSecondCircuit().front().getTerminator();
88 if (firstOutputs->getNumOperands() == 0) {
91 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
true));
92 rewriter.replaceOp(op, trueVal);
96 smt::SolverOp solver =
97 rewriter.create<smt::SolverOp>(loc, rewriter.getI1Type(), ValueRange{});
98 rewriter.createBlock(&solver.getBodyRegion());
101 if (failed(rewriter.convertRegionTypes(&adaptor.getFirstCircuit(),
104 if (failed(rewriter.convertRegionTypes(&adaptor.getSecondCircuit(),
109 SmallVector<Value> inputs;
110 for (
auto arg : adaptor.getFirstCircuit().getArguments())
111 inputs.push_back(rewriter.create<smt::DeclareFunOp>(loc, arg.getType()));
120 rewriter.mergeBlocks(&adaptor.getFirstCircuit().front(), solver.getBody(),
122 rewriter.mergeBlocks(&adaptor.getSecondCircuit().front(), solver.getBody(),
124 rewriter.setInsertionPointToEnd(solver.getBody());
130 SmallVector<Value> outputsDifferent;
131 for (
auto [out1, out2] :
132 llvm::zip(firstOutputs->getOperands(), secondOutputs->getOperands())) {
133 Value o1 = typeConverter->materializeTargetConversion(
134 rewriter, loc, typeConverter->convertType(out1.getType()), out1);
135 Value o2 = typeConverter->materializeTargetConversion(
136 rewriter, loc, typeConverter->convertType(out1.getType()), out2);
137 outputsDifferent.emplace_back(
138 rewriter.create<smt::DistinctOp>(loc, o1, o2));
141 rewriter.eraseOp(firstOutputs);
142 rewriter.eraseOp(secondOutputs);
145 if (outputsDifferent.size() == 1)
146 toAssert = outputsDifferent[0];
148 toAssert = rewriter.create<smt::OrOp>(loc, outputsDifferent);
150 rewriter.create<smt::AssertOp>(loc, toAssert);
154 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
false));
156 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
true));
157 auto checkOp = rewriter.create<smt::CheckOp>(loc, rewriter.getI1Type());
158 rewriter.createBlock(&checkOp.getSatRegion());
159 rewriter.create<smt::YieldOp>(loc, falseVal);
160 rewriter.createBlock(&checkOp.getUnknownRegion());
161 rewriter.create<smt::YieldOp>(loc, falseVal);
162 rewriter.createBlock(&checkOp.getUnsatRegion());
163 rewriter.create<smt::YieldOp>(loc, trueVal);
164 rewriter.setInsertionPointAfter(checkOp);
165 rewriter.create<smt::YieldOp>(loc, checkOp->getResults());
167 rewriter.replaceOp(op, solver->getResults());
174 struct VerifBoundedModelCheckingOpConversion
178 VerifBoundedModelCheckingOpConversion(TypeConverter &converter,
183 matchAndRewrite(verif::BoundedModelCheckingOp op, OpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter)
const override {
185 Location loc = op.getLoc();
186 SmallVector<Type> oldLoopInputTy(op.getLoop().getArgumentTypes());
187 SmallVector<Type> oldCircuitInputTy(op.getCircuit().getArgumentTypes());
191 SmallVector<Type> loopInputTy, circuitInputTy, initOutputTy,
193 if (failed(typeConverter->convertTypes(oldLoopInputTy, loopInputTy)))
195 if (failed(typeConverter->convertTypes(oldCircuitInputTy, circuitInputTy)))
197 if (failed(typeConverter->convertTypes(
198 op.getInit().front().back().getOperandTypes(), initOutputTy)))
200 if (failed(typeConverter->convertTypes(
201 op.getCircuit().front().back().getOperandTypes(), circuitOutputTy)))
203 if (failed(rewriter.convertRegionTypes(&op.getInit(), *typeConverter)))
205 if (failed(rewriter.convertRegionTypes(&op.getLoop(), *typeConverter)))
207 if (failed(rewriter.convertRegionTypes(&op.getCircuit(), *typeConverter)))
210 unsigned numRegs = op.getNumRegs();
211 auto initialValues = op.getInitialValues();
213 auto initFuncTy = rewriter.getFunctionType({}, initOutputTy);
216 auto loopFuncTy = rewriter.getFunctionType(loopInputTy, initOutputTy);
218 rewriter.getFunctionType(circuitInputTy, circuitOutputTy);
220 func::FuncOp initFuncOp, loopFuncOp, circuitFuncOp;
223 OpBuilder::InsertionGuard guard(rewriter);
224 rewriter.setInsertionPointToEnd(
225 op->getParentOfType<ModuleOp>().getBody());
226 initFuncOp = rewriter.create<func::FuncOp>(loc, names.newName(
"bmc_init"),
228 rewriter.inlineRegionBefore(op.getInit(), initFuncOp.getFunctionBody(),
230 loopFuncOp = rewriter.create<func::FuncOp>(loc, names.newName(
"bmc_loop"),
232 rewriter.inlineRegionBefore(op.getLoop(), loopFuncOp.getFunctionBody(),
234 circuitFuncOp = rewriter.create<func::FuncOp>(
235 loc, names.newName(
"bmc_circuit"), circuitFuncTy);
236 rewriter.inlineRegionBefore(op.getCircuit(),
237 circuitFuncOp.getFunctionBody(),
238 circuitFuncOp.end());
239 auto funcOps = {&initFuncOp, &loopFuncOp, &circuitFuncOp};
241 auto outputTys = {initOutputTy, initOutputTy, circuitOutputTy};
242 for (
auto [funcOp, outputTy] : llvm::zip(funcOps, outputTys)) {
243 auto operands = funcOp->getBody().front().back().getOperands();
244 rewriter.eraseOp(&funcOp->getFunctionBody().front().back());
245 rewriter.setInsertionPointToEnd(&funcOp->getBody().front());
246 SmallVector<Value> toReturn;
247 for (
unsigned i = 0; i < outputTy.size(); ++i)
248 toReturn.push_back(typeConverter->materializeTargetConversion(
249 rewriter, loc, outputTy[i], operands[i]));
250 rewriter.create<func::ReturnOp>(loc, toReturn);
255 rewriter.create<smt::SolverOp>(loc, rewriter.getI1Type(), ValueRange{});
256 rewriter.createBlock(&solver.getBodyRegion());
259 ValueRange initVals =
260 rewriter.create<func::CallOp>(loc, initFuncOp)->getResults();
263 rewriter.create<smt::PushOp>(loc, 1);
268 size_t initIndex = 0;
269 SmallVector<Value> inputDecls;
270 SmallVector<int> clockIndexes;
271 for (
auto [curIndex, oldTy, newTy] :
272 llvm::enumerate(oldCircuitInputTy, circuitInputTy)) {
273 if (isa<seq::ClockType>(oldTy)) {
274 inputDecls.push_back(initVals[initIndex++]);
275 clockIndexes.push_back(curIndex);
278 if (curIndex >= oldCircuitInputTy.size() - numRegs) {
280 initialValues[curIndex - oldCircuitInputTy.size() + numRegs];
281 if (
auto initIntAttr = dyn_cast<IntegerAttr>(initVal)) {
282 inputDecls.push_back(rewriter.create<smt::BVConstantOp>(
283 loc, initIntAttr.getValue().getSExtValue(),
284 cast<smt::BitVectorType>(newTy).getWidth()));
288 inputDecls.push_back(rewriter.create<smt::DeclareFunOp>(loc, newTy));
291 auto numStateArgs = initVals.size() - initIndex;
293 for (; initIndex < initVals.size(); ++initIndex)
294 inputDecls.push_back(initVals[initIndex]);
297 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
299 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
301 rewriter.create<arith::ConstantOp>(loc, adaptor.getBoundAttr());
303 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
false));
305 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
true));
306 inputDecls.push_back(constFalse);
311 auto forOp = rewriter.create<scf::ForOp>(
312 loc, lowerBound, upperBound, step, inputDecls,
313 [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
315 builder.create<smt::PopOp>(loc, 1);
316 builder.create<smt::PushOp>(loc, 1);
319 ValueRange circuitCallOuts =
321 .create<func::CallOp>(
323 iterArgs.take_front(circuitFuncOp.getNumArguments()))
326 rewriter.create<smt::CheckOp>(loc, builder.getI1Type());
328 OpBuilder::InsertionGuard guard(builder);
329 builder.createBlock(&checkOp.getSatRegion());
330 builder.create<smt::YieldOp>(loc, constTrue);
331 builder.createBlock(&checkOp.getUnknownRegion());
332 builder.create<smt::YieldOp>(loc, constTrue);
333 builder.createBlock(&checkOp.getUnsatRegion());
334 builder.create<smt::YieldOp>(loc, constFalse);
337 Value violated = builder.create<arith::OrIOp>(
338 loc, checkOp.getResult(0), iterArgs.back());
341 SmallVector<Value> loopCallInputs;
343 for (
auto index : clockIndexes)
344 loopCallInputs.push_back(iterArgs[index]);
346 for (
auto stateArg : iterArgs.drop_back().take_back(numStateArgs))
347 loopCallInputs.push_back(stateArg);
348 ValueRange loopVals =
349 builder.create<func::CallOp>(loc, loopFuncOp, loopCallInputs)
352 size_t loopIndex = 0;
354 SmallVector<Value> newDecls;
355 for (
auto [oldTy, newTy] :
356 llvm::zip(TypeRange(oldCircuitInputTy).drop_back(numRegs),
357 TypeRange(circuitInputTy).drop_back(numRegs))) {
358 if (isa<seq::ClockType>(oldTy))
359 newDecls.push_back(loopVals[loopIndex++]);
361 newDecls.push_back(builder.create<smt::DeclareFunOp>(loc, newTy));
367 if (clockIndexes.size() == 1) {
368 auto clockIndex = clockIndexes[0];
369 auto oldClock = iterArgs[clockIndex];
370 auto newClock = loopVals[clockIndex];
371 auto oldClockLow = builder.create<smt::BVNotOp>(loc, oldClock);
373 builder.create<smt::BVAndOp>(loc, oldClockLow, newClock);
375 auto trueBV = builder.create<smt::BVConstantOp>(loc, 1, 1);
377 builder.create<smt::EqOp>(loc, isPosedgeBV, trueBV);
379 iterArgs.take_front(circuitFuncOp.getNumArguments())
381 auto regInputs = circuitCallOuts.take_back(numRegs);
382 SmallVector<Value> nextRegStates;
383 for (
auto [regState, regInput] : llvm::zip(regStates, regInputs)) {
387 nextRegStates.push_back(builder.create<smt::IteOp>(
388 loc, isPosedge, regInput, regState));
390 newDecls.append(nextRegStates);
394 for (; loopIndex < loopVals.size(); ++loopIndex)
395 newDecls.push_back(loopVals[loopIndex]);
397 newDecls.push_back(violated);
399 builder.create<scf::YieldOp>(loc, newDecls);
402 Value res = rewriter.create<arith::XOrIOp>(loc, forOp->getResults().back(),
404 rewriter.create<smt::YieldOp>(loc, res);
405 rewriter.replaceOp(op, solver.getResults());
419 struct ConvertVerifToSMTPass
420 :
public circt::impl::ConvertVerifToSMTBase<ConvertVerifToSMTPass> {
421 void runOnOperation()
override;
428 patterns.add<VerifAssertOpConversion, VerifAssumeOpConversion,
429 LogicEquivalenceCheckingOpConversion>(converter,
431 patterns.add<VerifBoundedModelCheckingOpConversion>(
432 converter,
patterns.getContext(), names);
435 void ConvertVerifToSMTPass::runOnOperation() {
436 ConversionTarget target(getContext());
437 target.addIllegalDialect<verif::VerifDialect>();
438 target.addLegalDialect<smt::SMTDialect, arith::ArithDialect, scf::SCFDialect,
439 func::FuncDialect>();
440 target.addLegalOp<UnrealizedConversionCastOp>();
444 SymbolTable symbolTable(getOperation());
445 WalkResult assertionCheck = getOperation().walk(
447 if (
auto bmcOp = dyn_cast<verif::BoundedModelCheckingOp>(op)) {
450 auto regTypes = TypeRange(bmcOp.getCircuit().getArgumentTypes())
451 .take_back(bmcOp.getNumRegs());
452 for (
auto [regType, initVal] :
453 llvm::zip(regTypes, bmcOp.getInitialValues())) {
454 if (!isa<IntegerType>(regType) && !isa<UnitAttr>(initVal)) {
455 op->emitError(
"initial values are currently only supported for "
456 "registers with integer types");
458 return WalkResult::interrupt();
462 auto numClockArgs = 0;
463 for (
auto argType : bmcOp.getCircuit().getArgumentTypes())
464 if (isa<seq::ClockType>(argType))
468 if (numClockArgs > 1) {
470 "only modules with one or zero clocks are currently supported");
471 return WalkResult::interrupt();
473 SmallVector<mlir::Operation *> worklist;
474 int numAssertions = 0;
475 op->walk([&](Operation *curOp) {
476 if (isa<verif::AssertOp>(curOp))
478 if (
auto inst = dyn_cast<InstanceOp>(curOp))
479 worklist.push_back(symbolTable.lookup(inst.getModuleName()));
483 while (!worklist.empty()) {
484 auto *module = worklist.pop_back_val();
485 module->walk([&](Operation *curOp) {
486 if (isa<verif::AssertOp>(curOp))
488 if (
auto inst = dyn_cast<InstanceOp>(curOp))
489 worklist.push_back(symbolTable.lookup(inst.getModuleName()));
491 if (numAssertions > 1)
494 if (numAssertions > 1) {
496 "bounded model checking problems with multiple assertions are "
498 "correctly handled - instead, you can assert the "
499 "conjunction of your assertions");
500 return WalkResult::interrupt();
503 return WalkResult::advance();
505 if (assertionCheck.wasInterrupted())
506 return signalPassFailure();
507 RewritePatternSet
patterns(&getContext());
508 TypeConverter converter;
518 if (failed(mlir::applyPartialConversion(getOperation(), target,
520 return signalPassFailure();
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateVerifToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns, Namespace &names)
Get the Verif to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.