12 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
13 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
17 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
18 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 #include "mlir/Dialect/SCF/IR/SCF.h"
24 #include "mlir/IR/BuiltinDialect.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
30 #define DEBUG_TYPE "lower-smt-to-z3-llvm"
33 #define GEN_PASS_DEF_LOWERSMTTOZ3LLVM
34 #include "circt/Conversion/Passes.h.inc"
38 using namespace circt;
47 OpBuilder::InsertionGuard guard(builder);
48 builder.setInsertionPointToStart(module.getBody());
55 Location loc = module.getLoc();
58 auto createGlobal = [&](StringRef namePrefix) {
59 auto global = builder.create<LLVM::GlobalOp>(
60 loc, ptrTy,
false, LLVM::Linkage::Internal, names.
newName(namePrefix),
62 OpBuilder::InsertionGuard g(builder);
63 builder.createBlock(&global.getInitializer());
64 Value res = builder.create<LLVM::ZeroOp>(loc, ptrTy);
65 builder.create<LLVM::ReturnOp>(loc, res);
69 auto ctxGlobal = createGlobal(
"ctx");
70 auto solverGlobal = createGlobal(
"solver");
75 SMTGlobalsHandler::SMTGlobalsHandler(
Namespace &&names,
76 mlir::LLVM::GlobalOp solver,
77 mlir::LLVM::GlobalOp ctx)
78 : solver(solver), ctx(ctx), names(names) {}
81 mlir::LLVM::GlobalOp solver,
82 mlir::LLVM::GlobalOp ctx)
83 : solver(solver), ctx(ctx) {
95 template <
typename OpTy>
98 SMTLoweringPattern(
const TypeConverter &typeConverter, MLIRContext *context,
100 const LowerSMTToZ3LLVMOptions &options)
105 Value buildGlobalPtrToGlobal(OpBuilder &builder, Location loc,
106 LLVM::GlobalOp global,
107 DenseMap<Block *, Value> &cache)
const {
108 Block *block = builder.getBlock();
109 if (
auto iter = cache.find(block); iter != cache.end())
110 return iter->getSecond();
112 OpBuilder::InsertionGuard g(builder);
113 builder.setInsertionPointToStart(block);
114 Value globalAddr = builder.create<LLVM::AddressOfOp>(loc, global);
115 return cache[block] = builder.create<LLVM::LoadOp>(
126 Value buildContextPtr(OpBuilder &builder, Location loc)
const {
127 return buildGlobalPtrToGlobal(builder, loc, globals.ctx, globals.ctxCache);
135 Value buildSolverPtr(OpBuilder &builder, Location loc)
const {
136 return buildGlobalPtrToGlobal(builder, loc, globals.solver,
137 globals.solverCache);
143 LLVM::CallOp buildCall(OpBuilder &builder, Location loc, StringRef name,
144 LLVM::LLVMFunctionType funcType,
145 ValueRange args)
const {
146 auto &funcOp = globals.funcMap[builder.getStringAttr(name)];
148 OpBuilder::InsertionGuard guard(builder);
150 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
151 builder.setInsertionPointToEnd(module.getBody());
152 funcOp = LLVM::lookupOrCreateFn(module, name, funcType.getParams(),
153 funcType.getReturnType(),
154 funcType.getVarArg());
156 return builder.create<LLVM::CallOp>(loc, funcOp, args);
163 Value buildString(OpBuilder &builder, Location loc, StringRef str)
const {
164 auto &global = globals.stringCache[builder.getStringAttr(str)];
166 OpBuilder::InsertionGuard guard(builder);
168 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
169 builder.setInsertionPointToEnd(module.getBody());
172 auto strAttr = builder.getStringAttr(str.str() +
'\00');
173 global = builder.create<LLVM::GlobalOp>(
174 loc, arrayTy,
true, LLVM::Linkage::Internal,
175 globals.names.newName(
"str"), strAttr);
177 return builder.create<LLVM::AddressOfOp>(loc, global);
182 LLVM::CallOp buildAPICallWithContext(OpBuilder &builder, Location loc,
183 StringRef name, Type returnType,
184 ValueRange args = {})
const {
185 auto ctx = buildContextPtr(builder, loc);
186 SmallVector<Value> arguments;
187 arguments.emplace_back(ctx);
188 arguments.append(SmallVector<Value>(args));
192 returnType, SmallVector<Type>(ValueRange(arguments).getTypes())),
199 Value buildPtrAPICall(OpBuilder &builder, Location loc, StringRef name,
200 ValueRange args = {})
const {
201 return buildAPICallWithContext(
208 Value buildSort(OpBuilder &builder, Location loc, Type type)
const {
211 return TypeSwitch<Type, Value>(type)
212 .Case([&](smt::IntType ty) {
213 return buildPtrAPICall(builder, loc,
"Z3_mk_int_sort");
215 .Case([&](smt::BitVectorType ty) {
216 Value bitwidth = builder.create<LLVM::ConstantOp>(
217 loc, builder.getI32Type(), ty.getWidth());
218 return buildPtrAPICall(builder, loc,
"Z3_mk_bv_sort", {bitwidth});
220 .Case([&](smt::BoolType ty) {
221 return buildPtrAPICall(builder, loc,
"Z3_mk_bool_sort");
223 .Case([&](smt::SortType ty) {
224 Value str = buildString(builder, loc, ty.getIdentifier());
226 buildPtrAPICall(builder, loc,
"Z3_mk_string_symbol", {str});
227 return buildPtrAPICall(builder, loc,
"Z3_mk_uninterpreted_sort",
230 .Case([&](smt::ArrayType ty) {
231 return buildPtrAPICall(builder, loc,
"Z3_mk_array_sort",
232 {buildSort(builder, loc, ty.getDomainType()),
233 buildSort(builder, loc, ty.getRangeType())});
238 const LowerSMTToZ3LLVMOptions &options;
255 struct DeclareFunOpLowering :
public SMTLoweringPattern<DeclareFunOp> {
256 using SMTLoweringPattern::SMTLoweringPattern;
259 matchAndRewrite(DeclareFunOp op, OpAdaptor adaptor,
260 ConversionPatternRewriter &rewriter)
const final {
261 Location loc = op.getLoc();
265 if (adaptor.getNamePrefix())
266 prefix = buildString(rewriter, loc, *adaptor.getNamePrefix());
268 prefix = rewriter.create<LLVM::ZeroOp>(
272 if (!isa<SMTFuncType>(op.getType())) {
273 Value sort = buildSort(rewriter, loc, op.getType());
275 buildPtrAPICall(rewriter, loc,
"Z3_mk_fresh_const", {prefix, sort});
276 rewriter.replaceOp(op, constDecl);
282 auto funcType = cast<SMTFuncType>(op.getResult().getType());
283 Value rangeSort = buildSort(rewriter, loc, funcType.getRangeType());
288 Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
289 for (
auto [i, ty] : llvm::enumerate(funcType.getDomainTypes())) {
290 Value sort = buildSort(rewriter, loc, ty);
291 domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, sort, i);
295 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
296 Value domainStorage =
297 rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
298 rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
300 Value domainSize = rewriter.create<LLVM::ConstantOp>(
301 loc, rewriter.getI32Type(), funcType.getDomainTypes().size());
303 buildPtrAPICall(rewriter, loc,
"Z3_mk_fresh_func_decl",
304 {prefix, domainSize, domainStorage, rangeSort});
306 rewriter.replaceOp(op, decl);
316 struct ApplyFuncOpLowering :
public SMTLoweringPattern<ApplyFuncOp> {
317 using SMTLoweringPattern::SMTLoweringPattern;
320 matchAndRewrite(ApplyFuncOp op, OpAdaptor adaptor,
321 ConversionPatternRewriter &rewriter)
const final {
322 Location loc = op.getLoc();
327 Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
328 for (
auto [i, arg] : llvm::enumerate(adaptor.getArgs()))
329 domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, arg, i);
333 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
334 Value domainStorage =
335 rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
336 rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
340 Value domainSize = rewriter.create<LLVM::ConstantOp>(
341 loc, rewriter.getI32Type(), adaptor.getArgs().size());
343 buildPtrAPICall(rewriter, loc,
"Z3_mk_app",
344 {adaptor.getFunc(), domainSize, domainStorage});
345 rewriter.replaceOp(op, returnVal);
363 struct BVConstantOpLowering :
public SMTLoweringPattern<smt::BVConstantOp> {
364 using SMTLoweringPattern::SMTLoweringPattern;
367 matchAndRewrite(smt::BVConstantOp op, OpAdaptor adaptor,
368 ConversionPatternRewriter &rewriter)
const final {
369 Location loc = op.getLoc();
370 unsigned width = op.getType().getWidth();
371 auto bvSort = buildSort(rewriter, loc, op.getResult().getType());
372 APInt val = adaptor.getValue().getValue();
375 Value bvConst = rewriter.create<LLVM::ConstantOp>(
376 loc, rewriter.getI64Type(), val.getZExtValue());
377 Value res = buildPtrAPICall(rewriter, loc,
"Z3_mk_unsigned_int64",
379 rewriter.replaceOp(op, res);
384 llvm::raw_string_ostream stream(str);
386 Value bvString = buildString(rewriter, loc, str);
388 buildPtrAPICall(rewriter, loc,
"Z3_mk_numeral", {bvString, bvSort});
390 rewriter.replaceOp(op, bvNumeral);
400 template <
typename SourceTy>
401 struct VariadicSMTPattern :
public SMTLoweringPattern<SourceTy> {
402 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
404 VariadicSMTPattern(
const TypeConverter &typeConverter, MLIRContext *context,
406 const LowerSMTToZ3LLVMOptions &options,
407 StringRef apiFuncName,
unsigned minNumArgs)
408 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
409 apiFuncName(apiFuncName), minNumArgs(minNumArgs) {}
412 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
413 ConversionPatternRewriter &rewriter)
const final {
414 if (adaptor.getOperands().size() < minNumArgs)
417 Location loc = op.getLoc();
418 Value numOperands = rewriter.create<LLVM::ConstantOp>(
419 loc, rewriter.getI32Type(), op->getNumOperands());
421 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
425 rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
426 Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
428 for (
auto [i, operand] : llvm::enumerate(adaptor.getOperands()))
429 array = rewriter.create<LLVM::InsertValueOp>(
430 loc, array, operand, ArrayRef<int64_t>{(int64_t)i});
432 rewriter.create<LLVM::StoreOp>(loc, array, storage);
434 rewriter.replaceOp(op,
435 SMTLoweringPattern<SourceTy>::buildPtrAPICall(
436 rewriter, loc, apiFuncName, {numOperands, storage}));
441 StringRef apiFuncName;
447 template <
typename SourceTy>
448 struct OneToOneSMTPattern :
public SMTLoweringPattern<SourceTy> {
449 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
451 OneToOneSMTPattern(
const TypeConverter &typeConverter, MLIRContext *context,
453 const LowerSMTToZ3LLVMOptions &options,
454 StringRef apiFuncName,
unsigned numOperands)
455 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
456 apiFuncName(apiFuncName), numOperands(numOperands) {}
459 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
460 ConversionPatternRewriter &rewriter)
const final {
461 if (adaptor.getOperands().size() != numOperands)
465 op, SMTLoweringPattern<SourceTy>::buildPtrAPICall(
466 rewriter, op.getLoc(), apiFuncName, adaptor.getOperands()));
471 StringRef apiFuncName;
472 unsigned numOperands;
477 template <
typename SourceTy>
478 class LowerChainableSMTPattern :
public SMTLoweringPattern<SourceTy> {
479 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
480 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
483 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter)
const final {
485 if (adaptor.getOperands().size() <= 2)
488 Location loc = op.getLoc();
489 SmallVector<Value> elements;
490 for (
int i = 1, e = adaptor.getOperands().size(); i < e; ++i) {
491 Value val = rewriter.create<SourceTy>(
492 loc, op->getResultTypes(),
493 ValueRange{adaptor.getOperands()[i - 1], adaptor.getOperands()[i]});
494 elements.push_back(val);
496 rewriter.replaceOpWithNewOp<smt::AndOp>(op, elements);
503 template <
typename SourceTy>
504 class LowerLeftAssocSMTPattern :
public SMTLoweringPattern<SourceTy> {
505 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
506 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
509 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
510 ConversionPatternRewriter &rewriter)
const final {
511 if (adaptor.getOperands().size() <= 2)
512 return rewriter.notifyMatchFailure(op,
"must have at least two operands");
514 Value runner = adaptor.getOperands()[0];
515 for (Value val : adaptor.getOperands().drop_front())
516 runner = rewriter.create<SourceTy>(op.getLoc(), op->getResultTypes(),
517 ValueRange{runner, val});
519 rewriter.replaceOp(op, runner);
554 struct SolverOpLowering :
public SMTLoweringPattern<SolverOp> {
555 using SMTLoweringPattern::SMTLoweringPattern;
558 matchAndRewrite(SolverOp op, OpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter)
const final {
560 Location loc = op.getLoc();
568 Value config = buildCall(rewriter, loc,
"Z3_mk_config",
575 Value paramKey = buildString(rewriter, loc,
"proof");
576 Value paramValue = buildString(rewriter, loc,
"true");
577 buildCall(rewriter, loc,
"Z3_set_param_value",
579 {config, paramKey, paramValue});
583 Value ctx = buildCall(rewriter, loc,
"Z3_mk_context", ptrToPtrFunc, config)
586 rewriter.create<LLVM::AddressOfOp>(loc, globals.ctx).getResult();
587 rewriter.create<LLVM::StoreOp>(loc, ctx, ctxAddr);
590 buildCall(rewriter, loc,
"Z3_del_config", ptrToVoidFunc, {config});
594 Value solver = buildCall(rewriter, loc,
"Z3_mk_solver", ptrToPtrFunc, ctx)
596 buildCall(rewriter, loc,
"Z3_solver_inc_ref", ptrPtrToVoidFunc,
599 rewriter.create<LLVM::AddressOfOp>(loc, globals.solver).getResult();
600 rewriter.create<LLVM::StoreOp>(loc, solver, solverAddr);
608 SmallVector<Type> convertedTypes;
610 typeConverter->convertTypes(op->getResultTypes(), convertedTypes)))
615 OpBuilder::InsertionGuard guard(rewriter);
616 auto module = op->getParentOfType<ModuleOp>();
617 rewriter.setInsertionPointToEnd(module.getBody());
619 funcOp = rewriter.create<func::FuncOp>(
620 loc, globals.names.newName(
"solver"),
621 rewriter.getFunctionType(adaptor.getInputs().getTypes(),
623 rewriter.inlineRegionBefore(op.getBodyRegion(), funcOp.getBody(),
628 rewriter.create<func::CallOp>(loc, funcOp, adaptor.getInputs())
639 buildCall(rewriter, loc,
"Z3_solver_dec_ref", ptrPtrToVoidFunc,
641 buildCall(rewriter, loc,
"Z3_del_context", ptrToVoidFunc, ctx);
643 rewriter.replaceOp(op, results);
652 struct AssertOpLowering :
public SMTLoweringPattern<AssertOp> {
653 using SMTLoweringPattern::SMTLoweringPattern;
656 matchAndRewrite(AssertOp op, OpAdaptor adaptor,
657 ConversionPatternRewriter &rewriter)
const final {
658 Location loc = op.getLoc();
659 buildAPICallWithContext(
660 rewriter, loc,
"Z3_solver_assert",
662 {buildSolverPtr(rewriter, loc), adaptor.getInput()});
664 rewriter.eraseOp(op);
674 struct YieldOpLowering :
public SMTLoweringPattern<YieldOp> {
675 using SMTLoweringPattern::SMTLoweringPattern;
678 matchAndRewrite(YieldOp op, OpAdaptor adaptor,
679 ConversionPatternRewriter &rewriter)
const final {
680 if (op->getParentOfType<func::FuncOp>()) {
681 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getValues());
684 if (op->getParentOfType<LLVM::LLVMFuncOp>()) {
685 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getValues());
688 if (isa<scf::SCFDialect>(op->getParentOp()->getDialect())) {
689 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getValues());
707 struct CheckOpLowering :
public SMTLoweringPattern<CheckOp> {
708 using SMTLoweringPattern::SMTLoweringPattern;
711 matchAndRewrite(CheckOp op, OpAdaptor adaptor,
712 ConversionPatternRewriter &rewriter)
const final {
713 Location loc = op.getLoc();
718 auto getHeaderString = [](
const std::string &title) {
719 unsigned titleSize = title.size() + 2;
720 return std::string((80 - titleSize) / 2,
'-') +
" " + title +
" " +
721 std::string((80 - titleSize + 1) / 2,
'-') +
"\n%s\n" +
722 std::string(80,
'-') +
"\n";
726 Value solver = buildSolverPtr(rewriter, loc);
731 auto solverStringPtr =
732 buildPtrAPICall(rewriter, loc,
"Z3_solver_to_string", {solver});
733 auto solverFormatString =
734 buildString(rewriter, loc, getHeaderString(
"Solver"));
735 buildCall(rewriter, op.getLoc(),
"printf", printfType,
736 {solverFormatString, solverStringPtr});
740 SmallVector<Type> resultTypes;
741 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
746 buildAPICallWithContext(rewriter, loc,
"Z3_solver_check",
747 rewriter.getI32Type(), {solver})
750 rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), 1);
751 Value isSat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
752 checkResult, constOne);
755 auto satIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isSat);
756 rewriter.inlineRegionBefore(op.getSatRegion(), satIfOp.getThenRegion(),
757 satIfOp.getThenRegion().end());
762 rewriter.createBlock(&satIfOp.getElseRegion());
764 rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), -1);
765 Value isUnsat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
766 checkResult, constNegOne);
767 auto unsatIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isUnsat);
768 rewriter.create<scf::YieldOp>(loc, unsatIfOp->getResults());
770 rewriter.inlineRegionBefore(op.getUnsatRegion(), unsatIfOp.getThenRegion(),
771 unsatIfOp.getThenRegion().end());
772 rewriter.inlineRegionBefore(op.getUnknownRegion(),
773 unsatIfOp.getElseRegion(),
774 unsatIfOp.getElseRegion().end());
776 rewriter.replaceOp(op, satIfOp->getResults());
781 rewriter.setInsertionPointToStart(unsatIfOp.thenBlock());
782 auto proof = buildPtrAPICall(rewriter, op.getLoc(),
"Z3_solver_get_proof",
785 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_ast_to_string", {proof});
787 buildString(rewriter, op.getLoc(), getHeaderString(
"Proof"));
788 buildCall(rewriter, op.getLoc(),
"printf", printfType,
789 {formatString, stringPtr});
793 rewriter.setInsertionPointToStart(satIfOp.thenBlock());
794 auto model = buildPtrAPICall(rewriter, op.getLoc(),
"Z3_solver_get_model",
796 auto modelStringPtr =
797 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_model_to_string", {model});
798 auto modelFormatString =
799 buildString(rewriter, op.getLoc(), getHeaderString(
"Model"));
800 buildCall(rewriter, op.getLoc(),
"printf", printfType,
801 {modelFormatString, modelStringPtr});
829 template <
typename QuantifierOp>
830 struct QuantifierLowering :
public SMTLoweringPattern<QuantifierOp> {
831 using SMTLoweringPattern<QuantifierOp>::SMTLoweringPattern;
832 using SMTLoweringPattern<QuantifierOp>::typeConverter;
833 using SMTLoweringPattern<QuantifierOp>::buildPtrAPICall;
834 using OpAdaptor =
typename QuantifierOp::Adaptor;
836 Value createStorageForValueList(ValueRange values, Location loc,
837 ConversionPatternRewriter &rewriter)
const {
841 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
843 rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
844 Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
846 for (
auto [i, val] : llvm::enumerate(values))
847 array = rewriter.create<LLVM::InsertValueOp>(loc, array, val,
848 ArrayRef<int64_t>(i));
850 rewriter.create<LLVM::StoreOp>(loc, array, storage);
856 matchAndRewrite(QuantifierOp op, OpAdaptor adaptor,
857 ConversionPatternRewriter &rewriter)
const final {
858 Location loc = op.getLoc();
865 if (adaptor.getNoPattern())
866 return rewriter.notifyMatchFailure(
867 op,
"no-pattern attribute not yet supported!");
869 rewriter.setInsertionPoint(op);
872 Value weight = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
873 adaptor.getWeight());
876 unsigned numDecls = op.getBody().getNumArguments();
878 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), numDecls);
885 SmallVector<Value> repl;
886 for (
auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) {
888 if (adaptor.getBoundVarNames().has_value())
889 newArg = rewriter.create<smt::DeclareFunOp>(
891 cast<StringAttr>((*adaptor.getBoundVarNames())[i]));
893 newArg = rewriter.create<smt::DeclareFunOp>(loc, arg.getType());
894 repl.push_back(typeConverter->materializeTargetConversion(
895 rewriter, loc, typeConverter->convertType(arg.getType()), newArg));
898 Value boundStorage = createStorageForValueList(repl, loc, rewriter);
901 auto yieldOp = cast<smt::YieldOp>(op.getBody().front().getTerminator());
902 Value bodyExp = yieldOp.getValues()[0];
903 rewriter.setInsertionPointAfterValue(bodyExp);
904 bodyExp = typeConverter->materializeTargetConversion(
905 rewriter, loc, typeConverter->convertType(bodyExp.getType()), bodyExp);
906 rewriter.eraseOp(yieldOp);
908 rewriter.inlineBlockBefore(&op.getBody().front(), op, repl);
909 rewriter.setInsertionPoint(op);
912 unsigned numPatterns = adaptor.getPatterns().size();
913 Value numPatternsVal = rewriter.create<LLVM::ConstantOp>(
914 loc, rewriter.getI32Type(), numPatterns);
916 Value patternStorage;
917 if (numPatterns > 0) {
919 for (Region *patternRegion : adaptor.getPatterns()) {
921 cast<smt::YieldOp>(patternRegion->front().getTerminator());
922 auto patternTerms = yieldOp.getOperands();
924 rewriter.setInsertionPoint(yieldOp);
925 SmallVector<Value> patternList;
926 for (
auto val : patternTerms)
927 patternList.push_back(typeConverter->materializeTargetConversion(
928 rewriter, loc, typeConverter->convertType(val.getType()), val));
930 rewriter.eraseOp(yieldOp);
931 rewriter.inlineBlockBefore(&patternRegion->front(), op, repl);
933 rewriter.setInsertionPoint(op);
934 Value numTerms = rewriter.create<LLVM::ConstantOp>(
935 loc, rewriter.getI32Type(), patternTerms.size());
936 Value patternTermStorage =
937 createStorageForValueList(patternList, loc, rewriter);
938 Value
pattern = buildPtrAPICall(rewriter, loc,
"Z3_mk_pattern",
939 {numTerms, patternTermStorage});
943 patternStorage = createStorageForValueList(
patterns, loc, rewriter);
947 patternStorage = rewriter.create<LLVM::ZeroOp>(loc, ptrTy);
950 StringRef apiCallName =
"Z3_mk_forall_const";
951 if (std::is_same_v<QuantifierOp, ExistsOp>)
952 apiCallName =
"Z3_mk_exists_const";
953 Value quantifierExp =
954 buildPtrAPICall(rewriter, loc, apiCallName,
955 {weight, numDeclsVal, boundStorage, numPatternsVal,
956 patternStorage, bodyExp});
958 rewriter.replaceOp(op, quantifierExp);
967 struct RepeatOpLowering :
public SMTLoweringPattern<RepeatOp> {
968 using SMTLoweringPattern::SMTLoweringPattern;
971 matchAndRewrite(RepeatOp op, OpAdaptor adaptor,
972 ConversionPatternRewriter &rewriter)
const final {
973 Value count = rewriter.create<LLVM::ConstantOp>(
974 op.getLoc(), rewriter.getI32Type(), op.getCount());
975 rewriter.replaceOp(op,
976 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_mk_repeat",
977 {count, adaptor.getInput()}));
989 struct ExtractOpLowering :
public SMTLoweringPattern<ExtractOp> {
990 using SMTLoweringPattern::SMTLoweringPattern;
993 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
994 ConversionPatternRewriter &rewriter)
const final {
995 Location loc = op.getLoc();
996 Value low = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
997 adaptor.getLowBit());
998 Value high = rewriter.create<LLVM::ConstantOp>(
999 loc, rewriter.getI32Type(),
1000 adaptor.getLowBit() + op.getType().getWidth() - 1);
1001 rewriter.replaceOp(op, buildPtrAPICall(rewriter, loc,
"Z3_mk_extract",
1002 {high, low, adaptor.getInput()}));
1011 struct ArrayBroadcastOpLowering
1012 :
public SMTLoweringPattern<smt::ArrayBroadcastOp> {
1013 using SMTLoweringPattern::SMTLoweringPattern;
1016 matchAndRewrite(smt::ArrayBroadcastOp op, OpAdaptor adaptor,
1017 ConversionPatternRewriter &rewriter)
const final {
1018 auto domainSort = buildSort(
1019 rewriter, op.getLoc(),
1020 cast<smt::ArrayType>(op.getResult().getType()).getDomainType());
1022 rewriter.replaceOp(op, buildPtrAPICall(rewriter, op.getLoc(),
1023 "Z3_mk_const_array",
1024 {domainSort, adaptor.getValue()}));
1035 struct BoolConstantOpLowering :
public SMTLoweringPattern<smt::BoolConstantOp> {
1036 using SMTLoweringPattern::SMTLoweringPattern;
1039 matchAndRewrite(smt::BoolConstantOp op, OpAdaptor adaptor,
1040 ConversionPatternRewriter &rewriter)
const final {
1042 op, buildPtrAPICall(rewriter, op.getLoc(),
1043 adaptor.getValue() ?
"Z3_mk_true" :
"Z3_mk_false"));
1059 struct IntConstantOpLowering :
public SMTLoweringPattern<smt::IntConstantOp> {
1060 using SMTLoweringPattern::SMTLoweringPattern;
1063 matchAndRewrite(smt::IntConstantOp op, OpAdaptor adaptor,
1064 ConversionPatternRewriter &rewriter)
const final {
1065 Location loc = op.getLoc();
1066 Value type = buildPtrAPICall(rewriter, loc,
"Z3_mk_int_sort");
1067 if (adaptor.getValue().getBitWidth() <= 64) {
1068 Value val = rewriter.create<LLVM::ConstantOp>(
1069 loc, rewriter.getI64Type(), adaptor.getValue().getSExtValue());
1071 op, buildPtrAPICall(rewriter, loc,
"Z3_mk_int64", {val, type}));
1075 std::string numeralStr;
1076 llvm::raw_string_ostream stream(numeralStr);
1077 stream << adaptor.getValue().abs();
1079 Value numeral = buildString(rewriter, loc, numeralStr);
1081 buildPtrAPICall(rewriter, loc,
"Z3_mk_numeral", {numeral, type});
1083 if (adaptor.getValue().isNegative())
1085 buildPtrAPICall(rewriter, loc,
"Z3_mk_unary_minus", intNumeral);
1087 rewriter.replaceOp(op, intNumeral);
1097 struct IntCmpOpLowering :
public SMTLoweringPattern<IntCmpOp> {
1098 using SMTLoweringPattern::SMTLoweringPattern;
1101 matchAndRewrite(IntCmpOp op, OpAdaptor adaptor,
1102 ConversionPatternRewriter &rewriter)
const final {
1105 buildPtrAPICall(rewriter, op.getLoc(),
1106 "Z3_mk_" + stringifyIntPredicate(op.getPred()).str(),
1107 {adaptor.getLhs(), adaptor.getRhs()}));
1118 struct BVCmpOpLowering :
public SMTLoweringPattern<BVCmpOp> {
1119 using SMTLoweringPattern::SMTLoweringPattern;
1122 matchAndRewrite(BVCmpOp op, OpAdaptor adaptor,
1123 ConversionPatternRewriter &rewriter)
const final {
1125 op, buildPtrAPICall(rewriter, op.getLoc(),
1127 stringifyBVCmpPredicate(op.getPred()).str(),
1128 {adaptor.getLhs(), adaptor.getRhs()}));
1134 struct IntAbsOpLowering :
public SMTLoweringPattern<IntAbsOp> {
1135 using SMTLoweringPattern::SMTLoweringPattern;
1138 matchAndRewrite(IntAbsOp op, OpAdaptor adaptor,
1139 ConversionPatternRewriter &rewriter)
const final {
1140 Location loc = op.getLoc();
1141 Value zero = rewriter.create<IntConstantOp>(
1142 loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
1143 Value cmp = rewriter.create<IntCmpOp>(loc, IntPredicate::lt,
1144 adaptor.getInput(), zero);
1145 Value neg = rewriter.create<IntSubOp>(loc, zero, adaptor.getInput());
1146 rewriter.replaceOpWithNewOp<IteOp>(op, cmp, neg, adaptor.getInput());
1158 struct LowerSMTToZ3LLVMPass
1159 :
public circt::impl::LowerSMTToZ3LLVMBase<LowerSMTToZ3LLVMPass> {
1161 void runOnOperation()
override;
1166 converter.addConversion([](smt::BoolType type) {
1169 converter.addConversion([](smt::BitVectorType type) {
1172 converter.addConversion([](smt::ArrayType type) {
1175 converter.addConversion([](smt::IntType type) {
1178 converter.addConversion([](smt::SMTFuncType type) {
1181 converter.addConversion([](smt::SortType type) {
1187 RewritePatternSet &
patterns, TypeConverter &converter,
1189 #define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS) \
1190 patterns.add<VariadicSMTPattern<OP>>( \
1191 converter, patterns.getContext(), \
1192 globals, options, APINAME, \
1195 #define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS) \
1196 patterns.add<OneToOneSMTPattern<OP>>( \
1197 converter, patterns.getContext(), \
1198 globals, options, APINAME, NUM_ARGS);
1251 patterns.add<LowerLeftAssocSMTPattern<XOrOp>>(
1252 converter,
patterns.getContext(), globals, options);
1352 #undef ADD_VARIADIC_PATTERN
1353 #undef ADD_ONE_TO_ONE_PATTERN
1370 patterns.add<LowerChainableSMTPattern<EqOp>>(converter,
patterns.getContext(),
1373 globals, options,
"Z3_mk_eq", 2);
1377 patterns.add<BVConstantOpLowering, DeclareFunOpLowering, AssertOpLowering,
1378 CheckOpLowering, SolverOpLowering, ApplyFuncOpLowering,
1379 YieldOpLowering, RepeatOpLowering, ExtractOpLowering,
1380 BoolConstantOpLowering, IntConstantOpLowering,
1381 ArrayBroadcastOpLowering, BVCmpOpLowering, IntCmpOpLowering,
1382 IntAbsOpLowering, QuantifierLowering<ForallOp>,
1383 QuantifierLowering<ExistsOp>>(converter,
patterns.getContext(),
1387 void LowerSMTToZ3LLVMPass::runOnOperation() {
1388 LowerSMTToZ3LLVMOptions options;
1389 options.debug =
debug;
1392 LLVMTypeConverter converter(&getContext());
1395 RewritePatternSet
patterns(&getContext());
1411 populateFuncToLLVMConversionPatterns(converter,
patterns);
1412 arith::populateArithToLLVMConversionPatterns(converter,
patterns);
1417 populateSCFToControlFlowConversionPatterns(
patterns);
1418 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter,
patterns);
1422 OpBuilder builder(&getContext());
1428 LLVMConversionTarget target(getContext());
1429 target.addLegalOp<mlir::ModuleOp>();
1430 target.addLegalOp<scf::YieldOp>();
1432 if (failed(applyFullConversion(getOperation(), target, std::move(
patterns))))
1433 return signalPassFailure();
#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS)
#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS)
RewritePatternSet pattern
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateSMTToZ3LLVMTypeConverter(TypeConverter &converter)
Populate the given type converter with the SMT to LLVM type conversions.
void populateSMTToZ3LLVMConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter, SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options)
Add the SMT to LLVM IR conversion patterns to 'patterns'.
A symbol cache for LLVM globals and functions relevant to SMT lowering patterns.
static SMTGlobalsHandler create(OpBuilder &builder, ModuleOp module)
Creates the LLVM global operations to store the pointers to the solver and the context and returns a ...
SMTGlobalsHandler(ModuleOp module, mlir::LLVM::GlobalOp solver, mlir::LLVM::GlobalOp ctx)
Initializes the caches and keeps track of the given globals to store the pointers to the SMT solver a...