13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/SMT/IR/SMTOps.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/DialectConversion.h"
19#define GEN_PASS_DEF_CONVERTCOMBTOSMT
20#include "circt/Conversion/Passes.h.inc"
37 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
38 ConversionPatternRewriter &rewriter)
const override {
39 rewriter.replaceOpWithNewOp<smt::RepeatOp>(op, op.getMultiple(),
51 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
52 ConversionPatternRewriter &rewriter)
const override {
53 if (adaptor.getPredicate() == ICmpPredicate::weq ||
54 adaptor.getPredicate() == ICmpPredicate::ceq ||
55 adaptor.getPredicate() == ICmpPredicate::wne ||
56 adaptor.getPredicate() == ICmpPredicate::cne)
57 return rewriter.notifyMatchFailure(op,
58 "comparison predicate not supported");
61 if (adaptor.getPredicate() == ICmpPredicate::eq) {
62 result = smt::EqOp::create(rewriter, op.getLoc(), adaptor.getLhs(),
64 }
else if (adaptor.getPredicate() == ICmpPredicate::ne) {
65 result = smt::DistinctOp::create(rewriter, op.getLoc(), adaptor.getLhs(),
68 smt::BVCmpPredicate pred;
69 switch (adaptor.getPredicate()) {
70 case ICmpPredicate::sge:
71 pred = smt::BVCmpPredicate::sge;
73 case ICmpPredicate::sgt:
74 pred = smt::BVCmpPredicate::sgt;
76 case ICmpPredicate::sle:
77 pred = smt::BVCmpPredicate::sle;
79 case ICmpPredicate::slt:
80 pred = smt::BVCmpPredicate::slt;
82 case ICmpPredicate::uge:
83 pred = smt::BVCmpPredicate::uge;
85 case ICmpPredicate::ugt:
86 pred = smt::BVCmpPredicate::ugt;
88 case ICmpPredicate::ule:
89 pred = smt::BVCmpPredicate::ule;
91 case ICmpPredicate::ult:
92 pred = smt::BVCmpPredicate::ult;
95 llvm_unreachable(
"all cases handled above");
98 result = smt::BVCmpOp::create(rewriter, op.getLoc(), pred,
99 adaptor.getLhs(), adaptor.getRhs());
102 Value convVal = typeConverter->materializeTargetConversion(
103 rewriter, op.getLoc(), typeConverter->convertType(op.getType()),
108 rewriter.replaceOp(op, convVal);
118 matchAndRewrite(
ExtractOp op, OpAdaptor adaptor,
119 ConversionPatternRewriter &rewriter)
const override {
121 rewriter.replaceOpWithNewOp<smt::ExtractOp>(
122 op, typeConverter->convertType(op.getResult().getType()),
123 adaptor.getLowBitAttr(), adaptor.getInput());
133 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
134 ConversionPatternRewriter &rewriter)
const override {
135 Value condition = typeConverter->materializeTargetConversion(
136 rewriter, op.getLoc(), smt::BoolType::get(getContext()),
138 rewriter.replaceOpWithNewOp<smt::IteOp>(
139 op, condition, adaptor.getTrueValue(), adaptor.getFalseValue());
149 matchAndRewrite(
SubOp op, OpAdaptor adaptor,
150 ConversionPatternRewriter &rewriter)
const override {
152 smt::BVNegOp::create(rewriter, op.getLoc(), adaptor.getRhs());
153 rewriter.replaceOpWithNewOp<smt::BVAddOp>(op, adaptor.getLhs(), negRhs);
163 matchAndRewrite(
ParityOp op, OpAdaptor adaptor,
164 ConversionPatternRewriter &rewriter)
const override {
165 Location loc = op.getLoc();
167 cast<smt::BitVectorType>(adaptor.getInput().getType()).getWidth();
171 Type oneBitTy = smt::BitVectorType::get(getContext(), 1);
173 smt::ExtractOp::create(rewriter, loc, oneBitTy, 0, adaptor.getInput());
174 for (
unsigned i = 1; i < bitwidth; ++i) {
175 Value ext = smt::ExtractOp::create(rewriter, loc, oneBitTy, i,
177 runner = smt::BVXOrOp::create(rewriter, loc, runner, ext);
180 rewriter.replaceOp(op, runner);
186template <
typename SourceOp,
typename TargetOp>
189 using OpAdaptor =
typename SourceOp::Adaptor;
192 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
193 ConversionPatternRewriter &rewriter)
const override {
195 rewriter.replaceOpWithNewOp<TargetOp>(
198 op.getResult().getType()),
199 adaptor.getOperands());
206template <
typename SourceOp,
typename TargetOp>
209 using OpAdaptor =
typename SourceOp::Adaptor;
212 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
213 ConversionPatternRewriter &rewriter)
const override {
214 Location loc = op.getLoc();
215 auto type = dyn_cast<smt::BitVectorType>(adaptor.getRhs().getType());
220 op.getResult().getType());
222 smt::BVConstantOp::create(rewriter, loc, APInt(type.getWidth(), 0));
223 Value isZero = smt::EqOp::create(rewriter, loc, adaptor.getRhs(), zero);
224 Value symbolicVal = smt::DeclareFunOp::create(rewriter, loc, resultType);
226 TargetOp::create(rewriter, loc, resultType, adaptor.getOperands());
227 rewriter.replaceOpWithNewOp<smt::IteOp>(op, isZero, symbolicVal, division);
234template <
typename SourceOp,
typename TargetOp>
237 using OpAdaptor =
typename SourceOp::Adaptor;
240 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
241 ConversionPatternRewriter &rewriter)
const override {
243 ValueRange operands = adaptor.getOperands();
244 if (operands.size() < 2)
247 Value runner = operands[0];
248 for (Value operand : operands.drop_front())
249 runner = TargetOp::create(rewriter, op.
getLoc(), runner, operand);
251 rewriter.replaceOp(op, runner);
263struct ConvertCombToSMTPass
264 :
public circt::impl::ConvertCombToSMTBase<ConvertCombToSMTPass> {
265 void runOnOperation()
override;
271 patterns.add<CombReplicateOpConversion, IcmpOpConversion, ExtractOpConversion,
272 SubOpConversion, MuxOpConversion, ParityOpConversion,
273 OneToOneOpConversion<ShlOp, smt::BVShlOp>,
274 OneToOneOpConversion<ShrUOp, smt::BVLShrOp>,
275 OneToOneOpConversion<ShrSOp, smt::BVAShrOp>,
276 DivisionOpConversion<DivSOp, smt::BVSDivOp>,
277 DivisionOpConversion<DivUOp, smt::BVUDivOp>,
278 DivisionOpConversion<ModSOp, smt::BVSRemOp>,
279 DivisionOpConversion<ModUOp, smt::BVURemOp>,
280 VariadicToBinaryOpConversion<ConcatOp, smt::ConcatOp>,
281 VariadicToBinaryOpConversion<AddOp, smt::BVAddOp>,
282 VariadicToBinaryOpConversion<MulOp, smt::BVMulOp>,
283 VariadicToBinaryOpConversion<AndOp, smt::BVAndOp>,
284 VariadicToBinaryOpConversion<OrOp, smt::BVOrOp>,
285 VariadicToBinaryOpConversion<XorOp, smt::BVXOrOp>>(
292void ConvertCombToSMTPass::runOnOperation() {
293 ConversionTarget target(getContext());
294 target.addIllegalDialect<hw::HWDialect>();
295 target.addIllegalOp<seq::FromClockOp>();
296 target.addIllegalOp<seq::ToClockOp>();
297 target.addIllegalDialect<comb::CombDialect>();
298 target.addLegalDialect<smt::SMTDialect>();
299 target.addLegalDialect<mlir::func::FuncDialect>();
301 RewritePatternSet
patterns(&getContext());
302 TypeConverter converter;
310 if (failed(mlir::applyPartialConversion(getOperation(), target,
312 return signalPassFailure();
static Location getLoc(DefSlot slot)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateCombToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.