CIRCT 23.0.0git
Loading...
Searching...
No Matches
CombToSMT.cpp
Go to the documentation of this file.
1//===- CombToSMT.cpp ------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/SMT/IR/SMTOps.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18namespace circt {
19#define GEN_PASS_DEF_CONVERTCOMBTOSMT
20#include "circt/Conversion/Passes.h.inc"
21} // namespace circt
22
23using namespace mlir;
24using namespace circt;
25using namespace comb;
26
27//===----------------------------------------------------------------------===//
28// Conversion patterns
29//===----------------------------------------------------------------------===//
30
31namespace {
32/// Lower a comb::ReplicateOp operation to smt::RepeatOp
33struct CombReplicateOpConversion : OpConversionPattern<ReplicateOp> {
35
36 LogicalResult
37 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
38 ConversionPatternRewriter &rewriter) const override {
39 rewriter.replaceOpWithNewOp<smt::RepeatOp>(op, op.getMultiple(),
40 adaptor.getInput());
41 return success();
42 }
43};
44
45/// Lower a comb::ICmpOp operation to a smt::BVCmpOp, smt::EqOp or
46/// smt::DistinctOp
47struct IcmpOpConversion : OpConversionPattern<ICmpOp> {
49
50 LogicalResult
51 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
52 ConversionPatternRewriter &rewriter) const override {
53 if (adaptor.getPredicate() == ICmpPredicate::weq ||
54 adaptor.getPredicate() == ICmpPredicate::ceq ||
55 adaptor.getPredicate() == ICmpPredicate::wne ||
56 adaptor.getPredicate() == ICmpPredicate::cne)
57 return rewriter.notifyMatchFailure(op,
58 "comparison predicate not supported");
59
60 Value result;
61 if (adaptor.getPredicate() == ICmpPredicate::eq) {
62 result = smt::EqOp::create(rewriter, op.getLoc(), adaptor.getLhs(),
63 adaptor.getRhs());
64 } else if (adaptor.getPredicate() == ICmpPredicate::ne) {
65 result = smt::DistinctOp::create(rewriter, op.getLoc(), adaptor.getLhs(),
66 adaptor.getRhs());
67 } else {
68 smt::BVCmpPredicate pred;
69 switch (adaptor.getPredicate()) {
70 case ICmpPredicate::sge:
71 pred = smt::BVCmpPredicate::sge;
72 break;
73 case ICmpPredicate::sgt:
74 pred = smt::BVCmpPredicate::sgt;
75 break;
76 case ICmpPredicate::sle:
77 pred = smt::BVCmpPredicate::sle;
78 break;
79 case ICmpPredicate::slt:
80 pred = smt::BVCmpPredicate::slt;
81 break;
82 case ICmpPredicate::uge:
83 pred = smt::BVCmpPredicate::uge;
84 break;
85 case ICmpPredicate::ugt:
86 pred = smt::BVCmpPredicate::ugt;
87 break;
88 case ICmpPredicate::ule:
89 pred = smt::BVCmpPredicate::ule;
90 break;
91 case ICmpPredicate::ult:
92 pred = smt::BVCmpPredicate::ult;
93 break;
94 default:
95 llvm_unreachable("all cases handled above");
96 }
97
98 result = smt::BVCmpOp::create(rewriter, op.getLoc(), pred,
99 adaptor.getLhs(), adaptor.getRhs());
100 }
101
102 Value convVal = typeConverter->materializeTargetConversion(
103 rewriter, op.getLoc(), typeConverter->convertType(op.getType()),
104 result);
105 if (!convVal)
106 return failure();
107
108 rewriter.replaceOp(op, convVal);
109 return success();
110 }
111};
112
113/// Lower a comb::ExtractOp operation to an smt::ExtractOp
114struct ExtractOpConversion : OpConversionPattern<ExtractOp> {
116
117 LogicalResult
118 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
119 ConversionPatternRewriter &rewriter) const override {
120
121 rewriter.replaceOpWithNewOp<smt::ExtractOp>(
122 op, typeConverter->convertType(op.getResult().getType()),
123 adaptor.getLowBitAttr(), adaptor.getInput());
124 return success();
125 }
126};
127
128/// Lower a comb::MuxOp operation to an smt::IteOp
129struct MuxOpConversion : OpConversionPattern<MuxOp> {
131
132 LogicalResult
133 matchAndRewrite(MuxOp op, OpAdaptor adaptor,
134 ConversionPatternRewriter &rewriter) const override {
135 Value condition = typeConverter->materializeTargetConversion(
136 rewriter, op.getLoc(), smt::BoolType::get(getContext()),
137 adaptor.getCond());
138 rewriter.replaceOpWithNewOp<smt::IteOp>(
139 op, condition, adaptor.getTrueValue(), adaptor.getFalseValue());
140 return success();
141 }
142};
143
144/// Lower a comb::SubOp operation to an smt::BVNegOp + smt::BVAddOp
145struct SubOpConversion : OpConversionPattern<SubOp> {
147
148 LogicalResult
149 matchAndRewrite(SubOp op, OpAdaptor adaptor,
150 ConversionPatternRewriter &rewriter) const override {
151 Value negRhs =
152 smt::BVNegOp::create(rewriter, op.getLoc(), adaptor.getRhs());
153 rewriter.replaceOpWithNewOp<smt::BVAddOp>(op, adaptor.getLhs(), negRhs);
154 return success();
155 }
156};
157
158/// Lower a comb::ParityOp operation to a chain of smt::Extract + XOr ops
159struct ParityOpConversion : OpConversionPattern<ParityOp> {
161
162 LogicalResult
163 matchAndRewrite(ParityOp op, OpAdaptor adaptor,
164 ConversionPatternRewriter &rewriter) const override {
165 Location loc = op.getLoc();
166 unsigned bitwidth =
167 cast<smt::BitVectorType>(adaptor.getInput().getType()).getWidth();
168
169 // Note: the SMT bitvector type does not support 0 bitwidth vectors and thus
170 // the type conversion should already fail.
171 Type oneBitTy = smt::BitVectorType::get(getContext(), 1);
172 Value runner =
173 smt::ExtractOp::create(rewriter, loc, oneBitTy, 0, adaptor.getInput());
174 for (unsigned i = 1; i < bitwidth; ++i) {
175 Value ext = smt::ExtractOp::create(rewriter, loc, oneBitTy, i,
176 adaptor.getInput());
177 runner = smt::BVXOrOp::create(rewriter, loc, runner, ext);
178 }
179
180 rewriter.replaceOp(op, runner);
181 return success();
182 }
183};
184
185/// Lower a comb::ReverseOp operation to a chain of smt::Extract + Concat ops
186struct ReverseOpConversion : OpConversionPattern<ReverseOp> {
188
189 LogicalResult
190 matchAndRewrite(ReverseOp op, OpAdaptor adaptor,
191 ConversionPatternRewriter &rewriter) const override {
192 Location loc = op.getLoc();
193 unsigned bitwidth =
194 cast<smt::BitVectorType>(adaptor.getInput().getType()).getWidth();
195
196 Type oneBitTy = smt::BitVectorType::get(getContext(), 1);
197 // Extract bit 0 (LSB), which becomes the MSB of the reversed result.
198 Value result =
199 smt::ExtractOp::create(rewriter, loc, oneBitTy, 0, adaptor.getInput());
200 // Concatenate remaining bits in ascending order (LSB-to-MSB of input
201 // becomes MSB-to-LSB of output).
202 for (unsigned i = 1; i < bitwidth; ++i) {
203 Value ext = smt::ExtractOp::create(rewriter, loc, oneBitTy, i,
204 adaptor.getInput());
205 result = smt::ConcatOp::create(rewriter, loc, result, ext);
206 }
207
208 rewriter.replaceOp(op, result);
209 return success();
210 }
211};
212
213/// Lower the SourceOp to the TargetOp one-to-one.
214template <typename SourceOp, typename TargetOp>
215struct OneToOneOpConversion : OpConversionPattern<SourceOp> {
217 using OpAdaptor = typename SourceOp::Adaptor;
218
219 LogicalResult
220 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
221 ConversionPatternRewriter &rewriter) const override {
222
223 rewriter.replaceOpWithNewOp<TargetOp>(
224 op,
226 op.getResult().getType()),
227 adaptor.getOperands());
228 return success();
229 }
230};
231
232/// Lower the SourceOp to the TargetOp special-casing if the second operand is
233/// zero to return a new symbolic value.
234template <typename SourceOp, typename TargetOp>
235struct DivisionOpConversion : OpConversionPattern<SourceOp> {
237 using OpAdaptor = typename SourceOp::Adaptor;
238
239 LogicalResult
240 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
241 ConversionPatternRewriter &rewriter) const override {
242 Location loc = op.getLoc();
243 auto type = dyn_cast<smt::BitVectorType>(adaptor.getRhs().getType());
244 if (!type)
245 return failure();
246
247 auto resultType = OpConversionPattern<SourceOp>::typeConverter->convertType(
248 op.getResult().getType());
249 Value zero =
250 smt::BVConstantOp::create(rewriter, loc, APInt(type.getWidth(), 0));
251 Value isZero = smt::EqOp::create(rewriter, loc, adaptor.getRhs(), zero);
252 Value symbolicVal = smt::DeclareFunOp::create(rewriter, loc, resultType);
253 Value division =
254 TargetOp::create(rewriter, loc, resultType, adaptor.getOperands());
255 rewriter.replaceOpWithNewOp<smt::IteOp>(op, isZero, symbolicVal, division);
256 return success();
257 }
258};
259
260/// Converts an operation with a variadic number of operands to a chain of
261/// binary operations assuming left-associativity of the operation.
262template <typename SourceOp, typename TargetOp>
263struct VariadicToBinaryOpConversion : OpConversionPattern<SourceOp> {
265 using OpAdaptor = typename SourceOp::Adaptor;
266
267 LogicalResult
268 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
269 ConversionPatternRewriter &rewriter) const override {
270
271 ValueRange operands = adaptor.getOperands();
272 if (operands.size() < 2)
273 return failure();
274
275 Value runner = operands[0];
276 for (Value operand : operands.drop_front())
277 runner = TargetOp::create(rewriter, op.getLoc(), runner, operand);
278
279 rewriter.replaceOp(op, runner);
280 return success();
281 }
282};
283
284} // namespace
285
286//===----------------------------------------------------------------------===//
287// Convert Comb to SMT pass
288//===----------------------------------------------------------------------===//
289
290namespace {
291struct ConvertCombToSMTPass
292 : public circt::impl::ConvertCombToSMTBase<ConvertCombToSMTPass> {
293 void runOnOperation() override;
294};
295} // namespace
296
297void circt::populateCombToSMTConversionPatterns(TypeConverter &converter,
298 RewritePatternSet &patterns) {
299 patterns.add<CombReplicateOpConversion, IcmpOpConversion, ExtractOpConversion,
300 SubOpConversion, MuxOpConversion, ParityOpConversion,
301 ReverseOpConversion, OneToOneOpConversion<ShlOp, smt::BVShlOp>,
302 OneToOneOpConversion<ShrUOp, smt::BVLShrOp>,
303 OneToOneOpConversion<ShrSOp, smt::BVAShrOp>,
304 DivisionOpConversion<DivSOp, smt::BVSDivOp>,
305 DivisionOpConversion<DivUOp, smt::BVUDivOp>,
306 DivisionOpConversion<ModSOp, smt::BVSRemOp>,
307 DivisionOpConversion<ModUOp, smt::BVURemOp>,
308 VariadicToBinaryOpConversion<ConcatOp, smt::ConcatOp>,
309 VariadicToBinaryOpConversion<AddOp, smt::BVAddOp>,
310 VariadicToBinaryOpConversion<MulOp, smt::BVMulOp>,
311 VariadicToBinaryOpConversion<AndOp, smt::BVAndOp>,
312 VariadicToBinaryOpConversion<OrOp, smt::BVOrOp>,
313 VariadicToBinaryOpConversion<XorOp, smt::BVXOrOp>>(
314 converter, patterns.getContext());
315
316 // TODO: there is one unsupported operation in the comb dialect:
317 // 'truth_table'.
318}
319
320void ConvertCombToSMTPass::runOnOperation() {
321 ConversionTarget target(getContext());
322 target.addIllegalDialect<hw::HWDialect>();
323 target.addIllegalOp<seq::FromClockOp>();
324 target.addIllegalOp<seq::ToClockOp>();
325 target.addIllegalDialect<comb::CombDialect>();
326 target.addLegalDialect<smt::SMTDialect>();
327 target.addLegalDialect<mlir::func::FuncDialect>();
328
329 RewritePatternSet patterns(&getContext());
330 TypeConverter converter;
332 // Also add HW patterns because some 'comb' canonicalizers produce constant
333 // operations, i.e., even if there is absolutely no HW operation present
334 // initially, we might have to convert one.
337
338 if (failed(mlir::applyPartialConversion(getOperation(), target,
339 std::move(patterns))))
340 return signalPassFailure();
341}
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns, bool forSMTLIBExport)
Get the HW to SMT conversion patterns.
Definition HWToSMT.cpp:370
void populateCombToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition HWToSMT.cpp:265
Definition comb.py:1