CIRCT 20.0.0git
Loading...
Searching...
No Matches
CombToSMT.cpp
Go to the documentation of this file.
1//===- CombToSMT.cpp ------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/DialectConversion.h"
16
17namespace circt {
18#define GEN_PASS_DEF_CONVERTCOMBTOSMT
19#include "circt/Conversion/Passes.h.inc"
20} // namespace circt
21
22using namespace circt;
23using namespace comb;
24
25//===----------------------------------------------------------------------===//
26// Conversion patterns
27//===----------------------------------------------------------------------===//
28
29namespace {
30/// Lower a comb::ReplicateOp operation to smt::RepeatOp
31struct CombReplicateOpConversion : OpConversionPattern<ReplicateOp> {
33
34 LogicalResult
35 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
36 ConversionPatternRewriter &rewriter) const override {
37 rewriter.replaceOpWithNewOp<smt::RepeatOp>(op, op.getMultiple(),
38 adaptor.getInput());
39 return success();
40 }
41};
42
43/// Lower a comb::ICmpOp operation to a smt::BVCmpOp, smt::EqOp or
44/// smt::DistinctOp
45struct IcmpOpConversion : OpConversionPattern<ICmpOp> {
47
48 LogicalResult
49 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
50 ConversionPatternRewriter &rewriter) const override {
51 if (adaptor.getPredicate() == ICmpPredicate::weq ||
52 adaptor.getPredicate() == ICmpPredicate::ceq ||
53 adaptor.getPredicate() == ICmpPredicate::wne ||
54 adaptor.getPredicate() == ICmpPredicate::cne)
55 return rewriter.notifyMatchFailure(op,
56 "comparison predicate not supported");
57
58 if (adaptor.getPredicate() == ICmpPredicate::eq) {
59 rewriter.replaceOpWithNewOp<smt::EqOp>(op, adaptor.getLhs(),
60 adaptor.getRhs());
61 return success();
62 }
63
64 if (adaptor.getPredicate() == ICmpPredicate::ne) {
65 rewriter.replaceOpWithNewOp<smt::DistinctOp>(op, adaptor.getLhs(),
66 adaptor.getRhs());
67 return success();
68 }
69
70 smt::BVCmpPredicate pred;
71 switch (adaptor.getPredicate()) {
72 case ICmpPredicate::sge:
73 pred = smt::BVCmpPredicate::sge;
74 break;
75 case ICmpPredicate::sgt:
76 pred = smt::BVCmpPredicate::sgt;
77 break;
78 case ICmpPredicate::sle:
79 pred = smt::BVCmpPredicate::sle;
80 break;
81 case ICmpPredicate::slt:
82 pred = smt::BVCmpPredicate::slt;
83 break;
84 case ICmpPredicate::uge:
85 pred = smt::BVCmpPredicate::uge;
86 break;
87 case ICmpPredicate::ugt:
88 pred = smt::BVCmpPredicate::ugt;
89 break;
90 case ICmpPredicate::ule:
91 pred = smt::BVCmpPredicate::ule;
92 break;
93 case ICmpPredicate::ult:
94 pred = smt::BVCmpPredicate::ult;
95 break;
96 default:
97 llvm_unreachable("all cases handled above");
98 }
99
100 rewriter.replaceOpWithNewOp<smt::BVCmpOp>(op, pred, adaptor.getLhs(),
101 adaptor.getRhs());
102 return success();
103 }
104};
105
106/// Lower a comb::ExtractOp operation to an smt::ExtractOp
107struct ExtractOpConversion : OpConversionPattern<ExtractOp> {
109
110 LogicalResult
111 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
112 ConversionPatternRewriter &rewriter) const override {
113
114 rewriter.replaceOpWithNewOp<smt::ExtractOp>(
115 op, typeConverter->convertType(op.getResult().getType()),
116 adaptor.getLowBitAttr(), adaptor.getInput());
117 return success();
118 }
119};
120
121/// Lower a comb::MuxOp operation to an smt::IteOp
122struct MuxOpConversion : OpConversionPattern<MuxOp> {
124
125 LogicalResult
126 matchAndRewrite(MuxOp op, OpAdaptor adaptor,
127 ConversionPatternRewriter &rewriter) const override {
128 Value condition = typeConverter->materializeTargetConversion(
129 rewriter, op.getLoc(), smt::BoolType::get(getContext()),
130 adaptor.getCond());
131 rewriter.replaceOpWithNewOp<smt::IteOp>(
132 op, condition, adaptor.getTrueValue(), adaptor.getFalseValue());
133 return success();
134 }
135};
136
137/// Lower a comb::SubOp operation to an smt::BVNegOp + smt::BVAddOp
138struct SubOpConversion : OpConversionPattern<SubOp> {
140
141 LogicalResult
142 matchAndRewrite(SubOp op, OpAdaptor adaptor,
143 ConversionPatternRewriter &rewriter) const override {
144 Value negRhs = rewriter.create<smt::BVNegOp>(op.getLoc(), adaptor.getRhs());
145 rewriter.replaceOpWithNewOp<smt::BVAddOp>(op, adaptor.getLhs(), negRhs);
146 return success();
147 }
148};
149
150/// Lower a comb::ParityOp operation to a chain of smt::Extract + XOr ops
151struct ParityOpConversion : OpConversionPattern<ParityOp> {
153
154 LogicalResult
155 matchAndRewrite(ParityOp op, OpAdaptor adaptor,
156 ConversionPatternRewriter &rewriter) const override {
157 Location loc = op.getLoc();
158 unsigned bitwidth =
159 cast<smt::BitVectorType>(adaptor.getInput().getType()).getWidth();
160
161 // Note: the SMT bitvector type does not support 0 bitwidth vectors and thus
162 // the type conversion should already fail.
163 Type oneBitTy = smt::BitVectorType::get(getContext(), 1);
164 Value runner =
165 rewriter.create<smt::ExtractOp>(loc, oneBitTy, 0, adaptor.getInput());
166 for (unsigned i = 1; i < bitwidth; ++i) {
167 Value ext =
168 rewriter.create<smt::ExtractOp>(loc, oneBitTy, i, adaptor.getInput());
169 runner = rewriter.create<smt::BVXOrOp>(loc, runner, ext);
170 }
171
172 rewriter.replaceOp(op, runner);
173 return success();
174 }
175};
176
177/// Lower the SourceOp to the TargetOp one-to-one.
178template <typename SourceOp, typename TargetOp>
179struct OneToOneOpConversion : OpConversionPattern<SourceOp> {
181 using OpAdaptor = typename SourceOp::Adaptor;
182
183 LogicalResult
184 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter) const override {
186
187 rewriter.replaceOpWithNewOp<TargetOp>(
188 op,
190 op.getResult().getType()),
191 adaptor.getOperands());
192 return success();
193 }
194};
195
196/// Lower the SourceOp to the TargetOp special-casing if the second operand is
197/// zero to return a new symbolic value.
198template <typename SourceOp, typename TargetOp>
199struct DivisionOpConversion : OpConversionPattern<SourceOp> {
201 using OpAdaptor = typename SourceOp::Adaptor;
202
203 LogicalResult
204 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override {
206 Location loc = op.getLoc();
207 auto type = dyn_cast<smt::BitVectorType>(adaptor.getRhs().getType());
208 if (!type)
209 return failure();
210
211 auto resultType = OpConversionPattern<SourceOp>::typeConverter->convertType(
212 op.getResult().getType());
213 Value zero =
214 rewriter.create<smt::BVConstantOp>(loc, APInt(type.getWidth(), 0));
215 Value isZero = rewriter.create<smt::EqOp>(loc, adaptor.getRhs(), zero);
216 Value symbolicVal = rewriter.create<smt::DeclareFunOp>(loc, resultType);
217 Value division =
218 rewriter.create<TargetOp>(loc, resultType, adaptor.getOperands());
219 rewriter.replaceOpWithNewOp<smt::IteOp>(op, isZero, symbolicVal, division);
220 return success();
221 }
222};
223
224/// Converts an operation with a variadic number of operands to a chain of
225/// binary operations assuming left-associativity of the operation.
226template <typename SourceOp, typename TargetOp>
227struct VariadicToBinaryOpConversion : OpConversionPattern<SourceOp> {
229 using OpAdaptor = typename SourceOp::Adaptor;
230
231 LogicalResult
232 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
233 ConversionPatternRewriter &rewriter) const override {
234
235 ValueRange operands = adaptor.getOperands();
236 if (operands.size() < 2)
237 return failure();
238
239 Value runner = operands[0];
240 for (Value operand : operands.drop_front())
241 runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
242
243 rewriter.replaceOp(op, runner);
244 return success();
245 }
246};
247
248} // namespace
249
250//===----------------------------------------------------------------------===//
251// Convert Comb to SMT pass
252//===----------------------------------------------------------------------===//
253
254namespace {
255struct ConvertCombToSMTPass
256 : public impl::ConvertCombToSMTBase<ConvertCombToSMTPass> {
257 void runOnOperation() override;
258};
259} // namespace
260
261void circt::populateCombToSMTConversionPatterns(TypeConverter &converter,
262 RewritePatternSet &patterns) {
263 patterns.add<CombReplicateOpConversion, IcmpOpConversion, ExtractOpConversion,
264 SubOpConversion, MuxOpConversion, ParityOpConversion,
265 OneToOneOpConversion<ShlOp, smt::BVShlOp>,
266 OneToOneOpConversion<ShrUOp, smt::BVLShrOp>,
267 OneToOneOpConversion<ShrSOp, smt::BVAShrOp>,
268 DivisionOpConversion<DivSOp, smt::BVSDivOp>,
269 DivisionOpConversion<DivUOp, smt::BVUDivOp>,
270 DivisionOpConversion<ModSOp, smt::BVSRemOp>,
271 DivisionOpConversion<ModUOp, smt::BVURemOp>,
272 VariadicToBinaryOpConversion<ConcatOp, smt::ConcatOp>,
273 VariadicToBinaryOpConversion<AddOp, smt::BVAddOp>,
274 VariadicToBinaryOpConversion<MulOp, smt::BVMulOp>,
275 VariadicToBinaryOpConversion<AndOp, smt::BVAndOp>,
276 VariadicToBinaryOpConversion<OrOp, smt::BVOrOp>,
277 VariadicToBinaryOpConversion<XorOp, smt::BVXOrOp>>(
278 converter, patterns.getContext());
279
280 // TODO: there are two unsupported operations in the comb dialect: 'parity'
281 // and 'truth_table'.
282}
283
284void ConvertCombToSMTPass::runOnOperation() {
285 ConversionTarget target(getContext());
286 target.addIllegalDialect<comb::CombDialect>();
287 target.addLegalDialect<smt::SMTDialect>();
288
289 RewritePatternSet patterns(&getContext());
290 TypeConverter converter;
292 // Also add HW patterns because some 'comb' canonicalizers produce constant
293 // operations, i.e., even if there is absolutely no HW operation present
294 // initially, we might have to convert one.
297
298 if (failed(mlir::applyPartialConversion(getOperation(), target,
299 std::move(patterns))))
300 return signalPassFailure();
301}
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition HWToSMT.cpp:284
void populateCombToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition HWToSMT.cpp:181
Definition comb.py:1