Loading [MathJax]/jax/output/HTML-CSS/config.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CombToSMT.cpp
Go to the documentation of this file.
1//===- CombToSMT.cpp ------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/SMT/IR/SMTOps.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18namespace circt {
19#define GEN_PASS_DEF_CONVERTCOMBTOSMT
20#include "circt/Conversion/Passes.h.inc"
21} // namespace circt
22
23using namespace mlir;
24using namespace circt;
25using namespace comb;
26
27//===----------------------------------------------------------------------===//
28// Conversion patterns
29//===----------------------------------------------------------------------===//
30
31namespace {
32/// Lower a comb::ReplicateOp operation to smt::RepeatOp
33struct CombReplicateOpConversion : OpConversionPattern<ReplicateOp> {
35
36 LogicalResult
37 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
38 ConversionPatternRewriter &rewriter) const override {
39 rewriter.replaceOpWithNewOp<smt::RepeatOp>(op, op.getMultiple(),
40 adaptor.getInput());
41 return success();
42 }
43};
44
45/// Lower a comb::ICmpOp operation to a smt::BVCmpOp, smt::EqOp or
46/// smt::DistinctOp
47struct IcmpOpConversion : OpConversionPattern<ICmpOp> {
49
50 LogicalResult
51 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
52 ConversionPatternRewriter &rewriter) const override {
53 if (adaptor.getPredicate() == ICmpPredicate::weq ||
54 adaptor.getPredicate() == ICmpPredicate::ceq ||
55 adaptor.getPredicate() == ICmpPredicate::wne ||
56 adaptor.getPredicate() == ICmpPredicate::cne)
57 return rewriter.notifyMatchFailure(op,
58 "comparison predicate not supported");
59
60 if (adaptor.getPredicate() == ICmpPredicate::eq) {
61 rewriter.replaceOpWithNewOp<smt::EqOp>(op, adaptor.getLhs(),
62 adaptor.getRhs());
63 return success();
64 }
65
66 if (adaptor.getPredicate() == ICmpPredicate::ne) {
67 rewriter.replaceOpWithNewOp<smt::DistinctOp>(op, adaptor.getLhs(),
68 adaptor.getRhs());
69 return success();
70 }
71
72 smt::BVCmpPredicate pred;
73 switch (adaptor.getPredicate()) {
74 case ICmpPredicate::sge:
75 pred = smt::BVCmpPredicate::sge;
76 break;
77 case ICmpPredicate::sgt:
78 pred = smt::BVCmpPredicate::sgt;
79 break;
80 case ICmpPredicate::sle:
81 pred = smt::BVCmpPredicate::sle;
82 break;
83 case ICmpPredicate::slt:
84 pred = smt::BVCmpPredicate::slt;
85 break;
86 case ICmpPredicate::uge:
87 pred = smt::BVCmpPredicate::uge;
88 break;
89 case ICmpPredicate::ugt:
90 pred = smt::BVCmpPredicate::ugt;
91 break;
92 case ICmpPredicate::ule:
93 pred = smt::BVCmpPredicate::ule;
94 break;
95 case ICmpPredicate::ult:
96 pred = smt::BVCmpPredicate::ult;
97 break;
98 default:
99 llvm_unreachable("all cases handled above");
100 }
101
102 rewriter.replaceOpWithNewOp<smt::BVCmpOp>(op, pred, adaptor.getLhs(),
103 adaptor.getRhs());
104 return success();
105 }
106};
107
108/// Lower a comb::ExtractOp operation to an smt::ExtractOp
109struct ExtractOpConversion : OpConversionPattern<ExtractOp> {
111
112 LogicalResult
113 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
114 ConversionPatternRewriter &rewriter) const override {
115
116 rewriter.replaceOpWithNewOp<smt::ExtractOp>(
117 op, typeConverter->convertType(op.getResult().getType()),
118 adaptor.getLowBitAttr(), adaptor.getInput());
119 return success();
120 }
121};
122
123/// Lower a comb::MuxOp operation to an smt::IteOp
124struct MuxOpConversion : OpConversionPattern<MuxOp> {
126
127 LogicalResult
128 matchAndRewrite(MuxOp op, OpAdaptor adaptor,
129 ConversionPatternRewriter &rewriter) const override {
130 Value condition = typeConverter->materializeTargetConversion(
131 rewriter, op.getLoc(), smt::BoolType::get(getContext()),
132 adaptor.getCond());
133 rewriter.replaceOpWithNewOp<smt::IteOp>(
134 op, condition, adaptor.getTrueValue(), adaptor.getFalseValue());
135 return success();
136 }
137};
138
139/// Lower a comb::SubOp operation to an smt::BVNegOp + smt::BVAddOp
140struct SubOpConversion : OpConversionPattern<SubOp> {
142
143 LogicalResult
144 matchAndRewrite(SubOp op, OpAdaptor adaptor,
145 ConversionPatternRewriter &rewriter) const override {
146 Value negRhs = rewriter.create<smt::BVNegOp>(op.getLoc(), adaptor.getRhs());
147 rewriter.replaceOpWithNewOp<smt::BVAddOp>(op, adaptor.getLhs(), negRhs);
148 return success();
149 }
150};
151
152/// Lower a comb::ParityOp operation to a chain of smt::Extract + XOr ops
153struct ParityOpConversion : OpConversionPattern<ParityOp> {
155
156 LogicalResult
157 matchAndRewrite(ParityOp op, OpAdaptor adaptor,
158 ConversionPatternRewriter &rewriter) const override {
159 Location loc = op.getLoc();
160 unsigned bitwidth =
161 cast<smt::BitVectorType>(adaptor.getInput().getType()).getWidth();
162
163 // Note: the SMT bitvector type does not support 0 bitwidth vectors and thus
164 // the type conversion should already fail.
165 Type oneBitTy = smt::BitVectorType::get(getContext(), 1);
166 Value runner =
167 rewriter.create<smt::ExtractOp>(loc, oneBitTy, 0, adaptor.getInput());
168 for (unsigned i = 1; i < bitwidth; ++i) {
169 Value ext =
170 rewriter.create<smt::ExtractOp>(loc, oneBitTy, i, adaptor.getInput());
171 runner = rewriter.create<smt::BVXOrOp>(loc, runner, ext);
172 }
173
174 rewriter.replaceOp(op, runner);
175 return success();
176 }
177};
178
179/// Lower the SourceOp to the TargetOp one-to-one.
180template <typename SourceOp, typename TargetOp>
181struct OneToOneOpConversion : OpConversionPattern<SourceOp> {
183 using OpAdaptor = typename SourceOp::Adaptor;
184
185 LogicalResult
186 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
187 ConversionPatternRewriter &rewriter) const override {
188
189 rewriter.replaceOpWithNewOp<TargetOp>(
190 op,
192 op.getResult().getType()),
193 adaptor.getOperands());
194 return success();
195 }
196};
197
198/// Lower the SourceOp to the TargetOp special-casing if the second operand is
199/// zero to return a new symbolic value.
200template <typename SourceOp, typename TargetOp>
201struct DivisionOpConversion : OpConversionPattern<SourceOp> {
203 using OpAdaptor = typename SourceOp::Adaptor;
204
205 LogicalResult
206 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
207 ConversionPatternRewriter &rewriter) const override {
208 Location loc = op.getLoc();
209 auto type = dyn_cast<smt::BitVectorType>(adaptor.getRhs().getType());
210 if (!type)
211 return failure();
212
213 auto resultType = OpConversionPattern<SourceOp>::typeConverter->convertType(
214 op.getResult().getType());
215 Value zero =
216 rewriter.create<smt::BVConstantOp>(loc, APInt(type.getWidth(), 0));
217 Value isZero = rewriter.create<smt::EqOp>(loc, adaptor.getRhs(), zero);
218 Value symbolicVal = rewriter.create<smt::DeclareFunOp>(loc, resultType);
219 Value division =
220 rewriter.create<TargetOp>(loc, resultType, adaptor.getOperands());
221 rewriter.replaceOpWithNewOp<smt::IteOp>(op, isZero, symbolicVal, division);
222 return success();
223 }
224};
225
226/// Converts an operation with a variadic number of operands to a chain of
227/// binary operations assuming left-associativity of the operation.
228template <typename SourceOp, typename TargetOp>
229struct VariadicToBinaryOpConversion : OpConversionPattern<SourceOp> {
231 using OpAdaptor = typename SourceOp::Adaptor;
232
233 LogicalResult
234 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
235 ConversionPatternRewriter &rewriter) const override {
236
237 ValueRange operands = adaptor.getOperands();
238 if (operands.size() < 2)
239 return failure();
240
241 Value runner = operands[0];
242 for (Value operand : operands.drop_front())
243 runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
244
245 rewriter.replaceOp(op, runner);
246 return success();
247 }
248};
249
250} // namespace
251
252//===----------------------------------------------------------------------===//
253// Convert Comb to SMT pass
254//===----------------------------------------------------------------------===//
255
256namespace {
257struct ConvertCombToSMTPass
258 : public circt::impl::ConvertCombToSMTBase<ConvertCombToSMTPass> {
259 void runOnOperation() override;
260};
261} // namespace
262
263void circt::populateCombToSMTConversionPatterns(TypeConverter &converter,
264 RewritePatternSet &patterns) {
265 patterns.add<CombReplicateOpConversion, IcmpOpConversion, ExtractOpConversion,
266 SubOpConversion, MuxOpConversion, ParityOpConversion,
267 OneToOneOpConversion<ShlOp, smt::BVShlOp>,
268 OneToOneOpConversion<ShrUOp, smt::BVLShrOp>,
269 OneToOneOpConversion<ShrSOp, smt::BVAShrOp>,
270 DivisionOpConversion<DivSOp, smt::BVSDivOp>,
271 DivisionOpConversion<DivUOp, smt::BVUDivOp>,
272 DivisionOpConversion<ModSOp, smt::BVSRemOp>,
273 DivisionOpConversion<ModUOp, smt::BVURemOp>,
274 VariadicToBinaryOpConversion<ConcatOp, smt::ConcatOp>,
275 VariadicToBinaryOpConversion<AddOp, smt::BVAddOp>,
276 VariadicToBinaryOpConversion<MulOp, smt::BVMulOp>,
277 VariadicToBinaryOpConversion<AndOp, smt::BVAndOp>,
278 VariadicToBinaryOpConversion<OrOp, smt::BVOrOp>,
279 VariadicToBinaryOpConversion<XorOp, smt::BVXOrOp>>(
280 converter, patterns.getContext());
281
282 // TODO: there are two unsupported operations in the comb dialect: 'parity'
283 // and 'truth_table'.
284}
285
286void ConvertCombToSMTPass::runOnOperation() {
287 ConversionTarget target(getContext());
288 target.addIllegalDialect<hw::HWDialect>();
289 target.addIllegalOp<seq::FromClockOp>();
290 target.addIllegalOp<seq::ToClockOp>();
291 target.addIllegalDialect<comb::CombDialect>();
292 target.addLegalDialect<smt::SMTDialect>();
293 target.addLegalDialect<mlir::func::FuncDialect>();
294
295 RewritePatternSet patterns(&getContext());
296 TypeConverter converter;
298 // Also add HW patterns because some 'comb' canonicalizers produce constant
299 // operations, i.e., even if there is absolutely no HW operation present
300 // initially, we might have to convert one.
303
304 if (failed(mlir::applyPartialConversion(getOperation(), target,
305 std::move(patterns))))
306 return signalPassFailure();
307}
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:217
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition HWToSMT.cpp:289
void populateCombToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition HWToSMT.cpp:184
Definition comb.py:1