12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallBitVector.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/KnownBits.h"
20 using namespace circt;
22 using namespace matchers;
33 Block *thisBlock = op->getBlock();
34 return llvm::any_of(op->getOperands(), [&](Value operand) {
35 return operand.getParentBlock() != thisBlock;
45 ArrayRef<Value> operands, OpBuilder &builder) {
46 OperationState state(loc, name);
47 state.addOperands(operands);
48 state.addTypes(operands[0].getType());
49 return builder.create(state)->getResult(0);
52 static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
60 for (
auto op :
concat.getOperands())
62 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
63 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
75 if (
auto *newOp = newValue.getDefiningOp()) {
76 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
77 if (name && !newOp->hasAttr(
"sv.namehint"))
78 rewriter.modifyOpInPlace(newOp,
79 [&] { newOp->setAttr(
"sv.namehint", name); });
81 rewriter.replaceOp(op, newValue);
87 template <
typename OpTy,
typename... Args>
89 Operation *op, Args &&...args) {
90 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
92 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
93 if (name && !newOp->hasAttr(
"sv.namehint"))
94 rewriter.modifyOpInPlace(newOp,
95 [&] { newOp->setAttr(
"sv.namehint", name); });
104 return op->hasAttr(
"sv.attributes");
108 template <
typename SubType>
109 struct ComplementMatcher {
111 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
112 bool match(Operation *op) {
113 auto xorOp = dyn_cast<XorOp>(op);
114 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
119 template <
typename SubType>
120 static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
121 return ComplementMatcher<SubType>(subExpr);
127 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
128 "must be commutative operations"));
129 if (op->hasOneUse()) {
130 auto *user = *op->getUsers().begin();
131 return user->getName() == op->getName() &&
132 op->getAttrOfType<UnitAttr>(
"twoState") ==
133 user->getAttrOfType<UnitAttr>(
"twoState") &&
134 op->getBlock() == user->getBlock();
149 auto inputs = op->getOperands();
151 SmallVector<Value, 4> newOperands;
152 SmallVector<Location, 4> newLocations{op->getLoc()};
153 newOperands.reserve(inputs.size());
155 decltype(inputs.begin()) current, end;
158 SmallVector<Element> worklist;
159 worklist.push_back({inputs.begin(), inputs.end()});
160 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
161 bool changed =
false;
162 while (!worklist.empty()) {
163 auto &element = worklist.back();
166 if (element.current == element.end) {
171 Value value = *element.current++;
172 auto *flattenOp = value.getDefiningOp();
175 if (!flattenOp || flattenOp->getName() != op->getName() ||
176 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
177 flattenOp->getBlock() != op->getBlock()) {
178 newOperands.push_back(value);
183 if (!value.hasOneUse()) {
191 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
194 newOperands.push_back(value);
202 auto flattenOpInputs = flattenOp->getOperands();
203 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
204 newLocations.push_back(flattenOp->getLoc());
211 op->getName(), newOperands, rewriter);
213 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
221 static std::pair<size_t, size_t>
223 size_t originalOpWidth) {
224 auto users = op->getUsers();
226 "getLowestBitAndHighestBitRequired cannot operate on "
227 "a empty list of uses.");
231 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
232 size_t highestBitRequired = 0;
234 for (
auto *user : users) {
235 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
236 size_t lowBit = extractOp.getLowBit();
238 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
239 highestBitRequired = std::max(highestBitRequired, highBit);
240 lowestBitRequired = std::min(lowestBitRequired, lowBit);
244 highestBitRequired = originalOpWidth - 1;
245 lowestBitRequired = 0;
249 return {lowestBitRequired, highestBitRequired};
252 template <
class OpTy>
254 PatternRewriter &rewriter) {
255 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
261 if (range.second + 1 == opType.getWidth() && range.first == 0)
264 SmallVector<Value> args;
265 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
266 for (
auto inop : op.getOperands()) {
268 if (inop.getType() != op.getType())
269 args.push_back(inop);
271 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
274 auto newop = rewriter.create<OpTy>(op.getLoc(), newType, args);
275 newop->setDialectAttrs(op->getDialectAttrs());
276 if (op.getTwoState())
277 newop.setTwoState(
true);
279 Value newResult = newop.getResult();
281 newResult = rewriter.createOrFold<
ConcatOp>(
282 op.getLoc(), newResult,
284 APInt::getZero(range.first)));
285 if (range.second + 1 < opType.getWidth())
286 newResult = rewriter.createOrFold<
ConcatOp>(
289 op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
291 rewriter.replaceOp(op, newResult);
299 OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
304 if (cast<IntegerType>(getType()).
getWidth() ==
305 getInput().getType().getIntOrFloatBitWidth())
309 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
310 if (input.getValue().getBitWidth() == 1) {
311 if (input.getValue().isZero())
313 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
316 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
320 APInt result = APInt::getZeroWidth();
321 for (
auto i = getMultiple(); i != 0; --i)
322 result = result.concat(input.getValue());
329 OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
334 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
335 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
347 hw::PEO paramOpcode) {
348 assert(operands.size() == 2 &&
"binary op takes two operands");
349 if (!operands[0] || !operands[1])
355 cast<TypedAttr>(operands[1]));
358 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
362 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
363 unsigned shift = rhs.getValue().getZExtValue();
364 unsigned width = getType().getIntOrFloatBitWidth();
366 return getOperand(0);
368 return getIntAttr(APInt::getZero(width), getContext());
380 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
383 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
384 unsigned shift = value.getZExtValue();
387 if (width <= shift || shift == 0)
391 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
395 rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), 0, width - shift);
397 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, extract, zeros);
401 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
405 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
406 unsigned shift = rhs.getValue().getZExtValue();
408 return getOperand(0);
410 unsigned width = getType().getIntOrFloatBitWidth();
412 return getIntAttr(APInt::getZero(width), getContext());
423 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
426 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
427 unsigned shift = value.getZExtValue();
430 if (width <= shift || shift == 0)
434 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
437 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
440 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, zeros, extract);
444 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
448 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
449 if (rhs.getValue().getZExtValue() == 0)
450 return getOperand(0);
461 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
464 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
465 unsigned shift = value.getZExtValue();
468 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
469 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
471 if (width <= shift) {
476 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
479 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, sext, extract);
487 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
492 if (getInput().getType() == getType())
496 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
497 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
498 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
509 PatternRewriter &rewriter) {
510 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
511 size_t beginOfFirstRelevantElement = 0;
512 auto it = reversedConcatArgs.begin();
513 size_t lowBit = op.getLowBit();
516 for (; it != reversedConcatArgs.end(); it++) {
517 assert(beginOfFirstRelevantElement <= lowBit &&
518 "incorrectly moved past an element that lowBit has coverage over");
521 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
522 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
546 beginOfFirstRelevantElement += operandWidth;
548 assert(it != reversedConcatArgs.end() &&
549 "incorrectly failed to find an element which contains coverage of "
552 SmallVector<Value> reverseConcatArgs;
553 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
554 size_t extractLo = lowBit - beginOfFirstRelevantElement;
559 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
560 auto concatArg = *it;
561 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
562 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
564 if (widthToConsume == operandWidth && extractLo == 0) {
565 reverseConcatArgs.push_back(concatArg);
568 reverseConcatArgs.push_back(
569 rewriter.create<
ExtractOp>(op.getLoc(), resultType, *it, extractLo));
572 widthRemaining -= widthToConsume;
578 if (reverseConcatArgs.size() == 1) {
581 replaceOpWithNewOpAndCopyName<ConcatOp>(
582 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
589 PatternRewriter &rewriter) {
590 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
591 auto replicateEltWidth =
592 replicate.getOperand().getType().getIntOrFloatBitWidth();
596 if (op.getLowBit() % replicateEltWidth == 0 &&
597 extractResultWidth % replicateEltWidth == 0) {
598 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
599 replicate.getOperand());
605 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
607 replaceOpWithNewOpAndCopyName<ExtractOp>(
608 rewriter, op, op.getType(), replicate.getOperand(),
609 op.getLowBit() % replicateEltWidth);
622 auto *inputOp = op.getInput().getDefiningOp();
629 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
631 if (knownBits.isConstant()) {
632 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
633 knownBits.getConstant());
639 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
640 replaceOpWithNewOpAndCopyName<ExtractOp>(
641 rewriter, op, op.getType(), innerExtract.getInput(),
642 innerExtract.getLowBit() + op.getLowBit());
647 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
651 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
657 if (inputOp && inputOp->getNumOperands() == 2 &&
658 isa<AndOp, OrOp, XorOp>(inputOp)) {
659 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
660 auto extractedCst = cstRHS.getValue().extractBits(
661 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
662 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
663 replaceOpWithNewOpAndCopyName<ExtractOp>(
664 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
672 if (isa<AndOp>(inputOp)) {
675 unsigned lz = extractedCst.countLeadingZeros();
676 unsigned tz = extractedCst.countTrailingZeros();
677 unsigned pop = extractedCst.popcount();
678 if (extractedCst.getBitWidth() - lz - tz == pop) {
679 auto resultTy = rewriter.getIntegerType(pop);
680 SmallVector<Value> resultElts;
683 op.getLoc(), APInt::getZero(lz)));
684 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
685 op.getLoc(), resultTy, inputOp->getOperand(0),
686 op.getLowBit() + tz));
689 op.getLoc(), APInt::getZero(tz)));
690 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
699 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
700 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
702 if (shlOp->hasOneUse())
704 if (lhsCst.getValue().isOne()) {
707 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
708 replaceOpWithNewOpAndCopyName<ICmpOp>(
709 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
725 hw::PEO paramOpcode) {
726 assert(operands.size() > 1 &&
"caller should handle one-operand case");
729 if (!operands[1] || !operands[0])
733 if (llvm::all_of(operands.drop_front(2),
734 [&](Attribute in) { return !!in; })) {
735 SmallVector<mlir::TypedAttr> typedOperands;
736 typedOperands.reserve(operands.size());
737 for (
auto operand : operands) {
738 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
739 typedOperands.push_back(typedOperand);
743 if (typedOperands.size() == operands.size())
760 size_t concatIdx,
const APInt &cst,
761 PatternRewriter &rewriter) {
762 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
763 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
768 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
769 auto *operandOp = operand.getDefiningOp();
774 if (isa<hw::ConstantOp>(operandOp))
778 return operandOp->getName() == logicalOp->getName() &&
779 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
780 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
788 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
789 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
796 SmallVector<Value> newConcatOperands;
797 newConcatOperands.reserve(concatOp->getNumOperands());
800 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
801 for (Value operand : concatOp->getOperands()) {
802 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
803 nextOperandBit -= operandWidth;
806 logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
808 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
817 if (logicalOp->getNumOperands() > 2) {
818 auto origOperands = logicalOp->getOperands();
819 SmallVector<Value> operands;
821 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
823 operands.append(origOperands.begin() + concatIdx + 1,
824 origOperands.begin() + (origOperands.size() - 1));
826 operands.push_back(newResult);
827 newResult = createLogicalOp(operands);
837 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
839 for (
auto op : operands) {
840 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
841 icmpOp && icmpOp.getTwoState()) {
842 auto predicate = icmpOp.getPredicate();
843 auto lhs = icmpOp.getLhs();
844 auto rhs = icmpOp.getRhs();
845 if (seenPredicates.contains(
846 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
849 seenPredicates.insert({predicate, lhs, rhs});
855 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
859 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
861 auto inputs = adaptor.getInputs();
864 for (
auto operand : inputs) {
867 value &= cast<IntegerAttr>(operand).getValue();
873 if (inputs.size() == 2 && inputs[1] &&
874 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
875 return getInputs()[0];
878 if (llvm::all_of(getInputs(),
879 [&](
auto in) {
return in == this->getInputs()[0]; }))
880 return getInputs()[0];
883 for (Value arg : getInputs()) {
886 for (Value arg2 : getInputs())
889 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
910 template <
typename Op>
912 if (!op.getType().isInteger(1))
915 auto inputs = op.getInputs();
916 size_t size = inputs.size();
918 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
921 Value source = sourceOp.getOperand();
924 if (size != source.getType().getIntOrFloatBitWidth())
928 llvm::BitVector bits(size);
929 bits.set(sourceOp.getLowBit());
931 for (
size_t i = 1; i != size; ++i) {
932 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
933 if (!extractOp || extractOp.getOperand() != source)
935 bits.set(extractOp.getLowBit());
938 return bits.all() ? source : Value();
945 template <
typename Op>
948 constexpr
unsigned limit = 3;
949 auto inputs = op.getInputs();
951 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
952 llvm::SmallDenseSet<Op, 8> checked;
959 llvm::SmallVector<OpWithDepth, 8> worklist;
961 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
965 if (depth < limit && input.getParentBlock() == op->getBlock()) {
966 auto inputOp = input.template getDefiningOp<Op>();
967 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
968 checked.insert(inputOp).second)
969 worklist.push_back({inputOp, depth + 1});
973 for (
auto input : uniqueInputs)
976 while (!worklist.empty()) {
977 auto item = worklist.pop_back_val();
979 for (
auto input : item.op.getInputs()) {
980 uniqueInputs.remove(input);
981 enqueue(input, item.depth);
985 if (uniqueInputs.size() < inputs.size()) {
986 replaceOpWithNewOpAndCopyName<Op>(rewriter, op, op.getType(),
987 uniqueInputs.getArrayRef(),
996 auto inputs = op.getInputs();
997 auto size = inputs.size();
1011 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
1015 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1017 if (value.isAllOnes()) {
1018 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1019 inputs.drop_back(),
false);
1027 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1028 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value & value2);
1029 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1030 newOperands.push_back(cst);
1031 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1032 newOperands,
false);
1037 if (size == 2 && value.isPowerOf2()) {
1042 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1043 auto replicateOperand = replicate.getOperand();
1044 if (replicateOperand.getType().isInteger(1)) {
1045 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1046 auto trailingZeros = value.countTrailingZeros();
1049 SmallVector<Value, 3> concatOperands;
1050 if (trailingZeros != resultWidth - 1) {
1052 op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
1053 concatOperands.push_back(highZeros);
1055 concatOperands.push_back(replicateOperand);
1056 if (trailingZeros != 0) {
1058 op.getLoc(), APInt::getZero(trailingZeros));
1059 concatOperands.push_back(lowZeros);
1061 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1069 if (
auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
1072 (value.countLeadingZeros() || value.countTrailingZeros())) {
1073 unsigned lz = value.countLeadingZeros();
1074 unsigned tz = value.countTrailingZeros();
1077 auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
1078 Value smallElt = rewriter.createOrFold<
ExtractOp>(
1079 extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1080 extractOp.getLowBit() + tz);
1082 APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1083 if (!smallMask.isAllOnes()) {
1084 auto loc = inputs.back().getLoc();
1085 smallElt = rewriter.createOrFold<
AndOp>(
1092 SmallVector<Value> resultElts;
1094 resultElts.push_back(
1095 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1096 resultElts.push_back(smallElt);
1098 resultElts.push_back(
1099 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1100 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
1108 for (
size_t i = 0; i < size - 1; ++i) {
1109 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1122 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1123 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1124 source, cmpAgainst);
1132 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1136 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1137 auto inputs = adaptor.getInputs();
1139 for (
auto operand : inputs) {
1142 value |= cast<IntegerAttr>(operand).getValue();
1143 if (value.isAllOnes())
1148 if (inputs.size() == 2 && inputs[1] &&
1149 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1150 return getInputs()[0];
1153 if (llvm::all_of(getInputs(),
1154 [&](
auto in) {
return in == this->getInputs()[0]; }))
1155 return getInputs()[0];
1158 for (Value arg : getInputs()) {
1160 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1161 for (Value arg2 : getInputs())
1162 if (arg2 == subExpr)
1164 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1174 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1182 auto inputs = op.getInputs();
1183 auto size = inputs.size();
1197 assert(size > 1 &&
"expected 2 or more operands");
1201 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1203 if (value.isZero()) {
1204 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1205 inputs.drop_back());
1211 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1212 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value | value2);
1213 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1214 newOperands.push_back(cst);
1215 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1223 for (
size_t i = 0; i < size - 1; ++i) {
1224 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1237 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1238 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1239 source, cmpAgainst);
1245 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1247 if (op.getTwoState() && firstMux.getTwoState() &&
1248 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1250 SmallVector<Value> conditions{firstMux.getCond()};
1251 auto check = [&](Value v) {
1255 conditions.push_back(mux.getCond());
1256 return mux.getTwoState() &&
1257 firstMux.getTrueValue() == mux.getTrueValue() &&
1258 firstMux.getFalseValue() == mux.getFalseValue();
1260 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1261 auto cond = rewriter.create<
comb::OrOp>(op.getLoc(), conditions,
true);
1262 replaceOpWithNewOpAndCopyName<comb::MuxOp>(
1263 rewriter, op, cond, firstMux.getTrueValue(),
1264 firstMux.getFalseValue(),
true);
1274 OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1278 auto size = getInputs().size();
1279 auto inputs = adaptor.getInputs();
1283 return getInputs()[0];
1286 if (size == 2 && getInputs()[0] == getInputs()[1])
1290 if (inputs.size() == 2 && inputs[1] &&
1291 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1292 return getInputs()[0];
1296 if (isBinaryNot()) {
1298 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1299 subExpr != getResult())
1309 PatternRewriter &rewriter) {
1310 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1311 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1314 rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1315 icmp.getOperand(1), icmp.getTwoState());
1318 if (op.getNumOperands() > 2) {
1319 SmallVector<Value, 4> newOperands(op.getOperands());
1320 newOperands.pop_back();
1321 newOperands.erase(newOperands.begin() + icmpOperand);
1322 newOperands.push_back(result);
1323 result = rewriter.create<
XorOp>(op.getLoc(), newOperands, op.getTwoState());
1333 auto inputs = op.getInputs();
1334 auto size = inputs.size();
1335 assert(size > 1 &&
"expected 2 or more operands");
1338 if (inputs[size - 1] == inputs[size - 2]) {
1340 "expected idempotent case for 2 elements handled already.");
1341 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1342 inputs.drop_back(2),
false);
1348 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1350 if (value.isZero()) {
1351 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1352 inputs.drop_back(),
false);
1358 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1359 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value ^ value2);
1360 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1361 newOperands.push_back(cst);
1362 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1363 newOperands,
false);
1367 bool isSingleBit = value.getBitWidth() == 1;
1370 for (
size_t i = 0; i < size - 1; ++i) {
1371 Value operand = inputs[i];
1382 if (isSingleBit && operand.hasOneUse()) {
1383 assert(value == 1 &&
"single bit constant has to be one if not zero");
1384 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1400 replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
1407 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1412 if (getRhs() == getLhs())
1414 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1417 if (adaptor.getRhs()) {
1419 if (adaptor.getLhs()) {
1422 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1425 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1427 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1431 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1432 if (rhsC.getValue().isZero())
1446 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1447 auto negCst = rewriter.create<
hw::ConstantOp>(op.getLoc(), -value);
1448 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getLhs(), negCst,
1460 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1464 auto size = getInputs().size();
1468 return getInputs()[0];
1478 auto inputs = op.getInputs();
1479 auto size = inputs.size();
1480 assert(size > 1 &&
"expected 2 or more operands");
1482 APInt value, value2;
1485 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1486 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1487 inputs.drop_back(),
false);
1492 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1493 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1494 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1495 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1496 newOperands.push_back(cst);
1497 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1498 newOperands,
false);
1503 if (inputs[size - 1] == inputs[size - 2]) {
1504 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1506 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1510 newOperands.push_back(shiftLeftOp);
1511 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1512 newOperands,
false);
1516 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1518 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1519 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1521 APInt one(value.getBitWidth(), 1,
false);
1523 rewriter.create<
hw::ConstantOp>(op.getLoc(), (one << value) + one);
1525 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1526 auto mulOp = rewriter.create<
comb::MulOp>(op.getLoc(), factors,
false);
1528 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1529 newOperands.push_back(mulOp);
1530 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1531 newOperands,
false);
1535 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1537 if (mulOp && mulOp.getInputs().size() == 2 &&
1538 mulOp.getInputs()[0] == inputs[size - 2] &&
1539 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1541 APInt one(value.getBitWidth(), 1,
false);
1542 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + one);
1543 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1546 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1547 newOperands.push_back(newMulOp);
1548 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1549 newOperands,
false);
1562 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1563 if (addOp && addOp.getInputs().size() == 2 &&
1564 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1565 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1567 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1568 replaceOpWithNewOpAndCopyName<AddOp>(
1569 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1570 op.getTwoState() && addOp.getTwoState());
1577 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1581 auto size = getInputs().size();
1582 auto inputs = adaptor.getInputs();
1586 return getInputs()[0];
1588 auto width = cast<IntegerType>(getType()).getWidth();
1589 APInt value(width, 1,
false);
1592 for (
auto operand : inputs) {
1595 value *= cast<IntegerAttr>(operand).getValue();
1608 auto inputs = op.getInputs();
1609 auto size = inputs.size();
1610 assert(size > 1 &&
"expected 2 or more operands");
1612 APInt value, value2;
1615 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1616 value.isPowerOf2()) {
1617 auto shift = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(),
1618 value.exactLogBase2());
1622 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1623 ArrayRef<Value>(shlOp),
false);
1628 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1629 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1630 inputs.drop_back());
1635 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1636 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1637 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value * value2);
1638 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1639 newOperands.push_back(cst);
1640 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1656 template <
class Op,
bool isSigned>
1657 static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1658 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1660 if (rhsValue.getValue() == 1)
1664 if (rhsValue.getValue().isZero())
1671 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1675 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1678 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1685 template <
class Op,
bool isSigned>
1686 static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1687 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1689 if (rhsValue.getValue() == 1)
1690 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1694 if (rhsValue.getValue().isZero())
1698 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1700 if (lhsValue.getValue().isZero())
1701 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1708 OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1712 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1715 OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1726 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1730 if (getNumOperands() == 1)
1731 return getOperand(0);
1734 for (
auto attr : adaptor.getInputs())
1735 if (!attr || !isa<IntegerAttr>(attr))
1739 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1740 APInt result(resultWidth, 0);
1742 unsigned nextInsertion = resultWidth;
1744 for (
auto attr : adaptor.getInputs()) {
1745 auto chunk = cast<IntegerAttr>(attr).getValue();
1746 nextInsertion -= chunk.getBitWidth();
1747 result.insertBits(chunk, nextInsertion);
1757 auto inputs = op.getInputs();
1758 auto size = inputs.size();
1759 assert(size > 1 &&
"expected 2 or more operands");
1764 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1765 ValueRange replacements) -> LogicalResult {
1766 SmallVector<Value, 4> newOperands;
1767 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1768 newOperands.append(replacements.begin(), replacements.end());
1769 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1770 if (newOperands.size() == 1)
1773 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1778 Value commonOperand = inputs[0];
1779 for (
size_t i = 0; i != size; ++i) {
1781 if (inputs[i] != commonOperand)
1782 commonOperand = Value();
1786 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1787 return flattenConcat(i, i, subConcat->getOperands());
1792 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1793 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1794 unsigned prevWidth = prevCst.getValue().getBitWidth();
1795 unsigned thisWidth = cst.getValue().getBitWidth();
1796 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1797 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1801 return flattenConcat(i - 1, i, replacement);
1806 if (inputs[i] == inputs[i - 1]) {
1808 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1809 return flattenConcat(i - 1, i, replacement);
1814 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1816 if (repl.getOperand() == inputs[i - 1]) {
1817 Value replacement = rewriter.createOrFold<ReplicateOp>(
1818 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1819 return flattenConcat(i - 1, i, replacement);
1822 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1823 if (prevRepl.getOperand() == repl.getOperand()) {
1824 Value replacement = rewriter.createOrFold<ReplicateOp>(
1825 op.getLoc(), repl.getOperand(),
1826 repl.getMultiple() + prevRepl.getMultiple());
1827 return flattenConcat(i - 1, i, replacement);
1833 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1834 if (repl.getOperand() == inputs[i]) {
1835 Value replacement = rewriter.createOrFold<ReplicateOp>(
1836 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1837 return flattenConcat(i - 1, i, replacement);
1843 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1844 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1845 if (extract.getInput() == prevExtract.getInput()) {
1846 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1847 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1848 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1849 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1850 Value replacement = rewriter.create<
ExtractOp>(
1851 op.getLoc(), resType, extract.getInput(),
1852 extract.getLowBit());
1853 return flattenConcat(i - 1, i, replacement);
1866 static std::optional<ArraySlice>
get(Value value) {
1867 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1869 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1872 if (
auto arraySlice =
1875 arraySlice.getInput(), arraySlice.getLowIndex(),
1876 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1878 return std::nullopt;
1884 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1885 prevExtractOpt->input == extractOpt->input &&
1887 extractOpt->width)) {
1889 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1891 extractOpt->width + prevExtractOpt->width);
1894 op.getLoc(), resIntType,
1896 prevExtractOpt->input,
1897 extractOpt->index));
1898 return flattenConcat(i - 1, i, replacement);
1906 if (commonOperand) {
1907 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
1919 OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1924 if (getTrueValue() == getFalseValue())
1925 return getTrueValue();
1926 if (
auto tv = adaptor.getTrueValue())
1927 if (tv == adaptor.getFalseValue())
1932 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1933 if (pred.getValue().isZero())
1934 return getFalseValue();
1935 return getTrueValue();
1939 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1940 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1941 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1958 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1960 auto requiredPredicate =
1961 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1962 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1972 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
1975 for (
auto operand : orOp.getOperands())
1982 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
1985 for (
auto operand : andOp.getOperands())
2003 PatternRewriter &rewriter) {
2006 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2009 Value indexValue = rootCmp.getLhs();
2012 auto getCaseValue = [&](
MuxOp mux) -> Value {
2013 return mux.getOperand(1 +
unsigned(!isFalseSide));
2018 auto getTreeValue = [&](
MuxOp mux) -> Value {
2019 return mux.getOperand(1 +
unsigned(isFalseSide));
2024 SmallVector<Location> locationsFound;
2025 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2029 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2031 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2032 valuesFound.push_back({cst, getCaseValue(mux)});
2033 locationsFound.push_back(mux.getCond().getLoc());
2034 locationsFound.push_back(mux->getLoc());
2039 if (!collectConstantValues(rootMux))
2043 if (rootMux->hasOneUse()) {
2044 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2045 if (getTreeValue(userMux) == rootMux.getResult() &&
2053 auto nextTreeValue = getTreeValue(rootMux);
2055 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2056 if (!nextMux || !nextMux->hasOneUse())
2058 if (!collectConstantValues(nextMux))
2060 nextTreeValue = getTreeValue(nextMux);
2066 if (valuesFound.size() < 3)
2071 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2072 if (indexWidth >= 9)
2078 uint64_t tableSize = 1ULL << indexWidth;
2079 if (valuesFound.size() < (tableSize * 5) / 8)
2084 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2089 for (
auto &elt : llvm::reverse(valuesFound)) {
2090 uint64_t idx = elt.first.getValue().getZExtValue();
2091 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2092 table[idx] = elt.second;
2097 std::reverse(table.begin(), table.end());
2100 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2102 replaceOpWithNewOpAndCopyName<hw::ArrayGetOp>(rewriter, rootMux, array,
2117 PatternRewriter &rewriter) {
2118 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2119 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2123 if (fullyAssoc->getNumOperands() == 2)
2124 return fullyAssoc->getOperand(operandNo ^ 1);
2127 if (fullyAssoc->hasOneUse()) {
2128 rewriter.modifyOpInPlace(fullyAssoc,
2129 [&]() { fullyAssoc->eraseOperand(operandNo); });
2130 return fullyAssoc->getResult(0);
2134 SmallVector<Value> operands;
2135 operands.append(fullyAssoc->getOperands().begin(),
2136 fullyAssoc->getOperands().begin() + operandNo);
2137 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2138 fullyAssoc->getOperands().end());
2140 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2141 Value excluded = fullyAssoc->getOperand(operandNo);
2145 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2147 return opWithoutExcluded;
2157 PatternRewriter &rewriter) {
2160 Operation *subExpr =
2161 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2162 if (!subExpr || subExpr->getNumOperands() < 2)
2166 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2171 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2172 size_t opNo = 0, e = subExpr->getNumOperands();
2173 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2179 Value cond = op.getCond();
2185 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2187 Value subCond = subMux.getCond();
2190 if (subMux.getTrueValue() == commonValue)
2191 otherValue = subMux.getFalseValue();
2192 else if (subMux.getFalseValue() == commonValue) {
2193 otherValue = subMux.getTrueValue();
2203 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2204 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, cond, commonValue,
2205 otherValue, op.getTwoState());
2211 bool isaAndOp = isa<AndOp>(subExpr);
2212 if (isTrueOperand ^ isaAndOp)
2216 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2219 bool isaXorOp = isa<XorOp>(subExpr);
2220 bool isaOrOp = isa<OrOp>(subExpr);
2229 if (isaOrOp || isaXorOp) {
2230 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2231 restOfAssoc,
false);
2233 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, masked, commonValue,
2236 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, masked, commonValue,
2242 assert(isaAndOp &&
"unexpected operation here");
2243 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2244 restOfAssoc,
false);
2245 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, masked, commonValue,
2256 PatternRewriter &rewriter) {
2259 if (!isa<ConcatOp>(trueOp))
2263 SmallVector<Value> trueOperands, falseOperands;
2267 size_t numTrueOperands = trueOperands.size();
2268 size_t numFalseOperands = falseOperands.size();
2270 if (!numTrueOperands || !numFalseOperands ||
2271 (trueOperands.front() != falseOperands.front() &&
2272 trueOperands.back() != falseOperands.back()))
2276 if (trueOperands.front() == falseOperands.front()) {
2277 SmallVector<Value> operands;
2279 for (i = 0; i < numTrueOperands; ++i) {
2280 Value trueOperand = trueOperands[i];
2281 if (trueOperand == falseOperands[i])
2282 operands.push_back(trueOperand);
2286 if (i == numTrueOperands) {
2293 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2294 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2295 mux->getLoc(), operands.front(), operands.size());
2297 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2301 operands.append(trueOperands.begin() + i, trueOperands.end());
2302 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2304 operands.append(falseOperands.begin() + i, falseOperands.end());
2306 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2309 Value lsb = rewriter.createOrFold<
MuxOp>(
2310 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2311 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2316 if (trueOperands.back() == falseOperands.back()) {
2317 SmallVector<Value> operands;
2320 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2321 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2322 operands.push_back(trueOperand);
2326 std::reverse(operands.begin(), operands.end());
2327 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2331 operands.append(trueOperands.begin(), trueOperands.end() - i);
2332 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2334 operands.append(falseOperands.begin(), falseOperands.end() - i);
2336 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2338 Value msb = rewriter.createOrFold<
MuxOp>(
2339 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2340 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, msb, sharedLSB);
2352 if (!trueVec || !falseVec)
2354 if (!trueVec.isUniform() || !falseVec.isUniform())
2358 op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2359 falseVec.getUniformElement(), op.getTwoState());
2361 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2368 using OpRewritePattern::OpRewritePattern;
2370 LogicalResult matchAndRewrite(
MuxOp op,
2371 PatternRewriter &rewriter)
const override;
2374 LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2375 PatternRewriter &rewriter)
const {
2384 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2385 if (value.getBitWidth() == 1) {
2387 if (value.isZero()) {
2389 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, notCond,
2390 op.getFalseValue(),
false);
2395 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getCond(),
2396 op.getFalseValue(),
false);
2402 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2407 APInt xorValue = value ^ value2;
2408 if (xorValue.isPowerOf2()) {
2409 unsigned leadingZeros = xorValue.countLeadingZeros();
2410 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2411 SmallVector<Value, 3> operands;
2419 if (leadingZeros > 0)
2420 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2421 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2425 auto v1 = rewriter.createOrFold<
ExtractOp>(
2426 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2427 auto v2 = rewriter.createOrFold<
ExtractOp>(
2428 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2429 operands.push_back(rewriter.createOrFold<
MuxOp>(
2430 op.getLoc(), op.getCond(), v1, v2,
false));
2432 if (trailingZeros > 0)
2433 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2434 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2436 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
2443 if (value.isAllOnes() && value2.isZero()) {
2444 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2451 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2452 value.getBitWidth() == 1) {
2454 if (value.isZero()) {
2455 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getCond(),
2456 op.getTrueValue(),
false);
2463 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2464 op.getFalseValue(),
false);
2465 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, notCond,
2466 op.getTrueValue(),
false);
2472 Operation *condOp = op.getCond().getDefiningOp();
2473 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2475 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, op.getType(), subExpr,
2476 op.getFalseValue(), op.getTrueValue(),
2484 if (condOp && condOp->hasOneUse()) {
2485 SmallVector<Value> invertedOperands;
2489 auto getInvertedOperands = [&]() ->
bool {
2490 for (Value operand : condOp->getOperands()) {
2491 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2492 invertedOperands.push_back(subExpr);
2499 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2501 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2502 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newOr,
2504 op.getTrueValue(), op.getTwoState());
2507 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2509 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2510 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newAnd,
2512 op.getTrueValue(), op.getTwoState());
2518 dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp())) {
2520 if (op.getCond() == falseMux.getCond()) {
2521 replaceOpWithNewOpAndCopyName<MuxOp>(
2522 rewriter, op, op.getCond(), op.getTrueValue(),
2523 falseMux.getFalseValue(), op.getTwoStateAttr());
2533 dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp())) {
2535 if (op.getCond() == trueMux.getCond()) {
2536 replaceOpWithNewOpAndCopyName<MuxOp>(
2537 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2538 op.getFalseValue(), op.getTwoStateAttr());
2548 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2549 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2550 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2551 trueMux.getTrueValue() == falseMux.getTrueValue()) {
2552 auto subMux = rewriter.create<
MuxOp>(
2553 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2554 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2555 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2556 trueMux.getTrueValue(), subMux,
2557 op.getTwoStateAttr());
2562 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2563 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2564 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2565 trueMux.getFalseValue() == falseMux.getFalseValue()) {
2566 auto subMux = rewriter.create<
MuxOp>(
2567 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2568 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2569 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2570 subMux, trueMux.getFalseValue(),
2571 op.getTwoStateAttr());
2576 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2577 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2578 trueMux && falseMux &&
2579 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2580 trueMux.getFalseValue() == falseMux.getFalseValue()) {
2581 auto subMux = rewriter.create<
MuxOp>(
2582 rewriter.getFusedLoc(
2583 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2584 op.getCond(), trueMux.getCond(), falseMux.getCond());
2585 replaceOpWithNewOpAndCopyName<MuxOp>(
2586 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2587 op.getTwoStateAttr());
2599 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2600 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2601 if (trueOp->getName() == falseOp->getName())
2618 if (op.getInputs().empty() || op.isUniform())
2620 auto inputs = op.getInputs();
2621 if (inputs.size() <= 1)
2626 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2631 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2632 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2633 if (!input || first.getCond() != input.getCond())
2638 SmallVector<Value> trues{first.getTrueValue()};
2639 SmallVector<Value> falses{first.getFalseValue()};
2640 SmallVector<Location> locs{first->getLoc()};
2641 bool isTwoState =
true;
2642 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2643 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2644 trues.push_back(input.getTrueValue());
2645 falses.push_back(input.getFalseValue());
2646 locs.push_back(input->getLoc());
2647 if (!input.getTwoState())
2656 auto arrayTy = op.getType();
2659 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2660 trueValues, falseValues, isTwoState);
2665 using OpRewritePattern::OpRewritePattern;
2668 PatternRewriter &rewriter)
const override {
2672 if (foldArrayOfMuxes(op, rewriter))
2680 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2681 MLIRContext *context) {
2682 results.insert<MuxRewriter, ArrayRewriter>(context);
2693 switch (predicate) {
2694 case ICmpPredicate::eq:
2696 case ICmpPredicate::ne:
2698 case ICmpPredicate::slt:
2699 return lhs.slt(rhs);
2700 case ICmpPredicate::sle:
2701 return lhs.sle(rhs);
2702 case ICmpPredicate::sgt:
2703 return lhs.sgt(rhs);
2704 case ICmpPredicate::sge:
2705 return lhs.sge(rhs);
2706 case ICmpPredicate::ult:
2707 return lhs.ult(rhs);
2708 case ICmpPredicate::ule:
2709 return lhs.ule(rhs);
2710 case ICmpPredicate::ugt:
2711 return lhs.ugt(rhs);
2712 case ICmpPredicate::uge:
2713 return lhs.uge(rhs);
2714 case ICmpPredicate::ceq:
2716 case ICmpPredicate::cne:
2718 case ICmpPredicate::weq:
2720 case ICmpPredicate::wne:
2723 llvm_unreachable(
"unknown comparison predicate");
2729 switch (predicate) {
2730 case ICmpPredicate::eq:
2731 case ICmpPredicate::sle:
2732 case ICmpPredicate::sge:
2733 case ICmpPredicate::ule:
2734 case ICmpPredicate::uge:
2735 case ICmpPredicate::ceq:
2736 case ICmpPredicate::weq:
2738 case ICmpPredicate::ne:
2739 case ICmpPredicate::slt:
2740 case ICmpPredicate::sgt:
2741 case ICmpPredicate::ult:
2742 case ICmpPredicate::ugt:
2743 case ICmpPredicate::cne:
2744 case ICmpPredicate::wne:
2747 llvm_unreachable(
"unknown comparison predicate");
2750 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2756 if (getLhs() == getRhs()) {
2762 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2763 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2774 template <
typename Range>
2776 size_t commonPrefixLength = 0;
2777 auto ia = a.begin();
2778 auto ib = b.begin();
2780 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2786 return commonPrefixLength;
2790 size_t totalWidth = 0;
2791 for (
auto operand : operands) {
2794 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2796 totalWidth += width;
2806 PatternRewriter &rewriter) {
2810 SmallVector<Value> lhsOperands, rhsOperands;
2813 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2815 auto formCatOrReplicate = [&](Location loc,
2816 ArrayRef<Value> operands) -> Value {
2817 assert(!operands.empty());
2818 Value sameElement = operands[0];
2819 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2820 if (sameElement != operands[i])
2821 sameElement = Value();
2823 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2825 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2828 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2829 Value rhs) -> LogicalResult {
2830 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2835 size_t commonPrefixLength =
2837 if (commonPrefixLength == lhsOperands.size()) {
2840 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
2846 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2848 size_t commonPrefixTotalWidth =
2849 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2850 size_t commonSuffixTotalWidth =
2851 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2852 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2853 .drop_back(commonSuffixLength);
2854 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2855 .drop_back(commonSuffixLength);
2857 auto replaceWithoutReplicatingSignBit = [&]() {
2858 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2859 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2860 return replaceWith(op.getPredicate(), newLhs, newRhs);
2863 auto replaceWithReplicatingSignBit = [&]() {
2864 auto firstNonEmptyValue = lhsOperands[0];
2865 auto firstNonEmptyElemWidth =
2866 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2867 Value signBit = rewriter.createOrFold<
ExtractOp>(
2868 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2870 auto newLhs = rewriter.
create<
ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2871 auto newRhs = rewriter.create<
ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2872 return replaceWith(op.getPredicate(), newLhs, newRhs);
2875 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2877 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2878 return replaceWithoutReplicatingSignBit();
2884 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2885 return replaceWithReplicatingSignBit();
2887 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2889 return replaceWithoutReplicatingSignBit();
2903 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2904 PatternRewriter &rewriter) {
2908 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2909 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2912 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2913 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
2921 SmallVector<Value> newConcatOperands;
2922 auto newConstant = APInt::getZeroWidth();
2927 unsigned knownMSB = bitsKnown.countLeadingOnes();
2929 Value operand = cmpOp.getLhs();
2934 while (knownMSB != bitsKnown.getBitWidth()) {
2937 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2940 unsigned unknownBits = bitsKnown.countLeadingZeros();
2941 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2942 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
2943 operand.getLoc(), operand, lowBit,
2945 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2948 newConcatOperands.push_back(spanOperand);
2951 if (newConstant.getBitWidth() != 0)
2952 newConstant = newConstant.concat(spanConstant);
2954 newConstant = spanConstant;
2957 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
2958 bitsKnown = bitsKnown.trunc(newWidth);
2959 knownMSB = bitsKnown.countLeadingOnes();
2965 if (newConcatOperands.empty()) {
2966 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
2967 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
2973 Value concatResult =
2974 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
2978 cmpOp.getOperand(1).getLoc(), newConstant);
2980 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
2981 concatResult, newConstantOp,
2982 cmpOp.getTwoState());
2988 PatternRewriter &rewriter) {
2989 auto ip = rewriter.saveInsertionPoint();
2990 rewriter.setInsertionPoint(xorOp);
2992 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
2994 xorRHS.getValue() ^ rhs);
2996 switch (xorOp.getNumOperands()) {
3000 APInt::getZero(rhs.getBitWidth()));
3004 newLHS = xorOp.getOperand(0);
3008 SmallVector<Value> newOperands(xorOp.getOperands());
3009 newOperands.pop_back();
3010 newLHS = rewriter.create<
XorOp>(xorOp.getLoc(), newOperands,
false);
3014 bool xorMultipleUses = !xorOp->hasOneUse();
3018 if (xorMultipleUses)
3019 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3023 rewriter.restoreInsertionPoint(ip);
3024 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3025 newLHS, newRHS,
false);
3035 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3036 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3037 "Should be folded");
3038 replaceOpWithNewOpAndCopyName<ICmpOp>(
3039 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3040 op.getRhs(), op.getLhs(), op.getTwoState());
3045 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3047 return rewriter.create<
hw::ConstantOp>(op.getLoc(), std::move(constant));
3050 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3051 Value rhs) -> LogicalResult {
3052 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
3057 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3058 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
3059 APInt(1, constant));
3063 switch (op.getPredicate()) {
3064 case ICmpPredicate::slt:
3066 if (rhs.isMaxSignedValue())
3067 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3069 if (rhs.isMinSignedValue())
3070 return replaceWithConstantI1(0);
3072 if ((rhs - 1).isMinSignedValue())
3073 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3076 case ICmpPredicate::sgt:
3078 if (rhs.isMinSignedValue())
3079 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3081 if (rhs.isMaxSignedValue())
3082 return replaceWithConstantI1(0);
3084 if ((rhs + 1).isMaxSignedValue())
3085 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3088 case ICmpPredicate::ult:
3090 if (rhs.isAllOnes())
3091 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3094 return replaceWithConstantI1(0);
3096 if ((rhs - 1).isZero())
3097 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3101 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3102 rhs.getBitWidth()) {
3103 auto numOnes = rhs.countLeadingOnes();
3104 auto smaller = rewriter.create<
ExtractOp>(
3105 op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3106 return replaceWith(ICmpPredicate::ne, smaller,
3111 case ICmpPredicate::ugt:
3114 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3116 if (rhs.isAllOnes())
3117 return replaceWithConstantI1(0);
3119 if ((rhs + 1).isAllOnes())
3120 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3124 if ((rhs + 1).isPowerOf2()) {
3125 auto numOnes = rhs.countTrailingOnes();
3126 auto newWidth = rhs.getBitWidth() - numOnes;
3127 auto smaller = rewriter.create<
ExtractOp>(op.getLoc(), op.getLhs(),
3129 return replaceWith(ICmpPredicate::ne, smaller,
3134 case ICmpPredicate::sle:
3136 if (rhs.isMaxSignedValue())
3137 return replaceWithConstantI1(1);
3139 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3140 case ICmpPredicate::sge:
3142 if (rhs.isMinSignedValue())
3143 return replaceWithConstantI1(1);
3145 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3146 case ICmpPredicate::ule:
3148 if (rhs.isAllOnes())
3149 return replaceWithConstantI1(1);
3151 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3152 case ICmpPredicate::uge:
3155 return replaceWithConstantI1(1);
3157 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3158 case ICmpPredicate::eq:
3159 if (rhs.getBitWidth() == 1) {
3162 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3167 if (rhs.isAllOnes()) {
3174 case ICmpPredicate::ne:
3175 if (rhs.getBitWidth() == 1) {
3181 if (rhs.isAllOnes()) {
3183 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3190 case ICmpPredicate::ceq:
3191 case ICmpPredicate::cne:
3192 case ICmpPredicate::weq:
3193 case ICmpPredicate::wne:
3199 if (op.getPredicate() == ICmpPredicate::eq ||
3200 op.getPredicate() == ICmpPredicate::ne) {
3205 if (!knownBits.isUnknown())
3212 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3219 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3220 if (rhs.isAllOnes() || rhs.isZero()) {
3221 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3223 op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width)
3224 : APInt::getZero(width));
3225 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, op.getPredicate(),
3226 replicateOp.getInput(), cst,
3236 if (Operation *opLHS = op.getLhs().getDefiningOp())
3237 if (Operation *opRHS = op.getRhs().getDefiningOp())
3238 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3239 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "sv.namehint" attribute.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
def create(data_type, value)
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
uint64_t getWidth(Type t)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
bool isOffset(Value base, Value index, uint64_t offset)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.