12 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
13 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
17 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
18 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 #include "mlir/Dialect/SCF/IR/SCF.h"
24 #include "mlir/IR/BuiltinDialect.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
30 #define DEBUG_TYPE "lower-smt-to-z3-llvm"
33 #define GEN_PASS_DEF_LOWERSMTTOZ3LLVM
34 #include "circt/Conversion/Passes.h.inc"
38 using namespace circt;
47 OpBuilder::InsertionGuard guard(builder);
48 builder.setInsertionPointToStart(module.getBody());
55 Location loc = module.getLoc();
58 auto createGlobal = [&](StringRef namePrefix) {
59 auto global = builder.create<LLVM::GlobalOp>(
60 loc, ptrTy,
false, LLVM::Linkage::Internal, names.
newName(namePrefix),
62 OpBuilder::InsertionGuard g(builder);
63 builder.createBlock(&global.getInitializer());
64 Value res = builder.create<LLVM::ZeroOp>(loc, ptrTy);
65 builder.create<LLVM::ReturnOp>(loc, res);
69 auto ctxGlobal = createGlobal(
"ctx");
70 auto solverGlobal = createGlobal(
"solver");
75 SMTGlobalsHandler::SMTGlobalsHandler(
Namespace &&names,
76 mlir::LLVM::GlobalOp solver,
77 mlir::LLVM::GlobalOp ctx)
78 : solver(solver), ctx(ctx), names(names) {}
81 mlir::LLVM::GlobalOp solver,
82 mlir::LLVM::GlobalOp ctx)
83 : solver(solver), ctx(ctx) {
95 template <
typename OpTy>
98 SMTLoweringPattern(
const TypeConverter &typeConverter, MLIRContext *context,
100 const LowerSMTToZ3LLVMOptions &options)
105 Value buildGlobalPtrToGlobal(OpBuilder &builder, Location loc,
106 LLVM::GlobalOp global,
107 DenseMap<Block *, Value> &cache)
const {
108 Block *block = builder.getBlock();
109 if (
auto iter = cache.find(block); iter != cache.end())
110 return iter->getSecond();
112 OpBuilder::InsertionGuard g(builder);
113 builder.setInsertionPointToStart(block);
114 Value globalAddr = builder.create<LLVM::AddressOfOp>(loc, global);
115 return cache[block] = builder.create<LLVM::LoadOp>(
126 Value buildContextPtr(OpBuilder &builder, Location loc)
const {
127 return buildGlobalPtrToGlobal(builder, loc, globals.ctx, globals.ctxCache);
135 Value buildSolverPtr(OpBuilder &builder, Location loc)
const {
136 return buildGlobalPtrToGlobal(builder, loc, globals.solver,
137 globals.solverCache);
143 LLVM::CallOp buildCall(OpBuilder &builder, Location loc, StringRef name,
144 LLVM::LLVMFunctionType funcType,
145 ValueRange args)
const {
146 auto &funcOp = globals.funcMap[builder.getStringAttr(name)];
148 OpBuilder::InsertionGuard guard(builder);
150 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
151 builder.setInsertionPointToEnd(module.getBody());
152 funcOp = LLVM::lookupOrCreateFn(module, name, funcType.getParams(),
153 funcType.getReturnType(),
154 funcType.getVarArg());
156 return builder.create<LLVM::CallOp>(loc, funcOp, args);
163 Value buildString(OpBuilder &builder, Location loc, StringRef str)
const {
164 auto &global = globals.stringCache[builder.getStringAttr(str)];
166 OpBuilder::InsertionGuard guard(builder);
168 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
169 builder.setInsertionPointToEnd(module.getBody());
172 auto strAttr = builder.getStringAttr(str.str() +
'\00');
173 global = builder.create<LLVM::GlobalOp>(
174 loc, arrayTy,
true, LLVM::Linkage::Internal,
175 globals.names.newName(
"str"), strAttr);
177 return builder.create<LLVM::AddressOfOp>(loc, global);
182 LLVM::CallOp buildAPICallWithContext(OpBuilder &builder, Location loc,
183 StringRef name, Type returnType,
184 ValueRange args = {})
const {
185 auto ctx = buildContextPtr(builder, loc);
186 SmallVector<Value> arguments;
187 arguments.emplace_back(ctx);
188 arguments.append(SmallVector<Value>(args));
192 returnType, SmallVector<Type>(ValueRange(arguments).getTypes())),
199 Value buildPtrAPICall(OpBuilder &builder, Location loc, StringRef name,
200 ValueRange args = {})
const {
201 return buildAPICallWithContext(
208 Value buildSort(OpBuilder &builder, Location loc, Type type)
const {
211 return TypeSwitch<Type, Value>(type)
212 .Case([&](smt::IntType ty) {
213 return buildPtrAPICall(builder, loc,
"Z3_mk_int_sort");
215 .Case([&](smt::BitVectorType ty) {
216 Value bitwidth = builder.create<LLVM::ConstantOp>(
217 loc, builder.getI32Type(), ty.getWidth());
218 return buildPtrAPICall(builder, loc,
"Z3_mk_bv_sort", {bitwidth});
220 .Case([&](smt::BoolType ty) {
221 return buildPtrAPICall(builder, loc,
"Z3_mk_bool_sort");
223 .Case([&](smt::SortType ty) {
224 Value str = buildString(builder, loc, ty.getIdentifier());
226 buildPtrAPICall(builder, loc,
"Z3_mk_string_symbol", {str});
227 return buildPtrAPICall(builder, loc,
"Z3_mk_uninterpreted_sort",
230 .Case([&](smt::ArrayType ty) {
231 return buildPtrAPICall(builder, loc,
"Z3_mk_array_sort",
232 {buildSort(builder, loc, ty.getDomainType()),
233 buildSort(builder, loc, ty.getRangeType())});
238 const LowerSMTToZ3LLVMOptions &options;
255 struct DeclareFunOpLowering :
public SMTLoweringPattern<DeclareFunOp> {
256 using SMTLoweringPattern::SMTLoweringPattern;
259 matchAndRewrite(DeclareFunOp op, OpAdaptor adaptor,
260 ConversionPatternRewriter &rewriter)
const final {
261 Location loc = op.getLoc();
265 if (adaptor.getNamePrefix())
266 prefix = buildString(rewriter, loc, *adaptor.getNamePrefix());
268 prefix = rewriter.create<LLVM::ZeroOp>(
272 if (!isa<SMTFuncType>(op.getType())) {
273 Value sort = buildSort(rewriter, loc, op.getType());
275 buildPtrAPICall(rewriter, loc,
"Z3_mk_fresh_const", {prefix, sort});
276 rewriter.replaceOp(op, constDecl);
282 auto funcType = cast<SMTFuncType>(op.getResult().getType());
283 Value rangeSort = buildSort(rewriter, loc, funcType.getRangeType());
288 Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
289 for (
auto [i, ty] : llvm::enumerate(funcType.getDomainTypes())) {
290 Value sort = buildSort(rewriter, loc, ty);
291 domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, sort, i);
295 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
296 Value domainStorage =
297 rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
298 rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
300 Value domainSize = rewriter.create<LLVM::ConstantOp>(
301 loc, rewriter.getI32Type(), funcType.getDomainTypes().size());
303 buildPtrAPICall(rewriter, loc,
"Z3_mk_fresh_func_decl",
304 {prefix, domainSize, domainStorage, rangeSort});
306 rewriter.replaceOp(op, decl);
316 struct ApplyFuncOpLowering :
public SMTLoweringPattern<ApplyFuncOp> {
317 using SMTLoweringPattern::SMTLoweringPattern;
320 matchAndRewrite(ApplyFuncOp op, OpAdaptor adaptor,
321 ConversionPatternRewriter &rewriter)
const final {
322 Location loc = op.getLoc();
327 Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
328 for (
auto [i, arg] : llvm::enumerate(adaptor.getArgs()))
329 domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, arg, i);
333 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
334 Value domainStorage =
335 rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
336 rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
340 Value domainSize = rewriter.create<LLVM::ConstantOp>(
341 loc, rewriter.getI32Type(), adaptor.getArgs().size());
343 buildPtrAPICall(rewriter, loc,
"Z3_mk_app",
344 {adaptor.getFunc(), domainSize, domainStorage});
345 rewriter.replaceOp(op, returnVal);
363 struct BVConstantOpLowering :
public SMTLoweringPattern<smt::BVConstantOp> {
364 using SMTLoweringPattern::SMTLoweringPattern;
367 matchAndRewrite(smt::BVConstantOp op, OpAdaptor adaptor,
368 ConversionPatternRewriter &rewriter)
const final {
369 Location loc = op.getLoc();
370 unsigned width = op.getType().getWidth();
371 auto bvSort = buildSort(rewriter, loc, op.getResult().getType());
372 APInt val = adaptor.getValue().getValue();
375 Value bvConst = rewriter.create<LLVM::ConstantOp>(
376 loc, rewriter.getI64Type(), val.getZExtValue());
377 Value res = buildPtrAPICall(rewriter, loc,
"Z3_mk_unsigned_int64",
379 rewriter.replaceOp(op, res);
384 llvm::raw_string_ostream stream(str);
386 Value bvString = buildString(rewriter, loc, str);
388 buildPtrAPICall(rewriter, loc,
"Z3_mk_numeral", {bvString, bvSort});
390 rewriter.replaceOp(op, bvNumeral);
400 template <
typename SourceTy>
401 struct VariadicSMTPattern :
public SMTLoweringPattern<SourceTy> {
402 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
404 VariadicSMTPattern(
const TypeConverter &typeConverter, MLIRContext *context,
406 const LowerSMTToZ3LLVMOptions &options,
407 StringRef apiFuncName,
unsigned minNumArgs)
408 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
409 apiFuncName(apiFuncName), minNumArgs(minNumArgs) {}
412 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
413 ConversionPatternRewriter &rewriter)
const final {
414 if (adaptor.getOperands().size() < minNumArgs)
417 Location loc = op.getLoc();
418 Value numOperands = rewriter.create<LLVM::ConstantOp>(
419 loc, rewriter.getI32Type(), op->getNumOperands());
421 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
425 rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
426 Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
428 for (
auto [i, operand] : llvm::enumerate(adaptor.getOperands()))
429 array = rewriter.create<LLVM::InsertValueOp>(
430 loc, array, operand, ArrayRef<int64_t>{(int64_t)i});
432 rewriter.create<LLVM::StoreOp>(loc, array, storage);
434 rewriter.replaceOp(op,
435 SMTLoweringPattern<SourceTy>::buildPtrAPICall(
436 rewriter, loc, apiFuncName, {numOperands, storage}));
441 StringRef apiFuncName;
447 template <
typename SourceTy>
448 struct OneToOneSMTPattern :
public SMTLoweringPattern<SourceTy> {
449 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
451 OneToOneSMTPattern(
const TypeConverter &typeConverter, MLIRContext *context,
453 const LowerSMTToZ3LLVMOptions &options,
454 StringRef apiFuncName,
unsigned numOperands)
455 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
456 apiFuncName(apiFuncName), numOperands(numOperands) {}
459 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
460 ConversionPatternRewriter &rewriter)
const final {
461 if (adaptor.getOperands().size() != numOperands)
465 op, SMTLoweringPattern<SourceTy>::buildPtrAPICall(
466 rewriter, op.getLoc(), apiFuncName, adaptor.getOperands()));
471 StringRef apiFuncName;
472 unsigned numOperands;
477 template <
typename SourceTy>
478 class LowerChainableSMTPattern :
public SMTLoweringPattern<SourceTy> {
479 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
480 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
483 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter)
const final {
485 if (adaptor.getOperands().size() <= 2)
488 Location loc = op.getLoc();
489 SmallVector<Value> elements;
490 for (
int i = 1, e = adaptor.getOperands().size(); i < e; ++i) {
491 Value val = rewriter.create<SourceTy>(
492 loc, op->getResultTypes(),
493 ValueRange{adaptor.getOperands()[i - 1], adaptor.getOperands()[i]});
494 elements.push_back(val);
496 rewriter.replaceOpWithNewOp<smt::AndOp>(op, elements);
503 template <
typename SourceTy>
504 class LowerLeftAssocSMTPattern :
public SMTLoweringPattern<SourceTy> {
505 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
506 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
509 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
510 ConversionPatternRewriter &rewriter)
const final {
511 if (adaptor.getOperands().size() <= 2)
512 return rewriter.notifyMatchFailure(op,
"must have at least two operands");
514 Value runner = adaptor.getOperands()[0];
515 for (Value val : adaptor.getOperands().drop_front())
516 runner = rewriter.create<SourceTy>(op.getLoc(), op->getResultTypes(),
517 ValueRange{runner, val});
519 rewriter.replaceOp(op, runner);
554 struct SolverOpLowering :
public SMTLoweringPattern<SolverOp> {
555 using SMTLoweringPattern::SMTLoweringPattern;
558 matchAndRewrite(SolverOp op, OpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter)
const final {
560 Location loc = op.getLoc();
569 Value config = buildCall(rewriter, loc,
"Z3_mk_config",
576 Value paramKey = buildString(rewriter, loc,
"proof");
577 Value paramValue = buildString(rewriter, loc,
"true");
578 buildCall(rewriter, loc,
"Z3_set_param_value",
580 {config, paramKey, paramValue});
584 std::optional<StringRef> logic = std::nullopt;
585 auto setLogicOps = op.getBodyRegion().getOps<smt::SetLogicOp>();
586 if (!setLogicOps.empty()) {
589 auto setLogicOp = *setLogicOps.begin();
590 logic = setLogicOp.getLogic();
591 rewriter.eraseOp(setLogicOp);
595 Value ctx = buildCall(rewriter, loc,
"Z3_mk_context", ptrToPtrFunc, config)
598 rewriter.create<LLVM::AddressOfOp>(loc, globals.ctx).getResult();
599 rewriter.create<LLVM::StoreOp>(loc, ctx, ctxAddr);
602 buildCall(rewriter, loc,
"Z3_del_config", ptrToVoidFunc, {config});
608 auto logicStr = buildString(rewriter, loc, logic.value());
609 solver = buildCall(rewriter, loc,
"Z3_mk_solver_for_logic",
610 ptrPtrToPtrFunc, {ctx, logicStr})
613 solver = buildCall(rewriter, loc,
"Z3_mk_solver", ptrToPtrFunc, ctx)
616 buildCall(rewriter, loc,
"Z3_solver_inc_ref", ptrPtrToVoidFunc,
619 rewriter.create<LLVM::AddressOfOp>(loc, globals.solver).getResult();
620 rewriter.create<LLVM::StoreOp>(loc, solver, solverAddr);
628 SmallVector<Type> convertedTypes;
630 typeConverter->convertTypes(op->getResultTypes(), convertedTypes)))
635 OpBuilder::InsertionGuard guard(rewriter);
636 auto module = op->getParentOfType<ModuleOp>();
637 rewriter.setInsertionPointToEnd(module.getBody());
639 funcOp = rewriter.create<func::FuncOp>(
640 loc, globals.names.newName(
"solver"),
641 rewriter.getFunctionType(adaptor.getInputs().getTypes(),
643 rewriter.inlineRegionBefore(op.getBodyRegion(), funcOp.getBody(),
648 rewriter.create<func::CallOp>(loc, funcOp, adaptor.getInputs())
659 buildCall(rewriter, loc,
"Z3_solver_dec_ref", ptrPtrToVoidFunc,
661 buildCall(rewriter, loc,
"Z3_del_context", ptrToVoidFunc, ctx);
663 rewriter.replaceOp(op, results);
672 struct AssertOpLowering :
public SMTLoweringPattern<AssertOp> {
673 using SMTLoweringPattern::SMTLoweringPattern;
676 matchAndRewrite(AssertOp op, OpAdaptor adaptor,
677 ConversionPatternRewriter &rewriter)
const final {
678 Location loc = op.getLoc();
679 buildAPICallWithContext(
680 rewriter, loc,
"Z3_solver_assert",
682 {buildSolverPtr(rewriter, loc), adaptor.getInput()});
684 rewriter.eraseOp(op);
693 struct ResetOpLowering :
public SMTLoweringPattern<ResetOp> {
694 using SMTLoweringPattern::SMTLoweringPattern;
697 matchAndRewrite(ResetOp op, OpAdaptor adaptor,
698 ConversionPatternRewriter &rewriter)
const final {
699 Location loc = op.getLoc();
700 buildAPICallWithContext(rewriter, loc,
"Z3_solver_reset",
702 {buildSolverPtr(rewriter, loc)});
704 rewriter.eraseOp(op);
713 struct PushOpLowering :
public SMTLoweringPattern<PushOp> {
714 using SMTLoweringPattern::SMTLoweringPattern;
716 matchAndRewrite(PushOp op, OpAdaptor adaptor,
717 ConversionPatternRewriter &rewriter)
const final {
718 Location loc = op.getLoc();
722 for (uint32_t i = 0; i < op.getCount(); i++)
723 buildAPICallWithContext(rewriter, loc,
"Z3_solver_push",
725 {buildSolverPtr(rewriter, loc)});
726 rewriter.eraseOp(op);
735 struct PopOpLowering :
public SMTLoweringPattern<PopOp> {
736 using SMTLoweringPattern::SMTLoweringPattern;
738 matchAndRewrite(PopOp op, OpAdaptor adaptor,
739 ConversionPatternRewriter &rewriter)
const final {
740 Location loc = op.getLoc();
741 Value constVal = rewriter.create<LLVM::ConstantOp>(
742 loc, rewriter.getI32Type(), op.getCount());
743 buildAPICallWithContext(rewriter, loc,
"Z3_solver_pop",
745 {buildSolverPtr(rewriter, loc), constVal});
746 rewriter.eraseOp(op);
756 struct YieldOpLowering :
public SMTLoweringPattern<YieldOp> {
757 using SMTLoweringPattern::SMTLoweringPattern;
760 matchAndRewrite(YieldOp op, OpAdaptor adaptor,
761 ConversionPatternRewriter &rewriter)
const final {
762 if (op->getParentOfType<func::FuncOp>()) {
763 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getValues());
766 if (op->getParentOfType<LLVM::LLVMFuncOp>()) {
767 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getValues());
770 if (isa<scf::SCFDialect>(op->getParentOp()->getDialect())) {
771 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getValues());
789 struct CheckOpLowering :
public SMTLoweringPattern<CheckOp> {
790 using SMTLoweringPattern::SMTLoweringPattern;
793 matchAndRewrite(CheckOp op, OpAdaptor adaptor,
794 ConversionPatternRewriter &rewriter)
const final {
795 Location loc = op.getLoc();
800 auto getHeaderString = [](
const std::string &title) {
801 unsigned titleSize = title.size() + 2;
802 return std::string((80 - titleSize) / 2,
'-') +
" " + title +
" " +
803 std::string((80 - titleSize + 1) / 2,
'-') +
"\n%s\n" +
804 std::string(80,
'-') +
"\n";
808 Value solver = buildSolverPtr(rewriter, loc);
813 auto solverStringPtr =
814 buildPtrAPICall(rewriter, loc,
"Z3_solver_to_string", {solver});
815 auto solverFormatString =
816 buildString(rewriter, loc, getHeaderString(
"Solver"));
817 buildCall(rewriter, op.getLoc(),
"printf", printfType,
818 {solverFormatString, solverStringPtr});
822 SmallVector<Type> resultTypes;
823 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
828 buildAPICallWithContext(rewriter, loc,
"Z3_solver_check",
829 rewriter.getI32Type(), {solver})
832 rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), 1);
833 Value isSat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
834 checkResult, constOne);
837 auto satIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isSat);
838 rewriter.inlineRegionBefore(op.getSatRegion(), satIfOp.getThenRegion(),
839 satIfOp.getThenRegion().end());
844 rewriter.createBlock(&satIfOp.getElseRegion());
846 rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), -1);
847 Value isUnsat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
848 checkResult, constNegOne);
849 auto unsatIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isUnsat);
850 rewriter.create<scf::YieldOp>(loc, unsatIfOp->getResults());
852 rewriter.inlineRegionBefore(op.getUnsatRegion(), unsatIfOp.getThenRegion(),
853 unsatIfOp.getThenRegion().end());
854 rewriter.inlineRegionBefore(op.getUnknownRegion(),
855 unsatIfOp.getElseRegion(),
856 unsatIfOp.getElseRegion().end());
858 rewriter.replaceOp(op, satIfOp->getResults());
863 rewriter.setInsertionPointToStart(unsatIfOp.thenBlock());
864 auto proof = buildPtrAPICall(rewriter, op.getLoc(),
"Z3_solver_get_proof",
867 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_ast_to_string", {proof});
869 buildString(rewriter, op.getLoc(), getHeaderString(
"Proof"));
870 buildCall(rewriter, op.getLoc(),
"printf", printfType,
871 {formatString, stringPtr});
875 rewriter.setInsertionPointToStart(satIfOp.thenBlock());
876 auto model = buildPtrAPICall(rewriter, op.getLoc(),
"Z3_solver_get_model",
878 auto modelStringPtr =
879 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_model_to_string", {model});
880 auto modelFormatString =
881 buildString(rewriter, op.getLoc(), getHeaderString(
"Model"));
882 buildCall(rewriter, op.getLoc(),
"printf", printfType,
883 {modelFormatString, modelStringPtr});
911 template <
typename QuantifierOp>
912 struct QuantifierLowering :
public SMTLoweringPattern<QuantifierOp> {
913 using SMTLoweringPattern<QuantifierOp>::SMTLoweringPattern;
914 using SMTLoweringPattern<QuantifierOp>::typeConverter;
915 using SMTLoweringPattern<QuantifierOp>::buildPtrAPICall;
916 using OpAdaptor =
typename QuantifierOp::Adaptor;
918 Value createStorageForValueList(ValueRange values, Location loc,
919 ConversionPatternRewriter &rewriter)
const {
923 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
925 rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
926 Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
928 for (
auto [i, val] : llvm::enumerate(values))
929 array = rewriter.create<LLVM::InsertValueOp>(loc, array, val,
930 ArrayRef<int64_t>(i));
932 rewriter.create<LLVM::StoreOp>(loc, array, storage);
938 matchAndRewrite(QuantifierOp op, OpAdaptor adaptor,
939 ConversionPatternRewriter &rewriter)
const final {
940 Location loc = op.getLoc();
947 if (adaptor.getNoPattern())
948 return rewriter.notifyMatchFailure(
949 op,
"no-pattern attribute not yet supported!");
951 rewriter.setInsertionPoint(op);
954 Value weight = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
955 adaptor.getWeight());
958 unsigned numDecls = op.getBody().getNumArguments();
960 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), numDecls);
967 SmallVector<Value> repl;
968 for (
auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) {
970 if (adaptor.getBoundVarNames().has_value())
971 newArg = rewriter.create<smt::DeclareFunOp>(
973 cast<StringAttr>((*adaptor.getBoundVarNames())[i]));
975 newArg = rewriter.create<smt::DeclareFunOp>(loc, arg.getType());
976 repl.push_back(typeConverter->materializeTargetConversion(
977 rewriter, loc, typeConverter->convertType(arg.getType()), newArg));
980 Value boundStorage = createStorageForValueList(repl, loc, rewriter);
983 auto yieldOp = cast<smt::YieldOp>(op.getBody().front().getTerminator());
984 Value bodyExp = yieldOp.getValues()[0];
985 rewriter.setInsertionPointAfterValue(bodyExp);
986 bodyExp = typeConverter->materializeTargetConversion(
987 rewriter, loc, typeConverter->convertType(bodyExp.getType()), bodyExp);
988 rewriter.eraseOp(yieldOp);
990 rewriter.inlineBlockBefore(&op.getBody().front(), op, repl);
991 rewriter.setInsertionPoint(op);
994 unsigned numPatterns = adaptor.getPatterns().size();
995 Value numPatternsVal = rewriter.create<LLVM::ConstantOp>(
996 loc, rewriter.getI32Type(), numPatterns);
998 Value patternStorage;
999 if (numPatterns > 0) {
1001 for (Region *patternRegion : adaptor.getPatterns()) {
1003 cast<smt::YieldOp>(patternRegion->front().getTerminator());
1004 auto patternTerms = yieldOp.getOperands();
1006 rewriter.setInsertionPoint(yieldOp);
1007 SmallVector<Value> patternList;
1008 for (
auto val : patternTerms)
1009 patternList.push_back(typeConverter->materializeTargetConversion(
1010 rewriter, loc, typeConverter->convertType(val.getType()), val));
1012 rewriter.eraseOp(yieldOp);
1013 rewriter.inlineBlockBefore(&patternRegion->front(), op, repl);
1015 rewriter.setInsertionPoint(op);
1016 Value numTerms = rewriter.create<LLVM::ConstantOp>(
1017 loc, rewriter.getI32Type(), patternTerms.size());
1018 Value patternTermStorage =
1019 createStorageForValueList(patternList, loc, rewriter);
1020 Value
pattern = buildPtrAPICall(rewriter, loc,
"Z3_mk_pattern",
1021 {numTerms, patternTermStorage});
1025 patternStorage = createStorageForValueList(
patterns, loc, rewriter);
1029 patternStorage = rewriter.create<LLVM::ZeroOp>(loc, ptrTy);
1032 StringRef apiCallName =
"Z3_mk_forall_const";
1033 if (std::is_same_v<QuantifierOp, ExistsOp>)
1034 apiCallName =
"Z3_mk_exists_const";
1035 Value quantifierExp =
1036 buildPtrAPICall(rewriter, loc, apiCallName,
1037 {weight, numDeclsVal, boundStorage, numPatternsVal,
1038 patternStorage, bodyExp});
1040 rewriter.replaceOp(op, quantifierExp);
1049 struct RepeatOpLowering :
public SMTLoweringPattern<RepeatOp> {
1050 using SMTLoweringPattern::SMTLoweringPattern;
1053 matchAndRewrite(RepeatOp op, OpAdaptor adaptor,
1054 ConversionPatternRewriter &rewriter)
const final {
1055 Value count = rewriter.create<LLVM::ConstantOp>(
1056 op.getLoc(), rewriter.getI32Type(), op.getCount());
1057 rewriter.replaceOp(op,
1058 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_mk_repeat",
1059 {count, adaptor.getInput()}));
1071 struct ExtractOpLowering :
public SMTLoweringPattern<ExtractOp> {
1072 using SMTLoweringPattern::SMTLoweringPattern;
1075 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
1076 ConversionPatternRewriter &rewriter)
const final {
1077 Location loc = op.getLoc();
1078 Value low = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
1079 adaptor.getLowBit());
1080 Value high = rewriter.create<LLVM::ConstantOp>(
1081 loc, rewriter.getI32Type(),
1082 adaptor.getLowBit() + op.getType().getWidth() - 1);
1083 rewriter.replaceOp(op, buildPtrAPICall(rewriter, loc,
"Z3_mk_extract",
1084 {high, low, adaptor.getInput()}));
1093 struct ArrayBroadcastOpLowering
1094 :
public SMTLoweringPattern<smt::ArrayBroadcastOp> {
1095 using SMTLoweringPattern::SMTLoweringPattern;
1098 matchAndRewrite(smt::ArrayBroadcastOp op, OpAdaptor adaptor,
1099 ConversionPatternRewriter &rewriter)
const final {
1100 auto domainSort = buildSort(
1101 rewriter, op.getLoc(),
1102 cast<smt::ArrayType>(op.getResult().getType()).getDomainType());
1104 rewriter.replaceOp(op, buildPtrAPICall(rewriter, op.getLoc(),
1105 "Z3_mk_const_array",
1106 {domainSort, adaptor.getValue()}));
1117 struct BoolConstantOpLowering :
public SMTLoweringPattern<smt::BoolConstantOp> {
1118 using SMTLoweringPattern::SMTLoweringPattern;
1121 matchAndRewrite(smt::BoolConstantOp op, OpAdaptor adaptor,
1122 ConversionPatternRewriter &rewriter)
const final {
1124 op, buildPtrAPICall(rewriter, op.getLoc(),
1125 adaptor.getValue() ?
"Z3_mk_true" :
"Z3_mk_false"));
1141 struct IntConstantOpLowering :
public SMTLoweringPattern<smt::IntConstantOp> {
1142 using SMTLoweringPattern::SMTLoweringPattern;
1145 matchAndRewrite(smt::IntConstantOp op, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter)
const final {
1147 Location loc = op.getLoc();
1148 Value type = buildPtrAPICall(rewriter, loc,
"Z3_mk_int_sort");
1149 if (adaptor.getValue().getBitWidth() <= 64) {
1150 Value val = rewriter.create<LLVM::ConstantOp>(
1151 loc, rewriter.getI64Type(), adaptor.getValue().getSExtValue());
1153 op, buildPtrAPICall(rewriter, loc,
"Z3_mk_int64", {val, type}));
1157 std::string numeralStr;
1158 llvm::raw_string_ostream stream(numeralStr);
1159 stream << adaptor.getValue().abs();
1161 Value numeral = buildString(rewriter, loc, numeralStr);
1163 buildPtrAPICall(rewriter, loc,
"Z3_mk_numeral", {numeral, type});
1165 if (adaptor.getValue().isNegative())
1167 buildPtrAPICall(rewriter, loc,
"Z3_mk_unary_minus", intNumeral);
1169 rewriter.replaceOp(op, intNumeral);
1179 struct IntCmpOpLowering :
public SMTLoweringPattern<IntCmpOp> {
1180 using SMTLoweringPattern::SMTLoweringPattern;
1183 matchAndRewrite(IntCmpOp op, OpAdaptor adaptor,
1184 ConversionPatternRewriter &rewriter)
const final {
1187 buildPtrAPICall(rewriter, op.getLoc(),
1188 "Z3_mk_" + stringifyIntPredicate(op.getPred()).str(),
1189 {adaptor.getLhs(), adaptor.getRhs()}));
1200 struct BVCmpOpLowering :
public SMTLoweringPattern<BVCmpOp> {
1201 using SMTLoweringPattern::SMTLoweringPattern;
1204 matchAndRewrite(BVCmpOp op, OpAdaptor adaptor,
1205 ConversionPatternRewriter &rewriter)
const final {
1207 op, buildPtrAPICall(rewriter, op.getLoc(),
1209 stringifyBVCmpPredicate(op.getPred()).str(),
1210 {adaptor.getLhs(), adaptor.getRhs()}));
1216 struct IntAbsOpLowering :
public SMTLoweringPattern<IntAbsOp> {
1217 using SMTLoweringPattern::SMTLoweringPattern;
1220 matchAndRewrite(IntAbsOp op, OpAdaptor adaptor,
1221 ConversionPatternRewriter &rewriter)
const final {
1222 Location loc = op.getLoc();
1223 Value zero = rewriter.create<IntConstantOp>(
1224 loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
1225 Value cmp = rewriter.create<IntCmpOp>(loc, IntPredicate::lt,
1226 adaptor.getInput(), zero);
1227 Value neg = rewriter.create<IntSubOp>(loc, zero, adaptor.getInput());
1228 rewriter.replaceOpWithNewOp<IteOp>(op, cmp, neg, adaptor.getInput());
1240 struct LowerSMTToZ3LLVMPass
1241 :
public circt::impl::LowerSMTToZ3LLVMBase<LowerSMTToZ3LLVMPass> {
1243 void runOnOperation()
override;
1248 converter.addConversion([](smt::BoolType type) {
1251 converter.addConversion([](smt::BitVectorType type) {
1254 converter.addConversion([](smt::ArrayType type) {
1257 converter.addConversion([](smt::IntType type) {
1260 converter.addConversion([](smt::SMTFuncType type) {
1263 converter.addConversion([](smt::SortType type) {
1269 RewritePatternSet &
patterns, TypeConverter &converter,
1271 #define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS) \
1272 patterns.add<VariadicSMTPattern<OP>>( \
1273 converter, patterns.getContext(), \
1274 globals, options, APINAME, \
1277 #define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS) \
1278 patterns.add<OneToOneSMTPattern<OP>>( \
1279 converter, patterns.getContext(), \
1280 globals, options, APINAME, NUM_ARGS);
1333 patterns.add<LowerLeftAssocSMTPattern<XOrOp>>(
1334 converter,
patterns.getContext(), globals, options);
1434 #undef ADD_VARIADIC_PATTERN
1435 #undef ADD_ONE_TO_ONE_PATTERN
1452 patterns.add<LowerChainableSMTPattern<EqOp>>(converter,
patterns.getContext(),
1455 globals, options,
"Z3_mk_eq", 2);
1459 patterns.add<BVConstantOpLowering, DeclareFunOpLowering, AssertOpLowering,
1460 ResetOpLowering, PushOpLowering, PopOpLowering, CheckOpLowering,
1461 SolverOpLowering, ApplyFuncOpLowering, YieldOpLowering,
1462 RepeatOpLowering, ExtractOpLowering, BoolConstantOpLowering,
1463 IntConstantOpLowering, ArrayBroadcastOpLowering, BVCmpOpLowering,
1464 IntCmpOpLowering, IntAbsOpLowering, QuantifierLowering<ForallOp>,
1465 QuantifierLowering<ExistsOp>>(converter,
patterns.getContext(),
1469 void LowerSMTToZ3LLVMPass::runOnOperation() {
1470 LowerSMTToZ3LLVMOptions options;
1471 options.debug =
debug;
1475 auto setLogicCheck = getOperation().walk([&](SolverOp solverOp)
1479 auto setLogicOps = solverOp.getBodyRegion().getOps<smt::SetLogicOp>();
1480 auto numSetLogicOps = std::distance(setLogicOps.begin(), setLogicOps.end());
1481 if (numSetLogicOps > 1) {
1482 return solverOp.emitError(
1483 "multiple set-logic operations found in one solver operation - Z3 "
1484 "only supports setting the logic once");
1486 if (numSetLogicOps == 1)
1488 for (
auto &blockOp : solverOp.getBodyRegion().getOps()) {
1489 if (isa<smt::SetLogicOp>(blockOp))
1491 if (!blockOp.hasTrait<OpTrait::ConstantLike>()) {
1492 return solverOp.emitError(
"set-logic operation must be the first "
1493 "non-constant operation in a solver "
1497 return WalkResult::advance();
1499 if (setLogicCheck.wasInterrupted())
1500 return signalPassFailure();
1503 LLVMTypeConverter converter(&getContext());
1506 RewritePatternSet
patterns(&getContext());
1522 populateFuncToLLVMConversionPatterns(converter,
patterns);
1523 arith::populateArithToLLVMConversionPatterns(converter,
patterns);
1528 populateSCFToControlFlowConversionPatterns(
patterns);
1529 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter,
patterns);
1533 OpBuilder builder(&getContext());
1539 LLVMConversionTarget target(getContext());
1540 target.addLegalOp<mlir::ModuleOp>();
1541 target.addLegalOp<scf::YieldOp>();
1543 if (failed(applyFullConversion(getOperation(), target, std::move(
patterns))))
1544 return signalPassFailure();
#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS)
#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS)
RewritePatternSet pattern
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateSMTToZ3LLVMTypeConverter(TypeConverter &converter)
Populate the given type converter with the SMT to LLVM type conversions.
void populateSMTToZ3LLVMConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter, SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options)
Add the SMT to LLVM IR conversion patterns to 'patterns'.
A symbol cache for LLVM globals and functions relevant to SMT lowering patterns.
static SMTGlobalsHandler create(OpBuilder &builder, ModuleOp module)
Creates the LLVM global operations to store the pointers to the solver and the context and returns a ...
SMTGlobalsHandler(ModuleOp module, mlir::LLVM::GlobalOp solver, mlir::LLVM::GlobalOp ctx)
Initializes the caches and keeps track of the given globals to store the pointers to the SMT solver a...