13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
17 #define GEN_PASS_DEF_CONVERTCOMBTOSMT
18 #include "circt/Conversion/Passes.h.inc"
21 using namespace circt;
34 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
35 ConversionPatternRewriter &rewriter)
const override {
36 rewriter.replaceOpWithNewOp<smt::RepeatOp>(op, op.getMultiple(),
48 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
49 ConversionPatternRewriter &rewriter)
const override {
50 if (adaptor.getPredicate() == ICmpPredicate::weq ||
51 adaptor.getPredicate() == ICmpPredicate::ceq ||
52 adaptor.getPredicate() == ICmpPredicate::wne ||
53 adaptor.getPredicate() == ICmpPredicate::cne)
54 return rewriter.notifyMatchFailure(op,
55 "comparison predicate not supported");
57 if (adaptor.getPredicate() == ICmpPredicate::eq) {
58 rewriter.replaceOpWithNewOp<smt::EqOp>(op, adaptor.getLhs(),
63 if (adaptor.getPredicate() == ICmpPredicate::ne) {
64 rewriter.replaceOpWithNewOp<smt::DistinctOp>(op, adaptor.getLhs(),
69 smt::BVCmpPredicate pred;
70 switch (adaptor.getPredicate()) {
71 case ICmpPredicate::sge:
72 pred = smt::BVCmpPredicate::sge;
74 case ICmpPredicate::sgt:
75 pred = smt::BVCmpPredicate::sgt;
77 case ICmpPredicate::sle:
78 pred = smt::BVCmpPredicate::sle;
80 case ICmpPredicate::slt:
81 pred = smt::BVCmpPredicate::slt;
83 case ICmpPredicate::uge:
84 pred = smt::BVCmpPredicate::uge;
86 case ICmpPredicate::ugt:
87 pred = smt::BVCmpPredicate::ugt;
89 case ICmpPredicate::ule:
90 pred = smt::BVCmpPredicate::ule;
92 case ICmpPredicate::ult:
93 pred = smt::BVCmpPredicate::ult;
96 llvm_unreachable(
"all cases handled above");
99 rewriter.replaceOpWithNewOp<smt::BVCmpOp>(op, pred, adaptor.getLhs(),
110 matchAndRewrite(
ExtractOp op, OpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter)
const override {
113 rewriter.replaceOpWithNewOp<smt::ExtractOp>(
114 op, typeConverter->convertType(op.getResult().getType()),
115 adaptor.getLowBitAttr(), adaptor.getInput());
125 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
126 ConversionPatternRewriter &rewriter)
const override {
127 Value condition = typeConverter->materializeTargetConversion(
130 rewriter.replaceOpWithNewOp<smt::IteOp>(
131 op, condition, adaptor.getTrueValue(), adaptor.getFalseValue());
141 matchAndRewrite(
SubOp op, OpAdaptor adaptor,
142 ConversionPatternRewriter &rewriter)
const override {
143 Value negRhs = rewriter.create<smt::BVNegOp>(op.getLoc(), adaptor.getRhs());
144 rewriter.replaceOpWithNewOp<smt::BVAddOp>(op, adaptor.getLhs(), negRhs);
154 matchAndRewrite(
ParityOp op, OpAdaptor adaptor,
155 ConversionPatternRewriter &rewriter)
const override {
156 Location loc = op.getLoc();
158 cast<smt::BitVectorType>(adaptor.getInput().getType()).getWidth();
164 rewriter.create<smt::ExtractOp>(loc, oneBitTy, 0, adaptor.getInput());
165 for (
unsigned i = 1; i < bitwidth; ++i) {
167 rewriter.create<smt::ExtractOp>(loc, oneBitTy, i, adaptor.getInput());
168 runner = rewriter.create<smt::BVXOrOp>(loc, runner, ext);
171 rewriter.replaceOp(op, runner);
177 template <
typename SourceOp,
typename TargetOp>
180 using OpAdaptor =
typename SourceOp::Adaptor;
183 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter)
const override {
186 rewriter.replaceOpWithNewOp<TargetOp>(
189 op.getResult().getType()),
190 adaptor.getOperands());
197 template <
typename SourceOp,
typename TargetOp>
200 using OpAdaptor =
typename SourceOp::Adaptor;
203 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
204 ConversionPatternRewriter &rewriter)
const override {
206 ValueRange operands = adaptor.getOperands();
207 if (operands.size() < 2)
210 Value runner = operands[0];
211 for (Value operand : operands.drop_front())
212 runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
214 rewriter.replaceOp(op, runner);
226 struct ConvertCombToSMTPass
227 :
public impl::ConvertCombToSMTBase<ConvertCombToSMTPass> {
228 void runOnOperation()
override;
234 patterns.add<CombReplicateOpConversion, IcmpOpConversion, ExtractOpConversion,
235 SubOpConversion, MuxOpConversion, ParityOpConversion,
236 OneToOneOpConversion<ShlOp, smt::BVShlOp>,
237 OneToOneOpConversion<ShrUOp, smt::BVLShrOp>,
238 OneToOneOpConversion<ShrSOp, smt::BVAShrOp>,
239 OneToOneOpConversion<DivSOp, smt::BVSDivOp>,
240 OneToOneOpConversion<DivUOp, smt::BVUDivOp>,
241 OneToOneOpConversion<ModSOp, smt::BVSRemOp>,
242 OneToOneOpConversion<ModUOp, smt::BVURemOp>,
243 VariadicToBinaryOpConversion<ConcatOp, smt::ConcatOp>,
244 VariadicToBinaryOpConversion<AddOp, smt::BVAddOp>,
245 VariadicToBinaryOpConversion<MulOp, smt::BVMulOp>,
246 VariadicToBinaryOpConversion<AndOp, smt::BVAndOp>,
247 VariadicToBinaryOpConversion<OrOp, smt::BVOrOp>,
248 VariadicToBinaryOpConversion<XorOp, smt::BVXOrOp>>(
255 void ConvertCombToSMTPass::runOnOperation() {
256 ConversionTarget target(getContext());
257 target.addIllegalDialect<comb::CombDialect>();
258 target.addLegalDialect<smt::SMTDialect>();
260 RewritePatternSet
patterns(&getContext());
261 TypeConverter converter;
269 if (failed(mlir::applyPartialConversion(getOperation(), target,
271 return signalPassFailure();
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateCombToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.