12#include "mlir/Dialect/SMT/IR/SMTOps.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/DialectConversion.h"
17#define GEN_PASS_DEF_CONVERTDATAPATHTOSMT
18#include "circt/Conversion/Passes.h.inc"
23using namespace datapath;
39 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
40 ConversionPatternRewriter &rewriter)
const override {
42 ValueRange operands = adaptor.getOperands();
43 ValueRange results = op.getResults();
46 Value operandRunner = operands[0];
47 for (Value operand : operands.drop_front())
49 smt::BVAddOp::create(rewriter, op.getLoc(), operandRunner, operand);
52 SmallVector<Value, 2> newResults;
53 newResults.reserve(results.size());
54 for (Value result : results) {
55 auto declareFunOp = smt::DeclareFunOp::create(
56 rewriter, op.getLoc(), typeConverter->convertType(result.getType()));
57 newResults.push_back(declareFunOp.getResult());
61 Value resultRunner = newResults.front();
62 for (
auto freeVar : llvm::drop_begin(newResults, 1))
64 smt::BVAddOp::create(rewriter, op.getLoc(), resultRunner, freeVar);
68 smt::EqOp::create(rewriter, op.getLoc(), operandRunner, resultRunner);
70 smt::AssertOp::create(rewriter, op.getLoc(), premise);
72 if (newResults.size() != results.size())
73 return rewriter.notifyMatchFailure(op,
"expected same number of results");
75 rewriter.replaceOp(op, newResults);
88 matchAndRewrite(PartialProductOp op, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter)
const override {
91 ValueRange operands = adaptor.getOperands();
92 ValueRange results = op.getResults();
96 smt::BVMulOp::create(rewriter, op.getLoc(), operands[0], operands[1]);
99 SmallVector<Value, 2> newResults;
100 newResults.reserve(results.size());
101 for (Value result : results) {
102 auto declareFunOp = smt::DeclareFunOp::create(
103 rewriter, op.getLoc(), typeConverter->convertType(result.getType()));
104 newResults.push_back(declareFunOp.getResult());
108 Value resultRunner = newResults.front();
109 for (
auto freeVar :
llvm::drop_begin(newResults, 1))
111 smt::BVAddOp::create(rewriter, op.
getLoc(), resultRunner, freeVar);
115 smt::EqOp::create(rewriter, op.getLoc(), mulResult, resultRunner);
117 smt::AssertOp::create(rewriter, op.getLoc(), premise);
119 if (newResults.size() != results.size())
120 return rewriter.notifyMatchFailure(op,
"expected same number of results");
122 rewriter.replaceOp(op, newResults);
131struct PosPartialProductOpConversion
136 matchAndRewrite(PosPartialProductOp op, OpAdaptor adaptor,
137 ConversionPatternRewriter &rewriter)
const override {
139 ValueRange operands = adaptor.getOperands();
140 ValueRange results = op.getResults();
144 smt::BVAddOp::create(rewriter, op.getLoc(), operands[0], operands[1]);
147 smt::BVMulOp::create(rewriter, op.getLoc(), addResult, operands[2]);
150 SmallVector<Value, 2> newResults;
151 newResults.reserve(results.size());
152 for (Value result : results) {
153 auto declareFunOp = smt::DeclareFunOp::create(
154 rewriter, op.getLoc(), typeConverter->convertType(result.getType()));
155 newResults.push_back(declareFunOp.getResult());
159 Value resultRunner = newResults.front();
160 for (
auto freeVar :
llvm::drop_begin(newResults, 1))
162 smt::BVAddOp::create(rewriter, op.
getLoc(), resultRunner, freeVar);
166 smt::EqOp::create(rewriter, op.getLoc(), mulResult, resultRunner);
168 smt::AssertOp::create(rewriter, op.getLoc(), premise);
170 if (newResults.size() != results.size())
171 return rewriter.notifyMatchFailure(op,
"expected same number of results");
173 rewriter.replaceOp(op, newResults);
185struct ConvertDatapathToSMTPass
186 :
public circt::impl::ConvertDatapathToSMTBase<ConvertDatapathToSMTPass> {
187 void runOnOperation()
override;
192 TypeConverter &converter, RewritePatternSet &
patterns) {
193 patterns.add<CompressOpConversion, PartialProductOpConversion,
194 PosPartialProductOpConversion>(converter,
patterns.getContext());
197void ConvertDatapathToSMTPass::runOnOperation() {
198 ConversionTarget target(getContext());
199 target.addIllegalDialect<datapath::DatapathDialect>();
200 target.addLegalDialect<smt::SMTDialect>();
202 RewritePatternSet
patterns(&getContext());
203 TypeConverter converter;
207 if (failed(mlir::applyPartialConversion(getOperation(), target,
209 return signalPassFailure();
static Location getLoc(DefSlot slot)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateDatapathToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the Datapath to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.