17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/DialectConversion.h"
19#include "llvm/ADT/PointerUnion.h"
20#include "llvm/Support/Debug.h"
22#define DEBUG_TYPE "comb-to-aig"
25#define GEN_PASS_DEF_CONVERTCOMBTOAIG
26#include "circt/Conversion/Passes.h.inc"
37static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
38 SmallVector<Value> bits;
39 comb::extractBits(builder, val, bits);
37static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
…}
50template <
bool isLeftShift>
52 Value shiftAmount, int64_t maxShiftAmount,
53 llvm::function_ref<Value(int64_t)> getPadding,
54 llvm::function_ref<Value(int64_t)> getExtract) {
59 SmallVector<Value> nodes;
60 nodes.reserve(maxShiftAmount);
61 for (int64_t i = 0; i < maxShiftAmount; ++i) {
62 Value extract = getExtract(i);
63 Value padding = getPadding(i);
66 nodes.push_back(extract);
80 auto outOfBoundsValue = getPadding(maxShiftAmount);
81 assert(outOfBoundsValue &&
"outOfBoundsValue must be valid");
85 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
88 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
89 loc, ICmpPredicate::ult, shiftAmount,
93 return rewriter.createOrFold<
comb::MuxOp>(loc, inBound, result,
99using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
104 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
106 if (value.getType().isInteger(0))
111 int64_t totalUnknownBits = 0;
112 for (
auto concatInput : llvm::reverse(
concat.getInputs())) {
117 totalUnknownBits += unknownBits;
119 return totalUnknownBits;
124 values.push_back(constant.getValueAttr());
130 values.push_back(value);
131 return hw::getBitWidth(value.getType());
137 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
139 uint32_t bitPos = 0, unknownPos = 0;
140 APInt result(width, 0);
141 for (
auto constantOrValue : constantOrValues) {
143 if (
auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
144 elemWidth = constant.getValue().getBitWidth();
145 result.insertBits(constant.getValue(), bitPos);
147 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
148 assert(elemWidth >= 0 &&
"unknown bit width");
149 assert(elemWidth + unknownPos < 32 &&
"unknown bit width too large");
151 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
152 result.insertBits(APInt(elemWidth, usedBits), bitPos);
153 unknownPos += elemWidth;
165 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
167 llvm::function_ref<APInt(
const APInt &,
const APInt &)> emulate) {
168 SmallVector<ConstantOrValue> lhsValues, rhsValues;
170 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
171 "op must be a single result binary operation");
173 auto lhs = op->getOperand(0);
174 auto rhs = op->getOperand(1);
175 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
176 auto loc = op->getLoc();
181 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
184 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
185 if (totalUnknownBits > maxEmulationUnknownBits)
188 SmallVector<Value> emulatedResults;
189 emulatedResults.reserve(1 << totalUnknownBits);
192 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
194 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
195 auto it = constantPool.find(attr);
196 if (it != constantPool.end())
199 constantPool[attr] = constant;
203 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
204 lhsMask < lhsMaskEnd; ++lhsMask) {
206 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
207 rhsMask < rhsMaskEnd; ++rhsMask) {
210 emulatedResults.push_back(
getConstant(emulate(lhsValue, rhsValue)));
215 SmallVector<Value> selectors;
216 selectors.reserve(totalUnknownBits);
217 for (
auto &concatedValues : {rhsValues, lhsValues})
218 for (
auto valueOrConstant : concatedValues) {
219 auto value = dyn_cast<Value>(valueOrConstant);
225 assert(totalUnknownBits ==
static_cast<int64_t
>(selectors.size()) &&
226 "number of selectors must match");
227 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
230 rewriter.replaceOp(op, muxed);
245 matchAndRewrite(
AndOp op, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter)
const override {
247 SmallVector<bool> nonInverts(adaptor.getInputs().size(),
false);
248 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, adaptor.getInputs(),
259 matchAndRewrite(
OrOp op, OpAdaptor adaptor,
260 ConversionPatternRewriter &rewriter)
const override {
262 SmallVector<bool> allInverts(adaptor.getInputs().size(),
true);
263 auto andOp = rewriter.create<aig::AndInverterOp>(
264 op.getLoc(), adaptor.getInputs(), allInverts);
265 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, andOp,
276 matchAndRewrite(
XorOp op, OpAdaptor adaptor,
277 ConversionPatternRewriter &rewriter)
const override {
278 if (op.getNumOperands() != 2)
284 auto inputs = adaptor.getInputs();
285 SmallVector<bool> allInverts(inputs.size(),
true);
286 SmallVector<bool> allNotInverts(inputs.size(),
false);
289 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allInverts);
291 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allNotInverts);
293 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, notAAndNotB, aAndB,
300template <
typename OpTy>
305 matchAndRewrite(OpTy op, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter)
const override {
308 rewriter.replaceOp(op, result);
313 ConversionPatternRewriter &rewriter) {
315 switch (operands.size()) {
317 assert(
false &&
"cannot be called with empty operand range");
324 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs},
true);
326 auto firstHalf = operands.size() / 2;
331 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs},
true);
341 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter)
const override {
345 Value cond = op.getCond();
346 auto trueVal = op.getTrueValue();
347 auto falseVal = op.getFalseValue();
349 if (!op.getType().isInteger()) {
351 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
353 rewriter.create<
hw::BitcastOp>(op->getLoc(), widthType, trueVal);
359 if (!trueVal.getType().isInteger(1))
360 cond = rewriter.
create<comb::ReplicateOp>(op.getLoc(), trueVal.getType(),
364 auto lhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, trueVal);
365 auto rhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, falseVal,
368 Value result = rewriter.create<
comb::OrOp>(op.getLoc(), lhs, rhs);
370 if (result.getType() != op.getType())
372 rewriter.create<
hw::BitcastOp>(op.getLoc(), op.getType(), result);
373 rewriter.replaceOp(op, result);
381 matchAndRewrite(
AddOp op, OpAdaptor adaptor,
382 ConversionPatternRewriter &rewriter)
const override {
383 auto inputs = adaptor.getInputs();
386 if (inputs.size() != 2)
389 auto width = op.getType().getIntOrFloatBitWidth();
397 lowerRippleCarryAdder(op, inputs, rewriter);
399 lowerParallelPrefixAdder(op, inputs, rewriter);
405 void lowerRippleCarryAdder(
comb::AddOp op, ValueRange inputs,
406 ConversionPatternRewriter &rewriter)
const {
407 auto width = op.getType().getIntOrFloatBitWidth();
413 SmallVector<Value> results;
414 results.resize(width);
415 for (int64_t i = 0; i < width; ++i) {
416 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
418 xorOperands.push_back(carry);
422 results[width - i - 1] =
423 rewriter.create<
comb::XorOp>(op.getLoc(), xorOperands,
true);
431 op.getLoc(), ValueRange{aBits[i], bBits[i]},
true);
439 op.getLoc(), ValueRange{aBits[i], bBits[i]},
true);
441 op.getLoc(), ValueRange{carry, aXnorB},
true);
442 carry = rewriter.create<
comb::OrOp>(op.getLoc(),
443 ValueRange{andOp, nextCarry},
true);
445 LLVM_DEBUG(llvm::dbgs() <<
"Lower comb.add to Ripple-Carry Adder of width "
454 void lowerParallelPrefixAdder(
comb::AddOp op, ValueRange inputs,
455 ConversionPatternRewriter &rewriter)
const {
456 auto width = op.getType().getIntOrFloatBitWidth();
461 SmallVector<Value> p, g;
465 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
467 p.push_back(rewriter.create<
comb::XorOp>(op.getLoc(), aBit, bBit));
469 g.push_back(rewriter.create<
comb::AndOp>(op.getLoc(), aBit, bBit));
473 llvm::dbgs() <<
"Lower comb.add to Parallel-Prefix of width " << width
474 <<
"\n--------------------------------------- Init\n";
476 for (int64_t i = 0; i < width; ++i) {
478 llvm::dbgs() <<
"P0" << i <<
" = A" << i <<
" XOR B" << i <<
"\n";
480 llvm::dbgs() <<
"G0" << i <<
" = A" << i <<
" AND B" << i <<
"\n";
485 SmallVector<Value> pPrefix = p;
486 SmallVector<Value> gPrefix = g;
488 lowerKoggeStonePrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
490 lowerBrentKungPrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
494 SmallVector<Value> results;
495 results.resize(width);
497 results[width - 1] = p[0];
501 for (int64_t i = 1; i < width; ++i)
502 results[width - 1 - i] =
503 rewriter.create<
comb::XorOp>(op.getLoc(), p[i], gPrefix[i - 1]);
508 llvm::dbgs() <<
"--------------------------------------- Completion\n"
510 for (int64_t i = 1; i < width; ++i)
511 llvm::dbgs() <<
"RES" << i <<
" = P" << i <<
" XOR G" << i - 1 <<
"\n";
518 void lowerKoggeStonePrefixTree(
comb::AddOp op, ValueRange inputs,
519 ConversionPatternRewriter &rewriter,
520 SmallVector<Value> &pPrefix,
521 SmallVector<Value> &gPrefix)
const {
522 auto width = op.getType().getIntOrFloatBitWidth();
525 for (int64_t stride = 1; stride < width; stride *= 2) {
526 for (int64_t i = stride; i < width; ++i) {
527 int64_t j = i - stride;
530 rewriter.create<
comb::AndOp>(op.getLoc(), pPrefix[i], gPrefix[j]);
532 rewriter.create<
comb::OrOp>(op.getLoc(), gPrefix[i], andPG);
536 rewriter.create<
comb::AndOp>(op.getLoc(), pPrefix[i], pPrefix[j]);
541 for (int64_t stride = 1; stride < width; stride *= 2) {
543 <<
"--------------------------------------- Kogge-Stone Stage "
545 for (int64_t i = stride; i < width; ++i) {
546 int64_t j = i - stride;
548 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
549 <<
" OR (P" << i << stage <<
" AND G" << j << stage
553 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
554 <<
" AND P" << j << stage <<
"\n";
564 void lowerBrentKungPrefixTree(
comb::AddOp op, ValueRange inputs,
565 ConversionPatternRewriter &rewriter,
566 SmallVector<Value> &pPrefix,
567 SmallVector<Value> &gPrefix)
const {
568 auto width = op.getType().getIntOrFloatBitWidth();
573 for (stride = 1; stride < width; stride *= 2) {
574 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
575 int64_t j = i - stride;
579 rewriter.create<
comb::AndOp>(op.getLoc(), pPrefix[i], gPrefix[j]);
581 rewriter.create<
comb::OrOp>(op.getLoc(), gPrefix[i], andPG);
585 rewriter.create<
comb::AndOp>(op.getLoc(), pPrefix[i], pPrefix[j]);
590 for (; stride > 0; stride /= 2) {
591 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
592 int64_t j = i - stride;
596 rewriter.create<
comb::AndOp>(op.getLoc(), pPrefix[i], gPrefix[j]);
597 gPrefix[i] = rewriter.create<
OrOp>(op.getLoc(), gPrefix[i], andPG);
601 rewriter.create<
comb::AndOp>(op.getLoc(), pPrefix[i], pPrefix[j]);
607 for (stride = 1; stride < width; stride *= 2) {
608 llvm::dbgs() <<
"--------------------------------------- Brent-Kung FW "
609 << stage <<
" : Stride " << stride <<
"\n";
610 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
611 int64_t j = i - stride;
614 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
615 <<
" OR (P" << i << stage <<
" AND G" << j << stage
619 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
620 <<
" AND P" << j << stage <<
"\n";
625 for (; stride > 0; stride /= 2) {
626 if (stride * 3 - 1 < width)
628 <<
"--------------------------------------- Brent-Kung BW "
629 << stage <<
" : Stride " << stride <<
"\n";
631 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
632 int64_t j = i - stride;
635 llvm::dbgs() <<
"G" << i << stage + 1 <<
" = G" << i << stage
636 <<
" OR (P" << i << stage <<
" AND G" << j << stage
640 llvm::dbgs() <<
"P" << i << stage + 1 <<
" = P" << i << stage
641 <<
" AND P" << j << stage <<
"\n";
652 matchAndRewrite(
SubOp op, OpAdaptor adaptor,
653 ConversionPatternRewriter &rewriter)
const override {
654 auto lhs = op.getLhs();
655 auto rhs = op.getRhs();
659 auto notRhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), rhs,
661 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
662 rewriter.replaceOpWithNewOp<
comb::AddOp>(op, ValueRange{lhs, notRhs, one},
672 matchAndRewrite(
MulOp op, OpAdaptor adaptor,
673 ConversionPatternRewriter &rewriter)
const override {
674 if (adaptor.getInputs().size() != 2)
683 int64_t width = op.getType().getIntOrFloatBitWidth();
684 auto aBits =
extractBits(rewriter, adaptor.getInputs()[0]);
685 SmallVector<Value> results;
686 auto rhs = op.getInputs()[1];
688 llvm::APInt::getZero(width));
689 for (int64_t i = 0; i < width; ++i) {
690 auto aBit = aBits[i];
692 rewriter.createOrFold<
comb::MuxOp>(op.getLoc(), aBit, rhs, zero);
694 op.getLoc(), andBit, 0, width - i);
696 results.push_back(upperBits);
704 op.getLoc(), op.getType(), ValueRange{upperBits, lowerBits});
705 results.push_back(shifted);
708 rewriter.replaceOpWithNewOp<
comb::AddOp>(op, results,
true);
713template <
typename OpTy>
715 DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
717 maxEmulationUnknownBits(maxEmulationUnknownBits) {
718 assert(maxEmulationUnknownBits < 32 &&
719 "maxEmulationUnknownBits must be less than 32");
721 const int64_t maxEmulationUnknownBits;
724struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
725 using DivModOpConversionBase<
DivUOp>::DivModOpConversionBase;
727 matchAndRewrite(
DivUOp op, OpAdaptor adaptor,
728 ConversionPatternRewriter &rewriter)
const override {
730 if (
auto rhsConstantOp = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>())
731 if (rhsConstantOp.getValue().isPowerOf2()) {
733 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
734 size_t width = op.getType().getIntOrFloatBitWidth();
736 op.getLoc(), adaptor.getLhs(), extractAmount,
737 width - extractAmount);
739 op.getLoc(), APInt::getZero(extractAmount));
741 op, op.getType(), ArrayRef<Value>{constZero, upperBits});
748 rewriter, maxEmulationUnknownBits, op,
749 [](
const APInt &lhs,
const APInt &rhs) {
752 return APInt::getZero(rhs.getBitWidth());
753 return lhs.udiv(rhs);
758struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
759 using DivModOpConversionBase<
ModUOp>::DivModOpConversionBase;
761 matchAndRewrite(
ModUOp op, OpAdaptor adaptor,
762 ConversionPatternRewriter &rewriter)
const override {
764 if (
auto rhsConstantOp = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>())
765 if (rhsConstantOp.getValue().isPowerOf2()) {
767 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
768 size_t width = op.getType().getIntOrFloatBitWidth();
770 op.getLoc(), adaptor.getLhs(), 0, extractAmount);
772 op.getLoc(), APInt::getZero(width - extractAmount));
774 op, op.getType(), ArrayRef<Value>{constZero, lowerBits});
781 rewriter, maxEmulationUnknownBits, op,
782 [](
const APInt &lhs,
const APInt &rhs) {
785 return APInt::getZero(rhs.getBitWidth());
786 return lhs.urem(rhs);
791struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
792 using DivModOpConversionBase<
DivSOp>::DivModOpConversionBase;
795 matchAndRewrite(
DivSOp op, OpAdaptor adaptor,
796 ConversionPatternRewriter &rewriter)
const override {
800 rewriter, maxEmulationUnknownBits, op,
801 [](
const APInt &lhs,
const APInt &rhs) {
804 return APInt::getZero(rhs.getBitWidth());
805 return lhs.sdiv(rhs);
810struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
811 using DivModOpConversionBase<
ModSOp>::DivModOpConversionBase;
813 matchAndRewrite(
ModSOp op, OpAdaptor adaptor,
814 ConversionPatternRewriter &rewriter)
const override {
818 rewriter, maxEmulationUnknownBits, op,
819 [](
const APInt &lhs,
const APInt &rhs) {
822 return APInt::getZero(rhs.getBitWidth());
823 return lhs.srem(rhs);
830 static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
831 ArrayRef<Value> bBits,
bool isLess,
833 ConversionPatternRewriter &rewriter) {
840 rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), includeEq);
842 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
844 rewriter.createOrFold<
comb::XorOp>(op.getLoc(), aBit, bBit,
true);
845 auto aEqualB = rewriter.createOrFold<aig::AndInverterOp>(
846 op.getLoc(), aBitXorBBit,
true);
847 auto pred = rewriter.createOrFold<aig::AndInverterOp>(
848 op.getLoc(), aBit, bBit, isLess, !isLess);
850 auto aBitAndBBit = rewriter.createOrFold<
comb::AndOp>(
851 op.getLoc(), ValueRange{aEqualB, acc},
true);
852 acc = rewriter.createOrFold<
comb::OrOp>(op.getLoc(), pred, aBitAndBBit,
859 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
860 ConversionPatternRewriter &rewriter)
const override {
861 auto lhs = adaptor.getLhs();
862 auto rhs = adaptor.getRhs();
864 switch (op.getPredicate()) {
868 case ICmpPredicate::eq:
869 case ICmpPredicate::ceq: {
871 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
873 SmallVector<bool> allInverts(xorBits.size(),
true);
874 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, xorBits, allInverts);
878 case ICmpPredicate::ne:
879 case ICmpPredicate::cne: {
881 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
887 case ICmpPredicate::uge:
888 case ICmpPredicate::ugt:
889 case ICmpPredicate::ule:
890 case ICmpPredicate::ult: {
891 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
892 op.getPredicate() == ICmpPredicate::ule;
893 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
894 op.getPredicate() == ICmpPredicate::ule;
897 rewriter.replaceOp(op, constructUnsignedCompare(op, aBits, bBits, isLess,
898 includeEq, rewriter));
901 case ICmpPredicate::slt:
902 case ICmpPredicate::sle:
903 case ICmpPredicate::sgt:
904 case ICmpPredicate::sge: {
905 if (lhs.getType().getIntOrFloatBitWidth() == 0)
906 return rewriter.notifyMatchFailure(
907 op.getLoc(),
"i0 signed comparison is unsupported");
908 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
909 op.getPredicate() == ICmpPredicate::sle;
910 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
911 op.getPredicate() == ICmpPredicate::sle;
917 auto signA = aBits.back();
918 auto signB = bBits.back();
921 auto sameSignResult = constructUnsignedCompare(
922 op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess,
923 includeEq, rewriter);
927 rewriter.create<
comb::XorOp>(op.getLoc(), signA, signB);
930 Value diffSignResult = isLess ? signA : signB;
933 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, signsDiffer, diffSignResult,
945 matchAndRewrite(
ParityOp op, OpAdaptor adaptor,
946 ConversionPatternRewriter &rewriter)
const override {
949 op,
extractBits(rewriter, adaptor.getInput()),
true);
959 ConversionPatternRewriter &rewriter)
const override {
960 auto width = op.getType().getIntOrFloatBitWidth();
961 auto lhs = adaptor.getLhs();
963 rewriter, op.getLoc(), adaptor.getRhs(), width,
971 op.getLoc(), rewriter.getIntegerType(index), 0);
975 assert(index < width &&
"index out of bounds");
981 rewriter.replaceOp(op, result);
991 ConversionPatternRewriter &rewriter)
const override {
992 auto width = op.getType().getIntOrFloatBitWidth();
993 auto lhs = adaptor.getLhs();
995 rewriter, op.getLoc(), adaptor.getRhs(), width,
1003 op.getLoc(), rewriter.getIntegerType(index), 0);
1006 [&](int64_t index) {
1007 assert(index < width &&
"index out of bounds");
1009 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1013 rewriter.replaceOp(op, result);
1023 ConversionPatternRewriter &rewriter)
const override {
1024 auto width = op.getType().getIntOrFloatBitWidth();
1026 return rewriter.notifyMatchFailure(op.getLoc(),
1027 "i0 signed shift is unsupported");
1028 auto lhs = adaptor.getLhs();
1031 rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
1036 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
1038 [&](int64_t index) {
1039 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
1043 [&](int64_t index) {
1044 return rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, index,
1048 rewriter.replaceOp(op, result);
1060struct ConvertCombToAIGPass
1061 :
public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
1062 void runOnOperation()
override;
1063 using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
1064 using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
1065 using ConvertCombToAIGBase<ConvertCombToAIGPass>::maxEmulationUnknownBits;
1071 uint32_t maxEmulationUnknownBits) {
1074 CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
1075 CombMuxOpConversion, CombParityOpConversion,
1077 CombAddOpConversion, CombSubOpConversion, CombMulOpConversion,
1078 CombICmpOpConversion,
1080 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
1082 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
1083 CombLowerVariadicOp<MulOp>>(
patterns.getContext());
1086 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
1087 CombModSOpConversion>(
patterns.getContext(),
1088 maxEmulationUnknownBits);
1091void ConvertCombToAIGPass::runOnOperation() {
1092 ConversionTarget target(getContext());
1095 target.addIllegalDialect<comb::CombDialect>();
1105 hw::AggregateConstantOp>();
1108 target.addLegalDialect<aig::AIGDialect>();
1111 if (!additionalLegalOps.empty())
1112 for (
const auto &opName : additionalLegalOps)
1113 target.addLegalOp(OperationName(opName, &getContext()));
1115 RewritePatternSet
patterns(&getContext());
1118 if (failed(mlir::applyPartialConversion(getOperation(), target,
1120 return signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.