14#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/Dialect/SMT/IR/SMTOps.h"
19#include "mlir/Dialect/SMT/IR/SMTTypes.h"
20#include "mlir/IR/ValueRange.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/DialectConversion.h"
23#include "llvm/ADT/SmallVector.h"
26#define GEN_PASS_DEF_CONVERTVERIFTOSMT
27#include "circt/Conversion/Passes.h.inc"
45 matchAndRewrite(verif::AssertOp op, OpAdaptor adaptor,
46 ConversionPatternRewriter &rewriter)
const override {
47 Value cond = typeConverter->materializeTargetConversion(
48 rewriter, op.getLoc(), smt::BoolType::get(getContext()),
49 adaptor.getProperty());
50 Value notCond = rewriter.create<smt::NotOp>(op.getLoc(), cond);
51 rewriter.replaceOpWithNewOp<smt::AssertOp>(op, notCond);
61 matchAndRewrite(verif::AssumeOp op, OpAdaptor adaptor,
62 ConversionPatternRewriter &rewriter)
const override {
63 Value cond = typeConverter->materializeTargetConversion(
64 rewriter, op.getLoc(), smt::BoolType::get(getContext()),
65 adaptor.getProperty());
66 rewriter.replaceOpWithNewOp<smt::AssertOp>(op, cond);
76struct LogicEquivalenceCheckingOpConversion
82 matchAndRewrite(verif::LogicEquivalenceCheckingOp op, OpAdaptor adaptor,
83 ConversionPatternRewriter &rewriter)
const override {
84 Location loc = op.getLoc();
85 auto *firstOutputs = adaptor.getFirstCircuit().front().getTerminator();
86 auto *secondOutputs = adaptor.getSecondCircuit().front().getTerminator();
88 if (firstOutputs->getNumOperands() == 0) {
91 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
true));
92 rewriter.replaceOp(op, trueVal);
96 auto unusedResult = op.use_empty();
100 smt::SolverOp solver;
102 solver = rewriter.create<smt::SolverOp>(loc, TypeRange{}, ValueRange{});
104 solver = rewriter.create<smt::SolverOp>(loc, rewriter.getI1Type(),
107 rewriter.createBlock(&solver.getBodyRegion());
110 if (failed(rewriter.convertRegionTypes(&adaptor.getFirstCircuit(),
113 if (failed(rewriter.convertRegionTypes(&adaptor.getSecondCircuit(),
118 SmallVector<Value> inputs;
119 for (
auto arg : adaptor.getFirstCircuit().getArguments())
120 inputs.push_back(rewriter.create<smt::DeclareFunOp>(loc, arg.getType()));
129 rewriter.mergeBlocks(&adaptor.getFirstCircuit().front(), solver.getBody(),
131 rewriter.mergeBlocks(&adaptor.getSecondCircuit().front(), solver.getBody(),
133 rewriter.setInsertionPointToEnd(solver.getBody());
139 SmallVector<Value> outputsDifferent;
140 for (
auto [out1, out2] :
141 llvm::zip(firstOutputs->getOperands(), secondOutputs->getOperands())) {
142 Value o1 = typeConverter->materializeTargetConversion(
143 rewriter, loc, typeConverter->convertType(out1.getType()), out1);
144 Value o2 = typeConverter->materializeTargetConversion(
145 rewriter, loc, typeConverter->convertType(out1.getType()), out2);
146 outputsDifferent.emplace_back(
147 rewriter.create<smt::DistinctOp>(loc, o1, o2));
150 rewriter.eraseOp(firstOutputs);
151 rewriter.eraseOp(secondOutputs);
154 if (outputsDifferent.size() == 1)
155 toAssert = outputsDifferent[0];
157 toAssert = rewriter.create<smt::OrOp>(loc, outputsDifferent);
159 rewriter.create<smt::AssertOp>(loc, toAssert);
167 auto checkOp = rewriter.create<smt::CheckOp>(loc, TypeRange{});
168 rewriter.createBlock(&checkOp.getSatRegion());
169 rewriter.create<smt::YieldOp>(loc);
170 rewriter.createBlock(&checkOp.getUnknownRegion());
171 rewriter.create<smt::YieldOp>(loc);
172 rewriter.createBlock(&checkOp.getUnsatRegion());
173 rewriter.create<smt::YieldOp>(loc);
174 rewriter.setInsertionPointAfter(checkOp);
175 rewriter.create<smt::YieldOp>(loc);
178 rewriter.eraseOp(op);
181 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
false));
183 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
true));
184 auto checkOp = rewriter.create<smt::CheckOp>(loc, rewriter.getI1Type());
185 rewriter.createBlock(&checkOp.getSatRegion());
186 rewriter.create<smt::YieldOp>(loc, falseVal);
187 rewriter.createBlock(&checkOp.getUnknownRegion());
188 rewriter.create<smt::YieldOp>(loc, falseVal);
189 rewriter.createBlock(&checkOp.getUnsatRegion());
190 rewriter.create<smt::YieldOp>(loc, trueVal);
191 rewriter.setInsertionPointAfter(checkOp);
192 rewriter.create<smt::YieldOp>(loc, checkOp->getResults());
194 rewriter.replaceOp(op, solver->getResults());
203struct VerifBoundedModelCheckingOpConversion
207 VerifBoundedModelCheckingOpConversion(TypeConverter &converter,
209 bool risingClocksOnly)
211 risingClocksOnly(risingClocksOnly) {}
213 matchAndRewrite(verif::BoundedModelCheckingOp op, OpAdaptor adaptor,
214 ConversionPatternRewriter &rewriter)
const override {
215 Location loc = op.getLoc();
216 SmallVector<Type> oldLoopInputTy(op.getLoop().getArgumentTypes());
217 SmallVector<Type> oldCircuitInputTy(op.getCircuit().getArgumentTypes());
221 SmallVector<Type> loopInputTy, circuitInputTy, initOutputTy,
223 if (failed(typeConverter->convertTypes(oldLoopInputTy, loopInputTy)))
225 if (failed(typeConverter->convertTypes(oldCircuitInputTy, circuitInputTy)))
227 if (failed(typeConverter->convertTypes(
228 op.getInit().front().back().getOperandTypes(), initOutputTy)))
230 if (failed(typeConverter->convertTypes(
231 op.getCircuit().front().back().getOperandTypes(), circuitOutputTy)))
233 if (failed(rewriter.convertRegionTypes(&op.getInit(), *typeConverter)))
235 if (failed(rewriter.convertRegionTypes(&op.getLoop(), *typeConverter)))
237 if (failed(rewriter.convertRegionTypes(&op.getCircuit(), *typeConverter)))
240 unsigned numRegs = op.getNumRegs();
241 auto initialValues = op.getInitialValues();
243 auto initFuncTy = rewriter.getFunctionType({}, initOutputTy);
246 auto loopFuncTy = rewriter.getFunctionType(loopInputTy, initOutputTy);
248 rewriter.getFunctionType(circuitInputTy, circuitOutputTy);
250 func::FuncOp initFuncOp, loopFuncOp, circuitFuncOp;
253 OpBuilder::InsertionGuard guard(rewriter);
254 rewriter.setInsertionPointToEnd(
255 op->getParentOfType<ModuleOp>().getBody());
256 initFuncOp = rewriter.create<func::FuncOp>(loc, names.newName(
"bmc_init"),
258 rewriter.inlineRegionBefore(op.getInit(), initFuncOp.getFunctionBody(),
260 loopFuncOp = rewriter.create<func::FuncOp>(loc, names.newName(
"bmc_loop"),
262 rewriter.inlineRegionBefore(op.getLoop(), loopFuncOp.getFunctionBody(),
264 circuitFuncOp = rewriter.create<func::FuncOp>(
265 loc, names.newName(
"bmc_circuit"), circuitFuncTy);
266 rewriter.inlineRegionBefore(op.getCircuit(),
267 circuitFuncOp.getFunctionBody(),
268 circuitFuncOp.end());
269 auto funcOps = {&initFuncOp, &loopFuncOp, &circuitFuncOp};
271 auto outputTys = {initOutputTy, initOutputTy, circuitOutputTy};
272 for (
auto [funcOp, outputTy] :
llvm::zip(funcOps, outputTys)) {
273 auto operands = funcOp->getBody().front().back().getOperands();
274 rewriter.eraseOp(&funcOp->getFunctionBody().front().back());
275 rewriter.setInsertionPointToEnd(&funcOp->getBody().front());
276 SmallVector<Value> toReturn;
277 for (
unsigned i = 0; i < outputTy.size(); ++i)
278 toReturn.push_back(typeConverter->materializeTargetConversion(
279 rewriter, loc, outputTy[i], operands[i]));
280 rewriter.create<func::ReturnOp>(loc, toReturn);
285 rewriter.create<smt::SolverOp>(loc, rewriter.getI1Type(), ValueRange{});
286 rewriter.createBlock(&solver.getBodyRegion());
289 ValueRange initVals =
290 rewriter.create<func::CallOp>(loc, initFuncOp)->getResults();
293 rewriter.create<smt::PushOp>(loc, 1);
298 size_t initIndex = 0;
299 SmallVector<Value> inputDecls;
300 SmallVector<int> clockIndexes;
301 for (
auto [curIndex, oldTy, newTy] :
302 llvm::enumerate(oldCircuitInputTy, circuitInputTy)) {
303 if (isa<seq::ClockType>(oldTy)) {
304 inputDecls.push_back(initVals[initIndex++]);
305 clockIndexes.push_back(curIndex);
308 if (curIndex >= oldCircuitInputTy.size() - numRegs) {
310 initialValues[curIndex - oldCircuitInputTy.size() + numRegs];
311 if (
auto initIntAttr = dyn_cast<IntegerAttr>(initVal)) {
312 inputDecls.push_back(rewriter.create<smt::BVConstantOp>(
313 loc, initIntAttr.getValue().getSExtValue(),
314 cast<smt::BitVectorType>(newTy).getWidth()));
318 inputDecls.push_back(rewriter.create<smt::DeclareFunOp>(loc, newTy));
321 auto numStateArgs = initVals.size() - initIndex;
323 for (; initIndex < initVals.size(); ++initIndex)
324 inputDecls.push_back(initVals[initIndex]);
327 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
329 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
331 rewriter.create<arith::ConstantOp>(loc, adaptor.getBoundAttr());
333 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
false));
335 rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(
true));
336 inputDecls.push_back(constFalse);
341 auto forOp = rewriter.create<scf::ForOp>(
342 loc, lowerBound, upperBound, step, inputDecls,
343 [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) {
345 builder.create<smt::PopOp>(loc, 1);
346 builder.create<smt::PushOp>(loc, 1);
349 ValueRange circuitCallOuts =
351 .create<func::CallOp>(
353 iterArgs.take_front(circuitFuncOp.getNumArguments()))
356 rewriter.create<smt::CheckOp>(loc, builder.getI1Type());
358 OpBuilder::InsertionGuard guard(builder);
359 builder.createBlock(&checkOp.getSatRegion());
360 builder.create<smt::YieldOp>(loc, constTrue);
361 builder.createBlock(&checkOp.getUnknownRegion());
362 builder.create<smt::YieldOp>(loc, constTrue);
363 builder.createBlock(&checkOp.getUnsatRegion());
364 builder.create<smt::YieldOp>(loc, constFalse);
367 Value violated = builder.create<arith::OrIOp>(
368 loc, checkOp.getResult(0), iterArgs.back());
371 SmallVector<Value> loopCallInputs;
373 for (
auto index : clockIndexes)
374 loopCallInputs.push_back(iterArgs[index]);
376 for (
auto stateArg : iterArgs.drop_back().take_back(numStateArgs))
377 loopCallInputs.push_back(stateArg);
378 ValueRange loopVals =
379 builder.create<func::CallOp>(loc, loopFuncOp, loopCallInputs)
382 size_t loopIndex = 0;
384 SmallVector<Value> newDecls;
385 for (
auto [oldTy, newTy] :
386 llvm::zip(TypeRange(oldCircuitInputTy).drop_back(numRegs),
387 TypeRange(circuitInputTy).drop_back(numRegs))) {
388 if (isa<seq::ClockType>(oldTy))
389 newDecls.push_back(loopVals[loopIndex++]);
391 newDecls.push_back(builder.create<smt::DeclareFunOp>(loc, newTy));
398 if (clockIndexes.size() == 1) {
399 SmallVector<Value> regInputs = circuitCallOuts.take_back(numRegs);
400 if (risingClocksOnly) {
403 newDecls.append(regInputs);
405 auto clockIndex = clockIndexes[0];
406 auto oldClock = iterArgs[clockIndex];
409 auto newClock = loopVals[0];
410 auto oldClockLow = builder.create<smt::BVNotOp>(loc, oldClock);
412 builder.create<smt::BVAndOp>(loc, oldClockLow, newClock);
414 auto trueBV = builder.create<smt::BVConstantOp>(loc, 1, 1);
416 builder.create<smt::EqOp>(loc, isPosedgeBV, trueBV);
418 iterArgs.take_front(circuitFuncOp.getNumArguments())
420 SmallVector<Value> nextRegStates;
421 for (
auto [regState, regInput] :
422 llvm::zip(regStates, regInputs)) {
426 nextRegStates.push_back(builder.create<smt::IteOp>(
427 loc, isPosedge, regInput, regState));
429 newDecls.append(nextRegStates);
434 for (; loopIndex < loopVals.size(); ++loopIndex)
435 newDecls.push_back(loopVals[loopIndex]);
437 newDecls.push_back(violated);
439 builder.create<scf::YieldOp>(loc, newDecls);
442 Value res = rewriter.create<arith::XOrIOp>(loc, forOp->getResults().back(),
444 rewriter.create<smt::YieldOp>(loc, res);
445 rewriter.replaceOp(op, solver.getResults());
450 bool risingClocksOnly;
460struct ConvertVerifToSMTPass
461 :
public circt::impl::ConvertVerifToSMTBase<ConvertVerifToSMTPass> {
463 void runOnOperation()
override;
470 bool risingClocksOnly) {
471 patterns.add<VerifAssertOpConversion, VerifAssumeOpConversion,
472 LogicEquivalenceCheckingOpConversion>(converter,
474 patterns.add<VerifBoundedModelCheckingOpConversion>(
475 converter,
patterns.getContext(), names, risingClocksOnly);
478void ConvertVerifToSMTPass::runOnOperation() {
479 ConversionTarget target(getContext());
480 target.addIllegalDialect<verif::VerifDialect>();
481 target.addLegalDialect<smt::SMTDialect, arith::ArithDialect, scf::SCFDialect,
482 func::FuncDialect>();
483 target.addLegalOp<UnrealizedConversionCastOp>();
487 SymbolTable symbolTable(getOperation());
488 WalkResult assertionCheck = getOperation().walk(
490 if (
auto bmcOp = dyn_cast<verif::BoundedModelCheckingOp>(op)) {
493 auto regTypes = TypeRange(bmcOp.getCircuit().getArgumentTypes())
494 .take_back(bmcOp.getNumRegs());
495 for (
auto [regType, initVal] :
496 llvm::zip(regTypes, bmcOp.getInitialValues())) {
497 if (!isa<IntegerType>(regType) && !isa<UnitAttr>(initVal)) {
498 op->emitError(
"initial values are currently only supported for "
499 "registers with integer types");
501 return WalkResult::interrupt();
505 auto numClockArgs = 0;
506 for (
auto argType : bmcOp.getCircuit().getArgumentTypes())
507 if (isa<
seq::ClockType>(argType))
511 if (numClockArgs > 1) {
513 "only modules with one or zero clocks are currently supported");
514 return WalkResult::interrupt();
516 SmallVector<mlir::Operation *> worklist;
517 int numAssertions = 0;
518 op->walk([&](Operation *curOp) {
519 if (isa<verif::AssertOp>(curOp))
521 if (
auto inst = dyn_cast<InstanceOp>(curOp))
522 worklist.push_back(symbolTable.lookup(inst.getModuleName()));
526 while (!worklist.empty()) {
527 auto *
module = worklist.pop_back_val();
528 module->walk([&](Operation *curOp) {
529 if (isa<verif::AssertOp>(curOp))
531 if (
auto inst = dyn_cast<InstanceOp>(curOp))
532 worklist.push_back(symbolTable.lookup(inst.getModuleName()));
534 if (numAssertions > 1)
537 if (numAssertions > 1) {
539 "bounded model checking problems with multiple assertions are "
541 "correctly handled - instead, you can assert the "
542 "conjunction of your assertions");
543 return WalkResult::interrupt();
546 return WalkResult::advance();
548 if (assertionCheck.wasInterrupted())
549 return signalPassFailure();
550 RewritePatternSet
patterns(&getContext());
551 TypeConverter converter;
562 if (failed(mlir::applyPartialConversion(getOperation(), target,
564 return signalPassFailure();
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateVerifToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns, Namespace &names, bool risingClocksOnly)
Get the Verif to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.