11#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
12#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
13#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
14#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
17#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
20#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/Dialect/SMT/IR/SMTOps.h"
24#include "mlir/IR/BuiltinDialect.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/Debug.h"
30#define DEBUG_TYPE "lower-smt-to-z3-llvm"
33#define GEN_PASS_DEF_LOWERSMTTOZ3LLVM
34#include "circt/Conversion/Passes.h.inc"
47 OpBuilder::InsertionGuard guard(builder);
48 builder.setInsertionPointToStart(module.getBody());
55 Location loc =
module.getLoc();
56 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
58 auto createGlobal = [&](StringRef namePrefix) {
59 auto global = builder.create<LLVM::GlobalOp>(
60 loc, ptrTy,
false, LLVM::Linkage::Internal,
names.
newName(namePrefix),
62 OpBuilder::InsertionGuard g(builder);
63 builder.createBlock(&global.getInitializer());
64 Value res = builder.create<LLVM::ZeroOp>(loc, ptrTy);
65 builder.create<LLVM::ReturnOp>(loc, res);
69 auto ctxGlobal = createGlobal(
"ctx");
70 auto solverGlobal = createGlobal(
"solver");
76 mlir::LLVM::GlobalOp solver,
77 mlir::LLVM::GlobalOp ctx)
78 : solver(solver), ctx(ctx), names(names) {}
81 mlir::LLVM::GlobalOp solver,
82 mlir::LLVM::GlobalOp ctx)
83 : solver(solver), ctx(ctx) {
95template <
typename OpTy>
98 SMTLoweringPattern(
const TypeConverter &typeConverter, MLIRContext *context,
100 const LowerSMTToZ3LLVMOptions &options)
105 Value buildGlobalPtrToGlobal(OpBuilder &builder, Location loc,
106 LLVM::GlobalOp global,
107 DenseMap<Block *, Value> &cache)
const {
108 Block *block = builder.getBlock();
109 if (
auto iter = cache.find(block); iter != cache.end())
110 return iter->getSecond();
112 OpBuilder::InsertionGuard g(builder);
113 builder.setInsertionPointToStart(block);
114 Value globalAddr = builder.create<LLVM::AddressOfOp>(loc, global);
115 return cache[block] = builder.create<LLVM::LoadOp>(
116 loc, LLVM::LLVMPointerType::get(builder.getContext()),
126 Value buildContextPtr(OpBuilder &builder, Location loc)
const {
127 return buildGlobalPtrToGlobal(builder, loc, globals.ctx, globals.ctxCache);
135 Value buildSolverPtr(OpBuilder &builder, Location loc)
const {
136 return buildGlobalPtrToGlobal(builder, loc, globals.solver,
137 globals.solverCache);
143 LLVM::CallOp buildCall(OpBuilder &builder, Location loc, StringRef name,
144 LLVM::LLVMFunctionType funcType,
145 ValueRange args)
const {
146 auto &funcOp = globals.funcMap[builder.getStringAttr(name)];
148 OpBuilder::InsertionGuard guard(builder);
150 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
151 builder.setInsertionPointToEnd(module.getBody());
152 auto funcOpResult = LLVM::lookupOrCreateFn(
153 builder, module, name, funcType.getParams(), funcType.getReturnType(),
154 funcType.getVarArg());
155 assert(succeeded(funcOpResult) &&
"expected to lookup or create printf");
156 funcOp = funcOpResult.value();
158 return builder.create<LLVM::CallOp>(loc, funcOp, args);
165 Value buildString(OpBuilder &builder, Location loc, StringRef str)
const {
166 auto &global = globals.stringCache[builder.getStringAttr(str)];
168 OpBuilder::InsertionGuard guard(builder);
170 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
171 builder.setInsertionPointToEnd(module.getBody());
173 LLVM::LLVMArrayType::get(builder.getI8Type(), str.size() + 1);
174 auto strAttr = builder.getStringAttr(str.str() +
'\00');
175 global = builder.create<LLVM::GlobalOp>(
176 loc, arrayTy,
true, LLVM::Linkage::Internal,
177 globals.names.newName(
"str"), strAttr);
179 return builder.create<LLVM::AddressOfOp>(loc, global);
184 LLVM::CallOp buildAPICallWithContext(OpBuilder &builder, Location loc,
185 StringRef name, Type returnType,
186 ValueRange args = {})
const {
187 auto ctx = buildContextPtr(builder, loc);
188 SmallVector<Value> arguments;
189 arguments.emplace_back(ctx);
190 arguments.append(SmallVector<Value>(args));
193 LLVM::LLVMFunctionType::get(
194 returnType, SmallVector<Type>(ValueRange(arguments).getTypes())),
201 Value buildPtrAPICall(OpBuilder &builder, Location loc, StringRef name,
202 ValueRange args = {})
const {
203 return buildAPICallWithContext(
205 LLVM::LLVMPointerType::get(builder.getContext()), args)
210 Value buildSort(OpBuilder &builder, Location loc, Type type)
const {
213 return TypeSwitch<Type, Value>(type)
214 .Case([&](smt::IntType ty) {
215 return buildPtrAPICall(builder, loc,
"Z3_mk_int_sort");
217 .Case([&](smt::BitVectorType ty) {
218 Value bitwidth = builder.create<LLVM::ConstantOp>(
219 loc, builder.getI32Type(), ty.getWidth());
220 return buildPtrAPICall(builder, loc,
"Z3_mk_bv_sort", {bitwidth});
222 .Case([&](smt::BoolType ty) {
223 return buildPtrAPICall(builder, loc,
"Z3_mk_bool_sort");
225 .Case([&](smt::SortType ty) {
226 Value str = buildString(builder, loc, ty.getIdentifier());
228 buildPtrAPICall(builder, loc,
"Z3_mk_string_symbol", {str});
229 return buildPtrAPICall(builder, loc,
"Z3_mk_uninterpreted_sort",
232 .Case([&](smt::ArrayType ty) {
233 return buildPtrAPICall(builder, loc,
"Z3_mk_array_sort",
234 {buildSort(builder, loc, ty.getDomainType()),
235 buildSort(builder, loc, ty.getRangeType())});
240 const LowerSMTToZ3LLVMOptions &options;
257struct DeclareFunOpLowering :
public SMTLoweringPattern<DeclareFunOp> {
258 using SMTLoweringPattern::SMTLoweringPattern;
261 matchAndRewrite(DeclareFunOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter)
const final {
263 Location loc = op.getLoc();
267 if (adaptor.getNamePrefix())
268 prefix = buildString(rewriter, loc, *adaptor.getNamePrefix());
270 prefix = rewriter.create<LLVM::ZeroOp>(
271 loc, LLVM::LLVMPointerType::get(getContext()));
274 if (!isa<SMTFuncType>(op.getType())) {
275 Value sort = buildSort(rewriter, loc, op.getType());
277 buildPtrAPICall(rewriter, loc,
"Z3_mk_fresh_const", {prefix, sort});
278 rewriter.replaceOp(op, constDecl);
283 Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
284 auto funcType = cast<SMTFuncType>(op.getResult().getType());
285 Value rangeSort = buildSort(rewriter, loc, funcType.getRangeType());
288 LLVM::LLVMArrayType::get(llvmPtrTy, funcType.getDomainTypes().size());
290 Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
291 for (
auto [i, ty] :
llvm::enumerate(funcType.getDomainTypes())) {
292 Value sort = buildSort(rewriter, loc, ty);
293 domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, sort, i);
297 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
298 Value domainStorage =
299 rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
300 rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
302 Value domainSize = rewriter.create<LLVM::ConstantOp>(
303 loc, rewriter.getI32Type(), funcType.getDomainTypes().size());
305 buildPtrAPICall(rewriter, loc,
"Z3_mk_fresh_func_decl",
306 {prefix, domainSize, domainStorage, rangeSort});
308 rewriter.replaceOp(op, decl);
318struct ApplyFuncOpLowering :
public SMTLoweringPattern<ApplyFuncOp> {
319 using SMTLoweringPattern::SMTLoweringPattern;
322 matchAndRewrite(ApplyFuncOp op, OpAdaptor adaptor,
323 ConversionPatternRewriter &rewriter)
const final {
324 Location loc = op.getLoc();
325 Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
326 Type arrTy = LLVM::LLVMArrayType::get(llvmPtrTy, adaptor.getArgs().size());
329 Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
330 for (
auto [i, arg] :
llvm::enumerate(adaptor.getArgs()))
331 domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, arg, i);
335 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
336 Value domainStorage =
337 rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
338 rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
342 Value domainSize = rewriter.create<LLVM::ConstantOp>(
343 loc, rewriter.getI32Type(), adaptor.getArgs().size());
345 buildPtrAPICall(rewriter, loc,
"Z3_mk_app",
346 {adaptor.getFunc(), domainSize, domainStorage});
347 rewriter.replaceOp(op, returnVal);
365struct BVConstantOpLowering :
public SMTLoweringPattern<smt::BVConstantOp> {
366 using SMTLoweringPattern::SMTLoweringPattern;
369 matchAndRewrite(smt::BVConstantOp op, OpAdaptor adaptor,
370 ConversionPatternRewriter &rewriter)
const final {
371 Location loc = op.getLoc();
372 unsigned width = op.getType().getWidth();
373 auto bvSort = buildSort(rewriter, loc, op.getResult().getType());
374 APInt val = adaptor.getValue().getValue();
377 Value bvConst = rewriter.create<LLVM::ConstantOp>(
378 loc, rewriter.getI64Type(), val.getZExtValue());
379 Value res = buildPtrAPICall(rewriter, loc,
"Z3_mk_unsigned_int64",
381 rewriter.replaceOp(op, res);
386 llvm::raw_string_ostream stream(str);
388 Value bvString = buildString(rewriter, loc, str);
390 buildPtrAPICall(rewriter, loc,
"Z3_mk_numeral", {bvString, bvSort});
392 rewriter.replaceOp(op, bvNumeral);
402template <
typename SourceTy>
403struct VariadicSMTPattern :
public SMTLoweringPattern<SourceTy> {
404 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
406 VariadicSMTPattern(
const TypeConverter &typeConverter, MLIRContext *context,
408 const LowerSMTToZ3LLVMOptions &options,
409 StringRef apiFuncName,
unsigned minNumArgs)
410 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
411 apiFuncName(apiFuncName), minNumArgs(minNumArgs) {}
414 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
415 ConversionPatternRewriter &rewriter)
const final {
416 if (adaptor.getOperands().size() < minNumArgs)
419 Location loc = op.getLoc();
420 Value numOperands = rewriter.create<LLVM::ConstantOp>(
421 loc, rewriter.getI32Type(), op->getNumOperands());
423 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
424 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
425 Type arrTy = LLVM::LLVMArrayType::get(ptrTy, op->getNumOperands());
427 rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
428 Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
430 for (
auto [i, operand] :
llvm::enumerate(adaptor.getOperands()))
431 array = rewriter.create<LLVM::InsertValueOp>(
432 loc, array, operand, ArrayRef<int64_t>{(int64_t)i});
434 rewriter.create<LLVM::StoreOp>(loc, array, storage);
436 rewriter.replaceOp(op,
437 SMTLoweringPattern<SourceTy>::buildPtrAPICall(
438 rewriter, loc, apiFuncName, {numOperands, storage}));
443 StringRef apiFuncName;
449template <
typename SourceTy>
450struct OneToOneSMTPattern :
public SMTLoweringPattern<SourceTy> {
451 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
453 OneToOneSMTPattern(
const TypeConverter &typeConverter, MLIRContext *context,
455 const LowerSMTToZ3LLVMOptions &options,
456 StringRef apiFuncName,
unsigned numOperands)
457 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
458 apiFuncName(apiFuncName), numOperands(numOperands) {}
461 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
462 ConversionPatternRewriter &rewriter)
const final {
463 if (adaptor.getOperands().size() != numOperands)
467 op, SMTLoweringPattern<SourceTy>::buildPtrAPICall(
468 rewriter, op.getLoc(), apiFuncName, adaptor.getOperands()));
473 StringRef apiFuncName;
474 unsigned numOperands;
479template <
typename SourceTy>
480class LowerChainableSMTPattern :
public SMTLoweringPattern<SourceTy> {
481 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
482 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
485 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
486 ConversionPatternRewriter &rewriter)
const final {
487 if (adaptor.getOperands().size() <= 2)
490 Location loc = op.getLoc();
491 SmallVector<Value> elements;
492 for (
int i = 1, e = adaptor.getOperands().size(); i < e; ++i) {
493 Value val = rewriter.create<SourceTy>(
494 loc, op->getResultTypes(),
495 ValueRange{adaptor.getOperands()[i - 1], adaptor.getOperands()[i]});
496 elements.push_back(val);
498 rewriter.replaceOpWithNewOp<smt::AndOp>(op, elements);
505template <
typename SourceTy>
506class LowerLeftAssocSMTPattern :
public SMTLoweringPattern<SourceTy> {
507 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
508 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
511 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
512 ConversionPatternRewriter &rewriter)
const final {
513 if (adaptor.getOperands().size() <= 2)
514 return rewriter.notifyMatchFailure(op,
"must have at least two operands");
516 Value runner = adaptor.getOperands()[0];
517 for (Value val : adaptor.getOperands().drop_front())
518 runner = rewriter.create<SourceTy>(op.
getLoc(), op->getResultTypes(),
519 ValueRange{runner, val});
521 rewriter.replaceOp(op, runner);
556struct SolverOpLowering :
public SMTLoweringPattern<SolverOp> {
557 using SMTLoweringPattern::SMTLoweringPattern;
560 matchAndRewrite(SolverOp op, OpAdaptor adaptor,
561 ConversionPatternRewriter &rewriter)
const final {
562 Location loc = op.getLoc();
563 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
564 auto voidTy = LLVM::LLVMVoidType::get(getContext());
565 auto ptrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, ptrTy);
566 auto ptrPtrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, {ptrTy, ptrTy});
567 auto ptrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, ptrTy);
568 auto ptrPtrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy});
571 Value config = buildCall(rewriter, loc,
"Z3_mk_config",
572 LLVM::LLVMFunctionType::get(ptrTy, {}), {})
578 Value paramKey = buildString(rewriter, loc,
"proof");
579 Value paramValue = buildString(rewriter, loc,
"true");
580 buildCall(rewriter, loc,
"Z3_set_param_value",
581 LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, ptrTy}),
582 {config, paramKey, paramValue});
586 std::optional<StringRef> logic = std::nullopt;
587 auto setLogicOps = op.getBodyRegion().getOps<smt::SetLogicOp>();
588 if (!setLogicOps.empty()) {
591 auto setLogicOp = *setLogicOps.begin();
592 logic = setLogicOp.getLogic();
593 rewriter.eraseOp(setLogicOp);
597 Value ctx = buildCall(rewriter, loc,
"Z3_mk_context", ptrToPtrFunc, config)
600 rewriter.create<LLVM::AddressOfOp>(loc, globals.ctx).getResult();
601 rewriter.create<LLVM::StoreOp>(loc, ctx, ctxAddr);
604 buildCall(rewriter, loc,
"Z3_del_config", ptrToVoidFunc, {config});
610 auto logicStr = buildString(rewriter, loc, logic.value());
611 solver = buildCall(rewriter, loc,
"Z3_mk_solver_for_logic",
612 ptrPtrToPtrFunc, {ctx, logicStr})
615 solver = buildCall(rewriter, loc,
"Z3_mk_solver", ptrToPtrFunc, ctx)
618 buildCall(rewriter, loc,
"Z3_solver_inc_ref", ptrPtrToVoidFunc,
621 rewriter.create<LLVM::AddressOfOp>(loc, globals.solver).getResult();
622 rewriter.create<LLVM::StoreOp>(loc, solver, solverAddr);
630 SmallVector<Type> convertedTypes;
632 typeConverter->convertTypes(op->getResultTypes(), convertedTypes)))
637 OpBuilder::InsertionGuard guard(rewriter);
638 auto module = op->getParentOfType<ModuleOp>();
639 rewriter.setInsertionPointToEnd(module.getBody());
641 funcOp = rewriter.create<func::FuncOp>(
642 loc, globals.names.newName(
"solver"),
643 rewriter.getFunctionType(adaptor.getInputs().getTypes(),
645 rewriter.inlineRegionBefore(op.getBodyRegion(), funcOp.getBody(),
650 rewriter.create<func::CallOp>(loc, funcOp, adaptor.getInputs())
661 buildCall(rewriter, loc,
"Z3_solver_dec_ref", ptrPtrToVoidFunc,
663 buildCall(rewriter, loc,
"Z3_del_context", ptrToVoidFunc, ctx);
665 rewriter.replaceOp(op, results);
674struct AssertOpLowering :
public SMTLoweringPattern<AssertOp> {
675 using SMTLoweringPattern::SMTLoweringPattern;
678 matchAndRewrite(AssertOp op, OpAdaptor adaptor,
679 ConversionPatternRewriter &rewriter)
const final {
680 Location loc = op.getLoc();
681 buildAPICallWithContext(
682 rewriter, loc,
"Z3_solver_assert",
683 LLVM::LLVMVoidType::get(getContext()),
684 {buildSolverPtr(rewriter, loc), adaptor.getInput()});
686 rewriter.eraseOp(op);
695struct ResetOpLowering :
public SMTLoweringPattern<ResetOp> {
696 using SMTLoweringPattern::SMTLoweringPattern;
699 matchAndRewrite(ResetOp op, OpAdaptor adaptor,
700 ConversionPatternRewriter &rewriter)
const final {
701 Location loc = op.getLoc();
702 buildAPICallWithContext(rewriter, loc,
"Z3_solver_reset",
703 LLVM::LLVMVoidType::get(getContext()),
704 {buildSolverPtr(rewriter, loc)});
706 rewriter.eraseOp(op);
715struct PushOpLowering :
public SMTLoweringPattern<PushOp> {
716 using SMTLoweringPattern::SMTLoweringPattern;
718 matchAndRewrite(PushOp op, OpAdaptor adaptor,
719 ConversionPatternRewriter &rewriter)
const final {
720 Location loc = op.getLoc();
724 for (uint32_t i = 0; i < op.getCount(); i++)
725 buildAPICallWithContext(rewriter, loc,
"Z3_solver_push",
726 LLVM::LLVMVoidType::get(getContext()),
727 {buildSolverPtr(rewriter, loc)});
728 rewriter.eraseOp(op);
737struct PopOpLowering :
public SMTLoweringPattern<PopOp> {
738 using SMTLoweringPattern::SMTLoweringPattern;
740 matchAndRewrite(PopOp op, OpAdaptor adaptor,
741 ConversionPatternRewriter &rewriter)
const final {
742 Location loc = op.getLoc();
743 Value constVal = rewriter.create<LLVM::ConstantOp>(
744 loc, rewriter.getI32Type(), op.getCount());
745 buildAPICallWithContext(rewriter, loc,
"Z3_solver_pop",
746 LLVM::LLVMVoidType::get(getContext()),
747 {buildSolverPtr(rewriter, loc), constVal});
748 rewriter.eraseOp(op);
758struct YieldOpLowering :
public SMTLoweringPattern<YieldOp> {
759 using SMTLoweringPattern::SMTLoweringPattern;
762 matchAndRewrite(YieldOp op, OpAdaptor adaptor,
763 ConversionPatternRewriter &rewriter)
const final {
764 if (op->getParentOfType<func::FuncOp>()) {
765 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getValues());
768 if (op->getParentOfType<LLVM::LLVMFuncOp>()) {
769 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getValues());
772 if (isa<scf::SCFDialect>(op->getParentOp()->getDialect())) {
773 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getValues());
791struct CheckOpLowering :
public SMTLoweringPattern<CheckOp> {
792 using SMTLoweringPattern::SMTLoweringPattern;
795 matchAndRewrite(CheckOp op, OpAdaptor adaptor,
796 ConversionPatternRewriter &rewriter)
const final {
797 Location loc = op.getLoc();
798 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
799 auto printfType = LLVM::LLVMFunctionType::get(
800 LLVM::LLVMVoidType::get(rewriter.getContext()), {ptrTy},
true);
802 auto getHeaderString = [](
const std::string &title) {
803 unsigned titleSize = title.size() + 2;
804 return std::string((80 - titleSize) / 2,
'-') +
" " + title +
" " +
805 std::string((80 - titleSize + 1) / 2,
'-') +
"\n%s\n" +
806 std::string(80,
'-') +
"\n";
810 Value solver = buildSolverPtr(rewriter, loc);
815 auto solverStringPtr =
816 buildPtrAPICall(rewriter, loc,
"Z3_solver_to_string", {solver});
817 auto solverFormatString =
818 buildString(rewriter, loc, getHeaderString(
"Solver"));
819 buildCall(rewriter, op.getLoc(),
"printf", printfType,
820 {solverFormatString, solverStringPtr});
824 SmallVector<Type> resultTypes;
825 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
830 buildAPICallWithContext(rewriter, loc,
"Z3_solver_check",
831 rewriter.getI32Type(), {solver})
834 rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), 1);
835 Value isSat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
836 checkResult, constOne);
839 auto satIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isSat);
840 rewriter.inlineRegionBefore(op.getSatRegion(), satIfOp.getThenRegion(),
841 satIfOp.getThenRegion().end());
846 rewriter.createBlock(&satIfOp.getElseRegion());
848 rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), -1);
849 Value isUnsat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
850 checkResult, constNegOne);
851 auto unsatIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isUnsat);
852 rewriter.create<scf::YieldOp>(loc, unsatIfOp->getResults());
854 rewriter.inlineRegionBefore(op.getUnsatRegion(), unsatIfOp.getThenRegion(),
855 unsatIfOp.getThenRegion().end());
856 rewriter.inlineRegionBefore(op.getUnknownRegion(),
857 unsatIfOp.getElseRegion(),
858 unsatIfOp.getElseRegion().end());
860 rewriter.replaceOp(op, satIfOp->getResults());
865 rewriter.setInsertionPointToStart(unsatIfOp.thenBlock());
866 auto proof = buildPtrAPICall(rewriter, op.getLoc(),
"Z3_solver_get_proof",
869 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_ast_to_string", {proof});
871 buildString(rewriter, op.getLoc(), getHeaderString(
"Proof"));
872 buildCall(rewriter, op.getLoc(),
"printf", printfType,
873 {formatString, stringPtr});
877 rewriter.setInsertionPointToStart(satIfOp.thenBlock());
878 auto model = buildPtrAPICall(rewriter, op.getLoc(),
"Z3_solver_get_model",
880 auto modelStringPtr =
881 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_model_to_string", {model});
882 auto modelFormatString =
883 buildString(rewriter, op.getLoc(), getHeaderString(
"Model"));
884 buildCall(rewriter, op.getLoc(),
"printf", printfType,
885 {modelFormatString, modelStringPtr});
913template <
typename QuantifierOp>
914struct QuantifierLowering :
public SMTLoweringPattern<QuantifierOp> {
915 using SMTLoweringPattern<QuantifierOp>::SMTLoweringPattern;
916 using SMTLoweringPattern<QuantifierOp>::typeConverter;
917 using SMTLoweringPattern<QuantifierOp>::buildPtrAPICall;
918 using OpAdaptor =
typename QuantifierOp::Adaptor;
920 Value createStorageForValueList(ValueRange values, Location loc,
921 ConversionPatternRewriter &rewriter)
const {
922 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
923 Type arrTy = LLVM::LLVMArrayType::get(ptrTy, values.size());
925 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
927 rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
928 Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
930 for (
auto [i, val] :
llvm::enumerate(values))
931 array = rewriter.create<LLVM::InsertValueOp>(loc, array, val,
932 ArrayRef<int64_t>(i));
934 rewriter.create<LLVM::StoreOp>(loc, array, storage);
940 matchAndRewrite(QuantifierOp op, OpAdaptor adaptor,
941 ConversionPatternRewriter &rewriter)
const final {
942 Location loc = op.getLoc();
943 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
949 if (adaptor.getNoPattern())
950 return rewriter.notifyMatchFailure(
951 op,
"no-pattern attribute not yet supported!");
953 rewriter.setInsertionPoint(op);
956 Value weight = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
957 adaptor.getWeight());
960 unsigned numDecls = op.getBody().getNumArguments();
962 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), numDecls);
969 SmallVector<Value> repl;
970 for (
auto [i, arg] :
llvm::enumerate(op.getBody().getArguments())) {
972 if (adaptor.getBoundVarNames().has_value())
973 newArg = rewriter.create<smt::DeclareFunOp>(
975 cast<StringAttr>((*adaptor.getBoundVarNames())[i]));
977 newArg = rewriter.create<smt::DeclareFunOp>(loc, arg.getType());
978 repl.push_back(typeConverter->materializeTargetConversion(
979 rewriter, loc, typeConverter->convertType(arg.getType()), newArg));
982 Value boundStorage = createStorageForValueList(repl, loc, rewriter);
985 auto yieldOp = cast<smt::YieldOp>(op.getBody().front().getTerminator());
986 Value bodyExp = yieldOp.getValues()[0];
987 rewriter.setInsertionPointAfterValue(bodyExp);
988 bodyExp = typeConverter->materializeTargetConversion(
989 rewriter, loc, typeConverter->convertType(bodyExp.getType()), bodyExp);
990 rewriter.eraseOp(yieldOp);
992 rewriter.inlineBlockBefore(&op.getBody().front(), op, repl);
993 rewriter.setInsertionPoint(op);
996 unsigned numPatterns = adaptor.getPatterns().size();
997 Value numPatternsVal = rewriter.create<LLVM::ConstantOp>(
998 loc, rewriter.getI32Type(), numPatterns);
1000 Value patternStorage;
1001 if (numPatterns > 0) {
1003 for (Region *patternRegion : adaptor.getPatterns()) {
1005 cast<smt::YieldOp>(patternRegion->front().getTerminator());
1006 auto patternTerms = yieldOp.getOperands();
1008 rewriter.setInsertionPoint(yieldOp);
1009 SmallVector<Value> patternList;
1010 for (
auto val : patternTerms)
1011 patternList.push_back(typeConverter->materializeTargetConversion(
1012 rewriter, loc, typeConverter->
convertType(val.getType()), val));
1014 rewriter.eraseOp(yieldOp);
1015 rewriter.inlineBlockBefore(&patternRegion->front(), op, repl);
1017 rewriter.setInsertionPoint(op);
1018 Value numTerms = rewriter.create<LLVM::ConstantOp>(
1019 loc, rewriter.getI32Type(), patternTerms.size());
1020 Value patternTermStorage =
1021 createStorageForValueList(patternList, loc, rewriter);
1022 Value
pattern = buildPtrAPICall(rewriter, loc,
"Z3_mk_pattern",
1023 {numTerms, patternTermStorage});
1027 patternStorage = createStorageForValueList(
patterns, loc, rewriter);
1031 patternStorage = rewriter.create<LLVM::ZeroOp>(loc, ptrTy);
1034 StringRef apiCallName =
"Z3_mk_forall_const";
1035 if (std::is_same_v<QuantifierOp, ExistsOp>)
1036 apiCallName =
"Z3_mk_exists_const";
1037 Value quantifierExp =
1038 buildPtrAPICall(rewriter, loc, apiCallName,
1039 {weight, numDeclsVal, boundStorage, numPatternsVal,
1040 patternStorage, bodyExp});
1042 rewriter.replaceOp(op, quantifierExp);
1051struct RepeatOpLowering :
public SMTLoweringPattern<RepeatOp> {
1052 using SMTLoweringPattern::SMTLoweringPattern;
1055 matchAndRewrite(RepeatOp op, OpAdaptor adaptor,
1056 ConversionPatternRewriter &rewriter)
const final {
1057 Value count = rewriter.create<LLVM::ConstantOp>(
1058 op.getLoc(), rewriter.getI32Type(), op.getCount());
1059 rewriter.replaceOp(op,
1060 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_mk_repeat",
1061 {count, adaptor.getInput()}));
1073struct ExtractOpLowering :
public SMTLoweringPattern<ExtractOp> {
1074 using SMTLoweringPattern::SMTLoweringPattern;
1077 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
1078 ConversionPatternRewriter &rewriter)
const final {
1079 Location loc = op.getLoc();
1080 Value low = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
1081 adaptor.getLowBit());
1082 Value high = rewriter.create<LLVM::ConstantOp>(
1083 loc, rewriter.getI32Type(),
1084 adaptor.getLowBit() + op.getType().getWidth() - 1);
1085 rewriter.replaceOp(op, buildPtrAPICall(rewriter, loc,
"Z3_mk_extract",
1086 {high, low, adaptor.getInput()}));
1095struct ArrayBroadcastOpLowering
1096 :
public SMTLoweringPattern<smt::ArrayBroadcastOp> {
1097 using SMTLoweringPattern::SMTLoweringPattern;
1100 matchAndRewrite(smt::ArrayBroadcastOp op, OpAdaptor adaptor,
1101 ConversionPatternRewriter &rewriter)
const final {
1102 auto domainSort = buildSort(
1103 rewriter, op.getLoc(),
1104 cast<smt::ArrayType>(op.getResult().getType()).getDomainType());
1106 rewriter.replaceOp(op, buildPtrAPICall(rewriter, op.getLoc(),
1107 "Z3_mk_const_array",
1108 {domainSort, adaptor.getValue()}));
1119struct BoolConstantOpLowering :
public SMTLoweringPattern<smt::BoolConstantOp> {
1120 using SMTLoweringPattern::SMTLoweringPattern;
1123 matchAndRewrite(smt::BoolConstantOp op, OpAdaptor adaptor,
1124 ConversionPatternRewriter &rewriter)
const final {
1126 op, buildPtrAPICall(rewriter, op.getLoc(),
1127 adaptor.getValue() ?
"Z3_mk_true" :
"Z3_mk_false"));
1143struct IntConstantOpLowering :
public SMTLoweringPattern<smt::IntConstantOp> {
1144 using SMTLoweringPattern::SMTLoweringPattern;
1147 matchAndRewrite(smt::IntConstantOp op, OpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter)
const final {
1149 Location loc = op.getLoc();
1150 Value type = buildPtrAPICall(rewriter, loc,
"Z3_mk_int_sort");
1151 if (adaptor.getValue().getBitWidth() <= 64) {
1152 Value val = rewriter.create<LLVM::ConstantOp>(
1153 loc, rewriter.getI64Type(), adaptor.getValue().getSExtValue());
1155 op, buildPtrAPICall(rewriter, loc,
"Z3_mk_int64", {val, type}));
1159 std::string numeralStr;
1160 llvm::raw_string_ostream stream(numeralStr);
1161 stream << adaptor.getValue().abs();
1163 Value numeral = buildString(rewriter, loc, numeralStr);
1165 buildPtrAPICall(rewriter, loc,
"Z3_mk_numeral", {numeral, type});
1167 if (adaptor.getValue().isNegative())
1169 buildPtrAPICall(rewriter, loc,
"Z3_mk_unary_minus", intNumeral);
1171 rewriter.replaceOp(op, intNumeral);
1181struct IntCmpOpLowering :
public SMTLoweringPattern<IntCmpOp> {
1182 using SMTLoweringPattern::SMTLoweringPattern;
1185 matchAndRewrite(IntCmpOp op, OpAdaptor adaptor,
1186 ConversionPatternRewriter &rewriter)
const final {
1189 buildPtrAPICall(rewriter, op.getLoc(),
1190 "Z3_mk_" + stringifyIntPredicate(op.getPred()).str(),
1191 {adaptor.getLhs(), adaptor.getRhs()}));
1200struct Int2BVOpLowering :
public SMTLoweringPattern<Int2BVOp> {
1201 using SMTLoweringPattern::SMTLoweringPattern;
1204 matchAndRewrite(Int2BVOp op, OpAdaptor adaptor,
1205 ConversionPatternRewriter &rewriter)
const final {
1207 rewriter.create<LLVM::ConstantOp>(op->getLoc(), rewriter.getI32Type(),
1208 op.getResult().getType().getWidth());
1209 rewriter.replaceOp(op,
1210 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_mk_int2bv",
1211 {widthConst, adaptor.getInput()}));
1220struct BV2IntOpLowering :
public SMTLoweringPattern<BV2IntOp> {
1221 using SMTLoweringPattern::SMTLoweringPattern;
1224 matchAndRewrite(BV2IntOp op, OpAdaptor adaptor,
1225 ConversionPatternRewriter &rewriter)
const final {
1228 Value isSignedConst = rewriter.create<LLVM::ConstantOp>(
1229 op->getLoc(), rewriter.getI1Type(), op.getIsSigned());
1230 rewriter.replaceOp(op,
1231 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_mk_bv2int",
1232 {adaptor.getInput(), isSignedConst}));
1243struct BVCmpOpLowering :
public SMTLoweringPattern<BVCmpOp> {
1244 using SMTLoweringPattern::SMTLoweringPattern;
1247 matchAndRewrite(BVCmpOp op, OpAdaptor adaptor,
1248 ConversionPatternRewriter &rewriter)
const final {
1250 op, buildPtrAPICall(rewriter, op.getLoc(),
1252 stringifyBVCmpPredicate(op.getPred()).str(),
1253 {adaptor.getLhs(), adaptor.getRhs()}));
1259struct IntAbsOpLowering :
public SMTLoweringPattern<IntAbsOp> {
1260 using SMTLoweringPattern::SMTLoweringPattern;
1263 matchAndRewrite(IntAbsOp op, OpAdaptor adaptor,
1264 ConversionPatternRewriter &rewriter)
const final {
1265 Location loc = op.getLoc();
1266 Value zero = rewriter.create<IntConstantOp>(
1267 loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
1268 Value cmp = rewriter.create<IntCmpOp>(loc, IntPredicate::lt,
1269 adaptor.getInput(), zero);
1270 Value neg = rewriter.create<IntSubOp>(loc, zero, adaptor.getInput());
1271 rewriter.replaceOpWithNewOp<IteOp>(op, cmp, neg, adaptor.getInput());
1283struct LowerSMTToZ3LLVMPass
1284 :
public circt::impl::LowerSMTToZ3LLVMBase<LowerSMTToZ3LLVMPass> {
1286 void runOnOperation()
override;
1291 converter.addConversion([](smt::BoolType type) {
1292 return LLVM::LLVMPointerType::get(type.getContext());
1294 converter.addConversion([](smt::BitVectorType type) {
1295 return LLVM::LLVMPointerType::get(type.getContext());
1297 converter.addConversion([](smt::ArrayType type) {
1298 return LLVM::LLVMPointerType::get(type.getContext());
1300 converter.addConversion([](smt::IntType type) {
1301 return LLVM::LLVMPointerType::get(type.getContext());
1303 converter.addConversion([](smt::SMTFuncType type) {
1304 return LLVM::LLVMPointerType::get(type.getContext());
1306 converter.addConversion([](smt::SortType type) {
1307 return LLVM::LLVMPointerType::get(type.getContext());
1312 RewritePatternSet &
patterns, TypeConverter &converter,
1314#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS) \
1315 patterns.add<VariadicSMTPattern<OP>>( \
1316 converter, patterns.getContext(), \
1317 globals, options, APINAME, \
1320#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS) \
1321 patterns.add<OneToOneSMTPattern<OP>>( \
1322 converter, patterns.getContext(), \
1323 globals, options, APINAME, NUM_ARGS);
1376 patterns.add<LowerLeftAssocSMTPattern<XOrOp>>(
1377 converter,
patterns.getContext(), globals, options);
1477#undef ADD_VARIADIC_PATTERN
1478#undef ADD_ONE_TO_ONE_PATTERN
1495 patterns.add<LowerChainableSMTPattern<EqOp>>(converter,
patterns.getContext(),
1498 globals, options,
"Z3_mk_eq", 2);
1502 patterns.add<BVConstantOpLowering, DeclareFunOpLowering, AssertOpLowering,
1503 ResetOpLowering, PushOpLowering, PopOpLowering, CheckOpLowering,
1504 SolverOpLowering, ApplyFuncOpLowering, YieldOpLowering,
1505 RepeatOpLowering, ExtractOpLowering, BoolConstantOpLowering,
1506 IntConstantOpLowering, ArrayBroadcastOpLowering, BVCmpOpLowering,
1507 IntCmpOpLowering, IntAbsOpLowering, Int2BVOpLowering,
1508 BV2IntOpLowering, QuantifierLowering<ForallOp>,
1509 QuantifierLowering<ExistsOp>>(converter,
patterns.getContext(),
1513void LowerSMTToZ3LLVMPass::runOnOperation() {
1514 LowerSMTToZ3LLVMOptions options;
1515 options.debug =
debug;
1519 auto setLogicCheck = getOperation().walk([&](SolverOp solverOp)
1523 auto setLogicOps = solverOp.getBodyRegion().getOps<smt::SetLogicOp>();
1524 auto numSetLogicOps = std::distance(setLogicOps.begin(), setLogicOps.end());
1525 if (numSetLogicOps > 1) {
1526 return solverOp.emitError(
1527 "multiple set-logic operations found in one solver operation - Z3 "
1528 "only supports setting the logic once");
1530 if (numSetLogicOps == 1)
1532 for (
auto &blockOp : solverOp.getBodyRegion().getOps()) {
1533 if (isa<smt::SetLogicOp>(blockOp))
1535 if (!blockOp.hasTrait<OpTrait::ConstantLike>()) {
1536 return solverOp.emitError(
"set-logic operation must be the first "
1537 "non-constant operation in a solver "
1541 return WalkResult::advance();
1543 if (setLogicCheck.wasInterrupted())
1544 return signalPassFailure();
1547 LLVMTypeConverter converter(&getContext());
1550 RewritePatternSet
patterns(&getContext());
1566 populateFuncToLLVMConversionPatterns(converter,
patterns);
1567 arith::populateArithToLLVMConversionPatterns(converter,
patterns);
1572 populateSCFToControlFlowConversionPatterns(
patterns);
1573 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter,
patterns);
1577 OpBuilder builder(&getContext());
1583 LLVMConversionTarget target(getContext());
1584 target.addLegalOp<mlir::ModuleOp>();
1585 target.addLegalOp<scf::YieldOp>();
1587 if (failed(applyFullConversion(getOperation(), target, std::move(
patterns))))
1588 return signalPassFailure();
assert(baseType &&"element must be base type")
static FIRRTLBaseType convertType(FIRRTLBaseType type)
Returns null type if no conversion is needed.
#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS)
#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS)
static Location getLoc(DefSlot slot)
RewritePatternSet pattern
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateSMTToZ3LLVMTypeConverter(TypeConverter &converter)
Populate the given type converter with the SMT to LLVM type conversions.
void populateSMTToZ3LLVMConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter, SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options)
Add the SMT to LLVM IR conversion patterns to 'patterns'.
A symbol cache for LLVM globals and functions relevant to SMT lowering patterns.
static SMTGlobalsHandler create(OpBuilder &builder, ModuleOp module)
Creates the LLVM global operations to store the pointers to the solver and the context and returns a ...
SMTGlobalsHandler(ModuleOp module, mlir::LLVM::GlobalOp solver, mlir::LLVM::GlobalOp ctx)
Initializes the caches and keeps track of the given globals to store the pointers to the SMT solver a...