13#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
14#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
16#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
18#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
19#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
22#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
23#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/SMT/IR/SMTOps.h"
26#include "mlir/IR/BuiltinDialect.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
33#define DEBUG_TYPE "lower-smt-to-z3-llvm"
36#define GEN_PASS_DEF_LOWERSMTTOZ3LLVM
37#include "circt/Conversion/Passes.h.inc"
50 OpBuilder::InsertionGuard guard(builder);
51 builder.setInsertionPointToStart(module.getBody());
58 Location loc =
module.getLoc();
59 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
61 auto createGlobal = [&](StringRef namePrefix) {
62 auto global = LLVM::GlobalOp::create(
63 builder, loc, ptrTy,
false, LLVM::Linkage::Internal,
65 OpBuilder::InsertionGuard g(builder);
66 builder.createBlock(&global.getInitializer());
67 Value res = LLVM::ZeroOp::create(builder, loc, ptrTy);
68 LLVM::ReturnOp::create(builder, loc, res);
72 auto ctxGlobal = createGlobal(
"ctx");
73 auto solverGlobal = createGlobal(
"solver");
79 mlir::LLVM::GlobalOp solver,
80 mlir::LLVM::GlobalOp ctx)
81 : solver(solver), ctx(ctx), names(names) {}
84 mlir::LLVM::GlobalOp solver,
85 mlir::LLVM::GlobalOp ctx)
86 : solver(solver), ctx(ctx) {
98template <
typename OpTy>
101 SMTLoweringPattern(
const TypeConverter &typeConverter, MLIRContext *
context,
103 const LowerSMTToZ3LLVMOptions &options)
108 Value buildGlobalPtrToGlobal(OpBuilder &builder, Location loc,
109 LLVM::GlobalOp global,
110 DenseMap<Block *, Value> &cache)
const {
111 Block *block = builder.getBlock();
112 if (
auto iter = cache.find(block); iter != cache.end())
113 return iter->getSecond();
115 OpBuilder::InsertionGuard g(builder);
116 builder.setInsertionPointToStart(block);
117 Value globalAddr = LLVM::AddressOfOp::create(builder, loc, global);
118 return cache[block] = LLVM::LoadOp::create(
119 builder, loc, LLVM::LLVMPointerType::get(builder.getContext()),
129 Value buildContextPtr(OpBuilder &builder, Location loc)
const {
130 return buildGlobalPtrToGlobal(builder, loc, globals.ctx, globals.ctxCache);
138 Value buildSolverPtr(OpBuilder &builder, Location loc)
const {
139 return buildGlobalPtrToGlobal(builder, loc, globals.solver,
140 globals.solverCache);
146 LLVM::CallOp buildCall(OpBuilder &builder, Location loc, StringRef name,
147 LLVM::LLVMFunctionType funcType,
148 ValueRange args)
const {
149 auto &funcOp = globals.funcMap[builder.getStringAttr(name)];
151 OpBuilder::InsertionGuard guard(builder);
153 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
154 builder.setInsertionPointToEnd(module.getBody());
155 auto funcOpResult = LLVM::lookupOrCreateFn(
156 builder, module, name, funcType.getParams(), funcType.getReturnType(),
157 funcType.getVarArg());
158 assert(succeeded(funcOpResult) &&
"expected to lookup or create printf");
159 funcOp = funcOpResult.value();
161 return LLVM::CallOp::create(builder, loc, funcOp, args);
168 Value buildString(OpBuilder &builder, Location loc, StringRef str)
const {
169 auto &global = globals.stringCache[builder.getStringAttr(str)];
171 OpBuilder::InsertionGuard guard(builder);
173 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
174 builder.setInsertionPointToEnd(module.getBody());
176 LLVM::LLVMArrayType::get(builder.getI8Type(), str.size() + 1);
177 auto strAttr = builder.getStringAttr(str.str() +
'\00');
178 global = LLVM::GlobalOp::create(
179 builder, loc, arrayTy,
true, LLVM::Linkage::Internal,
180 globals.names.newName(
"str"), strAttr);
182 return LLVM::AddressOfOp::create(builder, loc, global);
187 LLVM::CallOp buildAPICallWithContext(OpBuilder &builder, Location loc,
188 StringRef name, Type returnType,
189 ValueRange args = {})
const {
190 auto ctx = buildContextPtr(builder, loc);
191 SmallVector<Value> arguments;
192 arguments.emplace_back(ctx);
193 arguments.append(SmallVector<Value>(args));
196 LLVM::LLVMFunctionType::get(
197 returnType, SmallVector<Type>(ValueRange(arguments).getTypes())),
204 Value buildPtrAPICall(OpBuilder &builder, Location loc, StringRef name,
205 ValueRange args = {})
const {
206 return buildAPICallWithContext(
208 LLVM::LLVMPointerType::get(builder.getContext()), args)
213 Value buildSort(OpBuilder &builder, Location loc, Type type)
const {
216 return TypeSwitch<Type, Value>(type)
217 .Case([&](smt::IntType ty) {
218 return buildPtrAPICall(builder, loc,
"Z3_mk_int_sort");
220 .Case([&](smt::BitVectorType ty) {
221 Value bitwidth = LLVM::ConstantOp::create(
222 builder, loc, builder.getI32Type(), ty.getWidth());
223 return buildPtrAPICall(builder, loc,
"Z3_mk_bv_sort", {bitwidth});
225 .Case([&](smt::BoolType ty) {
226 return buildPtrAPICall(builder, loc,
"Z3_mk_bool_sort");
228 .Case([&](smt::SortType ty) {
229 Value str = buildString(builder, loc, ty.getIdentifier());
231 buildPtrAPICall(builder, loc,
"Z3_mk_string_symbol", {str});
232 return buildPtrAPICall(builder, loc,
"Z3_mk_uninterpreted_sort",
235 .Case([&](smt::ArrayType ty) {
236 return buildPtrAPICall(builder, loc,
"Z3_mk_array_sort",
237 {buildSort(builder, loc, ty.getDomainType()),
238 buildSort(builder, loc, ty.getRangeType())});
243 const LowerSMTToZ3LLVMOptions &options;
260struct DeclareFunOpLowering :
public SMTLoweringPattern<DeclareFunOp> {
261 using SMTLoweringPattern::SMTLoweringPattern;
264 matchAndRewrite(DeclareFunOp op, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter)
const final {
266 Location loc = op.getLoc();
270 if (adaptor.getNamePrefix())
271 prefix = buildString(rewriter, loc, *adaptor.getNamePrefix());
273 prefix = LLVM::ZeroOp::create(rewriter, loc,
274 LLVM::LLVMPointerType::get(getContext()));
277 if (!isa<SMTFuncType>(op.getType())) {
278 Value sort = buildSort(rewriter, loc, op.getType());
280 buildPtrAPICall(rewriter, loc,
"Z3_mk_fresh_const", {prefix, sort});
281 rewriter.replaceOp(op, constDecl);
286 Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
287 auto funcType = cast<SMTFuncType>(op.getResult().getType());
288 Value rangeSort = buildSort(rewriter, loc, funcType.getRangeType());
291 LLVM::LLVMArrayType::get(llvmPtrTy, funcType.getDomainTypes().size());
293 Value domain = LLVM::UndefOp::create(rewriter, loc, arrTy);
294 for (
auto [i, ty] :
llvm::enumerate(funcType.getDomainTypes())) {
295 Value sort = buildSort(rewriter, loc, ty);
296 domain = LLVM::InsertValueOp::create(rewriter, loc, domain, sort, i);
300 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
301 Value domainStorage =
302 LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, arrTy, one);
303 LLVM::StoreOp::create(rewriter, loc, domain, domainStorage);
305 Value domainSize = LLVM::ConstantOp::create(
306 rewriter, loc, rewriter.getI32Type(), funcType.getDomainTypes().size());
308 buildPtrAPICall(rewriter, loc,
"Z3_mk_fresh_func_decl",
309 {prefix, domainSize, domainStorage, rangeSort});
311 rewriter.replaceOp(op, decl);
321struct ApplyFuncOpLowering :
public SMTLoweringPattern<ApplyFuncOp> {
322 using SMTLoweringPattern::SMTLoweringPattern;
325 matchAndRewrite(ApplyFuncOp op, OpAdaptor adaptor,
326 ConversionPatternRewriter &rewriter)
const final {
327 Location loc = op.getLoc();
328 Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
329 Type arrTy = LLVM::LLVMArrayType::get(llvmPtrTy, adaptor.getArgs().size());
332 Value domain = LLVM::UndefOp::create(rewriter, loc, arrTy);
333 for (
auto [i, arg] :
llvm::enumerate(adaptor.getArgs()))
334 domain = LLVM::InsertValueOp::create(rewriter, loc, domain, arg, i);
338 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
339 Value domainStorage =
340 LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, arrTy, one);
341 LLVM::StoreOp::create(rewriter, loc, domain, domainStorage);
345 Value domainSize = LLVM::ConstantOp::create(
346 rewriter, loc, rewriter.getI32Type(), adaptor.getArgs().size());
348 buildPtrAPICall(rewriter, loc,
"Z3_mk_app",
349 {adaptor.getFunc(), domainSize, domainStorage});
350 rewriter.replaceOp(op, returnVal);
368struct BVConstantOpLowering :
public SMTLoweringPattern<smt::BVConstantOp> {
369 using SMTLoweringPattern::SMTLoweringPattern;
372 matchAndRewrite(smt::BVConstantOp op, OpAdaptor adaptor,
373 ConversionPatternRewriter &rewriter)
const final {
374 Location loc = op.getLoc();
375 unsigned width = op.getType().getWidth();
376 auto bvSort = buildSort(rewriter, loc, op.getResult().getType());
377 APInt val = adaptor.getValue().getValue();
380 Value bvConst = LLVM::ConstantOp::create(
381 rewriter, loc, rewriter.getI64Type(), val.getZExtValue());
382 Value res = buildPtrAPICall(rewriter, loc,
"Z3_mk_unsigned_int64",
384 rewriter.replaceOp(op, res);
389 llvm::raw_string_ostream stream(str);
391 Value bvString = buildString(rewriter, loc, str);
393 buildPtrAPICall(rewriter, loc,
"Z3_mk_numeral", {bvString, bvSort});
395 rewriter.replaceOp(op, bvNumeral);
405template <
typename SourceTy>
406struct VariadicSMTPattern :
public SMTLoweringPattern<SourceTy> {
407 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
409 VariadicSMTPattern(
const TypeConverter &typeConverter, MLIRContext *
context,
411 const LowerSMTToZ3LLVMOptions &options,
412 StringRef apiFuncName,
unsigned minNumArgs)
413 : SMTLoweringPattern<SourceTy>(typeConverter,
context, globals, options),
414 apiFuncName(apiFuncName), minNumArgs(minNumArgs) {}
417 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
418 ConversionPatternRewriter &rewriter)
const final {
419 if (adaptor.getOperands().size() < minNumArgs)
422 Location loc = op.getLoc();
423 Value numOperands = LLVM::ConstantOp::create(
424 rewriter, loc, rewriter.getI32Type(), op->getNumOperands());
426 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
427 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
428 Type arrTy = LLVM::LLVMArrayType::get(ptrTy, op->getNumOperands());
430 LLVM::AllocaOp::create(rewriter, loc, ptrTy, arrTy, constOne);
431 Value array = LLVM::UndefOp::create(rewriter, loc, arrTy);
433 for (
auto [i, operand] :
llvm::enumerate(adaptor.getOperands()))
434 array = LLVM::InsertValueOp::create(rewriter, loc, array, operand,
435 ArrayRef<int64_t>{(int64_t)i});
437 LLVM::StoreOp::create(rewriter, loc, array, storage);
439 rewriter.replaceOp(op,
440 SMTLoweringPattern<SourceTy>::buildPtrAPICall(
441 rewriter, loc, apiFuncName, {numOperands, storage}));
446 StringRef apiFuncName;
452template <
typename SourceTy>
453struct OneToOneSMTPattern :
public SMTLoweringPattern<SourceTy> {
454 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
456 OneToOneSMTPattern(
const TypeConverter &typeConverter, MLIRContext *
context,
458 const LowerSMTToZ3LLVMOptions &options,
459 StringRef apiFuncName,
unsigned numOperands)
460 : SMTLoweringPattern<SourceTy>(typeConverter,
context, globals, options),
461 apiFuncName(apiFuncName), numOperands(numOperands) {}
464 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
465 ConversionPatternRewriter &rewriter)
const final {
466 if (adaptor.getOperands().size() != numOperands)
470 op, SMTLoweringPattern<SourceTy>::buildPtrAPICall(
471 rewriter, op.getLoc(), apiFuncName, adaptor.getOperands()));
476 StringRef apiFuncName;
477 unsigned numOperands;
482template <
typename SourceTy>
483class LowerChainableSMTPattern :
public SMTLoweringPattern<SourceTy> {
484 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
485 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
488 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter)
const final {
490 if (adaptor.getOperands().size() <= 2)
493 Location loc = op.getLoc();
494 SmallVector<Value> elements;
495 for (
int i = 1, e = adaptor.getOperands().size(); i < e; ++i) {
496 Value val = SourceTy::create(
497 rewriter, loc, op->getResultTypes(),
498 ValueRange{adaptor.getOperands()[i - 1], adaptor.getOperands()[i]});
499 elements.push_back(val);
501 rewriter.replaceOpWithNewOp<smt::AndOp>(op, elements);
508template <
typename SourceTy>
509class LowerLeftAssocSMTPattern :
public SMTLoweringPattern<SourceTy> {
510 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
511 using OpAdaptor =
typename SMTLoweringPattern<SourceTy>::OpAdaptor;
514 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
515 ConversionPatternRewriter &rewriter)
const final {
516 if (adaptor.getOperands().size() <= 2)
517 return rewriter.notifyMatchFailure(op,
"must have at least two operands");
519 Value runner = adaptor.getOperands()[0];
520 for (Value val : adaptor.getOperands().drop_front())
521 runner = SourceTy::create(rewriter, op.
getLoc(), op->getResultTypes(),
522 ValueRange{runner, val});
524 rewriter.replaceOp(op, runner);
559struct SolverOpLowering :
public SMTLoweringPattern<SolverOp> {
560 using SMTLoweringPattern::SMTLoweringPattern;
563 matchAndRewrite(SolverOp op, OpAdaptor adaptor,
564 ConversionPatternRewriter &rewriter)
const final {
565 Location loc = op.getLoc();
566 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
567 auto voidTy = LLVM::LLVMVoidType::get(getContext());
568 auto ptrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, ptrTy);
569 auto ptrPtrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, {ptrTy, ptrTy});
570 auto ptrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, ptrTy);
571 auto ptrPtrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy});
574 Value config = buildCall(rewriter, loc,
"Z3_mk_config",
575 LLVM::LLVMFunctionType::get(ptrTy, {}), {})
581 Value paramKey = buildString(rewriter, loc,
"proof");
582 Value paramValue = buildString(rewriter, loc,
"true");
583 buildCall(rewriter, loc,
"Z3_set_param_value",
584 LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, ptrTy}),
585 {config, paramKey, paramValue});
589 std::optional<StringRef> logic = std::nullopt;
590 auto setLogicOps = op.getBodyRegion().getOps<smt::SetLogicOp>();
591 if (!setLogicOps.empty()) {
594 auto setLogicOp = *setLogicOps.begin();
595 logic = setLogicOp.getLogic();
596 rewriter.eraseOp(setLogicOp);
600 Value ctx = buildCall(rewriter, loc,
"Z3_mk_context", ptrToPtrFunc, config)
603 LLVM::AddressOfOp::create(rewriter, loc, globals.ctx).getResult();
604 LLVM::StoreOp::create(rewriter, loc, ctx, ctxAddr);
607 buildCall(rewriter, loc,
"Z3_del_config", ptrToVoidFunc, {config});
613 auto logicStr = buildString(rewriter, loc, logic.value());
614 solver = buildCall(rewriter, loc,
"Z3_mk_solver_for_logic",
615 ptrPtrToPtrFunc, {ctx, logicStr})
618 solver = buildCall(rewriter, loc,
"Z3_mk_solver", ptrToPtrFunc, ctx)
621 buildCall(rewriter, loc,
"Z3_solver_inc_ref", ptrPtrToVoidFunc,
624 LLVM::AddressOfOp::create(rewriter, loc, globals.solver).getResult();
625 LLVM::StoreOp::create(rewriter, loc, solver, solverAddr);
633 SmallVector<Type> convertedTypes;
635 typeConverter->convertTypes(op->getResultTypes(), convertedTypes)))
640 OpBuilder::InsertionGuard guard(rewriter);
641 auto module = op->getParentOfType<ModuleOp>();
642 rewriter.setInsertionPointToEnd(module.getBody());
644 funcOp = func::FuncOp::create(
645 rewriter, loc, globals.names.newName(
"solver"),
646 rewriter.getFunctionType(adaptor.getInputs().getTypes(),
648 rewriter.inlineRegionBefore(op.getBodyRegion(), funcOp.getBody(),
653 func::CallOp::create(rewriter, loc, funcOp, adaptor.getInputs())
664 buildCall(rewriter, loc,
"Z3_solver_dec_ref", ptrPtrToVoidFunc,
666 buildCall(rewriter, loc,
"Z3_del_context", ptrToVoidFunc, ctx);
668 rewriter.replaceOp(op, results);
677struct AssertOpLowering :
public SMTLoweringPattern<AssertOp> {
678 using SMTLoweringPattern::SMTLoweringPattern;
681 matchAndRewrite(AssertOp op, OpAdaptor adaptor,
682 ConversionPatternRewriter &rewriter)
const final {
683 Location loc = op.getLoc();
684 buildAPICallWithContext(
685 rewriter, loc,
"Z3_solver_assert",
686 LLVM::LLVMVoidType::get(getContext()),
687 {buildSolverPtr(rewriter, loc), adaptor.getInput()});
689 rewriter.eraseOp(op);
698struct ResetOpLowering :
public SMTLoweringPattern<ResetOp> {
699 using SMTLoweringPattern::SMTLoweringPattern;
702 matchAndRewrite(ResetOp op, OpAdaptor adaptor,
703 ConversionPatternRewriter &rewriter)
const final {
704 Location loc = op.getLoc();
705 buildAPICallWithContext(rewriter, loc,
"Z3_solver_reset",
706 LLVM::LLVMVoidType::get(getContext()),
707 {buildSolverPtr(rewriter, loc)});
709 rewriter.eraseOp(op);
718struct PushOpLowering :
public SMTLoweringPattern<PushOp> {
719 using SMTLoweringPattern::SMTLoweringPattern;
721 matchAndRewrite(PushOp op, OpAdaptor adaptor,
722 ConversionPatternRewriter &rewriter)
const final {
723 Location loc = op.getLoc();
727 for (uint32_t i = 0; i < op.getCount(); i++)
728 buildAPICallWithContext(rewriter, loc,
"Z3_solver_push",
729 LLVM::LLVMVoidType::get(getContext()),
730 {buildSolverPtr(rewriter, loc)});
731 rewriter.eraseOp(op);
740struct PopOpLowering :
public SMTLoweringPattern<PopOp> {
741 using SMTLoweringPattern::SMTLoweringPattern;
743 matchAndRewrite(PopOp op, OpAdaptor adaptor,
744 ConversionPatternRewriter &rewriter)
const final {
745 Location loc = op.getLoc();
746 Value constVal = LLVM::ConstantOp::create(
747 rewriter, loc, rewriter.getI32Type(), op.getCount());
748 buildAPICallWithContext(rewriter, loc,
"Z3_solver_pop",
749 LLVM::LLVMVoidType::get(getContext()),
750 {buildSolverPtr(rewriter, loc), constVal});
751 rewriter.eraseOp(op);
761struct YieldOpLowering :
public SMTLoweringPattern<YieldOp> {
762 using SMTLoweringPattern::SMTLoweringPattern;
765 matchAndRewrite(YieldOp op, OpAdaptor adaptor,
766 ConversionPatternRewriter &rewriter)
const final {
767 if (op->getParentOfType<func::FuncOp>()) {
768 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getValues());
771 if (op->getParentOfType<LLVM::LLVMFuncOp>()) {
772 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getValues());
775 if (isa_and_nonnull<scf::SCFDialect>(op->getParentOp()->getDialect())) {
776 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getValues());
794struct CheckOpLowering :
public SMTLoweringPattern<CheckOp> {
795 using SMTLoweringPattern::SMTLoweringPattern;
798 matchAndRewrite(CheckOp op, OpAdaptor adaptor,
799 ConversionPatternRewriter &rewriter)
const final {
800 Location loc = op.getLoc();
801 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
802 auto printfType = LLVM::LLVMFunctionType::get(
803 LLVM::LLVMVoidType::get(rewriter.getContext()), {ptrTy},
true);
805 auto getHeaderString = [](
const std::string &title) {
806 unsigned titleSize = title.size() + 2;
807 return std::string((80 - titleSize) / 2,
'-') +
" " + title +
" " +
808 std::string((80 - titleSize + 1) / 2,
'-') +
"\n%s\n" +
809 std::string(80,
'-') +
"\n";
813 Value solver = buildSolverPtr(rewriter, loc);
818 auto solverStringPtr =
819 buildPtrAPICall(rewriter, loc,
"Z3_solver_to_string", {solver});
820 auto solverFormatString =
821 buildString(rewriter, loc, getHeaderString(
"Solver"));
822 buildCall(rewriter, op.getLoc(),
"printf", printfType,
823 {solverFormatString, solverStringPtr});
827 SmallVector<Type> resultTypes;
828 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
833 buildAPICallWithContext(rewriter, loc,
"Z3_solver_check",
834 rewriter.getI32Type(), {solver})
837 LLVM::ConstantOp::create(rewriter, loc, checkResult.getType(), 1);
838 Value isSat = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq,
839 checkResult, constOne);
842 auto satIfOp = scf::IfOp::create(rewriter, loc, resultTypes, isSat);
843 rewriter.inlineRegionBefore(op.getSatRegion(), satIfOp.getThenRegion(),
844 satIfOp.getThenRegion().end());
849 rewriter.createBlock(&satIfOp.getElseRegion());
851 LLVM::ConstantOp::create(rewriter, loc, checkResult.getType(), -1);
852 Value isUnsat = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq,
853 checkResult, constNegOne);
854 auto unsatIfOp = scf::IfOp::create(rewriter, loc, resultTypes, isUnsat);
855 scf::YieldOp::create(rewriter, loc, unsatIfOp->getResults());
857 rewriter.inlineRegionBefore(op.getUnsatRegion(), unsatIfOp.getThenRegion(),
858 unsatIfOp.getThenRegion().end());
859 rewriter.inlineRegionBefore(op.getUnknownRegion(),
860 unsatIfOp.getElseRegion(),
861 unsatIfOp.getElseRegion().end());
863 rewriter.replaceOp(op, satIfOp->getResults());
868 rewriter.setInsertionPointToStart(unsatIfOp.thenBlock());
869 auto proof = buildPtrAPICall(rewriter, op.getLoc(),
"Z3_solver_get_proof",
872 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_ast_to_string", {proof});
874 buildString(rewriter, op.getLoc(), getHeaderString(
"Proof"));
875 buildCall(rewriter, op.getLoc(),
"printf", printfType,
876 {formatString, stringPtr});
880 rewriter.setInsertionPointToStart(satIfOp.thenBlock());
881 auto model = buildPtrAPICall(rewriter, op.getLoc(),
"Z3_solver_get_model",
883 auto modelStringPtr =
884 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_model_to_string", {model});
885 auto modelFormatString =
886 buildString(rewriter, op.getLoc(), getHeaderString(
"Model"));
887 buildCall(rewriter, op.getLoc(),
"printf", printfType,
888 {modelFormatString, modelStringPtr});
916template <
typename QuantifierOp>
917struct QuantifierLowering :
public SMTLoweringPattern<QuantifierOp> {
918 using SMTLoweringPattern<QuantifierOp>::SMTLoweringPattern;
919 using SMTLoweringPattern<QuantifierOp>::typeConverter;
920 using SMTLoweringPattern<QuantifierOp>::buildPtrAPICall;
921 using OpAdaptor =
typename QuantifierOp::Adaptor;
923 Value createStorageForValueList(ValueRange values, Location loc,
924 ConversionPatternRewriter &rewriter)
const {
925 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
926 Type arrTy = LLVM::LLVMArrayType::get(ptrTy, values.size());
928 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
930 LLVM::AllocaOp::create(rewriter, loc, ptrTy, arrTy, constOne);
931 Value array = LLVM::UndefOp::create(rewriter, loc, arrTy);
933 for (
auto [i, val] :
llvm::enumerate(values))
934 array = LLVM::InsertValueOp::create(rewriter, loc, array, val,
935 ArrayRef<int64_t>(i));
937 LLVM::StoreOp::create(rewriter, loc, array, storage);
943 matchAndRewrite(QuantifierOp op, OpAdaptor adaptor,
944 ConversionPatternRewriter &rewriter)
const final {
945 Location loc = op.getLoc();
946 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
952 if (adaptor.getNoPattern())
953 return rewriter.notifyMatchFailure(
954 op,
"no-pattern attribute not yet supported!");
956 rewriter.setInsertionPoint(op);
959 Value weight = LLVM::ConstantOp::create(
960 rewriter, loc, rewriter.getI32Type(), adaptor.getWeight());
963 unsigned numDecls = op.getBody().getNumArguments();
964 Value numDeclsVal = LLVM::ConstantOp::create(
965 rewriter, loc, rewriter.getI32Type(), numDecls);
972 SmallVector<Value> repl;
973 for (
auto [i, arg] :
llvm::enumerate(op.getBody().getArguments())) {
975 if (adaptor.getBoundVarNames().has_value())
976 newArg = smt::DeclareFunOp::create(
977 rewriter, loc, arg.getType(),
978 cast<StringAttr>((*adaptor.getBoundVarNames())[i]));
980 newArg = smt::DeclareFunOp::create(rewriter, loc, arg.getType());
981 repl.push_back(typeConverter->materializeTargetConversion(
982 rewriter, loc, typeConverter->convertType(arg.getType()), newArg));
985 Value boundStorage = createStorageForValueList(repl, loc, rewriter);
988 auto yieldOp = cast<smt::YieldOp>(op.getBody().front().getTerminator());
989 Value bodyExp = yieldOp.getValues()[0];
990 rewriter.setInsertionPointAfterValue(bodyExp);
991 bodyExp = typeConverter->materializeTargetConversion(
992 rewriter, loc, typeConverter->convertType(bodyExp.getType()), bodyExp);
993 rewriter.eraseOp(yieldOp);
995 rewriter.inlineBlockBefore(&op.getBody().front(), op, repl);
996 rewriter.setInsertionPoint(op);
999 unsigned numPatterns = adaptor.getPatterns().size();
1000 Value numPatternsVal = LLVM::ConstantOp::create(
1001 rewriter, loc, rewriter.getI32Type(), numPatterns);
1003 Value patternStorage;
1004 if (numPatterns > 0) {
1006 for (Region *patternRegion : adaptor.getPatterns()) {
1008 cast<smt::YieldOp>(patternRegion->front().getTerminator());
1009 auto patternTerms = yieldOp.getOperands();
1011 rewriter.setInsertionPoint(yieldOp);
1012 SmallVector<Value> patternList;
1013 for (
auto val : patternTerms)
1014 patternList.push_back(typeConverter->materializeTargetConversion(
1015 rewriter, loc, typeConverter->
convertType(val.getType()), val));
1017 rewriter.eraseOp(yieldOp);
1018 rewriter.inlineBlockBefore(&patternRegion->front(), op, repl);
1020 rewriter.setInsertionPoint(op);
1021 Value numTerms = LLVM::ConstantOp::create(
1022 rewriter, loc, rewriter.getI32Type(), patternTerms.size());
1023 Value patternTermStorage =
1024 createStorageForValueList(patternList, loc, rewriter);
1025 Value
pattern = buildPtrAPICall(rewriter, loc,
"Z3_mk_pattern",
1026 {numTerms, patternTermStorage});
1030 patternStorage = createStorageForValueList(
patterns, loc, rewriter);
1034 patternStorage = LLVM::ZeroOp::create(rewriter, loc, ptrTy);
1037 StringRef apiCallName =
"Z3_mk_forall_const";
1038 if (std::is_same_v<QuantifierOp, ExistsOp>)
1039 apiCallName =
"Z3_mk_exists_const";
1040 Value quantifierExp =
1041 buildPtrAPICall(rewriter, loc, apiCallName,
1042 {weight, numDeclsVal, boundStorage, numPatternsVal,
1043 patternStorage, bodyExp});
1045 rewriter.replaceOp(op, quantifierExp);
1054struct RepeatOpLowering :
public SMTLoweringPattern<RepeatOp> {
1055 using SMTLoweringPattern::SMTLoweringPattern;
1058 matchAndRewrite(RepeatOp op, OpAdaptor adaptor,
1059 ConversionPatternRewriter &rewriter)
const final {
1060 Value count = LLVM::ConstantOp::create(
1061 rewriter, op.getLoc(), rewriter.getI32Type(), op.getCount());
1062 rewriter.replaceOp(op,
1063 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_mk_repeat",
1064 {count, adaptor.getInput()}));
1076struct ExtractOpLowering :
public SMTLoweringPattern<ExtractOp> {
1077 using SMTLoweringPattern::SMTLoweringPattern;
1080 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
1081 ConversionPatternRewriter &rewriter)
const final {
1082 Location loc = op.getLoc();
1083 Value low = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1084 adaptor.getLowBit());
1085 Value high = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1086 adaptor.getLowBit() +
1087 op.getType().getWidth() - 1);
1088 rewriter.replaceOp(op, buildPtrAPICall(rewriter, loc,
"Z3_mk_extract",
1089 {high, low, adaptor.getInput()}));
1098struct ArrayBroadcastOpLowering
1099 :
public SMTLoweringPattern<smt::ArrayBroadcastOp> {
1100 using SMTLoweringPattern::SMTLoweringPattern;
1103 matchAndRewrite(smt::ArrayBroadcastOp op, OpAdaptor adaptor,
1104 ConversionPatternRewriter &rewriter)
const final {
1105 auto domainSort = buildSort(
1106 rewriter, op.getLoc(),
1107 cast<smt::ArrayType>(op.getResult().getType()).getDomainType());
1109 rewriter.replaceOp(op, buildPtrAPICall(rewriter, op.getLoc(),
1110 "Z3_mk_const_array",
1111 {domainSort, adaptor.getValue()}));
1122struct BoolConstantOpLowering :
public SMTLoweringPattern<smt::BoolConstantOp> {
1123 using SMTLoweringPattern::SMTLoweringPattern;
1126 matchAndRewrite(smt::BoolConstantOp op, OpAdaptor adaptor,
1127 ConversionPatternRewriter &rewriter)
const final {
1129 op, buildPtrAPICall(rewriter, op.getLoc(),
1130 adaptor.getValue() ?
"Z3_mk_true" :
"Z3_mk_false"));
1146struct IntConstantOpLowering :
public SMTLoweringPattern<smt::IntConstantOp> {
1147 using SMTLoweringPattern::SMTLoweringPattern;
1150 matchAndRewrite(smt::IntConstantOp op, OpAdaptor adaptor,
1151 ConversionPatternRewriter &rewriter)
const final {
1152 Location loc = op.getLoc();
1153 Value type = buildPtrAPICall(rewriter, loc,
"Z3_mk_int_sort");
1154 if (adaptor.getValue().getBitWidth() <= 64) {
1155 Value val = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1156 adaptor.getValue().getSExtValue());
1158 op, buildPtrAPICall(rewriter, loc,
"Z3_mk_int64", {val, type}));
1162 std::string numeralStr;
1163 llvm::raw_string_ostream stream(numeralStr);
1164 stream << adaptor.getValue().abs();
1166 Value numeral = buildString(rewriter, loc, numeralStr);
1168 buildPtrAPICall(rewriter, loc,
"Z3_mk_numeral", {numeral, type});
1170 if (adaptor.getValue().isNegative())
1172 buildPtrAPICall(rewriter, loc,
"Z3_mk_unary_minus", intNumeral);
1174 rewriter.replaceOp(op, intNumeral);
1184struct IntCmpOpLowering :
public SMTLoweringPattern<IntCmpOp> {
1185 using SMTLoweringPattern::SMTLoweringPattern;
1188 matchAndRewrite(IntCmpOp op, OpAdaptor adaptor,
1189 ConversionPatternRewriter &rewriter)
const final {
1192 buildPtrAPICall(rewriter, op.getLoc(),
1193 "Z3_mk_" + stringifyIntPredicate(op.getPred()).str(),
1194 {adaptor.getLhs(), adaptor.getRhs()}));
1203struct Int2BVOpLowering :
public SMTLoweringPattern<Int2BVOp> {
1204 using SMTLoweringPattern::SMTLoweringPattern;
1207 matchAndRewrite(Int2BVOp op, OpAdaptor adaptor,
1208 ConversionPatternRewriter &rewriter)
const final {
1210 LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI32Type(),
1211 op.getResult().getType().getWidth());
1212 rewriter.replaceOp(op,
1213 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_mk_int2bv",
1214 {widthConst, adaptor.getInput()}));
1223struct BV2IntOpLowering :
public SMTLoweringPattern<BV2IntOp> {
1224 using SMTLoweringPattern::SMTLoweringPattern;
1227 matchAndRewrite(BV2IntOp op, OpAdaptor adaptor,
1228 ConversionPatternRewriter &rewriter)
const final {
1231 Value isSignedConst = LLVM::ConstantOp::create(
1232 rewriter, op->getLoc(), rewriter.getI1Type(), op.getIsSigned());
1233 rewriter.replaceOp(op,
1234 buildPtrAPICall(rewriter, op.getLoc(),
"Z3_mk_bv2int",
1235 {adaptor.getInput(), isSignedConst}));
1246struct BVCmpOpLowering :
public SMTLoweringPattern<BVCmpOp> {
1247 using SMTLoweringPattern::SMTLoweringPattern;
1250 matchAndRewrite(BVCmpOp op, OpAdaptor adaptor,
1251 ConversionPatternRewriter &rewriter)
const final {
1253 op, buildPtrAPICall(rewriter, op.getLoc(),
1255 stringifyBVCmpPredicate(op.getPred()).str(),
1256 {adaptor.getLhs(), adaptor.getRhs()}));
1262struct IntAbsOpLowering :
public SMTLoweringPattern<IntAbsOp> {
1263 using SMTLoweringPattern::SMTLoweringPattern;
1266 matchAndRewrite(IntAbsOp op, OpAdaptor adaptor,
1267 ConversionPatternRewriter &rewriter)
const final {
1268 Location loc = op.getLoc();
1269 Value zero = IntConstantOp::create(
1270 rewriter, loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
1271 Value cmp = IntCmpOp::create(rewriter, loc, IntPredicate::lt,
1272 adaptor.getInput(), zero);
1273 Value neg = IntSubOp::create(rewriter, loc, zero, adaptor.getInput());
1274 rewriter.replaceOpWithNewOp<IteOp>(op, cmp, neg, adaptor.getInput());
1287 using OpConversionPattern::OpConversionPattern;
1290 matchAndRewrite(debug::VariableOp op, OpAdaptor adaptor,
1291 ConversionPatternRewriter &rewriter)
const final {
1292 rewriter.eraseOp(op);
1299 using OpConversionPattern::OpConversionPattern;
1302 matchAndRewrite(debug::ScopeOp op, OpAdaptor adaptor,
1303 ConversionPatternRewriter &rewriter)
const final {
1305 if (llvm::any_of(op->getUsers(), [](Operation *user) {
1306 return !isa<debug::VariableOp>(user);
1309 rewriter.eraseOp(op);
1321struct LowerSMTToZ3LLVMPass
1322 :
public circt::impl::LowerSMTToZ3LLVMBase<LowerSMTToZ3LLVMPass> {
1324 void runOnOperation()
override;
1329 converter.addConversion([](smt::BoolType type) {
1330 return LLVM::LLVMPointerType::get(type.getContext());
1332 converter.addConversion([](smt::BitVectorType type) {
1333 return LLVM::LLVMPointerType::get(type.getContext());
1335 converter.addConversion([](smt::ArrayType type) {
1336 return LLVM::LLVMPointerType::get(type.getContext());
1338 converter.addConversion([](smt::IntType type) {
1339 return LLVM::LLVMPointerType::get(type.getContext());
1341 converter.addConversion([](smt::SMTFuncType type) {
1342 return LLVM::LLVMPointerType::get(type.getContext());
1344 converter.addConversion([](smt::SortType type) {
1345 return LLVM::LLVMPointerType::get(type.getContext());
1350 RewritePatternSet &
patterns, TypeConverter &converter,
1352#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS) \
1353 patterns.add<VariadicSMTPattern<OP>>( \
1354 converter, patterns.getContext(), \
1355 globals, options, APINAME, \
1358#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS) \
1359 patterns.add<OneToOneSMTPattern<OP>>( \
1360 converter, patterns.getContext(), \
1361 globals, options, APINAME, NUM_ARGS);
1414 patterns.add<LowerLeftAssocSMTPattern<XOrOp>>(
1415 converter,
patterns.getContext(), globals, options);
1515#undef ADD_VARIADIC_PATTERN
1516#undef ADD_ONE_TO_ONE_PATTERN
1533 patterns.add<LowerChainableSMTPattern<EqOp>>(converter,
patterns.getContext(),
1536 globals, options,
"Z3_mk_eq", 2);
1540 patterns.add<BVConstantOpLowering, DeclareFunOpLowering, AssertOpLowering,
1541 ResetOpLowering, PushOpLowering, PopOpLowering, CheckOpLowering,
1542 SolverOpLowering, ApplyFuncOpLowering, YieldOpLowering,
1543 RepeatOpLowering, ExtractOpLowering, BoolConstantOpLowering,
1544 IntConstantOpLowering, ArrayBroadcastOpLowering, BVCmpOpLowering,
1545 IntCmpOpLowering, IntAbsOpLowering, Int2BVOpLowering,
1546 BV2IntOpLowering, QuantifierLowering<ForallOp>,
1547 QuantifierLowering<ExistsOp>>(converter,
patterns.getContext(),
1549 patterns.add<DbgVariableLowering, DbgScopeLowering>(
patterns.getContext());
1552void LowerSMTToZ3LLVMPass::runOnOperation() {
1553 LowerSMTToZ3LLVMOptions options;
1554 options.debug =
debug;
1558 auto setLogicCheck = getOperation().walk([&](SolverOp solverOp)
1562 auto setLogicOps = solverOp.getBodyRegion().getOps<smt::SetLogicOp>();
1563 auto numSetLogicOps = std::distance(setLogicOps.begin(), setLogicOps.end());
1564 if (numSetLogicOps > 1) {
1565 return solverOp.emitError(
1566 "multiple set-logic operations found in one solver operation - Z3 "
1567 "only supports setting the logic once");
1569 if (numSetLogicOps == 1)
1571 for (
auto &blockOp : solverOp.getBodyRegion().getOps()) {
1572 if (isa<smt::SetLogicOp>(blockOp))
1574 if (!blockOp.hasTrait<OpTrait::ConstantLike>()) {
1575 return solverOp.emitError(
"set-logic operation must be the first "
1576 "non-constant operation in a solver "
1580 return WalkResult::advance();
1582 if (setLogicCheck.wasInterrupted())
1583 return signalPassFailure();
1586 LLVMTypeConverter converter(&getContext());
1589 RewritePatternSet
patterns(&getContext());
1605 populateFuncToLLVMConversionPatterns(converter,
patterns);
1606 arith::populateArithToLLVMConversionPatterns(converter,
patterns);
1611 populateSCFToControlFlowConversionPatterns(
patterns);
1612 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter,
patterns);
1616 OpBuilder builder(&getContext());
1622 LLVMConversionTarget target(getContext());
1623 target.addLegalOp<mlir::ModuleOp>();
1624 target.addLegalOp<scf::YieldOp>();
1625 target.addIllegalDialect<debug::DebugDialect>();
1627 if (failed(applyFullConversion(getOperation(), target, std::move(
patterns))))
1628 return signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static FIRRTLBaseType convertType(FIRRTLBaseType type)
Returns null type if no conversion is needed.
#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS)
#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS)
static Location getLoc(DefSlot slot)
RewritePatternSet pattern
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateSMTToZ3LLVMTypeConverter(TypeConverter &converter)
Populate the given type converter with the SMT to LLVM type conversions.
void populateSMTToZ3LLVMConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter, SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options)
Add the SMT to LLVM IR conversion patterns to 'patterns'.
A symbol cache for LLVM globals and functions relevant to SMT lowering patterns.
static SMTGlobalsHandler create(OpBuilder &builder, ModuleOp module)
Creates the LLVM global operations to store the pointers to the solver and the context and returns a ...
SMTGlobalsHandler(ModuleOp module, mlir::LLVM::GlobalOp solver, mlir::LLVM::GlobalOp ctx)
Initializes the caches and keeps track of the given globals to store the pointers to the SMT solver a...