CIRCT 20.0.0git
Loading...
Searching...
No Matches
LowerSMTToZ3LLVM.cpp
Go to the documentation of this file.
1//===- LowerSMTToZ3LLVM.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
13#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
14#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
17#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
18#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
22#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23#include "mlir/Dialect/SCF/IR/SCF.h"
24#include "mlir/IR/BuiltinDialect.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/Debug.h"
29
30#define DEBUG_TYPE "lower-smt-to-z3-llvm"
31
32namespace circt {
33#define GEN_PASS_DEF_LOWERSMTTOZ3LLVM
34#include "circt/Conversion/Passes.h.inc"
35} // namespace circt
36
37using namespace mlir;
38using namespace circt;
39using namespace smt;
40
41//===----------------------------------------------------------------------===//
42// SMTGlobalHandler implementation
43//===----------------------------------------------------------------------===//
44
46 ModuleOp module) {
47 OpBuilder::InsertionGuard guard(builder);
48 builder.setInsertionPointToStart(module.getBody());
49
50 SymbolCache symCache;
51 symCache.addDefinitions(module);
53 names.add(symCache);
54
55 Location loc = module.getLoc();
56 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
57
58 auto createGlobal = [&](StringRef namePrefix) {
59 auto global = builder.create<LLVM::GlobalOp>(
60 loc, ptrTy, false, LLVM::Linkage::Internal, names.newName(namePrefix),
61 Attribute{}, /*alignment=*/8);
62 OpBuilder::InsertionGuard g(builder);
63 builder.createBlock(&global.getInitializer());
64 Value res = builder.create<LLVM::ZeroOp>(loc, ptrTy);
65 builder.create<LLVM::ReturnOp>(loc, res);
66 return global;
67 };
68
69 auto ctxGlobal = createGlobal("ctx");
70 auto solverGlobal = createGlobal("solver");
71
72 return SMTGlobalsHandler(std::move(names), solverGlobal, ctxGlobal);
73}
74
76 mlir::LLVM::GlobalOp solver,
77 mlir::LLVM::GlobalOp ctx)
78 : solver(solver), ctx(ctx), names(names) {}
79
81 mlir::LLVM::GlobalOp solver,
82 mlir::LLVM::GlobalOp ctx)
83 : solver(solver), ctx(ctx) {
84 SymbolCache symCache;
85 symCache.addDefinitions(module);
86 names.add(symCache);
87}
88
89//===----------------------------------------------------------------------===//
90// Lowering Pattern Base
91//===----------------------------------------------------------------------===//
92
93namespace {
94
95template <typename OpTy>
96class SMTLoweringPattern : public OpConversionPattern<OpTy> {
97public:
98 SMTLoweringPattern(const TypeConverter &typeConverter, MLIRContext *context,
99 SMTGlobalsHandler &globals,
100 const LowerSMTToZ3LLVMOptions &options)
101 : OpConversionPattern<OpTy>(typeConverter, context), globals(globals),
102 options(options) {}
103
104private:
105 Value buildGlobalPtrToGlobal(OpBuilder &builder, Location loc,
106 LLVM::GlobalOp global,
107 DenseMap<Block *, Value> &cache) const {
108 Block *block = builder.getBlock();
109 if (auto iter = cache.find(block); iter != cache.end())
110 return iter->getSecond();
111
112 OpBuilder::InsertionGuard g(builder);
113 builder.setInsertionPointToStart(block);
114 Value globalAddr = builder.create<LLVM::AddressOfOp>(loc, global);
115 return cache[block] = builder.create<LLVM::LoadOp>(
116 loc, LLVM::LLVMPointerType::get(builder.getContext()),
117 globalAddr);
118 }
119
120protected:
121 /// A convenience function to get the pointer to the context from the 'global'
122 /// operation. The result is cached for each basic block, i.e., it is assumed
123 /// that this function is never called in the same basic block again at a
124 /// location (insertion point of the 'builder') not dominating all previous
125 /// locations this function was called at.
126 Value buildContextPtr(OpBuilder &builder, Location loc) const {
127 return buildGlobalPtrToGlobal(builder, loc, globals.ctx, globals.ctxCache);
128 }
129
130 /// A convenience function to get the pointer to the solver from the 'global'
131 /// operation. The result is cached for each basic block, i.e., it is assumed
132 /// that this function is never called in the same basic block again at a
133 /// location (insertion point of the 'builder') not dominating all previous
134 /// locations this function was called at.
135 Value buildSolverPtr(OpBuilder &builder, Location loc) const {
136 return buildGlobalPtrToGlobal(builder, loc, globals.solver,
137 globals.solverCache);
138 }
139
140 /// Create a `llvm.call` operation to a function with the given 'name' and
141 /// 'type'. If there does not already exist a (external) function with that
142 /// name create a matching external function declaration.
143 LLVM::CallOp buildCall(OpBuilder &builder, Location loc, StringRef name,
144 LLVM::LLVMFunctionType funcType,
145 ValueRange args) const {
146 auto &funcOp = globals.funcMap[builder.getStringAttr(name)];
147 if (!funcOp) {
148 OpBuilder::InsertionGuard guard(builder);
149 auto module =
150 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
151 builder.setInsertionPointToEnd(module.getBody());
152 funcOp = LLVM::lookupOrCreateFn(module, name, funcType.getParams(),
153 funcType.getReturnType(),
154 funcType.getVarArg());
155 }
156 return builder.create<LLVM::CallOp>(loc, funcOp, args);
157 }
158
159 /// Build a global constant for the given string and construct an 'addressof'
160 /// operation at the current 'builder' insertion point to get a pointer to it.
161 /// Multiple calls with the same string will reuse the same global. It is
162 /// guaranteed that the symbol of the global will be unique.
163 Value buildString(OpBuilder &builder, Location loc, StringRef str) const {
164 auto &global = globals.stringCache[builder.getStringAttr(str)];
165 if (!global) {
166 OpBuilder::InsertionGuard guard(builder);
167 auto module =
168 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
169 builder.setInsertionPointToEnd(module.getBody());
170 auto arrayTy =
171 LLVM::LLVMArrayType::get(builder.getI8Type(), str.size() + 1);
172 auto strAttr = builder.getStringAttr(str.str() + '\00');
173 global = builder.create<LLVM::GlobalOp>(
174 loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
175 globals.names.newName("str"), strAttr);
176 }
177 return builder.create<LLVM::AddressOfOp>(loc, global);
178 }
179 /// Most API functions require a pointer to the the Z3 context object as the
180 /// first argument. This helper function prepends this pointer value to the
181 /// call for convenience.
182 LLVM::CallOp buildAPICallWithContext(OpBuilder &builder, Location loc,
183 StringRef name, Type returnType,
184 ValueRange args = {}) const {
185 auto ctx = buildContextPtr(builder, loc);
186 SmallVector<Value> arguments;
187 arguments.emplace_back(ctx);
188 arguments.append(SmallVector<Value>(args));
189 return buildCall(
190 builder, loc, name,
191 LLVM::LLVMFunctionType::get(
192 returnType, SmallVector<Type>(ValueRange(arguments).getTypes())),
193 arguments);
194 }
195
196 /// Most API functions we need to call return a 'Z3_AST' object which is a
197 /// pointer in LLVM. This helper function simplifies calling those API
198 /// functions.
199 Value buildPtrAPICall(OpBuilder &builder, Location loc, StringRef name,
200 ValueRange args = {}) const {
201 return buildAPICallWithContext(
202 builder, loc, name,
203 LLVM::LLVMPointerType::get(builder.getContext()), args)
204 ->getResult(0);
205 }
206
207 /// Build a value representing the SMT sort given with 'type'.
208 Value buildSort(OpBuilder &builder, Location loc, Type type) const {
209 // NOTE: if a type not handled by this switch is passed, an assertion will
210 // be triggered.
211 return TypeSwitch<Type, Value>(type)
212 .Case([&](smt::IntType ty) {
213 return buildPtrAPICall(builder, loc, "Z3_mk_int_sort");
214 })
215 .Case([&](smt::BitVectorType ty) {
216 Value bitwidth = builder.create<LLVM::ConstantOp>(
217 loc, builder.getI32Type(), ty.getWidth());
218 return buildPtrAPICall(builder, loc, "Z3_mk_bv_sort", {bitwidth});
219 })
220 .Case([&](smt::BoolType ty) {
221 return buildPtrAPICall(builder, loc, "Z3_mk_bool_sort");
222 })
223 .Case([&](smt::SortType ty) {
224 Value str = buildString(builder, loc, ty.getIdentifier());
225 Value sym =
226 buildPtrAPICall(builder, loc, "Z3_mk_string_symbol", {str});
227 return buildPtrAPICall(builder, loc, "Z3_mk_uninterpreted_sort",
228 {sym});
229 })
230 .Case([&](smt::ArrayType ty) {
231 return buildPtrAPICall(builder, loc, "Z3_mk_array_sort",
232 {buildSort(builder, loc, ty.getDomainType()),
233 buildSort(builder, loc, ty.getRangeType())});
234 });
235 }
236
237 SMTGlobalsHandler &globals;
238 const LowerSMTToZ3LLVMOptions &options;
239};
240
241//===----------------------------------------------------------------------===//
242// Lowering Patterns
243//===----------------------------------------------------------------------===//
244
245/// The 'smt.declare_fun' operation is used to declare both constants and
246/// functions. The Z3 API, however, uses two different functions. Therefore,
247/// depending on the result type of this operation, one of the following two
248/// API functions is used to create the symbolic value:
249/// ```
250/// Z3_ast Z3_API Z3_mk_fresh_const(Z3_context c, Z3_string prefix, Z3_sort ty);
251/// Z3_func_decl Z3_API Z3_mk_fresh_func_decl(
252/// Z3_context c, Z3_string prefix, unsigned domain_size,
253/// Z3_sort const domain[], Z3_sort range);
254/// ```
255struct DeclareFunOpLowering : public SMTLoweringPattern<DeclareFunOp> {
256 using SMTLoweringPattern::SMTLoweringPattern;
257
258 LogicalResult
259 matchAndRewrite(DeclareFunOp op, OpAdaptor adaptor,
260 ConversionPatternRewriter &rewriter) const final {
261 Location loc = op.getLoc();
262
263 // Create the name prefix.
264 Value prefix;
265 if (adaptor.getNamePrefix())
266 prefix = buildString(rewriter, loc, *adaptor.getNamePrefix());
267 else
268 prefix = rewriter.create<LLVM::ZeroOp>(
269 loc, LLVM::LLVMPointerType::get(getContext()));
270
271 // Handle the constant value case.
272 if (!isa<SMTFuncType>(op.getType())) {
273 Value sort = buildSort(rewriter, loc, op.getType());
274 Value constDecl =
275 buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_const", {prefix, sort});
276 rewriter.replaceOp(op, constDecl);
277 return success();
278 }
279
280 // Otherwise, we declare a function.
281 Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
282 auto funcType = cast<SMTFuncType>(op.getResult().getType());
283 Value rangeSort = buildSort(rewriter, loc, funcType.getRangeType());
284
285 Type arrTy =
286 LLVM::LLVMArrayType::get(llvmPtrTy, funcType.getDomainTypes().size());
287
288 Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
289 for (auto [i, ty] : llvm::enumerate(funcType.getDomainTypes())) {
290 Value sort = buildSort(rewriter, loc, ty);
291 domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, sort, i);
292 }
293
294 Value one =
295 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
296 Value domainStorage =
297 rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
298 rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
299
300 Value domainSize = rewriter.create<LLVM::ConstantOp>(
301 loc, rewriter.getI32Type(), funcType.getDomainTypes().size());
302 Value decl =
303 buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_func_decl",
304 {prefix, domainSize, domainStorage, rangeSort});
305
306 rewriter.replaceOp(op, decl);
307 return success();
308 }
309};
310
311/// Lower the 'smt.apply_func' operation to Z3 API calls of the form:
312/// ```
313/// Z3_ast Z3_API Z3_mk_app(Z3_context c, Z3_func_decl d,
314/// unsigned num_args, Z3_ast const args[]);
315/// ```
316struct ApplyFuncOpLowering : public SMTLoweringPattern<ApplyFuncOp> {
317 using SMTLoweringPattern::SMTLoweringPattern;
318
319 LogicalResult
320 matchAndRewrite(ApplyFuncOp op, OpAdaptor adaptor,
321 ConversionPatternRewriter &rewriter) const final {
322 Location loc = op.getLoc();
323 Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
324 Type arrTy = LLVM::LLVMArrayType::get(llvmPtrTy, adaptor.getArgs().size());
325
326 // Create an array of the function arguments.
327 Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
328 for (auto [i, arg] : llvm::enumerate(adaptor.getArgs()))
329 domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, arg, i);
330
331 // Store the array on the stack.
332 Value one =
333 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
334 Value domainStorage =
335 rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
336 rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
337
338 // Call the API function with a pointer to the function, the number of
339 // arguments, and the pointer to the arguments stored on the stack.
340 Value domainSize = rewriter.create<LLVM::ConstantOp>(
341 loc, rewriter.getI32Type(), adaptor.getArgs().size());
342 Value returnVal =
343 buildPtrAPICall(rewriter, loc, "Z3_mk_app",
344 {adaptor.getFunc(), domainSize, domainStorage});
345 rewriter.replaceOp(op, returnVal);
346
347 return success();
348 }
349};
350
351/// Lower the `smt.bv.constant` operation to either
352/// ```
353/// Z3_ast Z3_API Z3_mk_unsigned_int64(Z3_context c, uint64_t v, Z3_sort ty);
354/// ```
355/// if the bit-vector fits into a 64-bit integer or convert it to a string and
356/// use the sligtly slower but arbitrary precision API function:
357/// ```
358/// Z3_ast Z3_API Z3_mk_numeral(Z3_context c, Z3_string numeral, Z3_sort ty);
359/// ```
360/// Note that there is also an API function taking an array of booleans, and
361/// while those are typically compiled to 'i8' in LLVM they don't necessarily
362/// have to (I think).
363struct BVConstantOpLowering : public SMTLoweringPattern<smt::BVConstantOp> {
364 using SMTLoweringPattern::SMTLoweringPattern;
365
366 LogicalResult
367 matchAndRewrite(smt::BVConstantOp op, OpAdaptor adaptor,
368 ConversionPatternRewriter &rewriter) const final {
369 Location loc = op.getLoc();
370 unsigned width = op.getType().getWidth();
371 auto bvSort = buildSort(rewriter, loc, op.getResult().getType());
372 APInt val = adaptor.getValue().getValue();
373
374 if (width <= 64) {
375 Value bvConst = rewriter.create<LLVM::ConstantOp>(
376 loc, rewriter.getI64Type(), val.getZExtValue());
377 Value res = buildPtrAPICall(rewriter, loc, "Z3_mk_unsigned_int64",
378 {bvConst, bvSort});
379 rewriter.replaceOp(op, res);
380 return success();
381 }
382
383 std::string str;
384 llvm::raw_string_ostream stream(str);
385 stream << val;
386 Value bvString = buildString(rewriter, loc, str);
387 Value bvNumeral =
388 buildPtrAPICall(rewriter, loc, "Z3_mk_numeral", {bvString, bvSort});
389
390 rewriter.replaceOp(op, bvNumeral);
391 return success();
392 }
393};
394
395/// Some of the Z3 API supports a variadic number of operands for some
396/// operations (in particular if the expansion would lead to a super-linear
397/// increase in operations such as with the ':pairwise' attribute). Those API
398/// calls take an 'unsigned' argument indicating the size of an array of
399/// pointers to the operands.
400template <typename SourceTy>
401struct VariadicSMTPattern : public SMTLoweringPattern<SourceTy> {
402 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
403
404 VariadicSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
405 SMTGlobalsHandler &globals,
406 const LowerSMTToZ3LLVMOptions &options,
407 StringRef apiFuncName, unsigned minNumArgs)
408 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
409 apiFuncName(apiFuncName), minNumArgs(minNumArgs) {}
410
411 LogicalResult
412 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
413 ConversionPatternRewriter &rewriter) const final {
414 if (adaptor.getOperands().size() < minNumArgs)
415 return failure();
416
417 Location loc = op.getLoc();
418 Value numOperands = rewriter.create<LLVM::ConstantOp>(
419 loc, rewriter.getI32Type(), op->getNumOperands());
420 Value constOne =
421 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
422 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
423 Type arrTy = LLVM::LLVMArrayType::get(ptrTy, op->getNumOperands());
424 Value storage =
425 rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
426 Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
427
428 for (auto [i, operand] : llvm::enumerate(adaptor.getOperands()))
429 array = rewriter.create<LLVM::InsertValueOp>(
430 loc, array, operand, ArrayRef<int64_t>{(int64_t)i});
431
432 rewriter.create<LLVM::StoreOp>(loc, array, storage);
433
434 rewriter.replaceOp(op,
435 SMTLoweringPattern<SourceTy>::buildPtrAPICall(
436 rewriter, loc, apiFuncName, {numOperands, storage}));
437 return success();
438 }
439
440private:
441 StringRef apiFuncName;
442 unsigned minNumArgs;
443};
444
445/// Lower an SMT operation to a function call with the name 'apiFuncName' with
446/// arguments matching the operands one-to-one.
447template <typename SourceTy>
448struct OneToOneSMTPattern : public SMTLoweringPattern<SourceTy> {
449 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
450
451 OneToOneSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
452 SMTGlobalsHandler &globals,
453 const LowerSMTToZ3LLVMOptions &options,
454 StringRef apiFuncName, unsigned numOperands)
455 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
456 apiFuncName(apiFuncName), numOperands(numOperands) {}
457
458 LogicalResult
459 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
460 ConversionPatternRewriter &rewriter) const final {
461 if (adaptor.getOperands().size() != numOperands)
462 return failure();
463
464 rewriter.replaceOp(
465 op, SMTLoweringPattern<SourceTy>::buildPtrAPICall(
466 rewriter, op.getLoc(), apiFuncName, adaptor.getOperands()));
467 return success();
468 }
469
470private:
471 StringRef apiFuncName;
472 unsigned numOperands;
473};
474
475/// A pattern to lower SMT operations with a variadic number of operands
476/// modelling the ':chainable' attribute in SMT to binary operations.
477template <typename SourceTy>
478class LowerChainableSMTPattern : public SMTLoweringPattern<SourceTy> {
479 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
480 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
481
482 LogicalResult
483 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
484 ConversionPatternRewriter &rewriter) const final {
485 if (adaptor.getOperands().size() <= 2)
486 return failure();
487
488 Location loc = op.getLoc();
489 SmallVector<Value> elements;
490 for (int i = 1, e = adaptor.getOperands().size(); i < e; ++i) {
491 Value val = rewriter.create<SourceTy>(
492 loc, op->getResultTypes(),
493 ValueRange{adaptor.getOperands()[i - 1], adaptor.getOperands()[i]});
494 elements.push_back(val);
495 }
496 rewriter.replaceOpWithNewOp<smt::AndOp>(op, elements);
497 return success();
498 }
499};
500
501/// A pattern to lower SMT operations with a variadic number of operands
502/// modelling the `:left-assoc` attribute to a sequence of binary operators.
503template <typename SourceTy>
504class LowerLeftAssocSMTPattern : public SMTLoweringPattern<SourceTy> {
505 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
506 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
507
508 LogicalResult
509 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
510 ConversionPatternRewriter &rewriter) const final {
511 if (adaptor.getOperands().size() <= 2)
512 return rewriter.notifyMatchFailure(op, "must have at least two operands");
513
514 Value runner = adaptor.getOperands()[0];
515 for (Value val : adaptor.getOperands().drop_front())
516 runner = rewriter.create<SourceTy>(op.getLoc(), op->getResultTypes(),
517 ValueRange{runner, val});
518
519 rewriter.replaceOp(op, runner);
520 return success();
521 }
522};
523
524/// The 'smt.solver' operation has a region that corresponds to the lifetime of
525/// the Z3 context and one solver instance created within this context.
526/// To create a context, a Z3 configuration has to be built first and various
527/// configuration parameters can be set before creating a context from it. Once
528/// we have a context, we can create a solver and store a pointer to the context
529/// and the solver in an LLVM global such that operations in the child region
530/// have access to them. While the context created with `Z3_mk_context` takes
531/// care of the reference counting of `Z3_AST` objects, it still requires manual
532/// reference counting of `Z3_solver` objects, therefore, we need to increase
533/// the ref. counter of the solver we get from `Z3_mk_solver` and must decrease
534/// it again once we don't need it anymore. Finally, the configuration object
535/// can be deleted.
536/// ```
537/// Z3_config Z3_API Z3_mk_config(void);
538/// void Z3_API Z3_set_param_value(Z3_config c, Z3_string param_id,
539/// Z3_string param_value);
540/// Z3_context Z3_API Z3_mk_context(Z3_config c);
541/// Z3_solver Z3_API Z3_mk_solver(Z3_context c);
542/// void Z3_API Z3_solver_inc_ref(Z3_context c, Z3_solver s);
543/// void Z3_API Z3_del_config(Z3_config c);
544/// ```
545/// At the end of the solver lifetime, we have to tell the context that we
546/// don't need the solver anymore and delete the context itself.
547/// ```
548/// void Z3_API Z3_solver_dec_ref(Z3_context c, Z3_solver s);
549/// void Z3_API Z3_del_context(Z3_context c);
550/// ```
551/// Note that the solver created here is a combined solver. There might be some
552/// potential for optimization by creating more specialized solvers supported by
553/// the Z3 API according the the kind of operations present in the body region.
554struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {
555 using SMTLoweringPattern::SMTLoweringPattern;
556
557 LogicalResult
558 matchAndRewrite(SolverOp op, OpAdaptor adaptor,
559 ConversionPatternRewriter &rewriter) const final {
560 Location loc = op.getLoc();
561 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
562 auto voidTy = LLVM::LLVMVoidType::get(getContext());
563 auto ptrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, ptrTy);
564 auto ptrPtrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, {ptrTy, ptrTy});
565 auto ptrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, ptrTy);
566 auto ptrPtrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy});
567
568 // Create the configuration.
569 Value config = buildCall(rewriter, loc, "Z3_mk_config",
570 LLVM::LLVMFunctionType::get(ptrTy, {}), {})
571 .getResult();
572
573 // In debug-mode, we enable proofs such that we can fetch one in the 'unsat'
574 // region of each 'smt.check' operation.
575 if (options.debug) {
576 Value paramKey = buildString(rewriter, loc, "proof");
577 Value paramValue = buildString(rewriter, loc, "true");
578 buildCall(rewriter, loc, "Z3_set_param_value",
579 LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, ptrTy}),
580 {config, paramKey, paramValue});
581 }
582
583 // Check if the logic is set anywhere within the solver
584 std::optional<StringRef> logic = std::nullopt;
585 auto setLogicOps = op.getBodyRegion().getOps<smt::SetLogicOp>();
586 if (!setLogicOps.empty()) {
587 // We know from before patterns were applied that there is only one
588 // set_logic op
589 auto setLogicOp = *setLogicOps.begin();
590 logic = setLogicOp.getLogic();
591 rewriter.eraseOp(setLogicOp);
592 }
593
594 // Create the context and store a pointer to it in the global variable.
595 Value ctx = buildCall(rewriter, loc, "Z3_mk_context", ptrToPtrFunc, config)
596 .getResult();
597 Value ctxAddr =
598 rewriter.create<LLVM::AddressOfOp>(loc, globals.ctx).getResult();
599 rewriter.create<LLVM::StoreOp>(loc, ctx, ctxAddr);
600
601 // Delete the configuration again.
602 buildCall(rewriter, loc, "Z3_del_config", ptrToVoidFunc, {config});
603
604 // Create a solver instance, increase its reference counter, and store a
605 // pointer to it in the global variable.
606 Value solver;
607 if (logic) {
608 auto logicStr = buildString(rewriter, loc, logic.value());
609 solver = buildCall(rewriter, loc, "Z3_mk_solver_for_logic",
610 ptrPtrToPtrFunc, {ctx, logicStr})
611 ->getResult(0);
612 } else {
613 solver = buildCall(rewriter, loc, "Z3_mk_solver", ptrToPtrFunc, ctx)
614 ->getResult(0);
615 }
616 buildCall(rewriter, loc, "Z3_solver_inc_ref", ptrPtrToVoidFunc,
617 {ctx, solver});
618 Value solverAddr =
619 rewriter.create<LLVM::AddressOfOp>(loc, globals.solver).getResult();
620 rewriter.create<LLVM::StoreOp>(loc, solver, solverAddr);
621
622 // This assumes that no constant hoisting of the like happens inbetween
623 // the patterns defined in this pass because once the solver initialization
624 // and deallocation calls are inserted and the body region is inlined,
625 // canonicalizations and folders applied inbetween lowering patterns might
626 // hoist the SMT constants which means they would access uninitialized
627 // global variables once they are lowered.
628 SmallVector<Type> convertedTypes;
629 if (failed(
630 typeConverter->convertTypes(op->getResultTypes(), convertedTypes)))
631 return failure();
632
633 func::FuncOp funcOp;
634 {
635 OpBuilder::InsertionGuard guard(rewriter);
636 auto module = op->getParentOfType<ModuleOp>();
637 rewriter.setInsertionPointToEnd(module.getBody());
638
639 funcOp = rewriter.create<func::FuncOp>(
640 loc, globals.names.newName("solver"),
641 rewriter.getFunctionType(adaptor.getInputs().getTypes(),
642 convertedTypes));
643 rewriter.inlineRegionBefore(op.getBodyRegion(), funcOp.getBody(),
644 funcOp.end());
645 }
646
647 ValueRange results =
648 rewriter.create<func::CallOp>(loc, funcOp, adaptor.getInputs())
649 ->getResults();
650
651 // At the end of the region, decrease the solver's reference counter and
652 // delete the context.
653 // NOTE: we cannot use the convenience helper here because we don't want to
654 // load the context from the global but use the result from the 'mk_context'
655 // call directly for two reasons:
656 // * avoid an unnecessary load
657 // * the caching mechanism of the context does not work here because it
658 // would reuse the loaded context from a earlier solver
659 buildCall(rewriter, loc, "Z3_solver_dec_ref", ptrPtrToVoidFunc,
660 {ctx, solver});
661 buildCall(rewriter, loc, "Z3_del_context", ptrToVoidFunc, ctx);
662
663 rewriter.replaceOp(op, results);
664 return success();
665 }
666};
667
668/// Lower `smt.assert` operations to Z3 API calls of the form:
669/// ```
670/// void Z3_API Z3_solver_assert(Z3_context c, Z3_solver s, Z3_ast a);
671/// ```
672struct AssertOpLowering : public SMTLoweringPattern<AssertOp> {
673 using SMTLoweringPattern::SMTLoweringPattern;
674
675 LogicalResult
676 matchAndRewrite(AssertOp op, OpAdaptor adaptor,
677 ConversionPatternRewriter &rewriter) const final {
678 Location loc = op.getLoc();
679 buildAPICallWithContext(
680 rewriter, loc, "Z3_solver_assert",
681 LLVM::LLVMVoidType::get(getContext()),
682 {buildSolverPtr(rewriter, loc), adaptor.getInput()});
683
684 rewriter.eraseOp(op);
685 return success();
686 }
687};
688
689/// Lower `smt.reset` operations to Z3 API calls of the form:
690/// ```
691/// void Z3_API Z3_solver_reset(Z3_context c, Z3_solver s);
692/// ```
693struct ResetOpLowering : public SMTLoweringPattern<ResetOp> {
694 using SMTLoweringPattern::SMTLoweringPattern;
695
696 LogicalResult
697 matchAndRewrite(ResetOp op, OpAdaptor adaptor,
698 ConversionPatternRewriter &rewriter) const final {
699 Location loc = op.getLoc();
700 buildAPICallWithContext(rewriter, loc, "Z3_solver_reset",
701 LLVM::LLVMVoidType::get(getContext()),
702 {buildSolverPtr(rewriter, loc)});
703
704 rewriter.eraseOp(op);
705 return success();
706 }
707};
708
709/// Lower `smt.push` operations to (repeated) Z3 API calls of the form:
710/// ```
711/// void Z3_API Z3_solver_push(Z3_context c, Z3_solver s);
712/// ```
713struct PushOpLowering : public SMTLoweringPattern<PushOp> {
714 using SMTLoweringPattern::SMTLoweringPattern;
715 LogicalResult
716 matchAndRewrite(PushOp op, OpAdaptor adaptor,
717 ConversionPatternRewriter &rewriter) const final {
718 Location loc = op.getLoc();
719 // SMTLIB allows multiple levels to be pushed with one push command, but the
720 // Z3 C API doesn't let you provide a number of levels for push calls so
721 // multiple calls have to be created.
722 for (uint32_t i = 0; i < op.getCount(); i++)
723 buildAPICallWithContext(rewriter, loc, "Z3_solver_push",
724 LLVM::LLVMVoidType::get(getContext()),
725 {buildSolverPtr(rewriter, loc)});
726 rewriter.eraseOp(op);
727 return success();
728 }
729};
730
731/// Lower `smt.pop` operations to Z3 API calls of the form:
732/// ```
733/// void Z3_API Z3_solver_pop(Z3_context c, Z3_solver s, unsigned n);
734/// ```
735struct PopOpLowering : public SMTLoweringPattern<PopOp> {
736 using SMTLoweringPattern::SMTLoweringPattern;
737 LogicalResult
738 matchAndRewrite(PopOp op, OpAdaptor adaptor,
739 ConversionPatternRewriter &rewriter) const final {
740 Location loc = op.getLoc();
741 Value constVal = rewriter.create<LLVM::ConstantOp>(
742 loc, rewriter.getI32Type(), op.getCount());
743 buildAPICallWithContext(rewriter, loc, "Z3_solver_pop",
744 LLVM::LLVMVoidType::get(getContext()),
745 {buildSolverPtr(rewriter, loc), constVal});
746 rewriter.eraseOp(op);
747 return success();
748 }
749};
750
751/// Lower `smt.yield` operations to `scf.yield` operations. This not necessary
752/// for the yield in `smt.solver` or in quantifiers since they are deleted
753/// directly by the parent operation, but makes the lowering of the `smt.check`
754/// operation simpler and more convenient since the regions get translated
755/// directly to regions of `scf.if` operations.
756struct YieldOpLowering : public SMTLoweringPattern<YieldOp> {
757 using SMTLoweringPattern::SMTLoweringPattern;
758
759 LogicalResult
760 matchAndRewrite(YieldOp op, OpAdaptor adaptor,
761 ConversionPatternRewriter &rewriter) const final {
762 if (op->getParentOfType<func::FuncOp>()) {
763 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getValues());
764 return success();
765 }
766 if (op->getParentOfType<LLVM::LLVMFuncOp>()) {
767 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getValues());
768 return success();
769 }
770 if (isa<scf::SCFDialect>(op->getParentOp()->getDialect())) {
771 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getValues());
772 return success();
773 }
774 return failure();
775 }
776};
777
778/// Lower `smt.check` operations to Z3 API calls and control-flow operations.
779/// ```
780/// Z3_lbool Z3_API Z3_solver_check(Z3_context c, Z3_solver s);
781///
782/// typedef enum
783/// {
784/// Z3_L_FALSE = -1, // means unsatisfiable here
785/// Z3_L_UNDEF, // means unknown here
786/// Z3_L_TRUE // means satisfiable here
787/// } Z3_lbool;
788/// ```
789struct CheckOpLowering : public SMTLoweringPattern<CheckOp> {
790 using SMTLoweringPattern::SMTLoweringPattern;
791
792 LogicalResult
793 matchAndRewrite(CheckOp op, OpAdaptor adaptor,
794 ConversionPatternRewriter &rewriter) const final {
795 Location loc = op.getLoc();
796 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
797 auto printfType =
798 LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrTy}, true);
799
800 auto getHeaderString = [](const std::string &title) {
801 unsigned titleSize = title.size() + 2; // Add a space left and right
802 return std::string((80 - titleSize) / 2, '-') + " " + title + " " +
803 std::string((80 - titleSize + 1) / 2, '-') + "\n%s\n" +
804 std::string(80, '-') + "\n";
805 };
806
807 // Get the pointer to the solver instance.
808 Value solver = buildSolverPtr(rewriter, loc);
809
810 // In debug-mode, print the state of the solver before calling 'check-sat'
811 // on it. This prints the asserted SMT expressions.
812 if (options.debug) {
813 auto solverStringPtr =
814 buildPtrAPICall(rewriter, loc, "Z3_solver_to_string", {solver});
815 auto solverFormatString =
816 buildString(rewriter, loc, getHeaderString("Solver"));
817 buildCall(rewriter, op.getLoc(), "printf", printfType,
818 {solverFormatString, solverStringPtr});
819 }
820
821 // Convert the result types of the `smt.check` operation.
822 SmallVector<Type> resultTypes;
823 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
824 return failure();
825
826 // Call 'check-sat' and check if the assertions are satisfiable.
827 Value checkResult =
828 buildAPICallWithContext(rewriter, loc, "Z3_solver_check",
829 rewriter.getI32Type(), {solver})
830 ->getResult(0);
831 Value constOne =
832 rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), 1);
833 Value isSat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
834 checkResult, constOne);
835
836 // Simply inline the 'sat' region into the 'then' region of the 'scf.if'
837 auto satIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isSat);
838 rewriter.inlineRegionBefore(op.getSatRegion(), satIfOp.getThenRegion(),
839 satIfOp.getThenRegion().end());
840
841 // Otherwise, the 'else' block checks if the assertions are unsatisfiable or
842 // unknown. The corresponding regions can also be simply inlined into the
843 // two branches of this nested if-statement as well.
844 rewriter.createBlock(&satIfOp.getElseRegion());
845 Value constNegOne =
846 rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), -1);
847 Value isUnsat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
848 checkResult, constNegOne);
849 auto unsatIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isUnsat);
850 rewriter.create<scf::YieldOp>(loc, unsatIfOp->getResults());
851
852 rewriter.inlineRegionBefore(op.getUnsatRegion(), unsatIfOp.getThenRegion(),
853 unsatIfOp.getThenRegion().end());
854 rewriter.inlineRegionBefore(op.getUnknownRegion(),
855 unsatIfOp.getElseRegion(),
856 unsatIfOp.getElseRegion().end());
857
858 rewriter.replaceOp(op, satIfOp->getResults());
859
860 if (options.debug) {
861 // In debug-mode, if the assertions are unsatisfiable we can print the
862 // proof.
863 rewriter.setInsertionPointToStart(unsatIfOp.thenBlock());
864 auto proof = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_proof",
865 {solver});
866 auto stringPtr =
867 buildPtrAPICall(rewriter, op.getLoc(), "Z3_ast_to_string", {proof});
868 auto formatString =
869 buildString(rewriter, op.getLoc(), getHeaderString("Proof"));
870 buildCall(rewriter, op.getLoc(), "printf", printfType,
871 {formatString, stringPtr});
872
873 // In debug mode, if the assertions are satisfiable we can print the model
874 // (effectively a counter-example).
875 rewriter.setInsertionPointToStart(satIfOp.thenBlock());
876 auto model = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_model",
877 {solver});
878 auto modelStringPtr =
879 buildPtrAPICall(rewriter, op.getLoc(), "Z3_model_to_string", {model});
880 auto modelFormatString =
881 buildString(rewriter, op.getLoc(), getHeaderString("Model"));
882 buildCall(rewriter, op.getLoc(), "printf", printfType,
883 {modelFormatString, modelStringPtr});
884 }
885
886 return success();
887 }
888};
889
890/// Lower `smt.forall` and `smt.exists` operations to the following Z3 API call.
891/// ```
892/// Z3_ast Z3_API Z3_mk_{forall|exists}_const(
893/// Z3_context c,
894/// unsigned weight,
895/// unsigned num_bound,
896/// Z3_app const bound[],
897/// unsigned num_patterns,
898/// Z3_pattern const patterns[],
899/// Z3_ast body
900/// );
901/// ```
902/// All nested regions are inlined into the parent region and the block
903/// arguments are replaced with new `smt.declare_fun` constants that are also
904/// passed to the `bound` argument of above API function. Patterns are created
905/// with the following API function.
906/// ```
907/// Z3_pattern Z3_API Z3_mk_pattern(Z3_context c, unsigned num_patterns,
908/// Z3_ast const terms[]);
909/// ```
910/// Where each operand of the `smt.yield` in a pattern region is a 'term'.
911template <typename QuantifierOp>
912struct QuantifierLowering : public SMTLoweringPattern<QuantifierOp> {
913 using SMTLoweringPattern<QuantifierOp>::SMTLoweringPattern;
914 using SMTLoweringPattern<QuantifierOp>::typeConverter;
915 using SMTLoweringPattern<QuantifierOp>::buildPtrAPICall;
916 using OpAdaptor = typename QuantifierOp::Adaptor;
917
918 Value createStorageForValueList(ValueRange values, Location loc,
919 ConversionPatternRewriter &rewriter) const {
920 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
921 Type arrTy = LLVM::LLVMArrayType::get(ptrTy, values.size());
922 Value constOne =
923 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
924 Value storage =
925 rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
926 Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
927
928 for (auto [i, val] : llvm::enumerate(values))
929 array = rewriter.create<LLVM::InsertValueOp>(loc, array, val,
930 ArrayRef<int64_t>(i));
931
932 rewriter.create<LLVM::StoreOp>(loc, array, storage);
933
934 return storage;
935 }
936
937 LogicalResult
938 matchAndRewrite(QuantifierOp op, OpAdaptor adaptor,
939 ConversionPatternRewriter &rewriter) const final {
940 Location loc = op.getLoc();
941 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
942
943 // no-pattern attribute not supported yet because the Z3 CAPI allows more
944 // fine-grained control where a list of patterns to be banned can be given.
945 // This means, the no-pattern attribute is equivalent to providing a list of
946 // all possible sub-expressions in the quantifier body to the CAPI.
947 if (adaptor.getNoPattern())
948 return rewriter.notifyMatchFailure(
949 op, "no-pattern attribute not yet supported!");
950
951 rewriter.setInsertionPoint(op);
952
953 // Weight attribute
954 Value weight = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
955 adaptor.getWeight());
956
957 // Bound variables
958 unsigned numDecls = op.getBody().getNumArguments();
959 Value numDeclsVal =
960 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), numDecls);
961
962 // We replace the block arguments with constant symbolic values and inform
963 // the quantifier API call which constants it should treat as bound
964 // variables. We also need to make sure that we use the exact same SSA
965 // values in the pattern regions since we lower constant declaration
966 // operation to always produce fresh constants.
967 SmallVector<Value> repl;
968 for (auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) {
969 Value newArg;
970 if (adaptor.getBoundVarNames().has_value())
971 newArg = rewriter.create<smt::DeclareFunOp>(
972 loc, arg.getType(),
973 cast<StringAttr>((*adaptor.getBoundVarNames())[i]));
974 else
975 newArg = rewriter.create<smt::DeclareFunOp>(loc, arg.getType());
976 repl.push_back(typeConverter->materializeTargetConversion(
977 rewriter, loc, typeConverter->convertType(arg.getType()), newArg));
978 }
979
980 Value boundStorage = createStorageForValueList(repl, loc, rewriter);
981
982 // Body Expression
983 auto yieldOp = cast<smt::YieldOp>(op.getBody().front().getTerminator());
984 Value bodyExp = yieldOp.getValues()[0];
985 rewriter.setInsertionPointAfterValue(bodyExp);
986 bodyExp = typeConverter->materializeTargetConversion(
987 rewriter, loc, typeConverter->convertType(bodyExp.getType()), bodyExp);
988 rewriter.eraseOp(yieldOp);
989
990 rewriter.inlineBlockBefore(&op.getBody().front(), op, repl);
991 rewriter.setInsertionPoint(op);
992
993 // Patterns
994 unsigned numPatterns = adaptor.getPatterns().size();
995 Value numPatternsVal = rewriter.create<LLVM::ConstantOp>(
996 loc, rewriter.getI32Type(), numPatterns);
997
998 Value patternStorage;
999 if (numPatterns > 0) {
1000 SmallVector<Value> patterns;
1001 for (Region *patternRegion : adaptor.getPatterns()) {
1002 auto yieldOp =
1003 cast<smt::YieldOp>(patternRegion->front().getTerminator());
1004 auto patternTerms = yieldOp.getOperands();
1005
1006 rewriter.setInsertionPoint(yieldOp);
1007 SmallVector<Value> patternList;
1008 for (auto val : patternTerms)
1009 patternList.push_back(typeConverter->materializeTargetConversion(
1010 rewriter, loc, typeConverter->convertType(val.getType()), val));
1011
1012 rewriter.eraseOp(yieldOp);
1013 rewriter.inlineBlockBefore(&patternRegion->front(), op, repl);
1014
1015 rewriter.setInsertionPoint(op);
1016 Value numTerms = rewriter.create<LLVM::ConstantOp>(
1017 loc, rewriter.getI32Type(), patternTerms.size());
1018 Value patternTermStorage =
1019 createStorageForValueList(patternList, loc, rewriter);
1020 Value pattern = buildPtrAPICall(rewriter, loc, "Z3_mk_pattern",
1021 {numTerms, patternTermStorage});
1022
1023 patterns.emplace_back(pattern);
1024 }
1025 patternStorage = createStorageForValueList(patterns, loc, rewriter);
1026 } else {
1027 // If we set the num_patterns parameter to 0, we can just pass a nullptr
1028 // as storage.
1029 patternStorage = rewriter.create<LLVM::ZeroOp>(loc, ptrTy);
1030 }
1031
1032 StringRef apiCallName = "Z3_mk_forall_const";
1033 if (std::is_same_v<QuantifierOp, ExistsOp>)
1034 apiCallName = "Z3_mk_exists_const";
1035 Value quantifierExp =
1036 buildPtrAPICall(rewriter, loc, apiCallName,
1037 {weight, numDeclsVal, boundStorage, numPatternsVal,
1038 patternStorage, bodyExp});
1039
1040 rewriter.replaceOp(op, quantifierExp);
1041 return success();
1042 }
1043};
1044
1045/// Lower `smt.bv.repeat` operations to Z3 API function calls of the form
1046/// ```
1047/// Z3_ast Z3_API Z3_mk_repeat(Z3_context c, unsigned i, Z3_ast t1);
1048/// ```
1049struct RepeatOpLowering : public SMTLoweringPattern<RepeatOp> {
1050 using SMTLoweringPattern::SMTLoweringPattern;
1051
1052 LogicalResult
1053 matchAndRewrite(RepeatOp op, OpAdaptor adaptor,
1054 ConversionPatternRewriter &rewriter) const final {
1055 Value count = rewriter.create<LLVM::ConstantOp>(
1056 op.getLoc(), rewriter.getI32Type(), op.getCount());
1057 rewriter.replaceOp(op,
1058 buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_repeat",
1059 {count, adaptor.getInput()}));
1060 return success();
1061 }
1062};
1063
1064/// Lower `smt.bv.extract` operations to Z3 API function calls of the following
1065/// form, where the output bit-vector has size `n = high - low + 1`. This means,
1066/// both the 'high' and 'low' indices are inclusive.
1067/// ```
1068/// Z3_ast Z3_API Z3_mk_extract(Z3_context c, unsigned high, unsigned low,
1069/// Z3_ast t1);
1070/// ```
1071struct ExtractOpLowering : public SMTLoweringPattern<ExtractOp> {
1072 using SMTLoweringPattern::SMTLoweringPattern;
1073
1074 LogicalResult
1075 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
1076 ConversionPatternRewriter &rewriter) const final {
1077 Location loc = op.getLoc();
1078 Value low = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
1079 adaptor.getLowBit());
1080 Value high = rewriter.create<LLVM::ConstantOp>(
1081 loc, rewriter.getI32Type(),
1082 adaptor.getLowBit() + op.getType().getWidth() - 1);
1083 rewriter.replaceOp(op, buildPtrAPICall(rewriter, loc, "Z3_mk_extract",
1084 {high, low, adaptor.getInput()}));
1085 return success();
1086 }
1087};
1088
1089/// Lower `smt.array.broadcast` operations to Z3 API function calls of the form
1090/// ```
1091/// Z3_ast Z3_API Z3_mk_const_array(Z3_context c, Z3_sort domain, Z3_ast v);
1092/// ```
1093struct ArrayBroadcastOpLowering
1094 : public SMTLoweringPattern<smt::ArrayBroadcastOp> {
1095 using SMTLoweringPattern::SMTLoweringPattern;
1096
1097 LogicalResult
1098 matchAndRewrite(smt::ArrayBroadcastOp op, OpAdaptor adaptor,
1099 ConversionPatternRewriter &rewriter) const final {
1100 auto domainSort = buildSort(
1101 rewriter, op.getLoc(),
1102 cast<smt::ArrayType>(op.getResult().getType()).getDomainType());
1103
1104 rewriter.replaceOp(op, buildPtrAPICall(rewriter, op.getLoc(),
1105 "Z3_mk_const_array",
1106 {domainSort, adaptor.getValue()}));
1107 return success();
1108 }
1109};
1110
1111/// Lower the `smt.constant` operation to one of the following Z3 API function
1112/// calls depending on the value of the boolean attribute.
1113/// ```
1114/// Z3_ast Z3_API Z3_mk_true(Z3_context c);
1115/// Z3_ast Z3_API Z3_mk_false(Z3_context c);
1116/// ```
1117struct BoolConstantOpLowering : public SMTLoweringPattern<smt::BoolConstantOp> {
1118 using SMTLoweringPattern::SMTLoweringPattern;
1119
1120 LogicalResult
1121 matchAndRewrite(smt::BoolConstantOp op, OpAdaptor adaptor,
1122 ConversionPatternRewriter &rewriter) const final {
1123 rewriter.replaceOp(
1124 op, buildPtrAPICall(rewriter, op.getLoc(),
1125 adaptor.getValue() ? "Z3_mk_true" : "Z3_mk_false"));
1126 return success();
1127 }
1128};
1129
1130/// Lower `smt.int.constant` operations to one of the following two Z3 API
1131/// function calls depending on whether the storage APInt has a bit-width that
1132/// fits in a `uint64_t`.
1133/// ```
1134/// Z3_sort Z3_API Z3_mk_int_sort(Z3_context c);
1135///
1136/// Z3_ast Z3_API Z3_mk_int64(Z3_context c, int64_t v, Z3_sort ty);
1137///
1138/// Z3_ast Z3_API Z3_mk_numeral(Z3_context c, Z3_string numeral, Z3_sort ty);
1139/// Z3_ast Z3_API Z3_mk_unary_minus(Z3_context c, Z3_ast arg);
1140/// ```
1141struct IntConstantOpLowering : public SMTLoweringPattern<smt::IntConstantOp> {
1142 using SMTLoweringPattern::SMTLoweringPattern;
1143
1144 LogicalResult
1145 matchAndRewrite(smt::IntConstantOp op, OpAdaptor adaptor,
1146 ConversionPatternRewriter &rewriter) const final {
1147 Location loc = op.getLoc();
1148 Value type = buildPtrAPICall(rewriter, loc, "Z3_mk_int_sort");
1149 if (adaptor.getValue().getBitWidth() <= 64) {
1150 Value val = rewriter.create<LLVM::ConstantOp>(
1151 loc, rewriter.getI64Type(), adaptor.getValue().getSExtValue());
1152 rewriter.replaceOp(
1153 op, buildPtrAPICall(rewriter, loc, "Z3_mk_int64", {val, type}));
1154 return success();
1155 }
1156
1157 std::string numeralStr;
1158 llvm::raw_string_ostream stream(numeralStr);
1159 stream << adaptor.getValue().abs();
1160
1161 Value numeral = buildString(rewriter, loc, numeralStr);
1162 Value intNumeral =
1163 buildPtrAPICall(rewriter, loc, "Z3_mk_numeral", {numeral, type});
1164
1165 if (adaptor.getValue().isNegative())
1166 intNumeral =
1167 buildPtrAPICall(rewriter, loc, "Z3_mk_unary_minus", intNumeral);
1168
1169 rewriter.replaceOp(op, intNumeral);
1170 return success();
1171 }
1172};
1173
1174/// Lower `smt.int.cmp` operations to one of the following Z3 API function calls
1175/// depending on the predicate.
1176/// ```
1177/// Z3_ast Z3_API Z3_mk_{{pred}}(Z3_context c, Z3_ast t1, Z3_ast t2);
1178/// ```
1179struct IntCmpOpLowering : public SMTLoweringPattern<IntCmpOp> {
1180 using SMTLoweringPattern::SMTLoweringPattern;
1181
1182 LogicalResult
1183 matchAndRewrite(IntCmpOp op, OpAdaptor adaptor,
1184 ConversionPatternRewriter &rewriter) const final {
1185 rewriter.replaceOp(
1186 op,
1187 buildPtrAPICall(rewriter, op.getLoc(),
1188 "Z3_mk_" + stringifyIntPredicate(op.getPred()).str(),
1189 {adaptor.getLhs(), adaptor.getRhs()}));
1190 return success();
1191 }
1192};
1193
1194/// Lower `smt.int2bv` operations to the following Z3 API function calls.
1195/// ```
1196/// Z3_ast Z3_API Z3_mk_int2bv(Z3_context c, unsigned n, Z3_ast t1);
1197/// ```
1198struct Int2BVOpLowering : public SMTLoweringPattern<Int2BVOp> {
1199 using SMTLoweringPattern::SMTLoweringPattern;
1200
1201 LogicalResult
1202 matchAndRewrite(Int2BVOp op, OpAdaptor adaptor,
1203 ConversionPatternRewriter &rewriter) const final {
1204 Value widthConst =
1205 rewriter.create<LLVM::ConstantOp>(op->getLoc(), rewriter.getI32Type(),
1206 op.getResult().getType().getWidth());
1207 rewriter.replaceOp(op,
1208 buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_int2bv",
1209 {widthConst, adaptor.getInput()}));
1210 return success();
1211 }
1212};
1213
1214/// Lower `smt.bv2int` operations to the following Z3 API function call.
1215/// ```
1216/// Z3_ast Z3_API Z3_mk_bv2int(Z3_context c, Z3_ast t1, bool is_signed)
1217/// ```
1218struct BV2IntOpLowering : public SMTLoweringPattern<BV2IntOp> {
1219 using SMTLoweringPattern::SMTLoweringPattern;
1220
1221 LogicalResult
1222 matchAndRewrite(BV2IntOp op, OpAdaptor adaptor,
1223 ConversionPatternRewriter &rewriter) const final {
1224 // FIXME: ideally we don't want to use i1 here, since bools can sometimes be
1225 // compiled to wider widths in LLVM
1226 Value isSignedConst = rewriter.create<LLVM::ConstantOp>(
1227 op->getLoc(), rewriter.getI1Type(), op.getIsSigned());
1228 rewriter.replaceOp(op,
1229 buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_bv2int",
1230 {adaptor.getInput(), isSignedConst}));
1231 return success();
1232 }
1233};
1234
1235/// Lower `smt.bv.cmp` operations to one of the following Z3 API function calls,
1236/// performing two's complement comparison, depending on the predicate
1237/// attribute.
1238/// ```
1239/// Z3_ast Z3_API Z3_mk_bv{{pred}}(Z3_context c, Z3_ast t1, Z3_ast t2);
1240/// ```
1241struct BVCmpOpLowering : public SMTLoweringPattern<BVCmpOp> {
1242 using SMTLoweringPattern::SMTLoweringPattern;
1243
1244 LogicalResult
1245 matchAndRewrite(BVCmpOp op, OpAdaptor adaptor,
1246 ConversionPatternRewriter &rewriter) const final {
1247 rewriter.replaceOp(
1248 op, buildPtrAPICall(rewriter, op.getLoc(),
1249 "Z3_mk_bv" +
1250 stringifyBVCmpPredicate(op.getPred()).str(),
1251 {adaptor.getLhs(), adaptor.getRhs()}));
1252 return success();
1253 }
1254};
1255
1256/// Expand the `smt.int.abs` operation to a `smt.ite` operation.
1257struct IntAbsOpLowering : public SMTLoweringPattern<IntAbsOp> {
1258 using SMTLoweringPattern::SMTLoweringPattern;
1259
1260 LogicalResult
1261 matchAndRewrite(IntAbsOp op, OpAdaptor adaptor,
1262 ConversionPatternRewriter &rewriter) const final {
1263 Location loc = op.getLoc();
1264 Value zero = rewriter.create<IntConstantOp>(
1265 loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
1266 Value cmp = rewriter.create<IntCmpOp>(loc, IntPredicate::lt,
1267 adaptor.getInput(), zero);
1268 Value neg = rewriter.create<IntSubOp>(loc, zero, adaptor.getInput());
1269 rewriter.replaceOpWithNewOp<IteOp>(op, cmp, neg, adaptor.getInput());
1270 return success();
1271 }
1272};
1273
1274} // namespace
1275
1276//===----------------------------------------------------------------------===//
1277// Pass Implementation
1278//===----------------------------------------------------------------------===//
1279
1280namespace {
1281struct LowerSMTToZ3LLVMPass
1282 : public circt::impl::LowerSMTToZ3LLVMBase<LowerSMTToZ3LLVMPass> {
1283 using Base::Base;
1284 void runOnOperation() override;
1285};
1286} // namespace
1287
1288void circt::populateSMTToZ3LLVMTypeConverter(TypeConverter &converter) {
1289 converter.addConversion([](smt::BoolType type) {
1290 return LLVM::LLVMPointerType::get(type.getContext());
1291 });
1292 converter.addConversion([](smt::BitVectorType type) {
1293 return LLVM::LLVMPointerType::get(type.getContext());
1294 });
1295 converter.addConversion([](smt::ArrayType type) {
1296 return LLVM::LLVMPointerType::get(type.getContext());
1297 });
1298 converter.addConversion([](smt::IntType type) {
1299 return LLVM::LLVMPointerType::get(type.getContext());
1300 });
1301 converter.addConversion([](smt::SMTFuncType type) {
1302 return LLVM::LLVMPointerType::get(type.getContext());
1303 });
1304 converter.addConversion([](smt::SortType type) {
1305 return LLVM::LLVMPointerType::get(type.getContext());
1306 });
1307}
1308
1310 RewritePatternSet &patterns, TypeConverter &converter,
1311 SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options) {
1312#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS) \
1313 patterns.add<VariadicSMTPattern<OP>>(/*NOLINT(bugprone-macro-parentheses)*/ \
1314 converter, patterns.getContext(), \
1315 globals, options, APINAME, \
1316 MIN_NUM_ARGS);
1317
1318#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS) \
1319 patterns.add<OneToOneSMTPattern<OP>>(/*NOLINT(bugprone-macro-parentheses)*/ \
1320 converter, patterns.getContext(), \
1321 globals, options, APINAME, NUM_ARGS);
1322
1323 // Lower `smt.distinct` operations which allows a variadic number of operands
1324 // according to the `:pairwise` attribute. The Z3 API function supports a
1325 // variadic number of operands as well, i.e., a direct lowering is possible:
1326 // ```
1327 // Z3_ast Z3_API Z3_mk_distinct(Z3_context c, unsigned num_args, Z3_ast const
1328 // args[])
1329 // ```
1330 // The API function requires num_args > 1 which is guaranteed to be satisfied
1331 // because `smt.distinct` is verified to have > 1 operands.
1332 ADD_VARIADIC_PATTERN(DistinctOp, "Z3_mk_distinct", 2);
1333
1334 // Lower `smt.and` operations which allows a variadic number of operands
1335 // according to the `:left-assoc` attribute. The Z3 API function supports a
1336 // variadic number of operands as well, i.e., a direct lowering is possible:
1337 // ```
1338 // Z3_ast Z3_API Z3_mk_and(Z3_context c, unsigned num_args, Z3_ast const
1339 // args[])
1340 // ```
1341 // The API function requires num_args > 1. This is not guaranteed by the
1342 // `smt.and` operation and thus the pattern will not apply when no operand is
1343 // present. The constant folder of the operation is assumed to fold this to
1344 // a constant 'true' (neutral element of AND).
1345 ADD_VARIADIC_PATTERN(AndOp, "Z3_mk_and", 2);
1346
1347 // Lower `smt.or` operations which allows a variadic number of operands
1348 // according to the `:left-assoc` attribute. The Z3 API function supports a
1349 // variadic number of operands as well, i.e., a direct lowering is possible:
1350 // ```
1351 // Z3_ast Z3_API Z3_mk_or(Z3_context c, unsigned num_args, Z3_ast const
1352 // args[])
1353 // ```
1354 // The API function requires num_args > 1. This is not guaranteed by the
1355 // `smt.or` operation and thus the pattern will not apply when no operand is
1356 // present. The constant folder of the operation is assumed to fold this to
1357 // a constant 'false' (neutral element of OR).
1358 ADD_VARIADIC_PATTERN(OrOp, "Z3_mk_or", 2);
1359
1360 // Lower `smt.not` operations to the following Z3 API function:
1361 // ```
1362 // Z3_ast Z3_API Z3_mk_not(Z3_context c, Z3_ast a);
1363 // ```
1364 ADD_ONE_TO_ONE_PATTERN(NotOp, "Z3_mk_not", 1);
1365
1366 // Lower `smt.xor` operations which allows a variadic number of operands
1367 // according to the `:left-assoc` attribute. The Z3 API function, however,
1368 // only takes two operands.
1369 // ```
1370 // Z3_ast Z3_API Z3_mk_xor(Z3_context c, Z3_ast t1, Z3_ast t2);
1371 // ```
1372 // Therefore, we need to decompose the operation first to a sequence of XOR
1373 // operations matching the left associative behavior.
1374 patterns.add<LowerLeftAssocSMTPattern<XOrOp>>(
1375 converter, patterns.getContext(), globals, options);
1376 ADD_ONE_TO_ONE_PATTERN(XOrOp, "Z3_mk_xor", 2);
1377
1378 // Lower `smt.implies` operations to the following Z3 API function:
1379 // ```
1380 // Z3_ast Z3_API Z3_mk_implies(Z3_context c, Z3_ast t1, Z3_ast t2);
1381 // ```
1382 ADD_ONE_TO_ONE_PATTERN(ImpliesOp, "Z3_mk_implies", 2);
1383
1384 // All the bit-vector arithmetic and bitwise operations conveniently lower to
1385 // Z3 API function calls with essentially matching names and a one-to-one
1386 // correspondence of operands to call arguments.
1387 ADD_ONE_TO_ONE_PATTERN(BVNegOp, "Z3_mk_bvneg", 1);
1388 ADD_ONE_TO_ONE_PATTERN(BVAddOp, "Z3_mk_bvadd", 2);
1389 ADD_ONE_TO_ONE_PATTERN(BVMulOp, "Z3_mk_bvmul", 2);
1390 ADD_ONE_TO_ONE_PATTERN(BVURemOp, "Z3_mk_bvurem", 2);
1391 ADD_ONE_TO_ONE_PATTERN(BVSRemOp, "Z3_mk_bvsrem", 2);
1392 ADD_ONE_TO_ONE_PATTERN(BVSModOp, "Z3_mk_bvsmod", 2);
1393 ADD_ONE_TO_ONE_PATTERN(BVUDivOp, "Z3_mk_bvudiv", 2);
1394 ADD_ONE_TO_ONE_PATTERN(BVSDivOp, "Z3_mk_bvsdiv", 2);
1395 ADD_ONE_TO_ONE_PATTERN(BVShlOp, "Z3_mk_bvshl", 2);
1396 ADD_ONE_TO_ONE_PATTERN(BVLShrOp, "Z3_mk_bvlshr", 2);
1397 ADD_ONE_TO_ONE_PATTERN(BVAShrOp, "Z3_mk_bvashr", 2);
1398 ADD_ONE_TO_ONE_PATTERN(BVNotOp, "Z3_mk_bvnot", 1);
1399 ADD_ONE_TO_ONE_PATTERN(BVAndOp, "Z3_mk_bvand", 2);
1400 ADD_ONE_TO_ONE_PATTERN(BVOrOp, "Z3_mk_bvor", 2);
1401 ADD_ONE_TO_ONE_PATTERN(BVXOrOp, "Z3_mk_bvxor", 2);
1402
1403 // The `smt.bv.concat` operation only supports two operands, just like the
1404 // Z3 API function.
1405 // ```
1406 // Z3_ast Z3_API Z3_mk_concat(Z3_context c, Z3_ast t1, Z3_ast t2);
1407 // ```
1408 ADD_ONE_TO_ONE_PATTERN(ConcatOp, "Z3_mk_concat", 2);
1409
1410 // Lower the `smt.ite` operation to the following Z3 API function call, where
1411 // `t1` must have boolean sort.
1412 // ```
1413 // Z3_ast Z3_API Z3_mk_ite(Z3_context c, Z3_ast t1, Z3_ast t2, Z3_ast t3);
1414 // ```
1415 ADD_ONE_TO_ONE_PATTERN(IteOp, "Z3_mk_ite", 3);
1416
1417 // Lower the `smt.array.select` operation to the following Z3 function call.
1418 // The operand declaration of the operation matches the order of arguments of
1419 // the API function.
1420 // ```
1421 // Z3_ast Z3_API Z3_mk_select(Z3_context c, Z3_ast a, Z3_ast i);
1422 // ```
1423 // Where `a` is the array expression and `i` is the index expression.
1424 ADD_ONE_TO_ONE_PATTERN(ArraySelectOp, "Z3_mk_select", 2);
1425
1426 // Lower the `smt.array.store` operation to the following Z3 function call.
1427 // The operand declaration of the operation matches the order of arguments of
1428 // the API function.
1429 // ```
1430 // Z3_ast Z3_API Z3_mk_store(Z3_context c, Z3_ast a, Z3_ast i, Z3_ast v);
1431 // ```
1432 // Where `a` is the array expression, `i` is the index expression, and `v` is
1433 // the value expression to be stored.
1434 ADD_ONE_TO_ONE_PATTERN(ArrayStoreOp, "Z3_mk_store", 3);
1435
1436 // Lower the `smt.int.add` operation to the following Z3 API function call.
1437 // ```
1438 // Z3_ast Z3_API Z3_mk_add(Z3_context c, unsigned num_args, Z3_ast const
1439 // args[]);
1440 // ```
1441 // The number of arguments must be greater than zero. Therefore, the pattern
1442 // will fail if applied to an operation with less than two operands.
1443 ADD_VARIADIC_PATTERN(IntAddOp, "Z3_mk_add", 2);
1444
1445 // Lower the `smt.int.mul` operation to the following Z3 API function call.
1446 // ```
1447 // Z3_ast Z3_API Z3_mk_mul(Z3_context c, unsigned num_args, Z3_ast const
1448 // args[]);
1449 // ```
1450 // The number of arguments must be greater than zero. Therefore, the pattern
1451 // will fail if applied to an operation with less than two operands.
1452 ADD_VARIADIC_PATTERN(IntMulOp, "Z3_mk_mul", 2);
1453
1454 // Lower the `smt.int.sub` operation to the following Z3 API function call.
1455 // ```
1456 // Z3_ast Z3_API Z3_mk_sub(Z3_context c, unsigned num_args, Z3_ast const
1457 // args[]);
1458 // ```
1459 // The number of arguments must be greater than zero. Since the `smt.int.sub`
1460 // operation always has exactly two operands, this trivially holds.
1461 ADD_VARIADIC_PATTERN(IntSubOp, "Z3_mk_sub", 2);
1462
1463 // Lower the `smt.int.div` operation to the following Z3 API function call.
1464 // ```
1465 // Z3_ast Z3_API Z3_mk_div(Z3_context c, Z3_ast arg1, Z3_ast arg2);
1466 // ```
1467 ADD_ONE_TO_ONE_PATTERN(IntDivOp, "Z3_mk_div", 2);
1468
1469 // Lower the `smt.int.mod` operation to the following Z3 API function call.
1470 // ```
1471 // Z3_ast Z3_API Z3_mk_mod(Z3_context c, Z3_ast arg1, Z3_ast arg2);
1472 // ```
1473 ADD_ONE_TO_ONE_PATTERN(IntModOp, "Z3_mk_mod", 2);
1474
1475#undef ADD_VARIADIC_PATTERN
1476#undef ADD_ONE_TO_ONE_PATTERN
1477
1478 // Lower `smt.eq` operations which allows a variadic number of operands
1479 // according to the `:chainable` attribute. The Z3 API function does not
1480 // support a variadic number of operands, but exactly two:
1481 // ```
1482 // Z3_ast Z3_API Z3_mk_eq(Z3_context c, Z3_ast l, Z3_ast r)
1483 // ```
1484 // As a result, we first apply a rewrite pattern that unfolds chainable
1485 // operators and then lower it one-to-one to the API function. In this case,
1486 // this means:
1487 // ```
1488 // eq(a,b,c,d) ->
1489 // and(eq(a,b), eq(b,c), eq(c,d)) ->
1490 // and(Z3_mk_eq(ctx, a, b), Z3_mk_eq(ctx, b, c), Z3_mk_eq(ctx, c, d))
1491 // ```
1492 // The patterns for `smt.and` will then do the remaining work.
1493 patterns.add<LowerChainableSMTPattern<EqOp>>(converter, patterns.getContext(),
1494 globals, options);
1495 patterns.add<OneToOneSMTPattern<EqOp>>(converter, patterns.getContext(),
1496 globals, options, "Z3_mk_eq", 2);
1497
1498 // Other lowering patterns. Refer to their implementation directly for more
1499 // information.
1500 patterns.add<BVConstantOpLowering, DeclareFunOpLowering, AssertOpLowering,
1501 ResetOpLowering, PushOpLowering, PopOpLowering, CheckOpLowering,
1502 SolverOpLowering, ApplyFuncOpLowering, YieldOpLowering,
1503 RepeatOpLowering, ExtractOpLowering, BoolConstantOpLowering,
1504 IntConstantOpLowering, ArrayBroadcastOpLowering, BVCmpOpLowering,
1505 IntCmpOpLowering, IntAbsOpLowering, Int2BVOpLowering,
1506 BV2IntOpLowering, QuantifierLowering<ForallOp>,
1507 QuantifierLowering<ExistsOp>>(converter, patterns.getContext(),
1508 globals, options);
1509}
1510
1511void LowerSMTToZ3LLVMPass::runOnOperation() {
1512 LowerSMTToZ3LLVMOptions options;
1513 options.debug = debug;
1514
1515 // Check that the lowering is possible
1516 // Specifically, check that the use of set-logic ops is valid for z3
1517 auto setLogicCheck = getOperation().walk([&](SolverOp solverOp)
1518 -> WalkResult {
1519 // Check that solver ops only contain one set-logic op and that they're at
1520 // the start of the body
1521 auto setLogicOps = solverOp.getBodyRegion().getOps<smt::SetLogicOp>();
1522 auto numSetLogicOps = std::distance(setLogicOps.begin(), setLogicOps.end());
1523 if (numSetLogicOps > 1) {
1524 return solverOp.emitError(
1525 "multiple set-logic operations found in one solver operation - Z3 "
1526 "only supports setting the logic once");
1527 }
1528 if (numSetLogicOps == 1)
1529 // Check the only ops before the set-logic op are ConstantLike
1530 for (auto &blockOp : solverOp.getBodyRegion().getOps()) {
1531 if (isa<smt::SetLogicOp>(blockOp))
1532 break;
1533 if (!blockOp.hasTrait<OpTrait::ConstantLike>()) {
1534 return solverOp.emitError("set-logic operation must be the first "
1535 "non-constant operation in a solver "
1536 "operation");
1537 }
1538 }
1539 return WalkResult::advance();
1540 });
1541 if (setLogicCheck.wasInterrupted())
1542 return signalPassFailure();
1543
1544 // Set up the type converter
1545 LLVMTypeConverter converter(&getContext());
1547
1548 RewritePatternSet patterns(&getContext());
1549
1550 // Populate the func to LLVM conversion patterns for two reasons:
1551 // * Typically functions are represented using `func.func` and including the
1552 // patterns to lower them here is more convenient for most lowering
1553 // pipelines (avoids running another pass).
1554 // * Already having `llvm.func` in the input or lowering `func.func` before
1555 // the SMT in the body leads to issues because the SCF conversion patterns
1556 // don't take the type converter into consideration and thus create blocks
1557 // with the old types for block arguments. However, the conversion happens
1558 // top-down and thus are assumed to be converted by the parent function op
1559 // which at that point would have already been lowered (and the blocks are
1560 // also not there when doing everything in one pass, i.e.,
1561 // `populateAnyFunctionOpInterfaceTypeConversionPattern` does not have any
1562 // effect as well). Are the SCF lowering patterns actually broken and should
1563 // take a type-converter?
1564 populateFuncToLLVMConversionPatterns(converter, patterns);
1565 arith::populateArithToLLVMConversionPatterns(converter, patterns);
1566
1567 // Populate SCF to CF and CF to LLVM lowering patterns because we create
1568 // `scf.if` operations in the lowering patterns for convenience (given the
1569 // above issue we might want to lower to LLVM directly; or fix upstream?)
1570 populateSCFToControlFlowConversionPatterns(patterns);
1571 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
1572
1573 // Create the globals to store the context and solver and populate the SMT
1574 // lowering patterns.
1575 OpBuilder builder(&getContext());
1576 auto globals = SMTGlobalsHandler::create(builder, getOperation());
1577 populateSMTToZ3LLVMConversionPatterns(patterns, converter, globals, options);
1578
1579 // Do a full conversion. This assumes that all other dialects have been
1580 // lowered before this pass already.
1581 LLVMConversionTarget target(getContext());
1582 target.addLegalOp<mlir::ModuleOp>();
1583 target.addLegalOp<scf::YieldOp>();
1584
1585 if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
1586 return signalPassFailure();
1587}
static FIRRTLBaseType convertType(FIRRTLBaseType type)
Returns null type if no conversion is needed.
Definition DropConst.cpp:32
#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS)
#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS)
RewritePatternSet pattern
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateSMTToZ3LLVMTypeConverter(TypeConverter &converter)
Populate the given type converter with the SMT to LLVM type conversions.
void populateSMTToZ3LLVMConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter, SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options)
Add the SMT to LLVM IR conversion patterns to 'patterns'.
Definition debug.py:1
Definition smt.py:1
A symbol cache for LLVM globals and functions relevant to SMT lowering patterns.
Definition SMTToZ3LLVM.h:25
static SMTGlobalsHandler create(OpBuilder &builder, ModuleOp module)
Creates the LLVM global operations to store the pointers to the solver and the context and returns a ...
SMTGlobalsHandler(ModuleOp module, mlir::LLVM::GlobalOp solver, mlir::LLVM::GlobalOp ctx)
Initializes the caches and keeps track of the given globals to store the pointers to the SMT solver a...