Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CombToSMT.cpp
Go to the documentation of this file.
1//===- CombToSMT.cpp ------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/Dialect/SMT/IR/SMTOps.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/DialectConversion.h"
16
17namespace circt {
18#define GEN_PASS_DEF_CONVERTCOMBTOSMT
19#include "circt/Conversion/Passes.h.inc"
20} // namespace circt
21
22using namespace mlir;
23using namespace circt;
24using namespace comb;
25
26//===----------------------------------------------------------------------===//
27// Conversion patterns
28//===----------------------------------------------------------------------===//
29
30namespace {
31/// Lower a comb::ReplicateOp operation to smt::RepeatOp
32struct CombReplicateOpConversion : OpConversionPattern<ReplicateOp> {
34
35 LogicalResult
36 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
37 ConversionPatternRewriter &rewriter) const override {
38 rewriter.replaceOpWithNewOp<smt::RepeatOp>(op, op.getMultiple(),
39 adaptor.getInput());
40 return success();
41 }
42};
43
44/// Lower a comb::ICmpOp operation to a smt::BVCmpOp, smt::EqOp or
45/// smt::DistinctOp
46struct IcmpOpConversion : OpConversionPattern<ICmpOp> {
48
49 LogicalResult
50 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
51 ConversionPatternRewriter &rewriter) const override {
52 if (adaptor.getPredicate() == ICmpPredicate::weq ||
53 adaptor.getPredicate() == ICmpPredicate::ceq ||
54 adaptor.getPredicate() == ICmpPredicate::wne ||
55 adaptor.getPredicate() == ICmpPredicate::cne)
56 return rewriter.notifyMatchFailure(op,
57 "comparison predicate not supported");
58
59 if (adaptor.getPredicate() == ICmpPredicate::eq) {
60 rewriter.replaceOpWithNewOp<smt::EqOp>(op, adaptor.getLhs(),
61 adaptor.getRhs());
62 return success();
63 }
64
65 if (adaptor.getPredicate() == ICmpPredicate::ne) {
66 rewriter.replaceOpWithNewOp<smt::DistinctOp>(op, adaptor.getLhs(),
67 adaptor.getRhs());
68 return success();
69 }
70
71 smt::BVCmpPredicate pred;
72 switch (adaptor.getPredicate()) {
73 case ICmpPredicate::sge:
74 pred = smt::BVCmpPredicate::sge;
75 break;
76 case ICmpPredicate::sgt:
77 pred = smt::BVCmpPredicate::sgt;
78 break;
79 case ICmpPredicate::sle:
80 pred = smt::BVCmpPredicate::sle;
81 break;
82 case ICmpPredicate::slt:
83 pred = smt::BVCmpPredicate::slt;
84 break;
85 case ICmpPredicate::uge:
86 pred = smt::BVCmpPredicate::uge;
87 break;
88 case ICmpPredicate::ugt:
89 pred = smt::BVCmpPredicate::ugt;
90 break;
91 case ICmpPredicate::ule:
92 pred = smt::BVCmpPredicate::ule;
93 break;
94 case ICmpPredicate::ult:
95 pred = smt::BVCmpPredicate::ult;
96 break;
97 default:
98 llvm_unreachable("all cases handled above");
99 }
100
101 rewriter.replaceOpWithNewOp<smt::BVCmpOp>(op, pred, adaptor.getLhs(),
102 adaptor.getRhs());
103 return success();
104 }
105};
106
107/// Lower a comb::ExtractOp operation to an smt::ExtractOp
108struct ExtractOpConversion : OpConversionPattern<ExtractOp> {
110
111 LogicalResult
112 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
113 ConversionPatternRewriter &rewriter) const override {
114
115 rewriter.replaceOpWithNewOp<smt::ExtractOp>(
116 op, typeConverter->convertType(op.getResult().getType()),
117 adaptor.getLowBitAttr(), adaptor.getInput());
118 return success();
119 }
120};
121
122/// Lower a comb::MuxOp operation to an smt::IteOp
123struct MuxOpConversion : OpConversionPattern<MuxOp> {
125
126 LogicalResult
127 matchAndRewrite(MuxOp op, OpAdaptor adaptor,
128 ConversionPatternRewriter &rewriter) const override {
129 Value condition = typeConverter->materializeTargetConversion(
130 rewriter, op.getLoc(), smt::BoolType::get(getContext()),
131 adaptor.getCond());
132 rewriter.replaceOpWithNewOp<smt::IteOp>(
133 op, condition, adaptor.getTrueValue(), adaptor.getFalseValue());
134 return success();
135 }
136};
137
138/// Lower a comb::SubOp operation to an smt::BVNegOp + smt::BVAddOp
139struct SubOpConversion : OpConversionPattern<SubOp> {
141
142 LogicalResult
143 matchAndRewrite(SubOp op, OpAdaptor adaptor,
144 ConversionPatternRewriter &rewriter) const override {
145 Value negRhs = rewriter.create<smt::BVNegOp>(op.getLoc(), adaptor.getRhs());
146 rewriter.replaceOpWithNewOp<smt::BVAddOp>(op, adaptor.getLhs(), negRhs);
147 return success();
148 }
149};
150
151/// Lower a comb::ParityOp operation to a chain of smt::Extract + XOr ops
152struct ParityOpConversion : OpConversionPattern<ParityOp> {
154
155 LogicalResult
156 matchAndRewrite(ParityOp op, OpAdaptor adaptor,
157 ConversionPatternRewriter &rewriter) const override {
158 Location loc = op.getLoc();
159 unsigned bitwidth =
160 cast<smt::BitVectorType>(adaptor.getInput().getType()).getWidth();
161
162 // Note: the SMT bitvector type does not support 0 bitwidth vectors and thus
163 // the type conversion should already fail.
164 Type oneBitTy = smt::BitVectorType::get(getContext(), 1);
165 Value runner =
166 rewriter.create<smt::ExtractOp>(loc, oneBitTy, 0, adaptor.getInput());
167 for (unsigned i = 1; i < bitwidth; ++i) {
168 Value ext =
169 rewriter.create<smt::ExtractOp>(loc, oneBitTy, i, adaptor.getInput());
170 runner = rewriter.create<smt::BVXOrOp>(loc, runner, ext);
171 }
172
173 rewriter.replaceOp(op, runner);
174 return success();
175 }
176};
177
178/// Lower the SourceOp to the TargetOp one-to-one.
179template <typename SourceOp, typename TargetOp>
180struct OneToOneOpConversion : OpConversionPattern<SourceOp> {
182 using OpAdaptor = typename SourceOp::Adaptor;
183
184 LogicalResult
185 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
186 ConversionPatternRewriter &rewriter) const override {
187
188 rewriter.replaceOpWithNewOp<TargetOp>(
189 op,
191 op.getResult().getType()),
192 adaptor.getOperands());
193 return success();
194 }
195};
196
197/// Lower the SourceOp to the TargetOp special-casing if the second operand is
198/// zero to return a new symbolic value.
199template <typename SourceOp, typename TargetOp>
200struct DivisionOpConversion : OpConversionPattern<SourceOp> {
202 using OpAdaptor = typename SourceOp::Adaptor;
203
204 LogicalResult
205 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
206 ConversionPatternRewriter &rewriter) const override {
207 Location loc = op.getLoc();
208 auto type = dyn_cast<smt::BitVectorType>(adaptor.getRhs().getType());
209 if (!type)
210 return failure();
211
212 auto resultType = OpConversionPattern<SourceOp>::typeConverter->convertType(
213 op.getResult().getType());
214 Value zero =
215 rewriter.create<smt::BVConstantOp>(loc, APInt(type.getWidth(), 0));
216 Value isZero = rewriter.create<smt::EqOp>(loc, adaptor.getRhs(), zero);
217 Value symbolicVal = rewriter.create<smt::DeclareFunOp>(loc, resultType);
218 Value division =
219 rewriter.create<TargetOp>(loc, resultType, adaptor.getOperands());
220 rewriter.replaceOpWithNewOp<smt::IteOp>(op, isZero, symbolicVal, division);
221 return success();
222 }
223};
224
225/// Converts an operation with a variadic number of operands to a chain of
226/// binary operations assuming left-associativity of the operation.
227template <typename SourceOp, typename TargetOp>
228struct VariadicToBinaryOpConversion : OpConversionPattern<SourceOp> {
230 using OpAdaptor = typename SourceOp::Adaptor;
231
232 LogicalResult
233 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
234 ConversionPatternRewriter &rewriter) const override {
235
236 ValueRange operands = adaptor.getOperands();
237 if (operands.size() < 2)
238 return failure();
239
240 Value runner = operands[0];
241 for (Value operand : operands.drop_front())
242 runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
243
244 rewriter.replaceOp(op, runner);
245 return success();
246 }
247};
248
249} // namespace
250
251//===----------------------------------------------------------------------===//
252// Convert Comb to SMT pass
253//===----------------------------------------------------------------------===//
254
255namespace {
256struct ConvertCombToSMTPass
257 : public circt::impl::ConvertCombToSMTBase<ConvertCombToSMTPass> {
258 void runOnOperation() override;
259};
260} // namespace
261
262void circt::populateCombToSMTConversionPatterns(TypeConverter &converter,
263 RewritePatternSet &patterns) {
264 patterns.add<CombReplicateOpConversion, IcmpOpConversion, ExtractOpConversion,
265 SubOpConversion, MuxOpConversion, ParityOpConversion,
266 OneToOneOpConversion<ShlOp, smt::BVShlOp>,
267 OneToOneOpConversion<ShrUOp, smt::BVLShrOp>,
268 OneToOneOpConversion<ShrSOp, smt::BVAShrOp>,
269 DivisionOpConversion<DivSOp, smt::BVSDivOp>,
270 DivisionOpConversion<DivUOp, smt::BVUDivOp>,
271 DivisionOpConversion<ModSOp, smt::BVSRemOp>,
272 DivisionOpConversion<ModUOp, smt::BVURemOp>,
273 VariadicToBinaryOpConversion<ConcatOp, smt::ConcatOp>,
274 VariadicToBinaryOpConversion<AddOp, smt::BVAddOp>,
275 VariadicToBinaryOpConversion<MulOp, smt::BVMulOp>,
276 VariadicToBinaryOpConversion<AndOp, smt::BVAndOp>,
277 VariadicToBinaryOpConversion<OrOp, smt::BVOrOp>,
278 VariadicToBinaryOpConversion<XorOp, smt::BVXOrOp>>(
279 converter, patterns.getContext());
280
281 // TODO: there are two unsupported operations in the comb dialect: 'parity'
282 // and 'truth_table'.
283}
284
285void ConvertCombToSMTPass::runOnOperation() {
286 ConversionTarget target(getContext());
287 target.addIllegalDialect<comb::CombDialect>();
288 target.addLegalDialect<smt::SMTDialect>();
289
290 RewritePatternSet patterns(&getContext());
291 TypeConverter converter;
293 // Also add HW patterns because some 'comb' canonicalizers produce constant
294 // operations, i.e., even if there is absolutely no HW operation present
295 // initially, we might have to convert one.
298
299 if (failed(mlir::applyPartialConversion(getOperation(), target,
300 std::move(patterns))))
301 return signalPassFailure();
302}
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:212
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition HWToSMT.cpp:289
void populateCombToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition HWToSMT.cpp:184
Definition comb.py:1