13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/DialectConversion.h"
18 #define GEN_PASS_DEF_CONVERTCOMBTOSMT
19 #include "circt/Conversion/Passes.h.inc"
22 using namespace circt;
35 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
36 ConversionPatternRewriter &rewriter)
const override {
37 rewriter.replaceOpWithNewOp<smt::RepeatOp>(op, op.getMultiple(),
49 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
50 ConversionPatternRewriter &rewriter)
const override {
51 if (adaptor.getPredicate() == ICmpPredicate::weq ||
52 adaptor.getPredicate() == ICmpPredicate::ceq ||
53 adaptor.getPredicate() == ICmpPredicate::wne ||
54 adaptor.getPredicate() == ICmpPredicate::cne)
55 return rewriter.notifyMatchFailure(op,
56 "comparison predicate not supported");
58 if (adaptor.getPredicate() == ICmpPredicate::eq) {
59 rewriter.replaceOpWithNewOp<smt::EqOp>(op, adaptor.getLhs(),
64 if (adaptor.getPredicate() == ICmpPredicate::ne) {
65 rewriter.replaceOpWithNewOp<smt::DistinctOp>(op, adaptor.getLhs(),
70 smt::BVCmpPredicate pred;
71 switch (adaptor.getPredicate()) {
72 case ICmpPredicate::sge:
73 pred = smt::BVCmpPredicate::sge;
75 case ICmpPredicate::sgt:
76 pred = smt::BVCmpPredicate::sgt;
78 case ICmpPredicate::sle:
79 pred = smt::BVCmpPredicate::sle;
81 case ICmpPredicate::slt:
82 pred = smt::BVCmpPredicate::slt;
84 case ICmpPredicate::uge:
85 pred = smt::BVCmpPredicate::uge;
87 case ICmpPredicate::ugt:
88 pred = smt::BVCmpPredicate::ugt;
90 case ICmpPredicate::ule:
91 pred = smt::BVCmpPredicate::ule;
93 case ICmpPredicate::ult:
94 pred = smt::BVCmpPredicate::ult;
97 llvm_unreachable(
"all cases handled above");
100 rewriter.replaceOpWithNewOp<smt::BVCmpOp>(op, pred, adaptor.getLhs(),
111 matchAndRewrite(
ExtractOp op, OpAdaptor adaptor,
112 ConversionPatternRewriter &rewriter)
const override {
114 rewriter.replaceOpWithNewOp<smt::ExtractOp>(
115 op, typeConverter->convertType(op.getResult().getType()),
116 adaptor.getLowBitAttr(), adaptor.getInput());
126 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
127 ConversionPatternRewriter &rewriter)
const override {
128 Value condition = typeConverter->materializeTargetConversion(
131 rewriter.replaceOpWithNewOp<smt::IteOp>(
132 op, condition, adaptor.getTrueValue(), adaptor.getFalseValue());
142 matchAndRewrite(
SubOp op, OpAdaptor adaptor,
143 ConversionPatternRewriter &rewriter)
const override {
144 Value negRhs = rewriter.create<smt::BVNegOp>(op.getLoc(), adaptor.getRhs());
145 rewriter.replaceOpWithNewOp<smt::BVAddOp>(op, adaptor.getLhs(), negRhs);
155 matchAndRewrite(
ParityOp op, OpAdaptor adaptor,
156 ConversionPatternRewriter &rewriter)
const override {
157 Location loc = op.getLoc();
159 cast<smt::BitVectorType>(adaptor.getInput().getType()).getWidth();
165 rewriter.create<smt::ExtractOp>(loc, oneBitTy, 0, adaptor.getInput());
166 for (
unsigned i = 1; i < bitwidth; ++i) {
168 rewriter.create<smt::ExtractOp>(loc, oneBitTy, i, adaptor.getInput());
169 runner = rewriter.create<smt::BVXOrOp>(loc, runner, ext);
172 rewriter.replaceOp(op, runner);
178 template <
typename SourceOp,
typename TargetOp>
181 using OpAdaptor =
typename SourceOp::Adaptor;
184 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter)
const override {
187 rewriter.replaceOpWithNewOp<TargetOp>(
190 op.getResult().getType()),
191 adaptor.getOperands());
198 template <
typename SourceOp,
typename TargetOp>
201 using OpAdaptor =
typename SourceOp::Adaptor;
204 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter)
const override {
206 Location loc = op.getLoc();
207 auto type = dyn_cast<smt::BitVectorType>(adaptor.getRhs().getType());
212 op.getResult().getType());
214 rewriter.create<smt::BVConstantOp>(loc, APInt(type.getWidth(), 0));
215 Value isZero = rewriter.create<smt::EqOp>(loc, adaptor.getRhs(), zero);
216 Value symbolicVal = rewriter.create<smt::DeclareFunOp>(loc, resultType);
218 rewriter.create<TargetOp>(loc, resultType, adaptor.getOperands());
219 rewriter.replaceOpWithNewOp<smt::IteOp>(op, isZero, symbolicVal, division);
226 template <
typename SourceOp,
typename TargetOp>
229 using OpAdaptor =
typename SourceOp::Adaptor;
232 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
233 ConversionPatternRewriter &rewriter)
const override {
235 ValueRange operands = adaptor.getOperands();
236 if (operands.size() < 2)
239 Value runner = operands[0];
240 for (Value operand : operands.drop_front())
241 runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
243 rewriter.replaceOp(op, runner);
255 struct ConvertCombToSMTPass
256 :
public impl::ConvertCombToSMTBase<ConvertCombToSMTPass> {
257 void runOnOperation()
override;
263 patterns.add<CombReplicateOpConversion, IcmpOpConversion, ExtractOpConversion,
264 SubOpConversion, MuxOpConversion, ParityOpConversion,
265 OneToOneOpConversion<ShlOp, smt::BVShlOp>,
266 OneToOneOpConversion<ShrUOp, smt::BVLShrOp>,
267 OneToOneOpConversion<ShrSOp, smt::BVAShrOp>,
268 DivisionOpConversion<DivSOp, smt::BVSDivOp>,
269 DivisionOpConversion<DivUOp, smt::BVUDivOp>,
270 DivisionOpConversion<ModSOp, smt::BVSRemOp>,
271 DivisionOpConversion<ModUOp, smt::BVURemOp>,
272 VariadicToBinaryOpConversion<ConcatOp, smt::ConcatOp>,
273 VariadicToBinaryOpConversion<AddOp, smt::BVAddOp>,
274 VariadicToBinaryOpConversion<MulOp, smt::BVMulOp>,
275 VariadicToBinaryOpConversion<AndOp, smt::BVAndOp>,
276 VariadicToBinaryOpConversion<OrOp, smt::BVOrOp>,
277 VariadicToBinaryOpConversion<XorOp, smt::BVXOrOp>>(
284 void ConvertCombToSMTPass::runOnOperation() {
285 ConversionTarget target(getContext());
286 target.addIllegalDialect<comb::CombDialect>();
287 target.addLegalDialect<smt::SMTDialect>();
289 RewritePatternSet
patterns(&getContext());
290 TypeConverter converter;
298 if (failed(mlir::applyPartialConversion(getOperation(), target,
300 return signalPassFailure();
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateCombToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.