31#include "mlir/Pass/Pass.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "llvm/ADT/APInt.h"
34#include "llvm/ADT/PointerUnion.h"
35#include "llvm/Support/Debug.h"
38#define DEBUG_TYPE "comb-to-synth"
41#define GEN_PASS_DEF_CONVERTCOMBTOSYNTH
42#include "circt/Conversion/Passes.h.inc"
53static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
54 SmallVector<Value> bits;
55 comb::extractBits(builder, val, bits);
66template <
bool isLeftShift>
68 Value shiftAmount, int64_t maxShiftAmount,
69 llvm::function_ref<Value(int64_t)> getPadding,
70 llvm::function_ref<Value(int64_t)> getExtract) {
75 SmallVector<Value> nodes;
76 nodes.reserve(maxShiftAmount);
77 for (int64_t i = 0; i < maxShiftAmount; ++i) {
78 Value extract = getExtract(i);
79 Value padding = getPadding(i);
82 nodes.push_back(extract);
96 auto outOfBoundsValue = getPadding(maxShiftAmount);
97 assert(outOfBoundsValue &&
"outOfBoundsValue must be valid");
101 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
104 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
105 loc, ICmpPredicate::ult, shiftAmount,
109 return rewriter.createOrFold<
comb::MuxOp>(loc, inBound, result,
116 Value b, Value carry) {
118 auto aXnorB = comb::XorOp::create(rewriter, loc, ValueRange{a, b},
true);
120 comb::AndOp::create(rewriter, loc, ValueRange{carry, aXnorB},
true);
121 auto aAndB = comb::AndOp::create(rewriter, loc, ValueRange{a, b},
true);
122 return comb::OrOp::create(rewriter, loc, ValueRange{andOp, aAndB},
true);
127 val.getLoc(), val, val.getType().getIntOrFloatBitWidth() - 1, 1);
132 val.getLoc(), val, 0, val.getType().getIntOrFloatBitWidth() - 1);
137using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
142 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
144 if (value.getType().isInteger(0))
149 int64_t totalUnknownBits = 0;
150 for (
auto concatInput : llvm::reverse(concat.getInputs())) {
155 totalUnknownBits += unknownBits;
157 return totalUnknownBits;
162 values.push_back(constant.getValueAttr());
168 values.push_back(value);
169 return hw::getBitWidth(value.getType());
175 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
177 uint32_t bitPos = 0, unknownPos = 0;
178 APInt result(width, 0);
179 for (
auto constantOrValue : constantOrValues) {
181 if (
auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
182 elemWidth = constant.getValue().getBitWidth();
183 result.insertBits(constant.getValue(), bitPos);
185 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
186 assert(elemWidth >= 0 &&
"unknown bit width");
187 assert(elemWidth + unknownPos < 32 &&
"unknown bit width too large");
189 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
190 result.insertBits(APInt(elemWidth, usedBits), bitPos);
191 unknownPos += elemWidth;
203 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
205 llvm::function_ref<APInt(
const APInt &,
const APInt &)> emulate) {
206 SmallVector<ConstantOrValue> lhsValues, rhsValues;
208 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
209 "op must be a single result binary operation");
211 auto lhs = op->getOperand(0);
212 auto rhs = op->getOperand(1);
213 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
214 auto loc = op->getLoc();
219 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
222 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
223 if (totalUnknownBits > maxEmulationUnknownBits)
226 SmallVector<Value> emulatedResults;
227 emulatedResults.reserve(1 << totalUnknownBits);
230 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
232 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
233 auto it = constantPool.find(attr);
234 if (it != constantPool.end())
237 constantPool[attr] = constant;
241 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
242 lhsMask < lhsMaskEnd; ++lhsMask) {
244 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
245 rhsMask < rhsMaskEnd; ++rhsMask) {
248 emulatedResults.push_back(
getConstant(emulate(lhsValue, rhsValue)));
253 SmallVector<Value> selectors;
254 selectors.reserve(totalUnknownBits);
255 for (
auto &concatedValues : {rhsValues, lhsValues})
256 for (
auto valueOrConstant : concatedValues) {
257 auto value = dyn_cast<Value>(valueOrConstant);
263 assert(totalUnknownBits ==
static_cast<int64_t
>(selectors.size()) &&
264 "number of selectors must match");
265 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
283 matchAndRewrite(
AndOp op, OpAdaptor adaptor,
284 ConversionPatternRewriter &rewriter)
const override {
285 SmallVector<bool> nonInverts(adaptor.getInputs().size(),
false);
286 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
287 rewriter, op, adaptor.getInputs(), nonInverts);
297 matchAndRewrite(
OrOp op, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter)
const override {
300 SmallVector<bool> allInverts(adaptor.getInputs().size(),
true);
301 auto andOp = synth::aig::AndInverterOp::create(
302 rewriter, op.getLoc(), adaptor.getInputs(), allInverts);
303 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
314 matchAndRewrite(
XorOp op, OpAdaptor adaptor,
315 ConversionPatternRewriter &rewriter)
const override {
316 SmallVector<bool> inverted(adaptor.getInputs().size(),
false);
317 replaceOpWithNewOpAndCopyNamehint<synth::XorInverterOp>(
318 rewriter, op, adaptor.getInputs(), inverted);
324struct SynthXorInverterOpConversion
329 matchAndRewrite(synth::XorInverterOp op, OpAdaptor adaptor,
330 ConversionPatternRewriter &rewriter)
const override {
331 if (op.getNumOperands() != 2)
337 auto inputs = adaptor.getInputs();
338 auto allNotInverts = op.getInverted();
339 std::array<bool, 2> allInverts = {!allNotInverts[0], !allNotInverts[1]};
341 auto notAAndNotB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
343 auto aAndB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
344 inputs, allNotInverts);
346 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
347 rewriter, op, notAAndNotB, aAndB,
359 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
360 ConversionPatternRewriter &rewriter)
const override {
361 Value cond = adaptor.getCond();
362 Value trueVal = adaptor.getTrueValue();
363 Value falseVal = adaptor.getFalseValue();
365 if (!op.getType().isInteger()) {
366 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
373 if (!trueVal.getType().isInteger(1))
374 cond = comb::ReplicateOp::create(rewriter, op.getLoc(), trueVal.getType(),
377 Value result = synth::MuxInverterOp::create(rewriter, op.getLoc(), cond,
380 if (result.getType() != op.getType())
390struct SynthMuxInverterOpConversion
395 matchAndRewrite(synth::MuxInverterOp op, OpAdaptor adaptor,
396 ConversionPatternRewriter &rewriter)
const override {
397 auto inputs = adaptor.getInputs();
398 auto inverted = op.getInverted();
400 auto lhs = synth::aig::AndInverterOp::create(
401 rewriter, op.getLoc(), inputs[0], inputs[1], inverted[0], inverted[1]);
403 auto rhs = synth::aig::AndInverterOp::create(
404 rewriter, op.getLoc(), inputs[0], inputs[2], !inverted[0], inverted[2]);
406 auto nand = synth::aig::AndInverterOp::create(rewriter, op.getLoc(), lhs,
408 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(rewriter, op,
414template <
typename OpTy>
419 matchAndRewrite(OpTy op, OpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter)
const override {
427 ConversionPatternRewriter &rewriter) {
429 switch (operands.size()) {
431 llvm_unreachable(
"cannot be called with empty operand range");
438 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs},
true);
440 auto firstHalf = operands.size() / 2;
445 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs},
true);
454enum AdderArchitecture { RippleCarry, Sklanskey, KoggeStone, BrentKung };
455AdderArchitecture determineAdderArch(Operation *op, int64_t width) {
456 auto strAttr = op->getAttrOfType<StringAttr>(
"synth.test.arch");
458 return llvm::StringSwitch<AdderArchitecture>(strAttr.getValue())
459 .Case(
"SKLANSKEY", Sklanskey)
460 .Case(
"KOGGE-STONE", KoggeStone)
461 .Case(
"BRENT-KUNG", BrentKung)
462 .Case(
"RIPPLE-CARRY", RippleCarry);
472 return AdderArchitecture::RippleCarry;
477 return AdderArchitecture::Sklanskey;
481 return AdderArchitecture::KoggeStone;
491void lowerKoggeStonePrefixTree(OpBuilder &builder, Location loc,
492 SmallVector<Value> &pPrefix,
493 SmallVector<Value> &gPrefix) {
495 auto width =
static_cast<int64_t
>(pPrefix.size());
496 assert(width ==
static_cast<int64_t
>(gPrefix.size()));
497 SmallVector<Value> pPrefixNew = pPrefix;
498 SmallVector<Value> gPrefixNew = gPrefix;
501 for (int64_t stride = 1; stride < width; stride *= 2) {
503 for (int64_t i = stride; i < width; ++i) {
504 int64_t j = i - stride;
507 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
508 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
511 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
514 pPrefix = pPrefixNew;
515 gPrefix = gPrefixNew;
520 for (int64_t stride = 1; stride < width; stride *= 2) {
522 <<
"--------------------------------------- Kogge-Stone Stage "
524 for (int64_t i = stride; i < width; ++i) {
525 int64_t j = i - stride;
527 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
528 <<
" OR (P" << i << stage <<
" AND G" << j << stage
532 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
533 <<
" AND P" << j << stage <<
"\n";
542void lowerSklanskeyPrefixTree(OpBuilder &builder, Location loc,
543 SmallVector<Value> &pPrefix,
544 SmallVector<Value> &gPrefix) {
545 auto width =
static_cast<int64_t
>(pPrefix.size());
546 assert(width ==
static_cast<int64_t
>(gPrefix.size()));
547 SmallVector<Value> pPrefixNew = pPrefix;
548 SmallVector<Value> gPrefixNew = gPrefix;
549 for (int64_t stride = 1; stride < width; stride *= 2) {
550 for (int64_t i = stride; i < width; i += 2 * stride) {
551 for (int64_t k = 0; k < stride && i + k < width; ++k) {
557 comb::AndOp::create(builder, loc, pPrefix[idx], gPrefix[j]);
558 gPrefixNew[idx] = comb::OrOp::create(builder, loc, gPrefix[idx], andPG);
562 comb::AndOp::create(builder, loc, pPrefix[idx], pPrefix[j]);
566 pPrefix = pPrefixNew;
567 gPrefix = gPrefixNew;
572 for (int64_t stride = 1; stride < width; stride *= 2) {
573 llvm::dbgs() <<
"--------------------------------------- Sklanskey Stage "
575 for (int64_t i = stride; i < width; i += 2 * stride) {
576 for (int64_t k = 0; k < stride && i + k < width; ++k) {
580 llvm::dbgs() <<
"G" << idx << stage + 1 <<
" = G" << idx << stage
581 <<
" OR (P" << idx << stage <<
" AND G" << j << stage
585 llvm::dbgs() <<
"P" << idx << stage + 1 <<
" = P" << idx << stage
586 <<
" AND P" << j << stage <<
"\n";
597void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
598 SmallVector<Value> &pPrefix,
599 SmallVector<Value> &gPrefix) {
600 auto width =
static_cast<int64_t
>(pPrefix.size());
601 assert(width ==
static_cast<int64_t
>(gPrefix.size()));
602 SmallVector<Value> pPrefixNew = pPrefix;
603 SmallVector<Value> gPrefixNew = gPrefix;
607 for (stride = 1; stride < width; stride *= 2) {
608 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
609 int64_t j = i - stride;
612 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
613 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
616 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
618 pPrefix = pPrefixNew;
619 gPrefix = gPrefixNew;
623 for (; stride > 0; stride /= 2) {
624 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
625 int64_t j = i - stride;
628 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
629 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
632 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
634 pPrefix = pPrefixNew;
635 gPrefix = gPrefixNew;
640 for (stride = 1; stride < width; stride *= 2) {
641 llvm::dbgs() <<
"--------------------------------------- Brent-Kung FW "
642 << stage <<
" : Stride " << stride <<
"\n";
643 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
644 int64_t j = i - stride;
647 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
648 <<
" OR (P" << i << stage <<
" AND G" << j << stage
652 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
653 <<
" AND P" << j << stage <<
"\n";
658 for (; stride > 0; stride /= 2) {
659 if (stride * 3 - 1 < width)
660 llvm::dbgs() <<
"--------------------------------------- Brent-Kung BW "
661 << stage <<
" : Stride " << stride <<
"\n";
663 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
664 int64_t j = i - stride;
667 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
668 <<
" OR (P" << i << stage <<
" AND G" << j << stage
672 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
673 <<
" AND P" << j << stage <<
"\n";
681class LazyKoggeStonePrefixTree {
683 LazyKoggeStonePrefixTree(OpBuilder &builder, Location loc, int64_t width,
684 ArrayRef<Value> pPrefix, ArrayRef<Value> gPrefix)
685 : builder(builder), loc(loc), width(width) {
686 assert(width > 0 &&
"width must be positive");
687 for (int64_t i = 0; i < width; ++i)
688 prefixCache[{0, i}] = {pPrefix[i], gPrefix[i]};
692 std::pair<Value, Value> getFinal(int64_t i) {
693 assert(i >= 0 && i < width &&
"i out of bounds");
695 return getGroupAndPropagate(llvm::Log2_64_Ceil(width), i);
703 std::pair<Value, Value> getGroupAndPropagate(int64_t level, int64_t i);
707 DenseMap<std::pair<int64_t, int64_t>, std::pair<Value, Value>> prefixCache;
710std::pair<Value, Value>
711LazyKoggeStonePrefixTree::getGroupAndPropagate(int64_t level, int64_t i) {
712 assert(i < width &&
"i out of bounds");
713 auto key = std::make_pair(level, i);
714 auto it = prefixCache.find(key);
715 if (it != prefixCache.end())
718 assert(level > 0 &&
"If the level is 0, we should have hit the cache");
720 int64_t previousStride = 1ULL << (level - 1);
721 if (i < previousStride) {
723 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
724 prefixCache[key] = {propagateI, generateI};
725 return prefixCache[key];
728 int64_t j = i - previousStride;
729 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
730 auto [propagateJ, generateJ] = getGroupAndPropagate(level - 1, j);
732 Value andPG = comb::AndOp::create(builder, loc, propagateI, generateJ);
733 Value newGenerate = comb::OrOp::create(builder, loc, generateI, andPG);
736 comb::AndOp::create(builder, loc, propagateI, propagateJ);
737 prefixCache[key] = {newPropagate, newGenerate};
738 return prefixCache[key];
745 matchAndRewrite(
AddOp op, OpAdaptor adaptor,
746 ConversionPatternRewriter &rewriter)
const override {
747 auto inputs = adaptor.getInputs();
750 if (inputs.size() != 2)
753 auto width = op.getType().getIntOrFloatBitWidth();
756 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
762 auto arch = determineAdderArch(op, width);
763 if (arch == AdderArchitecture::RippleCarry)
764 return lowerRippleCarryAdder(op, inputs, rewriter);
765 return lowerParallelPrefixAdder(op, inputs, rewriter);
770 lowerRippleCarryAdder(
comb::AddOp op, ValueRange inputs,
771 ConversionPatternRewriter &rewriter)
const {
772 auto width = op.getType().getIntOrFloatBitWidth();
778 SmallVector<Value> results;
779 results.resize(width);
780 for (int64_t i = 0; i < width; ++i) {
781 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
783 xorOperands.push_back(carry);
787 results[width - i - 1] =
788 comb::XorOp::create(rewriter, op.getLoc(), xorOperands,
true);
797 carry = comb::AndOp::create(rewriter, op.getLoc(),
798 ValueRange{aBits[i], bBits[i]},
true);
805 LLVM_DEBUG(llvm::dbgs() <<
"Lower comb.add to Ripple-Carry Adder of width "
808 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
816 lowerParallelPrefixAdder(
comb::AddOp op, ValueRange inputs,
817 ConversionPatternRewriter &rewriter)
const {
818 auto width = op.getType().getIntOrFloatBitWidth();
824 SmallVector<Value> p, g;
828 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
830 p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
832 g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
836 llvm::dbgs() <<
"Lower comb.add to Parallel-Prefix of width " << width
837 <<
"\n--------------------------------------- Init\n";
839 for (int64_t i = 0; i < width; ++i) {
841 llvm::dbgs() <<
"P0" << i <<
" = A" << i <<
" XOR B" << i <<
"\n";
843 llvm::dbgs() <<
"G0" << i <<
" = A" << i <<
" AND B" << i <<
"\n";
848 SmallVector<Value> pPrefix = p;
849 SmallVector<Value> gPrefix = g;
852 auto arch = determineAdderArch(op, width);
855 case AdderArchitecture::RippleCarry:
856 llvm_unreachable(
"Ripple-Carry should be handled separately");
858 case AdderArchitecture::Sklanskey:
859 lowerSklanskeyPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
861 case AdderArchitecture::KoggeStone:
862 lowerKoggeStonePrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
864 case AdderArchitecture::BrentKung:
865 lowerBrentKungPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
871 SmallVector<Value> results;
872 results.resize(width);
874 results[width - 1] = p[0];
878 for (int64_t i = 1; i < width; ++i)
879 results[width - 1 - i] =
880 comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);
882 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
885 llvm::dbgs() <<
"--------------------------------------- Completion\n"
887 for (int64_t i = 1; i < width; ++i)
888 llvm::dbgs() <<
"RES" << i <<
" = P" << i <<
" XOR G" << i - 1 <<
"\n";
899 matchAndRewrite(
MulOp op, OpAdaptor adaptor,
900 ConversionPatternRewriter &rewriter)
const override {
901 if (adaptor.getInputs().size() != 2)
904 Location loc = op.getLoc();
905 Value
a = adaptor.getInputs()[0];
906 Value
b = adaptor.getInputs()[1];
907 unsigned width = op.getType().getIntOrFloatBitWidth();
916 SmallVector<Value> aBits =
extractBits(rewriter, a);
917 SmallVector<Value> bBits =
extractBits(rewriter, b);
922 SmallVector<SmallVector<Value>> partialProducts;
923 partialProducts.reserve(width);
924 for (
unsigned i = 0; i < width; ++i) {
925 SmallVector<Value> row(i, falseValue);
928 for (
unsigned j = 0; i + j < width; ++j)
930 rewriter.createOrFold<
comb::AndOp>(loc, aBits[j], bBits[i]));
932 partialProducts.push_back(row);
937 rewriter.replaceOp(op, partialProducts[0][0]);
943 auto addends = comp.compressToHeight(rewriter, 2);
946 auto newAdd = comb::AddOp::create(rewriter, loc, addends,
true);
952template <
typename OpTy>
954 DivModOpConversionBase(MLIRContext *
context, int64_t maxEmulationUnknownBits)
956 maxEmulationUnknownBits(maxEmulationUnknownBits) {
957 assert(maxEmulationUnknownBits < 32 &&
958 "maxEmulationUnknownBits must be less than 32");
960 const int64_t maxEmulationUnknownBits;
963struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
964 using DivModOpConversionBase<
DivUOp>::DivModOpConversionBase;
966 matchAndRewrite(
DivUOp op, OpAdaptor adaptor,
967 ConversionPatternRewriter &rewriter)
const override {
969 if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter)))
975 rewriter, maxEmulationUnknownBits, op,
976 [](
const APInt &lhs,
const APInt &rhs) {
979 return APInt::getZero(rhs.getBitWidth());
980 return lhs.udiv(rhs);
985struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
986 using DivModOpConversionBase<
ModUOp>::DivModOpConversionBase;
988 matchAndRewrite(
ModUOp op, OpAdaptor adaptor,
989 ConversionPatternRewriter &rewriter)
const override {
991 if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter)))
997 rewriter, maxEmulationUnknownBits, op,
998 [](
const APInt &lhs,
const APInt &rhs) {
1001 return APInt::getZero(rhs.getBitWidth());
1002 return lhs.urem(rhs);
1007struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
1008 using DivModOpConversionBase<
DivSOp>::DivModOpConversionBase;
1011 matchAndRewrite(
DivSOp op, OpAdaptor adaptor,
1012 ConversionPatternRewriter &rewriter)
const override {
1016 rewriter, maxEmulationUnknownBits, op,
1017 [](
const APInt &lhs,
const APInt &rhs) {
1020 return APInt::getZero(rhs.getBitWidth());
1021 return lhs.sdiv(rhs);
1026struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
1027 using DivModOpConversionBase<
ModSOp>::DivModOpConversionBase;
1029 matchAndRewrite(
ModSOp op, OpAdaptor adaptor,
1030 ConversionPatternRewriter &rewriter)
const override {
1034 rewriter, maxEmulationUnknownBits, op,
1035 [](
const APInt &lhs,
const APInt &rhs) {
1038 return APInt::getZero(rhs.getBitWidth());
1039 return lhs.srem(rhs);
1048 static Value constructRippleCarry(Location loc, Value a, Value b,
1050 ConversionPatternRewriter &rewriter) {
1058 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
1060 rewriter.createOrFold<
comb::XorOp>(loc, aBit, bBit,
true);
1061 auto aEqualB = rewriter.createOrFold<synth::aig::AndInverterOp>(
1062 loc, aBitXorBBit,
true);
1063 auto pred = rewriter.createOrFold<synth::aig::AndInverterOp>(
1064 loc, aBit, bBit,
true,
false);
1066 auto aBitAndBBit = rewriter.createOrFold<
comb::AndOp>(
1067 loc, ValueRange{aEqualB,
acc},
true);
1068 acc = rewriter.createOrFold<
comb::OrOp>(loc, pred, aBitAndBBit,
true);
1081 static Value computePrefixComparison(ConversionPatternRewriter &rewriter,
1082 Location loc, SmallVector<Value> pPrefix,
1083 SmallVector<Value> gPrefix,
1084 bool includeEq, AdderArchitecture arch) {
1085 auto width = pPrefix.size();
1086 Value finalGroup, finalPropagate;
1089 case AdderArchitecture::RippleCarry:
1090 llvm_unreachable(
"Ripple-Carry should be handled separately");
1092 case AdderArchitecture::Sklanskey: {
1093 lowerSklanskeyPrefixTree(rewriter, loc, pPrefix, gPrefix);
1094 finalGroup = gPrefix[width - 1];
1095 finalPropagate = pPrefix[width - 1];
1098 case AdderArchitecture::KoggeStone:
1101 std::tie(finalPropagate, finalGroup) =
1102 LazyKoggeStonePrefixTree(rewriter, loc, width, pPrefix, gPrefix)
1103 .getFinal(width - 1);
1105 case AdderArchitecture::BrentKung: {
1106 lowerBrentKungPrefixTree(rewriter, loc, pPrefix, gPrefix);
1107 finalGroup = gPrefix[width - 1];
1108 finalPropagate = pPrefix[width - 1];
1117 return comb::OrOp::create(rewriter, loc, finalGroup, finalPropagate);
1126 static Value constructUnsignedCompare(Operation *op, Location loc, Value a,
1127 Value b,
bool isLess,
bool includeEq,
1128 ConversionPatternRewriter &rewriter) {
1132 auto width =
a.getType().getIntOrFloatBitWidth();
1135 auto arch = determineAdderArch(op, width);
1136 if (arch == AdderArchitecture::RippleCarry)
1137 return constructRippleCarry(loc, a, b, includeEq, rewriter);
1148 SmallVector<Value> eq, gt;
1155 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
1157 auto xorBit = comb::XorOp::create(rewriter, loc, aBit, bBit);
1158 eq.push_back(comb::XorOp::create(rewriter, loc, xorBit, one));
1161 auto notA = comb::XorOp::create(rewriter, loc, aBit, one);
1162 gt.push_back(comb::AndOp::create(rewriter, loc, notA, bBit));
1165 return computePrefixComparison(rewriter, loc, std::move(eq), std::move(gt),
1170 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
1171 ConversionPatternRewriter &rewriter)
const override {
1172 auto lhs = adaptor.getLhs();
1173 auto rhs = adaptor.getRhs();
1175 switch (op.getPredicate()) {
1179 case ICmpPredicate::eq:
1180 case ICmpPredicate::ceq: {
1182 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
1184 SmallVector<bool> allInverts(xorBits.size(),
true);
1185 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
1186 rewriter, op, xorBits, allInverts);
1190 case ICmpPredicate::ne:
1191 case ICmpPredicate::cne: {
1193 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
1194 replaceOpWithNewOpAndCopyNamehint<comb::OrOp>(
1195 rewriter, op,
extractBits(rewriter, xorOp),
true);
1199 case ICmpPredicate::uge:
1200 case ICmpPredicate::ugt:
1201 case ICmpPredicate::ule:
1202 case ICmpPredicate::ult: {
1203 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
1204 op.getPredicate() == ICmpPredicate::ule;
1205 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
1206 op.getPredicate() == ICmpPredicate::ule;
1208 constructUnsignedCompare(op, op.getLoc(), lhs,
1209 rhs, isLess, includeEq,
1213 case ICmpPredicate::slt:
1214 case ICmpPredicate::sle:
1215 case ICmpPredicate::sgt:
1216 case ICmpPredicate::sge: {
1217 if (lhs.getType().getIntOrFloatBitWidth() == 0)
1218 return rewriter.notifyMatchFailure(
1219 op.getLoc(),
"i0 signed comparison is unsupported");
1220 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
1221 op.getPredicate() == ICmpPredicate::sle;
1222 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
1223 op.getPredicate() == ICmpPredicate::sle;
1232 auto sameSignResult = constructUnsignedCompare(
1233 op, op.getLoc(), aRest, bRest, isLess, includeEq, rewriter);
1237 comb::XorOp::create(rewriter, op.getLoc(), signA, signB);
1240 Value diffSignResult = isLess ? signA : signB;
1243 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1244 rewriter, op, signsDiffer, diffSignResult, sameSignResult);
1255 matchAndRewrite(
ParityOp op, OpAdaptor adaptor,
1256 ConversionPatternRewriter &rewriter)
const override {
1258 replaceOpWithNewOpAndCopyNamehint<comb::XorOp>(
1259 rewriter, op,
extractBits(rewriter, adaptor.getInput()),
true);
1268 matchAndRewrite(
comb::ShlOp op, OpAdaptor adaptor,
1269 ConversionPatternRewriter &rewriter)
const override {
1270 auto width = op.getType().getIntOrFloatBitWidth();
1271 auto lhs = adaptor.getLhs();
1273 rewriter, op.getLoc(), adaptor.getRhs(), width,
1275 [&](int64_t index) {
1281 op.getLoc(), rewriter.getIntegerType(index), 0);
1284 [&](int64_t index) {
1285 assert(index < width &&
"index out of bounds");
1301 ConversionPatternRewriter &rewriter)
const override {
1302 auto width = op.getType().getIntOrFloatBitWidth();
1303 auto lhs = adaptor.getLhs();
1305 rewriter, op.getLoc(), adaptor.getRhs(), width,
1307 [&](int64_t index) {
1313 op.getLoc(), rewriter.getIntegerType(index), 0);
1316 [&](int64_t index) {
1317 assert(index < width &&
"index out of bounds");
1319 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1333 ConversionPatternRewriter &rewriter)
const override {
1334 auto width = op.getType().getIntOrFloatBitWidth();
1336 return rewriter.notifyMatchFailure(op.getLoc(),
1337 "i0 signed shift is unsupported");
1338 auto lhs = adaptor.getLhs();
1341 rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
1346 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
1348 [&](int64_t index) {
1349 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
1353 [&](int64_t index) {
1354 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1370struct ConvertCombToSynthPass
1371 :
public impl::ConvertCombToSynthBase<ConvertCombToSynthPass> {
1372 void runOnOperation()
override;
1373 using ConvertCombToSynthBase<ConvertCombToSynthPass>::ConvertCombToSynthBase;
1379 uint32_t maxEmulationUnknownBits,
1383 CombAndOpConversion, CombParityOpConversion, CombXorOpToSynthConversion,
1384 CombMuxOpToSynthConversion,
1386 CombMulOpConversion, CombICmpOpConversion,
1388 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
1390 CombLowerVariadicOp<AddOp>, CombLowerVariadicOp<MulOp>>(
1394 patterns.add<SynthXorInverterOpConversion, SynthMuxInverterOpConversion>(
1397 patterns.add(comb::convertSubToAdd);
1399 patterns.add<CombOrToAIGConversion, CombAddOpConversion>(
1401 synth::populateVariadicAndInverterLoweringPatterns(
patterns);
1404 synth::populateVariadicXorInverterLoweringPatterns(
patterns);
1407 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
1408 CombModSOpConversion>(
patterns.getContext(),
1409 maxEmulationUnknownBits);
1412void ConvertCombToSynthPass::runOnOperation() {
1413 ConversionTarget target(getContext());
1416 target.addIllegalDialect<comb::CombDialect>();
1426 hw::AggregateConstantOp>();
1428 target.addLegalDialect<synth::SynthDialect>();
1430 target.addIllegalOp<synth::XorInverterOp, synth::MuxInverterOp>();
1433 if (!additionalLegalOps.empty())
1434 for (
const auto &opName : additionalLegalOps)
1435 target.addLegalOp(OperationName(opName, &getContext()));
1437 RewritePatternSet
patterns(&getContext());
1441 if (failed(mlir::applyPartialConversion(getOperation(), target,
1443 return signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a, Value b, Value carry)
static Value extractOtherThanMSB(OpBuilder &builder, Value val)
static Value extractMSB(OpBuilder &builder, Value val)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits, bool forceAIG)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.