31#include "mlir/Pass/Pass.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "llvm/ADT/APInt.h"
34#include "llvm/ADT/PointerUnion.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/DivisionByConstantInfo.h"
39#define DEBUG_TYPE "comb-to-synth"
42#define GEN_PASS_DEF_CONVERTCOMBTOSYNTH
43#include "circt/Conversion/Passes.h.inc"
54static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
55 SmallVector<Value> bits;
56 comb::extractBits(builder, val, bits);
67template <
bool isLeftShift>
69 Value shiftAmount, int64_t maxShiftAmount,
70 llvm::function_ref<Value(int64_t)> getPadding,
71 llvm::function_ref<Value(int64_t)> getExtract) {
76 SmallVector<Value> nodes;
77 nodes.reserve(maxShiftAmount);
78 for (int64_t i = 0; i < maxShiftAmount; ++i) {
79 Value extract = getExtract(i);
80 Value padding = getPadding(i);
83 nodes.push_back(extract);
97 auto outOfBoundsValue = getPadding(maxShiftAmount);
98 assert(outOfBoundsValue &&
"outOfBoundsValue must be valid");
102 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
105 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
106 loc, ICmpPredicate::ult, shiftAmount,
110 return rewriter.createOrFold<
comb::MuxOp>(loc, inBound, result,
117 Value b, Value carry) {
119 auto aXnorB = comb::XorOp::create(rewriter, loc, ValueRange{a, b},
true);
121 comb::AndOp::create(rewriter, loc, ValueRange{carry, aXnorB},
true);
122 auto aAndB = comb::AndOp::create(rewriter, loc, ValueRange{a, b},
true);
123 return comb::OrOp::create(rewriter, loc, ValueRange{andOp, aAndB},
true);
128 val.getLoc(), val, val.getType().getIntOrFloatBitWidth() - 1, 1);
133 val.getLoc(), val, 0, val.getType().getIntOrFloatBitWidth() - 1);
138using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
143 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
145 if (value.getType().isInteger(0))
150 int64_t totalUnknownBits = 0;
151 for (
auto concatInput : llvm::reverse(concat.getInputs())) {
156 totalUnknownBits += unknownBits;
158 return totalUnknownBits;
163 values.push_back(constant.getValueAttr());
169 values.push_back(value);
170 return hw::getBitWidth(value.getType());
176 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
178 uint32_t bitPos = 0, unknownPos = 0;
179 APInt result(width, 0);
180 for (
auto constantOrValue : constantOrValues) {
182 if (
auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
183 elemWidth = constant.getValue().getBitWidth();
184 result.insertBits(constant.getValue(), bitPos);
186 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
187 assert(elemWidth >= 0 &&
"unknown bit width");
188 assert(elemWidth + unknownPos < 32 &&
"unknown bit width too large");
190 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
191 result.insertBits(APInt(elemWidth, usedBits), bitPos);
192 unknownPos += elemWidth;
204 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
206 llvm::function_ref<APInt(
const APInt &,
const APInt &)> emulate) {
207 SmallVector<ConstantOrValue> lhsValues, rhsValues;
209 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
210 "op must be a single result binary operation");
212 auto lhs = op->getOperand(0);
213 auto rhs = op->getOperand(1);
214 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
215 auto loc = op->getLoc();
220 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
223 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
224 if (totalUnknownBits > maxEmulationUnknownBits)
227 SmallVector<Value> emulatedResults;
228 emulatedResults.reserve(1 << totalUnknownBits);
231 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
233 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
234 auto it = constantPool.find(attr);
235 if (it != constantPool.end())
238 constantPool[attr] = constant;
242 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
243 lhsMask < lhsMaskEnd; ++lhsMask) {
245 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
246 rhsMask < rhsMaskEnd; ++rhsMask) {
249 emulatedResults.push_back(
getConstant(emulate(lhsValue, rhsValue)));
254 SmallVector<Value> selectors;
255 selectors.reserve(totalUnknownBits);
256 for (
auto &concatedValues : {rhsValues, lhsValues})
257 for (
auto valueOrConstant : concatedValues) {
258 auto value = dyn_cast<Value>(valueOrConstant);
264 assert(totalUnknownBits ==
static_cast<int64_t
>(selectors.size()) &&
265 "number of selectors must match");
266 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
281 APInt(value.getType().getIntOrFloatBitWidth(), amount)));
292 APInt(value.getType().getIntOrFloatBitWidth(), amount)));
295template <
bool isSigned>
298 unsigned width = lhs.getType().getIntOrFloatBitWidth();
299 auto destTy = builder.getIntegerType(width << 1);
302 Value wideLhs = isSigned ? comb::createOrFoldSExt(builder, loc, lhs, destTy)
303 : comb::createZExt(builder, loc, lhs, width << 1);
305 builder, loc, isSigned ? rhs.sext(width << 1) : rhs.zext(width << 1));
307 loc, ValueRange{wideLhs, wideRhs},
true);
308 return builder.createOrFold<
comb::ExtractOp>(loc, product, width, width);
312 Value lhs,
const APInt &divisor) {
313 auto info = llvm::UnsignedDivisionByConstantInfo::get(divisor);
315 q = createMulHigh<false>(builder, loc, q, info.Magic);
317 Value diff = builder.createOrFold<
comb::SubOp>(loc, lhs, q);
319 q = builder.createOrFold<
comb::AddOp>(loc, q, diff);
325 Value lhs,
const APInt &divisor) {
326 unsigned width = lhs.getType().getIntOrFloatBitWidth();
327 auto info = llvm::SignedDivisionByConstantInfo::get(divisor);
328 Value q = createMulHigh<true>(builder, loc, lhs, info.Magic);
331 if (divisor.isStrictlyPositive() && info.Magic.isNegative())
332 q = builder.createOrFold<
comb::AddOp>(loc, q, lhs);
333 else if (divisor.isNegative() && info.Magic.isStrictlyPositive())
334 q = builder.createOrFold<
comb::SubOp>(loc, q, lhs);
338 Value signBit = builder.createOrFold<
comb::ExtractOp>(loc, q, width - 1, 1);
339 Value signPadded = comb::createZExt(builder, loc, signBit, width);
340 return builder.createOrFold<
comb::AddOp>(loc, q, signPadded);
354 matchAndRewrite(
AndOp op, OpAdaptor adaptor,
355 ConversionPatternRewriter &rewriter)
const override {
356 SmallVector<bool> nonInverts(adaptor.getInputs().size(),
false);
357 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
358 rewriter, op, adaptor.getInputs(), nonInverts);
368 matchAndRewrite(
OrOp op, OpAdaptor adaptor,
369 ConversionPatternRewriter &rewriter)
const override {
371 SmallVector<bool> allInverts(adaptor.getInputs().size(),
true);
372 auto andOp = synth::aig::AndInverterOp::create(
373 rewriter, op.getLoc(), adaptor.getInputs(), allInverts);
374 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
385 matchAndRewrite(
XorOp op, OpAdaptor adaptor,
386 ConversionPatternRewriter &rewriter)
const override {
387 SmallVector<bool> inverted(adaptor.getInputs().size(),
false);
388 replaceOpWithNewOpAndCopyNamehint<synth::XorInverterOp>(
389 rewriter, op, adaptor.getInputs(), inverted);
395struct SynthXorInverterOpConversion
400 matchAndRewrite(synth::XorInverterOp op, OpAdaptor adaptor,
401 ConversionPatternRewriter &rewriter)
const override {
402 if (op.getNumOperands() != 2)
408 auto inputs = adaptor.getInputs();
409 auto allNotInverts = op.getInverted();
410 std::array<bool, 2> allInverts = {!allNotInverts[0], !allNotInverts[1]};
412 auto notAAndNotB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
414 auto aAndB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
415 inputs, allNotInverts);
417 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
418 rewriter, op, notAAndNotB, aAndB,
430 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
431 ConversionPatternRewriter &rewriter)
const override {
432 Value cond = adaptor.getCond();
433 Value trueVal = adaptor.getTrueValue();
434 Value falseVal = adaptor.getFalseValue();
436 if (!op.getType().isInteger()) {
437 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
444 if (!trueVal.getType().isInteger(1))
445 cond = comb::ReplicateOp::create(rewriter, op.getLoc(), trueVal.getType(),
448 Value result = synth::MuxInverterOp::create(rewriter, op.getLoc(), cond,
451 if (result.getType() != op.getType())
461struct SynthMuxInverterOpConversion
466 matchAndRewrite(synth::MuxInverterOp op, OpAdaptor adaptor,
467 ConversionPatternRewriter &rewriter)
const override {
468 auto inputs = adaptor.getInputs();
469 auto inverted = op.getInverted();
471 auto lhs = synth::aig::AndInverterOp::create(
472 rewriter, op.getLoc(), inputs[0], inputs[1], inverted[0], inverted[1]);
474 auto rhs = synth::aig::AndInverterOp::create(
475 rewriter, op.getLoc(), inputs[0], inputs[2], !inverted[0], inverted[2]);
477 auto nand = synth::aig::AndInverterOp::create(rewriter, op.getLoc(), lhs,
479 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(rewriter, op,
485template <
typename OpTy>
490 matchAndRewrite(OpTy op, OpAdaptor adaptor,
491 ConversionPatternRewriter &rewriter)
const override {
498 ConversionPatternRewriter &rewriter) {
500 switch (operands.size()) {
502 llvm_unreachable(
"cannot be called with empty operand range");
509 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs},
true);
511 auto firstHalf = operands.size() / 2;
516 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs},
true);
525enum AdderArchitecture { RippleCarry, Sklanskey, KoggeStone, BrentKung };
526AdderArchitecture determineAdderArch(Operation *op, int64_t width) {
527 auto strAttr = op->getAttrOfType<StringAttr>(
"synth.test.arch");
529 return llvm::StringSwitch<AdderArchitecture>(strAttr.getValue())
530 .Case(
"SKLANSKEY", Sklanskey)
531 .Case(
"KOGGE-STONE", KoggeStone)
532 .Case(
"BRENT-KUNG", BrentKung)
533 .Case(
"RIPPLE-CARRY", RippleCarry);
543 return AdderArchitecture::RippleCarry;
548 return AdderArchitecture::Sklanskey;
552 return AdderArchitecture::KoggeStone;
562void lowerKoggeStonePrefixTree(OpBuilder &builder, Location loc,
563 SmallVector<Value> &pPrefix,
564 SmallVector<Value> &gPrefix) {
566 auto width =
static_cast<int64_t
>(pPrefix.size());
567 assert(width ==
static_cast<int64_t
>(gPrefix.size()));
568 SmallVector<Value> pPrefixNew = pPrefix;
569 SmallVector<Value> gPrefixNew = gPrefix;
572 for (int64_t stride = 1; stride < width; stride *= 2) {
574 for (int64_t i = stride; i < width; ++i) {
575 int64_t j = i - stride;
578 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
579 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
582 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
585 pPrefix = pPrefixNew;
586 gPrefix = gPrefixNew;
591 for (int64_t stride = 1; stride < width; stride *= 2) {
593 <<
"--------------------------------------- Kogge-Stone Stage "
595 for (int64_t i = stride; i < width; ++i) {
596 int64_t j = i - stride;
598 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
599 <<
" OR (P" << i << stage <<
" AND G" << j << stage
603 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
604 <<
" AND P" << j << stage <<
"\n";
613void lowerSklanskeyPrefixTree(OpBuilder &builder, Location loc,
614 SmallVector<Value> &pPrefix,
615 SmallVector<Value> &gPrefix) {
616 auto width =
static_cast<int64_t
>(pPrefix.size());
617 assert(width ==
static_cast<int64_t
>(gPrefix.size()));
618 SmallVector<Value> pPrefixNew = pPrefix;
619 SmallVector<Value> gPrefixNew = gPrefix;
620 for (int64_t stride = 1; stride < width; stride *= 2) {
621 for (int64_t i = stride; i < width; i += 2 * stride) {
622 for (int64_t k = 0; k < stride && i + k < width; ++k) {
628 comb::AndOp::create(builder, loc, pPrefix[idx], gPrefix[j]);
629 gPrefixNew[idx] = comb::OrOp::create(builder, loc, gPrefix[idx], andPG);
633 comb::AndOp::create(builder, loc, pPrefix[idx], pPrefix[j]);
637 pPrefix = pPrefixNew;
638 gPrefix = gPrefixNew;
643 for (int64_t stride = 1; stride < width; stride *= 2) {
644 llvm::dbgs() <<
"--------------------------------------- Sklanskey Stage "
646 for (int64_t i = stride; i < width; i += 2 * stride) {
647 for (int64_t k = 0; k < stride && i + k < width; ++k) {
651 llvm::dbgs() <<
"G" << idx << stage + 1 <<
" = G" << idx << stage
652 <<
" OR (P" << idx << stage <<
" AND G" << j << stage
656 llvm::dbgs() <<
"P" << idx << stage + 1 <<
" = P" << idx << stage
657 <<
" AND P" << j << stage <<
"\n";
668void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
669 SmallVector<Value> &pPrefix,
670 SmallVector<Value> &gPrefix) {
671 auto width =
static_cast<int64_t
>(pPrefix.size());
672 assert(width ==
static_cast<int64_t
>(gPrefix.size()));
673 SmallVector<Value> pPrefixNew = pPrefix;
674 SmallVector<Value> gPrefixNew = gPrefix;
678 for (stride = 1; stride < width; stride *= 2) {
679 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
680 int64_t j = i - stride;
683 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
684 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
687 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
689 pPrefix = pPrefixNew;
690 gPrefix = gPrefixNew;
694 for (; stride > 0; stride /= 2) {
695 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
696 int64_t j = i - stride;
699 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
700 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
703 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
705 pPrefix = pPrefixNew;
706 gPrefix = gPrefixNew;
711 for (stride = 1; stride < width; stride *= 2) {
712 llvm::dbgs() <<
"--------------------------------------- Brent-Kung FW "
713 << stage <<
" : Stride " << stride <<
"\n";
714 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
715 int64_t j = i - stride;
718 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
719 <<
" OR (P" << i << stage <<
" AND G" << j << stage
723 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
724 <<
" AND P" << j << stage <<
"\n";
729 for (; stride > 0; stride /= 2) {
730 if (stride * 3 - 1 < width)
731 llvm::dbgs() <<
"--------------------------------------- Brent-Kung BW "
732 << stage <<
" : Stride " << stride <<
"\n";
734 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
735 int64_t j = i - stride;
738 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
739 <<
" OR (P" << i << stage <<
" AND G" << j << stage
743 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
744 <<
" AND P" << j << stage <<
"\n";
752class LazyKoggeStonePrefixTree {
754 LazyKoggeStonePrefixTree(OpBuilder &builder, Location loc, int64_t width,
755 ArrayRef<Value> pPrefix, ArrayRef<Value> gPrefix)
756 : builder(builder), loc(loc), width(width) {
757 assert(width > 0 &&
"width must be positive");
758 for (int64_t i = 0; i < width; ++i)
759 prefixCache[{0, i}] = {pPrefix[i], gPrefix[i]};
763 std::pair<Value, Value> getFinal(int64_t i) {
764 assert(i >= 0 && i < width &&
"i out of bounds");
766 return getGroupAndPropagate(llvm::Log2_64_Ceil(width), i);
774 std::pair<Value, Value> getGroupAndPropagate(int64_t level, int64_t i);
778 DenseMap<std::pair<int64_t, int64_t>, std::pair<Value, Value>> prefixCache;
781std::pair<Value, Value>
782LazyKoggeStonePrefixTree::getGroupAndPropagate(int64_t level, int64_t i) {
783 assert(i < width &&
"i out of bounds");
784 auto key = std::make_pair(level, i);
785 auto it = prefixCache.find(key);
786 if (it != prefixCache.end())
789 assert(level > 0 &&
"If the level is 0, we should have hit the cache");
791 int64_t previousStride = 1ULL << (level - 1);
792 if (i < previousStride) {
794 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
795 prefixCache[key] = {propagateI, generateI};
796 return prefixCache[key];
799 int64_t j = i - previousStride;
800 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
801 auto [propagateJ, generateJ] = getGroupAndPropagate(level - 1, j);
803 Value andPG = comb::AndOp::create(builder, loc, propagateI, generateJ);
804 Value newGenerate = comb::OrOp::create(builder, loc, generateI, andPG);
807 comb::AndOp::create(builder, loc, propagateI, propagateJ);
808 prefixCache[key] = {newPropagate, newGenerate};
809 return prefixCache[key];
816 matchAndRewrite(
AddOp op, OpAdaptor adaptor,
817 ConversionPatternRewriter &rewriter)
const override {
818 auto inputs = adaptor.getInputs();
821 if (inputs.size() != 2)
824 auto width = op.getType().getIntOrFloatBitWidth();
827 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
833 auto arch = determineAdderArch(op, width);
834 if (arch == AdderArchitecture::RippleCarry)
835 return lowerRippleCarryAdder(op, inputs, rewriter);
836 return lowerParallelPrefixAdder(op, inputs, rewriter);
841 lowerRippleCarryAdder(
comb::AddOp op, ValueRange inputs,
842 ConversionPatternRewriter &rewriter)
const {
843 auto width = op.getType().getIntOrFloatBitWidth();
849 SmallVector<Value> results;
850 results.resize(width);
851 for (int64_t i = 0; i < width; ++i) {
852 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
854 xorOperands.push_back(carry);
858 results[width - i - 1] =
859 comb::XorOp::create(rewriter, op.getLoc(), xorOperands,
true);
868 carry = comb::AndOp::create(rewriter, op.getLoc(),
869 ValueRange{aBits[i], bBits[i]},
true);
876 LLVM_DEBUG(llvm::dbgs() <<
"Lower comb.add to Ripple-Carry Adder of width "
879 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
887 lowerParallelPrefixAdder(
comb::AddOp op, ValueRange inputs,
888 ConversionPatternRewriter &rewriter)
const {
889 auto width = op.getType().getIntOrFloatBitWidth();
895 SmallVector<Value> p, g;
899 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
901 p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
903 g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
907 llvm::dbgs() <<
"Lower comb.add to Parallel-Prefix of width " << width
908 <<
"\n--------------------------------------- Init\n";
910 for (int64_t i = 0; i < width; ++i) {
912 llvm::dbgs() <<
"P0" << i <<
" = A" << i <<
" XOR B" << i <<
"\n";
914 llvm::dbgs() <<
"G0" << i <<
" = A" << i <<
" AND B" << i <<
"\n";
919 SmallVector<Value> pPrefix = p;
920 SmallVector<Value> gPrefix = g;
923 auto arch = determineAdderArch(op, width);
926 case AdderArchitecture::RippleCarry:
927 llvm_unreachable(
"Ripple-Carry should be handled separately");
929 case AdderArchitecture::Sklanskey:
930 lowerSklanskeyPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
932 case AdderArchitecture::KoggeStone:
933 lowerKoggeStonePrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
935 case AdderArchitecture::BrentKung:
936 lowerBrentKungPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
942 SmallVector<Value> results;
943 results.resize(width);
945 results[width - 1] = p[0];
949 for (int64_t i = 1; i < width; ++i)
950 results[width - 1 - i] =
951 comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);
953 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
956 llvm::dbgs() <<
"--------------------------------------- Completion\n"
958 for (int64_t i = 1; i < width; ++i)
959 llvm::dbgs() <<
"RES" << i <<
" = P" << i <<
" XOR G" << i - 1 <<
"\n";
970 matchAndRewrite(
MulOp op, OpAdaptor adaptor,
971 ConversionPatternRewriter &rewriter)
const override {
972 if (adaptor.getInputs().size() != 2)
975 Location loc = op.getLoc();
976 Value a = adaptor.getInputs()[0];
977 Value b = adaptor.getInputs()[1];
978 unsigned width = op.getType().getIntOrFloatBitWidth();
987 SmallVector<Value> aBits =
extractBits(rewriter, a);
988 SmallVector<Value> bBits =
extractBits(rewriter, b);
993 SmallVector<SmallVector<Value>> partialProducts;
994 partialProducts.reserve(width);
995 for (
unsigned i = 0; i < width; ++i) {
996 SmallVector<Value> row(i, falseValue);
999 for (
unsigned j = 0; i + j < width; ++j)
1001 rewriter.createOrFold<
comb::AndOp>(loc, aBits[j], bBits[i]));
1003 partialProducts.push_back(row);
1008 rewriter.replaceOp(op, partialProducts[0][0]);
1014 auto addends = comp.compressToHeight(rewriter, 2);
1017 auto newAdd = comb::AddOp::create(rewriter, loc, addends,
true);
1023template <
typename OpTy>
1025 DivModOpConversionBase(MLIRContext *
context, int64_t maxEmulationUnknownBits)
1027 maxEmulationUnknownBits(maxEmulationUnknownBits) {
1028 assert(maxEmulationUnknownBits < 32 &&
1029 "maxEmulationUnknownBits must be less than 32");
1031 const int64_t maxEmulationUnknownBits;
1034struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
1035 using DivModOpConversionBase<
DivUOp>::DivModOpConversionBase;
1037 matchAndRewrite(
DivUOp op, OpAdaptor adaptor,
1038 ConversionPatternRewriter &rewriter)
const override {
1040 if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter)))
1045 if (
auto rhsConst = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>()) {
1046 APInt divisor = rhsConst.getValue();
1048 if (divisor.isZero()) {
1049 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
1063 rewriter, maxEmulationUnknownBits, op,
1064 [](
const APInt &lhs,
const APInt &rhs) {
1067 return APInt::getZero(rhs.getBitWidth());
1068 return lhs.udiv(rhs);
1073struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
1074 using DivModOpConversionBase<
ModUOp>::DivModOpConversionBase;
1076 matchAndRewrite(
ModUOp op, OpAdaptor adaptor,
1077 ConversionPatternRewriter &rewriter)
const override {
1079 if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter)))
1084 if (
auto rhsConst = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>()) {
1085 APInt divisor = rhsConst.getValue();
1087 if (divisor.isZero()) {
1088 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
1092 auto loc = op.getLoc();
1096 rewriter.createOrFold<
comb::MulOp>(loc, q, adaptor.getRhs());
1098 rewriter.createOrFold<
comb::SubOp>(loc, adaptor.getLhs(), product);
1107 rewriter, maxEmulationUnknownBits, op,
1108 [](
const APInt &lhs,
const APInt &rhs) {
1111 return APInt::getZero(rhs.getBitWidth());
1112 return lhs.urem(rhs);
1117struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
1118 using DivModOpConversionBase<
DivSOp>::DivModOpConversionBase;
1121 matchAndRewrite(
DivSOp op, OpAdaptor adaptor,
1122 ConversionPatternRewriter &rewriter)
const override {
1125 if (
auto rhsConst = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>()) {
1126 APInt divisor = rhsConst.getValue();
1127 unsigned width = op.getType().getIntOrFloatBitWidth();
1129 if (divisor.isZero()) {
1130 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
1135 if (divisor.isOne()) {
1140 if (divisor.isAllOnes()) {
1146 APInt::getZero(width)),
1158 rewriter, maxEmulationUnknownBits, op,
1159 [](
const APInt &lhs,
const APInt &rhs) {
1162 return APInt::getZero(rhs.getBitWidth());
1163 return lhs.sdiv(rhs);
1168struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
1169 using DivModOpConversionBase<
ModSOp>::DivModOpConversionBase;
1171 matchAndRewrite(
ModSOp op, OpAdaptor adaptor,
1172 ConversionPatternRewriter &rewriter)
const override {
1175 if (
auto rhsConst = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>()) {
1176 APInt divisor = rhsConst.getValue();
1178 if (divisor.isZero() || divisor.isOne() || divisor.isAllOnes()) {
1179 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
1183 auto loc = op.getLoc();
1187 rewriter.createOrFold<
comb::MulOp>(loc, q, adaptor.getRhs());
1189 rewriter.createOrFold<
comb::SubOp>(loc, adaptor.getLhs(), product);
1195 rewriter, maxEmulationUnknownBits, op,
1196 [](
const APInt &lhs,
const APInt &rhs) {
1199 return APInt::getZero(rhs.getBitWidth());
1200 return lhs.srem(rhs);
1209 static Value constructRippleCarry(Location loc, Value a, Value b,
1211 ConversionPatternRewriter &rewriter) {
1219 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
1221 rewriter.createOrFold<
comb::XorOp>(loc, aBit, bBit,
true);
1222 auto aEqualB = rewriter.createOrFold<synth::aig::AndInverterOp>(
1223 loc, aBitXorBBit,
true);
1224 auto pred = rewriter.createOrFold<synth::aig::AndInverterOp>(
1225 loc, aBit, bBit,
true,
false);
1227 auto aBitAndBBit = rewriter.createOrFold<
comb::AndOp>(
1228 loc, ValueRange{aEqualB,
acc},
true);
1229 acc = rewriter.createOrFold<
comb::OrOp>(loc, pred, aBitAndBBit,
true);
1242 static Value computePrefixComparison(ConversionPatternRewriter &rewriter,
1243 Location loc, SmallVector<Value> pPrefix,
1244 SmallVector<Value> gPrefix,
1245 bool includeEq, AdderArchitecture arch) {
1246 auto width = pPrefix.size();
1247 Value finalGroup, finalPropagate;
1250 case AdderArchitecture::RippleCarry:
1251 llvm_unreachable(
"Ripple-Carry should be handled separately");
1253 case AdderArchitecture::Sklanskey: {
1254 lowerSklanskeyPrefixTree(rewriter, loc, pPrefix, gPrefix);
1255 finalGroup = gPrefix[width - 1];
1256 finalPropagate = pPrefix[width - 1];
1259 case AdderArchitecture::KoggeStone:
1262 std::tie(finalPropagate, finalGroup) =
1263 LazyKoggeStonePrefixTree(rewriter, loc, width, pPrefix, gPrefix)
1264 .getFinal(width - 1);
1266 case AdderArchitecture::BrentKung: {
1267 lowerBrentKungPrefixTree(rewriter, loc, pPrefix, gPrefix);
1268 finalGroup = gPrefix[width - 1];
1269 finalPropagate = pPrefix[width - 1];
1278 return comb::OrOp::create(rewriter, loc, finalGroup, finalPropagate);
1287 static Value constructUnsignedCompare(Operation *op, Location loc, Value a,
1288 Value b,
bool isLess,
bool includeEq,
1289 ConversionPatternRewriter &rewriter) {
1293 auto width = a.getType().getIntOrFloatBitWidth();
1296 auto arch = determineAdderArch(op, width);
1297 if (arch == AdderArchitecture::RippleCarry)
1298 return constructRippleCarry(loc, a, b, includeEq, rewriter);
1309 SmallVector<Value> eq, gt;
1316 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
1318 auto xorBit = comb::XorOp::create(rewriter, loc, aBit, bBit);
1319 eq.push_back(comb::XorOp::create(rewriter, loc, xorBit, one));
1322 auto notA = comb::XorOp::create(rewriter, loc, aBit, one);
1323 gt.push_back(comb::AndOp::create(rewriter, loc, notA, bBit));
1326 return computePrefixComparison(rewriter, loc, std::move(eq), std::move(gt),
1331 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
1332 ConversionPatternRewriter &rewriter)
const override {
1333 auto lhs = adaptor.getLhs();
1334 auto rhs = adaptor.getRhs();
1336 switch (op.getPredicate()) {
1340 case ICmpPredicate::eq:
1341 case ICmpPredicate::ceq: {
1343 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
1345 SmallVector<bool> allInverts(xorBits.size(),
true);
1346 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
1347 rewriter, op, xorBits, allInverts);
1351 case ICmpPredicate::ne:
1352 case ICmpPredicate::cne: {
1354 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
1355 replaceOpWithNewOpAndCopyNamehint<comb::OrOp>(
1356 rewriter, op,
extractBits(rewriter, xorOp),
true);
1360 case ICmpPredicate::uge:
1361 case ICmpPredicate::ugt:
1362 case ICmpPredicate::ule:
1363 case ICmpPredicate::ult: {
1364 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
1365 op.getPredicate() == ICmpPredicate::ule;
1366 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
1367 op.getPredicate() == ICmpPredicate::ule;
1369 constructUnsignedCompare(op, op.getLoc(), lhs,
1370 rhs, isLess, includeEq,
1374 case ICmpPredicate::slt:
1375 case ICmpPredicate::sle:
1376 case ICmpPredicate::sgt:
1377 case ICmpPredicate::sge: {
1378 if (lhs.getType().getIntOrFloatBitWidth() == 0)
1379 return rewriter.notifyMatchFailure(
1380 op.getLoc(),
"i0 signed comparison is unsupported");
1381 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
1382 op.getPredicate() == ICmpPredicate::sle;
1383 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
1384 op.getPredicate() == ICmpPredicate::sle;
1393 auto sameSignResult = constructUnsignedCompare(
1394 op, op.getLoc(), aRest, bRest, isLess, includeEq, rewriter);
1398 comb::XorOp::create(rewriter, op.getLoc(), signA, signB);
1401 Value diffSignResult = isLess ? signA : signB;
1404 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1405 rewriter, op, signsDiffer, diffSignResult, sameSignResult);
1416 matchAndRewrite(
ParityOp op, OpAdaptor adaptor,
1417 ConversionPatternRewriter &rewriter)
const override {
1419 replaceOpWithNewOpAndCopyNamehint<comb::XorOp>(
1420 rewriter, op,
extractBits(rewriter, adaptor.getInput()),
true);
1429 matchAndRewrite(
comb::ShlOp op, OpAdaptor adaptor,
1430 ConversionPatternRewriter &rewriter)
const override {
1431 auto width = op.getType().getIntOrFloatBitWidth();
1432 auto lhs = adaptor.getLhs();
1434 rewriter, op.getLoc(), adaptor.getRhs(), width,
1436 [&](int64_t index) {
1442 op.getLoc(), rewriter.getIntegerType(index), 0);
1445 [&](int64_t index) {
1446 assert(index < width &&
"index out of bounds");
1462 ConversionPatternRewriter &rewriter)
const override {
1463 auto width = op.getType().getIntOrFloatBitWidth();
1464 auto lhs = adaptor.getLhs();
1466 rewriter, op.getLoc(), adaptor.getRhs(), width,
1468 [&](int64_t index) {
1474 op.getLoc(), rewriter.getIntegerType(index), 0);
1477 [&](int64_t index) {
1478 assert(index < width &&
"index out of bounds");
1480 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1494 ConversionPatternRewriter &rewriter)
const override {
1495 auto width = op.getType().getIntOrFloatBitWidth();
1497 return rewriter.notifyMatchFailure(op.getLoc(),
1498 "i0 signed shift is unsupported");
1499 auto lhs = adaptor.getLhs();
1502 rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
1507 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
1509 [&](int64_t index) {
1510 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
1514 [&](int64_t index) {
1515 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1531struct ConvertCombToSynthPass
1532 :
public impl::ConvertCombToSynthBase<ConvertCombToSynthPass> {
1533 void runOnOperation()
override;
1534 using ConvertCombToSynthBase<ConvertCombToSynthPass>::ConvertCombToSynthBase;
1540 uint32_t maxEmulationUnknownBits,
1544 CombAndOpConversion, CombParityOpConversion, CombXorOpToSynthConversion,
1545 CombMuxOpToSynthConversion,
1547 CombMulOpConversion, CombICmpOpConversion,
1549 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
1551 CombLowerVariadicOp<AddOp>, CombLowerVariadicOp<MulOp>>(
1555 patterns.add<SynthXorInverterOpConversion, SynthMuxInverterOpConversion>(
1558 patterns.add(comb::convertSubToAdd);
1560 patterns.add<CombOrToAIGConversion, CombAddOpConversion>(
1562 synth::populateVariadicAndInverterLoweringPatterns(
patterns);
1565 synth::populateVariadicXorInverterLoweringPatterns(
patterns);
1568 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
1569 CombModSOpConversion>(
patterns.getContext(),
1570 maxEmulationUnknownBits);
1573void ConvertCombToSynthPass::runOnOperation() {
1574 ConversionTarget target(getContext());
1577 target.addIllegalDialect<comb::CombDialect>();
1587 hw::AggregateConstantOp>();
1589 target.addLegalDialect<synth::SynthDialect>();
1591 target.addIllegalOp<synth::XorInverterOp, synth::MuxInverterOp>();
1594 if (!additionalLegalOps.empty())
1595 for (
const auto &opName : additionalLegalOps)
1596 target.addLegalOp(OperationName(opName, &getContext()));
1598 RewritePatternSet
patterns(&getContext());
1602 if (failed(mlir::applyPartialConversion(getOperation(), target,
1604 return signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static Value createAShrByConstant(OpBuilder &builder, Location loc, Value value, unsigned amount)
static Value createMulHigh(OpBuilder &builder, Location loc, Value lhs, const APInt &rhs)
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static Value lowerSignedDivByConstant(OpBuilder &builder, Location loc, Value lhs, const APInt &divisor)
static Value createLShrByConstant(OpBuilder &builder, Location loc, Value value, unsigned amount)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a, Value b, Value carry)
static Value extractOtherThanMSB(OpBuilder &builder, Value val)
static Value extractMSB(OpBuilder &builder, Value val)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits, bool forceAIG)
static Value lowerUnsignedDivByConstant(OpBuilder &builder, Location loc, Value lhs, const APInt &divisor)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.