33#include "mlir/Pass/Pass.h"
34#include "mlir/Transforms/DialectConversion.h"
35#include "llvm/ADT/APInt.h"
36#include "llvm/ADT/PointerUnion.h"
37#include "llvm/Support/Debug.h"
40#define DEBUG_TYPE "comb-to-synth"
43#define GEN_PASS_DEF_CONVERTCOMBTOSYNTH
44#include "circt/Conversion/Passes.h.inc"
55static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
56 SmallVector<Value> bits;
57 comb::extractBits(builder, val, bits);
68template <
bool isLeftShift>
70 Value shiftAmount, int64_t maxShiftAmount,
71 llvm::function_ref<Value(int64_t)> getPadding,
72 llvm::function_ref<Value(int64_t)> getExtract) {
77 SmallVector<Value> nodes;
78 nodes.reserve(maxShiftAmount);
79 for (int64_t i = 0; i < maxShiftAmount; ++i) {
80 Value extract = getExtract(i);
81 Value padding = getPadding(i);
84 nodes.push_back(extract);
98 auto outOfBoundsValue = getPadding(maxShiftAmount);
99 assert(outOfBoundsValue &&
"outOfBoundsValue must be valid");
103 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
106 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
107 loc, ICmpPredicate::ult, shiftAmount,
111 return rewriter.createOrFold<
comb::MuxOp>(loc, inBound, result,
119 Value b, Value carry,
120 bool useMajorityInverterOp) {
121 if (useMajorityInverterOp) {
122 std::array<Value, 3> inputs = {a, b, carry};
123 std::array<bool, 3> inverts = {
false,
false,
false};
124 return synth::mig::MajorityInverterOp::create(rewriter, loc, inputs,
129 auto aXnorB = comb::XorOp::create(rewriter, loc, ValueRange{a, b},
true);
131 comb::AndOp::create(rewriter, loc, ValueRange{carry, aXnorB},
true);
132 auto aAndB = comb::AndOp::create(rewriter, loc, ValueRange{a, b},
true);
133 return comb::OrOp::create(rewriter, loc, ValueRange{andOp, aAndB},
true);
138 val.getLoc(), val, val.getType().getIntOrFloatBitWidth() - 1, 1);
143 val.getLoc(), val, 0, val.getType().getIntOrFloatBitWidth() - 1);
148using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
153 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
155 if (value.getType().isInteger(0))
160 int64_t totalUnknownBits = 0;
161 for (
auto concatInput : llvm::reverse(concat.getInputs())) {
166 totalUnknownBits += unknownBits;
168 return totalUnknownBits;
173 values.push_back(constant.getValueAttr());
179 values.push_back(value);
180 return hw::getBitWidth(value.getType());
186 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
188 uint32_t bitPos = 0, unknownPos = 0;
189 APInt result(width, 0);
190 for (
auto constantOrValue : constantOrValues) {
192 if (
auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
193 elemWidth = constant.getValue().getBitWidth();
194 result.insertBits(constant.getValue(), bitPos);
196 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
197 assert(elemWidth >= 0 &&
"unknown bit width");
198 assert(elemWidth + unknownPos < 32 &&
"unknown bit width too large");
200 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
201 result.insertBits(APInt(elemWidth, usedBits), bitPos);
202 unknownPos += elemWidth;
214 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
216 llvm::function_ref<APInt(
const APInt &,
const APInt &)> emulate) {
217 SmallVector<ConstantOrValue> lhsValues, rhsValues;
219 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
220 "op must be a single result binary operation");
222 auto lhs = op->getOperand(0);
223 auto rhs = op->getOperand(1);
224 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
225 auto loc = op->getLoc();
230 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
233 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
234 if (totalUnknownBits > maxEmulationUnknownBits)
237 SmallVector<Value> emulatedResults;
238 emulatedResults.reserve(1 << totalUnknownBits);
241 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
243 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
244 auto it = constantPool.find(attr);
245 if (it != constantPool.end())
248 constantPool[attr] = constant;
252 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
253 lhsMask < lhsMaskEnd; ++lhsMask) {
255 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
256 rhsMask < rhsMaskEnd; ++rhsMask) {
259 emulatedResults.push_back(
getConstant(emulate(lhsValue, rhsValue)));
264 SmallVector<Value> selectors;
265 selectors.reserve(totalUnknownBits);
266 for (
auto &concatedValues : {rhsValues, lhsValues})
267 for (
auto valueOrConstant : concatedValues) {
268 auto value = dyn_cast<Value>(valueOrConstant);
274 assert(totalUnknownBits ==
static_cast<int64_t
>(selectors.size()) &&
275 "number of selectors must match");
276 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
294 matchAndRewrite(
AndOp op, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter)
const override {
296 SmallVector<bool> nonInverts(adaptor.getInputs().size(),
false);
297 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
298 rewriter, op, adaptor.getInputs(), nonInverts);
308 matchAndRewrite(
OrOp op, OpAdaptor adaptor,
309 ConversionPatternRewriter &rewriter)
const override {
311 SmallVector<bool> allInverts(adaptor.getInputs().size(),
true);
312 auto andOp = synth::aig::AndInverterOp::create(
313 rewriter, op.getLoc(), adaptor.getInputs(), allInverts);
314 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
324 matchAndRewrite(
OrOp op, OpAdaptor adaptor,
325 ConversionPatternRewriter &rewriter)
const override {
326 if (op.getNumOperands() != 2)
328 SmallVector<Value, 3> inputs(adaptor.getInputs());
330 rewriter, op.getLoc(),
331 APInt::getAllOnes(hw::getBitWidth(op.getType())));
332 inputs.push_back(one);
333 std::array<bool, 3> inverts = {
false,
false,
false};
334 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
335 rewriter, op, inputs, inverts);
340struct AndInverterToMIGConversion
344 matchAndRewrite(synth::aig::AndInverterOp op, OpAdaptor adaptor,
345 ConversionPatternRewriter &rewriter)
const override {
346 if (op.getNumOperands() > 2)
348 if (op.getNumOperands() == 1) {
349 SmallVector<bool, 1> inverts{op.getInverted()[0]};
350 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
351 rewriter, op, adaptor.getInputs(), inverts);
354 SmallVector<Value, 3> inputs(adaptor.getInputs());
356 rewriter, op.getLoc(), APInt::getZero(hw::getBitWidth(op.getType())));
357 inputs.push_back(one);
358 SmallVector<bool, 3> inverts(adaptor.getInverted());
359 inverts.push_back(
false);
360 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
361 rewriter, op, inputs, inverts);
366struct MajorityInverterToAIGConversion
371 matchAndRewrite(synth::mig::MajorityInverterOp op, OpAdaptor adaptor,
372 ConversionPatternRewriter &rewriter)
const override {
374 if (op.getNumOperands() > 3)
377 if (op.getNumOperands() == 1) {
378 bool inverts[1] = {op.getInverted()[0]};
379 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
380 rewriter, op, adaptor.getInputs(), inverts);
384 assert(op.getNumOperands() == 3 &&
"Expected 3 operands for majority op");
385 auto getProduct = [&](
unsigned idx1,
unsigned idx2) {
386 bool inverts[2] = {op.getInverted()[idx1], op.getInverted()[idx2]};
387 return synth::aig::AndInverterOp::create(
388 rewriter, op.getLoc(),
389 ValueRange{adaptor.getInputs()[idx1], adaptor.getInputs()[idx2]},
395 Value products[3] = {getProduct(0, 1), getProduct(0, 2), getProduct(1, 2)};
396 bool allInverted[3] = {
true,
true,
true};
397 auto notOr = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
398 products, allInverted);
399 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
411 matchAndRewrite(
XorOp op, OpAdaptor adaptor,
412 ConversionPatternRewriter &rewriter)
const override {
413 if (op.getNumOperands() != 2)
419 auto inputs = adaptor.getInputs();
420 SmallVector<bool> allInverts(inputs.size(),
true);
421 SmallVector<bool> allNotInverts(inputs.size(),
false);
423 auto notAAndNotB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
425 auto aAndB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
426 inputs, allNotInverts);
428 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
429 rewriter, op, notAAndNotB, aAndB,
436template <
typename OpTy>
441 matchAndRewrite(OpTy op, OpAdaptor adaptor,
442 ConversionPatternRewriter &rewriter)
const override {
449 ConversionPatternRewriter &rewriter) {
451 switch (operands.size()) {
453 llvm_unreachable(
"cannot be called with empty operand range");
460 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs},
true);
462 auto firstHalf = operands.size() / 2;
467 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs},
true);
477 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
478 ConversionPatternRewriter &rewriter)
const override {
479 Value cond = op.getCond();
480 auto trueVal = op.getTrueValue();
481 auto falseVal = op.getFalseValue();
483 if (!op.getType().isInteger()) {
485 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
493 if (!trueVal.getType().isInteger(1))
494 cond = comb::ReplicateOp::create(rewriter, op.getLoc(), trueVal.getType(),
499 synth::aig::AndInverterOp::create(rewriter, op.getLoc(), cond, trueVal);
500 auto rhs = synth::aig::AndInverterOp::create(rewriter, op.getLoc(), cond,
501 falseVal,
true,
false);
503 Value result = comb::OrOp::create(rewriter, op.getLoc(), lhs, rhs);
505 if (result.getType() != op.getType())
517enum AdderArchitecture { RippleCarry, Sklanskey, KoggeStone, BrentKung };
518AdderArchitecture determineAdderArch(Operation *op, int64_t width) {
519 auto strAttr = op->getAttrOfType<StringAttr>(
"synth.test.arch");
521 return llvm::StringSwitch<AdderArchitecture>(strAttr.getValue())
522 .Case(
"SKLANSKEY", Sklanskey)
523 .Case(
"KOGGE-STONE", KoggeStone)
524 .Case(
"BRENT-KUNG", BrentKung)
525 .Case(
"RIPPLE-CARRY", RippleCarry);
535 return AdderArchitecture::RippleCarry;
540 return AdderArchitecture::Sklanskey;
544 return AdderArchitecture::KoggeStone;
554void lowerKoggeStonePrefixTree(OpBuilder &builder, Location loc,
555 SmallVector<Value> &pPrefix,
556 SmallVector<Value> &gPrefix) {
558 auto width =
static_cast<int64_t
>(pPrefix.size());
559 assert(width ==
static_cast<int64_t
>(gPrefix.size()));
560 SmallVector<Value> pPrefixNew = pPrefix;
561 SmallVector<Value> gPrefixNew = gPrefix;
564 for (int64_t stride = 1; stride < width; stride *= 2) {
566 for (int64_t i = stride; i < width; ++i) {
567 int64_t j = i - stride;
570 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
571 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
574 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
577 pPrefix = pPrefixNew;
578 gPrefix = gPrefixNew;
583 for (int64_t stride = 1; stride < width; stride *= 2) {
585 <<
"--------------------------------------- Kogge-Stone Stage "
587 for (int64_t i = stride; i < width; ++i) {
588 int64_t j = i - stride;
590 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
591 <<
" OR (P" << i << stage <<
" AND G" << j << stage
595 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
596 <<
" AND P" << j << stage <<
"\n";
605void lowerSklanskeyPrefixTree(OpBuilder &builder, Location loc,
606 SmallVector<Value> &pPrefix,
607 SmallVector<Value> &gPrefix) {
608 auto width =
static_cast<int64_t
>(pPrefix.size());
609 assert(width ==
static_cast<int64_t
>(gPrefix.size()));
610 SmallVector<Value> pPrefixNew = pPrefix;
611 SmallVector<Value> gPrefixNew = gPrefix;
612 for (int64_t stride = 1; stride < width; stride *= 2) {
613 for (int64_t i = stride; i < width; i += 2 * stride) {
614 for (int64_t k = 0; k < stride && i + k < width; ++k) {
620 comb::AndOp::create(builder, loc, pPrefix[idx], gPrefix[j]);
621 gPrefixNew[idx] = comb::OrOp::create(builder, loc, gPrefix[idx], andPG);
625 comb::AndOp::create(builder, loc, pPrefix[idx], pPrefix[j]);
629 pPrefix = pPrefixNew;
630 gPrefix = gPrefixNew;
635 for (int64_t stride = 1; stride < width; stride *= 2) {
636 llvm::dbgs() <<
"--------------------------------------- Sklanskey Stage "
638 for (int64_t i = stride; i < width; i += 2 * stride) {
639 for (int64_t k = 0; k < stride && i + k < width; ++k) {
643 llvm::dbgs() <<
"G" << idx << stage + 1 <<
" = G" << idx << stage
644 <<
" OR (P" << idx << stage <<
" AND G" << j << stage
648 llvm::dbgs() <<
"P" << idx << stage + 1 <<
" = P" << idx << stage
649 <<
" AND P" << j << stage <<
"\n";
660void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
661 SmallVector<Value> &pPrefix,
662 SmallVector<Value> &gPrefix) {
663 auto width =
static_cast<int64_t
>(pPrefix.size());
664 assert(width ==
static_cast<int64_t
>(gPrefix.size()));
665 SmallVector<Value> pPrefixNew = pPrefix;
666 SmallVector<Value> gPrefixNew = gPrefix;
670 for (stride = 1; stride < width; stride *= 2) {
671 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
672 int64_t j = i - stride;
675 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
676 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
679 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
681 pPrefix = pPrefixNew;
682 gPrefix = gPrefixNew;
686 for (; stride > 0; stride /= 2) {
687 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
688 int64_t j = i - stride;
691 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
692 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
695 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
697 pPrefix = pPrefixNew;
698 gPrefix = gPrefixNew;
703 for (stride = 1; stride < width; stride *= 2) {
704 llvm::dbgs() <<
"--------------------------------------- Brent-Kung FW "
705 << stage <<
" : Stride " << stride <<
"\n";
706 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
707 int64_t j = i - stride;
710 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
711 <<
" OR (P" << i << stage <<
" AND G" << j << stage
715 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
716 <<
" AND P" << j << stage <<
"\n";
721 for (; stride > 0; stride /= 2) {
722 if (stride * 3 - 1 < width)
723 llvm::dbgs() <<
"--------------------------------------- Brent-Kung BW "
724 << stage <<
" : Stride " << stride <<
"\n";
726 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
727 int64_t j = i - stride;
730 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
731 <<
" OR (P" << i << stage <<
" AND G" << j << stage
735 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
736 <<
" AND P" << j << stage <<
"\n";
744class LazyKoggeStonePrefixTree {
746 LazyKoggeStonePrefixTree(OpBuilder &builder, Location loc, int64_t width,
747 ArrayRef<Value> pPrefix, ArrayRef<Value> gPrefix)
748 : builder(builder), loc(loc), width(width) {
749 assert(width > 0 &&
"width must be positive");
750 for (int64_t i = 0; i < width; ++i)
751 prefixCache[{0, i}] = {pPrefix[i], gPrefix[i]};
755 std::pair<Value, Value> getFinal(int64_t i) {
756 assert(i >= 0 && i < width &&
"i out of bounds");
758 return getGroupAndPropagate(llvm::Log2_64_Ceil(width), i);
766 std::pair<Value, Value> getGroupAndPropagate(int64_t level, int64_t i);
770 DenseMap<std::pair<int64_t, int64_t>, std::pair<Value, Value>> prefixCache;
773std::pair<Value, Value>
774LazyKoggeStonePrefixTree::getGroupAndPropagate(int64_t level, int64_t i) {
775 assert(i < width &&
"i out of bounds");
776 auto key = std::make_pair(level, i);
777 auto it = prefixCache.find(key);
778 if (it != prefixCache.end())
781 assert(level > 0 &&
"If the level is 0, we should have hit the cache");
783 int64_t previousStride = 1ULL << (level - 1);
784 if (i < previousStride) {
786 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
787 prefixCache[key] = {propagateI, generateI};
788 return prefixCache[key];
791 int64_t j = i - previousStride;
792 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
793 auto [propagateJ, generateJ] = getGroupAndPropagate(level - 1, j);
795 Value andPG = comb::AndOp::create(builder, loc, propagateI, generateJ);
796 Value newGenerate = comb::OrOp::create(builder, loc, generateI, andPG);
799 comb::AndOp::create(builder, loc, propagateI, propagateJ);
800 prefixCache[key] = {newPropagate, newGenerate};
801 return prefixCache[key];
804template <
bool lowerToMIG>
809 matchAndRewrite(
AddOp op, OpAdaptor adaptor,
810 ConversionPatternRewriter &rewriter)
const override {
811 auto inputs = adaptor.getInputs();
814 if (inputs.size() != 2)
817 auto width = op.getType().getIntOrFloatBitWidth();
820 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
826 auto arch = determineAdderArch(op, width);
827 if (arch == AdderArchitecture::RippleCarry)
828 return lowerRippleCarryAdder(op, inputs, rewriter);
829 return lowerParallelPrefixAdder(op, inputs, rewriter);
834 lowerRippleCarryAdder(
comb::AddOp op, ValueRange inputs,
835 ConversionPatternRewriter &rewriter)
const {
836 auto width = op.getType().getIntOrFloatBitWidth();
842 SmallVector<Value> results;
843 results.resize(width);
844 for (int64_t i = 0; i < width; ++i) {
845 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
847 xorOperands.push_back(carry);
851 results[width - i - 1] =
852 comb::XorOp::create(rewriter, op.getLoc(), xorOperands,
true);
861 carry = comb::AndOp::create(rewriter, op.getLoc(),
862 ValueRange{aBits[i], bBits[i]},
true);
869 LLVM_DEBUG(llvm::dbgs() <<
"Lower comb.add to Ripple-Carry Adder of width "
872 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
880 lowerParallelPrefixAdder(
comb::AddOp op, ValueRange inputs,
881 ConversionPatternRewriter &rewriter)
const {
882 auto width = op.getType().getIntOrFloatBitWidth();
888 SmallVector<Value> p, g;
892 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
894 p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
896 g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
900 llvm::dbgs() <<
"Lower comb.add to Parallel-Prefix of width " << width
901 <<
"\n--------------------------------------- Init\n";
903 for (int64_t i = 0; i < width; ++i) {
905 llvm::dbgs() <<
"P0" << i <<
" = A" << i <<
" XOR B" << i <<
"\n";
907 llvm::dbgs() <<
"G0" << i <<
" = A" << i <<
" AND B" << i <<
"\n";
912 SmallVector<Value> pPrefix = p;
913 SmallVector<Value> gPrefix = g;
916 auto arch = determineAdderArch(op, width);
919 case AdderArchitecture::RippleCarry:
920 llvm_unreachable(
"Ripple-Carry should be handled separately");
922 case AdderArchitecture::Sklanskey:
923 lowerSklanskeyPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
925 case AdderArchitecture::KoggeStone:
926 lowerKoggeStonePrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
928 case AdderArchitecture::BrentKung:
929 lowerBrentKungPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
935 SmallVector<Value> results;
936 results.resize(width);
938 results[width - 1] = p[0];
942 for (int64_t i = 1; i < width; ++i)
943 results[width - 1 - i] =
944 comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);
946 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
949 llvm::dbgs() <<
"--------------------------------------- Completion\n"
951 for (int64_t i = 1; i < width; ++i)
952 llvm::dbgs() <<
"RES" << i <<
" = P" << i <<
" XOR G" << i - 1 <<
"\n";
963 matchAndRewrite(
MulOp op, OpAdaptor adaptor,
964 ConversionPatternRewriter &rewriter)
const override {
965 if (adaptor.getInputs().size() != 2)
968 Location loc = op.getLoc();
969 Value
a = adaptor.getInputs()[0];
970 Value
b = adaptor.getInputs()[1];
971 unsigned width = op.getType().getIntOrFloatBitWidth();
980 SmallVector<Value> aBits =
extractBits(rewriter, a);
981 SmallVector<Value> bBits =
extractBits(rewriter, b);
986 SmallVector<SmallVector<Value>> partialProducts;
987 partialProducts.reserve(width);
988 for (
unsigned i = 0; i < width; ++i) {
989 SmallVector<Value> row(i, falseValue);
992 for (
unsigned j = 0; i + j < width; ++j)
994 rewriter.createOrFold<
comb::AndOp>(loc, aBits[j], bBits[i]));
996 partialProducts.push_back(row);
1001 rewriter.replaceOp(op, partialProducts[0][0]);
1007 auto addends = comp.compressToHeight(rewriter, 2);
1010 auto newAdd = comb::AddOp::create(rewriter, loc, addends,
true);
1016template <
typename OpTy>
1018 DivModOpConversionBase(MLIRContext *
context, int64_t maxEmulationUnknownBits)
1020 maxEmulationUnknownBits(maxEmulationUnknownBits) {
1021 assert(maxEmulationUnknownBits < 32 &&
1022 "maxEmulationUnknownBits must be less than 32");
1024 const int64_t maxEmulationUnknownBits;
1027struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
1028 using DivModOpConversionBase<
DivUOp>::DivModOpConversionBase;
1030 matchAndRewrite(
DivUOp op, OpAdaptor adaptor,
1031 ConversionPatternRewriter &rewriter)
const override {
1033 if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter)))
1039 rewriter, maxEmulationUnknownBits, op,
1040 [](
const APInt &lhs,
const APInt &rhs) {
1043 return APInt::getZero(rhs.getBitWidth());
1044 return lhs.udiv(rhs);
1049struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
1050 using DivModOpConversionBase<
ModUOp>::DivModOpConversionBase;
1052 matchAndRewrite(
ModUOp op, OpAdaptor adaptor,
1053 ConversionPatternRewriter &rewriter)
const override {
1055 if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter)))
1061 rewriter, maxEmulationUnknownBits, op,
1062 [](
const APInt &lhs,
const APInt &rhs) {
1065 return APInt::getZero(rhs.getBitWidth());
1066 return lhs.urem(rhs);
1071struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
1072 using DivModOpConversionBase<
DivSOp>::DivModOpConversionBase;
1075 matchAndRewrite(
DivSOp op, OpAdaptor adaptor,
1076 ConversionPatternRewriter &rewriter)
const override {
1080 rewriter, maxEmulationUnknownBits, op,
1081 [](
const APInt &lhs,
const APInt &rhs) {
1084 return APInt::getZero(rhs.getBitWidth());
1085 return lhs.sdiv(rhs);
1090struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
1091 using DivModOpConversionBase<
ModSOp>::DivModOpConversionBase;
1093 matchAndRewrite(
ModSOp op, OpAdaptor adaptor,
1094 ConversionPatternRewriter &rewriter)
const override {
1098 rewriter, maxEmulationUnknownBits, op,
1099 [](
const APInt &lhs,
const APInt &rhs) {
1102 return APInt::getZero(rhs.getBitWidth());
1103 return lhs.srem(rhs);
1112 static Value constructRippleCarry(Location loc, Value a, Value b,
1114 ConversionPatternRewriter &rewriter) {
1122 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
1124 rewriter.createOrFold<
comb::XorOp>(loc, aBit, bBit,
true);
1125 auto aEqualB = rewriter.createOrFold<synth::aig::AndInverterOp>(
1126 loc, aBitXorBBit,
true);
1127 auto pred = rewriter.createOrFold<synth::aig::AndInverterOp>(
1128 loc, aBit, bBit,
true,
false);
1130 auto aBitAndBBit = rewriter.createOrFold<
comb::AndOp>(
1131 loc, ValueRange{aEqualB,
acc},
true);
1132 acc = rewriter.createOrFold<
comb::OrOp>(loc, pred, aBitAndBBit,
true);
1145 static Value computePrefixComparison(ConversionPatternRewriter &rewriter,
1146 Location loc, SmallVector<Value> pPrefix,
1147 SmallVector<Value> gPrefix,
1148 bool includeEq, AdderArchitecture arch) {
1149 auto width = pPrefix.size();
1150 Value finalGroup, finalPropagate;
1153 case AdderArchitecture::RippleCarry:
1154 llvm_unreachable(
"Ripple-Carry should be handled separately");
1156 case AdderArchitecture::Sklanskey: {
1157 lowerSklanskeyPrefixTree(rewriter, loc, pPrefix, gPrefix);
1158 finalGroup = gPrefix[width - 1];
1159 finalPropagate = pPrefix[width - 1];
1162 case AdderArchitecture::KoggeStone:
1165 std::tie(finalPropagate, finalGroup) =
1166 LazyKoggeStonePrefixTree(rewriter, loc, width, pPrefix, gPrefix)
1167 .getFinal(width - 1);
1169 case AdderArchitecture::BrentKung: {
1170 lowerBrentKungPrefixTree(rewriter, loc, pPrefix, gPrefix);
1171 finalGroup = gPrefix[width - 1];
1172 finalPropagate = pPrefix[width - 1];
1181 return comb::OrOp::create(rewriter, loc, finalGroup, finalPropagate);
1190 static Value constructUnsignedCompare(Operation *op, Location loc, Value a,
1191 Value b,
bool isLess,
bool includeEq,
1192 ConversionPatternRewriter &rewriter) {
1196 auto width =
a.getType().getIntOrFloatBitWidth();
1199 auto arch = determineAdderArch(op, width);
1200 if (arch == AdderArchitecture::RippleCarry)
1201 return constructRippleCarry(loc, a, b, includeEq, rewriter);
1212 SmallVector<Value> eq, gt;
1219 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
1221 auto xorBit = comb::XorOp::create(rewriter, loc, aBit, bBit);
1222 eq.push_back(comb::XorOp::create(rewriter, loc, xorBit, one));
1225 auto notA = comb::XorOp::create(rewriter, loc, aBit, one);
1226 gt.push_back(comb::AndOp::create(rewriter, loc, notA, bBit));
1229 return computePrefixComparison(rewriter, loc, std::move(eq), std::move(gt),
1234 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
1235 ConversionPatternRewriter &rewriter)
const override {
1236 auto lhs = adaptor.getLhs();
1237 auto rhs = adaptor.getRhs();
1239 switch (op.getPredicate()) {
1243 case ICmpPredicate::eq:
1244 case ICmpPredicate::ceq: {
1246 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
1248 SmallVector<bool> allInverts(xorBits.size(),
true);
1249 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
1250 rewriter, op, xorBits, allInverts);
1254 case ICmpPredicate::ne:
1255 case ICmpPredicate::cne: {
1257 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
1258 replaceOpWithNewOpAndCopyNamehint<comb::OrOp>(
1259 rewriter, op,
extractBits(rewriter, xorOp),
true);
1263 case ICmpPredicate::uge:
1264 case ICmpPredicate::ugt:
1265 case ICmpPredicate::ule:
1266 case ICmpPredicate::ult: {
1267 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
1268 op.getPredicate() == ICmpPredicate::ule;
1269 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
1270 op.getPredicate() == ICmpPredicate::ule;
1272 constructUnsignedCompare(op, op.getLoc(), lhs,
1273 rhs, isLess, includeEq,
1277 case ICmpPredicate::slt:
1278 case ICmpPredicate::sle:
1279 case ICmpPredicate::sgt:
1280 case ICmpPredicate::sge: {
1281 if (lhs.getType().getIntOrFloatBitWidth() == 0)
1282 return rewriter.notifyMatchFailure(
1283 op.getLoc(),
"i0 signed comparison is unsupported");
1284 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
1285 op.getPredicate() == ICmpPredicate::sle;
1286 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
1287 op.getPredicate() == ICmpPredicate::sle;
1296 auto sameSignResult = constructUnsignedCompare(
1297 op, op.getLoc(), aRest, bRest, isLess, includeEq, rewriter);
1301 comb::XorOp::create(rewriter, op.getLoc(), signA, signB);
1304 Value diffSignResult = isLess ? signA : signB;
1307 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1308 rewriter, op, signsDiffer, diffSignResult, sameSignResult);
1319 matchAndRewrite(
ParityOp op, OpAdaptor adaptor,
1320 ConversionPatternRewriter &rewriter)
const override {
1322 replaceOpWithNewOpAndCopyNamehint<comb::XorOp>(
1323 rewriter, op,
extractBits(rewriter, adaptor.getInput()),
true);
1332 matchAndRewrite(
comb::ShlOp op, OpAdaptor adaptor,
1333 ConversionPatternRewriter &rewriter)
const override {
1334 auto width = op.getType().getIntOrFloatBitWidth();
1335 auto lhs = adaptor.getLhs();
1337 rewriter, op.getLoc(), adaptor.getRhs(), width,
1339 [&](int64_t index) {
1345 op.getLoc(), rewriter.getIntegerType(index), 0);
1348 [&](int64_t index) {
1349 assert(index < width &&
"index out of bounds");
1365 ConversionPatternRewriter &rewriter)
const override {
1366 auto width = op.getType().getIntOrFloatBitWidth();
1367 auto lhs = adaptor.getLhs();
1369 rewriter, op.getLoc(), adaptor.getRhs(), width,
1371 [&](int64_t index) {
1377 op.getLoc(), rewriter.getIntegerType(index), 0);
1380 [&](int64_t index) {
1381 assert(index < width &&
"index out of bounds");
1383 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1397 ConversionPatternRewriter &rewriter)
const override {
1398 auto width = op.getType().getIntOrFloatBitWidth();
1400 return rewriter.notifyMatchFailure(op.getLoc(),
1401 "i0 signed shift is unsupported");
1402 auto lhs = adaptor.getLhs();
1405 rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
1410 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
1412 [&](int64_t index) {
1413 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
1417 [&](int64_t index) {
1418 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1434struct ConvertCombToSynthPass
1435 :
public impl::ConvertCombToSynthBase<ConvertCombToSynthPass> {
1436 void runOnOperation()
override;
1437 using ConvertCombToSynthBase<ConvertCombToSynthPass>::ConvertCombToSynthBase;
1443 uint32_t maxEmulationUnknownBits,
1447 CombAndOpConversion, CombXorOpConversion, CombMuxOpConversion,
1448 CombParityOpConversion,
1450 CombMulOpConversion, CombICmpOpConversion,
1452 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
1454 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
1455 CombLowerVariadicOp<MulOp>>(
patterns.getContext());
1457 patterns.add(comb::convertSubToAdd);
1460 patterns.add<CombOrToMIGConversion, CombLowerVariadicOp<OrOp>,
1461 AndInverterToMIGConversion,
1463 CombAddOpConversion<
true>>(
patterns.getContext());
1465 patterns.add<CombOrToAIGConversion, MajorityInverterToAIGConversion,
1466 CombAddOpConversion<
false>>(
patterns.getContext());
1470 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
1471 CombModSOpConversion>(
patterns.getContext(),
1472 maxEmulationUnknownBits);
1475void ConvertCombToSynthPass::runOnOperation() {
1476 ConversionTarget target(getContext());
1479 target.addIllegalDialect<comb::CombDialect>();
1489 hw::AggregateConstantOp>();
1491 target.addLegalDialect<synth::SynthDialect>();
1493 if (targetIR == CombToSynthTargetIR::AIG) {
1495 target.addIllegalOp<synth::mig::MajorityInverterOp>();
1496 }
else if (targetIR == CombToSynthTargetIR::MIG) {
1497 target.addIllegalOp<synth::aig::AndInverterOp>();
1501 if (!additionalLegalOps.empty())
1502 for (
const auto &opName : additionalLegalOps)
1503 target.addLegalOp(OperationName(opName, &getContext()));
1505 RewritePatternSet
patterns(&getContext());
1507 targetIR == CombToSynthTargetIR::MIG);
1509 if (failed(mlir::applyPartialConversion(getOperation(), target,
1511 return signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits, bool lowerToMIG)
static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a, Value b, Value carry, bool useMajorityInverterOp)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static Value extractOtherThanMSB(OpBuilder &builder, Value val)
static Value extractMSB(OpBuilder &builder, Value val)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.