12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallBitVector.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/KnownBits.h"
20 using namespace circt;
22 using namespace matchers;
33 Block *thisBlock = op->getBlock();
34 return llvm::any_of(op->getOperands(), [&](Value operand) {
35 return operand.getParentBlock() != thisBlock;
45 ArrayRef<Value> operands, OpBuilder &
builder) {
46 OperationState state(loc, name);
47 state.addOperands(operands);
48 state.addTypes(operands[0].getType());
49 return builder.create(state)->getResult(0);
60 for (
auto op :
concat.getOperands())
62 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
63 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
75 if (
auto *newOp = newValue.getDefiningOp()) {
76 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
77 if (name && !newOp->hasAttr(
"sv.namehint"))
78 rewriter.updateRootInPlace(newOp,
79 [&] { newOp->setAttr(
"sv.namehint", name); });
81 rewriter.replaceOp(op, newValue);
87 template <
typename OpTy,
typename... Args>
89 Operation *op, Args &&...args) {
90 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
92 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
93 if (name && !newOp->hasAttr(
"sv.namehint"))
94 rewriter.updateRootInPlace(newOp,
95 [&] { newOp->setAttr(
"sv.namehint", name); });
104 return op->hasAttr(
"sv.attributes");
108 template <
typename SubType>
109 struct ComplementMatcher {
111 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
112 bool match(Operation *op) {
113 auto xorOp = dyn_cast<XorOp>(op);
114 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
119 template <
typename SubType>
120 static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
121 return ComplementMatcher<SubType>(subExpr);
130 auto inputs = op->getOperands();
132 for (
size_t i = 0, size =
inputs.size(); i != size; ++i) {
133 Operation *flattenOp =
inputs[i].getDefiningOp();
134 if (!flattenOp || flattenOp->getName() != op->getName())
142 if (!
inputs[i].hasOneUse()) {
150 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
157 auto flattenOpInputs = flattenOp->getOperands();
159 SmallVector<Value, 4> newOperands;
160 newOperands.reserve(size + flattenOpInputs.size());
162 auto flattenOpIndex =
inputs.begin() + i;
163 newOperands.append(
inputs.begin(), flattenOpIndex);
164 newOperands.append(flattenOpInputs.begin(), flattenOpInputs.end());
165 newOperands.append(flattenOpIndex + 1,
inputs.end());
172 if (op->hasAttrOfType<UnitAttr>(
"twoState") &&
173 flattenOp->hasAttrOfType<UnitAttr>(
"twoState"))
174 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
184 static std::pair<size_t, size_t>
186 size_t originalOpWidth) {
187 auto users = op->getUsers();
189 "getLowestBitAndHighestBitRequired cannot operate on "
190 "a empty list of uses.");
194 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
195 size_t highestBitRequired = 0;
197 for (
auto *user : users) {
198 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
199 size_t lowBit = extractOp.getLowBit();
201 extractOp.getType().cast<IntegerType>().
getWidth() + lowBit - 1;
202 highestBitRequired = std::max(highestBitRequired, highBit);
203 lowestBitRequired = std::min(lowestBitRequired, lowBit);
207 highestBitRequired = originalOpWidth - 1;
208 lowestBitRequired = 0;
212 return {lowestBitRequired, highestBitRequired};
215 template <
class OpTy>
217 PatternRewriter &rewriter) {
219 op.getResult().getType().template dyn_cast<IntegerType>();
225 if (range.second + 1 == opType.getWidth() && range.first == 0)
228 SmallVector<Value> args;
229 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
230 for (
auto inop : op.getOperands()) {
232 if (inop.getType() != op.getType())
233 args.push_back(inop);
235 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
238 Value newop = rewriter.createOrFold<OpTy>(op.getLoc(), newType, args);
239 newop.getDefiningOp()->setDialectAttrs(op->getDialectAttrs());
241 newop = rewriter.createOrFold<
ConcatOp>(
244 APInt::getZero(range.first)));
245 if (range.second + 1 < opType.getWidth())
246 newop = rewriter.createOrFold<
ConcatOp>(
249 op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
251 rewriter.replaceOp(op, newop);
259 OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
264 if (getType().cast<IntegerType>().
getWidth() ==
265 getInput().getType().getIntOrFloatBitWidth())
269 if (
auto input = adaptor.getInput().dyn_cast_or_null<IntegerAttr>()) {
270 if (input.getValue().getBitWidth() == 1) {
271 if (input.getValue().isZero())
273 APInt::getZero(getType().cast<IntegerType>().
getWidth()),
276 APInt::getAllOnes(getType().cast<IntegerType>().
getWidth()),
280 APInt result = APInt::getZeroWidth();
281 for (
auto i = getMultiple(); i != 0; --i)
282 result = result.concat(input.getValue());
289 OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
294 if (
auto input = adaptor.getInput().dyn_cast_or_null<IntegerAttr>())
295 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
307 hw::PEO paramOpcode) {
308 assert(operands.size() == 2 &&
"binary op takes two operands");
309 if (!operands[0] || !operands[1])
315 operands[1].cast<TypedAttr>());
318 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
322 if (
auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>()) {
323 unsigned shift = rhs.getValue().getZExtValue();
324 unsigned width = getType().getIntOrFloatBitWidth();
326 return getOperand(0);
334 LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
340 if (!matchPattern(op.getRhs(), m_ConstantInt(&
value)))
343 unsigned width = op.getLhs().getType().cast<IntegerType>().
getWidth();
344 unsigned shift =
value.getZExtValue();
347 if (
width <= shift || shift == 0)
351 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
357 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, extract, zeros);
361 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
365 if (
auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>()) {
366 unsigned shift = rhs.getValue().getZExtValue();
368 return getOperand(0);
370 unsigned width = getType().getIntOrFloatBitWidth();
377 LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
383 if (!matchPattern(op.getRhs(), m_ConstantInt(&
value)))
386 unsigned width = op.getLhs().getType().cast<IntegerType>().
getWidth();
387 unsigned shift =
value.getZExtValue();
390 if (
width <= shift || shift == 0)
394 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
397 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
400 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, zeros, extract);
404 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
408 if (
auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>()) {
409 if (rhs.getValue().getZExtValue() == 0)
410 return getOperand(0);
415 LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
421 if (!matchPattern(op.getRhs(), m_ConstantInt(&
value)))
424 unsigned width = op.getLhs().getType().cast<IntegerType>().
getWidth();
425 unsigned shift =
value.getZExtValue();
428 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(),
width - 1, 1);
429 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
431 if (
width <= shift) {
436 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
439 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, sext, extract);
447 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
452 if (getInput().getType() == getType())
456 if (
auto input = adaptor.getInput().dyn_cast_or_null<IntegerAttr>()) {
457 unsigned dstWidth = getType().cast<IntegerType>().
getWidth();
458 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
469 PatternRewriter &rewriter) {
470 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
471 size_t beginOfFirstRelevantElement = 0;
472 auto it = reversedConcatArgs.begin();
473 size_t lowBit = op.getLowBit();
476 for (; it != reversedConcatArgs.end(); it++) {
477 assert(beginOfFirstRelevantElement <= lowBit &&
478 "incorrectly moved past an element that lowBit has coverage over");
481 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
482 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
506 beginOfFirstRelevantElement += operandWidth;
508 assert(it != reversedConcatArgs.end() &&
509 "incorrectly failed to find an element which contains coverage of "
512 SmallVector<Value> reverseConcatArgs;
513 size_t widthRemaining = op.getType().cast<IntegerType>().
getWidth();
514 size_t extractLo = lowBit - beginOfFirstRelevantElement;
519 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
520 auto concatArg = *it;
521 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
522 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
524 if (widthToConsume == operandWidth && extractLo == 0) {
525 reverseConcatArgs.push_back(concatArg);
528 reverseConcatArgs.push_back(
529 rewriter.create<
ExtractOp>(op.getLoc(), resultType, *it, extractLo));
532 widthRemaining -= widthToConsume;
538 if (reverseConcatArgs.size() == 1) {
541 replaceOpWithNewOpAndCopyName<ConcatOp>(
542 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
549 PatternRewriter &rewriter) {
550 auto extractResultWidth = op.getType().cast<IntegerType>().
getWidth();
551 auto replicateEltWidth =
552 replicate.getOperand().getType().getIntOrFloatBitWidth();
556 if (op.getLowBit() % replicateEltWidth == 0 &&
557 extractResultWidth % replicateEltWidth == 0) {
558 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
559 replicate.getOperand());
565 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
567 replaceOpWithNewOpAndCopyName<ExtractOp>(
568 rewriter, op, op.getType(), replicate.getOperand(),
569 op.getLowBit() % replicateEltWidth);
578 LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
582 auto *inputOp = op.getInput().getDefiningOp();
589 .extractBits(op.getType().cast<IntegerType>().getWidth(),
591 if (knownBits.isConstant()) {
592 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
593 knownBits.getConstant());
599 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
600 replaceOpWithNewOpAndCopyName<ExtractOp>(
601 rewriter, op, op.getType(), innerExtract.getInput(),
602 innerExtract.getLowBit() + op.getLowBit());
607 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
611 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
617 if (inputOp && inputOp->getNumOperands() == 2 &&
618 isa<AndOp, OrOp, XorOp>(inputOp)) {
619 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
620 auto extractedCst = cstRHS.getValue().extractBits(
621 op.getType().cast<IntegerType>().getWidth(), op.getLowBit());
622 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
623 replaceOpWithNewOpAndCopyName<ExtractOp>(
624 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
632 if (isa<AndOp>(inputOp)) {
635 unsigned lz = extractedCst.countLeadingZeros();
636 unsigned tz = extractedCst.countTrailingZeros();
637 unsigned pop = extractedCst.popcount();
638 if (extractedCst.getBitWidth() - lz - tz == pop) {
639 auto resultTy = rewriter.getIntegerType(pop);
640 SmallVector<Value> resultElts;
643 op.getLoc(), APInt::getZero(lz)));
644 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
645 op.getLoc(), resultTy, inputOp->getOperand(0),
646 op.getLowBit() + tz));
649 op.getLoc(), APInt::getZero(tz)));
650 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
659 if (op.getType().cast<IntegerType>().getWidth() == 1 && inputOp)
660 if (
auto shlOp = dyn_cast<ShlOp>(inputOp))
661 if (
auto lhsCst = shlOp.getOperand(0).getDefiningOp<
hw::ConstantOp>())
662 if (lhsCst.getValue().isOne()) {
665 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
666 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
667 shlOp->getOperand(1), newCst,
682 hw::PEO paramOpcode) {
683 assert(operands.size() > 1 &&
"caller should handle one-operand case");
686 if (!operands[1] || !operands[0])
690 if (llvm::all_of(operands.drop_front(2),
691 [&](Attribute in) { return !!in; })) {
692 SmallVector<mlir::TypedAttr> typedOperands;
693 typedOperands.reserve(operands.size());
694 for (
auto operand : operands) {
695 if (
auto typedOperand = operand.dyn_cast<mlir::TypedAttr>())
696 typedOperands.push_back(typedOperand);
700 if (typedOperands.size() == operands.size())
717 size_t concatIdx,
const APInt &cst,
718 PatternRewriter &rewriter) {
719 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
720 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
725 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
726 auto *operandOp = operand.getDefiningOp();
731 if (isa<hw::ConstantOp>(operandOp))
735 return operandOp->getName() == logicalOp->getName() &&
736 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
737 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
745 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
746 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
753 SmallVector<Value> newConcatOperands;
754 newConcatOperands.reserve(concatOp->getNumOperands());
757 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
758 for (Value operand : concatOp->getOperands()) {
759 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
760 nextOperandBit -= operandWidth;
763 logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
765 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
774 if (logicalOp->getNumOperands() > 2) {
775 auto origOperands = logicalOp->getOperands();
776 SmallVector<Value> operands;
778 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
780 operands.append(origOperands.begin() + concatIdx + 1,
781 origOperands.begin() + (origOperands.size() - 1));
783 operands.push_back(newResult);
784 newResult = createLogicalOp(operands);
791 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
795 APInt
value = APInt::getAllOnes(getType().cast<IntegerType>().
getWidth());
797 auto inputs = adaptor.getInputs();
800 for (
auto operand :
inputs) {
803 value &= operand.cast<IntegerAttr>().getValue();
810 inputs[1].cast<IntegerAttr>().getValue().isAllOnes())
811 return getInputs()[0];
814 if (llvm::all_of(getInputs(),
815 [&](
auto in) {
return in == this->getInputs()[0]; }))
816 return getInputs()[0];
819 for (Value arg : getInputs()) {
822 for (Value arg2 : getInputs())
825 APInt::getZero(getType().cast<IntegerType>().
getWidth()),
839 template <
typename Op>
841 if (!op.getType().isInteger(1))
844 auto inputs = op.getInputs();
845 size_t size =
inputs.size();
847 auto sourceOp =
inputs[0].template getDefiningOp<ExtractOp>();
850 Value source = sourceOp.getOperand();
853 if (size != source.getType().getIntOrFloatBitWidth())
857 llvm::BitVector bits(size);
858 bits.set(sourceOp.getLowBit());
860 for (
size_t i = 1; i != size; ++i) {
861 auto extractOp =
inputs[i].template getDefiningOp<ExtractOp>();
862 if (!extractOp || extractOp.getOperand() != source)
864 bits.set(extractOp.getLowBit());
867 return bits.all() ? source : Value();
874 template <
typename Op>
876 auto inputs = op.getInputs();
877 llvm::SmallSetVector<Value, 8> uniqueInputs;
879 for (
const auto input :
inputs)
880 uniqueInputs.insert(input);
882 if (uniqueInputs.size() <
inputs.size()) {
883 replaceOpWithNewOpAndCopyName<Op>(rewriter, op, op.getType(),
884 uniqueInputs.getArrayRef());
891 LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
895 auto inputs = op.getInputs();
896 auto size =
inputs.size();
897 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
906 if (matchPattern(
inputs.back(), m_ConstantInt(&
value))) {
908 if (
value.isAllOnes()) {
909 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
910 inputs.drop_back(),
false);
918 if (matchPattern(
inputs[size - 2], m_ConstantInt(&value2))) {
920 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
921 newOperands.push_back(cst);
922 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
928 if (size == 2 &&
value.isPowerOf2()) {
933 if (
auto replicate =
inputs[0].getDefiningOp<ReplicateOp>()) {
934 auto replicateOperand = replicate.getOperand();
935 if (replicateOperand.getType().isInteger(1)) {
936 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
937 auto trailingZeros =
value.countTrailingZeros();
940 SmallVector<Value, 3> concatOperands;
941 if (trailingZeros != resultWidth - 1) {
943 op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
944 concatOperands.push_back(highZeros);
946 concatOperands.push_back(replicateOperand);
947 if (trailingZeros != 0) {
949 op.getLoc(), APInt::getZero(trailingZeros));
950 concatOperands.push_back(lowZeros);
952 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
960 if (
auto extractOp =
inputs[0].getDefiningOp<ExtractOp>()) {
963 (
value.countLeadingZeros() ||
value.countTrailingZeros())) {
964 unsigned lz =
value.countLeadingZeros();
965 unsigned tz =
value.countTrailingZeros();
968 auto smallTy = rewriter.getIntegerType(
value.getBitWidth() - lz - tz);
969 Value smallElt = rewriter.createOrFold<
ExtractOp>(
970 extractOp.getLoc(), smallTy, extractOp->getOperand(0),
971 extractOp.getLowBit() + tz);
973 APInt smallMask =
value.extractBits(smallTy.getWidth(), tz);
974 if (!smallMask.isAllOnes()) {
975 auto loc =
inputs.back().getLoc();
976 smallElt = rewriter.createOrFold<
AndOp>(
983 SmallVector<Value> resultElts;
985 resultElts.push_back(
986 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
987 resultElts.push_back(smallElt);
989 resultElts.push_back(
990 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
991 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
999 for (
size_t i = 0; i < size - 1; ++i) {
1017 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1018 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1019 source, cmpAgainst);
1027 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1031 auto value = APInt::getZero(getType().cast<IntegerType>().
getWidth());
1032 auto inputs = adaptor.getInputs();
1034 for (
auto operand :
inputs) {
1037 value |= operand.cast<IntegerAttr>().getValue();
1038 if (
value.isAllOnes())
1044 inputs[1].cast<IntegerAttr>().getValue().isZero())
1045 return getInputs()[0];
1048 if (llvm::all_of(getInputs(),
1049 [&](
auto in) {
return in == this->getInputs()[0]; }))
1050 return getInputs()[0];
1053 for (Value arg : getInputs()) {
1055 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1056 for (Value arg2 : getInputs())
1057 if (arg2 == subExpr)
1059 APInt::getAllOnes(getType().cast<IntegerType>().
getWidth()),
1085 PatternRewriter &rewriter) {
1086 assert(concatIdx1 < concatIdx2 &&
"concatIdx1 must be < concatIdx2");
1088 auto inputs = op.getInputs();
1092 assert(concat1 && concat2 &&
"expected indexes to point to ConcatOps");
1095 bool hasConstantOp1 =
1096 llvm::any_of(concat1->getOperands(), [&](Value operand) ->
bool {
1097 return operand.getDefiningOp<hw::ConstantOp>();
1099 if (!hasConstantOp1) {
1100 bool hasConstantOp2 =
1101 llvm::any_of(concat2->getOperands(), [&](Value operand) ->
bool {
1102 return operand.getDefiningOp<hw::ConstantOp>();
1104 if (!hasConstantOp2)
1108 SmallVector<Value> newConcatOperands;
1113 auto operands1 = concat1->getOperands();
1114 auto operands2 = concat2->getOperands();
1116 unsigned consumedWidth1 = 0;
1117 unsigned consumedWidth2 = 0;
1118 for (
auto it1 = operands1.begin(), end1 = operands1.end(),
1119 it2 = operands2.begin(), end2 = operands2.end();
1120 it1 != end1 && it2 != end2;) {
1121 auto operand1 = *it1;
1122 auto operand2 = *it2;
1124 unsigned remainingWidth1 =
1126 unsigned remainingWidth2 =
1128 unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2);
1129 auto narrowedType = rewriter.getIntegerType(widthToConsume);
1131 auto extract1 = rewriter.createOrFold<
ExtractOp>(
1132 op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume);
1133 auto extract2 = rewriter.createOrFold<
ExtractOp>(
1134 op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume);
1136 newConcatOperands.push_back(
1137 rewriter.createOrFold<
OrOp>(op.getLoc(), extract1, extract2,
false));
1139 consumedWidth1 += widthToConsume;
1140 consumedWidth2 += widthToConsume;
1142 if (widthToConsume == remainingWidth1) {
1146 if (widthToConsume == remainingWidth2) {
1152 ConcatOp newOp = rewriter.create<
ConcatOp>(op.getLoc(), newConcatOperands);
1156 SmallVector<Value> newOrOperands;
1157 newOrOperands.append(
inputs.begin(),
inputs.begin() + concatIdx1);
1158 newOrOperands.append(
inputs.begin() + concatIdx1 + 1,
1159 inputs.begin() + concatIdx2);
1160 newOrOperands.append(
inputs.begin() + concatIdx2 + 1,
1162 newOrOperands.push_back(newOp);
1164 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1169 LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1173 auto inputs = op.getInputs();
1174 auto size =
inputs.size();
1175 assert(size > 1 &&
"expected 2 or more operands");
1184 if (matchPattern(
inputs.back(), m_ConstantInt(&
value))) {
1186 if (
value.isZero()) {
1187 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1194 if (matchPattern(
inputs[size - 2], m_ConstantInt(&value2))) {
1196 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1197 newOperands.push_back(cst);
1198 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1206 for (
size_t i = 0; i < size - 1; ++i) {
1219 for (
size_t i = 0; i < size - 1; ++i) {
1221 for (
size_t j = i + 1; j < size; ++j)
1234 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1235 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1236 source, cmpAgainst);
1242 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1244 if (op.getTwoState() && firstMux.getTwoState() &&
1245 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&
value)) &&
1247 SmallVector<Value> conditions{firstMux.getCond()};
1248 auto check = [&](Value v) {
1252 conditions.push_back(mux.getCond());
1253 return mux.getTwoState() &&
1254 firstMux.getTrueValue() == mux.getTrueValue() &&
1255 firstMux.getFalseValue() == mux.getFalseValue();
1257 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1258 auto cond = rewriter.create<
comb::OrOp>(op.getLoc(), conditions,
true);
1259 replaceOpWithNewOpAndCopyName<comb::MuxOp>(
1260 rewriter, op, cond, firstMux.getTrueValue(),
1261 firstMux.getFalseValue(),
true);
1271 OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1275 auto size = getInputs().size();
1276 auto inputs = adaptor.getInputs();
1280 return getInputs()[0];
1283 if (size == 2 && getInputs()[0] == getInputs()[1])
1288 inputs[1].cast<IntegerAttr>().getValue().isZero())
1289 return getInputs()[0];
1293 if (isBinaryNot()) {
1295 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1296 subExpr != getResult())
1306 PatternRewriter &rewriter) {
1307 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1308 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1311 rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1312 icmp.getOperand(1), icmp.getTwoState());
1315 if (op.getNumOperands() > 2) {
1316 SmallVector<Value, 4> newOperands(op.getOperands());
1317 newOperands.pop_back();
1318 newOperands.erase(newOperands.begin() + icmpOperand);
1319 newOperands.push_back(result);
1320 result = rewriter.create<
XorOp>(op.getLoc(), newOperands, op.getTwoState());
1326 LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1330 auto inputs = op.getInputs();
1331 auto size =
inputs.size();
1332 assert(size > 1 &&
"expected 2 or more operands");
1337 "expected idempotent case for 2 elements handled already.");
1338 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1339 inputs.drop_back(2),
false);
1345 if (matchPattern(
inputs.back(), m_ConstantInt(&
value))) {
1347 if (
value.isZero()) {
1348 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1349 inputs.drop_back(),
false);
1355 if (matchPattern(
inputs[size - 2], m_ConstantInt(&value2))) {
1357 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1358 newOperands.push_back(cst);
1359 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1360 newOperands,
false);
1364 bool isSingleBit =
value.getBitWidth() == 1;
1367 for (
size_t i = 0; i < size - 1; ++i) {
1368 Value operand =
inputs[i];
1379 if (isSingleBit && operand.hasOneUse()) {
1380 assert(
value == 1 &&
"single bit constant has to be one if not zero");
1381 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1397 replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
1404 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1409 if (getRhs() == getLhs())
1411 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1414 if (adaptor.getRhs()) {
1416 if (adaptor.getLhs()) {
1419 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1422 hw::PEO::Mul, adaptor.getRhs().cast<TypedAttr>(), negOne);
1424 adaptor.getLhs().cast<TypedAttr>(), rhsNeg);
1428 if (
auto rhsC = adaptor.getRhs().dyn_cast<IntegerAttr>()) {
1429 if (rhsC.getValue().isZero())
1437 LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1443 if (matchPattern(op.getRhs(), m_ConstantInt(&
value))) {
1445 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getLhs(), negCst,
1457 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1461 auto size = getInputs().size();
1465 return getInputs()[0];
1471 LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1475 auto inputs = op.getInputs();
1476 auto size =
inputs.size();
1477 assert(size > 1 &&
"expected 2 or more operands");
1479 APInt
value, value2;
1482 if (matchPattern(
inputs.back(), m_ConstantInt(&
value)) &&
value.isZero()) {
1483 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1484 inputs.drop_back(),
false);
1489 if (matchPattern(
inputs[size - 1], m_ConstantInt(&
value)) &&
1490 matchPattern(
inputs[size - 2], m_ConstantInt(&value2))) {
1492 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1493 newOperands.push_back(cst);
1494 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1495 newOperands,
false);
1501 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1503 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1507 newOperands.push_back(shiftLeftOp);
1508 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1509 newOperands,
false);
1515 if (shlOp && shlOp.getLhs() ==
inputs[size - 2] &&
1516 matchPattern(shlOp.getRhs(), m_ConstantInt(&
value))) {
1518 APInt one(
value.getBitWidth(), 1,
false);
1522 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1523 auto mulOp = rewriter.create<
comb::MulOp>(op.getLoc(), factors,
false);
1525 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1526 newOperands.push_back(mulOp);
1527 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1528 newOperands,
false);
1534 if (mulOp && mulOp.getInputs().size() == 2 &&
1535 mulOp.getInputs()[0] ==
inputs[size - 2] &&
1536 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&
value))) {
1538 APInt one(
value.getBitWidth(), 1,
false);
1540 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1543 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1544 newOperands.push_back(newMulOp);
1545 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1546 newOperands,
false);
1560 if (addOp && addOp.getInputs().size() == 2 &&
1561 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1565 replaceOpWithNewOpAndCopyName<AddOp>(
1566 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1567 op.getTwoState() && addOp.getTwoState());
1574 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1578 auto size = getInputs().size();
1579 auto inputs = adaptor.getInputs();
1583 return getInputs()[0];
1589 for (
auto operand :
inputs) {
1592 value *= operand.cast<IntegerAttr>().getValue();
1601 LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1605 auto inputs = op.getInputs();
1606 auto size =
inputs.size();
1607 assert(size > 1 &&
"expected 2 or more operands");
1609 APInt
value, value2;
1612 if (size == 2 && matchPattern(
inputs.back(), m_ConstantInt(&
value)) &&
1613 value.isPowerOf2()) {
1614 auto shift = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(),
1615 value.exactLogBase2());
1619 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1620 ArrayRef<Value>(shlOp),
false);
1626 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1632 if (matchPattern(
inputs[size - 1], m_ConstantInt(&
value)) &&
1633 matchPattern(
inputs[size - 2], m_ConstantInt(&value2))) {
1635 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1636 newOperands.push_back(cst);
1637 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1653 template <
class Op,
bool isSigned>
1654 static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1655 if (
auto rhsValue = constants[1].dyn_cast_or_null<IntegerAttr>()) {
1657 if (rhsValue.getValue() == 1)
1661 if (rhsValue.getValue().isZero())
1668 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1672 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1675 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1682 template <
class Op,
bool isSigned>
1683 static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1684 if (
auto rhsValue = constants[1].dyn_cast_or_null<IntegerAttr>()) {
1686 if (rhsValue.getValue() == 1)
1687 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1691 if (rhsValue.getValue().isZero())
1695 if (
auto lhsValue = constants[0].dyn_cast_or_null<IntegerAttr>()) {
1697 if (lhsValue.getValue().isZero())
1698 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1705 OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1709 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1712 OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1723 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1727 if (getNumOperands() == 1)
1728 return getOperand(0);
1731 for (
auto attr : adaptor.getInputs())
1732 if (!attr || !attr.isa<IntegerAttr>())
1736 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1737 APInt result(resultWidth, 0);
1739 unsigned nextInsertion = resultWidth;
1741 for (
auto attr : adaptor.getInputs()) {
1742 auto chunk = attr.cast<IntegerAttr>().getValue();
1743 nextInsertion -= chunk.getBitWidth();
1744 result.insertBits(chunk, nextInsertion);
1750 LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1754 auto inputs = op.getInputs();
1755 auto size =
inputs.size();
1756 assert(size > 1 &&
"expected 2 or more operands");
1761 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1762 ValueRange replacements) -> LogicalResult {
1763 SmallVector<Value, 4> newOperands;
1764 newOperands.append(
inputs.begin(),
inputs.begin() + firstOpIndex);
1765 newOperands.append(replacements.begin(), replacements.end());
1766 newOperands.append(
inputs.begin() + lastOpIndex + 1,
inputs.end());
1767 if (newOperands.size() == 1)
1770 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1775 Value commonOperand =
inputs[0];
1776 for (
size_t i = 0; i != size; ++i) {
1778 if (
inputs[i] != commonOperand)
1779 commonOperand = Value();
1783 if (
auto subConcat =
inputs[i].getDefiningOp<ConcatOp>())
1784 return flattenConcat(i, i, subConcat->getOperands());
1789 if (
auto cst =
inputs[i].getDefiningOp<hw::ConstantOp>()) {
1790 if (
auto prevCst =
inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1791 unsigned prevWidth = prevCst.getValue().getBitWidth();
1792 unsigned thisWidth = cst.getValue().getBitWidth();
1793 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1794 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1798 return flattenConcat(i - 1, i, replacement);
1805 rewriter.createOrFold<ReplicateOp>(op.getLoc(),
inputs[i], 2);
1806 return flattenConcat(i - 1, i, replacement);
1811 if (
auto repl =
inputs[i].getDefiningOp<ReplicateOp>()) {
1813 if (repl.getOperand() ==
inputs[i - 1]) {
1814 Value replacement = rewriter.createOrFold<ReplicateOp>(
1815 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1816 return flattenConcat(i - 1, i, replacement);
1819 if (
auto prevRepl =
inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1820 if (prevRepl.getOperand() == repl.getOperand()) {
1821 Value replacement = rewriter.createOrFold<ReplicateOp>(
1822 op.getLoc(), repl.getOperand(),
1823 repl.getMultiple() + prevRepl.getMultiple());
1824 return flattenConcat(i - 1, i, replacement);
1830 if (
auto repl =
inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1831 if (repl.getOperand() ==
inputs[i]) {
1832 Value replacement = rewriter.createOrFold<ReplicateOp>(
1833 op.getLoc(),
inputs[i], repl.getMultiple() + 1);
1834 return flattenConcat(i - 1, i, replacement);
1840 if (
auto extract =
inputs[i].getDefiningOp<ExtractOp>()) {
1841 if (
auto prevExtract =
inputs[i - 1].getDefiningOp<ExtractOp>()) {
1842 if (extract.getInput() == prevExtract.getInput()) {
1843 auto thisWidth = extract.getType().cast<IntegerType>().
getWidth();
1844 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1845 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1846 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1847 Value replacement = rewriter.create<
ExtractOp>(
1848 op.getLoc(), resType, extract.getInput(),
1849 extract.getLowBit());
1850 return flattenConcat(i - 1, i, replacement);
1863 static std::optional<ArraySlice>
get(Value
value) {
1864 assert(
value.getType().isa<IntegerType>() &&
"expected integer type");
1866 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1869 if (
auto arraySlice =
1872 arraySlice.getInput(), arraySlice.getLowIndex(),
1873 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1875 return std::nullopt;
1881 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1882 prevExtractOpt->input == extractOpt->input &&
1884 extractOpt->width)) {
1886 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1888 extractOpt->width + prevExtractOpt->width);
1891 op.getLoc(), resIntType,
1893 prevExtractOpt->input,
1894 extractOpt->index));
1895 return flattenConcat(i - 1, i, replacement);
1903 if (commonOperand) {
1904 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
1916 OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1921 if (getTrueValue() == getFalseValue())
1922 return getTrueValue();
1926 if (
auto pred = adaptor.getCond().dyn_cast_or_null<IntegerAttr>()) {
1927 if (pred.getValue().isZero())
1928 return getFalseValue();
1929 return getTrueValue();
1933 if (
auto tv = adaptor.getTrueValue().dyn_cast_or_null<IntegerAttr>())
1934 if (
auto fv = adaptor.getFalseValue().dyn_cast_or_null<IntegerAttr>())
1935 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1952 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1954 auto requiredPredicate =
1955 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1956 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1966 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
1969 for (
auto operand : orOp.getOperands())
1976 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
1979 for (
auto operand : andOp.getOperands())
1997 PatternRewriter &rewriter) {
2000 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2003 Value indexValue = rootCmp.getLhs();
2006 auto getCaseValue = [&](
MuxOp mux) -> Value {
2007 return mux.getOperand(1 +
unsigned(!isFalseSide));
2012 auto getTreeValue = [&](
MuxOp mux) -> Value {
2013 return mux.getOperand(1 +
unsigned(isFalseSide));
2018 SmallVector<Location> locationsFound;
2019 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2023 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2025 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2026 valuesFound.push_back({cst, getCaseValue(mux)});
2027 locationsFound.push_back(mux.getCond().getLoc());
2028 locationsFound.push_back(mux->getLoc());
2033 if (!collectConstantValues(rootMux))
2037 if (rootMux->hasOneUse()) {
2038 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2039 if (getTreeValue(userMux) == rootMux.getResult() &&
2047 auto nextTreeValue = getTreeValue(rootMux);
2049 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2050 if (!nextMux || !nextMux->hasOneUse())
2052 if (!collectConstantValues(nextMux))
2054 nextTreeValue = getTreeValue(nextMux);
2060 if (valuesFound.size() < 3)
2065 auto indexWidth = indexValue.getType().cast<IntegerType>().
getWidth();
2066 if (indexWidth >= 9)
2072 uint64_t tableSize = 1ULL << indexWidth;
2073 if (valuesFound.size() < (tableSize * 5) / 8)
2078 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2083 for (
auto &elt : llvm::reverse(valuesFound)) {
2084 uint64_t idx = elt.first.getValue().getZExtValue();
2085 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2086 table[idx] = elt.second;
2091 std::reverse(table.begin(), table.end());
2094 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2096 replaceOpWithNewOpAndCopyName<hw::ArrayGetOp>(rewriter, rootMux, array,
2111 PatternRewriter &rewriter) {
2112 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2113 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2117 if (fullyAssoc->getNumOperands() == 2)
2118 return fullyAssoc->getOperand(operandNo ^ 1);
2121 if (fullyAssoc->hasOneUse()) {
2122 fullyAssoc->eraseOperand(operandNo);
2123 return fullyAssoc->getResult(0);
2127 SmallVector<Value> operands;
2128 operands.append(fullyAssoc->getOperands().begin(),
2129 fullyAssoc->getOperands().begin() + operandNo);
2130 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2131 fullyAssoc->getOperands().end());
2133 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2134 Value excluded = fullyAssoc->getOperand(operandNo);
2138 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2140 return opWithoutExcluded;
2150 PatternRewriter &rewriter) {
2153 Operation *subExpr =
2154 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2155 if (!subExpr || subExpr->getNumOperands() < 2)
2159 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2164 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2165 size_t opNo = 0, e = subExpr->getNumOperands();
2166 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2172 Value cond = op.getCond();
2178 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2180 Value subCond = subMux.getCond();
2183 if (subMux.getTrueValue() == commonValue)
2184 otherValue = subMux.getFalseValue();
2185 else if (subMux.getFalseValue() == commonValue) {
2186 otherValue = subMux.getTrueValue();
2196 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2197 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, cond, commonValue,
2198 otherValue, op.getTwoState());
2204 bool isaAndOp = isa<AndOp>(subExpr);
2205 if (isTrueOperand ^ isaAndOp)
2209 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2212 bool isaXorOp = isa<XorOp>(subExpr);
2213 bool isaOrOp = isa<OrOp>(subExpr);
2222 if (isaOrOp || isaXorOp) {
2223 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2224 restOfAssoc,
false);
2226 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, masked, commonValue,
2229 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, masked, commonValue,
2235 assert(isaAndOp &&
"unexpected operation here");
2236 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2237 restOfAssoc,
false);
2238 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, masked, commonValue,
2249 PatternRewriter &rewriter) {
2252 if (!isa<ConcatOp>(trueOp))
2256 SmallVector<Value> trueOperands, falseOperands;
2260 size_t numTrueOperands = trueOperands.size();
2261 size_t numFalseOperands = falseOperands.size();
2263 if (!numTrueOperands || !numFalseOperands ||
2264 (trueOperands.front() != falseOperands.front() &&
2265 trueOperands.back() != falseOperands.back()))
2269 if (trueOperands.front() == falseOperands.front()) {
2270 SmallVector<Value> operands;
2272 for (i = 0; i < numTrueOperands; ++i) {
2273 Value trueOperand = trueOperands[i];
2274 if (trueOperand == falseOperands[i])
2275 operands.push_back(trueOperand);
2279 if (i == numTrueOperands) {
2286 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2287 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2288 mux->getLoc(), operands.front(), operands.size());
2290 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2294 operands.append(trueOperands.begin() + i, trueOperands.end());
2295 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2297 operands.append(falseOperands.begin() + i, falseOperands.end());
2299 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2302 Value lsb = rewriter.createOrFold<
MuxOp>(
2303 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2304 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2309 if (trueOperands.back() == falseOperands.back()) {
2310 SmallVector<Value> operands;
2313 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2314 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2315 operands.push_back(trueOperand);
2319 std::reverse(operands.begin(), operands.end());
2320 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2324 operands.append(trueOperands.begin(), trueOperands.end() - i);
2325 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2327 operands.append(falseOperands.begin(), falseOperands.end() - i);
2329 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2331 Value msb = rewriter.createOrFold<
MuxOp>(
2332 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2333 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, msb, sharedLSB);
2345 if (!trueVec || !falseVec)
2347 if (!trueVec.isUniform() || !falseVec.isUniform())
2351 op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2352 falseVec.getUniformElement(), op.getTwoState());
2354 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2361 using OpRewritePattern::OpRewritePattern;
2363 LogicalResult matchAndRewrite(
MuxOp op,
2364 PatternRewriter &rewriter)
const override;
2367 LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2368 PatternRewriter &rewriter)
const {
2377 if (matchPattern(op.getTrueValue(), m_ConstantInt(&
value))) {
2378 if (
value.getBitWidth() == 1) {
2380 if (
value.isZero()) {
2382 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, notCond,
2383 op.getFalseValue(),
false);
2388 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getCond(),
2389 op.getFalseValue(),
false);
2395 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2400 APInt xorValue =
value ^ value2;
2401 if (xorValue.isPowerOf2()) {
2402 unsigned leadingZeros = xorValue.countLeadingZeros();
2403 unsigned trailingZeros =
value.getBitWidth() - leadingZeros - 1;
2404 SmallVector<Value, 3> operands;
2412 if (leadingZeros > 0)
2413 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2414 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2418 auto v1 = rewriter.createOrFold<
ExtractOp>(
2419 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2420 auto v2 = rewriter.createOrFold<
ExtractOp>(
2421 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2422 operands.push_back(rewriter.createOrFold<
MuxOp>(
2423 op.getLoc(), op.getCond(), v1, v2,
false));
2425 if (trailingZeros > 0)
2426 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2427 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2429 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
2436 if (
value.isAllOnes() && value2.isZero()) {
2437 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2444 if (matchPattern(op.getFalseValue(), m_ConstantInt(&
value)) &&
2445 value.getBitWidth() == 1) {
2447 if (
value.isZero()) {
2448 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getCond(),
2449 op.getTrueValue(),
false);
2456 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2457 op.getFalseValue(),
false);
2458 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, notCond,
2459 op.getTrueValue(),
false);
2465 Operation *condOp = op.getCond().getDefiningOp();
2466 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2468 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, op.getType(), subExpr,
2469 op.getFalseValue(), op.getTrueValue(),
2477 if (condOp && condOp->hasOneUse()) {
2478 SmallVector<Value> invertedOperands;
2482 auto getInvertedOperands = [&]() ->
bool {
2483 for (Value operand : condOp->getOperands()) {
2484 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2485 invertedOperands.push_back(subExpr);
2492 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2494 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2495 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newOr,
2497 op.getTrueValue(), op.getTwoState());
2500 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2502 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2503 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newAnd,
2505 op.getTrueValue(), op.getTwoState());
2511 dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp())) {
2513 if (op.getCond() == falseMux.getCond()) {
2514 replaceOpWithNewOpAndCopyName<MuxOp>(
2515 rewriter, op, op.getCond(), op.getTrueValue(),
2516 falseMux.getFalseValue(), op.getTwoStateAttr());
2526 dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp())) {
2528 if (op.getCond() == trueMux.getCond()) {
2529 replaceOpWithNewOpAndCopyName<MuxOp>(
2530 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2531 op.getFalseValue(), op.getTwoStateAttr());
2541 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2542 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2543 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2544 trueMux.getTrueValue() == falseMux.getTrueValue()) {
2545 auto subMux = rewriter.create<
MuxOp>(
2546 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2547 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2548 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2549 trueMux.getTrueValue(), subMux,
2550 op.getTwoStateAttr());
2555 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2556 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2557 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2558 trueMux.getFalseValue() == falseMux.getFalseValue()) {
2559 auto subMux = rewriter.create<
MuxOp>(
2560 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2561 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2562 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2563 subMux, trueMux.getFalseValue(),
2564 op.getTwoStateAttr());
2569 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2570 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2571 trueMux && falseMux &&
2572 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2573 trueMux.getFalseValue() == falseMux.getFalseValue()) {
2574 auto subMux = rewriter.create<
MuxOp>(
2575 rewriter.getFusedLoc(
2576 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2577 op.getCond(), trueMux.getCond(), falseMux.getCond());
2578 replaceOpWithNewOpAndCopyName<MuxOp>(
2579 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2580 op.getTwoStateAttr());
2592 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2593 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2594 if (trueOp->getName() == falseOp->getName())
2611 if (op.getInputs().empty() || op.isUniform())
2613 auto inputs = op.getInputs();
2624 for (
size_t i = 1, n =
inputs.size(); i < n; ++i) {
2626 if (!input || first.getCond() != input.getCond())
2631 SmallVector<Value> trues{first.getTrueValue()};
2632 SmallVector<Value> falses{first.getFalseValue()};
2633 SmallVector<Location> locs{first->getLoc()};
2634 bool isTwoState =
true;
2635 for (
size_t i = 1, n =
inputs.size(); i < n; ++i) {
2637 trues.push_back(input.getTrueValue());
2638 falses.push_back(input.getFalseValue());
2639 locs.push_back(input->getLoc());
2640 if (!input.getTwoState())
2649 auto arrayTy = op.getType();
2652 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2653 trueValues, falseValues, isTwoState);
2658 using OpRewritePattern::OpRewritePattern;
2661 PatternRewriter &rewriter)
const override {
2665 if (foldArrayOfMuxes(op, rewriter))
2673 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2674 MLIRContext *context) {
2675 results.insert<MuxRewriter, ArrayRewriter>(context);
2686 switch (predicate) {
2687 case ICmpPredicate::eq:
2689 case ICmpPredicate::ne:
2691 case ICmpPredicate::slt:
2692 return lhs.slt(rhs);
2693 case ICmpPredicate::sle:
2694 return lhs.sle(rhs);
2695 case ICmpPredicate::sgt:
2696 return lhs.sgt(rhs);
2697 case ICmpPredicate::sge:
2698 return lhs.sge(rhs);
2699 case ICmpPredicate::ult:
2700 return lhs.ult(rhs);
2701 case ICmpPredicate::ule:
2702 return lhs.ule(rhs);
2703 case ICmpPredicate::ugt:
2704 return lhs.ugt(rhs);
2705 case ICmpPredicate::uge:
2706 return lhs.uge(rhs);
2707 case ICmpPredicate::ceq:
2709 case ICmpPredicate::cne:
2711 case ICmpPredicate::weq:
2713 case ICmpPredicate::wne:
2716 llvm_unreachable(
"unknown comparison predicate");
2722 switch (predicate) {
2723 case ICmpPredicate::eq:
2724 case ICmpPredicate::sle:
2725 case ICmpPredicate::sge:
2726 case ICmpPredicate::ule:
2727 case ICmpPredicate::uge:
2728 case ICmpPredicate::ceq:
2729 case ICmpPredicate::weq:
2731 case ICmpPredicate::ne:
2732 case ICmpPredicate::slt:
2733 case ICmpPredicate::sgt:
2734 case ICmpPredicate::ult:
2735 case ICmpPredicate::ugt:
2736 case ICmpPredicate::cne:
2737 case ICmpPredicate::wne:
2740 llvm_unreachable(
"unknown comparison predicate");
2743 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2749 if (getLhs() == getRhs()) {
2755 if (
auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>()) {
2756 if (
auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>()) {
2767 template <
typename Range>
2769 size_t commonPrefixLength = 0;
2770 auto ia = a.begin();
2771 auto ib = b.begin();
2773 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2779 return commonPrefixLength;
2783 size_t totalWidth = 0;
2784 for (
auto operand : operands) {
2787 ssize_t
width = operand.getType().getIntOrFloatBitWidth();
2789 totalWidth +=
width;
2799 PatternRewriter &rewriter) {
2803 SmallVector<Value> lhsOperands, rhsOperands;
2806 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2808 auto formCatOrReplicate = [&](Location loc,
2809 ArrayRef<Value> operands) -> Value {
2810 assert(!operands.empty());
2811 Value sameElement = operands[0];
2812 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2813 if (sameElement != operands[i])
2814 sameElement = Value();
2816 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2818 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2821 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2822 Value rhs) -> LogicalResult {
2823 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2828 size_t commonPrefixLength =
2830 if (commonPrefixLength == lhsOperands.size()) {
2833 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
2839 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2841 size_t commonPrefixTotalWidth =
2842 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2843 size_t commonSuffixTotalWidth =
2844 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2845 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2846 .drop_back(commonSuffixLength);
2847 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2848 .drop_back(commonSuffixLength);
2850 auto replaceWithoutReplicatingSignBit = [&]() {
2851 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2852 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2853 return replaceWith(op.getPredicate(), newLhs, newRhs);
2856 auto replaceWithReplicatingSignBit = [&]() {
2857 auto firstNonEmptyValue = lhsOperands[0];
2858 auto firstNonEmptyElemWidth =
2859 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2860 Value signBit = rewriter.createOrFold<
ExtractOp>(
2861 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2863 auto newLhs = rewriter.
create<
ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2864 auto newRhs = rewriter.create<
ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2865 return replaceWith(op.getPredicate(), newLhs, newRhs);
2868 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2870 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2871 return replaceWithoutReplicatingSignBit();
2877 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2878 return replaceWithReplicatingSignBit();
2880 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2882 return replaceWithoutReplicatingSignBit();
2896 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2897 PatternRewriter &rewriter) {
2901 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2902 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2905 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2906 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
2914 SmallVector<Value> newConcatOperands;
2915 auto newConstant = APInt::getZeroWidth();
2920 unsigned knownMSB = bitsKnown.countLeadingOnes();
2922 Value operand = cmpOp.getLhs();
2927 while (knownMSB != bitsKnown.getBitWidth()) {
2930 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2933 unsigned unknownBits = bitsKnown.countLeadingZeros();
2934 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2935 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
2936 operand.getLoc(), operand, lowBit,
2938 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2941 newConcatOperands.push_back(spanOperand);
2944 if (newConstant.getBitWidth() != 0)
2945 newConstant = newConstant.concat(spanConstant);
2947 newConstant = spanConstant;
2950 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
2951 bitsKnown = bitsKnown.trunc(newWidth);
2952 knownMSB = bitsKnown.countLeadingOnes();
2958 if (newConcatOperands.empty()) {
2959 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
2960 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
2966 Value concatResult =
2967 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
2971 cmpOp.getOperand(1).getLoc(), newConstant);
2973 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
2974 concatResult, newConstantOp,
2975 cmpOp.getTwoState());
2981 PatternRewriter &rewriter) {
2982 auto ip = rewriter.saveInsertionPoint();
2983 rewriter.setInsertionPoint(xorOp);
2985 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
2987 xorRHS.getValue() ^ rhs);
2989 switch (xorOp.getNumOperands()) {
2993 APInt::getZero(rhs.getBitWidth()));
2997 newLHS = xorOp.getOperand(0);
3001 SmallVector<Value> newOperands(xorOp.getOperands());
3002 newOperands.pop_back();
3003 newLHS = rewriter.create<
XorOp>(xorOp.getLoc(), newOperands,
false);
3007 bool xorMultipleUses = !xorOp->hasOneUse();
3011 if (xorMultipleUses)
3012 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3016 rewriter.restoreInsertionPoint(ip);
3017 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3018 newLHS, newRHS,
false);
3021 LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3028 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3029 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3030 "Should be folded");
3031 replaceOpWithNewOpAndCopyName<ICmpOp>(
3032 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3033 op.getRhs(), op.getLhs(), op.getTwoState());
3038 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3040 return rewriter.create<
hw::ConstantOp>(op.getLoc(), std::move(constant));
3043 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3044 Value rhs) -> LogicalResult {
3045 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
3050 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3051 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
3052 APInt(1, constant));
3056 switch (op.getPredicate()) {
3057 case ICmpPredicate::slt:
3059 if (rhs.isMaxSignedValue())
3060 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3062 if (rhs.isMinSignedValue())
3063 return replaceWithConstantI1(0);
3065 if ((rhs - 1).isMinSignedValue())
3066 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3069 case ICmpPredicate::sgt:
3071 if (rhs.isMinSignedValue())
3072 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3074 if (rhs.isMaxSignedValue())
3075 return replaceWithConstantI1(0);
3077 if ((rhs + 1).isMaxSignedValue())
3078 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3081 case ICmpPredicate::ult:
3083 if (rhs.isAllOnes())
3084 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3087 return replaceWithConstantI1(0);
3089 if ((rhs - 1).isZero())
3090 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3094 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3095 rhs.getBitWidth()) {
3096 auto numOnes = rhs.countLeadingOnes();
3097 auto smaller = rewriter.create<
ExtractOp>(
3098 op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3099 return replaceWith(ICmpPredicate::ne, smaller,
3104 case ICmpPredicate::ugt:
3107 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3109 if (rhs.isAllOnes())
3110 return replaceWithConstantI1(0);
3112 if ((rhs + 1).isAllOnes())
3113 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3117 if ((rhs + 1).isPowerOf2()) {
3118 auto numOnes = rhs.countTrailingOnes();
3119 auto newWidth = rhs.getBitWidth() - numOnes;
3120 auto smaller = rewriter.create<
ExtractOp>(op.getLoc(), op.getLhs(),
3122 return replaceWith(ICmpPredicate::ne, smaller,
3127 case ICmpPredicate::sle:
3129 if (rhs.isMaxSignedValue())
3130 return replaceWithConstantI1(1);
3132 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3133 case ICmpPredicate::sge:
3135 if (rhs.isMinSignedValue())
3136 return replaceWithConstantI1(1);
3138 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3139 case ICmpPredicate::ule:
3141 if (rhs.isAllOnes())
3142 return replaceWithConstantI1(1);
3144 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3145 case ICmpPredicate::uge:
3148 return replaceWithConstantI1(1);
3150 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3151 case ICmpPredicate::eq:
3152 if (rhs.getBitWidth() == 1) {
3155 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3160 if (rhs.isAllOnes()) {
3167 case ICmpPredicate::ne:
3168 if (rhs.getBitWidth() == 1) {
3174 if (rhs.isAllOnes()) {
3176 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3183 case ICmpPredicate::ceq:
3184 case ICmpPredicate::cne:
3185 case ICmpPredicate::weq:
3186 case ICmpPredicate::wne:
3192 if (op.getPredicate() == ICmpPredicate::eq ||
3193 op.getPredicate() == ICmpPredicate::ne) {
3198 if (!knownBits.isUnknown())
3205 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3212 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3213 if (rhs.isAllOnes() || rhs.isZero()) {
3214 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3216 op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(
width)
3217 : APInt::getZero(
width));
3218 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, op.getPredicate(),
3219 replicateOp.getInput(), cst,
3229 if (Operation *opLHS = op.getLhs().getDefiningOp())
3230 if (Operation *opRHS = op.getRhs().getDefiningOp())
3231 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3232 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "sv.namehint" attribute.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1, size_t concatIdx2, PatternRewriter &rewriter)
Simplify concat ops in an or op when a constant operand is present in either concat.
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
llvm::SmallVector< StringAttr > inputs
def create(data_type, value)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
uint64_t getWidth(Type t)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
bool isOffset(Value base, Value index, uint64_t offset)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...