12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallBitVector.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/KnownBits.h"
20 using namespace circt;
22 using namespace matchers;
33 Block *thisBlock = op->getBlock();
34 return llvm::any_of(op->getOperands(), [&](Value operand) {
35 return operand.getParentBlock() != thisBlock;
45 ArrayRef<Value> operands, OpBuilder &builder) {
46 OperationState state(loc, name);
47 state.addOperands(operands);
48 state.addTypes(operands[0].getType());
49 return builder.create(state)->getResult(0);
52 static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
60 for (
auto op :
concat.getOperands())
62 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
63 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
75 if (
auto *newOp = newValue.getDefiningOp()) {
76 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
77 if (name && !newOp->hasAttr(
"sv.namehint"))
78 rewriter.modifyOpInPlace(newOp,
79 [&] { newOp->setAttr(
"sv.namehint", name); });
81 rewriter.replaceOp(op, newValue);
87 template <
typename OpTy,
typename... Args>
89 Operation *op, Args &&...args) {
90 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
92 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
93 if (name && !newOp->hasAttr(
"sv.namehint"))
94 rewriter.modifyOpInPlace(newOp,
95 [&] { newOp->setAttr(
"sv.namehint", name); });
104 return op->hasAttr(
"sv.attributes");
108 template <
typename SubType>
109 struct ComplementMatcher {
111 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
112 bool match(Operation *op) {
113 auto xorOp = dyn_cast<XorOp>(op);
114 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
119 template <
typename SubType>
120 static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
121 return ComplementMatcher<SubType>(subExpr);
127 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
128 "must be commutative operations"));
129 if (op->hasOneUse()) {
130 auto *user = *op->getUsers().begin();
131 return user->getName() == op->getName() &&
132 op->getAttrOfType<UnitAttr>(
"twoState") ==
133 user->getAttrOfType<UnitAttr>(
"twoState") &&
134 op->getBlock() == user->getBlock();
149 auto inputs = op->getOperands();
151 SmallVector<Value, 4> newOperands;
152 SmallVector<Location, 4> newLocations{op->getLoc()};
153 newOperands.reserve(inputs.size());
155 decltype(inputs.begin()) current, end;
158 SmallVector<Element> worklist;
159 worklist.push_back({inputs.begin(), inputs.end()});
160 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
161 bool changed =
false;
162 while (!worklist.empty()) {
163 auto &element = worklist.back();
166 if (element.current == element.end) {
171 Value value = *element.current++;
172 auto *flattenOp = value.getDefiningOp();
175 if (!flattenOp || flattenOp->getName() != op->getName() ||
176 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
177 flattenOp->getBlock() != op->getBlock()) {
178 newOperands.push_back(value);
183 if (!value.hasOneUse()) {
191 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
194 newOperands.push_back(value);
202 auto flattenOpInputs = flattenOp->getOperands();
203 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
204 newLocations.push_back(flattenOp->getLoc());
211 op->getName(), newOperands, rewriter);
213 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
221 static std::pair<size_t, size_t>
223 size_t originalOpWidth) {
224 auto users = op->getUsers();
226 "getLowestBitAndHighestBitRequired cannot operate on "
227 "a empty list of uses.");
231 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
232 size_t highestBitRequired = 0;
234 for (
auto *user : users) {
235 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
236 size_t lowBit = extractOp.getLowBit();
238 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
239 highestBitRequired = std::max(highestBitRequired, highBit);
240 lowestBitRequired = std::min(lowestBitRequired, lowBit);
244 highestBitRequired = originalOpWidth - 1;
245 lowestBitRequired = 0;
249 return {lowestBitRequired, highestBitRequired};
252 template <
class OpTy>
254 PatternRewriter &rewriter) {
255 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
261 if (range.second + 1 == opType.getWidth() && range.first == 0)
264 SmallVector<Value> args;
265 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
266 for (
auto inop : op.getOperands()) {
268 if (inop.getType() != op.getType())
269 args.push_back(inop);
271 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
274 auto newop = rewriter.create<OpTy>(op.getLoc(), newType, args);
275 newop->setDialectAttrs(op->getDialectAttrs());
276 if (op.getTwoState())
277 newop.setTwoState(
true);
279 Value newResult = newop.getResult();
281 newResult = rewriter.createOrFold<
ConcatOp>(
282 op.getLoc(), newResult,
284 APInt::getZero(range.first)));
285 if (range.second + 1 < opType.getWidth())
286 newResult = rewriter.createOrFold<
ConcatOp>(
289 op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
291 rewriter.replaceOp(op, newResult);
299 OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
304 if (cast<IntegerType>(getType()).
getWidth() ==
305 getInput().getType().getIntOrFloatBitWidth())
309 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
310 if (input.getValue().getBitWidth() == 1) {
311 if (input.getValue().isZero())
313 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
316 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
320 APInt result = APInt::getZeroWidth();
321 for (
auto i = getMultiple(); i != 0; --i)
322 result = result.concat(input.getValue());
329 OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
334 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
335 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
347 hw::PEO paramOpcode) {
348 assert(operands.size() == 2 &&
"binary op takes two operands");
349 if (!operands[0] || !operands[1])
355 cast<TypedAttr>(operands[1]));
358 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
362 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
363 unsigned shift = rhs.getValue().getZExtValue();
364 unsigned width = getType().getIntOrFloatBitWidth();
366 return getOperand(0);
380 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
383 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
384 unsigned shift = value.getZExtValue();
387 if (
width <= shift || shift == 0)
391 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
397 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, extract, zeros);
401 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
405 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
406 unsigned shift = rhs.getValue().getZExtValue();
408 return getOperand(0);
410 unsigned width = getType().getIntOrFloatBitWidth();
423 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
426 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
427 unsigned shift = value.getZExtValue();
430 if (
width <= shift || shift == 0)
434 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
437 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
440 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, zeros, extract);
444 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
448 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
449 if (rhs.getValue().getZExtValue() == 0)
450 return getOperand(0);
461 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
464 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
465 unsigned shift = value.getZExtValue();
468 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(),
width - 1, 1);
469 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
471 if (
width <= shift) {
476 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
479 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, sext, extract);
487 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
492 if (getInput().getType() == getType())
496 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
497 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
498 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
509 PatternRewriter &rewriter) {
510 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
511 size_t beginOfFirstRelevantElement = 0;
512 auto it = reversedConcatArgs.begin();
513 size_t lowBit = op.getLowBit();
516 for (; it != reversedConcatArgs.end(); it++) {
517 assert(beginOfFirstRelevantElement <= lowBit &&
518 "incorrectly moved past an element that lowBit has coverage over");
521 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
522 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
546 beginOfFirstRelevantElement += operandWidth;
548 assert(it != reversedConcatArgs.end() &&
549 "incorrectly failed to find an element which contains coverage of "
552 SmallVector<Value> reverseConcatArgs;
553 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
554 size_t extractLo = lowBit - beginOfFirstRelevantElement;
559 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
560 auto concatArg = *it;
561 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
562 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
564 if (widthToConsume == operandWidth && extractLo == 0) {
565 reverseConcatArgs.push_back(concatArg);
568 reverseConcatArgs.push_back(
569 rewriter.create<
ExtractOp>(op.getLoc(), resultType, *it, extractLo));
572 widthRemaining -= widthToConsume;
578 if (reverseConcatArgs.size() == 1) {
581 replaceOpWithNewOpAndCopyName<ConcatOp>(
582 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
589 PatternRewriter &rewriter) {
590 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
591 auto replicateEltWidth =
592 replicate.getOperand().getType().getIntOrFloatBitWidth();
596 if (op.getLowBit() % replicateEltWidth == 0 &&
597 extractResultWidth % replicateEltWidth == 0) {
598 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
599 replicate.getOperand());
605 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
607 replaceOpWithNewOpAndCopyName<ExtractOp>(
608 rewriter, op, op.getType(), replicate.getOperand(),
609 op.getLowBit() % replicateEltWidth);
622 auto *inputOp = op.getInput().getDefiningOp();
629 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
631 if (knownBits.isConstant()) {
632 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
633 knownBits.getConstant());
639 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
640 replaceOpWithNewOpAndCopyName<ExtractOp>(
641 rewriter, op, op.getType(), innerExtract.getInput(),
642 innerExtract.getLowBit() + op.getLowBit());
647 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
651 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
657 if (inputOp && inputOp->getNumOperands() == 2 &&
658 isa<AndOp, OrOp, XorOp>(inputOp)) {
659 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
660 auto extractedCst = cstRHS.getValue().extractBits(
661 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
662 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
663 replaceOpWithNewOpAndCopyName<ExtractOp>(
664 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
672 if (isa<AndOp>(inputOp)) {
675 unsigned lz = extractedCst.countLeadingZeros();
676 unsigned tz = extractedCst.countTrailingZeros();
677 unsigned pop = extractedCst.popcount();
678 if (extractedCst.getBitWidth() - lz - tz == pop) {
679 auto resultTy = rewriter.getIntegerType(pop);
680 SmallVector<Value> resultElts;
683 op.getLoc(), APInt::getZero(lz)));
684 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
685 op.getLoc(), resultTy, inputOp->getOperand(0),
686 op.getLowBit() + tz));
689 op.getLoc(), APInt::getZero(tz)));
690 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
699 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
700 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
702 if (shlOp->hasOneUse())
704 if (lhsCst.getValue().isOne()) {
707 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
708 replaceOpWithNewOpAndCopyName<ICmpOp>(
709 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
725 hw::PEO paramOpcode) {
726 assert(operands.size() > 1 &&
"caller should handle one-operand case");
729 if (!operands[1] || !operands[0])
733 if (llvm::all_of(operands.drop_front(2),
734 [&](Attribute in) { return !!in; })) {
735 SmallVector<mlir::TypedAttr> typedOperands;
736 typedOperands.reserve(operands.size());
737 for (
auto operand : operands) {
738 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
739 typedOperands.push_back(typedOperand);
743 if (typedOperands.size() == operands.size())
760 size_t concatIdx,
const APInt &cst,
761 PatternRewriter &rewriter) {
762 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
763 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
768 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
769 auto *operandOp = operand.getDefiningOp();
774 if (isa<hw::ConstantOp>(operandOp))
778 return operandOp->getName() == logicalOp->getName() &&
779 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
780 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
788 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
789 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
796 SmallVector<Value> newConcatOperands;
797 newConcatOperands.reserve(concatOp->getNumOperands());
800 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
801 for (Value operand : concatOp->getOperands()) {
802 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
803 nextOperandBit -= operandWidth;
806 logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
808 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
817 if (logicalOp->getNumOperands() > 2) {
818 auto origOperands = logicalOp->getOperands();
819 SmallVector<Value> operands;
821 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
823 operands.append(origOperands.begin() + concatIdx + 1,
824 origOperands.begin() + (origOperands.size() - 1));
826 operands.push_back(newResult);
827 newResult = createLogicalOp(operands);
837 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
839 for (
auto op : operands) {
840 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
841 icmpOp && icmpOp.getTwoState()) {
842 auto predicate = icmpOp.getPredicate();
843 auto lhs = icmpOp.getLhs();
844 auto rhs = icmpOp.getRhs();
845 if (seenPredicates.contains(
846 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
849 seenPredicates.insert({predicate, lhs, rhs});
855 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
859 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
861 auto inputs = adaptor.getInputs();
864 for (
auto operand : inputs) {
867 value &= cast<IntegerAttr>(operand).getValue();
873 if (inputs.size() == 2 && inputs[1] &&
874 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
875 return getInputs()[0];
878 if (llvm::all_of(getInputs(),
879 [&](
auto in) {
return in == this->getInputs()[0]; }))
880 return getInputs()[0];
883 for (Value arg : getInputs()) {
886 for (Value arg2 : getInputs())
889 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
910 template <
typename Op>
912 if (!op.getType().isInteger(1))
915 auto inputs = op.getInputs();
916 size_t size = inputs.size();
918 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
921 Value source = sourceOp.getOperand();
924 if (size != source.getType().getIntOrFloatBitWidth())
928 llvm::BitVector bits(size);
929 bits.set(sourceOp.getLowBit());
931 for (
size_t i = 1; i != size; ++i) {
932 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
933 if (!extractOp || extractOp.getOperand() != source)
935 bits.set(extractOp.getLowBit());
938 return bits.all() ? source : Value();
945 template <
typename Op>
948 constexpr
unsigned limit = 3;
949 auto inputs = op.getInputs();
951 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
952 llvm::SmallDenseSet<Op, 8> checked;
959 llvm::SmallVector<OpWithDepth, 8> worklist;
961 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
965 if (depth < limit && input.getParentBlock() == op->getBlock()) {
966 auto inputOp = input.template getDefiningOp<Op>();
967 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
968 checked.insert(inputOp).second)
969 worklist.push_back({inputOp, depth + 1});
973 for (
auto input : uniqueInputs)
976 while (!worklist.empty()) {
977 auto item = worklist.pop_back_val();
979 for (
auto input : item.op.getInputs()) {
980 uniqueInputs.remove(input);
981 enqueue(input, item.depth);
985 if (uniqueInputs.size() < inputs.size()) {
986 replaceOpWithNewOpAndCopyName<Op>(rewriter, op, op.getType(),
987 uniqueInputs.getArrayRef(),
996 auto inputs = op.getInputs();
997 auto size = inputs.size();
1011 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
1015 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1017 if (value.isAllOnes()) {
1018 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1019 inputs.drop_back(),
false);
1027 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1028 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value & value2);
1029 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1030 newOperands.push_back(cst);
1031 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1032 newOperands,
false);
1037 if (size == 2 && value.isPowerOf2()) {
1042 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1043 auto replicateOperand = replicate.getOperand();
1044 if (replicateOperand.getType().isInteger(1)) {
1045 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1046 auto trailingZeros = value.countTrailingZeros();
1049 SmallVector<Value, 3> concatOperands;
1050 if (trailingZeros != resultWidth - 1) {
1052 op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
1053 concatOperands.push_back(highZeros);
1055 concatOperands.push_back(replicateOperand);
1056 if (trailingZeros != 0) {
1058 op.getLoc(), APInt::getZero(trailingZeros));
1059 concatOperands.push_back(lowZeros);
1061 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1069 if (
auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
1072 (value.countLeadingZeros() || value.countTrailingZeros())) {
1073 unsigned lz = value.countLeadingZeros();
1074 unsigned tz = value.countTrailingZeros();
1077 auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
1078 Value smallElt = rewriter.createOrFold<
ExtractOp>(
1079 extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1080 extractOp.getLowBit() + tz);
1082 APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1083 if (!smallMask.isAllOnes()) {
1084 auto loc = inputs.back().getLoc();
1085 smallElt = rewriter.createOrFold<
AndOp>(
1092 SmallVector<Value> resultElts;
1094 resultElts.push_back(
1095 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1096 resultElts.push_back(smallElt);
1098 resultElts.push_back(
1099 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1100 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
1108 for (
size_t i = 0; i < size - 1; ++i) {
1109 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1122 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1123 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1124 source, cmpAgainst);
1132 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1136 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1137 auto inputs = adaptor.getInputs();
1139 for (
auto operand : inputs) {
1142 value |= cast<IntegerAttr>(operand).getValue();
1143 if (value.isAllOnes())
1148 if (inputs.size() == 2 && inputs[1] &&
1149 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1150 return getInputs()[0];
1153 if (llvm::all_of(getInputs(),
1154 [&](
auto in) {
return in == this->getInputs()[0]; }))
1155 return getInputs()[0];
1158 for (Value arg : getInputs()) {
1160 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1161 for (Value arg2 : getInputs())
1162 if (arg2 == subExpr)
1164 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1174 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1198 PatternRewriter &rewriter) {
1199 assert(concatIdx1 < concatIdx2 &&
"concatIdx1 must be < concatIdx2");
1201 auto inputs = op.getInputs();
1202 auto concat1 = inputs[concatIdx1].getDefiningOp<
ConcatOp>();
1203 auto concat2 = inputs[concatIdx2].getDefiningOp<
ConcatOp>();
1205 assert(concat1 && concat2 &&
"expected indexes to point to ConcatOps");
1208 bool hasConstantOp1 =
1209 llvm::any_of(concat1->getOperands(), [&](Value operand) ->
bool {
1210 return operand.getDefiningOp<hw::ConstantOp>();
1212 if (!hasConstantOp1) {
1213 bool hasConstantOp2 =
1214 llvm::any_of(concat2->getOperands(), [&](Value operand) ->
bool {
1215 return operand.getDefiningOp<hw::ConstantOp>();
1217 if (!hasConstantOp2)
1221 SmallVector<Value> newConcatOperands;
1226 auto operands1 = concat1->getOperands();
1227 auto operands2 = concat2->getOperands();
1229 unsigned consumedWidth1 = 0;
1230 unsigned consumedWidth2 = 0;
1231 for (
auto it1 = operands1.begin(), end1 = operands1.end(),
1232 it2 = operands2.begin(), end2 = operands2.end();
1233 it1 != end1 && it2 != end2;) {
1234 auto operand1 = *it1;
1235 auto operand2 = *it2;
1237 unsigned remainingWidth1 =
1239 unsigned remainingWidth2 =
1241 unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2);
1242 auto narrowedType = rewriter.getIntegerType(widthToConsume);
1244 auto extract1 = rewriter.createOrFold<
ExtractOp>(
1245 op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume);
1246 auto extract2 = rewriter.createOrFold<
ExtractOp>(
1247 op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume);
1249 newConcatOperands.push_back(
1250 rewriter.createOrFold<
OrOp>(op.getLoc(), extract1, extract2,
false));
1252 consumedWidth1 += widthToConsume;
1253 consumedWidth2 += widthToConsume;
1255 if (widthToConsume == remainingWidth1) {
1259 if (widthToConsume == remainingWidth2) {
1265 ConcatOp newOp = rewriter.create<
ConcatOp>(op.getLoc(), newConcatOperands);
1269 SmallVector<Value> newOrOperands;
1270 newOrOperands.append(inputs.begin(), inputs.begin() + concatIdx1);
1271 newOrOperands.append(inputs.begin() + concatIdx1 + 1,
1272 inputs.begin() + concatIdx2);
1273 newOrOperands.append(inputs.begin() + concatIdx2 + 1,
1274 inputs.begin() + inputs.size());
1275 newOrOperands.push_back(newOp);
1277 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1283 auto inputs = op.getInputs();
1284 auto size = inputs.size();
1298 assert(size > 1 &&
"expected 2 or more operands");
1302 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1304 if (value.isZero()) {
1305 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1306 inputs.drop_back());
1312 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1313 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value | value2);
1314 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1315 newOperands.push_back(cst);
1316 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1324 for (
size_t i = 0; i < size - 1; ++i) {
1325 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1333 for (
size_t i = 0; i < size - 1; ++i) {
1334 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1335 for (
size_t j = i + 1; j < size; ++j)
1336 if (
auto concat = inputs[j].getDefiningOp<ConcatOp>())
1348 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1349 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1350 source, cmpAgainst);
1356 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1358 if (op.getTwoState() && firstMux.getTwoState() &&
1359 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1361 SmallVector<Value> conditions{firstMux.getCond()};
1362 auto check = [&](Value v) {
1366 conditions.push_back(mux.getCond());
1367 return mux.getTwoState() &&
1368 firstMux.getTrueValue() == mux.getTrueValue() &&
1369 firstMux.getFalseValue() == mux.getFalseValue();
1371 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1372 auto cond = rewriter.create<
comb::OrOp>(op.getLoc(), conditions,
true);
1373 replaceOpWithNewOpAndCopyName<comb::MuxOp>(
1374 rewriter, op, cond, firstMux.getTrueValue(),
1375 firstMux.getFalseValue(),
true);
1385 OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1389 auto size = getInputs().size();
1390 auto inputs = adaptor.getInputs();
1394 return getInputs()[0];
1397 if (size == 2 && getInputs()[0] == getInputs()[1])
1401 if (inputs.size() == 2 && inputs[1] &&
1402 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1403 return getInputs()[0];
1407 if (isBinaryNot()) {
1409 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1410 subExpr != getResult())
1420 PatternRewriter &rewriter) {
1421 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1422 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1425 rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1426 icmp.getOperand(1), icmp.getTwoState());
1429 if (op.getNumOperands() > 2) {
1430 SmallVector<Value, 4> newOperands(op.getOperands());
1431 newOperands.pop_back();
1432 newOperands.erase(newOperands.begin() + icmpOperand);
1433 newOperands.push_back(result);
1434 result = rewriter.create<
XorOp>(op.getLoc(), newOperands, op.getTwoState());
1444 auto inputs = op.getInputs();
1445 auto size = inputs.size();
1446 assert(size > 1 &&
"expected 2 or more operands");
1449 if (inputs[size - 1] == inputs[size - 2]) {
1451 "expected idempotent case for 2 elements handled already.");
1452 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1453 inputs.drop_back(2),
false);
1459 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1461 if (value.isZero()) {
1462 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1463 inputs.drop_back(),
false);
1469 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1470 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value ^ value2);
1471 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1472 newOperands.push_back(cst);
1473 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1474 newOperands,
false);
1478 bool isSingleBit = value.getBitWidth() == 1;
1481 for (
size_t i = 0; i < size - 1; ++i) {
1482 Value operand = inputs[i];
1493 if (isSingleBit && operand.hasOneUse()) {
1494 assert(value == 1 &&
"single bit constant has to be one if not zero");
1495 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1511 replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
1518 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1523 if (getRhs() == getLhs())
1525 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1528 if (adaptor.getRhs()) {
1530 if (adaptor.getLhs()) {
1533 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1536 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1538 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1542 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1543 if (rhsC.getValue().isZero())
1557 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1558 auto negCst = rewriter.create<
hw::ConstantOp>(op.getLoc(), -value);
1559 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getLhs(), negCst,
1571 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1575 auto size = getInputs().size();
1579 return getInputs()[0];
1589 auto inputs = op.getInputs();
1590 auto size = inputs.size();
1591 assert(size > 1 &&
"expected 2 or more operands");
1593 APInt value, value2;
1596 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1597 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1598 inputs.drop_back(),
false);
1603 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1604 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1605 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1606 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1607 newOperands.push_back(cst);
1608 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1609 newOperands,
false);
1614 if (inputs[size - 1] == inputs[size - 2]) {
1615 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1617 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1621 newOperands.push_back(shiftLeftOp);
1622 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1623 newOperands,
false);
1627 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1629 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1630 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1632 APInt one(value.getBitWidth(), 1,
false);
1634 rewriter.create<
hw::ConstantOp>(op.getLoc(), (one << value) + one);
1636 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1637 auto mulOp = rewriter.create<
comb::MulOp>(op.getLoc(), factors,
false);
1639 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1640 newOperands.push_back(mulOp);
1641 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1642 newOperands,
false);
1646 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1648 if (mulOp && mulOp.getInputs().size() == 2 &&
1649 mulOp.getInputs()[0] == inputs[size - 2] &&
1650 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1652 APInt one(value.getBitWidth(), 1,
false);
1653 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + one);
1654 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1657 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1658 newOperands.push_back(newMulOp);
1659 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1660 newOperands,
false);
1673 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1674 if (addOp && addOp.getInputs().size() == 2 &&
1675 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1676 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1678 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1679 replaceOpWithNewOpAndCopyName<AddOp>(
1680 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1681 op.getTwoState() && addOp.getTwoState());
1688 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1692 auto size = getInputs().size();
1693 auto inputs = adaptor.getInputs();
1697 return getInputs()[0];
1699 auto width = cast<IntegerType>(getType()).getWidth();
1700 APInt value(
width, 1,
false);
1703 for (
auto operand : inputs) {
1706 value *= cast<IntegerAttr>(operand).getValue();
1719 auto inputs = op.getInputs();
1720 auto size = inputs.size();
1721 assert(size > 1 &&
"expected 2 or more operands");
1723 APInt value, value2;
1726 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1727 value.isPowerOf2()) {
1728 auto shift = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(),
1729 value.exactLogBase2());
1733 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1734 ArrayRef<Value>(shlOp),
false);
1739 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1740 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1741 inputs.drop_back());
1746 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1747 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1748 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value * value2);
1749 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1750 newOperands.push_back(cst);
1751 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1767 template <
class Op,
bool isSigned>
1768 static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1769 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1771 if (rhsValue.getValue() == 1)
1775 if (rhsValue.getValue().isZero())
1782 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1786 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1789 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1796 template <
class Op,
bool isSigned>
1797 static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1798 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1800 if (rhsValue.getValue() == 1)
1801 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1805 if (rhsValue.getValue().isZero())
1809 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1811 if (lhsValue.getValue().isZero())
1812 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1819 OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1823 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1826 OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1837 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1841 if (getNumOperands() == 1)
1842 return getOperand(0);
1845 for (
auto attr : adaptor.getInputs())
1846 if (!attr || !isa<IntegerAttr>(attr))
1850 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1851 APInt result(resultWidth, 0);
1853 unsigned nextInsertion = resultWidth;
1855 for (
auto attr : adaptor.getInputs()) {
1856 auto chunk = cast<IntegerAttr>(attr).getValue();
1857 nextInsertion -= chunk.getBitWidth();
1858 result.insertBits(chunk, nextInsertion);
1868 auto inputs = op.getInputs();
1869 auto size = inputs.size();
1870 assert(size > 1 &&
"expected 2 or more operands");
1875 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1876 ValueRange replacements) -> LogicalResult {
1877 SmallVector<Value, 4> newOperands;
1878 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1879 newOperands.append(replacements.begin(), replacements.end());
1880 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1881 if (newOperands.size() == 1)
1884 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1889 Value commonOperand = inputs[0];
1890 for (
size_t i = 0; i != size; ++i) {
1892 if (inputs[i] != commonOperand)
1893 commonOperand = Value();
1897 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1898 return flattenConcat(i, i, subConcat->getOperands());
1903 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1904 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1905 unsigned prevWidth = prevCst.getValue().getBitWidth();
1906 unsigned thisWidth = cst.getValue().getBitWidth();
1907 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1908 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1912 return flattenConcat(i - 1, i, replacement);
1917 if (inputs[i] == inputs[i - 1]) {
1919 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1920 return flattenConcat(i - 1, i, replacement);
1925 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1927 if (repl.getOperand() == inputs[i - 1]) {
1928 Value replacement = rewriter.createOrFold<ReplicateOp>(
1929 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1930 return flattenConcat(i - 1, i, replacement);
1933 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1934 if (prevRepl.getOperand() == repl.getOperand()) {
1935 Value replacement = rewriter.createOrFold<ReplicateOp>(
1936 op.getLoc(), repl.getOperand(),
1937 repl.getMultiple() + prevRepl.getMultiple());
1938 return flattenConcat(i - 1, i, replacement);
1944 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1945 if (repl.getOperand() == inputs[i]) {
1946 Value replacement = rewriter.createOrFold<ReplicateOp>(
1947 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1948 return flattenConcat(i - 1, i, replacement);
1954 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1955 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1956 if (extract.getInput() == prevExtract.getInput()) {
1957 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1958 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1959 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1960 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1961 Value replacement = rewriter.create<
ExtractOp>(
1962 op.getLoc(), resType, extract.getInput(),
1963 extract.getLowBit());
1964 return flattenConcat(i - 1, i, replacement);
1977 static std::optional<ArraySlice>
get(Value value) {
1978 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1980 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1983 if (
auto arraySlice =
1986 arraySlice.getInput(), arraySlice.getLowIndex(),
1987 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1989 return std::nullopt;
1995 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1996 prevExtractOpt->input == extractOpt->input &&
1998 extractOpt->width)) {
2000 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
2002 extractOpt->width + prevExtractOpt->width);
2005 op.getLoc(), resIntType,
2007 prevExtractOpt->input,
2008 extractOpt->index));
2009 return flattenConcat(i - 1, i, replacement);
2017 if (commonOperand) {
2018 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2030 OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
2035 if (getTrueValue() == getFalseValue())
2036 return getTrueValue();
2037 if (
auto tv = adaptor.getTrueValue())
2038 if (tv == adaptor.getFalseValue())
2043 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
2044 if (pred.getValue().isZero())
2045 return getFalseValue();
2046 return getTrueValue();
2050 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
2051 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
2052 if (tv.getValue().isOne() && fv.getValue().isZero() &&
2069 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
2071 auto requiredPredicate =
2072 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
2073 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
2083 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
2086 for (
auto operand : orOp.getOperands())
2093 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
2096 for (
auto operand : andOp.getOperands())
2114 PatternRewriter &rewriter) {
2117 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2120 Value indexValue = rootCmp.getLhs();
2123 auto getCaseValue = [&](
MuxOp mux) -> Value {
2124 return mux.getOperand(1 +
unsigned(!isFalseSide));
2129 auto getTreeValue = [&](
MuxOp mux) -> Value {
2130 return mux.getOperand(1 +
unsigned(isFalseSide));
2135 SmallVector<Location> locationsFound;
2136 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2140 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2142 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2143 valuesFound.push_back({cst, getCaseValue(mux)});
2144 locationsFound.push_back(mux.getCond().getLoc());
2145 locationsFound.push_back(mux->getLoc());
2150 if (!collectConstantValues(rootMux))
2154 if (rootMux->hasOneUse()) {
2155 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2156 if (getTreeValue(userMux) == rootMux.getResult() &&
2164 auto nextTreeValue = getTreeValue(rootMux);
2166 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2167 if (!nextMux || !nextMux->hasOneUse())
2169 if (!collectConstantValues(nextMux))
2171 nextTreeValue = getTreeValue(nextMux);
2177 if (valuesFound.size() < 3)
2182 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2183 if (indexWidth >= 9)
2189 uint64_t tableSize = 1ULL << indexWidth;
2190 if (valuesFound.size() < (tableSize * 5) / 8)
2195 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2200 for (
auto &elt : llvm::reverse(valuesFound)) {
2201 uint64_t idx = elt.first.getValue().getZExtValue();
2202 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2203 table[idx] = elt.second;
2208 std::reverse(table.begin(), table.end());
2211 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2213 replaceOpWithNewOpAndCopyName<hw::ArrayGetOp>(rewriter, rootMux, array,
2228 PatternRewriter &rewriter) {
2229 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2230 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2234 if (fullyAssoc->getNumOperands() == 2)
2235 return fullyAssoc->getOperand(operandNo ^ 1);
2238 if (fullyAssoc->hasOneUse()) {
2239 rewriter.modifyOpInPlace(fullyAssoc,
2240 [&]() { fullyAssoc->eraseOperand(operandNo); });
2241 return fullyAssoc->getResult(0);
2245 SmallVector<Value> operands;
2246 operands.append(fullyAssoc->getOperands().begin(),
2247 fullyAssoc->getOperands().begin() + operandNo);
2248 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2249 fullyAssoc->getOperands().end());
2251 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2252 Value excluded = fullyAssoc->getOperand(operandNo);
2256 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2258 return opWithoutExcluded;
2268 PatternRewriter &rewriter) {
2271 Operation *subExpr =
2272 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2273 if (!subExpr || subExpr->getNumOperands() < 2)
2277 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2282 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2283 size_t opNo = 0, e = subExpr->getNumOperands();
2284 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2290 Value cond = op.getCond();
2296 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2298 Value subCond = subMux.getCond();
2301 if (subMux.getTrueValue() == commonValue)
2302 otherValue = subMux.getFalseValue();
2303 else if (subMux.getFalseValue() == commonValue) {
2304 otherValue = subMux.getTrueValue();
2314 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2315 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, cond, commonValue,
2316 otherValue, op.getTwoState());
2322 bool isaAndOp = isa<AndOp>(subExpr);
2323 if (isTrueOperand ^ isaAndOp)
2327 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2330 bool isaXorOp = isa<XorOp>(subExpr);
2331 bool isaOrOp = isa<OrOp>(subExpr);
2340 if (isaOrOp || isaXorOp) {
2341 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2342 restOfAssoc,
false);
2344 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, masked, commonValue,
2347 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, masked, commonValue,
2353 assert(isaAndOp &&
"unexpected operation here");
2354 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2355 restOfAssoc,
false);
2356 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, masked, commonValue,
2367 PatternRewriter &rewriter) {
2370 if (!isa<ConcatOp>(trueOp))
2374 SmallVector<Value> trueOperands, falseOperands;
2378 size_t numTrueOperands = trueOperands.size();
2379 size_t numFalseOperands = falseOperands.size();
2381 if (!numTrueOperands || !numFalseOperands ||
2382 (trueOperands.front() != falseOperands.front() &&
2383 trueOperands.back() != falseOperands.back()))
2387 if (trueOperands.front() == falseOperands.front()) {
2388 SmallVector<Value> operands;
2390 for (i = 0; i < numTrueOperands; ++i) {
2391 Value trueOperand = trueOperands[i];
2392 if (trueOperand == falseOperands[i])
2393 operands.push_back(trueOperand);
2397 if (i == numTrueOperands) {
2404 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2405 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2406 mux->getLoc(), operands.front(), operands.size());
2408 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2412 operands.append(trueOperands.begin() + i, trueOperands.end());
2413 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2415 operands.append(falseOperands.begin() + i, falseOperands.end());
2417 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2420 Value lsb = rewriter.createOrFold<
MuxOp>(
2421 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2422 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2427 if (trueOperands.back() == falseOperands.back()) {
2428 SmallVector<Value> operands;
2431 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2432 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2433 operands.push_back(trueOperand);
2437 std::reverse(operands.begin(), operands.end());
2438 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2442 operands.append(trueOperands.begin(), trueOperands.end() - i);
2443 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2445 operands.append(falseOperands.begin(), falseOperands.end() - i);
2447 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2449 Value msb = rewriter.createOrFold<
MuxOp>(
2450 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2451 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, msb, sharedLSB);
2463 if (!trueVec || !falseVec)
2465 if (!trueVec.isUniform() || !falseVec.isUniform())
2469 op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2470 falseVec.getUniformElement(), op.getTwoState());
2472 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2479 using OpRewritePattern::OpRewritePattern;
2481 LogicalResult matchAndRewrite(
MuxOp op,
2482 PatternRewriter &rewriter)
const override;
2485 LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2486 PatternRewriter &rewriter)
const {
2495 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2496 if (value.getBitWidth() == 1) {
2498 if (value.isZero()) {
2500 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, notCond,
2501 op.getFalseValue(),
false);
2506 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getCond(),
2507 op.getFalseValue(),
false);
2513 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2518 APInt xorValue = value ^ value2;
2519 if (xorValue.isPowerOf2()) {
2520 unsigned leadingZeros = xorValue.countLeadingZeros();
2521 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2522 SmallVector<Value, 3> operands;
2530 if (leadingZeros > 0)
2531 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2532 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2536 auto v1 = rewriter.createOrFold<
ExtractOp>(
2537 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2538 auto v2 = rewriter.createOrFold<
ExtractOp>(
2539 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2540 operands.push_back(rewriter.createOrFold<
MuxOp>(
2541 op.getLoc(), op.getCond(), v1, v2,
false));
2543 if (trailingZeros > 0)
2544 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2545 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2547 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
2554 if (value.isAllOnes() && value2.isZero()) {
2555 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2562 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2563 value.getBitWidth() == 1) {
2565 if (value.isZero()) {
2566 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getCond(),
2567 op.getTrueValue(),
false);
2574 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2575 op.getFalseValue(),
false);
2576 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, notCond,
2577 op.getTrueValue(),
false);
2583 Operation *condOp = op.getCond().getDefiningOp();
2584 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2586 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, op.getType(), subExpr,
2587 op.getFalseValue(), op.getTrueValue(),
2595 if (condOp && condOp->hasOneUse()) {
2596 SmallVector<Value> invertedOperands;
2600 auto getInvertedOperands = [&]() ->
bool {
2601 for (Value operand : condOp->getOperands()) {
2602 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2603 invertedOperands.push_back(subExpr);
2610 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2612 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2613 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newOr,
2615 op.getTrueValue(), op.getTwoState());
2618 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2620 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2621 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newAnd,
2623 op.getTrueValue(), op.getTwoState());
2629 dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp())) {
2631 if (op.getCond() == falseMux.getCond()) {
2632 replaceOpWithNewOpAndCopyName<MuxOp>(
2633 rewriter, op, op.getCond(), op.getTrueValue(),
2634 falseMux.getFalseValue(), op.getTwoStateAttr());
2644 dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp())) {
2646 if (op.getCond() == trueMux.getCond()) {
2647 replaceOpWithNewOpAndCopyName<MuxOp>(
2648 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2649 op.getFalseValue(), op.getTwoStateAttr());
2659 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2660 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2661 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2662 trueMux.getTrueValue() == falseMux.getTrueValue()) {
2663 auto subMux = rewriter.create<
MuxOp>(
2664 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2665 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2666 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2667 trueMux.getTrueValue(), subMux,
2668 op.getTwoStateAttr());
2673 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2674 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2675 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2676 trueMux.getFalseValue() == falseMux.getFalseValue()) {
2677 auto subMux = rewriter.create<
MuxOp>(
2678 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2679 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2680 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2681 subMux, trueMux.getFalseValue(),
2682 op.getTwoStateAttr());
2687 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2688 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2689 trueMux && falseMux &&
2690 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2691 trueMux.getFalseValue() == falseMux.getFalseValue()) {
2692 auto subMux = rewriter.create<
MuxOp>(
2693 rewriter.getFusedLoc(
2694 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2695 op.getCond(), trueMux.getCond(), falseMux.getCond());
2696 replaceOpWithNewOpAndCopyName<MuxOp>(
2697 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2698 op.getTwoStateAttr());
2710 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2711 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2712 if (trueOp->getName() == falseOp->getName())
2729 if (op.getInputs().empty() || op.isUniform())
2731 auto inputs = op.getInputs();
2732 if (inputs.size() <= 1)
2737 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2742 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2743 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2744 if (!input || first.getCond() != input.getCond())
2749 SmallVector<Value> trues{first.getTrueValue()};
2750 SmallVector<Value> falses{first.getFalseValue()};
2751 SmallVector<Location> locs{first->getLoc()};
2752 bool isTwoState =
true;
2753 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2754 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2755 trues.push_back(input.getTrueValue());
2756 falses.push_back(input.getFalseValue());
2757 locs.push_back(input->getLoc());
2758 if (!input.getTwoState())
2767 auto arrayTy = op.getType();
2770 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2771 trueValues, falseValues, isTwoState);
2776 using OpRewritePattern::OpRewritePattern;
2779 PatternRewriter &rewriter)
const override {
2783 if (foldArrayOfMuxes(op, rewriter))
2791 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2792 MLIRContext *context) {
2793 results.insert<MuxRewriter, ArrayRewriter>(context);
2804 switch (predicate) {
2805 case ICmpPredicate::eq:
2807 case ICmpPredicate::ne:
2809 case ICmpPredicate::slt:
2810 return lhs.slt(rhs);
2811 case ICmpPredicate::sle:
2812 return lhs.sle(rhs);
2813 case ICmpPredicate::sgt:
2814 return lhs.sgt(rhs);
2815 case ICmpPredicate::sge:
2816 return lhs.sge(rhs);
2817 case ICmpPredicate::ult:
2818 return lhs.ult(rhs);
2819 case ICmpPredicate::ule:
2820 return lhs.ule(rhs);
2821 case ICmpPredicate::ugt:
2822 return lhs.ugt(rhs);
2823 case ICmpPredicate::uge:
2824 return lhs.uge(rhs);
2825 case ICmpPredicate::ceq:
2827 case ICmpPredicate::cne:
2829 case ICmpPredicate::weq:
2831 case ICmpPredicate::wne:
2834 llvm_unreachable(
"unknown comparison predicate");
2840 switch (predicate) {
2841 case ICmpPredicate::eq:
2842 case ICmpPredicate::sle:
2843 case ICmpPredicate::sge:
2844 case ICmpPredicate::ule:
2845 case ICmpPredicate::uge:
2846 case ICmpPredicate::ceq:
2847 case ICmpPredicate::weq:
2849 case ICmpPredicate::ne:
2850 case ICmpPredicate::slt:
2851 case ICmpPredicate::sgt:
2852 case ICmpPredicate::ult:
2853 case ICmpPredicate::ugt:
2854 case ICmpPredicate::cne:
2855 case ICmpPredicate::wne:
2858 llvm_unreachable(
"unknown comparison predicate");
2861 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2867 if (getLhs() == getRhs()) {
2873 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2874 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2885 template <
typename Range>
2887 size_t commonPrefixLength = 0;
2888 auto ia = a.begin();
2889 auto ib = b.begin();
2891 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2897 return commonPrefixLength;
2901 size_t totalWidth = 0;
2902 for (
auto operand : operands) {
2905 ssize_t
width = operand.getType().getIntOrFloatBitWidth();
2907 totalWidth +=
width;
2917 PatternRewriter &rewriter) {
2921 SmallVector<Value> lhsOperands, rhsOperands;
2924 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2926 auto formCatOrReplicate = [&](Location loc,
2927 ArrayRef<Value> operands) -> Value {
2928 assert(!operands.empty());
2929 Value sameElement = operands[0];
2930 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2931 if (sameElement != operands[i])
2932 sameElement = Value();
2934 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2936 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2939 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2940 Value rhs) -> LogicalResult {
2941 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2946 size_t commonPrefixLength =
2948 if (commonPrefixLength == lhsOperands.size()) {
2951 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
2957 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2959 size_t commonPrefixTotalWidth =
2960 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2961 size_t commonSuffixTotalWidth =
2962 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2963 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2964 .drop_back(commonSuffixLength);
2965 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2966 .drop_back(commonSuffixLength);
2968 auto replaceWithoutReplicatingSignBit = [&]() {
2969 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2970 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2971 return replaceWith(op.getPredicate(), newLhs, newRhs);
2974 auto replaceWithReplicatingSignBit = [&]() {
2975 auto firstNonEmptyValue = lhsOperands[0];
2976 auto firstNonEmptyElemWidth =
2977 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2978 Value signBit = rewriter.createOrFold<
ExtractOp>(
2979 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2981 auto newLhs = rewriter.
create<
ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2982 auto newRhs = rewriter.create<
ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2983 return replaceWith(op.getPredicate(), newLhs, newRhs);
2986 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2988 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2989 return replaceWithoutReplicatingSignBit();
2995 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2996 return replaceWithReplicatingSignBit();
2998 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
3000 return replaceWithoutReplicatingSignBit();
3014 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
3015 PatternRewriter &rewriter) {
3019 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
3020 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
3023 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
3024 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
3032 SmallVector<Value> newConcatOperands;
3033 auto newConstant = APInt::getZeroWidth();
3038 unsigned knownMSB = bitsKnown.countLeadingOnes();
3040 Value operand = cmpOp.getLhs();
3045 while (knownMSB != bitsKnown.getBitWidth()) {
3048 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3051 unsigned unknownBits = bitsKnown.countLeadingZeros();
3052 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3053 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
3054 operand.getLoc(), operand, lowBit,
3056 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3059 newConcatOperands.push_back(spanOperand);
3062 if (newConstant.getBitWidth() != 0)
3063 newConstant = newConstant.concat(spanConstant);
3065 newConstant = spanConstant;
3068 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3069 bitsKnown = bitsKnown.trunc(newWidth);
3070 knownMSB = bitsKnown.countLeadingOnes();
3076 if (newConcatOperands.empty()) {
3077 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3078 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
3084 Value concatResult =
3085 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
3089 cmpOp.getOperand(1).getLoc(), newConstant);
3091 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3092 concatResult, newConstantOp,
3093 cmpOp.getTwoState());
3099 PatternRewriter &rewriter) {
3100 auto ip = rewriter.saveInsertionPoint();
3101 rewriter.setInsertionPoint(xorOp);
3103 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
3105 xorRHS.getValue() ^ rhs);
3107 switch (xorOp.getNumOperands()) {
3111 APInt::getZero(rhs.getBitWidth()));
3115 newLHS = xorOp.getOperand(0);
3119 SmallVector<Value> newOperands(xorOp.getOperands());
3120 newOperands.pop_back();
3121 newLHS = rewriter.create<
XorOp>(xorOp.getLoc(), newOperands,
false);
3125 bool xorMultipleUses = !xorOp->hasOneUse();
3129 if (xorMultipleUses)
3130 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3134 rewriter.restoreInsertionPoint(ip);
3135 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3136 newLHS, newRHS,
false);
3146 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3147 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3148 "Should be folded");
3149 replaceOpWithNewOpAndCopyName<ICmpOp>(
3150 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3151 op.getRhs(), op.getLhs(), op.getTwoState());
3156 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3158 return rewriter.create<
hw::ConstantOp>(op.getLoc(), std::move(constant));
3161 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3162 Value rhs) -> LogicalResult {
3163 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
3168 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3169 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
3170 APInt(1, constant));
3174 switch (op.getPredicate()) {
3175 case ICmpPredicate::slt:
3177 if (rhs.isMaxSignedValue())
3178 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3180 if (rhs.isMinSignedValue())
3181 return replaceWithConstantI1(0);
3183 if ((rhs - 1).isMinSignedValue())
3184 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3187 case ICmpPredicate::sgt:
3189 if (rhs.isMinSignedValue())
3190 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3192 if (rhs.isMaxSignedValue())
3193 return replaceWithConstantI1(0);
3195 if ((rhs + 1).isMaxSignedValue())
3196 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3199 case ICmpPredicate::ult:
3201 if (rhs.isAllOnes())
3202 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3205 return replaceWithConstantI1(0);
3207 if ((rhs - 1).isZero())
3208 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3212 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3213 rhs.getBitWidth()) {
3214 auto numOnes = rhs.countLeadingOnes();
3215 auto smaller = rewriter.create<
ExtractOp>(
3216 op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3217 return replaceWith(ICmpPredicate::ne, smaller,
3222 case ICmpPredicate::ugt:
3225 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3227 if (rhs.isAllOnes())
3228 return replaceWithConstantI1(0);
3230 if ((rhs + 1).isAllOnes())
3231 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3235 if ((rhs + 1).isPowerOf2()) {
3236 auto numOnes = rhs.countTrailingOnes();
3237 auto newWidth = rhs.getBitWidth() - numOnes;
3238 auto smaller = rewriter.create<
ExtractOp>(op.getLoc(), op.getLhs(),
3240 return replaceWith(ICmpPredicate::ne, smaller,
3245 case ICmpPredicate::sle:
3247 if (rhs.isMaxSignedValue())
3248 return replaceWithConstantI1(1);
3250 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3251 case ICmpPredicate::sge:
3253 if (rhs.isMinSignedValue())
3254 return replaceWithConstantI1(1);
3256 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3257 case ICmpPredicate::ule:
3259 if (rhs.isAllOnes())
3260 return replaceWithConstantI1(1);
3262 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3263 case ICmpPredicate::uge:
3266 return replaceWithConstantI1(1);
3268 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3269 case ICmpPredicate::eq:
3270 if (rhs.getBitWidth() == 1) {
3273 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3278 if (rhs.isAllOnes()) {
3285 case ICmpPredicate::ne:
3286 if (rhs.getBitWidth() == 1) {
3292 if (rhs.isAllOnes()) {
3294 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3301 case ICmpPredicate::ceq:
3302 case ICmpPredicate::cne:
3303 case ICmpPredicate::weq:
3304 case ICmpPredicate::wne:
3310 if (op.getPredicate() == ICmpPredicate::eq ||
3311 op.getPredicate() == ICmpPredicate::ne) {
3316 if (!knownBits.isUnknown())
3323 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3330 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3331 if (rhs.isAllOnes() || rhs.isZero()) {
3332 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3334 op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(
width)
3335 : APInt::getZero(
width));
3336 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, op.getPredicate(),
3337 replicateOp.getInput(), cst,
3347 if (Operation *opLHS = op.getLhs().getDefiningOp())
3348 if (Operation *opRHS = op.getRhs().getDefiningOp())
3349 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3350 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "sv.namehint" attribute.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1, size_t concatIdx2, PatternRewriter &rewriter)
Simplify concat ops in an or op when a constant operand is present in either concat.
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
def create(data_type, value)
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
uint64_t getWidth(Type t)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
bool isOffset(Value base, Value index, uint64_t offset)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.