33#include "mlir/Pass/Pass.h"
34#include "mlir/Transforms/DialectConversion.h"
35#include "llvm/ADT/APInt.h"
36#include "llvm/ADT/PointerUnion.h"
37#include "llvm/Support/Debug.h"
40#define DEBUG_TYPE "comb-to-synth"
43#define GEN_PASS_DEF_CONVERTCOMBTOSYNTH
44#include "circt/Conversion/Passes.h.inc"
55static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
56 SmallVector<Value> bits;
57 comb::extractBits(builder, val, bits);
68template <
bool isLeftShift>
70 Value shiftAmount, int64_t maxShiftAmount,
71 llvm::function_ref<Value(int64_t)> getPadding,
72 llvm::function_ref<Value(int64_t)> getExtract) {
77 SmallVector<Value> nodes;
78 nodes.reserve(maxShiftAmount);
79 for (int64_t i = 0; i < maxShiftAmount; ++i) {
80 Value extract = getExtract(i);
81 Value padding = getPadding(i);
84 nodes.push_back(extract);
98 auto outOfBoundsValue = getPadding(maxShiftAmount);
99 assert(outOfBoundsValue &&
"outOfBoundsValue must be valid");
103 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
106 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
107 loc, ICmpPredicate::ult, shiftAmount,
111 return rewriter.createOrFold<
comb::MuxOp>(loc, inBound, result,
119 Value b, Value carry,
120 bool useMajorityInverterOp) {
121 if (useMajorityInverterOp) {
122 std::array<Value, 3> inputs = {a, b, carry};
123 std::array<bool, 3> inverts = {
false,
false,
false};
124 return synth::mig::MajorityInverterOp::create(rewriter, loc, inputs,
129 auto aXnorB = comb::XorOp::create(rewriter, loc, ValueRange{a, b},
true);
131 comb::AndOp::create(rewriter, loc, ValueRange{carry, aXnorB},
true);
132 auto aAndB = comb::AndOp::create(rewriter, loc, ValueRange{a, b},
true);
133 return comb::OrOp::create(rewriter, loc, ValueRange{andOp, aAndB},
true);
138using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
143 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
145 if (value.getType().isInteger(0))
150 int64_t totalUnknownBits = 0;
151 for (
auto concatInput : llvm::reverse(
concat.getInputs())) {
156 totalUnknownBits += unknownBits;
158 return totalUnknownBits;
163 values.push_back(constant.getValueAttr());
169 values.push_back(value);
170 return hw::getBitWidth(value.getType());
176 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
178 uint32_t bitPos = 0, unknownPos = 0;
179 APInt result(width, 0);
180 for (
auto constantOrValue : constantOrValues) {
182 if (
auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
183 elemWidth = constant.getValue().getBitWidth();
184 result.insertBits(constant.getValue(), bitPos);
186 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
187 assert(elemWidth >= 0 &&
"unknown bit width");
188 assert(elemWidth + unknownPos < 32 &&
"unknown bit width too large");
190 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
191 result.insertBits(APInt(elemWidth, usedBits), bitPos);
192 unknownPos += elemWidth;
204 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
206 llvm::function_ref<APInt(
const APInt &,
const APInt &)> emulate) {
207 SmallVector<ConstantOrValue> lhsValues, rhsValues;
209 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
210 "op must be a single result binary operation");
212 auto lhs = op->getOperand(0);
213 auto rhs = op->getOperand(1);
214 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
215 auto loc = op->getLoc();
220 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
223 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
224 if (totalUnknownBits > maxEmulationUnknownBits)
227 SmallVector<Value> emulatedResults;
228 emulatedResults.reserve(1 << totalUnknownBits);
231 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
233 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
234 auto it = constantPool.find(attr);
235 if (it != constantPool.end())
238 constantPool[attr] = constant;
242 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
243 lhsMask < lhsMaskEnd; ++lhsMask) {
245 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
246 rhsMask < rhsMaskEnd; ++rhsMask) {
249 emulatedResults.push_back(
getConstant(emulate(lhsValue, rhsValue)));
254 SmallVector<Value> selectors;
255 selectors.reserve(totalUnknownBits);
256 for (
auto &concatedValues : {rhsValues, lhsValues})
257 for (
auto valueOrConstant : concatedValues) {
258 auto value = dyn_cast<Value>(valueOrConstant);
264 assert(totalUnknownBits ==
static_cast<int64_t
>(selectors.size()) &&
265 "number of selectors must match");
266 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
284 matchAndRewrite(
AndOp op, OpAdaptor adaptor,
285 ConversionPatternRewriter &rewriter)
const override {
286 SmallVector<bool> nonInverts(adaptor.getInputs().size(),
false);
287 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
288 rewriter, op, adaptor.getInputs(), nonInverts);
298 matchAndRewrite(
OrOp op, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter)
const override {
301 SmallVector<bool> allInverts(adaptor.getInputs().size(),
true);
302 auto andOp = synth::aig::AndInverterOp::create(
303 rewriter, op.getLoc(), adaptor.getInputs(), allInverts);
304 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
314 matchAndRewrite(
OrOp op, OpAdaptor adaptor,
315 ConversionPatternRewriter &rewriter)
const override {
316 if (op.getNumOperands() != 2)
318 SmallVector<Value, 3> inputs(adaptor.getInputs());
320 rewriter, op.getLoc(),
321 APInt::getAllOnes(hw::getBitWidth(op.getType())));
322 inputs.push_back(one);
323 std::array<bool, 3> inverts = {
false,
false,
false};
324 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
325 rewriter, op, inputs, inverts);
330struct AndInverterToMIGConversion
334 matchAndRewrite(synth::aig::AndInverterOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter)
const override {
336 if (op.getNumOperands() > 2)
338 if (op.getNumOperands() == 1) {
339 SmallVector<bool, 1> inverts{op.getInverted()[0]};
340 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
341 rewriter, op, adaptor.getInputs(), inverts);
344 SmallVector<Value, 3> inputs(adaptor.getInputs());
346 rewriter, op.getLoc(), APInt::getZero(hw::getBitWidth(op.getType())));
347 inputs.push_back(one);
348 SmallVector<bool, 3> inverts(adaptor.getInverted());
349 inverts.push_back(
false);
350 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
351 rewriter, op, inputs, inverts);
361 matchAndRewrite(
XorOp op, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter)
const override {
363 if (op.getNumOperands() != 2)
369 auto inputs = adaptor.getInputs();
370 SmallVector<bool> allInverts(inputs.size(),
true);
371 SmallVector<bool> allNotInverts(inputs.size(),
false);
373 auto notAAndNotB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
375 auto aAndB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
376 inputs, allNotInverts);
378 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
379 rewriter, op, notAAndNotB, aAndB,
386template <
typename OpTy>
391 matchAndRewrite(OpTy op, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter)
const override {
399 ConversionPatternRewriter &rewriter) {
401 switch (operands.size()) {
403 llvm_unreachable(
"cannot be called with empty operand range");
410 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs},
true);
412 auto firstHalf = operands.size() / 2;
417 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs},
true);
427 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter)
const override {
429 Value cond = op.getCond();
430 auto trueVal = op.getTrueValue();
431 auto falseVal = op.getFalseValue();
433 if (!op.getType().isInteger()) {
435 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
443 if (!trueVal.getType().isInteger(1))
444 cond = comb::ReplicateOp::create(rewriter, op.getLoc(), trueVal.getType(),
449 synth::aig::AndInverterOp::create(rewriter, op.getLoc(), cond, trueVal);
450 auto rhs = synth::aig::AndInverterOp::create(rewriter, op.getLoc(), cond,
451 falseVal,
true,
false);
453 Value result = comb::OrOp::create(rewriter, op.getLoc(), lhs, rhs);
455 if (result.getType() != op.getType())
463template <
bool lowerToMIG>
467 matchAndRewrite(
AddOp op, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter)
const override {
469 auto inputs = adaptor.getInputs();
472 if (inputs.size() != 2)
475 auto width = op.getType().getIntOrFloatBitWidth();
478 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
484 lowerRippleCarryAdder(op, inputs, rewriter);
486 lowerParallelPrefixAdder(op, inputs, rewriter);
492 void lowerRippleCarryAdder(
comb::AddOp op, ValueRange inputs,
493 ConversionPatternRewriter &rewriter)
const {
494 auto width = op.getType().getIntOrFloatBitWidth();
500 SmallVector<Value> results;
501 results.resize(width);
502 for (int64_t i = 0; i < width; ++i) {
503 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
505 xorOperands.push_back(carry);
509 results[width - i - 1] =
510 comb::XorOp::create(rewriter, op.getLoc(), xorOperands,
true);
519 carry = comb::AndOp::create(rewriter, op.getLoc(),
520 ValueRange{aBits[i], bBits[i]},
true);
527 LLVM_DEBUG(llvm::dbgs() <<
"Lower comb.add to Ripple-Carry Adder of width "
530 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
536 void lowerParallelPrefixAdder(
comb::AddOp op, ValueRange inputs,
537 ConversionPatternRewriter &rewriter)
const {
538 auto width = op.getType().getIntOrFloatBitWidth();
543 SmallVector<Value> p, g;
547 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
549 p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
551 g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
555 llvm::dbgs() <<
"Lower comb.add to Parallel-Prefix of width " << width
556 <<
"\n--------------------------------------- Init\n";
558 for (int64_t i = 0; i < width; ++i) {
560 llvm::dbgs() <<
"P0" << i <<
" = A" << i <<
" XOR B" << i <<
"\n";
562 llvm::dbgs() <<
"G0" << i <<
" = A" << i <<
" AND B" << i <<
"\n";
567 SmallVector<Value> pPrefix = p;
568 SmallVector<Value> gPrefix = g;
570 lowerKoggeStonePrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
572 lowerBrentKungPrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
576 SmallVector<Value> results;
577 results.resize(width);
579 results[width - 1] = p[0];
583 for (int64_t i = 1; i < width; ++i)
584 results[width - 1 - i] =
585 comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);
587 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
590 llvm::dbgs() <<
"--------------------------------------- Completion\n"
592 for (int64_t i = 1; i < width; ++i)
593 llvm::dbgs() <<
"RES" << i <<
" = P" << i <<
" XOR G" << i - 1 <<
"\n";
600 void lowerKoggeStonePrefixTree(
comb::AddOp op, ValueRange inputs,
601 ConversionPatternRewriter &rewriter,
602 SmallVector<Value> &pPrefix,
603 SmallVector<Value> &gPrefix)
const {
604 auto width = op.getType().getIntOrFloatBitWidth();
605 SmallVector<Value> pPrefixNew = pPrefix;
606 SmallVector<Value> gPrefixNew = gPrefix;
609 for (int64_t stride = 1; stride < width; stride *= 2) {
610 for (int64_t i = stride; i < width; ++i) {
611 int64_t j = i - stride;
615 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
617 comb::OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
621 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
623 pPrefix = pPrefixNew;
624 gPrefix = gPrefixNew;
628 for (int64_t stride = 1; stride < width; stride *= 2) {
630 <<
"--------------------------------------- Kogge-Stone Stage "
632 for (int64_t i = stride; i < width; ++i) {
633 int64_t j = i - stride;
635 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
636 <<
" OR (P" << i << stage <<
" AND G" << j << stage
640 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
641 <<
" AND P" << j << stage <<
"\n";
651 void lowerBrentKungPrefixTree(
comb::AddOp op, ValueRange inputs,
652 ConversionPatternRewriter &rewriter,
653 SmallVector<Value> &pPrefix,
654 SmallVector<Value> &gPrefix)
const {
655 auto width = op.getType().getIntOrFloatBitWidth();
656 SmallVector<Value> pPrefixNew = pPrefix;
657 SmallVector<Value> gPrefixNew = gPrefix;
661 for (stride = 1; stride < width; stride *= 2) {
662 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
663 int64_t j = i - stride;
667 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
669 comb::OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
673 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
675 pPrefix = pPrefixNew;
676 gPrefix = gPrefixNew;
680 for (; stride > 0; stride /= 2) {
681 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
682 int64_t j = i - stride;
686 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
688 comb::OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
692 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
694 pPrefix = pPrefixNew;
695 gPrefix = gPrefixNew;
700 for (stride = 1; stride < width; stride *= 2) {
701 llvm::dbgs() <<
"--------------------------------------- Brent-Kung FW "
702 << stage <<
" : Stride " << stride <<
"\n";
703 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
704 int64_t j = i - stride;
707 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
708 <<
" OR (P" << i << stage <<
" AND G" << j << stage
712 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
713 <<
" AND P" << j << stage <<
"\n";
718 for (; stride > 0; stride /= 2) {
719 if (stride * 3 - 1 < width)
721 <<
"--------------------------------------- Brent-Kung BW "
722 << stage <<
" : Stride " << stride <<
"\n";
724 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
725 int64_t j = i - stride;
728 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
729 <<
" OR (P" << i << stage <<
" AND G" << j << stage
733 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
734 <<
" AND P" << j << stage <<
"\n";
745 matchAndRewrite(
SubOp op, OpAdaptor adaptor,
746 ConversionPatternRewriter &rewriter)
const override {
747 auto lhs = op.getLhs();
748 auto rhs = op.getRhs();
752 auto notRhs = synth::aig::AndInverterOp::create(rewriter, op.getLoc(), rhs,
755 replaceOpWithNewOpAndCopyNamehint<comb::AddOp>(
756 rewriter, op, ValueRange{lhs, notRhs, one},
true);
765 matchAndRewrite(
MulOp op, OpAdaptor adaptor,
766 ConversionPatternRewriter &rewriter)
const override {
767 if (adaptor.getInputs().size() != 2)
770 Location loc = op.getLoc();
771 Value a = adaptor.getInputs()[0];
772 Value b = adaptor.getInputs()[1];
773 unsigned width = op.getType().getIntOrFloatBitWidth();
782 SmallVector<Value> aBits =
extractBits(rewriter, a);
783 SmallVector<Value> bBits =
extractBits(rewriter, b);
788 SmallVector<SmallVector<Value>> partialProducts;
789 partialProducts.reserve(width);
790 for (
unsigned i = 0; i < width; ++i) {
791 SmallVector<Value> row(i, falseValue);
794 for (
unsigned j = 0; i + j < width; ++j)
796 rewriter.createOrFold<
comb::AndOp>(loc, aBits[j], bBits[i]));
798 partialProducts.push_back(row);
803 rewriter.replaceOp(op, partialProducts[0][0]);
809 auto addends = comp.compressToHeight(rewriter, 2);
812 auto newAdd = comb::AddOp::create(rewriter, loc, addends,
true);
818template <
typename OpTy>
820 DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
822 maxEmulationUnknownBits(maxEmulationUnknownBits) {
823 assert(maxEmulationUnknownBits < 32 &&
824 "maxEmulationUnknownBits must be less than 32");
826 const int64_t maxEmulationUnknownBits;
829struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
830 using DivModOpConversionBase<
DivUOp>::DivModOpConversionBase;
832 matchAndRewrite(
DivUOp op, OpAdaptor adaptor,
833 ConversionPatternRewriter &rewriter)
const override {
835 if (
auto rhsConstantOp = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>())
836 if (rhsConstantOp.getValue().isPowerOf2()) {
838 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
839 size_t width = op.getType().getIntOrFloatBitWidth();
841 op.getLoc(), adaptor.getLhs(), extractAmount,
842 width - extractAmount);
844 APInt::getZero(extractAmount));
845 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
846 rewriter, op, op.getType(), ArrayRef<Value>{constZero, upperBits});
853 rewriter, maxEmulationUnknownBits, op,
854 [](
const APInt &lhs,
const APInt &rhs) {
857 return APInt::getZero(rhs.getBitWidth());
858 return lhs.udiv(rhs);
863struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
864 using DivModOpConversionBase<
ModUOp>::DivModOpConversionBase;
866 matchAndRewrite(
ModUOp op, OpAdaptor adaptor,
867 ConversionPatternRewriter &rewriter)
const override {
869 if (
auto rhsConstantOp = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>())
870 if (rhsConstantOp.getValue().isPowerOf2()) {
872 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
873 size_t width = op.getType().getIntOrFloatBitWidth();
875 op.getLoc(), adaptor.getLhs(), 0, extractAmount);
877 rewriter, op.getLoc(), APInt::getZero(width - extractAmount));
878 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
879 rewriter, op, op.getType(), ArrayRef<Value>{constZero, lowerBits});
886 rewriter, maxEmulationUnknownBits, op,
887 [](
const APInt &lhs,
const APInt &rhs) {
890 return APInt::getZero(rhs.getBitWidth());
891 return lhs.urem(rhs);
896struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
897 using DivModOpConversionBase<
DivSOp>::DivModOpConversionBase;
900 matchAndRewrite(
DivSOp op, OpAdaptor adaptor,
901 ConversionPatternRewriter &rewriter)
const override {
905 rewriter, maxEmulationUnknownBits, op,
906 [](
const APInt &lhs,
const APInt &rhs) {
909 return APInt::getZero(rhs.getBitWidth());
910 return lhs.sdiv(rhs);
915struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
916 using DivModOpConversionBase<
ModSOp>::DivModOpConversionBase;
918 matchAndRewrite(
ModSOp op, OpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter)
const override {
923 rewriter, maxEmulationUnknownBits, op,
924 [](
const APInt &lhs,
const APInt &rhs) {
927 return APInt::getZero(rhs.getBitWidth());
928 return lhs.srem(rhs);
935 static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
936 ArrayRef<Value> bBits,
bool isLess,
938 ConversionPatternRewriter &rewriter) {
947 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
949 rewriter.createOrFold<
comb::XorOp>(op.getLoc(), aBit, bBit,
true);
950 auto aEqualB = rewriter.createOrFold<synth::aig::AndInverterOp>(
951 op.getLoc(), aBitXorBBit,
true);
952 auto pred = rewriter.createOrFold<synth::aig::AndInverterOp>(
953 op.getLoc(), aBit, bBit, isLess, !isLess);
955 auto aBitAndBBit = rewriter.createOrFold<
comb::AndOp>(
956 op.getLoc(), ValueRange{aEqualB, acc},
true);
957 acc = rewriter.createOrFold<
comb::OrOp>(op.getLoc(), pred, aBitAndBBit,
964 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
965 ConversionPatternRewriter &rewriter)
const override {
966 auto lhs = adaptor.getLhs();
967 auto rhs = adaptor.getRhs();
969 switch (op.getPredicate()) {
973 case ICmpPredicate::eq:
974 case ICmpPredicate::ceq: {
976 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
978 SmallVector<bool> allInverts(xorBits.size(),
true);
979 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
980 rewriter, op, xorBits, allInverts);
984 case ICmpPredicate::ne:
985 case ICmpPredicate::cne: {
987 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
988 replaceOpWithNewOpAndCopyNamehint<comb::OrOp>(
993 case ICmpPredicate::uge:
994 case ICmpPredicate::ugt:
995 case ICmpPredicate::ule:
996 case ICmpPredicate::ult: {
997 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
998 op.getPredicate() == ICmpPredicate::ule;
999 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
1000 op.getPredicate() == ICmpPredicate::ule;
1004 constructUnsignedCompare(op, aBits, bBits,
1009 case ICmpPredicate::slt:
1010 case ICmpPredicate::sle:
1011 case ICmpPredicate::sgt:
1012 case ICmpPredicate::sge: {
1013 if (lhs.getType().getIntOrFloatBitWidth() == 0)
1014 return rewriter.notifyMatchFailure(
1015 op.getLoc(),
"i0 signed comparison is unsupported");
1016 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
1017 op.getPredicate() == ICmpPredicate::sle;
1018 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
1019 op.getPredicate() == ICmpPredicate::sle;
1025 auto signA = aBits.back();
1026 auto signB = bBits.back();
1029 auto sameSignResult = constructUnsignedCompare(
1030 op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess,
1031 includeEq, rewriter);
1035 comb::XorOp::create(rewriter, op.getLoc(), signA, signB);
1038 Value diffSignResult = isLess ? signA : signB;
1041 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1042 rewriter, op, signsDiffer, diffSignResult, sameSignResult);
1053 matchAndRewrite(
ParityOp op, OpAdaptor adaptor,
1054 ConversionPatternRewriter &rewriter)
const override {
1056 replaceOpWithNewOpAndCopyNamehint<comb::XorOp>(
1057 rewriter, op,
extractBits(rewriter, adaptor.getInput()),
true);
1066 matchAndRewrite(
comb::ShlOp op, OpAdaptor adaptor,
1067 ConversionPatternRewriter &rewriter)
const override {
1068 auto width = op.getType().getIntOrFloatBitWidth();
1069 auto lhs = adaptor.getLhs();
1071 rewriter, op.getLoc(), adaptor.getRhs(), width,
1073 [&](int64_t index) {
1079 op.getLoc(), rewriter.getIntegerType(index), 0);
1082 [&](int64_t index) {
1083 assert(index < width &&
"index out of bounds");
1099 ConversionPatternRewriter &rewriter)
const override {
1100 auto width = op.getType().getIntOrFloatBitWidth();
1101 auto lhs = adaptor.getLhs();
1103 rewriter, op.getLoc(), adaptor.getRhs(), width,
1105 [&](int64_t index) {
1111 op.getLoc(), rewriter.getIntegerType(index), 0);
1114 [&](int64_t index) {
1115 assert(index < width &&
"index out of bounds");
1117 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1131 ConversionPatternRewriter &rewriter)
const override {
1132 auto width = op.getType().getIntOrFloatBitWidth();
1134 return rewriter.notifyMatchFailure(op.getLoc(),
1135 "i0 signed shift is unsupported");
1136 auto lhs = adaptor.getLhs();
1139 rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
1144 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
1146 [&](int64_t index) {
1147 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
1151 [&](int64_t index) {
1152 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1168struct ConvertCombToSynthPass
1169 :
public impl::ConvertCombToSynthBase<ConvertCombToSynthPass> {
1170 void runOnOperation()
override;
1171 using ConvertCombToSynthBase<ConvertCombToSynthPass>::ConvertCombToSynthBase;
1177 uint32_t maxEmulationUnknownBits,
1181 CombAndOpConversion, CombXorOpConversion, CombMuxOpConversion,
1182 CombParityOpConversion,
1184 CombSubOpConversion, CombMulOpConversion, CombICmpOpConversion,
1186 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
1188 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
1189 CombLowerVariadicOp<MulOp>>(
patterns.getContext());
1192 patterns.add<CombOrToMIGConversion, CombLowerVariadicOp<OrOp>,
1193 AndInverterToMIGConversion,
1195 CombAddOpConversion<
true>>(
patterns.getContext());
1197 patterns.add<CombOrToAIGConversion, CombAddOpConversion<
false>>(
1202 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
1203 CombModSOpConversion>(
patterns.getContext(),
1204 maxEmulationUnknownBits);
1207void ConvertCombToSynthPass::runOnOperation() {
1208 ConversionTarget target(getContext());
1211 target.addIllegalDialect<comb::CombDialect>();
1221 hw::AggregateConstantOp>();
1223 target.addLegalDialect<synth::SynthDialect>();
1225 if (targetIR == CombToSynthTargetIR::AIG) {
1227 target.addIllegalOp<synth::mig::MajorityInverterOp>();
1228 }
else if (targetIR == CombToSynthTargetIR::MIG) {
1229 target.addIllegalOp<synth::aig::AndInverterOp>();
1233 if (!additionalLegalOps.empty())
1234 for (
const auto &opName : additionalLegalOps)
1235 target.addLegalOp(OperationName(opName, &getContext()));
1237 RewritePatternSet
patterns(&getContext());
1239 targetIR == CombToSynthTargetIR::MIG);
1241 if (failed(mlir::applyPartialConversion(getOperation(), target,
1243 return signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits, bool lowerToMIG)
static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a, Value b, Value carry, bool useMajorityInverterOp)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.