13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18#define GEN_PASS_DEF_DATAPATHREDUCEDELAY
19#include "circt/Dialect/Datapath/DatapathPasses.h.inc"
24using namespace datapath;
40 using OpRewritePattern::OpRewritePattern;
43 PatternRewriter &rewriter)
const override {
45 SmallVector<Value, 8> newCompressOperands;
47 for (Value operand : addOp.getOperands()) {
49 llvm::append_range(newCompressOperands, nestedAddOp.getOperands());
51 newCompressOperands.push_back(operand);
56 if (newCompressOperands.size() <= addOp.getNumOperands())
60 auto newCompressOp = datapath::CompressOp::create(rewriter, addOp.getLoc(),
61 newCompressOperands, 2);
64 rewriter.replaceOpWithNewOp<
comb::AddOp>(addOp, newCompressOp.getResults(),
74 using OpRewritePattern::OpRewritePattern;
79 PatternRewriter &rewriter)
const override {
81 SmallVector<Value, 8> newCompressOperands;
82 for (Value operand : addOp.getOperands()) {
88 newCompressOperands.push_back(operand);
92 SmallVector<Value> trueValOperands = {nestedMuxOp.getTrueValue()};
93 SmallVector<Value> falseValOperands = {nestedMuxOp.getFalseValue()};
96 nestedMuxOp.getTrueValue().getDefiningOp<
comb::AddOp>())
97 trueValOperands = trueVal.getOperands();
101 nestedMuxOp.getFalseValue().getDefiningOp<
comb::AddOp>())
102 falseValOperands = falseVal.getOperands();
105 std::max(trueValOperands.size(), falseValOperands.size());
108 if (maxOperands <= 1) {
109 newCompressOperands.push_back(operand);
117 rewriter.getIntegerAttr(addOp.getType(), 0));
118 for (
size_t i = 0; i < maxOperands; ++i) {
119 auto tOp = i < trueValOperands.size() ? trueValOperands[i] : zero;
120 auto fOp = i < falseValOperands.size() ? falseValOperands[i] : zero;
121 auto newMux = comb::MuxOp::create(rewriter, addOp.getLoc(),
122 nestedMuxOp.getCond(), tOp, fOp);
123 newCompressOperands.push_back(newMux.getResult());
128 if (newCompressOperands.size() <= addOp.getNumOperands())
132 auto newCompressOp = datapath::CompressOp::create(rewriter, addOp.getLoc(),
133 newCompressOperands, 2);
136 rewriter.replaceOpWithNewOp<
comb::AddOp>(addOp, newCompressOp.getResults(),
143 using OpRewritePattern::OpRewritePattern;
149 LogicalResult matchAndRewrite(comb::ICmpOp op,
150 PatternRewriter &rewriter)
const override {
151 Value lhs = op.getLhs();
152 Value rhs = op.getRhs();
153 auto width = lhs.getType().getIntOrFloatBitWidth();
156 if (op.getPredicate() != comb::ICmpPredicate::ult &&
157 op.getPredicate() != comb::ICmpPredicate::ule &&
158 op.getPredicate() != comb::ICmpPredicate::ugt &&
159 op.getPredicate() != comb::ICmpPredicate::uge)
168 bool lhsMinusRhs = op.getPredicate() == comb::ICmpPredicate::ult ||
169 op.getPredicate() == comb::ICmpPredicate::uge;
171 bool invertOut = op.getPredicate() == comb::ICmpPredicate::uge ||
172 op.getPredicate() == comb::ICmpPredicate::ule;
177 SmallVector<Value> lhsAddends = {lhs};
181 if (lhsAdd->getAttrOfType<UnitAttr>(
"comb.nuw"))
182 lhsAddends = lhsAdd.getOperands();
185 SmallVector<Value> rhsAddends = {rhs};
189 if (rhsAdd->getAttrOfType<UnitAttr>(
"comb.nuw"))
190 rhsAddends = rhsAdd.getOperands();
195 if (lhsAddends.size() + rhsAddends.size() < 3)
198 SmallVector<Value> lhsExtend;
199 for (
auto addend : lhsAddends) {
200 auto ext = comb::createZExt(rewriter, op.getLoc(), addend, width + 1);
201 lhsExtend.push_back(ext);
204 SmallVector<Value> rhsExtend;
205 for (
auto addend : rhsAddends) {
206 auto ext = comb::createZExt(rewriter, op.getLoc(), addend, width + 1);
207 auto negatedAddend = comb::createOrFoldNot(op.getLoc(), ext, rewriter);
208 rhsExtend.push_back(negatedAddend);
212 rewriter, op.getLoc(), APInt(width + 1, rhsExtend.size())));
214 SmallVector<Value> allAddends = std::move(lhsExtend);
215 llvm::append_range(allAddends, rhsExtend);
216 auto add = comb::AddOp::create(rewriter, op.getLoc(), allAddends,
false);
218 op.getLoc(), add.getResult(), width, 1);
221 rewriter.replaceOp(op, msb);
225 auto notOp = comb::createOrFoldNot(op.getLoc(), msb, rewriter);
226 rewriter.replaceOp(op, notOp);
234struct DatapathReduceDelayPass
235 :
public circt::datapath::impl::DatapathReduceDelayBase<
236 DatapathReduceDelayPass> {
238 void runOnOperation()
override {
239 Operation *op = getOperation();
240 MLIRContext *ctx = op->getContext();
243 patterns.add<FoldAddReplicate, FoldMuxAdd, ConvertCmpToAdd>(ctx);
245 if (failed(applyPatternsGreedily(op, std::move(
patterns))))
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.