18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "llvm/ADT/PointerUnion.h"
21#include "llvm/Support/Debug.h"
23#define DEBUG_TYPE "comb-to-aig"
26#define GEN_PASS_DEF_CONVERTCOMBTOAIG
27#include "circt/Conversion/Passes.h.inc"
38static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
39 SmallVector<Value> bits;
40 comb::extractBits(builder, val, bits);
38static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
…}
51template <
bool isLeftShift>
53 Value shiftAmount, int64_t maxShiftAmount,
54 llvm::function_ref<Value(int64_t)> getPadding,
55 llvm::function_ref<Value(int64_t)> getExtract) {
60 SmallVector<Value> nodes;
61 nodes.reserve(maxShiftAmount);
62 for (int64_t i = 0; i < maxShiftAmount; ++i) {
63 Value extract = getExtract(i);
64 Value padding = getPadding(i);
67 nodes.push_back(extract);
81 auto outOfBoundsValue = getPadding(maxShiftAmount);
82 assert(outOfBoundsValue &&
"outOfBoundsValue must be valid");
86 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
89 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
90 loc, ICmpPredicate::ult, shiftAmount,
94 return rewriter.createOrFold<
comb::MuxOp>(loc, inBound, result,
100using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
105 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
107 if (value.getType().isInteger(0))
112 int64_t totalUnknownBits = 0;
113 for (
auto concatInput : llvm::reverse(
concat.getInputs())) {
118 totalUnknownBits += unknownBits;
120 return totalUnknownBits;
125 values.push_back(constant.getValueAttr());
131 values.push_back(value);
132 return hw::getBitWidth(value.getType());
138 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
140 uint32_t bitPos = 0, unknownPos = 0;
141 APInt result(width, 0);
142 for (
auto constantOrValue : constantOrValues) {
144 if (
auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
145 elemWidth = constant.getValue().getBitWidth();
146 result.insertBits(constant.getValue(), bitPos);
148 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
149 assert(elemWidth >= 0 &&
"unknown bit width");
150 assert(elemWidth + unknownPos < 32 &&
"unknown bit width too large");
152 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
153 result.insertBits(APInt(elemWidth, usedBits), bitPos);
154 unknownPos += elemWidth;
166 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
168 llvm::function_ref<APInt(
const APInt &,
const APInt &)> emulate) {
169 SmallVector<ConstantOrValue> lhsValues, rhsValues;
171 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
172 "op must be a single result binary operation");
174 auto lhs = op->getOperand(0);
175 auto rhs = op->getOperand(1);
176 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
177 auto loc = op->getLoc();
182 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
185 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
186 if (totalUnknownBits > maxEmulationUnknownBits)
189 SmallVector<Value> emulatedResults;
190 emulatedResults.reserve(1 << totalUnknownBits);
193 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
195 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
196 auto it = constantPool.find(attr);
197 if (it != constantPool.end())
200 constantPool[attr] = constant;
204 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
205 lhsMask < lhsMaskEnd; ++lhsMask) {
207 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
208 rhsMask < rhsMaskEnd; ++rhsMask) {
211 emulatedResults.push_back(
getConstant(emulate(lhsValue, rhsValue)));
216 SmallVector<Value> selectors;
217 selectors.reserve(totalUnknownBits);
218 for (
auto &concatedValues : {rhsValues, lhsValues})
219 for (
auto valueOrConstant : concatedValues) {
220 auto value = dyn_cast<Value>(valueOrConstant);
226 assert(totalUnknownBits ==
static_cast<int64_t
>(selectors.size()) &&
227 "number of selectors must match");
228 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
246 matchAndRewrite(
AndOp op, OpAdaptor adaptor,
247 ConversionPatternRewriter &rewriter)
const override {
248 SmallVector<bool> nonInverts(adaptor.getInputs().size(),
false);
249 replaceOpWithNewOpAndCopyNamehint<aig::AndInverterOp>(
250 rewriter, op, adaptor.getInputs(), nonInverts);
260 matchAndRewrite(
OrOp op, OpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter)
const override {
263 SmallVector<bool> allInverts(adaptor.getInputs().size(),
true);
264 auto andOp = aig::AndInverterOp::create(rewriter, op.getLoc(),
265 adaptor.getInputs(), allInverts);
266 replaceOpWithNewOpAndCopyNamehint<aig::AndInverterOp>(rewriter, op, andOp,
277 matchAndRewrite(
XorOp op, OpAdaptor adaptor,
278 ConversionPatternRewriter &rewriter)
const override {
279 if (op.getNumOperands() != 2)
285 auto inputs = adaptor.getInputs();
286 SmallVector<bool> allInverts(inputs.size(),
true);
287 SmallVector<bool> allNotInverts(inputs.size(),
false);
290 aig::AndInverterOp::create(rewriter, op.getLoc(), inputs, allInverts);
291 auto aAndB = aig::AndInverterOp::create(rewriter, op.getLoc(), inputs,
294 replaceOpWithNewOpAndCopyNamehint<aig::AndInverterOp>(rewriter, op,
302template <
typename OpTy>
307 matchAndRewrite(OpTy op, OpAdaptor adaptor,
308 ConversionPatternRewriter &rewriter)
const override {
315 ConversionPatternRewriter &rewriter) {
317 switch (operands.size()) {
319 llvm_unreachable(
"cannot be called with empty operand range");
326 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs},
true);
328 auto firstHalf = operands.size() / 2;
333 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs},
true);
343 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
344 ConversionPatternRewriter &rewriter)
const override {
347 Value cond = op.getCond();
348 auto trueVal = op.getTrueValue();
349 auto falseVal = op.getFalseValue();
351 if (!op.getType().isInteger()) {
353 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
361 if (!trueVal.getType().isInteger(1))
362 cond = comb::ReplicateOp::create(rewriter, op.getLoc(), trueVal.getType(),
366 auto lhs = aig::AndInverterOp::create(rewriter, op.getLoc(), cond, trueVal);
367 auto rhs = aig::AndInverterOp::create(rewriter, op.getLoc(), cond, falseVal,
370 Value result = comb::OrOp::create(rewriter, op.getLoc(), lhs, rhs);
372 if (result.getType() != op.getType())
383 matchAndRewrite(
AddOp op, OpAdaptor adaptor,
384 ConversionPatternRewriter &rewriter)
const override {
385 auto inputs = adaptor.getInputs();
388 if (inputs.size() != 2)
391 auto width = op.getType().getIntOrFloatBitWidth();
394 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
400 lowerRippleCarryAdder(op, inputs, rewriter);
402 lowerParallelPrefixAdder(op, inputs, rewriter);
408 void lowerRippleCarryAdder(
comb::AddOp op, ValueRange inputs,
409 ConversionPatternRewriter &rewriter)
const {
410 auto width = op.getType().getIntOrFloatBitWidth();
416 SmallVector<Value> results;
417 results.resize(width);
418 for (int64_t i = 0; i < width; ++i) {
419 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
421 xorOperands.push_back(carry);
425 results[width - i - 1] =
426 comb::XorOp::create(rewriter, op.getLoc(), xorOperands,
true);
433 Value nextCarry = comb::AndOp::create(
434 rewriter, op.getLoc(), ValueRange{aBits[i], bBits[i]},
true);
441 auto aXnorB = comb::XorOp::create(rewriter, op.getLoc(),
442 ValueRange{aBits[i], bBits[i]},
true);
443 auto andOp = comb::AndOp::create(rewriter, op.getLoc(),
444 ValueRange{carry, aXnorB},
true);
445 carry = comb::OrOp::create(rewriter, op.getLoc(),
446 ValueRange{andOp, nextCarry},
true);
448 LLVM_DEBUG(llvm::dbgs() <<
"Lower comb.add to Ripple-Carry Adder of width "
451 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
457 void lowerParallelPrefixAdder(
comb::AddOp op, ValueRange inputs,
458 ConversionPatternRewriter &rewriter)
const {
459 auto width = op.getType().getIntOrFloatBitWidth();
464 SmallVector<Value> p, g;
468 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
470 p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
472 g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
476 llvm::dbgs() <<
"Lower comb.add to Parallel-Prefix of width " << width
477 <<
"\n--------------------------------------- Init\n";
479 for (int64_t i = 0; i < width; ++i) {
481 llvm::dbgs() <<
"P0" << i <<
" = A" << i <<
" XOR B" << i <<
"\n";
483 llvm::dbgs() <<
"G0" << i <<
" = A" << i <<
" AND B" << i <<
"\n";
488 SmallVector<Value> pPrefix = p;
489 SmallVector<Value> gPrefix = g;
491 lowerKoggeStonePrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
493 lowerBrentKungPrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
497 SmallVector<Value> results;
498 results.resize(width);
500 results[width - 1] = p[0];
504 for (int64_t i = 1; i < width; ++i)
505 results[width - 1 - i] =
506 comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);
508 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
511 llvm::dbgs() <<
"--------------------------------------- Completion\n"
513 for (int64_t i = 1; i < width; ++i)
514 llvm::dbgs() <<
"RES" << i <<
" = P" << i <<
" XOR G" << i - 1 <<
"\n";
521 void lowerKoggeStonePrefixTree(
comb::AddOp op, ValueRange inputs,
522 ConversionPatternRewriter &rewriter,
523 SmallVector<Value> &pPrefix,
524 SmallVector<Value> &gPrefix)
const {
525 auto width = op.getType().getIntOrFloatBitWidth();
528 for (int64_t stride = 1; stride < width; stride *= 2) {
529 for (int64_t i = stride; i < width; ++i) {
530 int64_t j = i - stride;
533 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
535 comb::OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
539 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
544 for (int64_t stride = 1; stride < width; stride *= 2) {
546 <<
"--------------------------------------- Kogge-Stone Stage "
548 for (int64_t i = stride; i < width; ++i) {
549 int64_t j = i - stride;
551 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
552 <<
" OR (P" << i << stage <<
" AND G" << j << stage
556 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
557 <<
" AND P" << j << stage <<
"\n";
567 void lowerBrentKungPrefixTree(
comb::AddOp op, ValueRange inputs,
568 ConversionPatternRewriter &rewriter,
569 SmallVector<Value> &pPrefix,
570 SmallVector<Value> &gPrefix)
const {
571 auto width = op.getType().getIntOrFloatBitWidth();
576 for (stride = 1; stride < width; stride *= 2) {
577 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
578 int64_t j = i - stride;
582 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
584 comb::OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
588 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
593 for (; stride > 0; stride /= 2) {
594 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
595 int64_t j = i - stride;
599 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
600 gPrefix[i] = OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
604 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
610 for (stride = 1; stride < width; stride *= 2) {
611 llvm::dbgs() <<
"--------------------------------------- Brent-Kung FW "
612 << stage <<
" : Stride " << stride <<
"\n";
613 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
614 int64_t j = i - stride;
617 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
618 <<
" OR (P" << i << stage <<
" AND G" << j << stage
622 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
623 <<
" AND P" << j << stage <<
"\n";
628 for (; stride > 0; stride /= 2) {
629 if (stride * 3 - 1 < width)
631 <<
"--------------------------------------- Brent-Kung BW "
632 << stage <<
" : Stride " << stride <<
"\n";
634 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
635 int64_t j = i - stride;
638 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
639 <<
" OR (P" << i << stage <<
" AND G" << j << stage
643 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
644 <<
" AND P" << j << stage <<
"\n";
655 matchAndRewrite(
SubOp op, OpAdaptor adaptor,
656 ConversionPatternRewriter &rewriter)
const override {
657 auto lhs = op.getLhs();
658 auto rhs = op.getRhs();
662 auto notRhs = aig::AndInverterOp::create(rewriter, op.getLoc(), rhs,
665 replaceOpWithNewOpAndCopyNamehint<comb::AddOp>(
666 rewriter, op, ValueRange{lhs, notRhs, one},
true);
675 matchAndRewrite(
MulOp op, OpAdaptor adaptor,
676 ConversionPatternRewriter &rewriter)
const override {
677 if (adaptor.getInputs().size() != 2)
680 Location loc = op.getLoc();
681 Value a = adaptor.getInputs()[0];
682 Value b = adaptor.getInputs()[1];
683 unsigned width = op.getType().getIntOrFloatBitWidth();
692 SmallVector<Value> aBits =
extractBits(rewriter, a);
693 SmallVector<Value> bBits =
extractBits(rewriter, b);
698 SmallVector<SmallVector<Value>> partialProducts;
699 partialProducts.reserve(width);
700 for (
unsigned i = 0; i < width; ++i) {
701 SmallVector<Value> row(i, falseValue);
704 for (
unsigned j = 0; i + j < width; ++j)
706 rewriter.createOrFold<
comb::AndOp>(loc, aBits[j], bBits[i]));
708 partialProducts.push_back(row);
713 rewriter.replaceOp(op, partialProducts[0][0]);
719 comb::wallaceReduction(rewriter, loc, width, 2, partialProducts);
721 auto newAdd = comb::AddOp::create(rewriter, loc, addends,
true);
727template <
typename OpTy>
729 DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
731 maxEmulationUnknownBits(maxEmulationUnknownBits) {
732 assert(maxEmulationUnknownBits < 32 &&
733 "maxEmulationUnknownBits must be less than 32");
735 const int64_t maxEmulationUnknownBits;
738struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
739 using DivModOpConversionBase<
DivUOp>::DivModOpConversionBase;
741 matchAndRewrite(
DivUOp op, OpAdaptor adaptor,
742 ConversionPatternRewriter &rewriter)
const override {
744 if (
auto rhsConstantOp = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>())
745 if (rhsConstantOp.getValue().isPowerOf2()) {
747 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
748 size_t width = op.getType().getIntOrFloatBitWidth();
750 op.getLoc(), adaptor.getLhs(), extractAmount,
751 width - extractAmount);
753 APInt::getZero(extractAmount));
754 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
755 rewriter, op, op.getType(), ArrayRef<Value>{constZero, upperBits});
762 rewriter, maxEmulationUnknownBits, op,
763 [](
const APInt &lhs,
const APInt &rhs) {
766 return APInt::getZero(rhs.getBitWidth());
767 return lhs.udiv(rhs);
772struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
773 using DivModOpConversionBase<
ModUOp>::DivModOpConversionBase;
775 matchAndRewrite(
ModUOp op, OpAdaptor adaptor,
776 ConversionPatternRewriter &rewriter)
const override {
778 if (
auto rhsConstantOp = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>())
779 if (rhsConstantOp.getValue().isPowerOf2()) {
781 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
782 size_t width = op.getType().getIntOrFloatBitWidth();
784 op.getLoc(), adaptor.getLhs(), 0, extractAmount);
786 rewriter, op.getLoc(), APInt::getZero(width - extractAmount));
787 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
788 rewriter, op, op.getType(), ArrayRef<Value>{constZero, lowerBits});
795 rewriter, maxEmulationUnknownBits, op,
796 [](
const APInt &lhs,
const APInt &rhs) {
799 return APInt::getZero(rhs.getBitWidth());
800 return lhs.urem(rhs);
805struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
806 using DivModOpConversionBase<
DivSOp>::DivModOpConversionBase;
809 matchAndRewrite(
DivSOp op, OpAdaptor adaptor,
810 ConversionPatternRewriter &rewriter)
const override {
814 rewriter, maxEmulationUnknownBits, op,
815 [](
const APInt &lhs,
const APInt &rhs) {
818 return APInt::getZero(rhs.getBitWidth());
819 return lhs.sdiv(rhs);
824struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
825 using DivModOpConversionBase<
ModSOp>::DivModOpConversionBase;
827 matchAndRewrite(
ModSOp op, OpAdaptor adaptor,
828 ConversionPatternRewriter &rewriter)
const override {
832 rewriter, maxEmulationUnknownBits, op,
833 [](
const APInt &lhs,
const APInt &rhs) {
836 return APInt::getZero(rhs.getBitWidth());
837 return lhs.srem(rhs);
844 static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
845 ArrayRef<Value> bBits,
bool isLess,
847 ConversionPatternRewriter &rewriter) {
856 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
858 rewriter.createOrFold<
comb::XorOp>(op.getLoc(), aBit, bBit,
true);
859 auto aEqualB = rewriter.createOrFold<aig::AndInverterOp>(
860 op.getLoc(), aBitXorBBit,
true);
861 auto pred = rewriter.createOrFold<aig::AndInverterOp>(
862 op.getLoc(), aBit, bBit, isLess, !isLess);
864 auto aBitAndBBit = rewriter.createOrFold<
comb::AndOp>(
865 op.getLoc(), ValueRange{aEqualB, acc},
true);
866 acc = rewriter.createOrFold<
comb::OrOp>(op.getLoc(), pred, aBitAndBBit,
873 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
874 ConversionPatternRewriter &rewriter)
const override {
875 auto lhs = adaptor.getLhs();
876 auto rhs = adaptor.getRhs();
878 switch (op.getPredicate()) {
882 case ICmpPredicate::eq:
883 case ICmpPredicate::ceq: {
885 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
887 SmallVector<bool> allInverts(xorBits.size(),
true);
888 replaceOpWithNewOpAndCopyNamehint<aig::AndInverterOp>(
889 rewriter, op, xorBits, allInverts);
893 case ICmpPredicate::ne:
894 case ICmpPredicate::cne: {
896 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
897 replaceOpWithNewOpAndCopyNamehint<comb::OrOp>(
902 case ICmpPredicate::uge:
903 case ICmpPredicate::ugt:
904 case ICmpPredicate::ule:
905 case ICmpPredicate::ult: {
906 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
907 op.getPredicate() == ICmpPredicate::ule;
908 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
909 op.getPredicate() == ICmpPredicate::ule;
913 constructUnsignedCompare(op, aBits, bBits,
918 case ICmpPredicate::slt:
919 case ICmpPredicate::sle:
920 case ICmpPredicate::sgt:
921 case ICmpPredicate::sge: {
922 if (lhs.getType().getIntOrFloatBitWidth() == 0)
923 return rewriter.notifyMatchFailure(
924 op.getLoc(),
"i0 signed comparison is unsupported");
925 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
926 op.getPredicate() == ICmpPredicate::sle;
927 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
928 op.getPredicate() == ICmpPredicate::sle;
934 auto signA = aBits.back();
935 auto signB = bBits.back();
938 auto sameSignResult = constructUnsignedCompare(
939 op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess,
940 includeEq, rewriter);
944 comb::XorOp::create(rewriter, op.getLoc(), signA, signB);
947 Value diffSignResult = isLess ? signA : signB;
950 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
951 rewriter, op, signsDiffer, diffSignResult, sameSignResult);
962 matchAndRewrite(
ParityOp op, OpAdaptor adaptor,
963 ConversionPatternRewriter &rewriter)
const override {
965 replaceOpWithNewOpAndCopyNamehint<comb::XorOp>(
966 rewriter, op,
extractBits(rewriter, adaptor.getInput()),
true);
976 ConversionPatternRewriter &rewriter)
const override {
977 auto width = op.getType().getIntOrFloatBitWidth();
978 auto lhs = adaptor.getLhs();
980 rewriter, op.getLoc(), adaptor.getRhs(), width,
988 op.getLoc(), rewriter.getIntegerType(index), 0);
992 assert(index < width &&
"index out of bounds");
1008 ConversionPatternRewriter &rewriter)
const override {
1009 auto width = op.getType().getIntOrFloatBitWidth();
1010 auto lhs = adaptor.getLhs();
1012 rewriter, op.getLoc(), adaptor.getRhs(), width,
1014 [&](int64_t index) {
1020 op.getLoc(), rewriter.getIntegerType(index), 0);
1023 [&](int64_t index) {
1024 assert(index < width &&
"index out of bounds");
1026 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1040 ConversionPatternRewriter &rewriter)
const override {
1041 auto width = op.getType().getIntOrFloatBitWidth();
1043 return rewriter.notifyMatchFailure(op.getLoc(),
1044 "i0 signed shift is unsupported");
1045 auto lhs = adaptor.getLhs();
1048 rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
1053 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
1055 [&](int64_t index) {
1056 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
1060 [&](int64_t index) {
1061 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1077struct ConvertCombToAIGPass
1078 :
public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
1079 void runOnOperation()
override;
1080 using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
1081 using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
1082 using ConvertCombToAIGBase<ConvertCombToAIGPass>::maxEmulationUnknownBits;
1088 uint32_t maxEmulationUnknownBits) {
1091 CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
1092 CombMuxOpConversion, CombParityOpConversion,
1094 CombAddOpConversion, CombSubOpConversion, CombMulOpConversion,
1095 CombICmpOpConversion,
1097 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
1099 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
1100 CombLowerVariadicOp<MulOp>>(
patterns.getContext());
1103 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
1104 CombModSOpConversion>(
patterns.getContext(),
1105 maxEmulationUnknownBits);
1108void ConvertCombToAIGPass::runOnOperation() {
1109 ConversionTarget target(getContext());
1112 target.addIllegalDialect<comb::CombDialect>();
1122 hw::AggregateConstantOp>();
1125 target.addLegalDialect<aig::AIGDialect>();
1128 if (!additionalLegalOps.empty())
1129 for (
const auto &opName : additionalLegalOps)
1130 target.addLegalOp(OperationName(opName, &getContext()));
1132 RewritePatternSet
patterns(&getContext());
1135 if (failed(mlir::applyPartialConversion(getOperation(), target,
1137 return signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.