14#include "mlir/Analysis/TopologicalSortUtils.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/KnownBits.h"
22#define DEBUG_TYPE "datapath-to-comb"
25#define GEN_PASS_DEF_CONVERTDATAPATHTOCOMB
26#include "circt/Conversion/Passes.h.inc"
30using namespace datapath;
33static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
34 SmallVector<Value> bits;
35 comb::extractBits(builder, val, bits);
50 matchAndRewrite(CompressOp op,
51 mlir::PatternRewriter &rewriter)
const override {
52 Location loc = op.getLoc();
53 auto inputs = op.getOperands();
54 unsigned width = inputs[0].getType().getIntOrFloatBitWidth();
56 auto addOp = comb::AddOp::create(rewriter, loc, inputs,
true);
59 SmallVector<Value> results(op.getNumResults() - 1, zeroOp);
60 results.push_back(addOp);
61 rewriter.replaceOp(op, results);
68 DatapathCompressOpConversion(MLIRContext *context,
73 matchAndRewrite(CompressOp op,
74 mlir::PatternRewriter &rewriter)
const override {
75 Location loc = op.getLoc();
76 auto inputs = op.getOperands();
78 SmallVector<SmallVector<Value>> addends;
79 for (
auto input : inputs) {
85 auto width = inputs[0].getType().getIntOrFloatBitWidth();
86 auto targetAddends = op.getNumResults();
91 if (failed(comp.withInputDelays(
92 [&](Value v) { return analysis->getMaxDelay(v, 0); })))
96 rewriter.replaceOp(op, comp.compressToHeight(rewriter, targetAddends));
104struct DatapathPartialProductOpConversion :
OpRewritePattern<PartialProductOp> {
107 DatapathPartialProductOpConversion(MLIRContext *context,
bool forceBooth)
110 const bool forceBooth;
112 LogicalResult matchAndRewrite(PartialProductOp op,
113 PatternRewriter &rewriter)
const override {
115 Value a = op.getLhs();
116 Value b = op.getRhs();
117 unsigned width = a.getType().getIntOrFloatBitWidth();
121 rewriter.replaceOpWithNewOp<
hw::ConstantOp>(op, op.getType(0), 0);
127 if (op.getNumResults() > 16 || forceBooth)
128 return lowerBoothArray(rewriter, a, b, op, width);
130 return lowerAndArray(rewriter, a, b, op, width);
134 static LogicalResult lowerAndArray(PatternRewriter &rewriter, Value a,
135 Value b, PartialProductOp op,
138 Location loc = op.getLoc();
140 SmallVector<Value> bBits =
extractBits(rewriter, b);
142 SmallVector<Value> partialProducts;
143 partialProducts.reserve(width);
146 assert(op.getNumResults() <= width &&
147 "Cannot return more results than the operator width");
149 for (
unsigned i = 0; i < op.getNumResults(); ++i) {
151 rewriter.createOrFold<comb::ReplicateOp>(loc, bBits[i], width);
152 auto ppRow = rewriter.createOrFold<
comb::AndOp>(loc, repl, a);
154 partialProducts.push_back(ppRow);
159 comb::ConcatOp::create(rewriter, loc, ValueRange{ppRow, shiftBy});
161 loc, ppAlign, 0, width);
162 partialProducts.push_back(ppAlignTrunc);
165 rewriter.replaceOp(op, partialProducts);
169 static LogicalResult lowerBoothArray(PatternRewriter &rewriter, Value a,
170 Value b, PartialProductOp op,
172 Location loc = op.getLoc();
179 auto rowWidth = width;
180 auto knownBitsA = comb::computeKnownBits(a);
181 if (!knownBitsA.Zero.isZero()) {
182 if (knownBitsA.Zero.countLeadingOnes() > 1) {
185 rowWidth -= knownBitsA.Zero.countLeadingOnes() - 1;
192 Value twoA = rewriter.createOrFold<
comb::ShlOp>(loc, a, oneRowWidth);
196 SmallVector<Value> bBits =
extractBits(rewriter, b);
199 auto knownBitsB = comb::computeKnownBits(b);
200 if (!knownBitsB.Zero.isZero()) {
201 for (
unsigned i = 0; i < width; ++i)
202 if (knownBitsB.Zero[i])
203 bBits[i] = zeroFalse;
206 SmallVector<Value> partialProducts;
207 partialProducts.reserve(width);
217 for (
unsigned i = 0; i <= width; i += 2) {
219 Value bim1 = (i == 0) ? zeroFalse : bBits[i - 1];
220 Value bi = (i < width) ? bBits[i] : zeroFalse;
221 Value bip1 = (i + 1 < width) ? bBits[i + 1] : zeroFalse;
226 Value encOne = rewriter.createOrFold<
comb::XorOp>(loc, bi, bim1,
true);
229 Value biInv = rewriter.createOrFold<
comb::XorOp>(loc, bi, constOne,
true);
231 rewriter.createOrFold<
comb::XorOp>(loc, bip1, constOne,
true);
233 rewriter.createOrFold<
comb::XorOp>(loc, bim1, constOne,
true);
235 Value andLeft = rewriter.createOrFold<
comb::AndOp>(
236 loc, ValueRange{bip1Inv, bi, bim1},
true);
237 Value andRight = rewriter.createOrFold<
comb::AndOp>(
238 loc, ValueRange{bip1, biInv, bim1Inv},
true);
240 rewriter.createOrFold<
comb::OrOp>(loc, andLeft, andRight,
true);
243 rewriter.createOrFold<comb::ReplicateOp>(loc, encNeg, rowWidth);
245 rewriter.createOrFold<comb::ReplicateOp>(loc, encOne, rowWidth);
247 rewriter.createOrFold<comb::ReplicateOp>(loc, encTwo, rowWidth);
250 Value selTwoA = rewriter.createOrFold<
comb::AndOp>(loc, encTwoRepl, twoA);
251 Value selOneA = rewriter.createOrFold<
comb::AndOp>(loc, encOneRepl, a);
253 rewriter.createOrFold<
comb::OrOp>(loc, selTwoA, selOneA,
true);
257 rewriter.createOrFold<
comb::XorOp>(loc, magA, encNegRepl,
true);
270 if (rowWidth < width) {
271 auto padding = width - rowWidth;
272 auto encNegInv = bip1Inv;
277 rewriter.createOrFold<comb::ReplicateOp>(loc, encNeg, padding);
279 loc, ValueRange{encNegPad, ppRow});
283 loc, ValueRange{encNegInv, encNeg, encNeg, ppRow});
287 loc, ValueRange{constOne, encNegInv, ppRow});
291 auto rowWidth = ppRow.getType().getIntOrFloatBitWidth();
292 if (rowWidth < width) {
296 loc, ValueRange{zeroPad, ppRow});
302 partialProducts.push_back(ppRow);
308 assert(i >= 2 &&
"Expected i to be at least 2 for sign correction");
311 loc, ValueRange{ppRow, zeroFalse, encNegPrev});
313 loc, withSignCorrection, 0, width);
317 rewriter.createOrFold<
comb::ShlOp>(loc, ppAlignPre, shiftBy);
318 partialProducts.push_back(ppAlign);
321 if (partialProducts.size() == op.getNumResults())
326 while (partialProducts.size() < op.getNumResults())
327 partialProducts.push_back(zeroWidth);
329 assert(partialProducts.size() == op.getNumResults() &&
330 "Expected number of booth partial products to match results");
332 rewriter.replaceOp(op, partialProducts);
343struct ConvertDatapathToCombPass
344 :
public impl::ConvertDatapathToCombBase<ConvertDatapathToCombPass> {
345 void runOnOperation()
override;
346 using ConvertDatapathToCombBase<
347 ConvertDatapathToCombPass>::ConvertDatapathToCombBase;
352 Operation *op, RewritePatternSet &&
patterns,
356 mlir::GreedyRewriteConfig config;
361 config.setMaxIterations(2).setListener(analysis).setUseTopDownTraversal(
true);
364 if (failed(mlir::applyPatternsGreedily(op, std::move(
patterns), config)))
370void ConvertDatapathToCombPass::runOnOperation() {
371 RewritePatternSet
patterns(&getContext());
377 analysis = &getAnalysis<aig::IncrementalLongestPathAnalysis>();
378 if (lowerCompressToAdd)
383 patterns.add<DatapathCompressOpConversion>(
patterns.getContext(), analysis);
386 getOperation(), std::move(
patterns), analysis)))
387 return signalPassFailure();
392 auto result = getOperation()->walk([&](Operation *op) {
393 if (llvm::isa_and_nonnull<datapath::DatapathDialect>(op->getDialect())) {
394 op->emitError(
"Datapath operation not converted: ") << *op;
395 return WalkResult::interrupt();
397 return WalkResult::advance();
399 if (result.wasInterrupted())
400 return signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static LogicalResult applyPatternsGreedilyWithTimingInfo(Operation *op, RewritePatternSet &&patterns, aig::IncrementalLongestPathAnalysis *analysis)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.