14#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/SMT/IR/SMTOps.h"
19#include "mlir/Dialect/SMT/IR/SMTTypes.h"
20#include "mlir/IR/ValueRange.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/ADT/SmallVector.h"
26#define GEN_PASS_DEF_CONVERTVERIFTOSMT
27#include "circt/Conversion/Passes.h.inc"
45 matchAndRewrite(verif::AssertOp op, OpAdaptor adaptor,
46 ConversionPatternRewriter &rewriter)
const override {
47 Value cond = typeConverter->materializeTargetConversion(
48 rewriter, op.getLoc(), smt::BoolType::get(getContext()),
49 adaptor.getProperty());
50 Value notCond = rewriter.create<smt::NotOp>(op.getLoc(), cond);
51 rewriter.replaceOpWithNewOp<smt::AssertOp>(op, notCond);
61 matchAndRewrite(verif::AssumeOp op, OpAdaptor adaptor,
62 ConversionPatternRewriter &rewriter)
const override {
63 Value cond = typeConverter->materializeTargetConversion(
64 rewriter, op.getLoc(), smt::BoolType::get(getContext()),
65 adaptor.getProperty());
66 rewriter.replaceOpWithNewOp<smt::AssertOp>(op, cond);
76struct LogicEquivalenceCheckingOpConversion
82 matchAndRewrite(verif::LogicEquivalenceCheckingOp op, OpAdaptor adaptor,
83 ConversionPatternRewriter &rewriter)
const override {
84 Location loc = op.getLoc();
85 auto *firstOutputs = adaptor.getFirstCircuit().front().getTerminator();
86 auto *secondOutputs = adaptor.getSecondCircuit().front().getTerminator();
88 if (firstOutputs->getNumOperands() == 0) {
91 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
true));
92 rewriter.replaceOp(op, trueVal);
96 smt::SolverOp solver =
97 rewriter.create<smt::SolverOp>(loc, rewriter.getI1Type(), ValueRange{});
98 rewriter.createBlock(&solver.getBodyRegion());
101 if (failed(rewriter.convertRegionTypes(&adaptor.getFirstCircuit(),
104 if (failed(rewriter.convertRegionTypes(&adaptor.getSecondCircuit(),
109 SmallVector<Value> inputs;
110 for (
auto arg : adaptor.getFirstCircuit().getArguments())
111 inputs.push_back(rewriter.create<smt::DeclareFunOp>(loc, arg.getType()));
120 rewriter.mergeBlocks(&adaptor.getFirstCircuit().front(), solver.getBody(),
122 rewriter.mergeBlocks(&adaptor.getSecondCircuit().front(), solver.getBody(),
124 rewriter.setInsertionPointToEnd(solver.getBody());
130 SmallVector<Value> outputsDifferent;
131 for (
auto [out1, out2] :
132 llvm::zip(firstOutputs->getOperands(), secondOutputs->getOperands())) {
133 Value o1 = typeConverter->materializeTargetConversion(
134 rewriter, loc, typeConverter->convertType(out1.getType()), out1);
135 Value o2 = typeConverter->materializeTargetConversion(
136 rewriter, loc, typeConverter->convertType(out1.getType()), out2);
137 outputsDifferent.emplace_back(
138 rewriter.create<smt::DistinctOp>(loc, o1, o2));
141 rewriter.eraseOp(firstOutputs);
142 rewriter.eraseOp(secondOutputs);
145 if (outputsDifferent.size() == 1)
146 toAssert = outputsDifferent[0];
148 toAssert = rewriter.create<smt::OrOp>(loc, outputsDifferent);
150 rewriter.create<smt::AssertOp>(loc, toAssert);
154 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
false));
156 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
true));
157 auto checkOp = rewriter.create<smt::CheckOp>(loc, rewriter.getI1Type());
158 rewriter.createBlock(&checkOp.getSatRegion());
159 rewriter.create<smt::YieldOp>(loc, falseVal);
160 rewriter.createBlock(&checkOp.getUnknownRegion());
161 rewriter.create<smt::YieldOp>(loc, falseVal);
162 rewriter.createBlock(&checkOp.getUnsatRegion());
163 rewriter.create<smt::YieldOp>(loc, trueVal);
164 rewriter.setInsertionPointAfter(checkOp);
165 rewriter.create<smt::YieldOp>(loc, checkOp->getResults());
167 rewriter.replaceOp(op, solver->getResults());
174struct VerifBoundedModelCheckingOpConversion
178 VerifBoundedModelCheckingOpConversion(TypeConverter &converter,
180 bool risingClocksOnly)
182 risingClocksOnly(risingClocksOnly) {}
184 matchAndRewrite(verif::BoundedModelCheckingOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter)
const override {
186 Location loc = op.getLoc();
187 SmallVector<Type> oldLoopInputTy(op.getLoop().getArgumentTypes());
188 SmallVector<Type> oldCircuitInputTy(op.getCircuit().getArgumentTypes());
192 SmallVector<Type> loopInputTy, circuitInputTy, initOutputTy,
194 if (failed(typeConverter->convertTypes(oldLoopInputTy, loopInputTy)))
196 if (failed(typeConverter->convertTypes(oldCircuitInputTy, circuitInputTy)))
198 if (failed(typeConverter->convertTypes(
199 op.getInit().front().back().getOperandTypes(), initOutputTy)))
201 if (failed(typeConverter->convertTypes(
202 op.getCircuit().front().back().getOperandTypes(), circuitOutputTy)))
204 if (failed(rewriter.convertRegionTypes(&op.getInit(), *typeConverter)))
206 if (failed(rewriter.convertRegionTypes(&op.getLoop(), *typeConverter)))
208 if (failed(rewriter.convertRegionTypes(&op.getCircuit(), *typeConverter)))
211 unsigned numRegs = op.getNumRegs();
212 auto initialValues = op.getInitialValues();
214 auto initFuncTy = rewriter.getFunctionType({}, initOutputTy);
217 auto loopFuncTy = rewriter.getFunctionType(loopInputTy, initOutputTy);
219 rewriter.getFunctionType(circuitInputTy, circuitOutputTy);
221 func::FuncOp initFuncOp, loopFuncOp, circuitFuncOp;
224 OpBuilder::InsertionGuard guard(rewriter);
225 rewriter.setInsertionPointToEnd(
226 op->getParentOfType<ModuleOp>().getBody());
227 initFuncOp = rewriter.create<func::FuncOp>(loc, names.newName(
"bmc_init"),
229 rewriter.inlineRegionBefore(op.getInit(), initFuncOp.getFunctionBody(),
231 loopFuncOp = rewriter.create<func::FuncOp>(loc, names.newName(
"bmc_loop"),
233 rewriter.inlineRegionBefore(op.getLoop(), loopFuncOp.getFunctionBody(),
235 circuitFuncOp = rewriter.create<func::FuncOp>(
236 loc, names.newName(
"bmc_circuit"), circuitFuncTy);
237 rewriter.inlineRegionBefore(op.getCircuit(),
238 circuitFuncOp.getFunctionBody(),
239 circuitFuncOp.end());
240 auto funcOps = {&initFuncOp, &loopFuncOp, &circuitFuncOp};
242 auto outputTys = {initOutputTy, initOutputTy, circuitOutputTy};
243 for (
auto [funcOp, outputTy] :
llvm::zip(funcOps, outputTys)) {
244 auto operands = funcOp->getBody().front().back().getOperands();
245 rewriter.eraseOp(&funcOp->getFunctionBody().front().back());
246 rewriter.setInsertionPointToEnd(&funcOp->getBody().front());
247 SmallVector<Value> toReturn;
248 for (
unsigned i = 0; i < outputTy.size(); ++i)
249 toReturn.push_back(typeConverter->materializeTargetConversion(
250 rewriter, loc, outputTy[i], operands[i]));
251 rewriter.create<func::ReturnOp>(loc, toReturn);
256 rewriter.create<smt::SolverOp>(loc, rewriter.getI1Type(), ValueRange{});
257 rewriter.createBlock(&solver.getBodyRegion());
260 ValueRange initVals =
261 rewriter.create<func::CallOp>(loc, initFuncOp)->getResults();
264 rewriter.create<smt::PushOp>(loc, 1);
269 size_t initIndex = 0;
270 SmallVector<Value> inputDecls;
271 SmallVector<int> clockIndexes;
272 for (
auto [curIndex, oldTy, newTy] :
273 llvm::enumerate(oldCircuitInputTy, circuitInputTy)) {
274 if (isa<seq::ClockType>(oldTy)) {
275 inputDecls.push_back(initVals[initIndex++]);
276 clockIndexes.push_back(curIndex);
279 if (curIndex >= oldCircuitInputTy.size() - numRegs) {
281 initialValues[curIndex - oldCircuitInputTy.size() + numRegs];
282 if (
auto initIntAttr = dyn_cast<IntegerAttr>(initVal)) {
283 inputDecls.push_back(rewriter.create<smt::BVConstantOp>(
284 loc, initIntAttr.getValue().getSExtValue(),
285 cast<smt::BitVectorType>(newTy).getWidth()));
289 inputDecls.push_back(rewriter.create<smt::DeclareFunOp>(loc, newTy));
292 auto numStateArgs = initVals.size() - initIndex;
294 for (; initIndex < initVals.size(); ++initIndex)
295 inputDecls.push_back(initVals[initIndex]);
298 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
300 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
302 rewriter.create<arith::ConstantOp>(loc, adaptor.getBoundAttr());
304 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
false));
306 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
true));
307 inputDecls.push_back(constFalse);
312 auto forOp = rewriter.create<scf::ForOp>(
313 loc, lowerBound, upperBound, step, inputDecls,
314 [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
316 builder.create<smt::PopOp>(loc, 1);
317 builder.create<smt::PushOp>(loc, 1);
320 ValueRange circuitCallOuts =
322 .create<func::CallOp>(
324 iterArgs.take_front(circuitFuncOp.getNumArguments()))
327 rewriter.create<smt::CheckOp>(loc, builder.getI1Type());
329 OpBuilder::InsertionGuard guard(builder);
330 builder.createBlock(&checkOp.getSatRegion());
331 builder.create<smt::YieldOp>(loc, constTrue);
332 builder.createBlock(&checkOp.getUnknownRegion());
333 builder.create<smt::YieldOp>(loc, constTrue);
334 builder.createBlock(&checkOp.getUnsatRegion());
335 builder.create<smt::YieldOp>(loc, constFalse);
338 Value violated = builder.create<arith::OrIOp>(
339 loc, checkOp.getResult(0), iterArgs.back());
342 SmallVector<Value> loopCallInputs;
344 for (
auto index : clockIndexes)
345 loopCallInputs.push_back(iterArgs[index]);
347 for (
auto stateArg : iterArgs.drop_back().take_back(numStateArgs))
348 loopCallInputs.push_back(stateArg);
349 ValueRange loopVals =
350 builder.create<func::CallOp>(loc, loopFuncOp, loopCallInputs)
353 size_t loopIndex = 0;
355 SmallVector<Value> newDecls;
356 for (
auto [oldTy, newTy] :
357 llvm::zip(TypeRange(oldCircuitInputTy).drop_back(numRegs),
358 TypeRange(circuitInputTy).drop_back(numRegs))) {
359 if (isa<seq::ClockType>(oldTy))
360 newDecls.push_back(loopVals[loopIndex++]);
362 newDecls.push_back(builder.create<smt::DeclareFunOp>(loc, newTy));
369 if (clockIndexes.size() == 1) {
370 SmallVector<Value> regInputs = circuitCallOuts.take_back(numRegs);
371 if (risingClocksOnly) {
374 newDecls.append(regInputs);
376 auto clockIndex = clockIndexes[0];
377 auto oldClock = iterArgs[clockIndex];
380 auto newClock = loopVals[0];
381 auto oldClockLow = builder.create<smt::BVNotOp>(loc, oldClock);
383 builder.create<smt::BVAndOp>(loc, oldClockLow, newClock);
385 auto trueBV = builder.create<smt::BVConstantOp>(loc, 1, 1);
387 builder.create<smt::EqOp>(loc, isPosedgeBV, trueBV);
389 iterArgs.take_front(circuitFuncOp.getNumArguments())
391 SmallVector<Value> nextRegStates;
392 for (
auto [regState, regInput] :
393 llvm::zip(regStates, regInputs)) {
397 nextRegStates.push_back(builder.create<smt::IteOp>(
398 loc, isPosedge, regInput, regState));
400 newDecls.append(nextRegStates);
405 for (; loopIndex < loopVals.size(); ++loopIndex)
406 newDecls.push_back(loopVals[loopIndex]);
408 newDecls.push_back(violated);
410 builder.create<scf::YieldOp>(loc, newDecls);
413 Value res = rewriter.create<arith::XOrIOp>(loc, forOp->getResults().back(),
415 rewriter.create<smt::YieldOp>(loc, res);
416 rewriter.replaceOp(op, solver.getResults());
421 bool risingClocksOnly;
431struct ConvertVerifToSMTPass
432 :
public circt::impl::ConvertVerifToSMTBase<ConvertVerifToSMTPass> {
434 void runOnOperation()
override;
441 bool risingClocksOnly) {
442 patterns.add<VerifAssertOpConversion, VerifAssumeOpConversion,
443 LogicEquivalenceCheckingOpConversion>(converter,
445 patterns.add<VerifBoundedModelCheckingOpConversion>(
446 converter,
patterns.getContext(), names, risingClocksOnly);
449void ConvertVerifToSMTPass::runOnOperation() {
450 ConversionTarget target(getContext());
451 target.addIllegalDialect<verif::VerifDialect>();
452 target.addLegalDialect<smt::SMTDialect, arith::ArithDialect, scf::SCFDialect,
453 func::FuncDialect>();
454 target.addLegalOp<UnrealizedConversionCastOp>();
458 SymbolTable symbolTable(getOperation());
459 WalkResult assertionCheck = getOperation().walk(
461 if (
auto bmcOp = dyn_cast<verif::BoundedModelCheckingOp>(op)) {
464 auto regTypes = TypeRange(bmcOp.getCircuit().getArgumentTypes())
465 .take_back(bmcOp.getNumRegs());
466 for (
auto [regType, initVal] :
467 llvm::zip(regTypes, bmcOp.getInitialValues())) {
468 if (!isa<IntegerType>(regType) && !isa<UnitAttr>(initVal)) {
469 op->emitError(
"initial values are currently only supported for "
470 "registers with integer types");
472 return WalkResult::interrupt();
476 auto numClockArgs = 0;
477 for (
auto argType : bmcOp.getCircuit().getArgumentTypes())
478 if (isa<
seq::ClockType>(argType))
482 if (numClockArgs > 1) {
484 "only modules with one or zero clocks are currently supported");
485 return WalkResult::interrupt();
487 SmallVector<mlir::Operation *> worklist;
488 int numAssertions = 0;
489 op->walk([&](Operation *curOp) {
490 if (isa<verif::AssertOp>(curOp))
492 if (
auto inst = dyn_cast<InstanceOp>(curOp))
493 worklist.push_back(symbolTable.lookup(inst.getModuleName()));
497 while (!worklist.empty()) {
498 auto *
module = worklist.pop_back_val();
499 module->walk([&](Operation *curOp) {
500 if (isa<verif::AssertOp>(curOp))
502 if (
auto inst = dyn_cast<InstanceOp>(curOp))
503 worklist.push_back(symbolTable.lookup(inst.getModuleName()));
505 if (numAssertions > 1)
508 if (numAssertions > 1) {
510 "bounded model checking problems with multiple assertions are "
512 "correctly handled - instead, you can assert the "
513 "conjunction of your assertions");
514 return WalkResult::interrupt();
517 return WalkResult::advance();
519 if (assertionCheck.wasInterrupted())
520 return signalPassFailure();
521 RewritePatternSet
patterns(&getContext());
522 TypeConverter converter;
533 if (failed(mlir::applyPartialConversion(getOperation(), target,
535 return signalPassFailure();
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateVerifToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns, Namespace &names, bool risingClocksOnly)
Get the Verif to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.