12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallBitVector.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/KnownBits.h"
20 using namespace circt;
22 using namespace matchers;
33 Block *thisBlock = op->getBlock();
34 return llvm::any_of(op->getOperands(), [&](Value operand) {
35 return operand.getParentBlock() != thisBlock;
45 ArrayRef<Value> operands, OpBuilder &builder) {
46 OperationState state(loc, name);
47 state.addOperands(operands);
48 state.addTypes(operands[0].getType());
49 return builder.create(state)->getResult(0);
52 static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
60 for (
auto op :
concat.getOperands())
62 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
63 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
75 if (
auto *newOp = newValue.getDefiningOp()) {
76 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
77 if (name && !newOp->hasAttr(
"sv.namehint"))
78 rewriter.modifyOpInPlace(newOp,
79 [&] { newOp->setAttr(
"sv.namehint", name); });
81 rewriter.replaceOp(op, newValue);
87 template <
typename OpTy,
typename... Args>
89 Operation *op, Args &&...args) {
90 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
92 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
93 if (name && !newOp->hasAttr(
"sv.namehint"))
94 rewriter.modifyOpInPlace(newOp,
95 [&] { newOp->setAttr(
"sv.namehint", name); });
104 return op->hasAttr(
"sv.attributes");
108 template <
typename SubType>
109 struct ComplementMatcher {
111 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
112 bool match(Operation *op) {
113 auto xorOp = dyn_cast<XorOp>(op);
114 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
119 template <
typename SubType>
120 static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
121 return ComplementMatcher<SubType>(subExpr);
127 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
128 "must be commutative operations"));
129 if (op->hasOneUse()) {
130 auto *user = *op->getUsers().begin();
131 return user->getName() == op->getName() &&
132 op->getAttrOfType<UnitAttr>(
"twoState") ==
133 user->getAttrOfType<UnitAttr>(
"twoState");
148 auto inputs = op->getOperands();
150 SmallVector<Value, 4> newOperands;
151 SmallVector<Location, 4> newLocations{op->getLoc()};
152 newOperands.reserve(inputs.size());
154 decltype(inputs.begin()) current, end;
157 SmallVector<Element> worklist;
158 worklist.push_back({inputs.begin(), inputs.end()});
159 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
160 bool changed =
false;
161 while (!worklist.empty()) {
162 auto &element = worklist.back();
165 if (element.current == element.end) {
170 Value value = *element.current++;
171 auto *flattenOp = value.getDefiningOp();
172 if (!flattenOp || flattenOp->getName() != op->getName() ||
173 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState")) {
174 newOperands.push_back(value);
179 if (!value.hasOneUse()) {
187 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
190 newOperands.push_back(value);
198 auto flattenOpInputs = flattenOp->getOperands();
199 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
200 newLocations.push_back(flattenOp->getLoc());
207 op->getName(), newOperands, rewriter);
209 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
217 static std::pair<size_t, size_t>
219 size_t originalOpWidth) {
220 auto users = op->getUsers();
222 "getLowestBitAndHighestBitRequired cannot operate on "
223 "a empty list of uses.");
227 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
228 size_t highestBitRequired = 0;
230 for (
auto *user : users) {
231 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
232 size_t lowBit = extractOp.getLowBit();
234 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
235 highestBitRequired = std::max(highestBitRequired, highBit);
236 lowestBitRequired = std::min(lowestBitRequired, lowBit);
240 highestBitRequired = originalOpWidth - 1;
241 lowestBitRequired = 0;
245 return {lowestBitRequired, highestBitRequired};
248 template <
class OpTy>
250 PatternRewriter &rewriter) {
251 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
257 if (range.second + 1 == opType.getWidth() && range.first == 0)
260 SmallVector<Value> args;
261 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
262 for (
auto inop : op.getOperands()) {
264 if (inop.getType() != op.getType())
265 args.push_back(inop);
267 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
270 Value newop = rewriter.createOrFold<OpTy>(op.getLoc(), newType, args);
271 newop.getDefiningOp()->setDialectAttrs(op->getDialectAttrs());
273 newop = rewriter.createOrFold<
ConcatOp>(
276 APInt::getZero(range.first)));
277 if (range.second + 1 < opType.getWidth())
278 newop = rewriter.createOrFold<
ConcatOp>(
281 op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
283 rewriter.replaceOp(op, newop);
291 OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
296 if (cast<IntegerType>(getType()).
getWidth() ==
297 getInput().getType().getIntOrFloatBitWidth())
301 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
302 if (input.getValue().getBitWidth() == 1) {
303 if (input.getValue().isZero())
305 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
308 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
312 APInt result = APInt::getZeroWidth();
313 for (
auto i = getMultiple(); i != 0; --i)
314 result = result.concat(input.getValue());
321 OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
326 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
327 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
339 hw::PEO paramOpcode) {
340 assert(operands.size() == 2 &&
"binary op takes two operands");
341 if (!operands[0] || !operands[1])
347 cast<TypedAttr>(operands[1]));
350 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
354 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
355 unsigned shift = rhs.getValue().getZExtValue();
356 unsigned width = getType().getIntOrFloatBitWidth();
358 return getOperand(0);
372 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
375 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
376 unsigned shift = value.getZExtValue();
379 if (
width <= shift || shift == 0)
383 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
389 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, extract, zeros);
393 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
397 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
398 unsigned shift = rhs.getValue().getZExtValue();
400 return getOperand(0);
402 unsigned width = getType().getIntOrFloatBitWidth();
415 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
418 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
419 unsigned shift = value.getZExtValue();
422 if (
width <= shift || shift == 0)
426 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
429 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
432 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, zeros, extract);
436 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
440 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
441 if (rhs.getValue().getZExtValue() == 0)
442 return getOperand(0);
453 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
456 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
457 unsigned shift = value.getZExtValue();
460 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(),
width - 1, 1);
461 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
463 if (
width <= shift) {
468 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
471 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, sext, extract);
479 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
484 if (getInput().getType() == getType())
488 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
489 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
490 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
501 PatternRewriter &rewriter) {
502 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
503 size_t beginOfFirstRelevantElement = 0;
504 auto it = reversedConcatArgs.begin();
505 size_t lowBit = op.getLowBit();
508 for (; it != reversedConcatArgs.end(); it++) {
509 assert(beginOfFirstRelevantElement <= lowBit &&
510 "incorrectly moved past an element that lowBit has coverage over");
513 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
514 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
538 beginOfFirstRelevantElement += operandWidth;
540 assert(it != reversedConcatArgs.end() &&
541 "incorrectly failed to find an element which contains coverage of "
544 SmallVector<Value> reverseConcatArgs;
545 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
546 size_t extractLo = lowBit - beginOfFirstRelevantElement;
551 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
552 auto concatArg = *it;
553 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
554 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
556 if (widthToConsume == operandWidth && extractLo == 0) {
557 reverseConcatArgs.push_back(concatArg);
560 reverseConcatArgs.push_back(
561 rewriter.create<
ExtractOp>(op.getLoc(), resultType, *it, extractLo));
564 widthRemaining -= widthToConsume;
570 if (reverseConcatArgs.size() == 1) {
573 replaceOpWithNewOpAndCopyName<ConcatOp>(
574 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
581 PatternRewriter &rewriter) {
582 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
583 auto replicateEltWidth =
584 replicate.getOperand().getType().getIntOrFloatBitWidth();
588 if (op.getLowBit() % replicateEltWidth == 0 &&
589 extractResultWidth % replicateEltWidth == 0) {
590 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
591 replicate.getOperand());
597 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
599 replaceOpWithNewOpAndCopyName<ExtractOp>(
600 rewriter, op, op.getType(), replicate.getOperand(),
601 op.getLowBit() % replicateEltWidth);
614 auto *inputOp = op.getInput().getDefiningOp();
621 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
623 if (knownBits.isConstant()) {
624 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
625 knownBits.getConstant());
631 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
632 replaceOpWithNewOpAndCopyName<ExtractOp>(
633 rewriter, op, op.getType(), innerExtract.getInput(),
634 innerExtract.getLowBit() + op.getLowBit());
639 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
643 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
649 if (inputOp && inputOp->getNumOperands() == 2 &&
650 isa<AndOp, OrOp, XorOp>(inputOp)) {
651 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
652 auto extractedCst = cstRHS.getValue().extractBits(
653 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
654 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
655 replaceOpWithNewOpAndCopyName<ExtractOp>(
656 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
664 if (isa<AndOp>(inputOp)) {
667 unsigned lz = extractedCst.countLeadingZeros();
668 unsigned tz = extractedCst.countTrailingZeros();
669 unsigned pop = extractedCst.popcount();
670 if (extractedCst.getBitWidth() - lz - tz == pop) {
671 auto resultTy = rewriter.getIntegerType(pop);
672 SmallVector<Value> resultElts;
675 op.getLoc(), APInt::getZero(lz)));
676 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
677 op.getLoc(), resultTy, inputOp->getOperand(0),
678 op.getLowBit() + tz));
681 op.getLoc(), APInt::getZero(tz)));
682 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
691 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
692 if (
auto shlOp = dyn_cast<ShlOp>(inputOp))
693 if (
auto lhsCst = shlOp.getOperand(0).getDefiningOp<
hw::ConstantOp>())
694 if (lhsCst.getValue().isOne()) {
697 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
698 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
699 shlOp->getOperand(1), newCst,
714 hw::PEO paramOpcode) {
715 assert(operands.size() > 1 &&
"caller should handle one-operand case");
718 if (!operands[1] || !operands[0])
722 if (llvm::all_of(operands.drop_front(2),
723 [&](Attribute in) { return !!in; })) {
724 SmallVector<mlir::TypedAttr> typedOperands;
725 typedOperands.reserve(operands.size());
726 for (
auto operand : operands) {
727 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
728 typedOperands.push_back(typedOperand);
732 if (typedOperands.size() == operands.size())
749 size_t concatIdx,
const APInt &cst,
750 PatternRewriter &rewriter) {
751 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
752 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
757 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
758 auto *operandOp = operand.getDefiningOp();
763 if (isa<hw::ConstantOp>(operandOp))
767 return operandOp->getName() == logicalOp->getName() &&
768 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
769 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
777 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
778 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
785 SmallVector<Value> newConcatOperands;
786 newConcatOperands.reserve(concatOp->getNumOperands());
789 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
790 for (Value operand : concatOp->getOperands()) {
791 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
792 nextOperandBit -= operandWidth;
795 logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
797 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
806 if (logicalOp->getNumOperands() > 2) {
807 auto origOperands = logicalOp->getOperands();
808 SmallVector<Value> operands;
810 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
812 operands.append(origOperands.begin() + concatIdx + 1,
813 origOperands.begin() + (origOperands.size() - 1));
815 operands.push_back(newResult);
816 newResult = createLogicalOp(operands);
826 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
828 for (
auto op : operands) {
829 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
830 icmpOp && icmpOp.getTwoState()) {
831 auto predicate = icmpOp.getPredicate();
832 auto lhs = icmpOp.getLhs();
833 auto rhs = icmpOp.getRhs();
834 if (seenPredicates.contains(
835 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
838 seenPredicates.insert({predicate, lhs, rhs});
844 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
848 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
850 auto inputs = adaptor.getInputs();
853 for (
auto operand : inputs) {
856 value &= cast<IntegerAttr>(operand).getValue();
862 if (inputs.size() == 2 && inputs[1] &&
863 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
864 return getInputs()[0];
867 if (llvm::all_of(getInputs(),
868 [&](
auto in) {
return in == this->getInputs()[0]; }))
869 return getInputs()[0];
872 for (Value arg : getInputs()) {
875 for (Value arg2 : getInputs())
878 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
899 template <
typename Op>
901 if (!op.getType().isInteger(1))
904 auto inputs = op.getInputs();
905 size_t size = inputs.size();
907 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
910 Value source = sourceOp.getOperand();
913 if (size != source.getType().getIntOrFloatBitWidth())
917 llvm::BitVector bits(size);
918 bits.set(sourceOp.getLowBit());
920 for (
size_t i = 1; i != size; ++i) {
921 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
922 if (!extractOp || extractOp.getOperand() != source)
924 bits.set(extractOp.getLowBit());
927 return bits.all() ? source : Value();
934 template <
typename Op>
936 auto inputs = op.getInputs();
938 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
939 llvm::SmallDenseSet<Value, 8> checked;
942 llvm::SmallVector<Value, 8> worklist;
943 for (
auto input : inputs) {
945 worklist.push_back(input);
948 while (!worklist.empty()) {
949 auto element = worklist.pop_back_val();
951 if (
auto idempotentOp = element.getDefiningOp<Op>()) {
952 for (
auto input : idempotentOp.getInputs()) {
953 uniqueInputs.remove(input);
955 if (checked.insert(input).second)
956 worklist.push_back(input);
961 if (uniqueInputs.size() < inputs.size()) {
962 replaceOpWithNewOpAndCopyName<Op>(rewriter, op, op.getType(),
963 uniqueInputs.getArrayRef());
974 auto inputs = op.getInputs();
975 auto size = inputs.size();
976 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
990 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
992 if (value.isAllOnes()) {
993 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
994 inputs.drop_back(),
false);
1002 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1003 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value & value2);
1004 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1005 newOperands.push_back(cst);
1006 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1007 newOperands,
false);
1012 if (size == 2 && value.isPowerOf2()) {
1017 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1018 auto replicateOperand = replicate.getOperand();
1019 if (replicateOperand.getType().isInteger(1)) {
1020 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1021 auto trailingZeros = value.countTrailingZeros();
1024 SmallVector<Value, 3> concatOperands;
1025 if (trailingZeros != resultWidth - 1) {
1027 op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
1028 concatOperands.push_back(highZeros);
1030 concatOperands.push_back(replicateOperand);
1031 if (trailingZeros != 0) {
1033 op.getLoc(), APInt::getZero(trailingZeros));
1034 concatOperands.push_back(lowZeros);
1036 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1044 if (
auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
1047 (value.countLeadingZeros() || value.countTrailingZeros())) {
1048 unsigned lz = value.countLeadingZeros();
1049 unsigned tz = value.countTrailingZeros();
1052 auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
1053 Value smallElt = rewriter.createOrFold<
ExtractOp>(
1054 extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1055 extractOp.getLowBit() + tz);
1057 APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1058 if (!smallMask.isAllOnes()) {
1059 auto loc = inputs.back().getLoc();
1060 smallElt = rewriter.createOrFold<
AndOp>(
1067 SmallVector<Value> resultElts;
1069 resultElts.push_back(
1070 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1071 resultElts.push_back(smallElt);
1073 resultElts.push_back(
1074 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1075 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
1083 for (
size_t i = 0; i < size - 1; ++i) {
1084 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1097 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1098 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1099 source, cmpAgainst);
1107 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1111 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1112 auto inputs = adaptor.getInputs();
1114 for (
auto operand : inputs) {
1117 value |= cast<IntegerAttr>(operand).getValue();
1118 if (value.isAllOnes())
1123 if (inputs.size() == 2 && inputs[1] &&
1124 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1125 return getInputs()[0];
1128 if (llvm::all_of(getInputs(),
1129 [&](
auto in) {
return in == this->getInputs()[0]; }))
1130 return getInputs()[0];
1133 for (Value arg : getInputs()) {
1135 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1136 for (Value arg2 : getInputs())
1137 if (arg2 == subExpr)
1139 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1149 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1173 PatternRewriter &rewriter) {
1174 assert(concatIdx1 < concatIdx2 &&
"concatIdx1 must be < concatIdx2");
1176 auto inputs = op.getInputs();
1177 auto concat1 = inputs[concatIdx1].getDefiningOp<
ConcatOp>();
1178 auto concat2 = inputs[concatIdx2].getDefiningOp<
ConcatOp>();
1180 assert(concat1 && concat2 &&
"expected indexes to point to ConcatOps");
1183 bool hasConstantOp1 =
1184 llvm::any_of(concat1->getOperands(), [&](Value operand) ->
bool {
1185 return operand.getDefiningOp<hw::ConstantOp>();
1187 if (!hasConstantOp1) {
1188 bool hasConstantOp2 =
1189 llvm::any_of(concat2->getOperands(), [&](Value operand) ->
bool {
1190 return operand.getDefiningOp<hw::ConstantOp>();
1192 if (!hasConstantOp2)
1196 SmallVector<Value> newConcatOperands;
1201 auto operands1 = concat1->getOperands();
1202 auto operands2 = concat2->getOperands();
1204 unsigned consumedWidth1 = 0;
1205 unsigned consumedWidth2 = 0;
1206 for (
auto it1 = operands1.begin(), end1 = operands1.end(),
1207 it2 = operands2.begin(), end2 = operands2.end();
1208 it1 != end1 && it2 != end2;) {
1209 auto operand1 = *it1;
1210 auto operand2 = *it2;
1212 unsigned remainingWidth1 =
1214 unsigned remainingWidth2 =
1216 unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2);
1217 auto narrowedType = rewriter.getIntegerType(widthToConsume);
1219 auto extract1 = rewriter.createOrFold<
ExtractOp>(
1220 op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume);
1221 auto extract2 = rewriter.createOrFold<
ExtractOp>(
1222 op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume);
1224 newConcatOperands.push_back(
1225 rewriter.createOrFold<
OrOp>(op.getLoc(), extract1, extract2,
false));
1227 consumedWidth1 += widthToConsume;
1228 consumedWidth2 += widthToConsume;
1230 if (widthToConsume == remainingWidth1) {
1234 if (widthToConsume == remainingWidth2) {
1240 ConcatOp newOp = rewriter.create<
ConcatOp>(op.getLoc(), newConcatOperands);
1244 SmallVector<Value> newOrOperands;
1245 newOrOperands.append(inputs.begin(), inputs.begin() + concatIdx1);
1246 newOrOperands.append(inputs.begin() + concatIdx1 + 1,
1247 inputs.begin() + concatIdx2);
1248 newOrOperands.append(inputs.begin() + concatIdx2 + 1,
1249 inputs.begin() + inputs.size());
1250 newOrOperands.push_back(newOp);
1252 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1261 auto inputs = op.getInputs();
1262 auto size = inputs.size();
1263 assert(size > 1 &&
"expected 2 or more operands");
1277 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1279 if (value.isZero()) {
1280 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1281 inputs.drop_back());
1287 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1288 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value | value2);
1289 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1290 newOperands.push_back(cst);
1291 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1299 for (
size_t i = 0; i < size - 1; ++i) {
1300 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1308 for (
size_t i = 0; i < size - 1; ++i) {
1309 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1310 for (
size_t j = i + 1; j < size; ++j)
1311 if (
auto concat = inputs[j].getDefiningOp<ConcatOp>())
1323 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1324 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1325 source, cmpAgainst);
1331 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1333 if (op.getTwoState() && firstMux.getTwoState() &&
1334 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1336 SmallVector<Value> conditions{firstMux.getCond()};
1337 auto check = [&](Value v) {
1341 conditions.push_back(mux.getCond());
1342 return mux.getTwoState() &&
1343 firstMux.getTrueValue() == mux.getTrueValue() &&
1344 firstMux.getFalseValue() == mux.getFalseValue();
1346 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1347 auto cond = rewriter.create<
comb::OrOp>(op.getLoc(), conditions,
true);
1348 replaceOpWithNewOpAndCopyName<comb::MuxOp>(
1349 rewriter, op, cond, firstMux.getTrueValue(),
1350 firstMux.getFalseValue(),
true);
1360 OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1364 auto size = getInputs().size();
1365 auto inputs = adaptor.getInputs();
1369 return getInputs()[0];
1372 if (size == 2 && getInputs()[0] == getInputs()[1])
1376 if (inputs.size() == 2 && inputs[1] &&
1377 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1378 return getInputs()[0];
1382 if (isBinaryNot()) {
1384 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1385 subExpr != getResult())
1395 PatternRewriter &rewriter) {
1396 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1397 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1400 rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1401 icmp.getOperand(1), icmp.getTwoState());
1404 if (op.getNumOperands() > 2) {
1405 SmallVector<Value, 4> newOperands(op.getOperands());
1406 newOperands.pop_back();
1407 newOperands.erase(newOperands.begin() + icmpOperand);
1408 newOperands.push_back(result);
1409 result = rewriter.create<
XorOp>(op.getLoc(), newOperands, op.getTwoState());
1419 auto inputs = op.getInputs();
1420 auto size = inputs.size();
1421 assert(size > 1 &&
"expected 2 or more operands");
1424 if (inputs[size - 1] == inputs[size - 2]) {
1426 "expected idempotent case for 2 elements handled already.");
1427 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1428 inputs.drop_back(2),
false);
1434 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1436 if (value.isZero()) {
1437 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1438 inputs.drop_back(),
false);
1444 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1445 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value ^ value2);
1446 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1447 newOperands.push_back(cst);
1448 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1449 newOperands,
false);
1453 bool isSingleBit = value.getBitWidth() == 1;
1456 for (
size_t i = 0; i < size - 1; ++i) {
1457 Value operand = inputs[i];
1468 if (isSingleBit && operand.hasOneUse()) {
1469 assert(value == 1 &&
"single bit constant has to be one if not zero");
1470 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1486 replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
1493 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1498 if (getRhs() == getLhs())
1500 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1503 if (adaptor.getRhs()) {
1505 if (adaptor.getLhs()) {
1508 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1511 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1513 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1517 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1518 if (rhsC.getValue().isZero())
1532 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1533 auto negCst = rewriter.create<
hw::ConstantOp>(op.getLoc(), -value);
1534 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getLhs(), negCst,
1546 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1550 auto size = getInputs().size();
1554 return getInputs()[0];
1564 auto inputs = op.getInputs();
1565 auto size = inputs.size();
1566 assert(size > 1 &&
"expected 2 or more operands");
1568 APInt value, value2;
1571 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1572 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1573 inputs.drop_back(),
false);
1578 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1579 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1580 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1581 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1582 newOperands.push_back(cst);
1583 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1584 newOperands,
false);
1589 if (inputs[size - 1] == inputs[size - 2]) {
1590 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1592 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1596 newOperands.push_back(shiftLeftOp);
1597 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1598 newOperands,
false);
1602 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1604 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1605 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1607 APInt one(value.getBitWidth(), 1,
false);
1609 rewriter.create<
hw::ConstantOp>(op.getLoc(), (one << value) + one);
1611 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1612 auto mulOp = rewriter.create<
comb::MulOp>(op.getLoc(), factors,
false);
1614 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1615 newOperands.push_back(mulOp);
1616 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1617 newOperands,
false);
1621 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1623 if (mulOp && mulOp.getInputs().size() == 2 &&
1624 mulOp.getInputs()[0] == inputs[size - 2] &&
1625 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1627 APInt one(value.getBitWidth(), 1,
false);
1628 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + one);
1629 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1632 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1633 newOperands.push_back(newMulOp);
1634 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1635 newOperands,
false);
1648 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1649 if (addOp && addOp.getInputs().size() == 2 &&
1650 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1651 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1653 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1654 replaceOpWithNewOpAndCopyName<AddOp>(
1655 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1656 op.getTwoState() && addOp.getTwoState());
1663 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1667 auto size = getInputs().size();
1668 auto inputs = adaptor.getInputs();
1672 return getInputs()[0];
1674 auto width = cast<IntegerType>(getType()).getWidth();
1675 APInt value(
width, 1,
false);
1678 for (
auto operand : inputs) {
1681 value *= cast<IntegerAttr>(operand).getValue();
1694 auto inputs = op.getInputs();
1695 auto size = inputs.size();
1696 assert(size > 1 &&
"expected 2 or more operands");
1698 APInt value, value2;
1701 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1702 value.isPowerOf2()) {
1703 auto shift = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(),
1704 value.exactLogBase2());
1708 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1709 ArrayRef<Value>(shlOp),
false);
1714 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1715 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1716 inputs.drop_back());
1721 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1722 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1723 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value * value2);
1724 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1725 newOperands.push_back(cst);
1726 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1742 template <
class Op,
bool isSigned>
1743 static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1744 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1746 if (rhsValue.getValue() == 1)
1750 if (rhsValue.getValue().isZero())
1757 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1761 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1764 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1771 template <
class Op,
bool isSigned>
1772 static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1773 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1775 if (rhsValue.getValue() == 1)
1776 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1780 if (rhsValue.getValue().isZero())
1784 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1786 if (lhsValue.getValue().isZero())
1787 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1794 OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1798 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1801 OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1812 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1816 if (getNumOperands() == 1)
1817 return getOperand(0);
1820 for (
auto attr : adaptor.getInputs())
1821 if (!attr || !isa<IntegerAttr>(attr))
1825 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1826 APInt result(resultWidth, 0);
1828 unsigned nextInsertion = resultWidth;
1830 for (
auto attr : adaptor.getInputs()) {
1831 auto chunk = cast<IntegerAttr>(attr).getValue();
1832 nextInsertion -= chunk.getBitWidth();
1833 result.insertBits(chunk, nextInsertion);
1843 auto inputs = op.getInputs();
1844 auto size = inputs.size();
1845 assert(size > 1 &&
"expected 2 or more operands");
1850 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1851 ValueRange replacements) -> LogicalResult {
1852 SmallVector<Value, 4> newOperands;
1853 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1854 newOperands.append(replacements.begin(), replacements.end());
1855 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1856 if (newOperands.size() == 1)
1859 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1864 Value commonOperand = inputs[0];
1865 for (
size_t i = 0; i != size; ++i) {
1867 if (inputs[i] != commonOperand)
1868 commonOperand = Value();
1872 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1873 return flattenConcat(i, i, subConcat->getOperands());
1878 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1879 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1880 unsigned prevWidth = prevCst.getValue().getBitWidth();
1881 unsigned thisWidth = cst.getValue().getBitWidth();
1882 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1883 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1887 return flattenConcat(i - 1, i, replacement);
1892 if (inputs[i] == inputs[i - 1]) {
1894 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1895 return flattenConcat(i - 1, i, replacement);
1900 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1902 if (repl.getOperand() == inputs[i - 1]) {
1903 Value replacement = rewriter.createOrFold<ReplicateOp>(
1904 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1905 return flattenConcat(i - 1, i, replacement);
1908 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1909 if (prevRepl.getOperand() == repl.getOperand()) {
1910 Value replacement = rewriter.createOrFold<ReplicateOp>(
1911 op.getLoc(), repl.getOperand(),
1912 repl.getMultiple() + prevRepl.getMultiple());
1913 return flattenConcat(i - 1, i, replacement);
1919 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1920 if (repl.getOperand() == inputs[i]) {
1921 Value replacement = rewriter.createOrFold<ReplicateOp>(
1922 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1923 return flattenConcat(i - 1, i, replacement);
1929 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1930 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1931 if (extract.getInput() == prevExtract.getInput()) {
1932 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1933 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1934 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1935 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1936 Value replacement = rewriter.create<
ExtractOp>(
1937 op.getLoc(), resType, extract.getInput(),
1938 extract.getLowBit());
1939 return flattenConcat(i - 1, i, replacement);
1952 static std::optional<ArraySlice>
get(Value value) {
1953 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1955 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1958 if (
auto arraySlice =
1961 arraySlice.getInput(), arraySlice.getLowIndex(),
1962 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1964 return std::nullopt;
1970 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1971 prevExtractOpt->input == extractOpt->input &&
1973 extractOpt->width)) {
1975 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1977 extractOpt->width + prevExtractOpt->width);
1980 op.getLoc(), resIntType,
1982 prevExtractOpt->input,
1983 extractOpt->index));
1984 return flattenConcat(i - 1, i, replacement);
1992 if (commonOperand) {
1993 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2005 OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
2010 if (getTrueValue() == getFalseValue())
2011 return getTrueValue();
2012 if (
auto tv = adaptor.getTrueValue())
2013 if (tv == adaptor.getFalseValue())
2018 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
2019 if (pred.getValue().isZero())
2020 return getFalseValue();
2021 return getTrueValue();
2025 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
2026 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
2027 if (tv.getValue().isOne() && fv.getValue().isZero() &&
2044 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
2046 auto requiredPredicate =
2047 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
2048 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
2058 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
2061 for (
auto operand : orOp.getOperands())
2068 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
2071 for (
auto operand : andOp.getOperands())
2089 PatternRewriter &rewriter) {
2092 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2095 Value indexValue = rootCmp.getLhs();
2098 auto getCaseValue = [&](
MuxOp mux) -> Value {
2099 return mux.getOperand(1 +
unsigned(!isFalseSide));
2104 auto getTreeValue = [&](
MuxOp mux) -> Value {
2105 return mux.getOperand(1 +
unsigned(isFalseSide));
2110 SmallVector<Location> locationsFound;
2111 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2115 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2117 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2118 valuesFound.push_back({cst, getCaseValue(mux)});
2119 locationsFound.push_back(mux.getCond().getLoc());
2120 locationsFound.push_back(mux->getLoc());
2125 if (!collectConstantValues(rootMux))
2129 if (rootMux->hasOneUse()) {
2130 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2131 if (getTreeValue(userMux) == rootMux.getResult() &&
2139 auto nextTreeValue = getTreeValue(rootMux);
2141 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2142 if (!nextMux || !nextMux->hasOneUse())
2144 if (!collectConstantValues(nextMux))
2146 nextTreeValue = getTreeValue(nextMux);
2152 if (valuesFound.size() < 3)
2157 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2158 if (indexWidth >= 9)
2164 uint64_t tableSize = 1ULL << indexWidth;
2165 if (valuesFound.size() < (tableSize * 5) / 8)
2170 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2175 for (
auto &elt : llvm::reverse(valuesFound)) {
2176 uint64_t idx = elt.first.getValue().getZExtValue();
2177 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2178 table[idx] = elt.second;
2183 std::reverse(table.begin(), table.end());
2186 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2188 replaceOpWithNewOpAndCopyName<hw::ArrayGetOp>(rewriter, rootMux, array,
2203 PatternRewriter &rewriter) {
2204 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2205 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2209 if (fullyAssoc->getNumOperands() == 2)
2210 return fullyAssoc->getOperand(operandNo ^ 1);
2213 if (fullyAssoc->hasOneUse()) {
2214 rewriter.modifyOpInPlace(fullyAssoc,
2215 [&]() { fullyAssoc->eraseOperand(operandNo); });
2216 return fullyAssoc->getResult(0);
2220 SmallVector<Value> operands;
2221 operands.append(fullyAssoc->getOperands().begin(),
2222 fullyAssoc->getOperands().begin() + operandNo);
2223 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2224 fullyAssoc->getOperands().end());
2226 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2227 Value excluded = fullyAssoc->getOperand(operandNo);
2231 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2233 return opWithoutExcluded;
2243 PatternRewriter &rewriter) {
2246 Operation *subExpr =
2247 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2248 if (!subExpr || subExpr->getNumOperands() < 2)
2252 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2257 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2258 size_t opNo = 0, e = subExpr->getNumOperands();
2259 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2265 Value cond = op.getCond();
2271 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2273 Value subCond = subMux.getCond();
2276 if (subMux.getTrueValue() == commonValue)
2277 otherValue = subMux.getFalseValue();
2278 else if (subMux.getFalseValue() == commonValue) {
2279 otherValue = subMux.getTrueValue();
2289 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2290 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, cond, commonValue,
2291 otherValue, op.getTwoState());
2297 bool isaAndOp = isa<AndOp>(subExpr);
2298 if (isTrueOperand ^ isaAndOp)
2302 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2305 bool isaXorOp = isa<XorOp>(subExpr);
2306 bool isaOrOp = isa<OrOp>(subExpr);
2315 if (isaOrOp || isaXorOp) {
2316 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2317 restOfAssoc,
false);
2319 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, masked, commonValue,
2322 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, masked, commonValue,
2328 assert(isaAndOp &&
"unexpected operation here");
2329 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2330 restOfAssoc,
false);
2331 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, masked, commonValue,
2342 PatternRewriter &rewriter) {
2345 if (!isa<ConcatOp>(trueOp))
2349 SmallVector<Value> trueOperands, falseOperands;
2353 size_t numTrueOperands = trueOperands.size();
2354 size_t numFalseOperands = falseOperands.size();
2356 if (!numTrueOperands || !numFalseOperands ||
2357 (trueOperands.front() != falseOperands.front() &&
2358 trueOperands.back() != falseOperands.back()))
2362 if (trueOperands.front() == falseOperands.front()) {
2363 SmallVector<Value> operands;
2365 for (i = 0; i < numTrueOperands; ++i) {
2366 Value trueOperand = trueOperands[i];
2367 if (trueOperand == falseOperands[i])
2368 operands.push_back(trueOperand);
2372 if (i == numTrueOperands) {
2379 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2380 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2381 mux->getLoc(), operands.front(), operands.size());
2383 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2387 operands.append(trueOperands.begin() + i, trueOperands.end());
2388 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2390 operands.append(falseOperands.begin() + i, falseOperands.end());
2392 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2395 Value lsb = rewriter.createOrFold<
MuxOp>(
2396 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2397 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2402 if (trueOperands.back() == falseOperands.back()) {
2403 SmallVector<Value> operands;
2406 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2407 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2408 operands.push_back(trueOperand);
2412 std::reverse(operands.begin(), operands.end());
2413 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2417 operands.append(trueOperands.begin(), trueOperands.end() - i);
2418 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2420 operands.append(falseOperands.begin(), falseOperands.end() - i);
2422 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2424 Value msb = rewriter.createOrFold<
MuxOp>(
2425 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2426 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, msb, sharedLSB);
2438 if (!trueVec || !falseVec)
2440 if (!trueVec.isUniform() || !falseVec.isUniform())
2444 op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2445 falseVec.getUniformElement(), op.getTwoState());
2447 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2454 using OpRewritePattern::OpRewritePattern;
2456 LogicalResult matchAndRewrite(
MuxOp op,
2457 PatternRewriter &rewriter)
const override;
2460 LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2461 PatternRewriter &rewriter)
const {
2470 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2471 if (value.getBitWidth() == 1) {
2473 if (value.isZero()) {
2475 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, notCond,
2476 op.getFalseValue(),
false);
2481 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getCond(),
2482 op.getFalseValue(),
false);
2488 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2493 APInt xorValue = value ^ value2;
2494 if (xorValue.isPowerOf2()) {
2495 unsigned leadingZeros = xorValue.countLeadingZeros();
2496 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2497 SmallVector<Value, 3> operands;
2505 if (leadingZeros > 0)
2506 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2507 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2511 auto v1 = rewriter.createOrFold<
ExtractOp>(
2512 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2513 auto v2 = rewriter.createOrFold<
ExtractOp>(
2514 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2515 operands.push_back(rewriter.createOrFold<
MuxOp>(
2516 op.getLoc(), op.getCond(), v1, v2,
false));
2518 if (trailingZeros > 0)
2519 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2520 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2522 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
2529 if (value.isAllOnes() && value2.isZero()) {
2530 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2537 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2538 value.getBitWidth() == 1) {
2540 if (value.isZero()) {
2541 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getCond(),
2542 op.getTrueValue(),
false);
2549 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2550 op.getFalseValue(),
false);
2551 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, notCond,
2552 op.getTrueValue(),
false);
2558 Operation *condOp = op.getCond().getDefiningOp();
2559 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2561 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, op.getType(), subExpr,
2562 op.getFalseValue(), op.getTrueValue(),
2570 if (condOp && condOp->hasOneUse()) {
2571 SmallVector<Value> invertedOperands;
2575 auto getInvertedOperands = [&]() ->
bool {
2576 for (Value operand : condOp->getOperands()) {
2577 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2578 invertedOperands.push_back(subExpr);
2585 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2587 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2588 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newOr,
2590 op.getTrueValue(), op.getTwoState());
2593 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2595 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2596 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newAnd,
2598 op.getTrueValue(), op.getTwoState());
2604 dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp())) {
2606 if (op.getCond() == falseMux.getCond()) {
2607 replaceOpWithNewOpAndCopyName<MuxOp>(
2608 rewriter, op, op.getCond(), op.getTrueValue(),
2609 falseMux.getFalseValue(), op.getTwoStateAttr());
2619 dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp())) {
2621 if (op.getCond() == trueMux.getCond()) {
2622 replaceOpWithNewOpAndCopyName<MuxOp>(
2623 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2624 op.getFalseValue(), op.getTwoStateAttr());
2634 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2635 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2636 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2637 trueMux.getTrueValue() == falseMux.getTrueValue()) {
2638 auto subMux = rewriter.create<
MuxOp>(
2639 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2640 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2641 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2642 trueMux.getTrueValue(), subMux,
2643 op.getTwoStateAttr());
2648 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2649 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2650 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2651 trueMux.getFalseValue() == falseMux.getFalseValue()) {
2652 auto subMux = rewriter.create<
MuxOp>(
2653 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2654 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2655 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2656 subMux, trueMux.getFalseValue(),
2657 op.getTwoStateAttr());
2662 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2663 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2664 trueMux && falseMux &&
2665 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2666 trueMux.getFalseValue() == falseMux.getFalseValue()) {
2667 auto subMux = rewriter.create<
MuxOp>(
2668 rewriter.getFusedLoc(
2669 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2670 op.getCond(), trueMux.getCond(), falseMux.getCond());
2671 replaceOpWithNewOpAndCopyName<MuxOp>(
2672 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2673 op.getTwoStateAttr());
2685 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2686 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2687 if (trueOp->getName() == falseOp->getName())
2704 if (op.getInputs().empty() || op.isUniform())
2706 auto inputs = op.getInputs();
2707 if (inputs.size() <= 1)
2712 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2717 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2718 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2719 if (!input || first.getCond() != input.getCond())
2724 SmallVector<Value> trues{first.getTrueValue()};
2725 SmallVector<Value> falses{first.getFalseValue()};
2726 SmallVector<Location> locs{first->getLoc()};
2727 bool isTwoState =
true;
2728 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2729 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2730 trues.push_back(input.getTrueValue());
2731 falses.push_back(input.getFalseValue());
2732 locs.push_back(input->getLoc());
2733 if (!input.getTwoState())
2742 auto arrayTy = op.getType();
2745 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2746 trueValues, falseValues, isTwoState);
2751 using OpRewritePattern::OpRewritePattern;
2754 PatternRewriter &rewriter)
const override {
2758 if (foldArrayOfMuxes(op, rewriter))
2766 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2767 MLIRContext *context) {
2768 results.insert<MuxRewriter, ArrayRewriter>(context);
2779 switch (predicate) {
2780 case ICmpPredicate::eq:
2782 case ICmpPredicate::ne:
2784 case ICmpPredicate::slt:
2785 return lhs.slt(rhs);
2786 case ICmpPredicate::sle:
2787 return lhs.sle(rhs);
2788 case ICmpPredicate::sgt:
2789 return lhs.sgt(rhs);
2790 case ICmpPredicate::sge:
2791 return lhs.sge(rhs);
2792 case ICmpPredicate::ult:
2793 return lhs.ult(rhs);
2794 case ICmpPredicate::ule:
2795 return lhs.ule(rhs);
2796 case ICmpPredicate::ugt:
2797 return lhs.ugt(rhs);
2798 case ICmpPredicate::uge:
2799 return lhs.uge(rhs);
2800 case ICmpPredicate::ceq:
2802 case ICmpPredicate::cne:
2804 case ICmpPredicate::weq:
2806 case ICmpPredicate::wne:
2809 llvm_unreachable(
"unknown comparison predicate");
2815 switch (predicate) {
2816 case ICmpPredicate::eq:
2817 case ICmpPredicate::sle:
2818 case ICmpPredicate::sge:
2819 case ICmpPredicate::ule:
2820 case ICmpPredicate::uge:
2821 case ICmpPredicate::ceq:
2822 case ICmpPredicate::weq:
2824 case ICmpPredicate::ne:
2825 case ICmpPredicate::slt:
2826 case ICmpPredicate::sgt:
2827 case ICmpPredicate::ult:
2828 case ICmpPredicate::ugt:
2829 case ICmpPredicate::cne:
2830 case ICmpPredicate::wne:
2833 llvm_unreachable(
"unknown comparison predicate");
2836 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2842 if (getLhs() == getRhs()) {
2848 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2849 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2860 template <
typename Range>
2862 size_t commonPrefixLength = 0;
2863 auto ia = a.begin();
2864 auto ib = b.begin();
2866 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2872 return commonPrefixLength;
2876 size_t totalWidth = 0;
2877 for (
auto operand : operands) {
2880 ssize_t
width = operand.getType().getIntOrFloatBitWidth();
2882 totalWidth +=
width;
2892 PatternRewriter &rewriter) {
2896 SmallVector<Value> lhsOperands, rhsOperands;
2899 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2901 auto formCatOrReplicate = [&](Location loc,
2902 ArrayRef<Value> operands) -> Value {
2903 assert(!operands.empty());
2904 Value sameElement = operands[0];
2905 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2906 if (sameElement != operands[i])
2907 sameElement = Value();
2909 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2911 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2914 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2915 Value rhs) -> LogicalResult {
2916 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2921 size_t commonPrefixLength =
2923 if (commonPrefixLength == lhsOperands.size()) {
2926 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
2932 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2934 size_t commonPrefixTotalWidth =
2935 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2936 size_t commonSuffixTotalWidth =
2937 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2938 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2939 .drop_back(commonSuffixLength);
2940 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2941 .drop_back(commonSuffixLength);
2943 auto replaceWithoutReplicatingSignBit = [&]() {
2944 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2945 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2946 return replaceWith(op.getPredicate(), newLhs, newRhs);
2949 auto replaceWithReplicatingSignBit = [&]() {
2950 auto firstNonEmptyValue = lhsOperands[0];
2951 auto firstNonEmptyElemWidth =
2952 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2953 Value signBit = rewriter.createOrFold<
ExtractOp>(
2954 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2956 auto newLhs = rewriter.
create<
ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2957 auto newRhs = rewriter.create<
ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2958 return replaceWith(op.getPredicate(), newLhs, newRhs);
2961 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2963 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2964 return replaceWithoutReplicatingSignBit();
2970 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2971 return replaceWithReplicatingSignBit();
2973 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2975 return replaceWithoutReplicatingSignBit();
2989 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2990 PatternRewriter &rewriter) {
2994 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2995 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2998 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2999 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
3007 SmallVector<Value> newConcatOperands;
3008 auto newConstant = APInt::getZeroWidth();
3013 unsigned knownMSB = bitsKnown.countLeadingOnes();
3015 Value operand = cmpOp.getLhs();
3020 while (knownMSB != bitsKnown.getBitWidth()) {
3023 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3026 unsigned unknownBits = bitsKnown.countLeadingZeros();
3027 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3028 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
3029 operand.getLoc(), operand, lowBit,
3031 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3034 newConcatOperands.push_back(spanOperand);
3037 if (newConstant.getBitWidth() != 0)
3038 newConstant = newConstant.concat(spanConstant);
3040 newConstant = spanConstant;
3043 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3044 bitsKnown = bitsKnown.trunc(newWidth);
3045 knownMSB = bitsKnown.countLeadingOnes();
3051 if (newConcatOperands.empty()) {
3052 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3053 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
3059 Value concatResult =
3060 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
3064 cmpOp.getOperand(1).getLoc(), newConstant);
3066 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3067 concatResult, newConstantOp,
3068 cmpOp.getTwoState());
3074 PatternRewriter &rewriter) {
3075 auto ip = rewriter.saveInsertionPoint();
3076 rewriter.setInsertionPoint(xorOp);
3078 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
3080 xorRHS.getValue() ^ rhs);
3082 switch (xorOp.getNumOperands()) {
3086 APInt::getZero(rhs.getBitWidth()));
3090 newLHS = xorOp.getOperand(0);
3094 SmallVector<Value> newOperands(xorOp.getOperands());
3095 newOperands.pop_back();
3096 newLHS = rewriter.create<
XorOp>(xorOp.getLoc(), newOperands,
false);
3100 bool xorMultipleUses = !xorOp->hasOneUse();
3104 if (xorMultipleUses)
3105 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3109 rewriter.restoreInsertionPoint(ip);
3110 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3111 newLHS, newRHS,
false);
3121 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3122 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3123 "Should be folded");
3124 replaceOpWithNewOpAndCopyName<ICmpOp>(
3125 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3126 op.getRhs(), op.getLhs(), op.getTwoState());
3131 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3133 return rewriter.create<
hw::ConstantOp>(op.getLoc(), std::move(constant));
3136 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3137 Value rhs) -> LogicalResult {
3138 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
3143 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3144 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
3145 APInt(1, constant));
3149 switch (op.getPredicate()) {
3150 case ICmpPredicate::slt:
3152 if (rhs.isMaxSignedValue())
3153 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3155 if (rhs.isMinSignedValue())
3156 return replaceWithConstantI1(0);
3158 if ((rhs - 1).isMinSignedValue())
3159 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3162 case ICmpPredicate::sgt:
3164 if (rhs.isMinSignedValue())
3165 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3167 if (rhs.isMaxSignedValue())
3168 return replaceWithConstantI1(0);
3170 if ((rhs + 1).isMaxSignedValue())
3171 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3174 case ICmpPredicate::ult:
3176 if (rhs.isAllOnes())
3177 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3180 return replaceWithConstantI1(0);
3182 if ((rhs - 1).isZero())
3183 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3187 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3188 rhs.getBitWidth()) {
3189 auto numOnes = rhs.countLeadingOnes();
3190 auto smaller = rewriter.create<
ExtractOp>(
3191 op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3192 return replaceWith(ICmpPredicate::ne, smaller,
3197 case ICmpPredicate::ugt:
3200 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3202 if (rhs.isAllOnes())
3203 return replaceWithConstantI1(0);
3205 if ((rhs + 1).isAllOnes())
3206 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3210 if ((rhs + 1).isPowerOf2()) {
3211 auto numOnes = rhs.countTrailingOnes();
3212 auto newWidth = rhs.getBitWidth() - numOnes;
3213 auto smaller = rewriter.create<
ExtractOp>(op.getLoc(), op.getLhs(),
3215 return replaceWith(ICmpPredicate::ne, smaller,
3220 case ICmpPredicate::sle:
3222 if (rhs.isMaxSignedValue())
3223 return replaceWithConstantI1(1);
3225 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3226 case ICmpPredicate::sge:
3228 if (rhs.isMinSignedValue())
3229 return replaceWithConstantI1(1);
3231 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3232 case ICmpPredicate::ule:
3234 if (rhs.isAllOnes())
3235 return replaceWithConstantI1(1);
3237 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3238 case ICmpPredicate::uge:
3241 return replaceWithConstantI1(1);
3243 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3244 case ICmpPredicate::eq:
3245 if (rhs.getBitWidth() == 1) {
3248 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3253 if (rhs.isAllOnes()) {
3260 case ICmpPredicate::ne:
3261 if (rhs.getBitWidth() == 1) {
3267 if (rhs.isAllOnes()) {
3269 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3276 case ICmpPredicate::ceq:
3277 case ICmpPredicate::cne:
3278 case ICmpPredicate::weq:
3279 case ICmpPredicate::wne:
3285 if (op.getPredicate() == ICmpPredicate::eq ||
3286 op.getPredicate() == ICmpPredicate::ne) {
3291 if (!knownBits.isUnknown())
3298 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3305 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3306 if (rhs.isAllOnes() || rhs.isZero()) {
3307 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3309 op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(
width)
3310 : APInt::getZero(
width));
3311 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, op.getPredicate(),
3312 replicateOp.getInput(), cst,
3322 if (Operation *opLHS = op.getLhs().getDefiningOp())
3323 if (Operation *opRHS = op.getRhs().getDefiningOp())
3324 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3325 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "sv.namehint" attribute.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1, size_t concatIdx2, PatternRewriter &rewriter)
Simplify concat ops in an or op when a constant operand is present in either concat.
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
def create(data_type, value)
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
uint64_t getWidth(Type t)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
bool isOffset(Value base, Value index, uint64_t offset)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.