13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/Matchers.h"
15#include "mlir/IR/PatternMatch.h"
16#include "llvm/ADT/SetVector.h"
17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/KnownBits.h"
24using namespace matchers;
28 return llvm::any_of(op->getOperands(), [op](
auto operand) {
29 return operand.getDefiningOp() == op;
39 ArrayRef<Value> operands, OpBuilder &builder) {
40 OperationState state(loc, name);
41 state.addOperands(operands);
42 state.addTypes(operands[0].getType());
43 return builder.create(state)->getResult(0);
47 return IntegerAttr::get(IntegerType::get(
context, value.getBitWidth()),
53 if (
auto concat = v.getDefiningOp<
ConcatOp>()) {
54 for (
auto op : concat.getOperands())
56 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
57 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
68 return op->hasAttr(
"sv.attributes");
72template <
typename SubType>
73struct ComplementMatcher {
75 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
76 bool match(Operation *op) {
77 auto xorOp = dyn_cast<XorOp>(op);
78 return xorOp && xorOp.isBinaryNot() &&
79 mlir::detail::matchOperandOrValueAtIndex(op, 0, lhs);
84template <
typename SubType>
85static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
86 return ComplementMatcher<SubType>(subExpr);
92 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
93 "must be commutative operations"));
94 if (op->hasOneUse()) {
95 auto *user = *op->getUsers().begin();
96 return user->getName() == op->getName() &&
97 op->getAttrOfType<UnitAttr>(
"twoState") ==
98 user->getAttrOfType<UnitAttr>(
"twoState") &&
99 op->getBlock() == user->getBlock();
114 auto inputs = op->getOperands();
116 SmallVector<Value, 4> newOperands;
117 SmallVector<Location, 4> newLocations{op->getLoc()};
118 newOperands.reserve(inputs.size());
120 decltype(inputs.begin()) current, end;
123 SmallVector<Element> worklist;
124 worklist.push_back({inputs.begin(), inputs.end()});
125 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
126 bool changed =
false;
127 while (!worklist.empty()) {
128 auto &element = worklist.back();
131 if (element.current == element.end) {
136 Value value = *element.current++;
137 auto *flattenOp = value.getDefiningOp();
140 if (!flattenOp || flattenOp->getName() != op->getName() ||
141 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
142 flattenOp->getBlock() != op->getBlock()) {
143 newOperands.push_back(value);
148 if (!value.hasOneUse()) {
156 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
159 newOperands.push_back(value);
167 auto flattenOpInputs = flattenOp->getOperands();
168 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
169 newLocations.push_back(flattenOp->getLoc());
175 Value result =
createGenericOp(FusedLoc::get(op->getContext(), newLocations),
176 op->getName(), newOperands, rewriter);
178 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
186static std::pair<size_t, size_t>
188 size_t originalOpWidth) {
189 auto users = op->getUsers();
191 "getLowestBitAndHighestBitRequired cannot operate on "
192 "a empty list of uses.");
196 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
197 size_t highestBitRequired = 0;
199 for (
auto *user : users) {
200 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
201 size_t lowBit = extractOp.getLowBit();
203 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
204 highestBitRequired = std::max(highestBitRequired, highBit);
205 lowestBitRequired = std::min(lowestBitRequired, lowBit);
209 highestBitRequired = originalOpWidth - 1;
210 lowestBitRequired = 0;
214 return {lowestBitRequired, highestBitRequired};
219 PatternRewriter &rewriter) {
220 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
226 if (range.second + 1 == opType.getWidth() && range.first == 0)
229 SmallVector<Value> args;
230 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
231 for (
auto inop : op.getOperands()) {
233 if (inop.getType() != op.getType())
234 args.push_back(inop);
236 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
239 auto newop = OpTy::create(rewriter, op.getLoc(), newType, args);
240 newop->setDialectAttrs(op->getDialectAttrs());
241 if (op.getTwoState())
242 newop.setTwoState(
true);
244 Value newResult = newop.getResult();
246 newResult = rewriter.createOrFold<
ConcatOp>(
247 op.getLoc(), newResult,
249 APInt::getZero(range.first)));
250 if (range.second + 1 < opType.getWidth())
251 newResult = rewriter.createOrFold<
ConcatOp>(
254 rewriter, op.getLoc(),
255 APInt::getZero(opType.getWidth() - range.second - 1)),
257 rewriter.replaceOp(op, newResult);
265OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
270 if (cast<IntegerType>(getType()).
getWidth() ==
271 getInput().getType().getIntOrFloatBitWidth())
275 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
276 if (input.getValue().getBitWidth() == 1) {
277 if (input.getValue().isZero())
279 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
282 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
286 APInt result = APInt::getZeroWidth();
287 for (
auto i = getMultiple(); i != 0; --i)
288 result = result.concat(input.getValue());
295OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
300 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
301 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
313 hw::PEO paramOpcode) {
314 assert(operands.size() == 2 &&
"binary op takes two operands");
315 if (!operands[0] || !operands[1])
320 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
321 cast<TypedAttr>(operands[1]));
324OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
328 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
329 if (rhs.getValue().isZero())
330 return getOperand(0);
332 unsigned width = getType().getIntOrFloatBitWidth();
333 if (rhs.getValue().uge(width))
334 return getIntAttr(APInt::getZero(width), getContext());
339LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
345 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
348 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
349 if (value.ugt(width))
351 unsigned shift = value.getZExtValue();
354 if (width <= shift || shift == 0)
364 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
368OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
372 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
373 if (rhs.getValue().isZero())
374 return getOperand(0);
376 unsigned width = getType().getIntOrFloatBitWidth();
377 if (rhs.getValue().uge(width))
378 return getIntAttr(APInt::getZero(width), getContext());
383LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
389 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
392 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
393 if (value.ugt(width))
395 unsigned shift = value.getZExtValue();
398 if (width <= shift || shift == 0)
408 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
412OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
416 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
417 if (rhs.getValue().isZero())
418 return getOperand(0);
422LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
428 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
431 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
432 if (value.ugt(width))
434 unsigned shift = value.getZExtValue();
437 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
438 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
440 if (width == shift) {
448 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
456OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
461 if (getInput().getType() == getType())
465 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
466 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
467 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
478 PatternRewriter &rewriter) {
479 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
480 size_t beginOfFirstRelevantElement = 0;
481 auto it = reversedConcatArgs.begin();
482 size_t lowBit = op.getLowBit();
485 for (; it != reversedConcatArgs.end(); it++) {
486 assert(beginOfFirstRelevantElement <= lowBit &&
487 "incorrectly moved past an element that lowBit has coverage over");
490 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
491 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
515 beginOfFirstRelevantElement += operandWidth;
517 assert(it != reversedConcatArgs.end() &&
518 "incorrectly failed to find an element which contains coverage of "
521 SmallVector<Value> reverseConcatArgs;
522 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
523 size_t extractLo = lowBit - beginOfFirstRelevantElement;
528 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
529 auto concatArg = *it;
530 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
531 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
533 if (widthToConsume == operandWidth && extractLo == 0) {
534 reverseConcatArgs.push_back(concatArg);
536 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
537 reverseConcatArgs.push_back(
541 widthRemaining -= widthToConsume;
547 if (reverseConcatArgs.size() == 1) {
550 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
551 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
558 PatternRewriter &rewriter) {
559 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
560 auto replicateEltWidth =
561 replicate.getOperand().getType().getIntOrFloatBitWidth();
565 if (op.getLowBit() % replicateEltWidth == 0 &&
566 extractResultWidth % replicateEltWidth == 0) {
567 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
568 replicate.getOperand());
574 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
576 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
577 rewriter, op, op.getType(), replicate.getOperand(),
578 op.getLowBit() % replicateEltWidth);
587LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
590 auto *inputOp = op.getInput().getDefiningOp();
597 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
599 if (knownBits.isConstant()) {
600 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
601 knownBits.getConstant());
607 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
608 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
609 rewriter, op, op.getType(), innerExtract.getInput(),
610 innerExtract.getLowBit() + op.getLowBit());
615 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
619 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
625 if (inputOp && inputOp->getNumOperands() == 2 &&
626 isa<AndOp, OrOp, XorOp>(inputOp)) {
627 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
628 auto extractedCst = cstRHS.getValue().extractBits(
629 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
630 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
631 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
632 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
640 if (isa<AndOp>(inputOp)) {
643 unsigned lz = extractedCst.countLeadingZeros();
644 unsigned tz = extractedCst.countTrailingZeros();
645 unsigned pop = extractedCst.popcount();
646 if (extractedCst.getBitWidth() - lz - tz == pop) {
647 auto resultTy = rewriter.getIntegerType(pop);
648 SmallVector<Value> resultElts;
651 APInt::getZero(lz)));
652 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
653 op.getLoc(), resultTy, inputOp->getOperand(0),
654 op.getLowBit() + tz));
657 APInt::getZero(tz)));
658 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
667 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
668 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
670 if (shlOp->hasOneUse())
672 if (lhsCst.getValue().isOne()) {
674 rewriter, shlOp.getLoc(),
675 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
676 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
677 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
693 hw::PEO paramOpcode) {
694 assert(operands.size() > 1 &&
"caller should handle one-operand case");
697 if (!operands[1] || !operands[0])
701 if (llvm::all_of(operands.drop_front(2),
702 [&](Attribute in) { return !!in; })) {
703 SmallVector<mlir::TypedAttr> typedOperands;
704 typedOperands.reserve(operands.size());
705 for (
auto operand : operands) {
706 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
707 typedOperands.push_back(typedOperand);
711 if (typedOperands.size() == operands.size())
712 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
728 size_t concatIdx,
const APInt &cst,
729 PatternRewriter &rewriter) {
730 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
731 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
736 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
737 auto *operandOp = operand.getDefiningOp();
742 if (isa<hw::ConstantOp>(operandOp))
746 return operandOp->getName() == logicalOp->getName() &&
747 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
748 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
756 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
757 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
764 SmallVector<Value> newConcatOperands;
765 newConcatOperands.reserve(concatOp->getNumOperands());
768 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
769 for (Value operand : concatOp->getOperands()) {
770 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
771 nextOperandBit -= operandWidth;
775 cst.lshr(nextOperandBit).trunc(operandWidth));
777 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
782 ConcatOp::create(rewriter, concatOp.getLoc(), newConcatOperands);
786 if (logicalOp->getNumOperands() > 2) {
787 auto origOperands = logicalOp->getOperands();
788 SmallVector<Value> operands;
790 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
792 operands.append(origOperands.begin() + concatIdx + 1,
793 origOperands.begin() + (origOperands.size() - 1));
795 operands.push_back(newResult);
796 newResult = createLogicalOp(operands);
806 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
808 for (
auto op : operands) {
809 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
810 icmpOp && icmpOp.getTwoState()) {
811 auto predicate = icmpOp.getPredicate();
812 auto lhs = icmpOp.getLhs();
813 auto rhs = icmpOp.getRhs();
814 if (seenPredicates.contains(
815 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
818 seenPredicates.insert({predicate, lhs, rhs});
824OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
828 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
830 auto inputs = adaptor.getInputs();
833 for (
auto operand : inputs) {
834 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
837 value &= attr.getValue();
843 if (inputs.size() == 2)
844 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
845 if (intAttr.getValue().isAllOnes())
846 return getInputs()[0];
849 if (llvm::all_of(getInputs(),
850 [&](
auto in) {
return in == this->getInputs()[0]; }))
851 return getInputs()[0];
854 for (Value arg : getInputs()) {
857 for (Value arg2 : getInputs())
860 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
881template <
typename Op>
883 if (!op.getType().isInteger(1))
886 auto inputs = op.getInputs();
887 size_t size = inputs.size();
889 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
892 Value source = sourceOp.getOperand();
895 if (size != source.getType().getIntOrFloatBitWidth())
899 llvm::BitVector bits(size);
900 bits.set(sourceOp.getLowBit());
902 for (
size_t i = 1; i != size; ++i) {
903 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
904 if (!extractOp || extractOp.getOperand() != source)
906 bits.set(extractOp.getLowBit());
909 return bits.all() ? source : Value();
916template <
typename Op>
919 constexpr unsigned limit = 3;
920 auto inputs = op.getInputs();
922 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
923 llvm::SmallDenseSet<Op, 8> checked;
930 llvm::SmallVector<OpWithDepth, 8> worklist;
932 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
936 if (depth < limit && input.getParentBlock() == op->getBlock()) {
937 auto inputOp = input.template getDefiningOp<Op>();
938 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
939 checked.insert(inputOp).second)
940 worklist.push_back({inputOp, depth + 1});
944 for (
auto input : uniqueInputs)
947 while (!worklist.empty()) {
948 auto item = worklist.pop_back_val();
950 for (
auto input : item.op.getInputs()) {
951 uniqueInputs.remove(input);
952 enqueue(input, item.depth);
956 if (uniqueInputs.size() < inputs.size()) {
957 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
958 uniqueInputs.getArrayRef(),
966LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
970 auto inputs = op.getInputs();
971 auto size = inputs.size();
983 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
987 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
989 if (value.isAllOnes()) {
990 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
991 inputs.drop_back(),
false);
999 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1001 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1002 newOperands.push_back(cst);
1003 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1004 newOperands,
false);
1009 if (size == 2 && value.isPowerOf2()) {
1014 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1015 auto replicateOperand = replicate.getOperand();
1016 if (replicateOperand.getType().isInteger(1)) {
1017 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1018 auto trailingZeros = value.countTrailingZeros();
1021 SmallVector<Value, 3> concatOperands;
1022 if (trailingZeros != resultWidth - 1) {
1024 rewriter, op.getLoc(),
1025 APInt::getZero(resultWidth - trailingZeros - 1));
1026 concatOperands.push_back(highZeros);
1028 concatOperands.push_back(replicateOperand);
1029 if (trailingZeros != 0) {
1031 rewriter, op.getLoc(), APInt::getZero(trailingZeros));
1032 concatOperands.push_back(lowZeros);
1034 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1035 rewriter, op, op.getType(), concatOperands);
1044 unsigned leadingZeros = value.countLeadingZeros();
1045 unsigned trailingZeros = value.countTrailingZeros();
1046 if (leadingZeros > 0 || trailingZeros > 0) {
1047 unsigned maskLength = value.getBitWidth() - leadingZeros - trailingZeros;
1050 SmallVector<Value> operands;
1051 for (
auto input : inputs.drop_back()) {
1052 unsigned offset = trailingZeros;
1053 while (
auto extractOp = input.getDefiningOp<
ExtractOp>()) {
1054 input = extractOp.getInput();
1055 offset += extractOp.getLowBit();
1058 offset, maskLength));
1062 auto narrowMask = value.extractBits(maskLength, trailingZeros);
1063 if (!narrowMask.isAllOnes())
1065 rewriter, inputs.back().getLoc(), narrowMask));
1068 Value narrowValue = operands.back();
1069 if (operands.size() > 1)
1071 AndOp::create(rewriter, op.getLoc(), operands, op.getTwoState());
1075 if (leadingZeros > 0)
1077 rewriter, op.getLoc(), APInt::getZero(leadingZeros)));
1078 operands.push_back(narrowValue);
1079 if (trailingZeros > 0)
1081 rewriter, op.getLoc(), APInt::getZero(trailingZeros)));
1082 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, operands);
1089 for (
size_t i = 0; i < size - 1; ++i) {
1090 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1104 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1105 source, cmpAgainst);
1113OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1117 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1118 auto inputs = adaptor.getInputs();
1120 for (
auto operand : inputs) {
1121 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1124 value |= attr.getValue();
1125 if (value.isAllOnes())
1130 if (inputs.size() == 2)
1131 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1132 if (intAttr.getValue().isZero())
1133 return getInputs()[0];
1136 if (llvm::all_of(getInputs(),
1137 [&](
auto in) {
return in == this->getInputs()[0]; }))
1138 return getInputs()[0];
1141 for (Value arg : getInputs()) {
1143 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1144 for (Value arg2 : getInputs())
1145 if (arg2 == subExpr)
1147 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1157 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1164LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1168 auto inputs = op.getInputs();
1169 auto size = inputs.size();
1181 assert(size > 1 &&
"expected 2 or more operands");
1185 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1187 if (value.isZero()) {
1188 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1189 inputs.drop_back());
1195 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1197 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1198 newOperands.push_back(cst);
1199 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1207 for (
size_t i = 0; i < size - 1; ++i) {
1208 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1222 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1223 source, cmpAgainst);
1229 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1231 if (op.getTwoState() && firstMux.getTwoState() &&
1232 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1234 SmallVector<Value> conditions{firstMux.getCond()};
1235 auto check = [&](Value v) {
1239 conditions.push_back(mux.getCond());
1240 return mux.getTwoState() &&
1241 firstMux.getTrueValue() == mux.getTrueValue() &&
1242 firstMux.getFalseValue() == mux.getFalseValue();
1244 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1245 auto cond = comb::OrOp::create(rewriter, op.getLoc(), conditions,
true);
1246 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1247 rewriter, op, cond, firstMux.getTrueValue(),
1248 firstMux.getFalseValue(),
true);
1258OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1262 auto size = getInputs().size();
1263 auto inputs = adaptor.getInputs();
1267 return getInputs()[0];
1270 if (size == 2 && getInputs()[0] == getInputs()[1])
1271 return IntegerAttr::get(getType(), 0);
1274 if (inputs.size() == 2)
1275 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1276 if (intAttr.getValue().isZero())
1277 return getInputs()[0];
1283 subExpr != getResult())
1292 PatternRewriter &rewriter) {
1293 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1294 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1297 ICmpOp::create(rewriter, icmp.getLoc(), negatedPred, icmp.getOperand(0),
1298 icmp.getOperand(1), icmp.getTwoState());
1301 if (op.getNumOperands() > 2) {
1302 SmallVector<Value, 4> newOperands(op.getOperands());
1303 newOperands.pop_back();
1304 newOperands.erase(newOperands.begin() + icmpOperand);
1305 newOperands.push_back(result);
1307 XorOp::create(rewriter, op.getLoc(), newOperands, op.getTwoState());
1313LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1317 auto inputs = op.getInputs();
1318 auto size = inputs.size();
1319 assert(size > 1 &&
"expected 2 or more operands");
1322 if (inputs[size - 1] == inputs[size - 2]) {
1324 "expected idempotent case for 2 elements handled already.");
1325 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1326 inputs.drop_back(2),
false);
1332 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1334 if (value.isZero()) {
1335 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1336 inputs.drop_back(),
false);
1342 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1344 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1345 newOperands.push_back(cst);
1346 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1347 newOperands,
false);
1351 bool isSingleBit = value.getBitWidth() == 1;
1354 for (
size_t i = 0; i < size - 1; ++i) {
1355 Value operand = inputs[i];
1361 if (
auto concat = operand.getDefiningOp<
ConcatOp>())
1366 if (isSingleBit && operand.hasOneUse()) {
1367 assert(value == 1 &&
"single bit constant has to be one if not zero");
1368 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1378 if (matchPattern(op.getResult(),
m_Complement(m_Sext(m_Any(&base))))) {
1397 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1404OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1409 if (getRhs() == getLhs())
1411 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1414 if (adaptor.getRhs()) {
1416 if (adaptor.getLhs()) {
1419 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1421 auto rhsNeg = hw::ParamExprAttr::get(
1422 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1423 return hw::ParamExprAttr::get(hw::PEO::Add,
1424 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1428 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1429 if (rhsC.getValue().isZero())
1437LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1443 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1445 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1457OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1461 auto size = getInputs().size();
1465 return getInputs()[0];
1471LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1475 auto inputs = op.getInputs();
1476 auto size = inputs.size();
1477 assert(size > 1 &&
"expected 2 or more operands");
1479 APInt value, value2;
1482 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1483 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1484 inputs.drop_back(),
false);
1489 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1490 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1492 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1493 newOperands.push_back(cst);
1494 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1495 newOperands,
false);
1500 if (inputs[size - 1] == inputs[size - 2]) {
1501 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1505 comb::ShlOp::create(rewriter, op.getLoc(), inputs.back(), one,
false);
1507 newOperands.push_back(shiftLeftOp);
1508 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1509 newOperands,
false);
1513 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1515 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1516 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1518 APInt one(value.getBitWidth(), 1,
false);
1522 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1523 auto mulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1525 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1526 newOperands.push_back(mulOp);
1527 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1528 newOperands,
false);
1532 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1534 if (mulOp && mulOp.getInputs().size() == 2 &&
1535 mulOp.getInputs()[0] == inputs[size - 2] &&
1536 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1538 APInt one(value.getBitWidth(), 1,
false);
1540 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1541 auto newMulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1543 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1544 newOperands.push_back(newMulOp);
1545 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1546 newOperands,
false);
1559 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1560 if (addOp && addOp.getInputs().size() == 2 &&
1561 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1562 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1565 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1566 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1567 op.getTwoState() && addOp.getTwoState());
1574OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1578 auto size = getInputs().size();
1579 auto inputs = adaptor.getInputs();
1583 return getInputs()[0];
1585 auto width = cast<IntegerType>(getType()).getWidth();
1587 return getIntAttr(APInt::getZero(0), getContext());
1589 APInt value(width, 1,
false);
1592 for (
auto operand : inputs) {
1593 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1596 value *= attr.getValue();
1605LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1609 auto inputs = op.getInputs();
1610 auto size = inputs.size();
1611 assert(size > 1 &&
"expected 2 or more operands");
1613 APInt value, value2;
1616 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1617 value.isPowerOf2()) {
1619 value.exactLogBase2());
1621 comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift,
false);
1623 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1624 ArrayRef<Value>(shlOp),
false);
1629 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1630 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1631 inputs.drop_back());
1636 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1637 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1639 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1640 newOperands.push_back(cst);
1641 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1657template <
class Op,
bool isSigned>
1658static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1659 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1661 if (rhsValue.getValue() == 1)
1665 if (rhsValue.getValue().isZero())
1672OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1675 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1678OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1684template <
class Op,
bool isSigned>
1685static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1686 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1688 if (rhsValue.getValue() == 1)
1689 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1693 if (rhsValue.getValue().isZero())
1697 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1699 if (lhsValue.getValue().isZero())
1700 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1707OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1710 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1713OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1719LogicalResult DivUOp::canonicalize(
DivUOp op, PatternRewriter &rewriter) {
1725LogicalResult ModUOp::canonicalize(
ModUOp op, PatternRewriter &rewriter) {
1737OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1741 if (getNumOperands() == 1)
1742 return getOperand(0);
1745 for (
auto attr : adaptor.getInputs())
1746 if (!attr || !isa<IntegerAttr>(attr))
1750 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1751 APInt result(resultWidth, 0);
1753 unsigned nextInsertion = resultWidth;
1755 for (
auto attr : adaptor.getInputs()) {
1756 auto chunk = cast<IntegerAttr>(attr).getValue();
1757 nextInsertion -= chunk.getBitWidth();
1758 result.insertBits(chunk, nextInsertion);
1764LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1768 auto inputs = op.getInputs();
1769 auto size = inputs.size();
1770 assert(size > 1 &&
"expected 2 or more operands");
1775 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1776 ValueRange replacements) -> LogicalResult {
1777 SmallVector<Value, 4> newOperands;
1778 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1779 newOperands.append(replacements.begin(), replacements.end());
1780 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1781 if (newOperands.size() == 1)
1784 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1789 Value commonOperand = inputs[0];
1790 for (
size_t i = 0; i != size; ++i) {
1792 if (inputs[i] != commonOperand)
1793 commonOperand = Value();
1797 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1798 return flattenConcat(i, i, subConcat->getOperands());
1803 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1804 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1805 unsigned prevWidth = prevCst.getValue().getBitWidth();
1806 unsigned thisWidth = cst.getValue().getBitWidth();
1807 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1808 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1812 return flattenConcat(i - 1, i, replacement);
1817 if (inputs[i] == inputs[i - 1]) {
1819 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1820 return flattenConcat(i - 1, i, replacement);
1825 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1827 if (repl.getOperand() == inputs[i - 1]) {
1828 Value replacement = rewriter.createOrFold<ReplicateOp>(
1829 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1830 return flattenConcat(i - 1, i, replacement);
1833 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1834 if (prevRepl.getOperand() == repl.getOperand()) {
1835 Value replacement = rewriter.createOrFold<ReplicateOp>(
1836 op.getLoc(), repl.getOperand(),
1837 repl.getMultiple() + prevRepl.getMultiple());
1838 return flattenConcat(i - 1, i, replacement);
1844 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1845 if (repl.getOperand() == inputs[i]) {
1846 Value replacement = rewriter.createOrFold<ReplicateOp>(
1847 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1848 return flattenConcat(i - 1, i, replacement);
1854 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1855 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1856 if (extract.getInput() == prevExtract.getInput()) {
1857 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1858 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1859 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1860 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1863 extract.getInput(), extract.getLowBit());
1864 return flattenConcat(i - 1, i, replacement);
1877 static std::optional<ArraySlice>
get(Value value) {
1878 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1880 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1883 if (
auto arraySlice =
1886 arraySlice.getInput(), arraySlice.getLowIndex(),
1887 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1889 return std::nullopt;
1892 if (
auto extractOpt = ArraySlice::get(inputs[i])) {
1893 if (
auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1895 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1896 prevExtractOpt->input == extractOpt->input &&
1897 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1898 extractOpt->width)) {
1899 auto resType = hw::ArrayType::get(
1900 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1902 extractOpt->width + prevExtractOpt->width);
1903 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1905 rewriter, op.getLoc(), resIntType,
1907 prevExtractOpt->input,
1908 extractOpt->index));
1909 return flattenConcat(i - 1, i, replacement);
1917 if (commonOperand) {
1918 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1930OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1935 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1936 return getTrueValue();
1937 if (
auto tv = adaptor.getTrueValue())
1938 if (tv == adaptor.getFalseValue())
1943 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1944 if (pred.getValue().isZero() && getFalseValue() != getResult())
1945 return getFalseValue();
1946 if (pred.getValue().isOne() && getTrueValue() != getResult())
1947 return getTrueValue();
1951 if (getCond().getType() == getTrueValue().getType())
1952 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1953 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1954 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1955 hw::getBitWidth(getType()) == 1 && getCond() != getResult())
1971 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1973 auto requiredPredicate =
1974 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1975 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1985 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
1988 for (
auto operand : orOp.getOperands())
1995 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
1998 for (
auto operand : andOp.getOperands())
2017 PatternRewriter &rewriter,
MuxOp rootMux,
bool isFalseSide,
2023 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2026 Value indexValue = rootCmp.getLhs();
2029 auto getCaseValue = [&](
MuxOp mux) -> Value {
2030 return mux.getOperand(1 +
unsigned(!isFalseSide));
2035 auto getTreeValue = [&](
MuxOp mux) -> Value {
2036 return mux.getOperand(1 +
unsigned(isFalseSide));
2041 SmallVector<Location> locationsFound;
2042 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2046 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2048 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2049 valuesFound.push_back({cst, getCaseValue(mux)});
2050 locationsFound.push_back(mux.getCond().getLoc());
2051 locationsFound.push_back(mux->getLoc());
2056 if (!collectConstantValues(rootMux))
2060 if (rootMux->hasOneUse()) {
2061 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2062 if (getTreeValue(userMux) == rootMux.getResult() &&
2070 auto nextTreeValue = getTreeValue(rootMux);
2072 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2073 if (!nextMux || !nextMux->hasOneUse())
2075 if (!collectConstantValues(nextMux))
2077 nextTreeValue = getTreeValue(nextMux);
2080 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2082 if (indexWidth > 20)
2085 auto foldingStyle = styleFn(indexWidth, valuesFound.size());
2089 uint64_t tableSize = 1ULL << indexWidth;
2093 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2098 for (
auto &elt :
llvm::reverse(valuesFound)) {
2099 uint64_t idx = elt.first.getValue().getZExtValue();
2100 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2101 table[idx] = elt.second;
2105 SmallVector<Value> bits;
2114 "unknown folding style");
2118 std::reverse(table.begin(), table.end());
2121 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2123 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2138 PatternRewriter &rewriter) {
2139 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2140 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2144 if (fullyAssoc->getNumOperands() == 2)
2145 return fullyAssoc->getOperand(operandNo ^ 1);
2148 if (fullyAssoc->hasOneUse()) {
2149 rewriter.modifyOpInPlace(fullyAssoc,
2150 [&]() { fullyAssoc->eraseOperand(operandNo); });
2151 return fullyAssoc->getResult(0);
2155 SmallVector<Value> operands;
2156 operands.append(fullyAssoc->getOperands().begin(),
2157 fullyAssoc->getOperands().begin() + operandNo);
2158 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2159 fullyAssoc->getOperands().end());
2161 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2162 Value excluded = fullyAssoc->getOperand(operandNo);
2166 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2168 return opWithoutExcluded;
2178 PatternRewriter &rewriter) {
2181 Operation *subExpr =
2182 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2183 if (!subExpr || subExpr->getNumOperands() < 2)
2187 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2192 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2193 size_t opNo = 0, e = subExpr->getNumOperands();
2194 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2200 Value cond = op.getCond();
2206 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2211 Value subCond = subMux.getCond();
2214 if (subMux.getTrueValue() == commonValue)
2215 otherValue = subMux.getFalseValue();
2216 else if (subMux.getFalseValue() == commonValue) {
2217 otherValue = subMux.getTrueValue();
2227 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2228 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2229 otherValue, op.getTwoState());
2235 bool isaAndOp = isa<AndOp>(subExpr);
2236 if (isTrueOperand ^ isaAndOp)
2240 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2243 bool isaXorOp = isa<XorOp>(subExpr);
2244 bool isaOrOp = isa<OrOp>(subExpr);
2253 if (isaOrOp || isaXorOp) {
2254 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2255 restOfAssoc,
false);
2257 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2258 commonValue,
false);
2260 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2266 assert(isaAndOp &&
"unexpected operation here");
2267 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2268 restOfAssoc,
false);
2269 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2280 PatternRewriter &rewriter) {
2283 if (!isa<ConcatOp>(trueOp))
2287 SmallVector<Value> trueOperands, falseOperands;
2291 size_t numTrueOperands = trueOperands.size();
2292 size_t numFalseOperands = falseOperands.size();
2294 if (!numTrueOperands || !numFalseOperands ||
2295 (trueOperands.front() != falseOperands.front() &&
2296 trueOperands.back() != falseOperands.back()))
2300 if (trueOperands.front() == falseOperands.front()) {
2301 SmallVector<Value> operands;
2303 for (i = 0; i < numTrueOperands; ++i) {
2304 Value trueOperand = trueOperands[i];
2305 if (trueOperand == falseOperands[i])
2306 operands.push_back(trueOperand);
2310 if (i == numTrueOperands) {
2317 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2318 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2319 mux->getLoc(), operands.front(), operands.size());
2321 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2325 operands.append(trueOperands.begin() + i, trueOperands.end());
2326 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2328 operands.append(falseOperands.begin() + i, falseOperands.end());
2330 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2333 Value lsb = rewriter.createOrFold<
MuxOp>(
2334 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2335 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2340 if (trueOperands.back() == falseOperands.back()) {
2341 SmallVector<Value> operands;
2344 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2345 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2346 operands.push_back(trueOperand);
2350 std::reverse(operands.begin(), operands.end());
2351 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2355 operands.append(trueOperands.begin(), trueOperands.end() - i);
2356 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2358 operands.append(falseOperands.begin(), falseOperands.end() - i);
2360 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2362 Value msb = rewriter.createOrFold<
MuxOp>(
2363 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2364 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2376 if (!trueVec || !falseVec)
2378 if (!trueVec.isUniform() || !falseVec.isUniform())
2381 auto mux = MuxOp::create(rewriter, op.getLoc(), op.getCond(),
2382 trueVec.getUniformElement(),
2383 falseVec.getUniformElement(), op.getTwoState());
2385 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2393 bool constCond, PatternRewriter &rewriter) {
2394 if (!muxValue.hasOneUse())
2396 auto *op = muxValue.getDefiningOp();
2397 if (!op || !isa_and_nonnull<CombDialect>(op->getDialect()))
2399 if (!llvm::is_contained(op->getOperands(), muxCond))
2401 OpBuilder::InsertionGuard guard(rewriter);
2402 rewriter.setInsertionPoint(op);
2405 rewriter.modifyOpInPlace(op, [&] {
2406 for (
auto &use : op->getOpOperands())
2407 if (use.get() == muxCond)
2415 using OpRewritePattern::OpRewritePattern;
2417 LogicalResult matchAndRewrite(
MuxOp op,
2418 PatternRewriter &rewriter)
const override;
2422foldToArrayCreateOnlyWhenDense(
size_t indexWidth,
size_t numEntries) {
2425 if (indexWidth >= 9 || numEntries < 3)
2431 uint64_t tableSize = 1ULL << indexWidth;
2432 if (numEntries >= tableSize * 5 / 8)
2437LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2438 PatternRewriter &rewriter)
const {
2442 bool isSignlessInt =
false;
2443 if (
auto intType = dyn_cast<IntegerType>(op.getType()))
2444 isSignlessInt = intType.isSignless();
2451 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value)) && isSignlessInt) {
2452 if (value.getBitWidth() == 1) {
2454 if (value.isZero()) {
2456 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2457 op.getFalseValue(),
false);
2462 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2463 op.getFalseValue(),
false);
2469 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2474 APInt xorValue = value ^ value2;
2475 if (xorValue.isPowerOf2()) {
2476 unsigned leadingZeros = xorValue.countLeadingZeros();
2477 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2478 SmallVector<Value, 3> operands;
2486 if (leadingZeros > 0)
2487 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2488 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2492 auto v1 = rewriter.createOrFold<
ExtractOp>(
2493 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2494 auto v2 = rewriter.createOrFold<
ExtractOp>(
2495 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2496 operands.push_back(rewriter.createOrFold<
MuxOp>(
2497 op.getLoc(), op.getCond(), v1, v2,
false));
2499 if (trailingZeros > 0)
2500 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2501 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2503 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2510 if (value.isAllOnes() && value2.isZero()) {
2511 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2512 rewriter, op, op.getType(), op.getCond());
2518 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2519 isSignlessInt && value.getBitWidth() == 1) {
2521 if (value.isZero()) {
2522 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2523 op.getTrueValue(),
false);
2530 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2531 op.getFalseValue(),
false);
2532 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2533 op.getTrueValue(),
false);
2539 Operation *condOp = op.getCond().getDefiningOp();
2540 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2542 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2543 subExpr, op.getFalseValue(),
2544 op.getTrueValue(),
true);
2551 if (condOp && condOp->hasOneUse()) {
2552 SmallVector<Value> invertedOperands;
2556 auto getInvertedOperands = [&]() ->
bool {
2557 for (Value operand : condOp->getOperands()) {
2558 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2559 invertedOperands.push_back(subExpr);
2566 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2568 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2569 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2570 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2574 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2576 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2577 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2578 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2584 if (
auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
2585 falseMux && falseMux != op) {
2587 if (op.getCond() == falseMux.getCond() &&
2588 falseMux.getFalseValue() != falseMux) {
2589 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2590 rewriter, op, op.getCond(), op.getTrueValue(),
2591 falseMux.getFalseValue(), op.getTwoStateAttr());
2597 foldToArrayCreateOnlyWhenDense))
2601 if (
auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
2602 trueMux && trueMux != op) {
2604 if (op.getCond() == trueMux.getCond()) {
2605 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2606 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2607 op.getFalseValue(), op.getTwoStateAttr());
2613 foldToArrayCreateOnlyWhenDense))
2618 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2619 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2620 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2621 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2623 auto subMux = MuxOp::create(
2624 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2625 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2626 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2627 trueMux.getTrueValue(), subMux,
2628 op.getTwoStateAttr());
2633 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2634 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2635 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2636 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2638 auto subMux = MuxOp::create(
2639 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2640 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2641 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2642 subMux, trueMux.getFalseValue(),
2643 op.getTwoStateAttr());
2648 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2649 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2650 trueMux && falseMux &&
2651 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2652 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2655 MuxOp::create(rewriter,
2656 rewriter.getFusedLoc(
2657 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2658 op.getCond(), trueMux.getCond(), falseMux.getCond());
2659 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2660 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2661 op.getTwoStateAttr());
2673 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2674 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2675 if (trueOp->getName() == falseOp->getName())
2688 if (op.getTrueValue().getDefiningOp() &&
2689 op.getTrueValue().getDefiningOp() != op)
2692 if (op.getFalseValue().getDefiningOp() &&
2693 op.getFalseValue().getDefiningOp() != op)
2704 if (op.getInputs().empty() || op.isUniform())
2706 auto inputs = op.getInputs();
2707 if (inputs.size() <= 1)
2712 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2717 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2718 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2719 if (!input || first.getCond() != input.getCond())
2724 SmallVector<Value> trues{first.getTrueValue()};
2725 SmallVector<Value> falses{first.getFalseValue()};
2726 SmallVector<Location> locs{first->getLoc()};
2727 bool isTwoState =
true;
2728 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2729 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2730 trues.push_back(input.getTrueValue());
2731 falses.push_back(input.getFalseValue());
2732 locs.push_back(input->getLoc());
2733 if (!input.getTwoState())
2738 auto loc = FusedLoc::get(op.getContext(), locs);
2742 auto arrayTy = op.getType();
2745 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2746 trueValues, falseValues, isTwoState);
2751 using OpRewritePattern::OpRewritePattern;
2754 PatternRewriter &rewriter)
const override {
2755 if (foldArrayOfMuxes(op, rewriter))
2763void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2765 results.insert<MuxRewriter, ArrayRewriter>(
context);
2776 switch (predicate) {
2777 case ICmpPredicate::eq:
2779 case ICmpPredicate::ne:
2781 case ICmpPredicate::slt:
2782 return lhs.slt(rhs);
2783 case ICmpPredicate::sle:
2784 return lhs.sle(rhs);
2785 case ICmpPredicate::sgt:
2786 return lhs.sgt(rhs);
2787 case ICmpPredicate::sge:
2788 return lhs.sge(rhs);
2789 case ICmpPredicate::ult:
2790 return lhs.ult(rhs);
2791 case ICmpPredicate::ule:
2792 return lhs.ule(rhs);
2793 case ICmpPredicate::ugt:
2794 return lhs.ugt(rhs);
2795 case ICmpPredicate::uge:
2796 return lhs.uge(rhs);
2797 case ICmpPredicate::ceq:
2799 case ICmpPredicate::cne:
2801 case ICmpPredicate::weq:
2803 case ICmpPredicate::wne:
2806 llvm_unreachable(
"unknown comparison predicate");
2812 switch (predicate) {
2813 case ICmpPredicate::eq:
2814 case ICmpPredicate::sle:
2815 case ICmpPredicate::sge:
2816 case ICmpPredicate::ule:
2817 case ICmpPredicate::uge:
2818 case ICmpPredicate::ceq:
2819 case ICmpPredicate::weq:
2821 case ICmpPredicate::ne:
2822 case ICmpPredicate::slt:
2823 case ICmpPredicate::sgt:
2824 case ICmpPredicate::ult:
2825 case ICmpPredicate::ugt:
2826 case ICmpPredicate::cne:
2827 case ICmpPredicate::wne:
2830 llvm_unreachable(
"unknown comparison predicate");
2833OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2836 if (getLhs() == getRhs()) {
2838 return IntegerAttr::get(getType(), val);
2842 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2843 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2846 return IntegerAttr::get(getType(), val);
2854template <
typename Range>
2856 size_t commonPrefixLength = 0;
2857 auto ia = a.begin();
2858 auto ib = b.begin();
2860 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2866 return commonPrefixLength;
2870 size_t totalWidth = 0;
2871 for (
auto operand : operands) {
2874 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2876 totalWidth += width;
2886 PatternRewriter &rewriter) {
2890 SmallVector<Value> lhsOperands, rhsOperands;
2893 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2895 auto formCatOrReplicate = [&](Location loc,
2896 ArrayRef<Value> operands) -> Value {
2897 assert(!operands.empty());
2898 Value sameElement = operands[0];
2899 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2900 if (sameElement != operands[i])
2901 sameElement = Value();
2903 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2905 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2908 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2909 Value rhs) -> LogicalResult {
2910 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2915 size_t commonPrefixLength =
2917 if (commonPrefixLength == lhsOperands.size()) {
2920 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2926 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2928 size_t commonPrefixTotalWidth =
2929 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2930 size_t commonSuffixTotalWidth =
2931 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2932 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2933 .drop_back(commonSuffixLength);
2934 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2935 .drop_back(commonSuffixLength);
2937 auto replaceWithoutReplicatingSignBit = [&]() {
2938 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2939 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2940 return replaceWith(op.getPredicate(), newLhs, newRhs);
2943 auto replaceWithReplicatingSignBit = [&]() {
2944 auto firstNonEmptyValue = lhsOperands[0];
2945 auto firstNonEmptyElemWidth =
2946 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2947 Value signBit = rewriter.createOrFold<
ExtractOp>(
2948 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2950 auto newLhs = ConcatOp::create(rewriter, lhs->getLoc(), signBit, lhsOnly);
2951 auto newRhs = ConcatOp::create(rewriter, rhs->getLoc(), signBit, rhsOnly);
2952 return replaceWith(op.getPredicate(), newLhs, newRhs);
2955 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2957 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2958 return replaceWithoutReplicatingSignBit();
2964 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2965 return replaceWithReplicatingSignBit();
2967 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2969 return replaceWithoutReplicatingSignBit();
2983 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2984 PatternRewriter &rewriter) {
2988 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2989 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2992 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2993 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3001 SmallVector<Value> newConcatOperands;
3002 auto newConstant = APInt::getZeroWidth();
3007 unsigned knownMSB = bitsKnown.countLeadingOnes();
3009 Value operand = cmpOp.getLhs();
3014 while (knownMSB != bitsKnown.getBitWidth()) {
3017 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3020 unsigned unknownBits = bitsKnown.countLeadingZeros();
3021 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3022 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
3023 operand.getLoc(), operand, lowBit,
3025 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3028 newConcatOperands.push_back(spanOperand);
3031 if (newConstant.getBitWidth() != 0)
3032 newConstant = newConstant.concat(spanConstant);
3034 newConstant = spanConstant;
3037 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3038 bitsKnown = bitsKnown.trunc(newWidth);
3039 knownMSB = bitsKnown.countLeadingOnes();
3045 if (newConcatOperands.empty()) {
3046 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3047 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3053 Value concatResult =
3054 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
3058 rewriter, cmpOp.getOperand(1).getLoc(), newConstant);
3060 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
3061 cmpOp.getPredicate(), concatResult,
3062 newConstantOp, cmpOp.getTwoState());
3068 PatternRewriter &rewriter) {
3069 auto ip = rewriter.saveInsertionPoint();
3070 rewriter.setInsertionPoint(xorOp);
3072 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
3074 xorRHS.getValue() ^ rhs);
3076 switch (xorOp.getNumOperands()) {
3080 APInt::getZero(rhs.getBitWidth()));
3084 newLHS = xorOp.getOperand(0);
3088 SmallVector<Value> newOperands(xorOp.getOperands());
3089 newOperands.pop_back();
3090 newLHS = XorOp::create(rewriter, xorOp.getLoc(), newOperands,
false);
3094 bool xorMultipleUses = !xorOp->hasOneUse();
3098 if (xorMultipleUses)
3099 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3103 rewriter.restoreInsertionPoint(ip);
3104 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3105 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS,
false);
3108LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3114 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3115 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3116 "Should be folded");
3117 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3118 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3119 op.getRhs(), op.getLhs(), op.getTwoState());
3124 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3129 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3130 Value rhs) -> LogicalResult {
3131 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3132 rhs, op.getTwoState());
3136 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3137 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3138 APInt(1, constant));
3142 switch (op.getPredicate()) {
3143 case ICmpPredicate::slt:
3145 if (rhs.isMaxSignedValue())
3146 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3148 if (rhs.isMinSignedValue())
3149 return replaceWithConstantI1(0);
3151 if ((rhs - 1).isMinSignedValue())
3152 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3155 case ICmpPredicate::sgt:
3157 if (rhs.isMinSignedValue())
3158 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3160 if (rhs.isMaxSignedValue())
3161 return replaceWithConstantI1(0);
3163 if ((rhs + 1).isMaxSignedValue())
3164 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3167 case ICmpPredicate::ult:
3169 if (rhs.isAllOnes())
3170 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3173 return replaceWithConstantI1(0);
3175 if ((rhs - 1).isZero())
3176 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3180 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3181 rhs.getBitWidth()) {
3182 auto numOnes = rhs.countLeadingOnes();
3184 rhs.getBitWidth() - numOnes, numOnes);
3185 return replaceWith(ICmpPredicate::ne, smaller,
3190 case ICmpPredicate::ugt:
3193 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3195 if (rhs.isAllOnes())
3196 return replaceWithConstantI1(0);
3198 if ((rhs + 1).isAllOnes())
3199 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3203 if ((rhs + 1).isPowerOf2()) {
3204 auto numOnes = rhs.countTrailingOnes();
3205 auto newWidth = rhs.getBitWidth() - numOnes;
3208 return replaceWith(ICmpPredicate::ne, smaller,
3213 case ICmpPredicate::sle:
3215 if (rhs.isMaxSignedValue())
3216 return replaceWithConstantI1(1);
3218 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3219 case ICmpPredicate::sge:
3221 if (rhs.isMinSignedValue())
3222 return replaceWithConstantI1(1);
3224 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3225 case ICmpPredicate::ule:
3227 if (rhs.isAllOnes())
3228 return replaceWithConstantI1(1);
3230 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3231 case ICmpPredicate::uge:
3234 return replaceWithConstantI1(1);
3236 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3237 case ICmpPredicate::eq:
3238 if (rhs.getBitWidth() == 1) {
3241 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3246 if (rhs.isAllOnes()) {
3253 case ICmpPredicate::ne:
3254 if (rhs.getBitWidth() == 1) {
3260 if (rhs.isAllOnes()) {
3262 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3269 case ICmpPredicate::ceq:
3270 case ICmpPredicate::cne:
3271 case ICmpPredicate::weq:
3272 case ICmpPredicate::wne:
3278 if (op.getPredicate() == ICmpPredicate::eq ||
3279 op.getPredicate() == ICmpPredicate::ne) {
3284 if (!knownBits.isUnknown())
3291 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3298 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3299 if (rhs.isAllOnes() || rhs.isZero()) {
3300 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3303 rhs.isAllOnes() ? APInt::getAllOnes(width)
3304 : APInt::getZero(width));
3305 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3306 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3316 if (Operation *opLHS = op.getLhs().getDefiningOp())
3317 if (Operation *opRHS = op.getRhs().getDefiningOp())
3318 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3319 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool assumeMuxCondInOperand(Value muxCond, Value muxValue, bool constCond, PatternRewriter &rewriter)
If the mux condition is an operand to the op defining its true or false value, replace the condition ...
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool isOpTriviallyRecursive(Operation *op)
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(elements, Type result_type=None)
create(array_value, low_index, ret_type)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl< Value > &bits)
Extract bits from a value.
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
MuxChainWithComparisonFoldingStyle
Enum for mux chain folding styles.
LogicalResult convertModUByPowerOfTwo(ModUOp modOp, mlir::PatternRewriter &rewriter)
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Value constructMuxTree(OpBuilder &builder, Location loc, ArrayRef< Value > selectors, ArrayRef< Value > leafNodes, Value outOfBoundsValue)
Construct a mux tree for given leaf nodes.
LogicalResult convertDivUByPowerOfTwo(DivUOp divOp, mlir::PatternRewriter &rewriter)
Convert unsigned division or modulo by a power of two.
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.