14#include "mlir/Analysis/TopologicalSortUtils.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/KnownBits.h"
23#define DEBUG_TYPE "datapath-to-comb"
26#define GEN_PASS_DEF_CONVERTDATAPATHTOCOMB
27#include "circt/Conversion/Passes.h.inc"
31using namespace datapath;
34static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
35 SmallVector<Value> bits;
36 comb::extractBits(builder, val, bits);
51 matchAndRewrite(CompressOp op,
52 mlir::PatternRewriter &rewriter)
const override {
53 Location loc = op.getLoc();
54 auto inputs = op.getOperands();
55 unsigned width = inputs[0].getType().getIntOrFloatBitWidth();
57 auto addOp = comb::AddOp::create(rewriter, loc, inputs,
true);
60 SmallVector<Value> results(op.getNumResults() - 1, zeroOp);
61 results.push_back(addOp);
62 rewriter.replaceOp(op, results);
69 DatapathCompressOpConversion(MLIRContext *
context,
74 matchAndRewrite(CompressOp op,
75 mlir::PatternRewriter &rewriter)
const override {
76 Location loc = op.getLoc();
77 auto inputs = op.getOperands();
79 SmallVector<SmallVector<Value>> addends;
80 for (
auto input : inputs) {
86 auto width = inputs[0].getType().getIntOrFloatBitWidth();
87 auto targetAddends = op.getNumResults();
92 if (failed(comp.withInputDelays(
93 [&](Value v) { return analysis->getMaxDelay(v, 0); })))
97 rewriter.replaceOp(op, comp.compressToHeight(rewriter, targetAddends));
105struct DatapathPartialProductOpConversion :
OpRewritePattern<PartialProductOp> {
108 DatapathPartialProductOpConversion(MLIRContext *
context,
bool forceBooth)
111 const bool forceBooth;
113 LogicalResult matchAndRewrite(PartialProductOp op,
114 PatternRewriter &rewriter)
const override {
116 Value
a = op.getLhs();
117 Value
b = op.getRhs();
118 unsigned width =
a.getType().getIntOrFloatBitWidth();
122 rewriter.replaceOpWithNewOp<
hw::ConstantOp>(op, op.getType(0), 0);
139 return lowerSqrAndArray(rewriter, a, op, width);
143 if (op.getNumResults() > 16 || forceBooth)
144 return lowerBoothArray(rewriter, a, b, op, width);
146 return lowerAndArray(rewriter, a, b, op, width);
150 static LogicalResult lowerAndArray(PatternRewriter &rewriter, Value a,
151 Value b, PartialProductOp op,
154 Location loc = op.getLoc();
156 SmallVector<Value> bBits =
extractBits(rewriter, b);
158 auto rowWidth = width;
159 auto knownBitsA = comb::computeKnownBits(a);
160 if (!knownBitsA.Zero.isZero()) {
161 if (knownBitsA.Zero.countLeadingOnes() > 1) {
162 rowWidth -= knownBitsA.Zero.countLeadingOnes();
167 SmallVector<Value> partialProducts;
168 partialProducts.reserve(width);
171 assert(op.getNumResults() <= width &&
172 "Cannot return more results than the operator width");
174 for (
unsigned i = 0; i < op.getNumResults(); ++i) {
176 rewriter.createOrFold<comb::ReplicateOp>(loc, bBits[i], rowWidth);
177 auto ppRow = rewriter.createOrFold<
comb::AndOp>(loc, repl,
a);
178 if (rowWidth < width) {
179 auto padding = width - rowWidth;
182 loc, ValueRange{
zeroPad, ppRow});
186 partialProducts.push_back(ppRow);
191 comb::ConcatOp::create(rewriter, loc, ValueRange{ppRow, shiftBy});
193 loc, ppAlign, 0, width);
194 partialProducts.push_back(ppAlignTrunc);
197 rewriter.replaceOp(op, partialProducts);
201 static LogicalResult lowerSqrAndArray(PatternRewriter &rewriter, Value a,
202 PartialProductOp op,
unsigned width) {
204 Location loc = op.getLoc();
205 SmallVector<Value> aBits =
extractBits(rewriter, a);
207 SmallVector<Value> partialProducts;
208 partialProducts.reserve(width);
212 assert(op.getNumResults() <= width &&
213 "Cannot return more results than the operator width");
215 for (
unsigned i = 0; i < op.getNumResults(); ++i) {
216 SmallVector<Value> row;
219 if (2 * i >= width) {
222 partialProducts.push_back(zeroWidth);
228 row.push_back(shiftBy);
230 row.push_back(aBits[i]);
233 unsigned rowWidth = 2 * i + 1;
234 if (rowWidth < width) {
235 row.push_back(zeroFalse);
239 for (
unsigned j = i + 1; j < width; ++j) {
241 if (rowWidth == width)
247 if (j >= op.getNumResults()) {
248 row.push_back(zeroFalse);
253 rewriter.createOrFold<
comb::AndOp>(loc, aBits[i], aBits[j]);
254 row.push_back(ppBit);
256 std::reverse(row.begin(), row.end());
257 auto ppRow = comb::ConcatOp::create(rewriter, loc, row);
258 partialProducts.push_back(ppRow);
261 rewriter.replaceOp(op, partialProducts);
265 static LogicalResult lowerBoothArray(PatternRewriter &rewriter, Value a,
266 Value b, PartialProductOp op,
268 Location loc = op.getLoc();
275 auto rowWidth = width;
276 auto knownBitsA = comb::computeKnownBits(a);
277 if (!knownBitsA.Zero.isZero()) {
278 if (knownBitsA.Zero.countLeadingOnes() > 1) {
281 rowWidth -= knownBitsA.Zero.countLeadingOnes() - 1;
290 Value twoA = rewriter.createOrFold<
comb::ShlOp>(loc,
a, oneRowWidth);
294 SmallVector<Value> bBits =
extractBits(rewriter, b);
297 auto bWidth =
b.getType().getIntOrFloatBitWidth();
298 auto knownBitsB = comb::computeKnownBits(b);
299 if (!knownBitsB.Zero.isZero()) {
300 bWidth -= knownBitsB.Zero.countLeadingOnes();
301 for (
unsigned i = 0; i < width; ++i)
302 if (knownBitsB.Zero[i])
303 bBits[i] = zeroFalse;
306 SmallVector<Value> partialProducts;
307 partialProducts.reserve(width);
314 SmallVector<Value> encNegs;
318 for (
unsigned i = 0; i <= width; i += 2) {
320 Value bim1 = (i == 0) ? zeroFalse : bBits[i - 1];
321 Value bi = (i < width) ? bBits[i] : zeroFalse;
322 Value bip1 = (i + 1 < width) ? bBits[i + 1] : zeroFalse;
326 encNegs.push_back(encNeg);
328 Value encOne = rewriter.createOrFold<
comb::XorOp>(loc, bi, bim1,
true);
331 Value biInv = rewriter.createOrFold<
comb::XorOp>(loc, bi, constOne,
true);
333 rewriter.createOrFold<
comb::XorOp>(loc, bip1, constOne,
true);
335 rewriter.createOrFold<
comb::XorOp>(loc, bim1, constOne,
true);
337 Value andLeft = rewriter.createOrFold<
comb::AndOp>(
338 loc, ValueRange{bip1Inv, bi, bim1},
true);
339 Value andRight = rewriter.createOrFold<
comb::AndOp>(
340 loc, ValueRange{bip1, biInv, bim1Inv},
true);
342 rewriter.createOrFold<
comb::OrOp>(loc, andLeft, andRight,
true);
345 rewriter.createOrFold<comb::ReplicateOp>(loc, encNeg, rowWidth);
347 rewriter.createOrFold<comb::ReplicateOp>(loc, encOne, rowWidth);
349 rewriter.createOrFold<comb::ReplicateOp>(loc, encTwo, rowWidth);
352 Value selTwoA = rewriter.createOrFold<
comb::AndOp>(loc, encTwoRepl, twoA);
353 Value selOneA = rewriter.createOrFold<
comb::AndOp>(loc, encOneRepl,
a);
355 rewriter.createOrFold<
comb::OrOp>(loc, selTwoA, selOneA,
true);
359 rewriter.createOrFold<
comb::XorOp>(loc, magA, encNegRepl,
true);
363 partialProducts.push_back(ppRow);
370 loc, ValueRange{ppRow, zeroFalse, encNegPrev});
371 partialProducts.push_back(withSignCorrection);
380 loc, ValueRange{ppRow, zeroFalse, encNegPrev, shiftBy});
381 partialProducts.push_back(withSignCorrection);
384 if (partialProducts.size() == op.getNumResults())
401 for (
unsigned i = 0; i < partialProducts.size(); ++i) {
402 auto ppRow = partialProducts[i];
403 auto encNeg = encNegs[i];
404 auto ppWidth = ppRow.getType().getIntOrFloatBitWidth();
405 if (ppWidth < width) {
406 auto padding = width - ppWidth;
410 rewriter.createOrFold<comb::ReplicateOp>(loc, encNeg, padding);
412 loc, ValueRange{encNegPad, ppRow});
416 ppWidth = ppRow.getType().getIntOrFloatBitWidth();
417 if (ppWidth > width) {
420 partialProducts[i] = ppRow;
421 assert(partialProducts[i].getType().getIntOrFloatBitWidth() == width &&
422 "Expected sign-extended partial product to be full width");
426 while (partialProducts.size() < op.getNumResults())
427 partialProducts.push_back(zeroWidth);
429 assert(partialProducts.size() == op.getNumResults() &&
430 "Expected number of booth partial products to match results");
432 rewriter.replaceOp(op, partialProducts);
437struct DatapathPosPartialProductOpConversion
441 DatapathPosPartialProductOpConversion(MLIRContext *
context,
bool forceBooth)
443 forceBooth(forceBooth){};
445 const bool forceBooth;
447 LogicalResult matchAndRewrite(PosPartialProductOp op,
448 PatternRewriter &rewriter)
const override {
450 Value
a = op.getAddend0();
451 Value
b = op.getAddend1();
452 Value c = op.getMultiplicand();
453 unsigned width =
a.getType().getIntOrFloatBitWidth();
457 rewriter.replaceOpWithNewOp<
hw::ConstantOp>(op, op.getType(0), 0);
462 return lowerAndArray(rewriter, a, b, c, op, width);
466 static LogicalResult lowerAndArray(PatternRewriter &rewriter, Value a,
467 Value b, Value c, PosPartialProductOp op,
470 Location loc = op.getLoc();
476 SmallVector<Value> carryBits =
extractBits(rewriter, carry);
477 SmallVector<Value> saveBits =
extractBits(rewriter, save);
480 auto rowWidth = width;
481 auto knownBitsC = comb::computeKnownBits(c);
482 if (!knownBitsC.Zero.isZero()) {
483 if (knownBitsC.Zero.countLeadingOnes() > 1) {
486 rowWidth -= knownBitsC.Zero.countLeadingOnes() - 1;
494 comb::ConcatOp::create(rewriter, loc, ValueRange{c, zero});
499 SmallVector<Value> partialProducts;
500 partialProducts.reserve(width);
502 assert(op.getNumResults() <= width &&
503 "Cannot return more results than the operator width");
505 for (
unsigned i = 0; i < op.getNumResults(); ++i) {
507 rewriter.createOrFold<comb::ReplicateOp>(loc, saveBits[i], rowWidth);
509 rewriter.createOrFold<comb::ReplicateOp>(loc, carryBits[i], rowWidth);
511 auto ppRowSave = rewriter.createOrFold<
comb::AndOp>(loc, replSave, c);
513 rewriter.createOrFold<
comb::AndOp>(loc, replCarry, twoC);
515 rewriter.createOrFold<
comb::OrOp>(loc, ppRowSave, ppRowCarry);
516 auto ppAlign = ppRow;
520 comb::ConcatOp::create(rewriter, loc, ValueRange{ppRow, shiftBy});
524 if (rowWidth + i > width) {
527 partialProducts.push_back(ppAlignTrunc);
531 if (rowWidth + i < width) {
532 auto padding = width - rowWidth - i;
536 loc, ValueRange{zeroPad, ppAlign}));
540 partialProducts.push_back(ppAlign);
543 rewriter.replaceOp(op, partialProducts);
555struct ConvertDatapathToCombPass
556 :
public impl::ConvertDatapathToCombBase<ConvertDatapathToCombPass> {
557 void runOnOperation()
override;
558 using ConvertDatapathToCombBase<
559 ConvertDatapathToCombPass>::ConvertDatapathToCombBase;
564 Operation *op, RewritePatternSet &&
patterns,
568 mlir::GreedyRewriteConfig config;
573 config.setMaxIterations(2).setListener(analysis).setUseTopDownTraversal(
true);
576 if (failed(mlir::applyPatternsGreedily(op, std::move(
patterns), config)))
582void ConvertDatapathToCombPass::runOnOperation() {
583 RewritePatternSet
patterns(&getContext());
585 patterns.add<DatapathPartialProductOpConversion,
586 DatapathPosPartialProductOpConversion>(
patterns.getContext(),
590 analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
591 if (lowerCompressToAdd)
596 patterns.add<DatapathCompressOpConversion>(
patterns.getContext(), analysis);
599 getOperation(), std::move(
patterns), analysis)))
600 return signalPassFailure();
605 auto result = getOperation()->walk([&](Operation *op) {
606 if (llvm::isa_and_nonnull<datapath::DatapathDialect>(op->getDialect())) {
607 op->emitError(
"Datapath operation not converted: ") << *op;
608 return WalkResult::interrupt();
610 return WalkResult::advance();
612 if (result.wasInterrupted())
613 return signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value zeroPad(PatternRewriter &rewriter, Location loc, Value input, size_t targetWidth, size_t trailingZeros)
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static LogicalResult applyPatternsGreedilyWithTimingInfo(Operation *op, RewritePatternSet &&patterns, synth::IncrementalLongestPathAnalysis *analysis)
static std::unique_ptr< Context > context
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.