CIRCT  19.0.0git
LowerSMTToZ3LLVM.cpp
Go to the documentation of this file.
1 //===- LowerSMTToZ3LLVM.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
13 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
16 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
17 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
18 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
21 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
22 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 #include "mlir/Dialect/SCF/IR/SCF.h"
24 #include "mlir/IR/BuiltinDialect.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/DialectConversion.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 
30 #define DEBUG_TYPE "lower-smt-to-z3-llvm"
31 
32 namespace circt {
33 #define GEN_PASS_DEF_LOWERSMTTOZ3LLVM
34 #include "circt/Conversion/Passes.h.inc"
35 } // namespace circt
36 
37 using namespace mlir;
38 using namespace circt;
39 using namespace smt;
40 
41 //===----------------------------------------------------------------------===//
42 // SMTGlobalHandler implementation
43 //===----------------------------------------------------------------------===//
44 
45 SMTGlobalsHandler SMTGlobalsHandler::create(OpBuilder &builder,
46  ModuleOp module) {
47  OpBuilder::InsertionGuard guard(builder);
48  builder.setInsertionPointToStart(module.getBody());
49 
50  SymbolCache symCache;
51  symCache.addDefinitions(module);
52  Namespace names;
53  names.add(symCache);
54 
55  Location loc = module.getLoc();
56  auto ptrTy = LLVM::LLVMPointerType::get(builder.getContext());
57 
58  auto createGlobal = [&](StringRef namePrefix) {
59  auto global = builder.create<LLVM::GlobalOp>(
60  loc, ptrTy, false, LLVM::Linkage::Internal, names.newName(namePrefix),
61  Attribute{}, /*alignment=*/8);
62  OpBuilder::InsertionGuard g(builder);
63  builder.createBlock(&global.getInitializer());
64  Value res = builder.create<LLVM::ZeroOp>(loc, ptrTy);
65  builder.create<LLVM::ReturnOp>(loc, res);
66  return global;
67  };
68 
69  auto ctxGlobal = createGlobal("ctx");
70  auto solverGlobal = createGlobal("solver");
71 
72  return SMTGlobalsHandler(std::move(names), solverGlobal, ctxGlobal);
73 }
74 
75 SMTGlobalsHandler::SMTGlobalsHandler(Namespace &&names,
76  mlir::LLVM::GlobalOp solver,
77  mlir::LLVM::GlobalOp ctx)
78  : solver(solver), ctx(ctx), names(names) {}
79 
81  mlir::LLVM::GlobalOp solver,
82  mlir::LLVM::GlobalOp ctx)
83  : solver(solver), ctx(ctx) {
84  SymbolCache symCache;
85  symCache.addDefinitions(module);
86  names.add(symCache);
87 }
88 
89 //===----------------------------------------------------------------------===//
90 // Lowering Pattern Base
91 //===----------------------------------------------------------------------===//
92 
93 namespace {
94 
95 template <typename OpTy>
96 class SMTLoweringPattern : public OpConversionPattern<OpTy> {
97 public:
98  SMTLoweringPattern(const TypeConverter &typeConverter, MLIRContext *context,
99  SMTGlobalsHandler &globals,
100  const LowerSMTToZ3LLVMOptions &options)
101  : OpConversionPattern<OpTy>(typeConverter, context), globals(globals),
102  options(options) {}
103 
104 private:
105  Value buildGlobalPtrToGlobal(OpBuilder &builder, Location loc,
106  LLVM::GlobalOp global,
107  DenseMap<Block *, Value> &cache) const {
108  Block *block = builder.getBlock();
109  if (auto iter = cache.find(block); iter != cache.end())
110  return iter->getSecond();
111 
112  OpBuilder::InsertionGuard g(builder);
113  builder.setInsertionPointToStart(block);
114  Value globalAddr = builder.create<LLVM::AddressOfOp>(loc, global);
115  return cache[block] = builder.create<LLVM::LoadOp>(
116  loc, LLVM::LLVMPointerType::get(builder.getContext()),
117  globalAddr);
118  }
119 
120 protected:
121  /// A convenience function to get the pointer to the context from the 'global'
122  /// operation. The result is cached for each basic block, i.e., it is assumed
123  /// that this function is never called in the same basic block again at a
124  /// location (insertion point of the 'builder') not dominating all previous
125  /// locations this function was called at.
126  Value buildContextPtr(OpBuilder &builder, Location loc) const {
127  return buildGlobalPtrToGlobal(builder, loc, globals.ctx, globals.ctxCache);
128  }
129 
130  /// A convenience function to get the pointer to the solver from the 'global'
131  /// operation. The result is cached for each basic block, i.e., it is assumed
132  /// that this function is never called in the same basic block again at a
133  /// location (insertion point of the 'builder') not dominating all previous
134  /// locations this function was called at.
135  Value buildSolverPtr(OpBuilder &builder, Location loc) const {
136  return buildGlobalPtrToGlobal(builder, loc, globals.solver,
137  globals.solverCache);
138  }
139 
140  /// Create a `llvm.call` operation to a function with the given 'name' and
141  /// 'type'. If there does not already exist a (external) function with that
142  /// name create a matching external function declaration.
143  LLVM::CallOp buildCall(OpBuilder &builder, Location loc, StringRef name,
144  LLVM::LLVMFunctionType funcType,
145  ValueRange args) const {
146  auto &funcOp = globals.funcMap[builder.getStringAttr(name)];
147  if (!funcOp) {
148  OpBuilder::InsertionGuard guard(builder);
149  auto module =
150  builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
151  builder.setInsertionPointToEnd(module.getBody());
152  funcOp = LLVM::lookupOrCreateFn(module, name, funcType.getParams(),
153  funcType.getReturnType(),
154  funcType.getVarArg());
155  }
156  return builder.create<LLVM::CallOp>(loc, funcOp, args);
157  }
158 
159  /// Build a global constant for the given string and construct an 'addressof'
160  /// operation at the current 'builder' insertion point to get a pointer to it.
161  /// Multiple calls with the same string will reuse the same global. It is
162  /// guaranteed that the symbol of the global will be unique.
163  Value buildString(OpBuilder &builder, Location loc, StringRef str) const {
164  auto &global = globals.stringCache[builder.getStringAttr(str)];
165  if (!global) {
166  OpBuilder::InsertionGuard guard(builder);
167  auto module =
168  builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
169  builder.setInsertionPointToEnd(module.getBody());
170  auto arrayTy =
171  LLVM::LLVMArrayType::get(builder.getI8Type(), str.size() + 1);
172  auto strAttr = builder.getStringAttr(str.str() + '\00');
173  global = builder.create<LLVM::GlobalOp>(
174  loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
175  globals.names.newName("str"), strAttr);
176  }
177  return builder.create<LLVM::AddressOfOp>(loc, global);
178  }
179  /// Most API functions require a pointer to the the Z3 context object as the
180  /// first argument. This helper function prepends this pointer value to the
181  /// call for convenience.
182  LLVM::CallOp buildAPICallWithContext(OpBuilder &builder, Location loc,
183  StringRef name, Type returnType,
184  ValueRange args = {}) const {
185  auto ctx = buildContextPtr(builder, loc);
186  SmallVector<Value> arguments;
187  arguments.emplace_back(ctx);
188  arguments.append(SmallVector<Value>(args));
189  return buildCall(
190  builder, loc, name,
192  returnType, SmallVector<Type>(ValueRange(arguments).getTypes())),
193  arguments);
194  }
195 
196  /// Most API functions we need to call return a 'Z3_AST' object which is a
197  /// pointer in LLVM. This helper function simplifies calling those API
198  /// functions.
199  Value buildPtrAPICall(OpBuilder &builder, Location loc, StringRef name,
200  ValueRange args = {}) const {
201  return buildAPICallWithContext(
202  builder, loc, name,
203  LLVM::LLVMPointerType::get(builder.getContext()), args)
204  ->getResult(0);
205  }
206 
207  /// Build a value representing the SMT sort given with 'type'.
208  Value buildSort(OpBuilder &builder, Location loc, Type type) const {
209  // NOTE: if a type not handled by this switch is passed, an assertion will
210  // be triggered.
211  return TypeSwitch<Type, Value>(type)
212  .Case([&](smt::IntType ty) {
213  return buildPtrAPICall(builder, loc, "Z3_mk_int_sort");
214  })
215  .Case([&](smt::BitVectorType ty) {
216  Value bitwidth = builder.create<LLVM::ConstantOp>(
217  loc, builder.getI32Type(), ty.getWidth());
218  return buildPtrAPICall(builder, loc, "Z3_mk_bv_sort", {bitwidth});
219  })
220  .Case([&](smt::BoolType ty) {
221  return buildPtrAPICall(builder, loc, "Z3_mk_bool_sort");
222  })
223  .Case([&](smt::SortType ty) {
224  Value str = buildString(builder, loc, ty.getIdentifier());
225  Value sym =
226  buildPtrAPICall(builder, loc, "Z3_mk_string_symbol", {str});
227  return buildPtrAPICall(builder, loc, "Z3_mk_uninterpreted_sort",
228  {sym});
229  })
230  .Case([&](smt::ArrayType ty) {
231  return buildPtrAPICall(builder, loc, "Z3_mk_array_sort",
232  {buildSort(builder, loc, ty.getDomainType()),
233  buildSort(builder, loc, ty.getRangeType())});
234  });
235  }
236 
237  SMTGlobalsHandler &globals;
238  const LowerSMTToZ3LLVMOptions &options;
239 };
240 
241 //===----------------------------------------------------------------------===//
242 // Lowering Patterns
243 //===----------------------------------------------------------------------===//
244 
245 /// The 'smt.declare_fun' operation is used to declare both constants and
246 /// functions. The Z3 API, however, uses two different functions. Therefore,
247 /// depending on the result type of this operation, one of the following two
248 /// API functions is used to create the symbolic value:
249 /// ```
250 /// Z3_ast Z3_API Z3_mk_fresh_const(Z3_context c, Z3_string prefix, Z3_sort ty);
251 /// Z3_func_decl Z3_API Z3_mk_fresh_func_decl(
252 /// Z3_context c, Z3_string prefix, unsigned domain_size,
253 /// Z3_sort const domain[], Z3_sort range);
254 /// ```
255 struct DeclareFunOpLowering : public SMTLoweringPattern<DeclareFunOp> {
256  using SMTLoweringPattern::SMTLoweringPattern;
257 
258  LogicalResult
259  matchAndRewrite(DeclareFunOp op, OpAdaptor adaptor,
260  ConversionPatternRewriter &rewriter) const final {
261  Location loc = op.getLoc();
262 
263  // Create the name prefix.
264  Value prefix;
265  if (adaptor.getNamePrefix())
266  prefix = buildString(rewriter, loc, *adaptor.getNamePrefix());
267  else
268  prefix = rewriter.create<LLVM::ZeroOp>(
269  loc, LLVM::LLVMPointerType::get(getContext()));
270 
271  // Handle the constant value case.
272  if (!isa<SMTFuncType>(op.getType())) {
273  Value sort = buildSort(rewriter, loc, op.getType());
274  Value constDecl =
275  buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_const", {prefix, sort});
276  rewriter.replaceOp(op, constDecl);
277  return success();
278  }
279 
280  // Otherwise, we declare a function.
281  Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
282  auto funcType = cast<SMTFuncType>(op.getResult().getType());
283  Value rangeSort = buildSort(rewriter, loc, funcType.getRangeType());
284 
285  Type arrTy =
286  LLVM::LLVMArrayType::get(llvmPtrTy, funcType.getDomainTypes().size());
287 
288  Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
289  for (auto [i, ty] : llvm::enumerate(funcType.getDomainTypes())) {
290  Value sort = buildSort(rewriter, loc, ty);
291  domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, sort, i);
292  }
293 
294  Value one =
295  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
296  Value domainStorage =
297  rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
298  rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
299 
300  Value domainSize = rewriter.create<LLVM::ConstantOp>(
301  loc, rewriter.getI32Type(), funcType.getDomainTypes().size());
302  Value decl =
303  buildPtrAPICall(rewriter, loc, "Z3_mk_fresh_func_decl",
304  {prefix, domainSize, domainStorage, rangeSort});
305 
306  rewriter.replaceOp(op, decl);
307  return success();
308  }
309 };
310 
311 /// Lower the 'smt.apply_func' operation to Z3 API calls of the form:
312 /// ```
313 /// Z3_ast Z3_API Z3_mk_app(Z3_context c, Z3_func_decl d,
314 /// unsigned num_args, Z3_ast const args[]);
315 /// ```
316 struct ApplyFuncOpLowering : public SMTLoweringPattern<ApplyFuncOp> {
317  using SMTLoweringPattern::SMTLoweringPattern;
318 
319  LogicalResult
320  matchAndRewrite(ApplyFuncOp op, OpAdaptor adaptor,
321  ConversionPatternRewriter &rewriter) const final {
322  Location loc = op.getLoc();
323  Type llvmPtrTy = LLVM::LLVMPointerType::get(getContext());
324  Type arrTy = LLVM::LLVMArrayType::get(llvmPtrTy, adaptor.getArgs().size());
325 
326  // Create an array of the function arguments.
327  Value domain = rewriter.create<LLVM::UndefOp>(loc, arrTy);
328  for (auto [i, arg] : llvm::enumerate(adaptor.getArgs()))
329  domain = rewriter.create<LLVM::InsertValueOp>(loc, domain, arg, i);
330 
331  // Store the array on the stack.
332  Value one =
333  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
334  Value domainStorage =
335  rewriter.create<LLVM::AllocaOp>(loc, llvmPtrTy, arrTy, one);
336  rewriter.create<LLVM::StoreOp>(loc, domain, domainStorage);
337 
338  // Call the API function with a pointer to the function, the number of
339  // arguments, and the pointer to the arguments stored on the stack.
340  Value domainSize = rewriter.create<LLVM::ConstantOp>(
341  loc, rewriter.getI32Type(), adaptor.getArgs().size());
342  Value returnVal =
343  buildPtrAPICall(rewriter, loc, "Z3_mk_app",
344  {adaptor.getFunc(), domainSize, domainStorage});
345  rewriter.replaceOp(op, returnVal);
346 
347  return success();
348  }
349 };
350 
351 /// Lower the `smt.bv.constant` operation to either
352 /// ```
353 /// Z3_ast Z3_API Z3_mk_unsigned_int64(Z3_context c, uint64_t v, Z3_sort ty);
354 /// ```
355 /// if the bit-vector fits into a 64-bit integer or convert it to a string and
356 /// use the sligtly slower but arbitrary precision API function:
357 /// ```
358 /// Z3_ast Z3_API Z3_mk_numeral(Z3_context c, Z3_string numeral, Z3_sort ty);
359 /// ```
360 /// Note that there is also an API function taking an array of booleans, and
361 /// while those are typically compiled to 'i8' in LLVM they don't necessarily
362 /// have to (I think).
363 struct BVConstantOpLowering : public SMTLoweringPattern<smt::BVConstantOp> {
364  using SMTLoweringPattern::SMTLoweringPattern;
365 
366  LogicalResult
367  matchAndRewrite(smt::BVConstantOp op, OpAdaptor adaptor,
368  ConversionPatternRewriter &rewriter) const final {
369  Location loc = op.getLoc();
370  unsigned width = op.getType().getWidth();
371  auto bvSort = buildSort(rewriter, loc, op.getResult().getType());
372  APInt val = adaptor.getValue().getValue();
373 
374  if (width <= 64) {
375  Value bvConst = rewriter.create<LLVM::ConstantOp>(
376  loc, rewriter.getI64Type(), val.getZExtValue());
377  Value res = buildPtrAPICall(rewriter, loc, "Z3_mk_unsigned_int64",
378  {bvConst, bvSort});
379  rewriter.replaceOp(op, res);
380  return success();
381  }
382 
383  std::string str;
384  llvm::raw_string_ostream stream(str);
385  stream << val;
386  Value bvString = buildString(rewriter, loc, str);
387  Value bvNumeral =
388  buildPtrAPICall(rewriter, loc, "Z3_mk_numeral", {bvString, bvSort});
389 
390  rewriter.replaceOp(op, bvNumeral);
391  return success();
392  }
393 };
394 
395 /// Some of the Z3 API supports a variadic number of operands for some
396 /// operations (in particular if the expansion would lead to a super-linear
397 /// increase in operations such as with the ':pairwise' attribute). Those API
398 /// calls take an 'unsigned' argument indicating the size of an array of
399 /// pointers to the operands.
400 template <typename SourceTy>
401 struct VariadicSMTPattern : public SMTLoweringPattern<SourceTy> {
402  using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
403 
404  VariadicSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
405  SMTGlobalsHandler &globals,
406  const LowerSMTToZ3LLVMOptions &options,
407  StringRef apiFuncName, unsigned minNumArgs)
408  : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
409  apiFuncName(apiFuncName), minNumArgs(minNumArgs) {}
410 
411  LogicalResult
412  matchAndRewrite(SourceTy op, OpAdaptor adaptor,
413  ConversionPatternRewriter &rewriter) const final {
414  if (adaptor.getOperands().size() < minNumArgs)
415  return failure();
416 
417  Location loc = op.getLoc();
418  Value numOperands = rewriter.create<LLVM::ConstantOp>(
419  loc, rewriter.getI32Type(), op->getNumOperands());
420  Value constOne =
421  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
422  Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
423  Type arrTy = LLVM::LLVMArrayType::get(ptrTy, op->getNumOperands());
424  Value storage =
425  rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
426  Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
427 
428  for (auto [i, operand] : llvm::enumerate(adaptor.getOperands()))
429  array = rewriter.create<LLVM::InsertValueOp>(
430  loc, array, operand, ArrayRef<int64_t>{(int64_t)i});
431 
432  rewriter.create<LLVM::StoreOp>(loc, array, storage);
433 
434  rewriter.replaceOp(op,
435  SMTLoweringPattern<SourceTy>::buildPtrAPICall(
436  rewriter, loc, apiFuncName, {numOperands, storage}));
437  return success();
438  }
439 
440 private:
441  StringRef apiFuncName;
442  unsigned minNumArgs;
443 };
444 
445 /// Lower an SMT operation to a function call with the name 'apiFuncName' with
446 /// arguments matching the operands one-to-one.
447 template <typename SourceTy>
448 struct OneToOneSMTPattern : public SMTLoweringPattern<SourceTy> {
449  using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
450 
451  OneToOneSMTPattern(const TypeConverter &typeConverter, MLIRContext *context,
452  SMTGlobalsHandler &globals,
453  const LowerSMTToZ3LLVMOptions &options,
454  StringRef apiFuncName, unsigned numOperands)
455  : SMTLoweringPattern<SourceTy>(typeConverter, context, globals, options),
456  apiFuncName(apiFuncName), numOperands(numOperands) {}
457 
458  LogicalResult
459  matchAndRewrite(SourceTy op, OpAdaptor adaptor,
460  ConversionPatternRewriter &rewriter) const final {
461  if (adaptor.getOperands().size() != numOperands)
462  return failure();
463 
464  rewriter.replaceOp(
465  op, SMTLoweringPattern<SourceTy>::buildPtrAPICall(
466  rewriter, op.getLoc(), apiFuncName, adaptor.getOperands()));
467  return success();
468  }
469 
470 private:
471  StringRef apiFuncName;
472  unsigned numOperands;
473 };
474 
475 /// A pattern to lower SMT operations with a variadic number of operands
476 /// modelling the ':chainable' attribute in SMT to binary operations.
477 template <typename SourceTy>
478 class LowerChainableSMTPattern : public SMTLoweringPattern<SourceTy> {
479  using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
480  using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
481 
482  LogicalResult
483  matchAndRewrite(SourceTy op, OpAdaptor adaptor,
484  ConversionPatternRewriter &rewriter) const final {
485  if (adaptor.getOperands().size() <= 2)
486  return failure();
487 
488  Location loc = op.getLoc();
489  SmallVector<Value> elements;
490  for (int i = 1, e = adaptor.getOperands().size(); i < e; ++i) {
491  Value val = rewriter.create<SourceTy>(
492  loc, op->getResultTypes(),
493  ValueRange{adaptor.getOperands()[i - 1], adaptor.getOperands()[i]});
494  elements.push_back(val);
495  }
496  rewriter.replaceOpWithNewOp<smt::AndOp>(op, elements);
497  return success();
498  }
499 };
500 
501 /// A pattern to lower SMT operations with a variadic number of operands
502 /// modelling the `:left-assoc` attribute to a sequence of binary operators.
503 template <typename SourceTy>
504 class LowerLeftAssocSMTPattern : public SMTLoweringPattern<SourceTy> {
505  using SMTLoweringPattern<SourceTy>::SMTLoweringPattern;
506  using OpAdaptor = typename SMTLoweringPattern<SourceTy>::OpAdaptor;
507 
508  LogicalResult
509  matchAndRewrite(SourceTy op, OpAdaptor adaptor,
510  ConversionPatternRewriter &rewriter) const final {
511  if (adaptor.getOperands().size() <= 2)
512  return rewriter.notifyMatchFailure(op, "must have at least two operands");
513 
514  Value runner = adaptor.getOperands()[0];
515  for (Value val : adaptor.getOperands().drop_front())
516  runner = rewriter.create<SourceTy>(op.getLoc(), op->getResultTypes(),
517  ValueRange{runner, val});
518 
519  rewriter.replaceOp(op, runner);
520  return success();
521  }
522 };
523 
524 /// The 'smt.solver' operation has a region that corresponds to the lifetime of
525 /// the Z3 context and one solver instance created within this context.
526 /// To create a context, a Z3 configuration has to be built first and various
527 /// configuration parameters can be set before creating a context from it. Once
528 /// we have a context, we can create a solver and store a pointer to the context
529 /// and the solver in an LLVM global such that operations in the child region
530 /// have access to them. While the context created with `Z3_mk_context` takes
531 /// care of the reference counting of `Z3_AST` objects, it still requires manual
532 /// reference counting of `Z3_solver` objects, therefore, we need to increase
533 /// the ref. counter of the solver we get from `Z3_mk_solver` and must decrease
534 /// it again once we don't need it anymore. Finally, the configuration object
535 /// can be deleted.
536 /// ```
537 /// Z3_config Z3_API Z3_mk_config(void);
538 /// void Z3_API Z3_set_param_value(Z3_config c, Z3_string param_id,
539 /// Z3_string param_value);
540 /// Z3_context Z3_API Z3_mk_context(Z3_config c);
541 /// Z3_solver Z3_API Z3_mk_solver(Z3_context c);
542 /// void Z3_API Z3_solver_inc_ref(Z3_context c, Z3_solver s);
543 /// void Z3_API Z3_del_config(Z3_config c);
544 /// ```
545 /// At the end of the solver lifetime, we have to tell the context that we
546 /// don't need the solver anymore and delete the context itself.
547 /// ```
548 /// void Z3_API Z3_solver_dec_ref(Z3_context c, Z3_solver s);
549 /// void Z3_API Z3_del_context(Z3_context c);
550 /// ```
551 /// Note that the solver created here is a combined solver. There might be some
552 /// potential for optimization by creating more specialized solvers supported by
553 /// the Z3 API according the the kind of operations present in the body region.
554 struct SolverOpLowering : public SMTLoweringPattern<SolverOp> {
555  using SMTLoweringPattern::SMTLoweringPattern;
556 
557  LogicalResult
558  matchAndRewrite(SolverOp op, OpAdaptor adaptor,
559  ConversionPatternRewriter &rewriter) const final {
560  Location loc = op.getLoc();
561  auto ptrTy = LLVM::LLVMPointerType::get(getContext());
562  auto voidTy = LLVM::LLVMVoidType::get(getContext());
563  auto ptrToPtrFunc = LLVM::LLVMFunctionType::get(ptrTy, ptrTy);
564  auto ptrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, ptrTy);
565  auto ptrPtrToVoidFunc = LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy});
566 
567  // Create the configuration.
568  Value config = buildCall(rewriter, loc, "Z3_mk_config",
569  LLVM::LLVMFunctionType::get(ptrTy, {}), {})
570  .getResult();
571 
572  // In debug-mode, we enable proofs such that we can fetch one in the 'unsat'
573  // region of each 'smt.check' operation.
574  if (options.debug) {
575  Value paramKey = buildString(rewriter, loc, "proof");
576  Value paramValue = buildString(rewriter, loc, "true");
577  buildCall(rewriter, loc, "Z3_set_param_value",
578  LLVM::LLVMFunctionType::get(voidTy, {ptrTy, ptrTy, ptrTy}),
579  {config, paramKey, paramValue});
580  }
581 
582  // Create the context and store a pointer to it in the global variable.
583  Value ctx = buildCall(rewriter, loc, "Z3_mk_context", ptrToPtrFunc, config)
584  .getResult();
585  Value ctxAddr =
586  rewriter.create<LLVM::AddressOfOp>(loc, globals.ctx).getResult();
587  rewriter.create<LLVM::StoreOp>(loc, ctx, ctxAddr);
588 
589  // Delete the configuration again.
590  buildCall(rewriter, loc, "Z3_del_config", ptrToVoidFunc, {config});
591 
592  // Create a solver instance, increase its reference counter, and store a
593  // pointer to it in the global variable.
594  Value solver = buildCall(rewriter, loc, "Z3_mk_solver", ptrToPtrFunc, ctx)
595  ->getResult(0);
596  buildCall(rewriter, loc, "Z3_solver_inc_ref", ptrPtrToVoidFunc,
597  {ctx, solver});
598  Value solverAddr =
599  rewriter.create<LLVM::AddressOfOp>(loc, globals.solver).getResult();
600  rewriter.create<LLVM::StoreOp>(loc, solver, solverAddr);
601 
602  // This assumes that no constant hoisting of the like happens inbetween
603  // the patterns defined in this pass because once the solver initialization
604  // and deallocation calls are inserted and the body region is inlined,
605  // canonicalizations and folders applied inbetween lowering patterns might
606  // hoist the SMT constants which means they would access uninitialized
607  // global variables once they are lowered.
608  SmallVector<Type> convertedTypes;
609  if (failed(
610  typeConverter->convertTypes(op->getResultTypes(), convertedTypes)))
611  return failure();
612 
613  func::FuncOp funcOp;
614  {
615  OpBuilder::InsertionGuard guard(rewriter);
616  auto module = op->getParentOfType<ModuleOp>();
617  rewriter.setInsertionPointToEnd(module.getBody());
618 
619  funcOp = rewriter.create<func::FuncOp>(
620  loc, globals.names.newName("solver"),
621  rewriter.getFunctionType(adaptor.getInputs().getTypes(),
622  convertedTypes));
623  rewriter.inlineRegionBefore(op.getBodyRegion(), funcOp.getBody(),
624  funcOp.end());
625  }
626 
627  ValueRange results =
628  rewriter.create<func::CallOp>(loc, funcOp, adaptor.getInputs())
629  ->getResults();
630 
631  // At the end of the region, decrease the solver's reference counter and
632  // delete the context.
633  // NOTE: we cannot use the convenience helper here because we don't want to
634  // load the context from the global but use the result from the 'mk_context'
635  // call directly for two reasons:
636  // * avoid an unnecessary load
637  // * the caching mechanism of the context does not work here because it
638  // would reuse the loaded context from a earlier solver
639  buildCall(rewriter, loc, "Z3_solver_dec_ref", ptrPtrToVoidFunc,
640  {ctx, solver});
641  buildCall(rewriter, loc, "Z3_del_context", ptrToVoidFunc, ctx);
642 
643  rewriter.replaceOp(op, results);
644  return success();
645  }
646 };
647 
648 /// Lower `smt.assert` operations to Z3 API calls of the form:
649 /// ```
650 /// void Z3_API Z3_solver_assert(Z3_context c, Z3_solver s, Z3_ast a);
651 /// ```
652 struct AssertOpLowering : public SMTLoweringPattern<AssertOp> {
653  using SMTLoweringPattern::SMTLoweringPattern;
654 
655  LogicalResult
656  matchAndRewrite(AssertOp op, OpAdaptor adaptor,
657  ConversionPatternRewriter &rewriter) const final {
658  Location loc = op.getLoc();
659  buildAPICallWithContext(
660  rewriter, loc, "Z3_solver_assert",
661  LLVM::LLVMVoidType::get(getContext()),
662  {buildSolverPtr(rewriter, loc), adaptor.getInput()});
663 
664  rewriter.eraseOp(op);
665  return success();
666  }
667 };
668 
669 /// Lower `smt.yield` operations to `scf.yield` operations. This not necessary
670 /// for the yield in `smt.solver` or in quantifiers since they are deleted
671 /// directly by the parent operation, but makes the lowering of the `smt.check`
672 /// operation simpler and more convenient since the regions get translated
673 /// directly to regions of `scf.if` operations.
674 struct YieldOpLowering : public SMTLoweringPattern<YieldOp> {
675  using SMTLoweringPattern::SMTLoweringPattern;
676 
677  LogicalResult
678  matchAndRewrite(YieldOp op, OpAdaptor adaptor,
679  ConversionPatternRewriter &rewriter) const final {
680  if (op->getParentOfType<func::FuncOp>()) {
681  rewriter.replaceOpWithNewOp<func::ReturnOp>(op, adaptor.getValues());
682  return success();
683  }
684  if (op->getParentOfType<LLVM::LLVMFuncOp>()) {
685  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getValues());
686  return success();
687  }
688  if (isa<scf::SCFDialect>(op->getParentOp()->getDialect())) {
689  rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getValues());
690  return success();
691  }
692  return failure();
693  }
694 };
695 
696 /// Lower `smt.check` operations to Z3 API calls and control-flow operations.
697 /// ```
698 /// Z3_lbool Z3_API Z3_solver_check(Z3_context c, Z3_solver s);
699 ///
700 /// typedef enum
701 /// {
702 /// Z3_L_FALSE = -1, // means unsatisfiable here
703 /// Z3_L_UNDEF, // means unknown here
704 /// Z3_L_TRUE // means satisfiable here
705 /// } Z3_lbool;
706 /// ```
707 struct CheckOpLowering : public SMTLoweringPattern<CheckOp> {
708  using SMTLoweringPattern::SMTLoweringPattern;
709 
710  LogicalResult
711  matchAndRewrite(CheckOp op, OpAdaptor adaptor,
712  ConversionPatternRewriter &rewriter) const final {
713  Location loc = op.getLoc();
714  auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
715  auto printfType =
716  LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrTy}, true);
717 
718  auto getHeaderString = [](const std::string &title) {
719  unsigned titleSize = title.size() + 2; // Add a space left and right
720  return std::string((80 - titleSize) / 2, '-') + " " + title + " " +
721  std::string((80 - titleSize + 1) / 2, '-') + "\n%s\n" +
722  std::string(80, '-') + "\n";
723  };
724 
725  // Get the pointer to the solver instance.
726  Value solver = buildSolverPtr(rewriter, loc);
727 
728  // In debug-mode, print the state of the solver before calling 'check-sat'
729  // on it. This prints the asserted SMT expressions.
730  if (options.debug) {
731  auto solverStringPtr =
732  buildPtrAPICall(rewriter, loc, "Z3_solver_to_string", {solver});
733  auto solverFormatString =
734  buildString(rewriter, loc, getHeaderString("Solver"));
735  buildCall(rewriter, op.getLoc(), "printf", printfType,
736  {solverFormatString, solverStringPtr});
737  }
738 
739  // Convert the result types of the `smt.check` operation.
740  SmallVector<Type> resultTypes;
741  if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
742  return failure();
743 
744  // Call 'check-sat' and check if the assertions are satisfiable.
745  Value checkResult =
746  buildAPICallWithContext(rewriter, loc, "Z3_solver_check",
747  rewriter.getI32Type(), {solver})
748  ->getResult(0);
749  Value constOne =
750  rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), 1);
751  Value isSat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
752  checkResult, constOne);
753 
754  // Simply inline the 'sat' region into the 'then' region of the 'scf.if'
755  auto satIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isSat);
756  rewriter.inlineRegionBefore(op.getSatRegion(), satIfOp.getThenRegion(),
757  satIfOp.getThenRegion().end());
758 
759  // Otherwise, the 'else' block checks if the assertions are unsatisfiable or
760  // unknown. The corresponding regions can also be simply inlined into the
761  // two branches of this nested if-statement as well.
762  rewriter.createBlock(&satIfOp.getElseRegion());
763  Value constNegOne =
764  rewriter.create<LLVM::ConstantOp>(loc, checkResult.getType(), -1);
765  Value isUnsat = rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq,
766  checkResult, constNegOne);
767  auto unsatIfOp = rewriter.create<scf::IfOp>(loc, resultTypes, isUnsat);
768  rewriter.create<scf::YieldOp>(loc, unsatIfOp->getResults());
769 
770  rewriter.inlineRegionBefore(op.getUnsatRegion(), unsatIfOp.getThenRegion(),
771  unsatIfOp.getThenRegion().end());
772  rewriter.inlineRegionBefore(op.getUnknownRegion(),
773  unsatIfOp.getElseRegion(),
774  unsatIfOp.getElseRegion().end());
775 
776  rewriter.replaceOp(op, satIfOp->getResults());
777 
778  if (options.debug) {
779  // In debug-mode, if the assertions are unsatisfiable we can print the
780  // proof.
781  rewriter.setInsertionPointToStart(unsatIfOp.thenBlock());
782  auto proof = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_proof",
783  {solver});
784  auto stringPtr =
785  buildPtrAPICall(rewriter, op.getLoc(), "Z3_ast_to_string", {proof});
786  auto formatString =
787  buildString(rewriter, op.getLoc(), getHeaderString("Proof"));
788  buildCall(rewriter, op.getLoc(), "printf", printfType,
789  {formatString, stringPtr});
790 
791  // In debug mode, if the assertions are satisfiable we can print the model
792  // (effectively a counter-example).
793  rewriter.setInsertionPointToStart(satIfOp.thenBlock());
794  auto model = buildPtrAPICall(rewriter, op.getLoc(), "Z3_solver_get_model",
795  {solver});
796  auto modelStringPtr =
797  buildPtrAPICall(rewriter, op.getLoc(), "Z3_model_to_string", {model});
798  auto modelFormatString =
799  buildString(rewriter, op.getLoc(), getHeaderString("Model"));
800  buildCall(rewriter, op.getLoc(), "printf", printfType,
801  {modelFormatString, modelStringPtr});
802  }
803 
804  return success();
805  }
806 };
807 
808 /// Lower `smt.forall` and `smt.exists` operations to the following Z3 API call.
809 /// ```
810 /// Z3_ast Z3_API Z3_mk_{forall|exists}_const(
811 /// Z3_context c,
812 /// unsigned weight,
813 /// unsigned num_bound,
814 /// Z3_app const bound[],
815 /// unsigned num_patterns,
816 /// Z3_pattern const patterns[],
817 /// Z3_ast body
818 /// );
819 /// ```
820 /// All nested regions are inlined into the parent region and the block
821 /// arguments are replaced with new `smt.declare_fun` constants that are also
822 /// passed to the `bound` argument of above API function. Patterns are created
823 /// with the following API function.
824 /// ```
825 /// Z3_pattern Z3_API Z3_mk_pattern(Z3_context c, unsigned num_patterns,
826 /// Z3_ast const terms[]);
827 /// ```
828 /// Where each operand of the `smt.yield` in a pattern region is a 'term'.
829 template <typename QuantifierOp>
830 struct QuantifierLowering : public SMTLoweringPattern<QuantifierOp> {
831  using SMTLoweringPattern<QuantifierOp>::SMTLoweringPattern;
832  using SMTLoweringPattern<QuantifierOp>::typeConverter;
833  using SMTLoweringPattern<QuantifierOp>::buildPtrAPICall;
834  using OpAdaptor = typename QuantifierOp::Adaptor;
835 
836  Value createStorageForValueList(ValueRange values, Location loc,
837  ConversionPatternRewriter &rewriter) const {
838  Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
839  Type arrTy = LLVM::LLVMArrayType::get(ptrTy, values.size());
840  Value constOne =
841  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), 1);
842  Value storage =
843  rewriter.create<LLVM::AllocaOp>(loc, ptrTy, arrTy, constOne);
844  Value array = rewriter.create<LLVM::UndefOp>(loc, arrTy);
845 
846  for (auto [i, val] : llvm::enumerate(values))
847  array = rewriter.create<LLVM::InsertValueOp>(loc, array, val,
848  ArrayRef<int64_t>(i));
849 
850  rewriter.create<LLVM::StoreOp>(loc, array, storage);
851 
852  return storage;
853  }
854 
855  LogicalResult
856  matchAndRewrite(QuantifierOp op, OpAdaptor adaptor,
857  ConversionPatternRewriter &rewriter) const final {
858  Location loc = op.getLoc();
859  Type ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext());
860 
861  // no-pattern attribute not supported yet because the Z3 CAPI allows more
862  // fine-grained control where a list of patterns to be banned can be given.
863  // This means, the no-pattern attribute is equivalent to providing a list of
864  // all possible sub-expressions in the quantifier body to the CAPI.
865  if (adaptor.getNoPattern())
866  return rewriter.notifyMatchFailure(
867  op, "no-pattern attribute not yet supported!");
868 
869  rewriter.setInsertionPoint(op);
870 
871  // Weight attribute
872  Value weight = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
873  adaptor.getWeight());
874 
875  // Bound variables
876  unsigned numDecls = op.getBody().getNumArguments();
877  Value numDeclsVal =
878  rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(), numDecls);
879 
880  // We replace the block arguments with constant symbolic values and inform
881  // the quantifier API call which constants it should treat as bound
882  // variables. We also need to make sure that we use the exact same SSA
883  // values in the pattern regions since we lower constant declaration
884  // operation to always produce fresh constants.
885  SmallVector<Value> repl;
886  for (auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) {
887  Value newArg;
888  if (adaptor.getBoundVarNames().has_value())
889  newArg = rewriter.create<smt::DeclareFunOp>(
890  loc, arg.getType(),
891  cast<StringAttr>((*adaptor.getBoundVarNames())[i]));
892  else
893  newArg = rewriter.create<smt::DeclareFunOp>(loc, arg.getType());
894  repl.push_back(typeConverter->materializeTargetConversion(
895  rewriter, loc, typeConverter->convertType(arg.getType()), newArg));
896  }
897 
898  Value boundStorage = createStorageForValueList(repl, loc, rewriter);
899 
900  // Body Expression
901  auto yieldOp = cast<smt::YieldOp>(op.getBody().front().getTerminator());
902  Value bodyExp = yieldOp.getValues()[0];
903  rewriter.setInsertionPointAfterValue(bodyExp);
904  bodyExp = typeConverter->materializeTargetConversion(
905  rewriter, loc, typeConverter->convertType(bodyExp.getType()), bodyExp);
906  rewriter.eraseOp(yieldOp);
907 
908  rewriter.inlineBlockBefore(&op.getBody().front(), op, repl);
909  rewriter.setInsertionPoint(op);
910 
911  // Patterns
912  unsigned numPatterns = adaptor.getPatterns().size();
913  Value numPatternsVal = rewriter.create<LLVM::ConstantOp>(
914  loc, rewriter.getI32Type(), numPatterns);
915 
916  Value patternStorage;
917  if (numPatterns > 0) {
918  SmallVector<Value> patterns;
919  for (Region *patternRegion : adaptor.getPatterns()) {
920  auto yieldOp =
921  cast<smt::YieldOp>(patternRegion->front().getTerminator());
922  auto patternTerms = yieldOp.getOperands();
923 
924  rewriter.setInsertionPoint(yieldOp);
925  SmallVector<Value> patternList;
926  for (auto val : patternTerms)
927  patternList.push_back(typeConverter->materializeTargetConversion(
928  rewriter, loc, typeConverter->convertType(val.getType()), val));
929 
930  rewriter.eraseOp(yieldOp);
931  rewriter.inlineBlockBefore(&patternRegion->front(), op, repl);
932 
933  rewriter.setInsertionPoint(op);
934  Value numTerms = rewriter.create<LLVM::ConstantOp>(
935  loc, rewriter.getI32Type(), patternTerms.size());
936  Value patternTermStorage =
937  createStorageForValueList(patternList, loc, rewriter);
938  Value pattern = buildPtrAPICall(rewriter, loc, "Z3_mk_pattern",
939  {numTerms, patternTermStorage});
940 
941  patterns.emplace_back(pattern);
942  }
943  patternStorage = createStorageForValueList(patterns, loc, rewriter);
944  } else {
945  // If we set the num_patterns parameter to 0, we can just pass a nullptr
946  // as storage.
947  patternStorage = rewriter.create<LLVM::ZeroOp>(loc, ptrTy);
948  }
949 
950  StringRef apiCallName = "Z3_mk_forall_const";
951  if (std::is_same_v<QuantifierOp, ExistsOp>)
952  apiCallName = "Z3_mk_exists_const";
953  Value quantifierExp =
954  buildPtrAPICall(rewriter, loc, apiCallName,
955  {weight, numDeclsVal, boundStorage, numPatternsVal,
956  patternStorage, bodyExp});
957 
958  rewriter.replaceOp(op, quantifierExp);
959  return success();
960  }
961 };
962 
963 /// Lower `smt.bv.repeat` operations to Z3 API function calls of the form
964 /// ```
965 /// Z3_ast Z3_API Z3_mk_repeat(Z3_context c, unsigned i, Z3_ast t1);
966 /// ```
967 struct RepeatOpLowering : public SMTLoweringPattern<RepeatOp> {
968  using SMTLoweringPattern::SMTLoweringPattern;
969 
970  LogicalResult
971  matchAndRewrite(RepeatOp op, OpAdaptor adaptor,
972  ConversionPatternRewriter &rewriter) const final {
973  Value count = rewriter.create<LLVM::ConstantOp>(
974  op.getLoc(), rewriter.getI32Type(), op.getCount());
975  rewriter.replaceOp(op,
976  buildPtrAPICall(rewriter, op.getLoc(), "Z3_mk_repeat",
977  {count, adaptor.getInput()}));
978  return success();
979  }
980 };
981 
982 /// Lower `smt.bv.extract` operations to Z3 API function calls of the following
983 /// form, where the output bit-vector has size `n = high - low + 1`. This means,
984 /// both the 'high' and 'low' indices are inclusive.
985 /// ```
986 /// Z3_ast Z3_API Z3_mk_extract(Z3_context c, unsigned high, unsigned low,
987 /// Z3_ast t1);
988 /// ```
989 struct ExtractOpLowering : public SMTLoweringPattern<ExtractOp> {
990  using SMTLoweringPattern::SMTLoweringPattern;
991 
992  LogicalResult
993  matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
994  ConversionPatternRewriter &rewriter) const final {
995  Location loc = op.getLoc();
996  Value low = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI32Type(),
997  adaptor.getLowBit());
998  Value high = rewriter.create<LLVM::ConstantOp>(
999  loc, rewriter.getI32Type(),
1000  adaptor.getLowBit() + op.getType().getWidth() - 1);
1001  rewriter.replaceOp(op, buildPtrAPICall(rewriter, loc, "Z3_mk_extract",
1002  {high, low, adaptor.getInput()}));
1003  return success();
1004  }
1005 };
1006 
1007 /// Lower `smt.array.broadcast` operations to Z3 API function calls of the form
1008 /// ```
1009 /// Z3_ast Z3_API Z3_mk_const_array(Z3_context c, Z3_sort domain, Z3_ast v);
1010 /// ```
1011 struct ArrayBroadcastOpLowering
1012  : public SMTLoweringPattern<smt::ArrayBroadcastOp> {
1013  using SMTLoweringPattern::SMTLoweringPattern;
1014 
1015  LogicalResult
1016  matchAndRewrite(smt::ArrayBroadcastOp op, OpAdaptor adaptor,
1017  ConversionPatternRewriter &rewriter) const final {
1018  auto domainSort = buildSort(
1019  rewriter, op.getLoc(),
1020  cast<smt::ArrayType>(op.getResult().getType()).getDomainType());
1021 
1022  rewriter.replaceOp(op, buildPtrAPICall(rewriter, op.getLoc(),
1023  "Z3_mk_const_array",
1024  {domainSort, adaptor.getValue()}));
1025  return success();
1026  }
1027 };
1028 
1029 /// Lower the `smt.constant` operation to one of the following Z3 API function
1030 /// calls depending on the value of the boolean attribute.
1031 /// ```
1032 /// Z3_ast Z3_API Z3_mk_true(Z3_context c);
1033 /// Z3_ast Z3_API Z3_mk_false(Z3_context c);
1034 /// ```
1035 struct BoolConstantOpLowering : public SMTLoweringPattern<smt::BoolConstantOp> {
1036  using SMTLoweringPattern::SMTLoweringPattern;
1037 
1038  LogicalResult
1039  matchAndRewrite(smt::BoolConstantOp op, OpAdaptor adaptor,
1040  ConversionPatternRewriter &rewriter) const final {
1041  rewriter.replaceOp(
1042  op, buildPtrAPICall(rewriter, op.getLoc(),
1043  adaptor.getValue() ? "Z3_mk_true" : "Z3_mk_false"));
1044  return success();
1045  }
1046 };
1047 
1048 /// Lower `smt.int.constant` operations to one of the following two Z3 API
1049 /// function calls depending on whether the storage APInt has a bit-width that
1050 /// fits in a `uint64_t`.
1051 /// ```
1052 /// Z3_sort Z3_API Z3_mk_int_sort(Z3_context c);
1053 ///
1054 /// Z3_ast Z3_API Z3_mk_int64(Z3_context c, int64_t v, Z3_sort ty);
1055 ///
1056 /// Z3_ast Z3_API Z3_mk_numeral(Z3_context c, Z3_string numeral, Z3_sort ty);
1057 /// Z3_ast Z3_API Z3_mk_unary_minus(Z3_context c, Z3_ast arg);
1058 /// ```
1059 struct IntConstantOpLowering : public SMTLoweringPattern<smt::IntConstantOp> {
1060  using SMTLoweringPattern::SMTLoweringPattern;
1061 
1062  LogicalResult
1063  matchAndRewrite(smt::IntConstantOp op, OpAdaptor adaptor,
1064  ConversionPatternRewriter &rewriter) const final {
1065  Location loc = op.getLoc();
1066  Value type = buildPtrAPICall(rewriter, loc, "Z3_mk_int_sort");
1067  if (adaptor.getValue().getBitWidth() <= 64) {
1068  Value val = rewriter.create<LLVM::ConstantOp>(
1069  loc, rewriter.getI64Type(), adaptor.getValue().getSExtValue());
1070  rewriter.replaceOp(
1071  op, buildPtrAPICall(rewriter, loc, "Z3_mk_int64", {val, type}));
1072  return success();
1073  }
1074 
1075  std::string numeralStr;
1076  llvm::raw_string_ostream stream(numeralStr);
1077  stream << adaptor.getValue().abs();
1078 
1079  Value numeral = buildString(rewriter, loc, numeralStr);
1080  Value intNumeral =
1081  buildPtrAPICall(rewriter, loc, "Z3_mk_numeral", {numeral, type});
1082 
1083  if (adaptor.getValue().isNegative())
1084  intNumeral =
1085  buildPtrAPICall(rewriter, loc, "Z3_mk_unary_minus", intNumeral);
1086 
1087  rewriter.replaceOp(op, intNumeral);
1088  return success();
1089  }
1090 };
1091 
1092 /// Lower `smt.int.cmp` operations to one of the following Z3 API function calls
1093 /// depending on the predicate.
1094 /// ```
1095 /// Z3_ast Z3_API Z3_mk_{{pred}}(Z3_context c, Z3_ast t1, Z3_ast t2);
1096 /// ```
1097 struct IntCmpOpLowering : public SMTLoweringPattern<IntCmpOp> {
1098  using SMTLoweringPattern::SMTLoweringPattern;
1099 
1100  LogicalResult
1101  matchAndRewrite(IntCmpOp op, OpAdaptor adaptor,
1102  ConversionPatternRewriter &rewriter) const final {
1103  rewriter.replaceOp(
1104  op,
1105  buildPtrAPICall(rewriter, op.getLoc(),
1106  "Z3_mk_" + stringifyIntPredicate(op.getPred()).str(),
1107  {adaptor.getLhs(), adaptor.getRhs()}));
1108  return success();
1109  }
1110 };
1111 
1112 /// Lower `smt.bv.cmp` operations to one of the following Z3 API function calls,
1113 /// performing two's complement comparison, depending on the predicate
1114 /// attribute.
1115 /// ```
1116 /// Z3_ast Z3_API Z3_mk_bv{{pred}}(Z3_context c, Z3_ast t1, Z3_ast t2);
1117 /// ```
1118 struct BVCmpOpLowering : public SMTLoweringPattern<BVCmpOp> {
1119  using SMTLoweringPattern::SMTLoweringPattern;
1120 
1121  LogicalResult
1122  matchAndRewrite(BVCmpOp op, OpAdaptor adaptor,
1123  ConversionPatternRewriter &rewriter) const final {
1124  rewriter.replaceOp(
1125  op, buildPtrAPICall(rewriter, op.getLoc(),
1126  "Z3_mk_bv" +
1127  stringifyBVCmpPredicate(op.getPred()).str(),
1128  {adaptor.getLhs(), adaptor.getRhs()}));
1129  return success();
1130  }
1131 };
1132 
1133 /// Expand the `smt.int.abs` operation to a `smt.ite` operation.
1134 struct IntAbsOpLowering : public SMTLoweringPattern<IntAbsOp> {
1135  using SMTLoweringPattern::SMTLoweringPattern;
1136 
1137  LogicalResult
1138  matchAndRewrite(IntAbsOp op, OpAdaptor adaptor,
1139  ConversionPatternRewriter &rewriter) const final {
1140  Location loc = op.getLoc();
1141  Value zero = rewriter.create<IntConstantOp>(
1142  loc, rewriter.getIntegerAttr(rewriter.getI1Type(), 0));
1143  Value cmp = rewriter.create<IntCmpOp>(loc, IntPredicate::lt,
1144  adaptor.getInput(), zero);
1145  Value neg = rewriter.create<IntSubOp>(loc, zero, adaptor.getInput());
1146  rewriter.replaceOpWithNewOp<IteOp>(op, cmp, neg, adaptor.getInput());
1147  return success();
1148  }
1149 };
1150 
1151 } // namespace
1152 
1153 //===----------------------------------------------------------------------===//
1154 // Pass Implementation
1155 //===----------------------------------------------------------------------===//
1156 
1157 namespace {
1158 struct LowerSMTToZ3LLVMPass
1159  : public circt::impl::LowerSMTToZ3LLVMBase<LowerSMTToZ3LLVMPass> {
1160  using Base::Base;
1161  void runOnOperation() override;
1162 };
1163 } // namespace
1164 
1165 void circt::populateSMTToZ3LLVMTypeConverter(TypeConverter &converter) {
1166  converter.addConversion([](smt::BoolType type) {
1167  return LLVM::LLVMPointerType::get(type.getContext());
1168  });
1169  converter.addConversion([](smt::BitVectorType type) {
1170  return LLVM::LLVMPointerType::get(type.getContext());
1171  });
1172  converter.addConversion([](smt::ArrayType type) {
1173  return LLVM::LLVMPointerType::get(type.getContext());
1174  });
1175  converter.addConversion([](smt::IntType type) {
1176  return LLVM::LLVMPointerType::get(type.getContext());
1177  });
1178  converter.addConversion([](smt::SMTFuncType type) {
1179  return LLVM::LLVMPointerType::get(type.getContext());
1180  });
1181  converter.addConversion([](smt::SortType type) {
1182  return LLVM::LLVMPointerType::get(type.getContext());
1183  });
1184 }
1185 
1187  RewritePatternSet &patterns, TypeConverter &converter,
1188  SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options) {
1189 #define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS) \
1190  patterns.add<VariadicSMTPattern<OP>>(/*NOLINT(bugprone-macro-parentheses)*/ \
1191  converter, patterns.getContext(), \
1192  globals, options, APINAME, \
1193  MIN_NUM_ARGS);
1194 
1195 #define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS) \
1196  patterns.add<OneToOneSMTPattern<OP>>(/*NOLINT(bugprone-macro-parentheses)*/ \
1197  converter, patterns.getContext(), \
1198  globals, options, APINAME, NUM_ARGS);
1199 
1200  // Lower `smt.distinct` operations which allows a variadic number of operands
1201  // according to the `:pairwise` attribute. The Z3 API function supports a
1202  // variadic number of operands as well, i.e., a direct lowering is possible:
1203  // ```
1204  // Z3_ast Z3_API Z3_mk_distinct(Z3_context c, unsigned num_args, Z3_ast const
1205  // args[])
1206  // ```
1207  // The API function requires num_args > 1 which is guaranteed to be satisfied
1208  // because `smt.distinct` is verified to have > 1 operands.
1209  ADD_VARIADIC_PATTERN(DistinctOp, "Z3_mk_distinct", 2);
1210 
1211  // Lower `smt.and` operations which allows a variadic number of operands
1212  // according to the `:left-assoc` attribute. The Z3 API function supports a
1213  // variadic number of operands as well, i.e., a direct lowering is possible:
1214  // ```
1215  // Z3_ast Z3_API Z3_mk_and(Z3_context c, unsigned num_args, Z3_ast const
1216  // args[])
1217  // ```
1218  // The API function requires num_args > 1. This is not guaranteed by the
1219  // `smt.and` operation and thus the pattern will not apply when no operand is
1220  // present. The constant folder of the operation is assumed to fold this to
1221  // a constant 'true' (neutral element of AND).
1222  ADD_VARIADIC_PATTERN(AndOp, "Z3_mk_and", 2);
1223 
1224  // Lower `smt.or` operations which allows a variadic number of operands
1225  // according to the `:left-assoc` attribute. The Z3 API function supports a
1226  // variadic number of operands as well, i.e., a direct lowering is possible:
1227  // ```
1228  // Z3_ast Z3_API Z3_mk_or(Z3_context c, unsigned num_args, Z3_ast const
1229  // args[])
1230  // ```
1231  // The API function requires num_args > 1. This is not guaranteed by the
1232  // `smt.or` operation and thus the pattern will not apply when no operand is
1233  // present. The constant folder of the operation is assumed to fold this to
1234  // a constant 'false' (neutral element of OR).
1235  ADD_VARIADIC_PATTERN(OrOp, "Z3_mk_or", 2);
1236 
1237  // Lower `smt.not` operations to the following Z3 API function:
1238  // ```
1239  // Z3_ast Z3_API Z3_mk_not(Z3_context c, Z3_ast a);
1240  // ```
1241  ADD_ONE_TO_ONE_PATTERN(NotOp, "Z3_mk_not", 1);
1242 
1243  // Lower `smt.xor` operations which allows a variadic number of operands
1244  // according to the `:left-assoc` attribute. The Z3 API function, however,
1245  // only takes two operands.
1246  // ```
1247  // Z3_ast Z3_API Z3_mk_xor(Z3_context c, Z3_ast t1, Z3_ast t2);
1248  // ```
1249  // Therefore, we need to decompose the operation first to a sequence of XOR
1250  // operations matching the left associative behavior.
1251  patterns.add<LowerLeftAssocSMTPattern<XOrOp>>(
1252  converter, patterns.getContext(), globals, options);
1253  ADD_ONE_TO_ONE_PATTERN(XOrOp, "Z3_mk_xor", 2);
1254 
1255  // Lower `smt.implies` operations to the following Z3 API function:
1256  // ```
1257  // Z3_ast Z3_API Z3_mk_implies(Z3_context c, Z3_ast t1, Z3_ast t2);
1258  // ```
1259  ADD_ONE_TO_ONE_PATTERN(ImpliesOp, "Z3_mk_implies", 2);
1260 
1261  // All the bit-vector arithmetic and bitwise operations conveniently lower to
1262  // Z3 API function calls with essentially matching names and a one-to-one
1263  // correspondence of operands to call arguments.
1264  ADD_ONE_TO_ONE_PATTERN(BVNegOp, "Z3_mk_bvneg", 1);
1265  ADD_ONE_TO_ONE_PATTERN(BVAddOp, "Z3_mk_bvadd", 2);
1266  ADD_ONE_TO_ONE_PATTERN(BVMulOp, "Z3_mk_bvmul", 2);
1267  ADD_ONE_TO_ONE_PATTERN(BVURemOp, "Z3_mk_bvurem", 2);
1268  ADD_ONE_TO_ONE_PATTERN(BVSRemOp, "Z3_mk_bvsrem", 2);
1269  ADD_ONE_TO_ONE_PATTERN(BVSModOp, "Z3_mk_bvsmod", 2);
1270  ADD_ONE_TO_ONE_PATTERN(BVUDivOp, "Z3_mk_bvudiv", 2);
1271  ADD_ONE_TO_ONE_PATTERN(BVSDivOp, "Z3_mk_bvsdiv", 2);
1272  ADD_ONE_TO_ONE_PATTERN(BVShlOp, "Z3_mk_bvshl", 2);
1273  ADD_ONE_TO_ONE_PATTERN(BVLShrOp, "Z3_mk_bvlshr", 2);
1274  ADD_ONE_TO_ONE_PATTERN(BVAShrOp, "Z3_mk_bvashr", 2);
1275  ADD_ONE_TO_ONE_PATTERN(BVNotOp, "Z3_mk_bvnot", 1);
1276  ADD_ONE_TO_ONE_PATTERN(BVAndOp, "Z3_mk_bvand", 2);
1277  ADD_ONE_TO_ONE_PATTERN(BVOrOp, "Z3_mk_bvor", 2);
1278  ADD_ONE_TO_ONE_PATTERN(BVXOrOp, "Z3_mk_bvxor", 2);
1279 
1280  // The `smt.bv.concat` operation only supports two operands, just like the
1281  // Z3 API function.
1282  // ```
1283  // Z3_ast Z3_API Z3_mk_concat(Z3_context c, Z3_ast t1, Z3_ast t2);
1284  // ```
1285  ADD_ONE_TO_ONE_PATTERN(ConcatOp, "Z3_mk_concat", 2);
1286 
1287  // Lower the `smt.ite` operation to the following Z3 API function call, where
1288  // `t1` must have boolean sort.
1289  // ```
1290  // Z3_ast Z3_API Z3_mk_ite(Z3_context c, Z3_ast t1, Z3_ast t2, Z3_ast t3);
1291  // ```
1292  ADD_ONE_TO_ONE_PATTERN(IteOp, "Z3_mk_ite", 3);
1293 
1294  // Lower the `smt.array.select` operation to the following Z3 function call.
1295  // The operand declaration of the operation matches the order of arguments of
1296  // the API function.
1297  // ```
1298  // Z3_ast Z3_API Z3_mk_select(Z3_context c, Z3_ast a, Z3_ast i);
1299  // ```
1300  // Where `a` is the array expression and `i` is the index expression.
1301  ADD_ONE_TO_ONE_PATTERN(ArraySelectOp, "Z3_mk_select", 2);
1302 
1303  // Lower the `smt.array.store` operation to the following Z3 function call.
1304  // The operand declaration of the operation matches the order of arguments of
1305  // the API function.
1306  // ```
1307  // Z3_ast Z3_API Z3_mk_store(Z3_context c, Z3_ast a, Z3_ast i, Z3_ast v);
1308  // ```
1309  // Where `a` is the array expression, `i` is the index expression, and `v` is
1310  // the value expression to be stored.
1311  ADD_ONE_TO_ONE_PATTERN(ArrayStoreOp, "Z3_mk_store", 3);
1312 
1313  // Lower the `smt.int.add` operation to the following Z3 API function call.
1314  // ```
1315  // Z3_ast Z3_API Z3_mk_add(Z3_context c, unsigned num_args, Z3_ast const
1316  // args[]);
1317  // ```
1318  // The number of arguments must be greater than zero. Therefore, the pattern
1319  // will fail if applied to an operation with less than two operands.
1320  ADD_VARIADIC_PATTERN(IntAddOp, "Z3_mk_add", 2);
1321 
1322  // Lower the `smt.int.mul` operation to the following Z3 API function call.
1323  // ```
1324  // Z3_ast Z3_API Z3_mk_mul(Z3_context c, unsigned num_args, Z3_ast const
1325  // args[]);
1326  // ```
1327  // The number of arguments must be greater than zero. Therefore, the pattern
1328  // will fail if applied to an operation with less than two operands.
1329  ADD_VARIADIC_PATTERN(IntMulOp, "Z3_mk_mul", 2);
1330 
1331  // Lower the `smt.int.sub` operation to the following Z3 API function call.
1332  // ```
1333  // Z3_ast Z3_API Z3_mk_sub(Z3_context c, unsigned num_args, Z3_ast const
1334  // args[]);
1335  // ```
1336  // The number of arguments must be greater than zero. Since the `smt.int.sub`
1337  // operation always has exactly two operands, this trivially holds.
1338  ADD_VARIADIC_PATTERN(IntSubOp, "Z3_mk_sub", 2);
1339 
1340  // Lower the `smt.int.div` operation to the following Z3 API function call.
1341  // ```
1342  // Z3_ast Z3_API Z3_mk_div(Z3_context c, Z3_ast arg1, Z3_ast arg2);
1343  // ```
1344  ADD_ONE_TO_ONE_PATTERN(IntDivOp, "Z3_mk_div", 2);
1345 
1346  // Lower the `smt.int.mod` operation to the following Z3 API function call.
1347  // ```
1348  // Z3_ast Z3_API Z3_mk_mod(Z3_context c, Z3_ast arg1, Z3_ast arg2);
1349  // ```
1350  ADD_ONE_TO_ONE_PATTERN(IntModOp, "Z3_mk_mod", 2);
1351 
1352 #undef ADD_VARIADIC_PATTERN
1353 #undef ADD_ONE_TO_ONE_PATTERN
1354 
1355  // Lower `smt.eq` operations which allows a variadic number of operands
1356  // according to the `:chainable` attribute. The Z3 API function does not
1357  // support a variadic number of operands, but exactly two:
1358  // ```
1359  // Z3_ast Z3_API Z3_mk_eq(Z3_context c, Z3_ast l, Z3_ast r)
1360  // ```
1361  // As a result, we first apply a rewrite pattern that unfolds chainable
1362  // operators and then lower it one-to-one to the API function. In this case,
1363  // this means:
1364  // ```
1365  // eq(a,b,c,d) ->
1366  // and(eq(a,b), eq(b,c), eq(c,d)) ->
1367  // and(Z3_mk_eq(ctx, a, b), Z3_mk_eq(ctx, b, c), Z3_mk_eq(ctx, c, d))
1368  // ```
1369  // The patterns for `smt.and` will then do the remaining work.
1370  patterns.add<LowerChainableSMTPattern<EqOp>>(converter, patterns.getContext(),
1371  globals, options);
1372  patterns.add<OneToOneSMTPattern<EqOp>>(converter, patterns.getContext(),
1373  globals, options, "Z3_mk_eq", 2);
1374 
1375  // Other lowering patterns. Refer to their implementation directly for more
1376  // information.
1377  patterns.add<BVConstantOpLowering, DeclareFunOpLowering, AssertOpLowering,
1378  CheckOpLowering, SolverOpLowering, ApplyFuncOpLowering,
1379  YieldOpLowering, RepeatOpLowering, ExtractOpLowering,
1380  BoolConstantOpLowering, IntConstantOpLowering,
1381  ArrayBroadcastOpLowering, BVCmpOpLowering, IntCmpOpLowering,
1382  IntAbsOpLowering, QuantifierLowering<ForallOp>,
1383  QuantifierLowering<ExistsOp>>(converter, patterns.getContext(),
1384  globals, options);
1385 }
1386 
1387 void LowerSMTToZ3LLVMPass::runOnOperation() {
1388  LowerSMTToZ3LLVMOptions options;
1389  options.debug = debug;
1390 
1391  // Set up the type converter
1392  LLVMTypeConverter converter(&getContext());
1394 
1395  RewritePatternSet patterns(&getContext());
1396 
1397  // Populate the func to LLVM conversion patterns for two reasons:
1398  // * Typically functions are represented using `func.func` and including the
1399  // patterns to lower them here is more convenient for most lowering
1400  // pipelines (avoids running another pass).
1401  // * Already having `llvm.func` in the input or lowering `func.func` before
1402  // the SMT in the body leads to issues because the SCF conversion patterns
1403  // don't take the type converter into consideration and thus create blocks
1404  // with the old types for block arguments. However, the conversion happens
1405  // top-down and thus are assumed to be converted by the parent function op
1406  // which at that point would have already been lowered (and the blocks are
1407  // also not there when doing everything in one pass, i.e.,
1408  // `populateAnyFunctionOpInterfaceTypeConversionPattern` does not have any
1409  // effect as well). Are the SCF lowering patterns actually broken and should
1410  // take a type-converter?
1411  populateFuncToLLVMConversionPatterns(converter, patterns);
1412  arith::populateArithToLLVMConversionPatterns(converter, patterns);
1413 
1414  // Populate SCF to CF and CF to LLVM lowering patterns because we create
1415  // `scf.if` operations in the lowering patterns for convenience (given the
1416  // above issue we might want to lower to LLVM directly; or fix upstream?)
1417  populateSCFToControlFlowConversionPatterns(patterns);
1418  mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
1419 
1420  // Create the globals to store the context and solver and populate the SMT
1421  // lowering patterns.
1422  OpBuilder builder(&getContext());
1423  auto globals = SMTGlobalsHandler::create(builder, getOperation());
1424  populateSMTToZ3LLVMConversionPatterns(patterns, converter, globals, options);
1425 
1426  // Do a full conversion. This assumes that all other dialects have been
1427  // lowered before this pass already.
1428  LLVMConversionTarget target(getContext());
1429  target.addLegalOp<mlir::ModuleOp>();
1430  target.addLegalOp<scf::YieldOp>();
1431 
1432  if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
1433  return signalPassFailure();
1434 }
int32_t width
Definition: FIRRTL.cpp:36
#define ADD_VARIADIC_PATTERN(OP, APINAME, MIN_NUM_ARGS)
#define ADD_ONE_TO_ONE_PATTERN(OP, APINAME, NUM_ARGS)
RewritePatternSet pattern
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
void add(mlir::ModuleOp module)
Definition: Namespace.h:46
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition: Namespace.h:72
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
void populateSMTToZ3LLVMTypeConverter(TypeConverter &converter)
Populate the given type converter with the SMT to LLVM type conversions.
void populateSMTToZ3LLVMConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter, SMTGlobalsHandler &globals, const LowerSMTToZ3LLVMOptions &options)
Add the SMT to LLVM IR conversion patterns to 'patterns'.
Definition: debug.py:1
A symbol cache for LLVM globals and functions relevant to SMT lowering patterns.
Definition: SMTToZ3LLVM.h:25
static SMTGlobalsHandler create(OpBuilder &builder, ModuleOp module)
Creates the LLVM global operations to store the pointers to the solver and the context and returns a ...
SMTGlobalsHandler(ModuleOp module, mlir::LLVM::GlobalOp solver, mlir::LLVM::GlobalOp ctx)
Initializes the caches and keeps track of the given globals to store the pointers to the SMT solver a...