CIRCT 23.0.0git
Loading...
Searching...
No Matches
LowerSMTToZ3LLVM.cpp
Go to the documentation of this file.
1//===- LowerSMTToZ3LLVM.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
14#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
15#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
16#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
18#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
19#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
22#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
23#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
24#include "mlir/Dialect/SCF/IR/SCF.h"
25#include "mlir/Dialect/SMT/IR/SMTOps.h"
26#include "mlir/IR/BuiltinDialect.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Transforms/DialectConversion.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
32
33#define DEBUG_TYPE "lower-smt-to-z3-llvm"
34
35namespace circt {
36#define GEN_PASS_DEF_LOWERSMTTOZ3LLVM
37#include "circt/Conversion/Passes.h.inc"
38} // namespace circt
39
40using namespace mlir;
41using namespace circt;
42using namespace smt;
43
44//===----------------------------------------------------------------------===//
45// SMTGlobalHandler implementation
46//===----------------------------------------------------------------------===//
47
49 ModuleOp module) {
50 OpBuilder::InsertionGuard guard(builder);
51 builder.setInsertionPointToStart(module.getBody());
52
53 SymbolCache symCache;
54 symCache.addDefinitions(module);
56 names.add(symCache);
57
58 Location loc = module.getLoc();
59 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
60
61 auto createGlobal = [&](StringRef namePrefix) {
62 auto global = LLVM::GlobalOp::create(
63 builder, loc, ptrTy, false, LLVM::Linkage::Internal,
64 names.newName(namePrefix), Attribute{}, /*alignment=*/8);
65 OpBuilder::InsertionGuard g(builder);
66 builder.createBlock(&global.getInitializer());
67 Value res = LLVM::ZeroOp::create(builder, loc, ptrTy);
68 LLVM::ReturnOp::create(builder, loc, res);
69 return global;
70 };
71
72 auto ctxGlobal = createGlobal("ctx");
73 auto solverGlobal = createGlobal("solver");
74
75 return SMTGlobalsHandler(std::move(names), solverGlobal, ctxGlobal);
76}
77
79 mlir::LLVM::GlobalOp solver,
80 mlir::LLVM::GlobalOp ctx)
81 : solver(solver), ctx(ctx), names(names) {}
82
84 mlir::LLVM::GlobalOp solver,
85 mlir::LLVM::GlobalOp ctx)
86 : solver(solver), ctx(ctx) {
87 SymbolCache symCache;
88 symCache.addDefinitions(module);
89 names.add(symCache);
90}
91
92//===----------------------------------------------------------------------===//
93// Lowering Pattern Base
94//===----------------------------------------------------------------------===//
95
96namespace {
97
98template <typename OpTy>
99class SMTLoweringPattern : public OpConversionPattern<OpTy> {
100public:
101 SMTLoweringPattern(const TypeConverter &typeConverter, MLIRContext *context,
102 SMTGlobalsHandler &globals,
103 const LowerSMTToZ3LLVMOptions &options)
104 : OpConversionPattern<OpTy>(typeConverter, context), globals(globals),
105 options(options) {}
106
107private:
108 Value buildGlobalPtrToGlobal(OpBuilder &builder, Location loc,
109 LLVM::GlobalOp global,
110 DenseMap<Block *, Value> &cache) const {
111 Block *block = builder.getBlock();
112 if (auto iter = cache.find(block); iter != cache.end())
113 return iter->getSecond();
114
115 OpBuilder::InsertionGuard g(builder);
116 builder.setInsertionPointToStart(block);
117 Value globalAddr = LLVM::AddressOfOp::create(builder, loc, global);
118 return cache[block] = LLVM::LoadOp::create(
119 builder, loc, LLVM::LLVMPointerType::get(builder.getContext()),
120 globalAddr);
121 }
122
123protected:
124 /// A convenience function to get the pointer to the context from the 'global'
125 /// operation. The result is cached for each basic block, i.e., it is assumed
126 /// that this function is never called in the same basic block again at a
127 /// location (insertion point of the 'builder') not dominating all previous
128 /// locations this function was called at.
129 Value buildContextPtr(OpBuilder &builder, Location loc) const {
130 return buildGlobalPtrToGlobal(builder, loc, globals.ctx, globals.ctxCache);
131 }
132
133 /// A convenience function to get the pointer to the solver from the 'global'
134 /// operation. The result is cached for each basic block, i.e., it is assumed
135 /// that this function is never called in the same basic block again at a
136 /// location (insertion point of the 'builder') not dominating all previous
137 /// locations this function was called at.
138 Value buildSolverPtr(OpBuilder &builder, Location loc) const {
139 return buildGlobalPtrToGlobal(builder, loc, globals.solver,
140 globals.solverCache);
141 }
142
143 /// Create a `llvm.call` operation to a function with the given 'name' and
144 /// 'type'. If there does not already exist a (external) function with that
145 /// name create a matching external function declaration.
146 LLVM::CallOp buildCall(OpBuilder &builder, Location loc, StringRef name,
147 LLVM::LLVMFunctionType funcType,
148 ValueRange args) const {
149 auto &funcOp = globals.funcMap[builder.getStringAttr(name)];
150 if (!funcOp) {
151 OpBuilder::InsertionGuard guard(builder);
152 auto module =
153 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
154 builder.setInsertionPointToEnd(module.getBody());
155 auto funcOpResult = LLVM::lookupOrCreateFn(
156 builder, module, name, funcType.getParams(), funcType.getReturnType(),
157 funcType.getVarArg());
158 assert(succeeded(funcOpResult) && "expected to lookup or create printf");
159 funcOp = funcOpResult.value();
160 }
161 return LLVM::CallOp::create(builder, loc, funcOp, args);
162 }
163
164 /// Build a global constant for the given string and construct an 'addressof'
165 /// operation at the current 'builder' insertion point to get a pointer to it.
166 /// Multiple calls with the same string will reuse the same global. It is
167 /// guaranteed that the symbol of the global will be unique.
168 Value buildString(OpBuilder &builder, Location loc, StringRef str) const {
169 auto &global = globals.stringCache[builder.getStringAttr(str)];
170 if (!global) {
171 OpBuilder::InsertionGuard guard(builder);
172 auto module =
173 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
174 builder.setInsertionPointToEnd(module.getBody());
175 auto arrayTy =
176 LLVM::LLVMArrayType::get(builder.getI8Type(), str.size() + 1);
177 auto strAttr = builder.getStringAttr(str.str() + '\00');
178 global = LLVM::GlobalOp::create(
179 builder, loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
180 globals.names.newName("str"), strAttr);
181 }
182 return LLVM::AddressOfOp::create(builder, loc, global);
183 }
184 /// Most API functions require a pointer to the the Z3 context object as the
185 /// first argument. This helper function prepends this pointer value to the
186 /// call for convenience.
187 LLVM::CallOp buildAPICallWithContext(OpBuilder &builder, Location loc,
188 StringRef name, Type returnType,
189 ValueRange args = {}) const {
190 auto ctx = buildContextPtr(builder, loc);
191 SmallVector<Value> arguments;
192 arguments.emplace_back(ctx);
193 arguments.append(SmallVector<Value>(args));
194 return buildCall(
195 builder, loc, name,
196 LLVM::LLVMFunctionType::get(
197 returnType, SmallVector<Type>(ValueRange(arguments).getTypes())),
198 arguments);
199 }
200
201 /// Most API functions we need to call return a 'Z3_AST' object which is a
202 /// pointer in LLVM. This helper function simplifies calling those API
203 /// functions.
204 Value buildPtrAPICall(OpBuilder &builder, Location loc, StringRef name,
205 ValueRange args = {}) const {
206 return buildAPICallWithContext(
207 builder, loc, name,
208 LLVM::LLVMPointerType::get(builder.getContext()), args)
209 ->getResult(0);
210 }
211
212 /// Build a value representing the SMT sort given with 'type'.
213 Value buildSort(OpBuilder &builder, Location loc, Type type) const {
214 // NOTE: if a type not handled by this switch is passed, an assertion will
215 // be triggered.
216 return TypeSwitch<Type, Value>(type)
217 .Case([&](smt::IntType ty) {
218 return buildPtrAPICall(builder, loc, "Z3_mk_int_sort");
219 })
220 .Case([&](smt::BitVectorType ty) {
221 Value bitwidth = LLVM::ConstantOp::create(
222 builder, loc, builder.getI32Type(), ty.getWidth());
223 return buildPtrAPICall(builder, loc, "Z3_mk_bv_sort", {bitwidth});
224 })
225 .Case([&](smt::BoolType ty) {
226 return buildPtrAPICall(builder, loc, "Z3_mk_bool_sort");
227 })
228 .Case([&](smt::SortType ty) {
229 Value str = buildString(builder, loc, ty.getIdentifier());
230 Value sym =
231 buildPtrAPICall(builder, loc, "Z3_mk_string_symbol", {str});
232 return buildPtrAPICall(builder, loc, "Z3_mk_uninterpreted_sort",
233 {sym});
234 })
235 .Case([&](smt::ArrayType ty) {
236 return buildPtrAPICall(builder, loc, "Z3_mk_array_sort",
237 {buildSort(builder, loc, ty.getDomainType()),
238 buildSort(builder, loc, ty.getRangeType())});
239 });
240 }
241
242 SMTGlobalsHandler &globals;
243 const LowerSMTToZ3LLVMOptions &options;
244};
245
246//===----------------------------------------------------------------------===//
247// Lowering Patterns
248//===----------------------------------------------------------------------===//
249
250/// The 'smt.declare_fun' operation is used to declare both constants and
251/// functions. The Z3 API, however, uses two different functions. Therefore,
252/// depending on the result type of this operation, one of the following two
253/// API functions is used to create the symbolic value:
254/// ```
255/// Z3_ast Z3_API Z3_mk_fresh_const(Z3_context c, Z3_string prefix, Z3_sort ty);
256/// Z3_func_decl Z3_API Z3_mk_fresh_func_decl(
257/// Z3_context c, Z3_string prefix, unsigned domain_size,
258/// Z3_sort const domain[], Z3_sort range);
259/// ```
260struct DeclareFunOpLowering : public SMTLoweringPattern<DeclareFunOp> {
261 using SMTLoweringPattern::SMTLoweringPattern;
262
263 LogicalResult
264 matchAndRewrite(DeclareFunOp op, OpAdaptor adaptor,
265 ConversionPatternRewriter &rewriter) const final {
266 Location loc = op.getLoc();
267
268 // Create the name prefix.
269 Value prefix;
270 if (adaptor.getNamePrefix())
271 prefix = buildString(rewriter, loc, *adaptor.getNamePrefix());
272 else
273 prefix = LLVM::ZeroOp::create(rewriter, loc,
274 LLVM::LLVMPointerType::get(getContext()));
275
276 // Handle the constant value case.
277 if (!isa<SMTFuncType>(op.getType())) {
278 Value sort = buildSort(rewriter, loc, op.getType());
279 Value constDecl =
280 buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_const", {prefix, sort});
281 rewriter.replaceOp(op, constDecl);
282 return success();
283 }
284
285 // Otherwise, we declare a function.
286 Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
287 auto funcType = cast<SMTFuncType>(op.getResult().getType());
288 Value rangeSort = buildSort(rewriter, loc, funcType.getRangeType());
289
290 Type arrTy =
291 LLVM::LLVMArrayType::get(llvmPtrTy, funcType.getDomainTypes().size());
292
293 Value domain = LLVM::UndefOp::create(rewriter, loc, arrTy);
294 for (auto [i, ty] : llvm::enumerate(funcType.getDomainTypes())) {
295 Value sort = buildSort(rewriter, loc, ty);
296 domain = LLVM::InsertValueOp::create(rewriter, loc, domain, sort, i);
297 }
298
299 Value one =
300 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
301 Value domainStorage =
302 LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, arrTy, one);
303 LLVM::StoreOp::create(rewriter, loc, domain, domainStorage);
304
305 Value domainSize = LLVM::ConstantOp::create(
306 rewriter, loc, rewriter.getI32Type(), funcType.getDomainTypes().size());
307 Value decl =
308 buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_func_decl",
309 {prefix, domainSize, domainStorage, rangeSort});
310
311 rewriter.replaceOp(op, decl);
312 return success();
313 }
314};
315
316/// Lower the 'smt.apply_func' operation to Z3 API calls of the form:
317/// ```
318/// Z3_ast Z3_API Z3_mk_app(Z3_context c, Z3_func_decl d,
319/// unsigned num_args, Z3_ast const args[]);
320/// ```
321struct ApplyFuncOpLowering : public SMTLoweringPattern<ApplyFuncOp> {
322 using SMTLoweringPattern::SMTLoweringPattern;
323
324 LogicalResult
325 matchAndRewrite(ApplyFuncOp op, OpAdaptor adaptor,
326 ConversionPatternRewriter &rewriter) const final {
327 Location loc = op.getLoc();
328 Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
329 Type arrTy = LLVM::LLVMArrayType::get(llvmPtrTy, adaptor.getArgs().size());
330
331 // Create an array of the function arguments.
332 Value domain = LLVM::UndefOp::create(rewriter, loc, arrTy);
333 for (auto [i, arg] : llvm::enumerate(adaptor.getArgs()))
334 domain = LLVM::InsertValueOp::create(rewriter, loc, domain, arg, i);
335
336 // Store the array on the stack.
337 Value one =
338 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
339 Value domainStorage =
340 LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, arrTy, one);
341 LLVM::StoreOp::create(rewriter, loc, domain, domainStorage);
342
343 // Call the API function with a pointer to the function, the number of
344 // arguments, and the pointer to the arguments stored on the stack.
345 Value domainSize = LLVM::ConstantOp::create(
346 rewriter, loc, rewriter.getI32Type(), adaptor.getArgs().size());
347 Value returnVal =
348 buildPtrAPICall(rewriter, loc, "Z3_mk_app",
349 {adaptor.getFunc(), domainSize, domainStorage});
350 rewriter.replaceOp(op, returnVal);
351
352 return success();
353 }
354};
355
356/// Lower the `smt.bv.constant` operation to either
357/// ```
358/// Z3_ast Z3_API Z3_mk_unsigned_int64(Z3_context c, uint64_t v, Z3_sort ty);
359/// ```
360/// if the bit-vector fits into a 64-bit integer or convert it to a string and
361/// use the sligtly slower but arbitrary precision API function:
362/// ```
363/// Z3_ast Z3_API Z3_mk_numeral(Z3_context c, Z3_string numeral, Z3_sort ty);
364/// ```
365/// Note that there is also an API function taking an array of booleans, and
366/// while those are typically compiled to 'i8' in LLVM they don't necessarily
367/// have to (I think).
368struct BVConstantOpLowering : public SMTLoweringPattern<smt::BVConstantOp> {
369 using SMTLoweringPattern::SMTLoweringPattern;
370
371 LogicalResult
372 matchAndRewrite(smt::BVConstantOp op, OpAdaptor adaptor,
373 ConversionPatternRewriter &rewriter) const final {
374 Location loc = op.getLoc();
375 unsigned width = op.getType().getWidth();
376 auto bvSort = buildSort(rewriter, loc, op.getResult().getType());
377 APInt val = adaptor.getValue().getValue();
378
379 if (width <= 64) {
380 Value bvConst = LLVM::ConstantOp::create(
381 rewriter, loc, rewriter.getI64Type(), val.getZExtValue());
382 Value res = buildPtrAPICall(rewriter, loc, "Z3_mk_unsigned_int64",
383 {bvConst, bvSort});
384 rewriter.replaceOp(op, res);
385 return success();
386 }
387
388 std::string str;
389 llvm::raw_string_ostream stream(str);
390 stream << val;
391 Value bvString = buildString(rewriter, loc, str);
392 Value bvNumeral =
393 buildPtrAPICall(rewriter, loc, "Z3_mk_numeral", {bvString, bvSort});
394
395 rewriter.replaceOp(op, bvNumeral);
396 return success();
397 }
398};
399
400/// Some of the Z3 API supports a variadic number of operands for some
401/// operations (in particular if the expansion would lead to a super-linear
402/// increase in operations such as with the ':pairwise' attribute). Those API
403/// calls take an 'unsigned' argument indicating the size of an array of
404/// pointers to the operands.
405template <typename SourceTy>
406struct VariadicSMTPattern : public SMTLoweringPattern<SourceTy> {
407 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
408
409 VariadicSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
410 SMTGlobalsHandler &globals,
411 const LowerSMTToZ3LLVMOptions &options,
412 StringRef apiFuncName, unsigned minNumArgs)
413 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
414 apiFuncName(apiFuncName), minNumArgs(minNumArgs) {}
415
416 LogicalResult
417 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
418 ConversionPatternRewriter &rewriter) const final {
419 if (adaptor.getOperands().size() < minNumArgs)
420 return failure();
421
422 Location loc = op.getLoc();
423 Value numOperands = LLVM::ConstantOp::create(
424 rewriter, loc, rewriter.getI32Type(), op->getNumOperands());
425 Value constOne =
426 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
427 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
428 Type arrTy = LLVM::LLVMArrayType::get(ptrTy, op->getNumOperands());
429 Value storage =
430 LLVM::AllocaOp::create(rewriter, loc, ptrTy, arrTy, constOne);
431 Value array = LLVM::UndefOp::create(rewriter, loc, arrTy);
432
433 for (auto [i, operand] : llvm::enumerate(adaptor.getOperands()))
434 array = LLVM::InsertValueOp::create(rewriter, loc, array, operand,
435 ArrayRef<int64_t>{(int64_t)i});
436
437 LLVM::StoreOp::create(rewriter, loc, array, storage);
438
439 rewriter.replaceOp(op,
440 SMTLoweringPattern<SourceTy>::buildPtrAPICall(
441 rewriter, loc, apiFuncName, {numOperands, storage}));
442 return success();
443 }
444
445private:
446 StringRef apiFuncName;
447 unsigned minNumArgs;
448};
449
450/// Lower an SMT operation to a function call with the name 'apiFuncName' with
451/// arguments matching the operands one-to-one.
452template <typename SourceTy>
453struct OneToOneSMTPattern : public SMTLoweringPattern<SourceTy> {
454 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
455
456 OneToOneSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
457 SMTGlobalsHandler &globals,
458 const LowerSMTToZ3LLVMOptions &options,
459 StringRef apiFuncName, unsigned numOperands)
460 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
461 apiFuncName(apiFuncName), numOperands(numOperands) {}
462
463 LogicalResult
464 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
465 ConversionPatternRewriter &rewriter) const final {
466 if (adaptor.getOperands().size() != numOperands)
467 return failure();
468
469 rewriter.replaceOp(
470 op, SMTLoweringPattern<SourceTy>::buildPtrAPICall(
471 rewriter, op.getLoc(), apiFuncName, adaptor.getOperands()));
472 return success();
473 }
474
475private:
476 StringRef apiFuncName;
477 unsigned numOperands;
478};
479
480/// A pattern to lower SMT operations with a variadic number of operands
481/// modelling the ':chainable' attribute in SMT to binary operations.
482template <typename SourceTy>
483class LowerChainableSMTPattern : public SMTLoweringPattern<SourceTy> {
484 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
485 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
486
487 LogicalResult
488 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
489 ConversionPatternRewriter &rewriter) const final {
490 if (adaptor.getOperands().size() <= 2)
491 return failure();
492
493 Location loc = op.getLoc();
494 SmallVector<Value> elements;
495 for (int i = 1, e = adaptor.getOperands().size(); i < e; ++i) {
496 Value val = SourceTy::create(
497 rewriter, loc, op->getResultTypes(),
498 ValueRange{adaptor.getOperands()[i - 1], adaptor.getOperands()[i]});
499 elements.push_back(val);
500 }
501 rewriter.replaceOpWithNewOp<smt::AndOp>(op, elements);
502 return success();
503 }
504};
505
506/// A pattern to lower SMT operations with a variadic number of operands
507/// modelling the `:left-assoc` attribute to a sequence of binary operators.
508template <typename SourceTy>
509class LowerLeftAssocSMTPattern : public SMTLoweringPattern<SourceTy> {
510 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
511 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
512
513 LogicalResult
514 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
515 ConversionPatternRewriter &rewriter) const final {
516 if (adaptor.getOperands().size() <= 2)
517 return rewriter.notifyMatchFailure(op, "must have at least two operands");
518
519 Value runner = adaptor.getOperands()[0];
520 for (Value val : adaptor.getOperands().drop_front())
521 runner = SourceTy::create(rewriter, op.getLoc(), op->getResultTypes(),
522 ValueRange{runner, val});
523
524 rewriter.replaceOp(op, runner);
525 return success();
526 }
527};
528
529/// The 'smt.solver' operation has a region that corresponds to the lifetime of
530/// the Z3 context and one solver instance created within this context.
531/// To create a context, a Z3 configuration has to be built first and various
532/// configuration parameters can be set before creating a context from it. Once
533/// we have a context, we can create a solver and store a pointer to the context
534/// and the solver in an LLVM global such that operations in the child region
535/// have access to them. While the context created with `Z3_mk_context` takes
536/// care of the reference counting of `Z3_AST` objects, it still requires manual
537/// reference counting of `Z3_solver` objects, therefore, we need to increase
538/// the ref. counter of the solver we get from `Z3_mk_solver` and must decrease
539/// it again once we don't need it anymore. Finally, the configuration object
540/// can be deleted.
541/// ```
542/// Z3_config Z3_API Z3_mk_config(void);
543/// void Z3_API Z3_set_param_value(Z3_config c, Z3_string param_id,
544/// Z3_string param_value);
545/// Z3_context Z3_API Z3_mk_context(Z3_config c);
546/// Z3_solver Z3_API Z3_mk_solver(Z3_context c);
547/// void Z3_API Z3_solver_inc_ref(Z3_context c, Z3_solver s);
548/// void Z3_API Z3_del_config(Z3_config c);
549/// ```
550/// At the end of the solver lifetime, we have to tell the context that we
551/// don't need the solver anymore and delete the context itself.
552/// ```
553/// void Z3_API Z3_solver_dec_ref(Z3_context c, Z3_solver s);
554/// void Z3_API Z3_del_context(Z3_context c);
555/// ```
556/// Note that the solver created here is a combined solver. There might be some
557/// potential for optimization by creating more specialized solvers supported by
558/// the Z3 API according the the kind of operations present in the body region.
559struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {
560 using SMTLoweringPattern::SMTLoweringPattern;
561
562 LogicalResult
563 matchAndRewrite(SolverOp op, OpAdaptor adaptor,
564 ConversionPatternRewriter &rewriter) const final {
565 Location loc = op.getLoc();
566 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
567 auto voidTy = LLVM::LLVMVoidType::get(getContext());
568 auto ptrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, ptrTy);
569 auto ptrPtrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, {ptrTy, ptrTy});
570 auto ptrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, ptrTy);
571 auto ptrPtrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy});
572
573 // Create the configuration.
574 Value config = buildCall(rewriter, loc, "Z3_mk_config",
575 LLVM::LLVMFunctionType::get(ptrTy, {}), {})
576 .getResult();
577
578 // In debug-mode, we enable proofs such that we can fetch one in the 'unsat'
579 // region of each 'smt.check' operation.
580 if (options.debug) {
581 Value paramKey = buildString(rewriter, loc, "proof");
582 Value paramValue = buildString(rewriter, loc, "true");
583 buildCall(rewriter, loc, "Z3_set_param_value",
584 LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, ptrTy}),
585 {config, paramKey, paramValue});
586 }
587
588 // Check if the logic is set anywhere within the solver
589 std::optional<StringRef> logic = std::nullopt;
590 auto setLogicOps = op.getBodyRegion().getOps<smt::SetLogicOp>();
591 if (!setLogicOps.empty()) {
592 // We know from before patterns were applied that there is only one
593 // set_logic op
594 auto setLogicOp = *setLogicOps.begin();
595 logic = setLogicOp.getLogic();
596 rewriter.eraseOp(setLogicOp);
597 }
598
599 // Create the context and store a pointer to it in the global variable.
600 Value ctx = buildCall(rewriter, loc, "Z3_mk_context", ptrToPtrFunc, config)
601 .getResult();
602 Value ctxAddr =
603 LLVM::AddressOfOp::create(rewriter, loc, globals.ctx).getResult();
604 LLVM::StoreOp::create(rewriter, loc, ctx, ctxAddr);
605
606 // Delete the configuration again.
607 buildCall(rewriter, loc, "Z3_del_config", ptrToVoidFunc, {config});
608
609 // Create a solver instance, increase its reference counter, and store a
610 // pointer to it in the global variable.
611 Value solver;
612 if (logic) {
613 auto logicStr = buildString(rewriter, loc, logic.value());
614 solver = buildCall(rewriter, loc, "Z3_mk_solver_for_logic",
615 ptrPtrToPtrFunc, {ctx, logicStr})
616 ->getResult(0);
617 } else {
618 solver = buildCall(rewriter, loc, "Z3_mk_solver", ptrToPtrFunc, ctx)
619 ->getResult(0);
620 }
621 buildCall(rewriter, loc, "Z3_solver_inc_ref", ptrPtrToVoidFunc,
622 {ctx, solver});
623 Value solverAddr =
624 LLVM::AddressOfOp::create(rewriter, loc, globals.solver).getResult();
625 LLVM::StoreOp::create(rewriter, loc, solver, solverAddr);
626
627 // This assumes that no constant hoisting of the like happens inbetween
628 // the patterns defined in this pass because once the solver initialization
629 // and deallocation calls are inserted and the body region is inlined,
630 // canonicalizations and folders applied inbetween lowering patterns might
631 // hoist the SMT constants which means they would access uninitialized
632 // global variables once they are lowered.
633 SmallVector<Type> convertedTypes;
634 if (failed(
635 typeConverter->convertTypes(op->getResultTypes(), convertedTypes)))
636 return failure();
637
638 func::FuncOp funcOp;
639 {
640 OpBuilder::InsertionGuard guard(rewriter);
641 auto module = op->getParentOfType<ModuleOp>();
642 rewriter.setInsertionPointToEnd(module.getBody());
643
644 funcOp = func::FuncOp::create(
645 rewriter, loc, globals.names.newName("solver"),
646 rewriter.getFunctionType(adaptor.getInputs().getTypes(),
647 convertedTypes));
648 rewriter.inlineRegionBefore(op.getBodyRegion(), funcOp.getBody(),
649 funcOp.end());
650 }
651
652 ValueRange results =
653 func::CallOp::create(rewriter, loc, funcOp, adaptor.getInputs())
654 ->getResults();
655
656 // At the end of the region, decrease the solver's reference counter and
657 // delete the context.
658 // NOTE: we cannot use the convenience helper here because we don't want to
659 // load the context from the global but use the result from the 'mk_context'
660 // call directly for two reasons:
661 // * avoid an unnecessary load
662 // * the caching mechanism of the context does not work here because it
663 // would reuse the loaded context from a earlier solver
664 buildCall(rewriter, loc, "Z3_solver_dec_ref", ptrPtrToVoidFunc,
665 {ctx, solver});
666 buildCall(rewriter, loc, "Z3_del_context", ptrToVoidFunc, ctx);
667
668 rewriter.replaceOp(op, results);
669 return success();
670 }
671};
672
673/// Lower `smt.assert` operations to Z3 API calls of the form:
674/// ```
675/// void Z3_API Z3_solver_assert(Z3_context c, Z3_solver s, Z3_ast a);
676/// ```
677struct AssertOpLowering : public SMTLoweringPattern<AssertOp> {
678 using SMTLoweringPattern::SMTLoweringPattern;
679
680 LogicalResult
681 matchAndRewrite(AssertOp op, OpAdaptor adaptor,
682 ConversionPatternRewriter &rewriter) const final {
683 Location loc = op.getLoc();
684 buildAPICallWithContext(
685 rewriter, loc, "Z3_solver_assert",
686 LLVM::LLVMVoidType::get(getContext()),
687 {buildSolverPtr(rewriter, loc), adaptor.getInput()});
688
689 rewriter.eraseOp(op);
690 return success();
691 }
692};
693
694/// Lower `smt.reset` operations to Z3 API calls of the form:
695/// ```
696/// void Z3_API Z3_solver_reset(Z3_context c, Z3_solver s);
697/// ```
698struct ResetOpLowering : public SMTLoweringPattern<ResetOp> {
699 using SMTLoweringPattern::SMTLoweringPattern;
700
701 LogicalResult
702 matchAndRewrite(ResetOp op, OpAdaptor adaptor,
703 ConversionPatternRewriter &rewriter) const final {
704 Location loc = op.getLoc();
705 buildAPICallWithContext(rewriter, loc, "Z3_solver_reset",
706 LLVM::LLVMVoidType::get(getContext()),
707 {buildSolverPtr(rewriter, loc)});
708
709 rewriter.eraseOp(op);
710 return success();
711 }
712};
713
714/// Lower `smt.push` operations to (repeated) Z3 API calls of the form:
715/// ```
716/// void Z3_API Z3_solver_push(Z3_context c, Z3_solver s);
717/// ```
718struct PushOpLowering : public SMTLoweringPattern<PushOp> {
719 using SMTLoweringPattern::SMTLoweringPattern;
720 LogicalResult
721 matchAndRewrite(PushOp op, OpAdaptor adaptor,
722 ConversionPatternRewriter &rewriter) const final {
723 Location loc = op.getLoc();
724 // SMTLIB allows multiple levels to be pushed with one push command, but the
725 // Z3 C API doesn't let you provide a number of levels for push calls so
726 // multiple calls have to be created.
727 for (uint32_t i = 0; i < op.getCount(); i++)
728 buildAPICallWithContext(rewriter, loc, "Z3_solver_push",
729 LLVM::LLVMVoidType::get(getContext()),
730 {buildSolverPtr(rewriter, loc)});
731 rewriter.eraseOp(op);
732 return success();
733 }
734};
735
736/// Lower `smt.pop` operations to Z3 API calls of the form:
737/// ```
738/// void Z3_API Z3_solver_pop(Z3_context c, Z3_solver s, unsigned n);
739/// ```
740struct PopOpLowering : public SMTLoweringPattern<PopOp> {
741 using SMTLoweringPattern::SMTLoweringPattern;
742 LogicalResult
743 matchAndRewrite(PopOp op, OpAdaptor adaptor,
744 ConversionPatternRewriter &rewriter) const final {
745 Location loc = op.getLoc();
746 Value constVal = LLVM::ConstantOp::create(
747 rewriter, loc, rewriter.getI32Type(), op.getCount());
748 buildAPICallWithContext(rewriter, loc, "Z3_solver_pop",
749 LLVM::LLVMVoidType::get(getContext()),
750 {buildSolverPtr(rewriter, loc), constVal});
751 rewriter.eraseOp(op);
752 return success();
753 }
754};
755
756/// Lower `smt.yield` operations to `scf.yield` operations. This not necessary
757/// for the yield in `smt.solver` or in quantifiers since they are deleted
758/// directly by the parent operation, but makes the lowering of the `smt.check`
759/// operation simpler and more convenient since the regions get translated
760/// directly to regions of `scf.if` operations.
761struct YieldOpLowering : public SMTLoweringPattern<YieldOp> {
762 using SMTLoweringPattern::SMTLoweringPattern;
763
764 LogicalResult
765 matchAndRewrite(YieldOp op, OpAdaptor adaptor,
766 ConversionPatternRewriter &rewriter) const final {
767 if (op->getParentOfType<func::FuncOp>()) {
768 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getValues());
769 return success();
770 }
771 if (op->getParentOfType<LLVM::LLVMFuncOp>()) {
772 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getValues());
773 return success();
774 }
775 if (isa_and_nonnull<scf::SCFDialect>(op->getParentOp()->getDialect())) {
776 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getValues());
777 return success();
778 }
779 return failure();
780 }
781};
782
783/// Lower `smt.check` operations to Z3 API calls and control-flow operations.
784/// ```
785/// Z3_lbool Z3_API Z3_solver_check(Z3_context c, Z3_solver s);
786///
787/// typedef enum
788/// {
789/// Z3_L_FALSE = -1, // means unsatisfiable here
790/// Z3_L_UNDEF, // means unknown here
791/// Z3_L_TRUE // means satisfiable here
792/// } Z3_lbool;
793/// ```
794struct CheckOpLowering : public SMTLoweringPattern<CheckOp> {
795 using SMTLoweringPattern::SMTLoweringPattern;
796
797 LogicalResult
798 matchAndRewrite(CheckOp op, OpAdaptor adaptor,
799 ConversionPatternRewriter &rewriter) const final {
800 Location loc = op.getLoc();
801 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
802 auto printfType = LLVM::LLVMFunctionType::get(
803 LLVM::LLVMVoidType::get(rewriter.getContext()), {ptrTy}, true);
804
805 auto getHeaderString = [](const std::string &title) {
806 unsigned titleSize = title.size() + 2; // Add a space left and right
807 return std::string((80 - titleSize) / 2, '-') + " " + title + " " +
808 std::string((80 - titleSize + 1) / 2, '-') + "\n%s\n" +
809 std::string(80, '-') + "\n";
810 };
811
812 // Get the pointer to the solver instance.
813 Value solver = buildSolverPtr(rewriter, loc);
814
815 // In debug-mode, print the state of the solver before calling 'check-sat'
816 // on it. This prints the asserted SMT expressions.
817 if (options.debug) {
818 auto solverStringPtr =
819 buildPtrAPICall(rewriter, loc, "Z3_solver_to_string", {solver});
820 auto solverFormatString =
821 buildString(rewriter, loc, getHeaderString("Solver"));
822 buildCall(rewriter, op.getLoc(), "printf", printfType,
823 {solverFormatString, solverStringPtr});
824 }
825
826 // Convert the result types of the `smt.check` operation.
827 SmallVector<Type> resultTypes;
828 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
829 return failure();
830
831 // Call 'check-sat' and check if the assertions are satisfiable.
832 Value checkResult =
833 buildAPICallWithContext(rewriter, loc, "Z3_solver_check",
834 rewriter.getI32Type(), {solver})
835 ->getResult(0);
836 Value constOne =
837 LLVM::ConstantOp::create(rewriter, loc, checkResult.getType(), 1);
838 Value isSat = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq,
839 checkResult, constOne);
840
841 // Simply inline the 'sat' region into the 'then' region of the 'scf.if'
842 auto satIfOp = scf::IfOp::create(rewriter, loc, resultTypes, isSat);
843 rewriter.inlineRegionBefore(op.getSatRegion(), satIfOp.getThenRegion(),
844 satIfOp.getThenRegion().end());
845
846 // Otherwise, the 'else' block checks if the assertions are unsatisfiable or
847 // unknown. The corresponding regions can also be simply inlined into the
848 // two branches of this nested if-statement as well.
849 rewriter.createBlock(&satIfOp.getElseRegion());
850 Value constNegOne =
851 LLVM::ConstantOp::create(rewriter, loc, checkResult.getType(), -1);
852 Value isUnsat = LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq,
853 checkResult, constNegOne);
854 auto unsatIfOp = scf::IfOp::create(rewriter, loc, resultTypes, isUnsat);
855 scf::YieldOp::create(rewriter, loc, unsatIfOp->getResults());
856
857 rewriter.inlineRegionBefore(op.getUnsatRegion(), unsatIfOp.getThenRegion(),
858 unsatIfOp.getThenRegion().end());
859 rewriter.inlineRegionBefore(op.getUnknownRegion(),
860 unsatIfOp.getElseRegion(),
861 unsatIfOp.getElseRegion().end());
862
863 rewriter.replaceOp(op, satIfOp->getResults());
864
865 if (options.debug) {
866 // In debug-mode, if the assertions are unsatisfiable we can print the
867 // proof.
868 rewriter.setInsertionPointToStart(unsatIfOp.thenBlock());
869 auto proof = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_proof",
870 {solver});
871 auto stringPtr =
872 buildPtrAPICall(rewriter, op.getLoc(), "Z3_ast_to_string", {proof});
873 auto formatString =
874 buildString(rewriter, op.getLoc(), getHeaderString("Proof"));
875 buildCall(rewriter, op.getLoc(), "printf", printfType,
876 {formatString, stringPtr});
877
878 // In debug mode, if the assertions are satisfiable we can print the model
879 // (effectively a counter-example).
880 rewriter.setInsertionPointToStart(satIfOp.thenBlock());
881 auto model = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_model",
882 {solver});
883 auto modelStringPtr =
884 buildPtrAPICall(rewriter, op.getLoc(), "Z3_model_to_string", {model});
885 auto modelFormatString =
886 buildString(rewriter, op.getLoc(), getHeaderString("Model"));
887 buildCall(rewriter, op.getLoc(), "printf", printfType,
888 {modelFormatString, modelStringPtr});
889 }
890
891 return success();
892 }
893};
894
895/// Lower `smt.forall` and `smt.exists` operations to the following Z3 API call.
896/// ```
897/// Z3_ast Z3_API Z3_mk_{forall|exists}_const(
898/// Z3_context c,
899/// unsigned weight,
900/// unsigned num_bound,
901/// Z3_app const bound[],
902/// unsigned num_patterns,
903/// Z3_pattern const patterns[],
904/// Z3_ast body
905/// );
906/// ```
907/// All nested regions are inlined into the parent region and the block
908/// arguments are replaced with new `smt.declare_fun` constants that are also
909/// passed to the `bound` argument of above API function. Patterns are created
910/// with the following API function.
911/// ```
912/// Z3_pattern Z3_API Z3_mk_pattern(Z3_context c, unsigned num_patterns,
913/// Z3_ast const terms[]);
914/// ```
915/// Where each operand of the `smt.yield` in a pattern region is a 'term'.
916template <typename QuantifierOp>
917struct QuantifierLowering : public SMTLoweringPattern<QuantifierOp> {
918 using SMTLoweringPattern<QuantifierOp>::SMTLoweringPattern;
919 using SMTLoweringPattern<QuantifierOp>::typeConverter;
920 using SMTLoweringPattern<QuantifierOp>::buildPtrAPICall;
921 using OpAdaptor = typename QuantifierOp::Adaptor;
922
923 Value createStorageForValueList(ValueRange values, Location loc,
924 ConversionPatternRewriter &rewriter) const {
925 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
926 Type arrTy = LLVM::LLVMArrayType::get(ptrTy, values.size());
927 Value constOne =
928 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 1);
929 Value storage =
930 LLVM::AllocaOp::create(rewriter, loc, ptrTy, arrTy, constOne);
931 Value array = LLVM::UndefOp::create(rewriter, loc, arrTy);
932
933 for (auto [i, val] : llvm::enumerate(values))
934 array = LLVM::InsertValueOp::create(rewriter, loc, array, val,
935 ArrayRef<int64_t>(i));
936
937 LLVM::StoreOp::create(rewriter, loc, array, storage);
938
939 return storage;
940 }
941
942 LogicalResult
943 matchAndRewrite(QuantifierOp op, OpAdaptor adaptor,
944 ConversionPatternRewriter &rewriter) const final {
945 Location loc = op.getLoc();
946 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
947
948 // no-pattern attribute not supported yet because the Z3 CAPI allows more
949 // fine-grained control where a list of patterns to be banned can be given.
950 // This means, the no-pattern attribute is equivalent to providing a list of
951 // all possible sub-expressions in the quantifier body to the CAPI.
952 if (adaptor.getNoPattern())
953 return rewriter.notifyMatchFailure(
954 op, "no-pattern attribute not yet supported!");
955
956 rewriter.setInsertionPoint(op);
957
958 // Weight attribute
959 Value weight = LLVM::ConstantOp::create(
960 rewriter, loc, rewriter.getI32Type(), adaptor.getWeight());
961
962 // Bound variables
963 unsigned numDecls = op.getBody().getNumArguments();
964 Value numDeclsVal = LLVM::ConstantOp::create(
965 rewriter, loc, rewriter.getI32Type(), numDecls);
966
967 // We replace the block arguments with constant symbolic values and inform
968 // the quantifier API call which constants it should treat as bound
969 // variables. We also need to make sure that we use the exact same SSA
970 // values in the pattern regions since we lower constant declaration
971 // operation to always produce fresh constants.
972 SmallVector<Value> repl;
973 for (auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) {
974 Value newArg;
975 if (adaptor.getBoundVarNames().has_value())
976 newArg = smt::DeclareFunOp::create(
977 rewriter, loc, arg.getType(),
978 cast<StringAttr>((*adaptor.getBoundVarNames())[i]));
979 else
980 newArg = smt::DeclareFunOp::create(rewriter, loc, arg.getType());
981 repl.push_back(typeConverter->materializeTargetConversion(
982 rewriter, loc, typeConverter->convertType(arg.getType()), newArg));
983 }
984
985 Value boundStorage = createStorageForValueList(repl, loc, rewriter);
986
987 // Body Expression
988 auto yieldOp = cast<smt::YieldOp>(op.getBody().front().getTerminator());
989 Value bodyExp = yieldOp.getValues()[0];
990 rewriter.setInsertionPointAfterValue(bodyExp);
991 bodyExp = typeConverter->materializeTargetConversion(
992 rewriter, loc, typeConverter->convertType(bodyExp.getType()), bodyExp);
993 rewriter.eraseOp(yieldOp);
994
995 rewriter.inlineBlockBefore(&op.getBody().front(), op, repl);
996 rewriter.setInsertionPoint(op);
997
998 // Patterns
999 unsigned numPatterns = adaptor.getPatterns().size();
1000 Value numPatternsVal = LLVM::ConstantOp::create(
1001 rewriter, loc, rewriter.getI32Type(), numPatterns);
1002
1003 Value patternStorage;
1004 if (numPatterns > 0) {
1005 SmallVector<Value> patterns;
1006 for (Region *patternRegion : adaptor.getPatterns()) {
1007 auto yieldOp =
1008 cast<smt::YieldOp>(patternRegion->front().getTerminator());
1009 auto patternTerms = yieldOp.getOperands();
1010
1011 rewriter.setInsertionPoint(yieldOp);
1012 SmallVector<Value> patternList;
1013 for (auto val : patternTerms)
1014 patternList.push_back(typeConverter->materializeTargetConversion(
1015 rewriter, loc, typeConverter->convertType(val.getType()), val));
1016
1017 rewriter.eraseOp(yieldOp);
1018 rewriter.inlineBlockBefore(&patternRegion->front(), op, repl);
1019
1020 rewriter.setInsertionPoint(op);
1021 Value numTerms = LLVM::ConstantOp::create(
1022 rewriter, loc, rewriter.getI32Type(), patternTerms.size());
1023 Value patternTermStorage =
1024 createStorageForValueList(patternList, loc, rewriter);
1025 Value pattern = buildPtrAPICall(rewriter, loc, "Z3_mk_pattern",
1026 {numTerms, patternTermStorage});
1027
1028 patterns.emplace_back(pattern);
1029 }
1030 patternStorage = createStorageForValueList(patterns, loc, rewriter);
1031 } else {
1032 // If we set the num_patterns parameter to 0, we can just pass a nullptr
1033 // as storage.
1034 patternStorage = LLVM::ZeroOp::create(rewriter, loc, ptrTy);
1035 }
1036
1037 StringRef apiCallName = "Z3_mk_forall_const";
1038 if (std::is_same_v<QuantifierOp, ExistsOp>)
1039 apiCallName = "Z3_mk_exists_const";
1040 Value quantifierExp =
1041 buildPtrAPICall(rewriter, loc, apiCallName,
1042 {weight, numDeclsVal, boundStorage, numPatternsVal,
1043 patternStorage, bodyExp});
1044
1045 rewriter.replaceOp(op, quantifierExp);
1046 return success();
1047 }
1048};
1049
1050/// Lower `smt.bv.repeat` operations to Z3 API function calls of the form
1051/// ```
1052/// Z3_ast Z3_API Z3_mk_repeat(Z3_context c, unsigned i, Z3_ast t1);
1053/// ```
1054struct RepeatOpLowering : public SMTLoweringPattern<RepeatOp> {
1055 using SMTLoweringPattern::SMTLoweringPattern;
1056
1057 LogicalResult
1058 matchAndRewrite(RepeatOp op, OpAdaptor adaptor,
1059 ConversionPatternRewriter &rewriter) const final {
1060 Value count = LLVM::ConstantOp::create(
1061 rewriter, op.getLoc(), rewriter.getI32Type(), op.getCount());
1062 rewriter.replaceOp(op,
1063 buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_repeat",
1064 {count, adaptor.getInput()}));
1065 return success();
1066 }
1067};
1068
1069/// Lower `smt.bv.extract` operations to Z3 API function calls of the following
1070/// form, where the output bit-vector has size `n = high - low + 1`. This means,
1071/// both the 'high' and 'low' indices are inclusive.
1072/// ```
1073/// Z3_ast Z3_API Z3_mk_extract(Z3_context c, unsigned high, unsigned low,
1074/// Z3_ast t1);
1075/// ```
1076struct ExtractOpLowering : public SMTLoweringPattern<ExtractOp> {
1077 using SMTLoweringPattern::SMTLoweringPattern;
1078
1079 LogicalResult
1080 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
1081 ConversionPatternRewriter &rewriter) const final {
1082 Location loc = op.getLoc();
1083 Value low = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1084 adaptor.getLowBit());
1085 Value high = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(),
1086 adaptor.getLowBit() +
1087 op.getType().getWidth() - 1);
1088 rewriter.replaceOp(op, buildPtrAPICall(rewriter, loc, "Z3_mk_extract",
1089 {high, low, adaptor.getInput()}));
1090 return success();
1091 }
1092};
1093
1094/// Lower `smt.array.broadcast` operations to Z3 API function calls of the form
1095/// ```
1096/// Z3_ast Z3_API Z3_mk_const_array(Z3_context c, Z3_sort domain, Z3_ast v);
1097/// ```
1098struct ArrayBroadcastOpLowering
1099 : public SMTLoweringPattern<smt::ArrayBroadcastOp> {
1100 using SMTLoweringPattern::SMTLoweringPattern;
1101
1102 LogicalResult
1103 matchAndRewrite(smt::ArrayBroadcastOp op, OpAdaptor adaptor,
1104 ConversionPatternRewriter &rewriter) const final {
1105 auto domainSort = buildSort(
1106 rewriter, op.getLoc(),
1107 cast<smt::ArrayType>(op.getResult().getType()).getDomainType());
1108
1109 rewriter.replaceOp(op, buildPtrAPICall(rewriter, op.getLoc(),
1110 "Z3_mk_const_array",
1111 {domainSort, adaptor.getValue()}));
1112 return success();
1113 }
1114};
1115
1116/// Lower the `smt.constant` operation to one of the following Z3 API function
1117/// calls depending on the value of the boolean attribute.
1118/// ```
1119/// Z3_ast Z3_API Z3_mk_true(Z3_context c);
1120/// Z3_ast Z3_API Z3_mk_false(Z3_context c);
1121/// ```
1122struct BoolConstantOpLowering : public SMTLoweringPattern<smt::BoolConstantOp> {
1123 using SMTLoweringPattern::SMTLoweringPattern;
1124
1125 LogicalResult
1126 matchAndRewrite(smt::BoolConstantOp op, OpAdaptor adaptor,
1127 ConversionPatternRewriter &rewriter) const final {
1128 rewriter.replaceOp(
1129 op, buildPtrAPICall(rewriter, op.getLoc(),
1130 adaptor.getValue() ? "Z3_mk_true" : "Z3_mk_false"));
1131 return success();
1132 }
1133};
1134
1135/// Lower `smt.int.constant` operations to one of the following two Z3 API
1136/// function calls depending on whether the storage APInt has a bit-width that
1137/// fits in a `uint64_t`.
1138/// ```
1139/// Z3_sort Z3_API Z3_mk_int_sort(Z3_context c);
1140///
1141/// Z3_ast Z3_API Z3_mk_int64(Z3_context c, int64_t v, Z3_sort ty);
1142///
1143/// Z3_ast Z3_API Z3_mk_numeral(Z3_context c, Z3_string numeral, Z3_sort ty);
1144/// Z3_ast Z3_API Z3_mk_unary_minus(Z3_context c, Z3_ast arg);
1145/// ```
1146struct IntConstantOpLowering : public SMTLoweringPattern<smt::IntConstantOp> {
1147 using SMTLoweringPattern::SMTLoweringPattern;
1148
1149 LogicalResult
1150 matchAndRewrite(smt::IntConstantOp op, OpAdaptor adaptor,
1151 ConversionPatternRewriter &rewriter) const final {
1152 Location loc = op.getLoc();
1153 Value type = buildPtrAPICall(rewriter, loc, "Z3_mk_int_sort");
1154 if (adaptor.getValue().getBitWidth() <= 64) {
1155 Value val = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1156 adaptor.getValue().getSExtValue());
1157 rewriter.replaceOp(
1158 op, buildPtrAPICall(rewriter, loc, "Z3_mk_int64", {val, type}));
1159 return success();
1160 }
1161
1162 std::string numeralStr;
1163 llvm::raw_string_ostream stream(numeralStr);
1164 stream << adaptor.getValue().abs();
1165
1166 Value numeral = buildString(rewriter, loc, numeralStr);
1167 Value intNumeral =
1168 buildPtrAPICall(rewriter, loc, "Z3_mk_numeral", {numeral, type});
1169
1170 if (adaptor.getValue().isNegative())
1171 intNumeral =
1172 buildPtrAPICall(rewriter, loc, "Z3_mk_unary_minus", intNumeral);
1173
1174 rewriter.replaceOp(op, intNumeral);
1175 return success();
1176 }
1177};
1178
1179/// Lower `smt.int.cmp` operations to one of the following Z3 API function calls
1180/// depending on the predicate.
1181/// ```
1182/// Z3_ast Z3_API Z3_mk_{{pred}}(Z3_context c, Z3_ast t1, Z3_ast t2);
1183/// ```
1184struct IntCmpOpLowering : public SMTLoweringPattern<IntCmpOp> {
1185 using SMTLoweringPattern::SMTLoweringPattern;
1186
1187 LogicalResult
1188 matchAndRewrite(IntCmpOp op, OpAdaptor adaptor,
1189 ConversionPatternRewriter &rewriter) const final {
1190 rewriter.replaceOp(
1191 op,
1192 buildPtrAPICall(rewriter, op.getLoc(),
1193 "Z3_mk_" + stringifyIntPredicate(op.getPred()).str(),
1194 {adaptor.getLhs(), adaptor.getRhs()}));
1195 return success();
1196 }
1197};
1198
1199/// Lower `smt.int2bv` operations to the following Z3 API function calls.
1200/// ```
1201/// Z3_ast Z3_API Z3_mk_int2bv(Z3_context c, unsigned n, Z3_ast t1);
1202/// ```
1203struct Int2BVOpLowering : public SMTLoweringPattern<Int2BVOp> {
1204 using SMTLoweringPattern::SMTLoweringPattern;
1205
1206 LogicalResult
1207 matchAndRewrite(Int2BVOp op, OpAdaptor adaptor,
1208 ConversionPatternRewriter &rewriter) const final {
1209 Value widthConst =
1210 LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI32Type(),
1211 op.getResult().getType().getWidth());
1212 rewriter.replaceOp(op,
1213 buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_int2bv",
1214 {widthConst, adaptor.getInput()}));
1215 return success();
1216 }
1217};
1218
1219/// Lower `smt.bv2int` operations to the following Z3 API function call.
1220/// ```
1221/// Z3_ast Z3_API Z3_mk_bv2int(Z3_context c, Z3_ast t1, bool is_signed)
1222/// ```
1223struct BV2IntOpLowering : public SMTLoweringPattern<BV2IntOp> {
1224 using SMTLoweringPattern::SMTLoweringPattern;
1225
1226 LogicalResult
1227 matchAndRewrite(BV2IntOp op, OpAdaptor adaptor,
1228 ConversionPatternRewriter &rewriter) const final {
1229 // FIXME: ideally we don't want to use i1 here, since bools can sometimes be
1230 // compiled to wider widths in LLVM
1231 Value isSignedConst = LLVM::ConstantOp::create(
1232 rewriter, op->getLoc(), rewriter.getI1Type(), op.getIsSigned());
1233 rewriter.replaceOp(op,
1234 buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_bv2int",
1235 {adaptor.getInput(), isSignedConst}));
1236 return success();
1237 }
1238};
1239
1240/// Lower `smt.bv.cmp` operations to one of the following Z3 API function calls,
1241/// performing two's complement comparison, depending on the predicate
1242/// attribute.
1243/// ```
1244/// Z3_ast Z3_API Z3_mk_bv{{pred}}(Z3_context c, Z3_ast t1, Z3_ast t2);
1245/// ```
1246struct BVCmpOpLowering : public SMTLoweringPattern<BVCmpOp> {
1247 using SMTLoweringPattern::SMTLoweringPattern;
1248
1249 LogicalResult
1250 matchAndRewrite(BVCmpOp op, OpAdaptor adaptor,
1251 ConversionPatternRewriter &rewriter) const final {
1252 rewriter.replaceOp(
1253 op, buildPtrAPICall(rewriter, op.getLoc(),
1254 "Z3_mk_bv" +
1255 stringifyBVCmpPredicate(op.getPred()).str(),
1256 {adaptor.getLhs(), adaptor.getRhs()}));
1257 return success();
1258 }
1259};
1260
1261/// Expand the `smt.int.abs` operation to a `smt.ite` operation.
1262struct IntAbsOpLowering : public SMTLoweringPattern<IntAbsOp> {
1263 using SMTLoweringPattern::SMTLoweringPattern;
1264
1265 LogicalResult
1266 matchAndRewrite(IntAbsOp op, OpAdaptor adaptor,
1267 ConversionPatternRewriter &rewriter) const final {
1268 Location loc = op.getLoc();
1269 Value zero = IntConstantOp::create(
1270 rewriter, loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
1271 Value cmp = IntCmpOp::create(rewriter, loc, IntPredicate::lt,
1272 adaptor.getInput(), zero);
1273 Value neg = IntSubOp::create(rewriter, loc, zero, adaptor.getInput());
1274 rewriter.replaceOpWithNewOp<IteOp>(op, cmp, neg, adaptor.getInput());
1275 return success();
1276 }
1277};
1278
1279//===----------------------------------------------------------------------===//
1280// Placeholder Debug Patterns
1281//===----------------------------------------------------------------------===//
1282// For now we want to ignore Debug variable and scope ops - eventually we'll
1283// give this debug info to Z3
1284
1285/// Strip dbg.variable ops.
1286struct DbgVariableLowering : public OpConversionPattern<debug::VariableOp> {
1287 using OpConversionPattern::OpConversionPattern;
1288
1289 LogicalResult
1290 matchAndRewrite(debug::VariableOp op, OpAdaptor adaptor,
1291 ConversionPatternRewriter &rewriter) const final {
1292 rewriter.eraseOp(op);
1293 return success();
1294 }
1295};
1296
1297/// Strip dbg.scope ops.
1298struct DbgScopeLowering : public OpConversionPattern<debug::ScopeOp> {
1299 using OpConversionPattern::OpConversionPattern;
1300
1301 LogicalResult
1302 matchAndRewrite(debug::ScopeOp op, OpAdaptor adaptor,
1303 ConversionPatternRewriter &rewriter) const final {
1304 // Make sure scope's only users are variables and therefore being deleted
1305 if (llvm::any_of(op->getUsers(), [](Operation *user) {
1306 return !isa<debug::VariableOp>(user);
1307 }))
1308 return failure();
1309 rewriter.eraseOp(op);
1310 return success();
1311 }
1312};
1313
1314} // namespace
1315
1316//===----------------------------------------------------------------------===//
1317// Pass Implementation
1318//===----------------------------------------------------------------------===//
1319
1320namespace {
1321struct LowerSMTToZ3LLVMPass
1322 : public circt::impl::LowerSMTToZ3LLVMBase<LowerSMTToZ3LLVMPass> {
1323 using Base::Base;
1324 void runOnOperation() override;
1325};
1326} // namespace
1327
1328void circt::populateSMTToZ3LLVMTypeConverter(TypeConverter &converter) {
1329 converter.addConversion([](smt::BoolType type) {
1330 return LLVM::LLVMPointerType::get(type.getContext());
1331 });
1332 converter.addConversion([](smt::BitVectorType type) {
1333 return LLVM::LLVMPointerType::get(type.getContext());
1334 });
1335 converter.addConversion([](smt::ArrayType type) {
1336 return LLVM::LLVMPointerType::get(type.getContext());
1337 });
1338 converter.addConversion([](smt::IntType type) {
1339 return LLVM::LLVMPointerType::get(type.getContext());
1340 });
1341 converter.addConversion([](smt::SMTFuncType type) {
1342 return LLVM::LLVMPointerType::get(type.getContext());
1343 });
1344 converter.addConversion([](smt::SortType type) {
1345 return LLVM::LLVMPointerType::get(type.getContext());
1346 });
1347}
1348
1350 RewritePatternSet &patterns, TypeConverter &converter,
1351 SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options) {
1352#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS) \
1353 patterns.add<VariadicSMTPattern<OP>>(/*NOLINT(bugprone-macro-parentheses)*/ \
1354 converter, patterns.getContext(), \
1355 globals, options, APINAME, \
1356 MIN_NUM_ARGS);
1357
1358#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS) \
1359 patterns.add<OneToOneSMTPattern<OP>>(/*NOLINT(bugprone-macro-parentheses)*/ \
1360 converter, patterns.getContext(), \
1361 globals, options, APINAME, NUM_ARGS);
1362
1363 // Lower `smt.distinct` operations which allows a variadic number of operands
1364 // according to the `:pairwise` attribute. The Z3 API function supports a
1365 // variadic number of operands as well, i.e., a direct lowering is possible:
1366 // ```
1367 // Z3_ast Z3_API Z3_mk_distinct(Z3_context c, unsigned num_args, Z3_ast const
1368 // args[])
1369 // ```
1370 // The API function requires num_args > 1 which is guaranteed to be satisfied
1371 // because `smt.distinct` is verified to have > 1 operands.
1372 ADD_VARIADIC_PATTERN(DistinctOp, "Z3_mk_distinct", 2);
1373
1374 // Lower `smt.and` operations which allows a variadic number of operands
1375 // according to the `:left-assoc` attribute. The Z3 API function supports a
1376 // variadic number of operands as well, i.e., a direct lowering is possible:
1377 // ```
1378 // Z3_ast Z3_API Z3_mk_and(Z3_context c, unsigned num_args, Z3_ast const
1379 // args[])
1380 // ```
1381 // The API function requires num_args > 1. This is not guaranteed by the
1382 // `smt.and` operation and thus the pattern will not apply when no operand is
1383 // present. The constant folder of the operation is assumed to fold this to
1384 // a constant 'true' (neutral element of AND).
1385 ADD_VARIADIC_PATTERN(AndOp, "Z3_mk_and", 2);
1386
1387 // Lower `smt.or` operations which allows a variadic number of operands
1388 // according to the `:left-assoc` attribute. The Z3 API function supports a
1389 // variadic number of operands as well, i.e., a direct lowering is possible:
1390 // ```
1391 // Z3_ast Z3_API Z3_mk_or(Z3_context c, unsigned num_args, Z3_ast const
1392 // args[])
1393 // ```
1394 // The API function requires num_args > 1. This is not guaranteed by the
1395 // `smt.or` operation and thus the pattern will not apply when no operand is
1396 // present. The constant folder of the operation is assumed to fold this to
1397 // a constant 'false' (neutral element of OR).
1398 ADD_VARIADIC_PATTERN(OrOp, "Z3_mk_or", 2);
1399
1400 // Lower `smt.not` operations to the following Z3 API function:
1401 // ```
1402 // Z3_ast Z3_API Z3_mk_not(Z3_context c, Z3_ast a);
1403 // ```
1404 ADD_ONE_TO_ONE_PATTERN(NotOp, "Z3_mk_not", 1);
1405
1406 // Lower `smt.xor` operations which allows a variadic number of operands
1407 // according to the `:left-assoc` attribute. The Z3 API function, however,
1408 // only takes two operands.
1409 // ```
1410 // Z3_ast Z3_API Z3_mk_xor(Z3_context c, Z3_ast t1, Z3_ast t2);
1411 // ```
1412 // Therefore, we need to decompose the operation first to a sequence of XOR
1413 // operations matching the left associative behavior.
1414 patterns.add<LowerLeftAssocSMTPattern<XOrOp>>(
1415 converter, patterns.getContext(), globals, options);
1416 ADD_ONE_TO_ONE_PATTERN(XOrOp, "Z3_mk_xor", 2);
1417
1418 // Lower `smt.implies` operations to the following Z3 API function:
1419 // ```
1420 // Z3_ast Z3_API Z3_mk_implies(Z3_context c, Z3_ast t1, Z3_ast t2);
1421 // ```
1422 ADD_ONE_TO_ONE_PATTERN(ImpliesOp, "Z3_mk_implies", 2);
1423
1424 // All the bit-vector arithmetic and bitwise operations conveniently lower to
1425 // Z3 API function calls with essentially matching names and a one-to-one
1426 // correspondence of operands to call arguments.
1427 ADD_ONE_TO_ONE_PATTERN(BVNegOp, "Z3_mk_bvneg", 1);
1428 ADD_ONE_TO_ONE_PATTERN(BVAddOp, "Z3_mk_bvadd", 2);
1429 ADD_ONE_TO_ONE_PATTERN(BVMulOp, "Z3_mk_bvmul", 2);
1430 ADD_ONE_TO_ONE_PATTERN(BVURemOp, "Z3_mk_bvurem", 2);
1431 ADD_ONE_TO_ONE_PATTERN(BVSRemOp, "Z3_mk_bvsrem", 2);
1432 ADD_ONE_TO_ONE_PATTERN(BVSModOp, "Z3_mk_bvsmod", 2);
1433 ADD_ONE_TO_ONE_PATTERN(BVUDivOp, "Z3_mk_bvudiv", 2);
1434 ADD_ONE_TO_ONE_PATTERN(BVSDivOp, "Z3_mk_bvsdiv", 2);
1435 ADD_ONE_TO_ONE_PATTERN(BVShlOp, "Z3_mk_bvshl", 2);
1436 ADD_ONE_TO_ONE_PATTERN(BVLShrOp, "Z3_mk_bvlshr", 2);
1437 ADD_ONE_TO_ONE_PATTERN(BVAShrOp, "Z3_mk_bvashr", 2);
1438 ADD_ONE_TO_ONE_PATTERN(BVNotOp, "Z3_mk_bvnot", 1);
1439 ADD_ONE_TO_ONE_PATTERN(BVAndOp, "Z3_mk_bvand", 2);
1440 ADD_ONE_TO_ONE_PATTERN(BVOrOp, "Z3_mk_bvor", 2);
1441 ADD_ONE_TO_ONE_PATTERN(BVXOrOp, "Z3_mk_bvxor", 2);
1442
1443 // The `smt.bv.concat` operation only supports two operands, just like the
1444 // Z3 API function.
1445 // ```
1446 // Z3_ast Z3_API Z3_mk_concat(Z3_context c, Z3_ast t1, Z3_ast t2);
1447 // ```
1448 ADD_ONE_TO_ONE_PATTERN(ConcatOp, "Z3_mk_concat", 2);
1449
1450 // Lower the `smt.ite` operation to the following Z3 API function call, where
1451 // `t1` must have boolean sort.
1452 // ```
1453 // Z3_ast Z3_API Z3_mk_ite(Z3_context c, Z3_ast t1, Z3_ast t2, Z3_ast t3);
1454 // ```
1455 ADD_ONE_TO_ONE_PATTERN(IteOp, "Z3_mk_ite", 3);
1456
1457 // Lower the `smt.array.select` operation to the following Z3 function call.
1458 // The operand declaration of the operation matches the order of arguments of
1459 // the API function.
1460 // ```
1461 // Z3_ast Z3_API Z3_mk_select(Z3_context c, Z3_ast a, Z3_ast i);
1462 // ```
1463 // Where `a` is the array expression and `i` is the index expression.
1464 ADD_ONE_TO_ONE_PATTERN(ArraySelectOp, "Z3_mk_select", 2);
1465
1466 // Lower the `smt.array.store` operation to the following Z3 function call.
1467 // The operand declaration of the operation matches the order of arguments of
1468 // the API function.
1469 // ```
1470 // Z3_ast Z3_API Z3_mk_store(Z3_context c, Z3_ast a, Z3_ast i, Z3_ast v);
1471 // ```
1472 // Where `a` is the array expression, `i` is the index expression, and `v` is
1473 // the value expression to be stored.
1474 ADD_ONE_TO_ONE_PATTERN(ArrayStoreOp, "Z3_mk_store", 3);
1475
1476 // Lower the `smt.int.add` operation to the following Z3 API function call.
1477 // ```
1478 // Z3_ast Z3_API Z3_mk_add(Z3_context c, unsigned num_args, Z3_ast const
1479 // args[]);
1480 // ```
1481 // The number of arguments must be greater than zero. Therefore, the pattern
1482 // will fail if applied to an operation with less than two operands.
1483 ADD_VARIADIC_PATTERN(IntAddOp, "Z3_mk_add", 2);
1484
1485 // Lower the `smt.int.mul` operation to the following Z3 API function call.
1486 // ```
1487 // Z3_ast Z3_API Z3_mk_mul(Z3_context c, unsigned num_args, Z3_ast const
1488 // args[]);
1489 // ```
1490 // The number of arguments must be greater than zero. Therefore, the pattern
1491 // will fail if applied to an operation with less than two operands.
1492 ADD_VARIADIC_PATTERN(IntMulOp, "Z3_mk_mul", 2);
1493
1494 // Lower the `smt.int.sub` operation to the following Z3 API function call.
1495 // ```
1496 // Z3_ast Z3_API Z3_mk_sub(Z3_context c, unsigned num_args, Z3_ast const
1497 // args[]);
1498 // ```
1499 // The number of arguments must be greater than zero. Since the `smt.int.sub`
1500 // operation always has exactly two operands, this trivially holds.
1501 ADD_VARIADIC_PATTERN(IntSubOp, "Z3_mk_sub", 2);
1502
1503 // Lower the `smt.int.div` operation to the following Z3 API function call.
1504 // ```
1505 // Z3_ast Z3_API Z3_mk_div(Z3_context c, Z3_ast arg1, Z3_ast arg2);
1506 // ```
1507 ADD_ONE_TO_ONE_PATTERN(IntDivOp, "Z3_mk_div", 2);
1508
1509 // Lower the `smt.int.mod` operation to the following Z3 API function call.
1510 // ```
1511 // Z3_ast Z3_API Z3_mk_mod(Z3_context c, Z3_ast arg1, Z3_ast arg2);
1512 // ```
1513 ADD_ONE_TO_ONE_PATTERN(IntModOp, "Z3_mk_mod", 2);
1514
1515#undef ADD_VARIADIC_PATTERN
1516#undef ADD_ONE_TO_ONE_PATTERN
1517
1518 // Lower `smt.eq` operations which allows a variadic number of operands
1519 // according to the `:chainable` attribute. The Z3 API function does not
1520 // support a variadic number of operands, but exactly two:
1521 // ```
1522 // Z3_ast Z3_API Z3_mk_eq(Z3_context c, Z3_ast l, Z3_ast r)
1523 // ```
1524 // As a result, we first apply a rewrite pattern that unfolds chainable
1525 // operators and then lower it one-to-one to the API function. In this case,
1526 // this means:
1527 // ```
1528 // eq(a,b,c,d) ->
1529 // and(eq(a,b), eq(b,c), eq(c,d)) ->
1530 // and(Z3_mk_eq(ctx, a, b), Z3_mk_eq(ctx, b, c), Z3_mk_eq(ctx, c, d))
1531 // ```
1532 // The patterns for `smt.and` will then do the remaining work.
1533 patterns.add<LowerChainableSMTPattern<EqOp>>(converter, patterns.getContext(),
1534 globals, options);
1535 patterns.add<OneToOneSMTPattern<EqOp>>(converter, patterns.getContext(),
1536 globals, options, "Z3_mk_eq", 2);
1537
1538 // Other lowering patterns. Refer to their implementation directly for more
1539 // information.
1540 patterns.add<BVConstantOpLowering, DeclareFunOpLowering, AssertOpLowering,
1541 ResetOpLowering, PushOpLowering, PopOpLowering, CheckOpLowering,
1542 SolverOpLowering, ApplyFuncOpLowering, YieldOpLowering,
1543 RepeatOpLowering, ExtractOpLowering, BoolConstantOpLowering,
1544 IntConstantOpLowering, ArrayBroadcastOpLowering, BVCmpOpLowering,
1545 IntCmpOpLowering, IntAbsOpLowering, Int2BVOpLowering,
1546 BV2IntOpLowering, QuantifierLowering<ForallOp>,
1547 QuantifierLowering<ExistsOp>>(converter, patterns.getContext(),
1548 globals, options);
1549 patterns.add<DbgVariableLowering, DbgScopeLowering>(patterns.getContext());
1550}
1551
1552void LowerSMTToZ3LLVMPass::runOnOperation() {
1553 LowerSMTToZ3LLVMOptions options;
1554 options.debug = debug;
1555
1556 // Check that the lowering is possible
1557 // Specifically, check that the use of set-logic ops is valid for z3
1558 auto setLogicCheck = getOperation().walk([&](SolverOp solverOp)
1559 -> WalkResult {
1560 // Check that solver ops only contain one set-logic op and that they're at
1561 // the start of the body
1562 auto setLogicOps = solverOp.getBodyRegion().getOps<smt::SetLogicOp>();
1563 auto numSetLogicOps = std::distance(setLogicOps.begin(), setLogicOps.end());
1564 if (numSetLogicOps > 1) {
1565 return solverOp.emitError(
1566 "multiple set-logic operations found in one solver operation - Z3 "
1567 "only supports setting the logic once");
1568 }
1569 if (numSetLogicOps == 1)
1570 // Check the only ops before the set-logic op are ConstantLike
1571 for (auto &blockOp : solverOp.getBodyRegion().getOps()) {
1572 if (isa<smt::SetLogicOp>(blockOp))
1573 break;
1574 if (!blockOp.hasTrait<OpTrait::ConstantLike>()) {
1575 return solverOp.emitError("set-logic operation must be the first "
1576 "non-constant operation in a solver "
1577 "operation");
1578 }
1579 }
1580 return WalkResult::advance();
1581 });
1582 if (setLogicCheck.wasInterrupted())
1583 return signalPassFailure();
1584
1585 // Set up the type converter
1586 LLVMTypeConverter converter(&getContext());
1588
1589 RewritePatternSet patterns(&getContext());
1590
1591 // Populate the func to LLVM conversion patterns for two reasons:
1592 // * Typically functions are represented using `func.func` and including the
1593 // patterns to lower them here is more convenient for most lowering
1594 // pipelines (avoids running another pass).
1595 // * Already having `llvm.func` in the input or lowering `func.func` before
1596 // the SMT in the body leads to issues because the SCF conversion patterns
1597 // don't take the type converter into consideration and thus create blocks
1598 // with the old types for block arguments. However, the conversion happens
1599 // top-down and thus are assumed to be converted by the parent function op
1600 // which at that point would have already been lowered (and the blocks are
1601 // also not there when doing everything in one pass, i.e.,
1602 // `populateAnyFunctionOpInterfaceTypeConversionPattern` does not have any
1603 // effect as well). Are the SCF lowering patterns actually broken and should
1604 // take a type-converter?
1605 populateFuncToLLVMConversionPatterns(converter, patterns);
1606 arith::populateArithToLLVMConversionPatterns(converter, patterns);
1607
1608 // Populate SCF to CF and CF to LLVM lowering patterns because we create
1609 // `scf.if` operations in the lowering patterns for convenience (given the
1610 // above issue we might want to lower to LLVM directly; or fix upstream?)
1611 populateSCFToControlFlowConversionPatterns(patterns);
1612 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
1613
1614 // Create the globals to store the context and solver and populate the SMT
1615 // lowering patterns.
1616 OpBuilder builder(&getContext());
1617 auto globals = SMTGlobalsHandler::create(builder, getOperation());
1618 populateSMTToZ3LLVMConversionPatterns(patterns, converter, globals, options);
1619
1620 // Do a full conversion. This assumes that all other dialects have been
1621 // lowered before this pass already.
1622 LLVMConversionTarget target(getContext());
1623 target.addLegalOp<mlir::ModuleOp>();
1624 target.addLegalOp<scf::YieldOp>();
1625 target.addIllegalDialect<debug::DebugDialect>();
1626
1627 if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
1628 return signalPassFailure();
1629}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static FIRRTLBaseType convertType(FIRRTLBaseType type)
Returns null type if no conversion is needed.
Definition DropConst.cpp:32
#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS)
#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
RewritePatternSet pattern
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateSMTToZ3LLVMTypeConverter(TypeConverter &converter)
Populate the given type converter with the SMT to LLVM type conversions.
void populateSMTToZ3LLVMConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter, SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options)
Add the SMT to LLVM IR conversion patterns to 'patterns'.
Definition debug.py:1
A symbol cache for LLVM globals and functions relevant to SMT lowering patterns.
Definition SMTToZ3LLVM.h:25
static SMTGlobalsHandler create(OpBuilder &builder, ModuleOp module)
Creates the LLVM global operations to store the pointers to the solver and the context and returns a ...
SMTGlobalsHandler(ModuleOp module, mlir::LLVM::GlobalOp solver, mlir::LLVM::GlobalOp ctx)
Initializes the caches and keeps track of the given globals to store the pointers to the SMT solver a...