Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
LowerSMTToZ3LLVM.cpp
Go to the documentation of this file.
1//===- LowerSMTToZ3LLVM.cpp -----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
12#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
13#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
14#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
16#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
17#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
20#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
21#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/Dialect/SMT/IR/SMTOps.h"
24#include "mlir/IR/BuiltinDialect.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Transforms/DialectConversion.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/Debug.h"
29
30#define DEBUG_TYPE "lower-smt-to-z3-llvm"
31
32namespace circt {
33#define GEN_PASS_DEF_LOWERSMTTOZ3LLVM
34#include "circt/Conversion/Passes.h.inc"
35} // namespace circt
36
37using namespace mlir;
38using namespace circt;
39using namespace smt;
40
41//===----------------------------------------------------------------------===//
42// SMTGlobalHandler implementation
43//===----------------------------------------------------------------------===//
44
46 ModuleOp module) {
47 OpBuilder::InsertionGuard guard(builder);
48 builder.setInsertionPointToStart(module.getBody());
49
50 SymbolCache symCache;
51 symCache.addDefinitions(module);
53 names.add(symCache);
54
55 Location loc = module.getLoc();
56 auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
57
58 auto createGlobal = [&](StringRef namePrefix) {
59 auto global = builder.create<LLVM::GlobalOp>(
60 loc, ptrTy, false, LLVM::Linkage::Internal, names.newName(namePrefix),
61 Attribute{}, /*alignment=*/8);
62 OpBuilder::InsertionGuard g(builder);
63 builder.createBlock(&global.getInitializer());
64 Value res = builder.create<LLVM::ZeroOp>(loc, ptrTy);
65 builder.create<LLVM::ReturnOp>(loc, res);
66 return global;
67 };
68
69 auto ctxGlobal = createGlobal("ctx");
70 auto solverGlobal = createGlobal("solver");
71
72 return SMTGlobalsHandler(std::move(names), solverGlobal, ctxGlobal);
73}
74
76 mlir::LLVM::GlobalOp solver,
77 mlir::LLVM::GlobalOp ctx)
78 : solver(solver), ctx(ctx), names(names) {}
79
81 mlir::LLVM::GlobalOp solver,
82 mlir::LLVM::GlobalOp ctx)
83 : solver(solver), ctx(ctx) {
84 SymbolCache symCache;
85 symCache.addDefinitions(module);
86 names.add(symCache);
87}
88
89//===----------------------------------------------------------------------===//
90// Lowering Pattern Base
91//===----------------------------------------------------------------------===//
92
93namespace {
94
95template <typename OpTy>
96class SMTLoweringPattern : public OpConversionPattern<OpTy> {
97public:
98 SMTLoweringPattern(const TypeConverter &typeConverter, MLIRContext *context,
99 SMTGlobalsHandler &globals,
100 const LowerSMTToZ3LLVMOptions &options)
101 : OpConversionPattern<OpTy>(typeConverter, context), globals(globals),
102 options(options) {}
103
104private:
105 Value buildGlobalPtrToGlobal(OpBuilder &builder, Location loc,
106 LLVM::GlobalOp global,
107 DenseMap<Block *, Value> &cache) const {
108 Block *block = builder.getBlock();
109 if (auto iter = cache.find(block); iter != cache.end())
110 return iter->getSecond();
111
112 OpBuilder::InsertionGuard g(builder);
113 builder.setInsertionPointToStart(block);
114 Value globalAddr = builder.create<LLVM::AddressOfOp>(loc, global);
115 return cache[block] = builder.create<LLVM::LoadOp>(
116 loc, LLVM::LLVMPointerType::get(builder.getContext()),
117 globalAddr);
118 }
119
120protected:
121 /// A convenience function to get the pointer to the context from the 'global'
122 /// operation. The result is cached for each basic block, i.e., it is assumed
123 /// that this function is never called in the same basic block again at a
124 /// location (insertion point of the 'builder') not dominating all previous
125 /// locations this function was called at.
126 Value buildContextPtr(OpBuilder &builder, Location loc) const {
127 return buildGlobalPtrToGlobal(builder, loc, globals.ctx, globals.ctxCache);
128 }
129
130 /// A convenience function to get the pointer to the solver from the 'global'
131 /// operation. The result is cached for each basic block, i.e., it is assumed
132 /// that this function is never called in the same basic block again at a
133 /// location (insertion point of the 'builder') not dominating all previous
134 /// locations this function was called at.
135 Value buildSolverPtr(OpBuilder &builder, Location loc) const {
136 return buildGlobalPtrToGlobal(builder, loc, globals.solver,
137 globals.solverCache);
138 }
139
140 /// Create a `llvm.call` operation to a function with the given 'name' and
141 /// 'type'. If there does not already exist a (external) function with that
142 /// name create a matching external function declaration.
143 LLVM::CallOp buildCall(OpBuilder &builder, Location loc, StringRef name,
144 LLVM::LLVMFunctionType funcType,
145 ValueRange args) const {
146 auto &funcOp = globals.funcMap[builder.getStringAttr(name)];
147 if (!funcOp) {
148 OpBuilder::InsertionGuard guard(builder);
149 auto module =
150 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
151 builder.setInsertionPointToEnd(module.getBody());
152 auto funcOpResult = LLVM::lookupOrCreateFn(
153 builder, module, name, funcType.getParams(), funcType.getReturnType(),
154 funcType.getVarArg());
155 assert(succeeded(funcOpResult) && "expected to lookup or create printf");
156 funcOp = funcOpResult.value();
157 }
158 return builder.create<LLVM::CallOp>(loc, funcOp, args);
159 }
160
161 /// Build a global constant for the given string and construct an 'addressof'
162 /// operation at the current 'builder' insertion point to get a pointer to it.
163 /// Multiple calls with the same string will reuse the same global. It is
164 /// guaranteed that the symbol of the global will be unique.
165 Value buildString(OpBuilder &builder, Location loc, StringRef str) const {
166 auto &global = globals.stringCache[builder.getStringAttr(str)];
167 if (!global) {
168 OpBuilder::InsertionGuard guard(builder);
169 auto module =
170 builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
171 builder.setInsertionPointToEnd(module.getBody());
172 auto arrayTy =
173 LLVM::LLVMArrayType::get(builder.getI8Type(), str.size() + 1);
174 auto strAttr = builder.getStringAttr(str.str() + '\00');
175 global = builder.create<LLVM::GlobalOp>(
176 loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
177 globals.names.newName("str"), strAttr);
178 }
179 return builder.create<LLVM::AddressOfOp>(loc, global);
180 }
181 /// Most API functions require a pointer to the the Z3 context object as the
182 /// first argument. This helper function prepends this pointer value to the
183 /// call for convenience.
184 LLVM::CallOp buildAPICallWithContext(OpBuilder &builder, Location loc,
185 StringRef name, Type returnType,
186 ValueRange args = {}) const {
187 auto ctx = buildContextPtr(builder, loc);
188 SmallVector<Value> arguments;
189 arguments.emplace_back(ctx);
190 arguments.append(SmallVector<Value>(args));
191 return buildCall(
192 builder, loc, name,
193 LLVM::LLVMFunctionType::get(
194 returnType, SmallVector<Type>(ValueRange(arguments).getTypes())),
195 arguments);
196 }
197
198 /// Most API functions we need to call return a 'Z3_AST' object which is a
199 /// pointer in LLVM. This helper function simplifies calling those API
200 /// functions.
201 Value buildPtrAPICall(OpBuilder &builder, Location loc, StringRef name,
202 ValueRange args = {}) const {
203 return buildAPICallWithContext(
204 builder, loc, name,
205 LLVM::LLVMPointerType::get(builder.getContext()), args)
206 ->getResult(0);
207 }
208
209 /// Build a value representing the SMT sort given with 'type'.
210 Value buildSort(OpBuilder &builder, Location loc, Type type) const {
211 // NOTE: if a type not handled by this switch is passed, an assertion will
212 // be triggered.
213 return TypeSwitch<Type, Value>(type)
214 .Case([&](smt::IntType ty) {
215 return buildPtrAPICall(builder, loc, "Z3_mk_int_sort");
216 })
217 .Case([&](smt::BitVectorType ty) {
218 Value bitwidth = builder.create<LLVM::ConstantOp>(
219 loc, builder.getI32Type(), ty.getWidth());
220 return buildPtrAPICall(builder, loc, "Z3_mk_bv_sort", {bitwidth});
221 })
222 .Case([&](smt::BoolType ty) {
223 return buildPtrAPICall(builder, loc, "Z3_mk_bool_sort");
224 })
225 .Case([&](smt::SortType ty) {
226 Value str = buildString(builder, loc, ty.getIdentifier());
227 Value sym =
228 buildPtrAPICall(builder, loc, "Z3_mk_string_symbol", {str});
229 return buildPtrAPICall(builder, loc, "Z3_mk_uninterpreted_sort",
230 {sym});
231 })
232 .Case([&](smt::ArrayType ty) {
233 return buildPtrAPICall(builder, loc, "Z3_mk_array_sort",
234 {buildSort(builder, loc, ty.getDomainType()),
235 buildSort(builder, loc, ty.getRangeType())});
236 });
237 }
238
239 SMTGlobalsHandler &globals;
240 const LowerSMTToZ3LLVMOptions &options;
241};
242
243//===----------------------------------------------------------------------===//
244// Lowering Patterns
245//===----------------------------------------------------------------------===//
246
247/// The 'smt.declare_fun' operation is used to declare both constants and
248/// functions. The Z3 API, however, uses two different functions. Therefore,
249/// depending on the result type of this operation, one of the following two
250/// API functions is used to create the symbolic value:
251/// ```
252/// Z3_ast Z3_API Z3_mk_fresh_const(Z3_context c, Z3_string prefix, Z3_sort ty);
253/// Z3_func_decl Z3_API Z3_mk_fresh_func_decl(
254/// Z3_context c, Z3_string prefix, unsigned domain_size,
255/// Z3_sort const domain[], Z3_sort range);
256/// ```
257struct DeclareFunOpLowering : public SMTLoweringPattern<DeclareFunOp> {
258 using SMTLoweringPattern::SMTLoweringPattern;
259
260 LogicalResult
261 matchAndRewrite(DeclareFunOp op, OpAdaptor adaptor,
262 ConversionPatternRewriter &rewriter) const final {
263 Location loc = op.getLoc();
264
265 // Create the name prefix.
266 Value prefix;
267 if (adaptor.getNamePrefix())
268 prefix = buildString(rewriter, loc, *adaptor.getNamePrefix());
269 else
270 prefix = rewriter.create<LLVM::ZeroOp>(
271 loc, LLVM::LLVMPointerType::get(getContext()));
272
273 // Handle the constant value case.
274 if (!isa<SMTFuncType>(op.getType())) {
275 Value sort = buildSort(rewriter, loc, op.getType());
276 Value constDecl =
277 buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_const", {prefix, sort});
278 rewriter.replaceOp(op, constDecl);
279 return success();
280 }
281
282 // Otherwise, we declare a function.
283 Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
284 auto funcType = cast<SMTFuncType>(op.getResult().getType());
285 Value rangeSort = buildSort(rewriter, loc, funcType.getRangeType());
286
287 Type arrTy =
288 LLVM::LLVMArrayType::get(llvmPtrTy, funcType.getDomainTypes().size());
289
290 Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
291 for (auto [i, ty] : llvm::enumerate(funcType.getDomainTypes())) {
292 Value sort = buildSort(rewriter, loc, ty);
293 domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, sort, i);
294 }
295
296 Value one =
297 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
298 Value domainStorage =
299 rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
300 rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
301
302 Value domainSize = rewriter.create<LLVM::ConstantOp>(
303 loc, rewriter.getI32Type(), funcType.getDomainTypes().size());
304 Value decl =
305 buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_func_decl",
306 {prefix, domainSize, domainStorage, rangeSort});
307
308 rewriter.replaceOp(op, decl);
309 return success();
310 }
311};
312
313/// Lower the 'smt.apply_func' operation to Z3 API calls of the form:
314/// ```
315/// Z3_ast Z3_API Z3_mk_app(Z3_context c, Z3_func_decl d,
316/// unsigned num_args, Z3_ast const args[]);
317/// ```
318struct ApplyFuncOpLowering : public SMTLoweringPattern<ApplyFuncOp> {
319 using SMTLoweringPattern::SMTLoweringPattern;
320
321 LogicalResult
322 matchAndRewrite(ApplyFuncOp op, OpAdaptor adaptor,
323 ConversionPatternRewriter &rewriter) const final {
324 Location loc = op.getLoc();
325 Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
326 Type arrTy = LLVM::LLVMArrayType::get(llvmPtrTy, adaptor.getArgs().size());
327
328 // Create an array of the function arguments.
329 Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
330 for (auto [i, arg] : llvm::enumerate(adaptor.getArgs()))
331 domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, arg, i);
332
333 // Store the array on the stack.
334 Value one =
335 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
336 Value domainStorage =
337 rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
338 rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
339
340 // Call the API function with a pointer to the function, the number of
341 // arguments, and the pointer to the arguments stored on the stack.
342 Value domainSize = rewriter.create<LLVM::ConstantOp>(
343 loc, rewriter.getI32Type(), adaptor.getArgs().size());
344 Value returnVal =
345 buildPtrAPICall(rewriter, loc, "Z3_mk_app",
346 {adaptor.getFunc(), domainSize, domainStorage});
347 rewriter.replaceOp(op, returnVal);
348
349 return success();
350 }
351};
352
353/// Lower the `smt.bv.constant` operation to either
354/// ```
355/// Z3_ast Z3_API Z3_mk_unsigned_int64(Z3_context c, uint64_t v, Z3_sort ty);
356/// ```
357/// if the bit-vector fits into a 64-bit integer or convert it to a string and
358/// use the sligtly slower but arbitrary precision API function:
359/// ```
360/// Z3_ast Z3_API Z3_mk_numeral(Z3_context c, Z3_string numeral, Z3_sort ty);
361/// ```
362/// Note that there is also an API function taking an array of booleans, and
363/// while those are typically compiled to 'i8' in LLVM they don't necessarily
364/// have to (I think).
365struct BVConstantOpLowering : public SMTLoweringPattern<smt::BVConstantOp> {
366 using SMTLoweringPattern::SMTLoweringPattern;
367
368 LogicalResult
369 matchAndRewrite(smt::BVConstantOp op, OpAdaptor adaptor,
370 ConversionPatternRewriter &rewriter) const final {
371 Location loc = op.getLoc();
372 unsigned width = op.getType().getWidth();
373 auto bvSort = buildSort(rewriter, loc, op.getResult().getType());
374 APInt val = adaptor.getValue().getValue();
375
376 if (width <= 64) {
377 Value bvConst = rewriter.create<LLVM::ConstantOp>(
378 loc, rewriter.getI64Type(), val.getZExtValue());
379 Value res = buildPtrAPICall(rewriter, loc, "Z3_mk_unsigned_int64",
380 {bvConst, bvSort});
381 rewriter.replaceOp(op, res);
382 return success();
383 }
384
385 std::string str;
386 llvm::raw_string_ostream stream(str);
387 stream << val;
388 Value bvString = buildString(rewriter, loc, str);
389 Value bvNumeral =
390 buildPtrAPICall(rewriter, loc, "Z3_mk_numeral", {bvString, bvSort});
391
392 rewriter.replaceOp(op, bvNumeral);
393 return success();
394 }
395};
396
397/// Some of the Z3 API supports a variadic number of operands for some
398/// operations (in particular if the expansion would lead to a super-linear
399/// increase in operations such as with the ':pairwise' attribute). Those API
400/// calls take an 'unsigned' argument indicating the size of an array of
401/// pointers to the operands.
402template <typename SourceTy>
403struct VariadicSMTPattern : public SMTLoweringPattern<SourceTy> {
404 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
405
406 VariadicSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
407 SMTGlobalsHandler &globals,
408 const LowerSMTToZ3LLVMOptions &options,
409 StringRef apiFuncName, unsigned minNumArgs)
410 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
411 apiFuncName(apiFuncName), minNumArgs(minNumArgs) {}
412
413 LogicalResult
414 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
415 ConversionPatternRewriter &rewriter) const final {
416 if (adaptor.getOperands().size() < minNumArgs)
417 return failure();
418
419 Location loc = op.getLoc();
420 Value numOperands = rewriter.create<LLVM::ConstantOp>(
421 loc, rewriter.getI32Type(), op->getNumOperands());
422 Value constOne =
423 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
424 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
425 Type arrTy = LLVM::LLVMArrayType::get(ptrTy, op->getNumOperands());
426 Value storage =
427 rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
428 Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
429
430 for (auto [i, operand] : llvm::enumerate(adaptor.getOperands()))
431 array = rewriter.create<LLVM::InsertValueOp>(
432 loc, array, operand, ArrayRef<int64_t>{(int64_t)i});
433
434 rewriter.create<LLVM::StoreOp>(loc, array, storage);
435
436 rewriter.replaceOp(op,
437 SMTLoweringPattern<SourceTy>::buildPtrAPICall(
438 rewriter, loc, apiFuncName, {numOperands, storage}));
439 return success();
440 }
441
442private:
443 StringRef apiFuncName;
444 unsigned minNumArgs;
445};
446
447/// Lower an SMT operation to a function call with the name 'apiFuncName' with
448/// arguments matching the operands one-to-one.
449template <typename SourceTy>
450struct OneToOneSMTPattern : public SMTLoweringPattern<SourceTy> {
451 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
452
453 OneToOneSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
454 SMTGlobalsHandler &globals,
455 const LowerSMTToZ3LLVMOptions &options,
456 StringRef apiFuncName, unsigned numOperands)
457 : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
458 apiFuncName(apiFuncName), numOperands(numOperands) {}
459
460 LogicalResult
461 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
462 ConversionPatternRewriter &rewriter) const final {
463 if (adaptor.getOperands().size() != numOperands)
464 return failure();
465
466 rewriter.replaceOp(
467 op, SMTLoweringPattern<SourceTy>::buildPtrAPICall(
468 rewriter, op.getLoc(), apiFuncName, adaptor.getOperands()));
469 return success();
470 }
471
472private:
473 StringRef apiFuncName;
474 unsigned numOperands;
475};
476
477/// A pattern to lower SMT operations with a variadic number of operands
478/// modelling the ':chainable' attribute in SMT to binary operations.
479template <typename SourceTy>
480class LowerChainableSMTPattern : public SMTLoweringPattern<SourceTy> {
481 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
482 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
483
484 LogicalResult
485 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
486 ConversionPatternRewriter &rewriter) const final {
487 if (adaptor.getOperands().size() <= 2)
488 return failure();
489
490 Location loc = op.getLoc();
491 SmallVector<Value> elements;
492 for (int i = 1, e = adaptor.getOperands().size(); i < e; ++i) {
493 Value val = rewriter.create<SourceTy>(
494 loc, op->getResultTypes(),
495 ValueRange{adaptor.getOperands()[i - 1], adaptor.getOperands()[i]});
496 elements.push_back(val);
497 }
498 rewriter.replaceOpWithNewOp<smt::AndOp>(op, elements);
499 return success();
500 }
501};
502
503/// A pattern to lower SMT operations with a variadic number of operands
504/// modelling the `:left-assoc` attribute to a sequence of binary operators.
505template <typename SourceTy>
506class LowerLeftAssocSMTPattern : public SMTLoweringPattern<SourceTy> {
507 using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
508 using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
509
510 LogicalResult
511 matchAndRewrite(SourceTy op, OpAdaptor adaptor,
512 ConversionPatternRewriter &rewriter) const final {
513 if (adaptor.getOperands().size() <= 2)
514 return rewriter.notifyMatchFailure(op, "must have at least two operands");
515
516 Value runner = adaptor.getOperands()[0];
517 for (Value val : adaptor.getOperands().drop_front())
518 runner = rewriter.create<SourceTy>(op.getLoc(), op->getResultTypes(),
519 ValueRange{runner, val});
520
521 rewriter.replaceOp(op, runner);
522 return success();
523 }
524};
525
526/// The 'smt.solver' operation has a region that corresponds to the lifetime of
527/// the Z3 context and one solver instance created within this context.
528/// To create a context, a Z3 configuration has to be built first and various
529/// configuration parameters can be set before creating a context from it. Once
530/// we have a context, we can create a solver and store a pointer to the context
531/// and the solver in an LLVM global such that operations in the child region
532/// have access to them. While the context created with `Z3_mk_context` takes
533/// care of the reference counting of `Z3_AST` objects, it still requires manual
534/// reference counting of `Z3_solver` objects, therefore, we need to increase
535/// the ref. counter of the solver we get from `Z3_mk_solver` and must decrease
536/// it again once we don't need it anymore. Finally, the configuration object
537/// can be deleted.
538/// ```
539/// Z3_config Z3_API Z3_mk_config(void);
540/// void Z3_API Z3_set_param_value(Z3_config c, Z3_string param_id,
541/// Z3_string param_value);
542/// Z3_context Z3_API Z3_mk_context(Z3_config c);
543/// Z3_solver Z3_API Z3_mk_solver(Z3_context c);
544/// void Z3_API Z3_solver_inc_ref(Z3_context c, Z3_solver s);
545/// void Z3_API Z3_del_config(Z3_config c);
546/// ```
547/// At the end of the solver lifetime, we have to tell the context that we
548/// don't need the solver anymore and delete the context itself.
549/// ```
550/// void Z3_API Z3_solver_dec_ref(Z3_context c, Z3_solver s);
551/// void Z3_API Z3_del_context(Z3_context c);
552/// ```
553/// Note that the solver created here is a combined solver. There might be some
554/// potential for optimization by creating more specialized solvers supported by
555/// the Z3 API according the the kind of operations present in the body region.
556struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {
557 using SMTLoweringPattern::SMTLoweringPattern;
558
559 LogicalResult
560 matchAndRewrite(SolverOp op, OpAdaptor adaptor,
561 ConversionPatternRewriter &rewriter) const final {
562 Location loc = op.getLoc();
563 auto ptrTy = LLVM::LLVMPointerType::get(getContext());
564 auto voidTy = LLVM::LLVMVoidType::get(getContext());
565 auto ptrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, ptrTy);
566 auto ptrPtrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, {ptrTy, ptrTy});
567 auto ptrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, ptrTy);
568 auto ptrPtrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy});
569
570 // Create the configuration.
571 Value config = buildCall(rewriter, loc, "Z3_mk_config",
572 LLVM::LLVMFunctionType::get(ptrTy, {}), {})
573 .getResult();
574
575 // In debug-mode, we enable proofs such that we can fetch one in the 'unsat'
576 // region of each 'smt.check' operation.
577 if (options.debug) {
578 Value paramKey = buildString(rewriter, loc, "proof");
579 Value paramValue = buildString(rewriter, loc, "true");
580 buildCall(rewriter, loc, "Z3_set_param_value",
581 LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, ptrTy}),
582 {config, paramKey, paramValue});
583 }
584
585 // Check if the logic is set anywhere within the solver
586 std::optional<StringRef> logic = std::nullopt;
587 auto setLogicOps = op.getBodyRegion().getOps<smt::SetLogicOp>();
588 if (!setLogicOps.empty()) {
589 // We know from before patterns were applied that there is only one
590 // set_logic op
591 auto setLogicOp = *setLogicOps.begin();
592 logic = setLogicOp.getLogic();
593 rewriter.eraseOp(setLogicOp);
594 }
595
596 // Create the context and store a pointer to it in the global variable.
597 Value ctx = buildCall(rewriter, loc, "Z3_mk_context", ptrToPtrFunc, config)
598 .getResult();
599 Value ctxAddr =
600 rewriter.create<LLVM::AddressOfOp>(loc, globals.ctx).getResult();
601 rewriter.create<LLVM::StoreOp>(loc, ctx, ctxAddr);
602
603 // Delete the configuration again.
604 buildCall(rewriter, loc, "Z3_del_config", ptrToVoidFunc, {config});
605
606 // Create a solver instance, increase its reference counter, and store a
607 // pointer to it in the global variable.
608 Value solver;
609 if (logic) {
610 auto logicStr = buildString(rewriter, loc, logic.value());
611 solver = buildCall(rewriter, loc, "Z3_mk_solver_for_logic",
612 ptrPtrToPtrFunc, {ctx, logicStr})
613 ->getResult(0);
614 } else {
615 solver = buildCall(rewriter, loc, "Z3_mk_solver", ptrToPtrFunc, ctx)
616 ->getResult(0);
617 }
618 buildCall(rewriter, loc, "Z3_solver_inc_ref", ptrPtrToVoidFunc,
619 {ctx, solver});
620 Value solverAddr =
621 rewriter.create<LLVM::AddressOfOp>(loc, globals.solver).getResult();
622 rewriter.create<LLVM::StoreOp>(loc, solver, solverAddr);
623
624 // This assumes that no constant hoisting of the like happens inbetween
625 // the patterns defined in this pass because once the solver initialization
626 // and deallocation calls are inserted and the body region is inlined,
627 // canonicalizations and folders applied inbetween lowering patterns might
628 // hoist the SMT constants which means they would access uninitialized
629 // global variables once they are lowered.
630 SmallVector<Type> convertedTypes;
631 if (failed(
632 typeConverter->convertTypes(op->getResultTypes(), convertedTypes)))
633 return failure();
634
635 func::FuncOp funcOp;
636 {
637 OpBuilder::InsertionGuard guard(rewriter);
638 auto module = op->getParentOfType<ModuleOp>();
639 rewriter.setInsertionPointToEnd(module.getBody());
640
641 funcOp = rewriter.create<func::FuncOp>(
642 loc, globals.names.newName("solver"),
643 rewriter.getFunctionType(adaptor.getInputs().getTypes(),
644 convertedTypes));
645 rewriter.inlineRegionBefore(op.getBodyRegion(), funcOp.getBody(),
646 funcOp.end());
647 }
648
649 ValueRange results =
650 rewriter.create<func::CallOp>(loc, funcOp, adaptor.getInputs())
651 ->getResults();
652
653 // At the end of the region, decrease the solver's reference counter and
654 // delete the context.
655 // NOTE: we cannot use the convenience helper here because we don't want to
656 // load the context from the global but use the result from the 'mk_context'
657 // call directly for two reasons:
658 // * avoid an unnecessary load
659 // * the caching mechanism of the context does not work here because it
660 // would reuse the loaded context from a earlier solver
661 buildCall(rewriter, loc, "Z3_solver_dec_ref", ptrPtrToVoidFunc,
662 {ctx, solver});
663 buildCall(rewriter, loc, "Z3_del_context", ptrToVoidFunc, ctx);
664
665 rewriter.replaceOp(op, results);
666 return success();
667 }
668};
669
670/// Lower `smt.assert` operations to Z3 API calls of the form:
671/// ```
672/// void Z3_API Z3_solver_assert(Z3_context c, Z3_solver s, Z3_ast a);
673/// ```
674struct AssertOpLowering : public SMTLoweringPattern<AssertOp> {
675 using SMTLoweringPattern::SMTLoweringPattern;
676
677 LogicalResult
678 matchAndRewrite(AssertOp op, OpAdaptor adaptor,
679 ConversionPatternRewriter &rewriter) const final {
680 Location loc = op.getLoc();
681 buildAPICallWithContext(
682 rewriter, loc, "Z3_solver_assert",
683 LLVM::LLVMVoidType::get(getContext()),
684 {buildSolverPtr(rewriter, loc), adaptor.getInput()});
685
686 rewriter.eraseOp(op);
687 return success();
688 }
689};
690
691/// Lower `smt.reset` operations to Z3 API calls of the form:
692/// ```
693/// void Z3_API Z3_solver_reset(Z3_context c, Z3_solver s);
694/// ```
695struct ResetOpLowering : public SMTLoweringPattern<ResetOp> {
696 using SMTLoweringPattern::SMTLoweringPattern;
697
698 LogicalResult
699 matchAndRewrite(ResetOp op, OpAdaptor adaptor,
700 ConversionPatternRewriter &rewriter) const final {
701 Location loc = op.getLoc();
702 buildAPICallWithContext(rewriter, loc, "Z3_solver_reset",
703 LLVM::LLVMVoidType::get(getContext()),
704 {buildSolverPtr(rewriter, loc)});
705
706 rewriter.eraseOp(op);
707 return success();
708 }
709};
710
711/// Lower `smt.push` operations to (repeated) Z3 API calls of the form:
712/// ```
713/// void Z3_API Z3_solver_push(Z3_context c, Z3_solver s);
714/// ```
715struct PushOpLowering : public SMTLoweringPattern<PushOp> {
716 using SMTLoweringPattern::SMTLoweringPattern;
717 LogicalResult
718 matchAndRewrite(PushOp op, OpAdaptor adaptor,
719 ConversionPatternRewriter &rewriter) const final {
720 Location loc = op.getLoc();
721 // SMTLIB allows multiple levels to be pushed with one push command, but the
722 // Z3 C API doesn't let you provide a number of levels for push calls so
723 // multiple calls have to be created.
724 for (uint32_t i = 0; i < op.getCount(); i++)
725 buildAPICallWithContext(rewriter, loc, "Z3_solver_push",
726 LLVM::LLVMVoidType::get(getContext()),
727 {buildSolverPtr(rewriter, loc)});
728 rewriter.eraseOp(op);
729 return success();
730 }
731};
732
733/// Lower `smt.pop` operations to Z3 API calls of the form:
734/// ```
735/// void Z3_API Z3_solver_pop(Z3_context c, Z3_solver s, unsigned n);
736/// ```
737struct PopOpLowering : public SMTLoweringPattern<PopOp> {
738 using SMTLoweringPattern::SMTLoweringPattern;
739 LogicalResult
740 matchAndRewrite(PopOp op, OpAdaptor adaptor,
741 ConversionPatternRewriter &rewriter) const final {
742 Location loc = op.getLoc();
743 Value constVal = rewriter.create<LLVM::ConstantOp>(
744 loc, rewriter.getI32Type(), op.getCount());
745 buildAPICallWithContext(rewriter, loc, "Z3_solver_pop",
746 LLVM::LLVMVoidType::get(getContext()),
747 {buildSolverPtr(rewriter, loc), constVal});
748 rewriter.eraseOp(op);
749 return success();
750 }
751};
752
753/// Lower `smt.yield` operations to `scf.yield` operations. This not necessary
754/// for the yield in `smt.solver` or in quantifiers since they are deleted
755/// directly by the parent operation, but makes the lowering of the `smt.check`
756/// operation simpler and more convenient since the regions get translated
757/// directly to regions of `scf.if` operations.
758struct YieldOpLowering : public SMTLoweringPattern<YieldOp> {
759 using SMTLoweringPattern::SMTLoweringPattern;
760
761 LogicalResult
762 matchAndRewrite(YieldOp op, OpAdaptor adaptor,
763 ConversionPatternRewriter &rewriter) const final {
764 if (op->getParentOfType<func::FuncOp>()) {
765 rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getValues());
766 return success();
767 }
768 if (op->getParentOfType<LLVM::LLVMFuncOp>()) {
769 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getValues());
770 return success();
771 }
772 if (isa<scf::SCFDialect>(op->getParentOp()->getDialect())) {
773 rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getValues());
774 return success();
775 }
776 return failure();
777 }
778};
779
780/// Lower `smt.check` operations to Z3 API calls and control-flow operations.
781/// ```
782/// Z3_lbool Z3_API Z3_solver_check(Z3_context c, Z3_solver s);
783///
784/// typedef enum
785/// {
786/// Z3_L_FALSE = -1, // means unsatisfiable here
787/// Z3_L_UNDEF, // means unknown here
788/// Z3_L_TRUE // means satisfiable here
789/// } Z3_lbool;
790/// ```
791struct CheckOpLowering : public SMTLoweringPattern<CheckOp> {
792 using SMTLoweringPattern::SMTLoweringPattern;
793
794 LogicalResult
795 matchAndRewrite(CheckOp op, OpAdaptor adaptor,
796 ConversionPatternRewriter &rewriter) const final {
797 Location loc = op.getLoc();
798 auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
799 auto printfType = LLVM::LLVMFunctionType::get(
800 LLVM::LLVMVoidType::get(rewriter.getContext()), {ptrTy}, true);
801
802 auto getHeaderString = [](const std::string &title) {
803 unsigned titleSize = title.size() + 2; // Add a space left and right
804 return std::string((80 - titleSize) / 2, '-') + " " + title + " " +
805 std::string((80 - titleSize + 1) / 2, '-') + "\n%s\n" +
806 std::string(80, '-') + "\n";
807 };
808
809 // Get the pointer to the solver instance.
810 Value solver = buildSolverPtr(rewriter, loc);
811
812 // In debug-mode, print the state of the solver before calling 'check-sat'
813 // on it. This prints the asserted SMT expressions.
814 if (options.debug) {
815 auto solverStringPtr =
816 buildPtrAPICall(rewriter, loc, "Z3_solver_to_string", {solver});
817 auto solverFormatString =
818 buildString(rewriter, loc, getHeaderString("Solver"));
819 buildCall(rewriter, op.getLoc(), "printf", printfType,
820 {solverFormatString, solverStringPtr});
821 }
822
823 // Convert the result types of the `smt.check` operation.
824 SmallVector<Type> resultTypes;
825 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
826 return failure();
827
828 // Call 'check-sat' and check if the assertions are satisfiable.
829 Value checkResult =
830 buildAPICallWithContext(rewriter, loc, "Z3_solver_check",
831 rewriter.getI32Type(), {solver})
832 ->getResult(0);
833 Value constOne =
834 rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), 1);
835 Value isSat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
836 checkResult, constOne);
837
838 // Simply inline the 'sat' region into the 'then' region of the 'scf.if'
839 auto satIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isSat);
840 rewriter.inlineRegionBefore(op.getSatRegion(), satIfOp.getThenRegion(),
841 satIfOp.getThenRegion().end());
842
843 // Otherwise, the 'else' block checks if the assertions are unsatisfiable or
844 // unknown. The corresponding regions can also be simply inlined into the
845 // two branches of this nested if-statement as well.
846 rewriter.createBlock(&satIfOp.getElseRegion());
847 Value constNegOne =
848 rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), -1);
849 Value isUnsat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
850 checkResult, constNegOne);
851 auto unsatIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isUnsat);
852 rewriter.create<scf::YieldOp>(loc, unsatIfOp->getResults());
853
854 rewriter.inlineRegionBefore(op.getUnsatRegion(), unsatIfOp.getThenRegion(),
855 unsatIfOp.getThenRegion().end());
856 rewriter.inlineRegionBefore(op.getUnknownRegion(),
857 unsatIfOp.getElseRegion(),
858 unsatIfOp.getElseRegion().end());
859
860 rewriter.replaceOp(op, satIfOp->getResults());
861
862 if (options.debug) {
863 // In debug-mode, if the assertions are unsatisfiable we can print the
864 // proof.
865 rewriter.setInsertionPointToStart(unsatIfOp.thenBlock());
866 auto proof = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_proof",
867 {solver});
868 auto stringPtr =
869 buildPtrAPICall(rewriter, op.getLoc(), "Z3_ast_to_string", {proof});
870 auto formatString =
871 buildString(rewriter, op.getLoc(), getHeaderString("Proof"));
872 buildCall(rewriter, op.getLoc(), "printf", printfType,
873 {formatString, stringPtr});
874
875 // In debug mode, if the assertions are satisfiable we can print the model
876 // (effectively a counter-example).
877 rewriter.setInsertionPointToStart(satIfOp.thenBlock());
878 auto model = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_model",
879 {solver});
880 auto modelStringPtr =
881 buildPtrAPICall(rewriter, op.getLoc(), "Z3_model_to_string", {model});
882 auto modelFormatString =
883 buildString(rewriter, op.getLoc(), getHeaderString("Model"));
884 buildCall(rewriter, op.getLoc(), "printf", printfType,
885 {modelFormatString, modelStringPtr});
886 }
887
888 return success();
889 }
890};
891
892/// Lower `smt.forall` and `smt.exists` operations to the following Z3 API call.
893/// ```
894/// Z3_ast Z3_API Z3_mk_{forall|exists}_const(
895/// Z3_context c,
896/// unsigned weight,
897/// unsigned num_bound,
898/// Z3_app const bound[],
899/// unsigned num_patterns,
900/// Z3_pattern const patterns[],
901/// Z3_ast body
902/// );
903/// ```
904/// All nested regions are inlined into the parent region and the block
905/// arguments are replaced with new `smt.declare_fun` constants that are also
906/// passed to the `bound` argument of above API function. Patterns are created
907/// with the following API function.
908/// ```
909/// Z3_pattern Z3_API Z3_mk_pattern(Z3_context c, unsigned num_patterns,
910/// Z3_ast const terms[]);
911/// ```
912/// Where each operand of the `smt.yield` in a pattern region is a 'term'.
913template <typename QuantifierOp>
914struct QuantifierLowering : public SMTLoweringPattern<QuantifierOp> {
915 using SMTLoweringPattern<QuantifierOp>::SMTLoweringPattern;
916 using SMTLoweringPattern<QuantifierOp>::typeConverter;
917 using SMTLoweringPattern<QuantifierOp>::buildPtrAPICall;
918 using OpAdaptor = typename QuantifierOp::Adaptor;
919
920 Value createStorageForValueList(ValueRange values, Location loc,
921 ConversionPatternRewriter &rewriter) const {
922 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
923 Type arrTy = LLVM::LLVMArrayType::get(ptrTy, values.size());
924 Value constOne =
925 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
926 Value storage =
927 rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
928 Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
929
930 for (auto [i, val] : llvm::enumerate(values))
931 array = rewriter.create<LLVM::InsertValueOp>(loc, array, val,
932 ArrayRef<int64_t>(i));
933
934 rewriter.create<LLVM::StoreOp>(loc, array, storage);
935
936 return storage;
937 }
938
939 LogicalResult
940 matchAndRewrite(QuantifierOp op, OpAdaptor adaptor,
941 ConversionPatternRewriter &rewriter) const final {
942 Location loc = op.getLoc();
943 Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
944
945 // no-pattern attribute not supported yet because the Z3 CAPI allows more
946 // fine-grained control where a list of patterns to be banned can be given.
947 // This means, the no-pattern attribute is equivalent to providing a list of
948 // all possible sub-expressions in the quantifier body to the CAPI.
949 if (adaptor.getNoPattern())
950 return rewriter.notifyMatchFailure(
951 op, "no-pattern attribute not yet supported!");
952
953 rewriter.setInsertionPoint(op);
954
955 // Weight attribute
956 Value weight = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
957 adaptor.getWeight());
958
959 // Bound variables
960 unsigned numDecls = op.getBody().getNumArguments();
961 Value numDeclsVal =
962 rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), numDecls);
963
964 // We replace the block arguments with constant symbolic values and inform
965 // the quantifier API call which constants it should treat as bound
966 // variables. We also need to make sure that we use the exact same SSA
967 // values in the pattern regions since we lower constant declaration
968 // operation to always produce fresh constants.
969 SmallVector<Value> repl;
970 for (auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) {
971 Value newArg;
972 if (adaptor.getBoundVarNames().has_value())
973 newArg = rewriter.create<smt::DeclareFunOp>(
974 loc, arg.getType(),
975 cast<StringAttr>((*adaptor.getBoundVarNames())[i]));
976 else
977 newArg = rewriter.create<smt::DeclareFunOp>(loc, arg.getType());
978 repl.push_back(typeConverter->materializeTargetConversion(
979 rewriter, loc, typeConverter->convertType(arg.getType()), newArg));
980 }
981
982 Value boundStorage = createStorageForValueList(repl, loc, rewriter);
983
984 // Body Expression
985 auto yieldOp = cast<smt::YieldOp>(op.getBody().front().getTerminator());
986 Value bodyExp = yieldOp.getValues()[0];
987 rewriter.setInsertionPointAfterValue(bodyExp);
988 bodyExp = typeConverter->materializeTargetConversion(
989 rewriter, loc, typeConverter->convertType(bodyExp.getType()), bodyExp);
990 rewriter.eraseOp(yieldOp);
991
992 rewriter.inlineBlockBefore(&op.getBody().front(), op, repl);
993 rewriter.setInsertionPoint(op);
994
995 // Patterns
996 unsigned numPatterns = adaptor.getPatterns().size();
997 Value numPatternsVal = rewriter.create<LLVM::ConstantOp>(
998 loc, rewriter.getI32Type(), numPatterns);
999
1000 Value patternStorage;
1001 if (numPatterns > 0) {
1002 SmallVector<Value> patterns;
1003 for (Region *patternRegion : adaptor.getPatterns()) {
1004 auto yieldOp =
1005 cast<smt::YieldOp>(patternRegion->front().getTerminator());
1006 auto patternTerms = yieldOp.getOperands();
1007
1008 rewriter.setInsertionPoint(yieldOp);
1009 SmallVector<Value> patternList;
1010 for (auto val : patternTerms)
1011 patternList.push_back(typeConverter->materializeTargetConversion(
1012 rewriter, loc, typeConverter->convertType(val.getType()), val));
1013
1014 rewriter.eraseOp(yieldOp);
1015 rewriter.inlineBlockBefore(&patternRegion->front(), op, repl);
1016
1017 rewriter.setInsertionPoint(op);
1018 Value numTerms = rewriter.create<LLVM::ConstantOp>(
1019 loc, rewriter.getI32Type(), patternTerms.size());
1020 Value patternTermStorage =
1021 createStorageForValueList(patternList, loc, rewriter);
1022 Value pattern = buildPtrAPICall(rewriter, loc, "Z3_mk_pattern",
1023 {numTerms, patternTermStorage});
1024
1025 patterns.emplace_back(pattern);
1026 }
1027 patternStorage = createStorageForValueList(patterns, loc, rewriter);
1028 } else {
1029 // If we set the num_patterns parameter to 0, we can just pass a nullptr
1030 // as storage.
1031 patternStorage = rewriter.create<LLVM::ZeroOp>(loc, ptrTy);
1032 }
1033
1034 StringRef apiCallName = "Z3_mk_forall_const";
1035 if (std::is_same_v<QuantifierOp, ExistsOp>)
1036 apiCallName = "Z3_mk_exists_const";
1037 Value quantifierExp =
1038 buildPtrAPICall(rewriter, loc, apiCallName,
1039 {weight, numDeclsVal, boundStorage, numPatternsVal,
1040 patternStorage, bodyExp});
1041
1042 rewriter.replaceOp(op, quantifierExp);
1043 return success();
1044 }
1045};
1046
1047/// Lower `smt.bv.repeat` operations to Z3 API function calls of the form
1048/// ```
1049/// Z3_ast Z3_API Z3_mk_repeat(Z3_context c, unsigned i, Z3_ast t1);
1050/// ```
1051struct RepeatOpLowering : public SMTLoweringPattern<RepeatOp> {
1052 using SMTLoweringPattern::SMTLoweringPattern;
1053
1054 LogicalResult
1055 matchAndRewrite(RepeatOp op, OpAdaptor adaptor,
1056 ConversionPatternRewriter &rewriter) const final {
1057 Value count = rewriter.create<LLVM::ConstantOp>(
1058 op.getLoc(), rewriter.getI32Type(), op.getCount());
1059 rewriter.replaceOp(op,
1060 buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_repeat",
1061 {count, adaptor.getInput()}));
1062 return success();
1063 }
1064};
1065
1066/// Lower `smt.bv.extract` operations to Z3 API function calls of the following
1067/// form, where the output bit-vector has size `n = high - low + 1`. This means,
1068/// both the 'high' and 'low' indices are inclusive.
1069/// ```
1070/// Z3_ast Z3_API Z3_mk_extract(Z3_context c, unsigned high, unsigned low,
1071/// Z3_ast t1);
1072/// ```
1073struct ExtractOpLowering : public SMTLoweringPattern<ExtractOp> {
1074 using SMTLoweringPattern::SMTLoweringPattern;
1075
1076 LogicalResult
1077 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
1078 ConversionPatternRewriter &rewriter) const final {
1079 Location loc = op.getLoc();
1080 Value low = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
1081 adaptor.getLowBit());
1082 Value high = rewriter.create<LLVM::ConstantOp>(
1083 loc, rewriter.getI32Type(),
1084 adaptor.getLowBit() + op.getType().getWidth() - 1);
1085 rewriter.replaceOp(op, buildPtrAPICall(rewriter, loc, "Z3_mk_extract",
1086 {high, low, adaptor.getInput()}));
1087 return success();
1088 }
1089};
1090
1091/// Lower `smt.array.broadcast` operations to Z3 API function calls of the form
1092/// ```
1093/// Z3_ast Z3_API Z3_mk_const_array(Z3_context c, Z3_sort domain, Z3_ast v);
1094/// ```
1095struct ArrayBroadcastOpLowering
1096 : public SMTLoweringPattern<smt::ArrayBroadcastOp> {
1097 using SMTLoweringPattern::SMTLoweringPattern;
1098
1099 LogicalResult
1100 matchAndRewrite(smt::ArrayBroadcastOp op, OpAdaptor adaptor,
1101 ConversionPatternRewriter &rewriter) const final {
1102 auto domainSort = buildSort(
1103 rewriter, op.getLoc(),
1104 cast<smt::ArrayType>(op.getResult().getType()).getDomainType());
1105
1106 rewriter.replaceOp(op, buildPtrAPICall(rewriter, op.getLoc(),
1107 "Z3_mk_const_array",
1108 {domainSort, adaptor.getValue()}));
1109 return success();
1110 }
1111};
1112
1113/// Lower the `smt.constant` operation to one of the following Z3 API function
1114/// calls depending on the value of the boolean attribute.
1115/// ```
1116/// Z3_ast Z3_API Z3_mk_true(Z3_context c);
1117/// Z3_ast Z3_API Z3_mk_false(Z3_context c);
1118/// ```
1119struct BoolConstantOpLowering : public SMTLoweringPattern<smt::BoolConstantOp> {
1120 using SMTLoweringPattern::SMTLoweringPattern;
1121
1122 LogicalResult
1123 matchAndRewrite(smt::BoolConstantOp op, OpAdaptor adaptor,
1124 ConversionPatternRewriter &rewriter) const final {
1125 rewriter.replaceOp(
1126 op, buildPtrAPICall(rewriter, op.getLoc(),
1127 adaptor.getValue() ? "Z3_mk_true" : "Z3_mk_false"));
1128 return success();
1129 }
1130};
1131
1132/// Lower `smt.int.constant` operations to one of the following two Z3 API
1133/// function calls depending on whether the storage APInt has a bit-width that
1134/// fits in a `uint64_t`.
1135/// ```
1136/// Z3_sort Z3_API Z3_mk_int_sort(Z3_context c);
1137///
1138/// Z3_ast Z3_API Z3_mk_int64(Z3_context c, int64_t v, Z3_sort ty);
1139///
1140/// Z3_ast Z3_API Z3_mk_numeral(Z3_context c, Z3_string numeral, Z3_sort ty);
1141/// Z3_ast Z3_API Z3_mk_unary_minus(Z3_context c, Z3_ast arg);
1142/// ```
1143struct IntConstantOpLowering : public SMTLoweringPattern<smt::IntConstantOp> {
1144 using SMTLoweringPattern::SMTLoweringPattern;
1145
1146 LogicalResult
1147 matchAndRewrite(smt::IntConstantOp op, OpAdaptor adaptor,
1148 ConversionPatternRewriter &rewriter) const final {
1149 Location loc = op.getLoc();
1150 Value type = buildPtrAPICall(rewriter, loc, "Z3_mk_int_sort");
1151 if (adaptor.getValue().getBitWidth() <= 64) {
1152 Value val = rewriter.create<LLVM::ConstantOp>(
1153 loc, rewriter.getI64Type(), adaptor.getValue().getSExtValue());
1154 rewriter.replaceOp(
1155 op, buildPtrAPICall(rewriter, loc, "Z3_mk_int64", {val, type}));
1156 return success();
1157 }
1158
1159 std::string numeralStr;
1160 llvm::raw_string_ostream stream(numeralStr);
1161 stream << adaptor.getValue().abs();
1162
1163 Value numeral = buildString(rewriter, loc, numeralStr);
1164 Value intNumeral =
1165 buildPtrAPICall(rewriter, loc, "Z3_mk_numeral", {numeral, type});
1166
1167 if (adaptor.getValue().isNegative())
1168 intNumeral =
1169 buildPtrAPICall(rewriter, loc, "Z3_mk_unary_minus", intNumeral);
1170
1171 rewriter.replaceOp(op, intNumeral);
1172 return success();
1173 }
1174};
1175
1176/// Lower `smt.int.cmp` operations to one of the following Z3 API function calls
1177/// depending on the predicate.
1178/// ```
1179/// Z3_ast Z3_API Z3_mk_{{pred}}(Z3_context c, Z3_ast t1, Z3_ast t2);
1180/// ```
1181struct IntCmpOpLowering : public SMTLoweringPattern<IntCmpOp> {
1182 using SMTLoweringPattern::SMTLoweringPattern;
1183
1184 LogicalResult
1185 matchAndRewrite(IntCmpOp op, OpAdaptor adaptor,
1186 ConversionPatternRewriter &rewriter) const final {
1187 rewriter.replaceOp(
1188 op,
1189 buildPtrAPICall(rewriter, op.getLoc(),
1190 "Z3_mk_" + stringifyIntPredicate(op.getPred()).str(),
1191 {adaptor.getLhs(), adaptor.getRhs()}));
1192 return success();
1193 }
1194};
1195
1196/// Lower `smt.int2bv` operations to the following Z3 API function calls.
1197/// ```
1198/// Z3_ast Z3_API Z3_mk_int2bv(Z3_context c, unsigned n, Z3_ast t1);
1199/// ```
1200struct Int2BVOpLowering : public SMTLoweringPattern<Int2BVOp> {
1201 using SMTLoweringPattern::SMTLoweringPattern;
1202
1203 LogicalResult
1204 matchAndRewrite(Int2BVOp op, OpAdaptor adaptor,
1205 ConversionPatternRewriter &rewriter) const final {
1206 Value widthConst =
1207 rewriter.create<LLVM::ConstantOp>(op->getLoc(), rewriter.getI32Type(),
1208 op.getResult().getType().getWidth());
1209 rewriter.replaceOp(op,
1210 buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_int2bv",
1211 {widthConst, adaptor.getInput()}));
1212 return success();
1213 }
1214};
1215
1216/// Lower `smt.bv2int` operations to the following Z3 API function call.
1217/// ```
1218/// Z3_ast Z3_API Z3_mk_bv2int(Z3_context c, Z3_ast t1, bool is_signed)
1219/// ```
1220struct BV2IntOpLowering : public SMTLoweringPattern<BV2IntOp> {
1221 using SMTLoweringPattern::SMTLoweringPattern;
1222
1223 LogicalResult
1224 matchAndRewrite(BV2IntOp op, OpAdaptor adaptor,
1225 ConversionPatternRewriter &rewriter) const final {
1226 // FIXME: ideally we don't want to use i1 here, since bools can sometimes be
1227 // compiled to wider widths in LLVM
1228 Value isSignedConst = rewriter.create<LLVM::ConstantOp>(
1229 op->getLoc(), rewriter.getI1Type(), op.getIsSigned());
1230 rewriter.replaceOp(op,
1231 buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_bv2int",
1232 {adaptor.getInput(), isSignedConst}));
1233 return success();
1234 }
1235};
1236
1237/// Lower `smt.bv.cmp` operations to one of the following Z3 API function calls,
1238/// performing two's complement comparison, depending on the predicate
1239/// attribute.
1240/// ```
1241/// Z3_ast Z3_API Z3_mk_bv{{pred}}(Z3_context c, Z3_ast t1, Z3_ast t2);
1242/// ```
1243struct BVCmpOpLowering : public SMTLoweringPattern<BVCmpOp> {
1244 using SMTLoweringPattern::SMTLoweringPattern;
1245
1246 LogicalResult
1247 matchAndRewrite(BVCmpOp op, OpAdaptor adaptor,
1248 ConversionPatternRewriter &rewriter) const final {
1249 rewriter.replaceOp(
1250 op, buildPtrAPICall(rewriter, op.getLoc(),
1251 "Z3_mk_bv" +
1252 stringifyBVCmpPredicate(op.getPred()).str(),
1253 {adaptor.getLhs(), adaptor.getRhs()}));
1254 return success();
1255 }
1256};
1257
1258/// Expand the `smt.int.abs` operation to a `smt.ite` operation.
1259struct IntAbsOpLowering : public SMTLoweringPattern<IntAbsOp> {
1260 using SMTLoweringPattern::SMTLoweringPattern;
1261
1262 LogicalResult
1263 matchAndRewrite(IntAbsOp op, OpAdaptor adaptor,
1264 ConversionPatternRewriter &rewriter) const final {
1265 Location loc = op.getLoc();
1266 Value zero = rewriter.create<IntConstantOp>(
1267 loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
1268 Value cmp = rewriter.create<IntCmpOp>(loc, IntPredicate::lt,
1269 adaptor.getInput(), zero);
1270 Value neg = rewriter.create<IntSubOp>(loc, zero, adaptor.getInput());
1271 rewriter.replaceOpWithNewOp<IteOp>(op, cmp, neg, adaptor.getInput());
1272 return success();
1273 }
1274};
1275
1276} // namespace
1277
1278//===----------------------------------------------------------------------===//
1279// Pass Implementation
1280//===----------------------------------------------------------------------===//
1281
1282namespace {
1283struct LowerSMTToZ3LLVMPass
1284 : public circt::impl::LowerSMTToZ3LLVMBase<LowerSMTToZ3LLVMPass> {
1285 using Base::Base;
1286 void runOnOperation() override;
1287};
1288} // namespace
1289
1290void circt::populateSMTToZ3LLVMTypeConverter(TypeConverter &converter) {
1291 converter.addConversion([](smt::BoolType type) {
1292 return LLVM::LLVMPointerType::get(type.getContext());
1293 });
1294 converter.addConversion([](smt::BitVectorType type) {
1295 return LLVM::LLVMPointerType::get(type.getContext());
1296 });
1297 converter.addConversion([](smt::ArrayType type) {
1298 return LLVM::LLVMPointerType::get(type.getContext());
1299 });
1300 converter.addConversion([](smt::IntType type) {
1301 return LLVM::LLVMPointerType::get(type.getContext());
1302 });
1303 converter.addConversion([](smt::SMTFuncType type) {
1304 return LLVM::LLVMPointerType::get(type.getContext());
1305 });
1306 converter.addConversion([](smt::SortType type) {
1307 return LLVM::LLVMPointerType::get(type.getContext());
1308 });
1309}
1310
1312 RewritePatternSet &patterns, TypeConverter &converter,
1313 SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options) {
1314#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS) \
1315 patterns.add<VariadicSMTPattern<OP>>(/*NOLINT(bugprone-macro-parentheses)*/ \
1316 converter, patterns.getContext(), \
1317 globals, options, APINAME, \
1318 MIN_NUM_ARGS);
1319
1320#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS) \
1321 patterns.add<OneToOneSMTPattern<OP>>(/*NOLINT(bugprone-macro-parentheses)*/ \
1322 converter, patterns.getContext(), \
1323 globals, options, APINAME, NUM_ARGS);
1324
1325 // Lower `smt.distinct` operations which allows a variadic number of operands
1326 // according to the `:pairwise` attribute. The Z3 API function supports a
1327 // variadic number of operands as well, i.e., a direct lowering is possible:
1328 // ```
1329 // Z3_ast Z3_API Z3_mk_distinct(Z3_context c, unsigned num_args, Z3_ast const
1330 // args[])
1331 // ```
1332 // The API function requires num_args > 1 which is guaranteed to be satisfied
1333 // because `smt.distinct` is verified to have > 1 operands.
1334 ADD_VARIADIC_PATTERN(DistinctOp, "Z3_mk_distinct", 2);
1335
1336 // Lower `smt.and` operations which allows a variadic number of operands
1337 // according to the `:left-assoc` attribute. The Z3 API function supports a
1338 // variadic number of operands as well, i.e., a direct lowering is possible:
1339 // ```
1340 // Z3_ast Z3_API Z3_mk_and(Z3_context c, unsigned num_args, Z3_ast const
1341 // args[])
1342 // ```
1343 // The API function requires num_args > 1. This is not guaranteed by the
1344 // `smt.and` operation and thus the pattern will not apply when no operand is
1345 // present. The constant folder of the operation is assumed to fold this to
1346 // a constant 'true' (neutral element of AND).
1347 ADD_VARIADIC_PATTERN(AndOp, "Z3_mk_and", 2);
1348
1349 // Lower `smt.or` operations which allows a variadic number of operands
1350 // according to the `:left-assoc` attribute. The Z3 API function supports a
1351 // variadic number of operands as well, i.e., a direct lowering is possible:
1352 // ```
1353 // Z3_ast Z3_API Z3_mk_or(Z3_context c, unsigned num_args, Z3_ast const
1354 // args[])
1355 // ```
1356 // The API function requires num_args > 1. This is not guaranteed by the
1357 // `smt.or` operation and thus the pattern will not apply when no operand is
1358 // present. The constant folder of the operation is assumed to fold this to
1359 // a constant 'false' (neutral element of OR).
1360 ADD_VARIADIC_PATTERN(OrOp, "Z3_mk_or", 2);
1361
1362 // Lower `smt.not` operations to the following Z3 API function:
1363 // ```
1364 // Z3_ast Z3_API Z3_mk_not(Z3_context c, Z3_ast a);
1365 // ```
1366 ADD_ONE_TO_ONE_PATTERN(NotOp, "Z3_mk_not", 1);
1367
1368 // Lower `smt.xor` operations which allows a variadic number of operands
1369 // according to the `:left-assoc` attribute. The Z3 API function, however,
1370 // only takes two operands.
1371 // ```
1372 // Z3_ast Z3_API Z3_mk_xor(Z3_context c, Z3_ast t1, Z3_ast t2);
1373 // ```
1374 // Therefore, we need to decompose the operation first to a sequence of XOR
1375 // operations matching the left associative behavior.
1376 patterns.add<LowerLeftAssocSMTPattern<XOrOp>>(
1377 converter, patterns.getContext(), globals, options);
1378 ADD_ONE_TO_ONE_PATTERN(XOrOp, "Z3_mk_xor", 2);
1379
1380 // Lower `smt.implies` operations to the following Z3 API function:
1381 // ```
1382 // Z3_ast Z3_API Z3_mk_implies(Z3_context c, Z3_ast t1, Z3_ast t2);
1383 // ```
1384 ADD_ONE_TO_ONE_PATTERN(ImpliesOp, "Z3_mk_implies", 2);
1385
1386 // All the bit-vector arithmetic and bitwise operations conveniently lower to
1387 // Z3 API function calls with essentially matching names and a one-to-one
1388 // correspondence of operands to call arguments.
1389 ADD_ONE_TO_ONE_PATTERN(BVNegOp, "Z3_mk_bvneg", 1);
1390 ADD_ONE_TO_ONE_PATTERN(BVAddOp, "Z3_mk_bvadd", 2);
1391 ADD_ONE_TO_ONE_PATTERN(BVMulOp, "Z3_mk_bvmul", 2);
1392 ADD_ONE_TO_ONE_PATTERN(BVURemOp, "Z3_mk_bvurem", 2);
1393 ADD_ONE_TO_ONE_PATTERN(BVSRemOp, "Z3_mk_bvsrem", 2);
1394 ADD_ONE_TO_ONE_PATTERN(BVSModOp, "Z3_mk_bvsmod", 2);
1395 ADD_ONE_TO_ONE_PATTERN(BVUDivOp, "Z3_mk_bvudiv", 2);
1396 ADD_ONE_TO_ONE_PATTERN(BVSDivOp, "Z3_mk_bvsdiv", 2);
1397 ADD_ONE_TO_ONE_PATTERN(BVShlOp, "Z3_mk_bvshl", 2);
1398 ADD_ONE_TO_ONE_PATTERN(BVLShrOp, "Z3_mk_bvlshr", 2);
1399 ADD_ONE_TO_ONE_PATTERN(BVAShrOp, "Z3_mk_bvashr", 2);
1400 ADD_ONE_TO_ONE_PATTERN(BVNotOp, "Z3_mk_bvnot", 1);
1401 ADD_ONE_TO_ONE_PATTERN(BVAndOp, "Z3_mk_bvand", 2);
1402 ADD_ONE_TO_ONE_PATTERN(BVOrOp, "Z3_mk_bvor", 2);
1403 ADD_ONE_TO_ONE_PATTERN(BVXOrOp, "Z3_mk_bvxor", 2);
1404
1405 // The `smt.bv.concat` operation only supports two operands, just like the
1406 // Z3 API function.
1407 // ```
1408 // Z3_ast Z3_API Z3_mk_concat(Z3_context c, Z3_ast t1, Z3_ast t2);
1409 // ```
1410 ADD_ONE_TO_ONE_PATTERN(ConcatOp, "Z3_mk_concat", 2);
1411
1412 // Lower the `smt.ite` operation to the following Z3 API function call, where
1413 // `t1` must have boolean sort.
1414 // ```
1415 // Z3_ast Z3_API Z3_mk_ite(Z3_context c, Z3_ast t1, Z3_ast t2, Z3_ast t3);
1416 // ```
1417 ADD_ONE_TO_ONE_PATTERN(IteOp, "Z3_mk_ite", 3);
1418
1419 // Lower the `smt.array.select` operation to the following Z3 function call.
1420 // The operand declaration of the operation matches the order of arguments of
1421 // the API function.
1422 // ```
1423 // Z3_ast Z3_API Z3_mk_select(Z3_context c, Z3_ast a, Z3_ast i);
1424 // ```
1425 // Where `a` is the array expression and `i` is the index expression.
1426 ADD_ONE_TO_ONE_PATTERN(ArraySelectOp, "Z3_mk_select", 2);
1427
1428 // Lower the `smt.array.store` operation to the following Z3 function call.
1429 // The operand declaration of the operation matches the order of arguments of
1430 // the API function.
1431 // ```
1432 // Z3_ast Z3_API Z3_mk_store(Z3_context c, Z3_ast a, Z3_ast i, Z3_ast v);
1433 // ```
1434 // Where `a` is the array expression, `i` is the index expression, and `v` is
1435 // the value expression to be stored.
1436 ADD_ONE_TO_ONE_PATTERN(ArrayStoreOp, "Z3_mk_store", 3);
1437
1438 // Lower the `smt.int.add` operation to the following Z3 API function call.
1439 // ```
1440 // Z3_ast Z3_API Z3_mk_add(Z3_context c, unsigned num_args, Z3_ast const
1441 // args[]);
1442 // ```
1443 // The number of arguments must be greater than zero. Therefore, the pattern
1444 // will fail if applied to an operation with less than two operands.
1445 ADD_VARIADIC_PATTERN(IntAddOp, "Z3_mk_add", 2);
1446
1447 // Lower the `smt.int.mul` operation to the following Z3 API function call.
1448 // ```
1449 // Z3_ast Z3_API Z3_mk_mul(Z3_context c, unsigned num_args, Z3_ast const
1450 // args[]);
1451 // ```
1452 // The number of arguments must be greater than zero. Therefore, the pattern
1453 // will fail if applied to an operation with less than two operands.
1454 ADD_VARIADIC_PATTERN(IntMulOp, "Z3_mk_mul", 2);
1455
1456 // Lower the `smt.int.sub` operation to the following Z3 API function call.
1457 // ```
1458 // Z3_ast Z3_API Z3_mk_sub(Z3_context c, unsigned num_args, Z3_ast const
1459 // args[]);
1460 // ```
1461 // The number of arguments must be greater than zero. Since the `smt.int.sub`
1462 // operation always has exactly two operands, this trivially holds.
1463 ADD_VARIADIC_PATTERN(IntSubOp, "Z3_mk_sub", 2);
1464
1465 // Lower the `smt.int.div` operation to the following Z3 API function call.
1466 // ```
1467 // Z3_ast Z3_API Z3_mk_div(Z3_context c, Z3_ast arg1, Z3_ast arg2);
1468 // ```
1469 ADD_ONE_TO_ONE_PATTERN(IntDivOp, "Z3_mk_div", 2);
1470
1471 // Lower the `smt.int.mod` operation to the following Z3 API function call.
1472 // ```
1473 // Z3_ast Z3_API Z3_mk_mod(Z3_context c, Z3_ast arg1, Z3_ast arg2);
1474 // ```
1475 ADD_ONE_TO_ONE_PATTERN(IntModOp, "Z3_mk_mod", 2);
1476
1477#undef ADD_VARIADIC_PATTERN
1478#undef ADD_ONE_TO_ONE_PATTERN
1479
1480 // Lower `smt.eq` operations which allows a variadic number of operands
1481 // according to the `:chainable` attribute. The Z3 API function does not
1482 // support a variadic number of operands, but exactly two:
1483 // ```
1484 // Z3_ast Z3_API Z3_mk_eq(Z3_context c, Z3_ast l, Z3_ast r)
1485 // ```
1486 // As a result, we first apply a rewrite pattern that unfolds chainable
1487 // operators and then lower it one-to-one to the API function. In this case,
1488 // this means:
1489 // ```
1490 // eq(a,b,c,d) ->
1491 // and(eq(a,b), eq(b,c), eq(c,d)) ->
1492 // and(Z3_mk_eq(ctx, a, b), Z3_mk_eq(ctx, b, c), Z3_mk_eq(ctx, c, d))
1493 // ```
1494 // The patterns for `smt.and` will then do the remaining work.
1495 patterns.add<LowerChainableSMTPattern<EqOp>>(converter, patterns.getContext(),
1496 globals, options);
1497 patterns.add<OneToOneSMTPattern<EqOp>>(converter, patterns.getContext(),
1498 globals, options, "Z3_mk_eq", 2);
1499
1500 // Other lowering patterns. Refer to their implementation directly for more
1501 // information.
1502 patterns.add<BVConstantOpLowering, DeclareFunOpLowering, AssertOpLowering,
1503 ResetOpLowering, PushOpLowering, PopOpLowering, CheckOpLowering,
1504 SolverOpLowering, ApplyFuncOpLowering, YieldOpLowering,
1505 RepeatOpLowering, ExtractOpLowering, BoolConstantOpLowering,
1506 IntConstantOpLowering, ArrayBroadcastOpLowering, BVCmpOpLowering,
1507 IntCmpOpLowering, IntAbsOpLowering, Int2BVOpLowering,
1508 BV2IntOpLowering, QuantifierLowering<ForallOp>,
1509 QuantifierLowering<ExistsOp>>(converter, patterns.getContext(),
1510 globals, options);
1511}
1512
1513void LowerSMTToZ3LLVMPass::runOnOperation() {
1514 LowerSMTToZ3LLVMOptions options;
1515 options.debug = debug;
1516
1517 // Check that the lowering is possible
1518 // Specifically, check that the use of set-logic ops is valid for z3
1519 auto setLogicCheck = getOperation().walk([&](SolverOp solverOp)
1520 -> WalkResult {
1521 // Check that solver ops only contain one set-logic op and that they're at
1522 // the start of the body
1523 auto setLogicOps = solverOp.getBodyRegion().getOps<smt::SetLogicOp>();
1524 auto numSetLogicOps = std::distance(setLogicOps.begin(), setLogicOps.end());
1525 if (numSetLogicOps > 1) {
1526 return solverOp.emitError(
1527 "multiple set-logic operations found in one solver operation - Z3 "
1528 "only supports setting the logic once");
1529 }
1530 if (numSetLogicOps == 1)
1531 // Check the only ops before the set-logic op are ConstantLike
1532 for (auto &blockOp : solverOp.getBodyRegion().getOps()) {
1533 if (isa<smt::SetLogicOp>(blockOp))
1534 break;
1535 if (!blockOp.hasTrait<OpTrait::ConstantLike>()) {
1536 return solverOp.emitError("set-logic operation must be the first "
1537 "non-constant operation in a solver "
1538 "operation");
1539 }
1540 }
1541 return WalkResult::advance();
1542 });
1543 if (setLogicCheck.wasInterrupted())
1544 return signalPassFailure();
1545
1546 // Set up the type converter
1547 LLVMTypeConverter converter(&getContext());
1549
1550 RewritePatternSet patterns(&getContext());
1551
1552 // Populate the func to LLVM conversion patterns for two reasons:
1553 // * Typically functions are represented using `func.func` and including the
1554 // patterns to lower them here is more convenient for most lowering
1555 // pipelines (avoids running another pass).
1556 // * Already having `llvm.func` in the input or lowering `func.func` before
1557 // the SMT in the body leads to issues because the SCF conversion patterns
1558 // don't take the type converter into consideration and thus create blocks
1559 // with the old types for block arguments. However, the conversion happens
1560 // top-down and thus are assumed to be converted by the parent function op
1561 // which at that point would have already been lowered (and the blocks are
1562 // also not there when doing everything in one pass, i.e.,
1563 // `populateAnyFunctionOpInterfaceTypeConversionPattern` does not have any
1564 // effect as well). Are the SCF lowering patterns actually broken and should
1565 // take a type-converter?
1566 populateFuncToLLVMConversionPatterns(converter, patterns);
1567 arith::populateArithToLLVMConversionPatterns(converter, patterns);
1568
1569 // Populate SCF to CF and CF to LLVM lowering patterns because we create
1570 // `scf.if` operations in the lowering patterns for convenience (given the
1571 // above issue we might want to lower to LLVM directly; or fix upstream?)
1572 populateSCFToControlFlowConversionPatterns(patterns);
1573 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
1574
1575 // Create the globals to store the context and solver and populate the SMT
1576 // lowering patterns.
1577 OpBuilder builder(&getContext());
1578 auto globals = SMTGlobalsHandler::create(builder, getOperation());
1579 populateSMTToZ3LLVMConversionPatterns(patterns, converter, globals, options);
1580
1581 // Do a full conversion. This assumes that all other dialects have been
1582 // lowered before this pass already.
1583 LLVMConversionTarget target(getContext());
1584 target.addLegalOp<mlir::ModuleOp>();
1585 target.addLegalOp<scf::YieldOp>();
1586
1587 if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
1588 return signalPassFailure();
1589}
assert(baseType &&"element must be base type")
static FIRRTLBaseType convertType(FIRRTLBaseType type)
Returns null type if no conversion is needed.
Definition DropConst.cpp:32
#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS)
#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS)
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:212
RewritePatternSet pattern
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateSMTToZ3LLVMTypeConverter(TypeConverter &converter)
Populate the given type converter with the SMT to LLVM type conversions.
void populateSMTToZ3LLVMConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter, SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options)
Add the SMT to LLVM IR conversion patterns to 'patterns'.
Definition debug.py:1
A symbol cache for LLVM globals and functions relevant to SMT lowering patterns.
Definition SMTToZ3LLVM.h:25
static SMTGlobalsHandler create(OpBuilder &builder, ModuleOp module)
Creates the LLVM global operations to store the pointers to the solver and the context and returns a ...
SMTGlobalsHandler(ModuleOp module, mlir::LLVM::GlobalOp solver, mlir::LLVM::GlobalOp ctx)
Initializes the caches and keeps track of the given globals to store the pointers to the SMT solver a...