13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/Matchers.h"
15#include "mlir/IR/PatternMatch.h"
16#include "llvm/ADT/SetVector.h"
17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/KnownBits.h"
24using namespace matchers;
28 return llvm::any_of(op->getOperands(), [op](
auto operand) {
29 return operand.getDefiningOp() == op;
39 ArrayRef<Value> operands, OpBuilder &builder) {
40 OperationState state(loc, name);
41 state.addOperands(operands);
42 state.addTypes(operands[0].getType());
43 return builder.create(state)->getResult(0);
46static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
47 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
54 for (
auto op :
concat.getOperands())
56 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
57 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
68 return op->hasAttr(
"sv.attributes");
72template <
typename SubType>
73struct ComplementMatcher {
75 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
76 bool match(Operation *op) {
77 auto xorOp = dyn_cast<XorOp>(op);
78 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
83template <
typename SubType>
84static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
85 return ComplementMatcher<SubType>(subExpr);
91 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
92 "must be commutative operations"));
93 if (op->hasOneUse()) {
94 auto *user = *op->getUsers().begin();
95 return user->getName() == op->getName() &&
96 op->getAttrOfType<UnitAttr>(
"twoState") ==
97 user->getAttrOfType<UnitAttr>(
"twoState") &&
98 op->getBlock() == user->getBlock();
113 auto inputs = op->getOperands();
115 SmallVector<Value, 4> newOperands;
116 SmallVector<Location, 4> newLocations{op->getLoc()};
117 newOperands.reserve(inputs.size());
119 decltype(inputs.begin()) current, end;
122 SmallVector<Element> worklist;
123 worklist.push_back({inputs.begin(), inputs.end()});
124 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
125 bool changed =
false;
126 while (!worklist.empty()) {
127 auto &element = worklist.back();
130 if (element.current == element.end) {
135 Value value = *element.current++;
136 auto *flattenOp = value.getDefiningOp();
139 if (!flattenOp || flattenOp->getName() != op->getName() ||
140 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
141 flattenOp->getBlock() != op->getBlock()) {
142 newOperands.push_back(value);
147 if (!value.hasOneUse()) {
155 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
158 newOperands.push_back(value);
166 auto flattenOpInputs = flattenOp->getOperands();
167 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
168 newLocations.push_back(flattenOp->getLoc());
174 Value result =
createGenericOp(FusedLoc::get(op->getContext(), newLocations),
175 op->getName(), newOperands, rewriter);
177 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
185static std::pair<size_t, size_t>
187 size_t originalOpWidth) {
188 auto users = op->getUsers();
190 "getLowestBitAndHighestBitRequired cannot operate on "
191 "a empty list of uses.");
195 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
196 size_t highestBitRequired = 0;
198 for (
auto *user : users) {
199 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
200 size_t lowBit = extractOp.getLowBit();
202 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
203 highestBitRequired = std::max(highestBitRequired, highBit);
204 lowestBitRequired = std::min(lowestBitRequired, lowBit);
208 highestBitRequired = originalOpWidth - 1;
209 lowestBitRequired = 0;
213 return {lowestBitRequired, highestBitRequired};
218 PatternRewriter &rewriter) {
219 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
225 if (range.second + 1 == opType.getWidth() && range.first == 0)
228 SmallVector<Value> args;
229 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
230 for (
auto inop : op.getOperands()) {
232 if (inop.getType() != op.getType())
233 args.push_back(inop);
235 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
238 auto newop = OpTy::create(rewriter, op.getLoc(), newType, args);
239 newop->setDialectAttrs(op->getDialectAttrs());
240 if (op.getTwoState())
241 newop.setTwoState(
true);
243 Value newResult = newop.getResult();
245 newResult = rewriter.createOrFold<
ConcatOp>(
246 op.getLoc(), newResult,
248 APInt::getZero(range.first)));
249 if (range.second + 1 < opType.getWidth())
250 newResult = rewriter.createOrFold<
ConcatOp>(
253 rewriter, op.getLoc(),
254 APInt::getZero(opType.getWidth() - range.second - 1)),
256 rewriter.replaceOp(op, newResult);
264OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
269 if (cast<IntegerType>(getType()).
getWidth() ==
270 getInput().getType().getIntOrFloatBitWidth())
274 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
275 if (input.getValue().getBitWidth() == 1) {
276 if (input.getValue().isZero())
278 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
281 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
285 APInt result = APInt::getZeroWidth();
286 for (
auto i = getMultiple(); i != 0; --i)
287 result = result.concat(input.getValue());
294OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
299 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
300 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
312 hw::PEO paramOpcode) {
313 assert(operands.size() == 2 &&
"binary op takes two operands");
314 if (!operands[0] || !operands[1])
319 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
320 cast<TypedAttr>(operands[1]));
323OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
327 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
328 if (rhs.getValue().isZero())
329 return getOperand(0);
331 unsigned width = getType().getIntOrFloatBitWidth();
332 if (rhs.getValue().uge(width))
333 return getIntAttr(APInt::getZero(width), getContext());
338LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
344 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
347 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
348 if (value.ugt(width))
350 unsigned shift = value.getZExtValue();
353 if (width <= shift || shift == 0)
363 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
367OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
371 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
372 if (rhs.getValue().isZero())
373 return getOperand(0);
375 unsigned width = getType().getIntOrFloatBitWidth();
376 if (rhs.getValue().uge(width))
377 return getIntAttr(APInt::getZero(width), getContext());
382LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
388 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
391 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
392 if (value.ugt(width))
394 unsigned shift = value.getZExtValue();
397 if (width <= shift || shift == 0)
407 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
411OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
415 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
416 if (rhs.getValue().isZero())
417 return getOperand(0);
421LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
427 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
430 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
431 if (value.ugt(width))
433 unsigned shift = value.getZExtValue();
436 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
437 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
439 if (width == shift) {
447 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
455OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
460 if (getInput().getType() == getType())
464 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
465 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
466 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
477 PatternRewriter &rewriter) {
478 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
479 size_t beginOfFirstRelevantElement = 0;
480 auto it = reversedConcatArgs.begin();
481 size_t lowBit = op.getLowBit();
484 for (; it != reversedConcatArgs.end(); it++) {
485 assert(beginOfFirstRelevantElement <= lowBit &&
486 "incorrectly moved past an element that lowBit has coverage over");
489 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
490 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
514 beginOfFirstRelevantElement += operandWidth;
516 assert(it != reversedConcatArgs.end() &&
517 "incorrectly failed to find an element which contains coverage of "
520 SmallVector<Value> reverseConcatArgs;
521 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
522 size_t extractLo = lowBit - beginOfFirstRelevantElement;
527 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
528 auto concatArg = *it;
529 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
530 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
532 if (widthToConsume == operandWidth && extractLo == 0) {
533 reverseConcatArgs.push_back(concatArg);
535 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
536 reverseConcatArgs.push_back(
540 widthRemaining -= widthToConsume;
546 if (reverseConcatArgs.size() == 1) {
549 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
550 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
557 PatternRewriter &rewriter) {
558 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
559 auto replicateEltWidth =
560 replicate.getOperand().getType().getIntOrFloatBitWidth();
564 if (op.getLowBit() % replicateEltWidth == 0 &&
565 extractResultWidth % replicateEltWidth == 0) {
566 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
567 replicate.getOperand());
573 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
575 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
576 rewriter, op, op.getType(), replicate.getOperand(),
577 op.getLowBit() % replicateEltWidth);
586LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
589 auto *inputOp = op.getInput().getDefiningOp();
596 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
598 if (knownBits.isConstant()) {
599 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
600 knownBits.getConstant());
606 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
607 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
608 rewriter, op, op.getType(), innerExtract.getInput(),
609 innerExtract.getLowBit() + op.getLowBit());
614 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
618 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
624 if (inputOp && inputOp->getNumOperands() == 2 &&
625 isa<AndOp, OrOp, XorOp>(inputOp)) {
626 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
627 auto extractedCst = cstRHS.getValue().extractBits(
628 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
629 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
630 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
631 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
639 if (isa<AndOp>(inputOp)) {
642 unsigned lz = extractedCst.countLeadingZeros();
643 unsigned tz = extractedCst.countTrailingZeros();
644 unsigned pop = extractedCst.popcount();
645 if (extractedCst.getBitWidth() - lz - tz == pop) {
646 auto resultTy = rewriter.getIntegerType(pop);
647 SmallVector<Value> resultElts;
650 APInt::getZero(lz)));
651 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
652 op.getLoc(), resultTy, inputOp->getOperand(0),
653 op.getLowBit() + tz));
656 APInt::getZero(tz)));
657 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
666 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
667 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
669 if (shlOp->hasOneUse())
671 if (lhsCst.getValue().isOne()) {
673 rewriter, shlOp.getLoc(),
674 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
675 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
676 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
692 hw::PEO paramOpcode) {
693 assert(operands.size() > 1 &&
"caller should handle one-operand case");
696 if (!operands[1] || !operands[0])
700 if (llvm::all_of(operands.drop_front(2),
701 [&](Attribute in) { return !!in; })) {
702 SmallVector<mlir::TypedAttr> typedOperands;
703 typedOperands.reserve(operands.size());
704 for (
auto operand : operands) {
705 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
706 typedOperands.push_back(typedOperand);
710 if (typedOperands.size() == operands.size())
711 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
727 size_t concatIdx,
const APInt &cst,
728 PatternRewriter &rewriter) {
729 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
730 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
735 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
736 auto *operandOp = operand.getDefiningOp();
741 if (isa<hw::ConstantOp>(operandOp))
745 return operandOp->getName() == logicalOp->getName() &&
746 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
747 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
755 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
756 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
763 SmallVector<Value> newConcatOperands;
764 newConcatOperands.reserve(concatOp->getNumOperands());
767 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
768 for (Value operand : concatOp->getOperands()) {
769 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
770 nextOperandBit -= operandWidth;
774 cst.lshr(nextOperandBit).trunc(operandWidth));
776 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
781 ConcatOp::create(rewriter, concatOp.getLoc(), newConcatOperands);
785 if (logicalOp->getNumOperands() > 2) {
786 auto origOperands = logicalOp->getOperands();
787 SmallVector<Value> operands;
789 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
791 operands.append(origOperands.begin() + concatIdx + 1,
792 origOperands.begin() + (origOperands.size() - 1));
794 operands.push_back(newResult);
795 newResult = createLogicalOp(operands);
805 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
807 for (
auto op : operands) {
808 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
809 icmpOp && icmpOp.getTwoState()) {
810 auto predicate = icmpOp.getPredicate();
811 auto lhs = icmpOp.getLhs();
812 auto rhs = icmpOp.getRhs();
813 if (seenPredicates.contains(
814 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
817 seenPredicates.insert({predicate, lhs, rhs});
823OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
827 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
829 auto inputs = adaptor.getInputs();
832 for (
auto operand : inputs) {
835 value &= cast<IntegerAttr>(operand).getValue();
841 if (inputs.size() == 2 && inputs[1] &&
842 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
843 return getInputs()[0];
846 if (llvm::all_of(getInputs(),
847 [&](
auto in) {
return in == this->getInputs()[0]; }))
848 return getInputs()[0];
851 for (Value arg : getInputs()) {
854 for (Value arg2 : getInputs())
857 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
878template <
typename Op>
880 if (!op.getType().isInteger(1))
883 auto inputs = op.getInputs();
884 size_t size = inputs.size();
886 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
889 Value source = sourceOp.getOperand();
892 if (size != source.getType().getIntOrFloatBitWidth())
896 llvm::BitVector bits(size);
897 bits.set(sourceOp.getLowBit());
899 for (
size_t i = 1; i != size; ++i) {
900 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
901 if (!extractOp || extractOp.getOperand() != source)
903 bits.set(extractOp.getLowBit());
906 return bits.all() ? source : Value();
913template <
typename Op>
916 constexpr unsigned limit = 3;
917 auto inputs = op.getInputs();
919 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
920 llvm::SmallDenseSet<Op, 8> checked;
927 llvm::SmallVector<OpWithDepth, 8> worklist;
929 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
933 if (depth < limit && input.getParentBlock() == op->getBlock()) {
934 auto inputOp = input.template getDefiningOp<Op>();
935 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
936 checked.insert(inputOp).second)
937 worklist.push_back({inputOp, depth + 1});
941 for (
auto input : uniqueInputs)
944 while (!worklist.empty()) {
945 auto item = worklist.pop_back_val();
947 for (
auto input : item.op.getInputs()) {
948 uniqueInputs.remove(input);
949 enqueue(input, item.depth);
953 if (uniqueInputs.size() < inputs.size()) {
954 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
955 uniqueInputs.getArrayRef(),
963LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
967 auto inputs = op.getInputs();
968 auto size = inputs.size();
980 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
984 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
986 if (value.isAllOnes()) {
987 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
988 inputs.drop_back(),
false);
996 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
998 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
999 newOperands.push_back(cst);
1000 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1001 newOperands,
false);
1006 if (size == 2 && value.isPowerOf2()) {
1011 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1012 auto replicateOperand = replicate.getOperand();
1013 if (replicateOperand.getType().isInteger(1)) {
1014 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1015 auto trailingZeros = value.countTrailingZeros();
1018 SmallVector<Value, 3> concatOperands;
1019 if (trailingZeros != resultWidth - 1) {
1021 rewriter, op.getLoc(),
1022 APInt::getZero(resultWidth - trailingZeros - 1));
1023 concatOperands.push_back(highZeros);
1025 concatOperands.push_back(replicateOperand);
1026 if (trailingZeros != 0) {
1028 rewriter, op.getLoc(), APInt::getZero(trailingZeros));
1029 concatOperands.push_back(lowZeros);
1031 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1032 rewriter, op, op.getType(), concatOperands);
1041 unsigned leadingZeros = value.countLeadingZeros();
1042 unsigned trailingZeros = value.countTrailingZeros();
1043 if (leadingZeros > 0 || trailingZeros > 0) {
1044 unsigned maskLength = value.getBitWidth() - leadingZeros - trailingZeros;
1047 SmallVector<Value> operands;
1048 for (
auto input : inputs.drop_back()) {
1049 unsigned offset = trailingZeros;
1050 while (
auto extractOp = input.getDefiningOp<
ExtractOp>()) {
1051 input = extractOp.getInput();
1052 offset += extractOp.getLowBit();
1055 offset, maskLength));
1059 auto narrowMask = value.extractBits(maskLength, trailingZeros);
1060 if (!narrowMask.isAllOnes())
1062 rewriter, inputs.back().getLoc(), narrowMask));
1065 Value narrowValue = operands.back();
1066 if (operands.size() > 1)
1068 AndOp::create(rewriter, op.getLoc(), operands, op.getTwoState());
1072 if (leadingZeros > 0)
1074 rewriter, op.getLoc(), APInt::getZero(leadingZeros)));
1075 operands.push_back(narrowValue);
1076 if (trailingZeros > 0)
1078 rewriter, op.getLoc(), APInt::getZero(trailingZeros)));
1079 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, operands);
1086 for (
size_t i = 0; i < size - 1; ++i) {
1087 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1101 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1102 source, cmpAgainst);
1110OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1114 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1115 auto inputs = adaptor.getInputs();
1117 for (
auto operand : inputs) {
1120 value |= cast<IntegerAttr>(operand).getValue();
1121 if (value.isAllOnes())
1126 if (inputs.size() == 2 && inputs[1] &&
1127 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1128 return getInputs()[0];
1131 if (llvm::all_of(getInputs(),
1132 [&](
auto in) {
return in == this->getInputs()[0]; }))
1133 return getInputs()[0];
1136 for (Value arg : getInputs()) {
1138 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1139 for (Value arg2 : getInputs())
1140 if (arg2 == subExpr)
1142 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1152 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1159LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1163 auto inputs = op.getInputs();
1164 auto size = inputs.size();
1176 assert(size > 1 &&
"expected 2 or more operands");
1180 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1182 if (value.isZero()) {
1183 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1184 inputs.drop_back());
1190 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1192 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1193 newOperands.push_back(cst);
1194 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1202 for (
size_t i = 0; i < size - 1; ++i) {
1203 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1217 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1218 source, cmpAgainst);
1224 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1226 if (op.getTwoState() && firstMux.getTwoState() &&
1227 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1229 SmallVector<Value> conditions{firstMux.getCond()};
1230 auto check = [&](Value v) {
1234 conditions.push_back(mux.getCond());
1235 return mux.getTwoState() &&
1236 firstMux.getTrueValue() == mux.getTrueValue() &&
1237 firstMux.getFalseValue() == mux.getFalseValue();
1239 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1240 auto cond = comb::OrOp::create(rewriter, op.getLoc(), conditions,
true);
1241 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1242 rewriter, op, cond, firstMux.getTrueValue(),
1243 firstMux.getFalseValue(),
true);
1253OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1257 auto size = getInputs().size();
1258 auto inputs = adaptor.getInputs();
1262 return getInputs()[0];
1265 if (size == 2 && getInputs()[0] == getInputs()[1])
1266 return IntegerAttr::get(getType(), 0);
1269 if (inputs.size() == 2 && inputs[1] &&
1270 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1271 return getInputs()[0];
1275 if (isBinaryNot()) {
1277 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1278 subExpr != getResult())
1288 PatternRewriter &rewriter) {
1289 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1290 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1293 ICmpOp::create(rewriter, icmp.getLoc(), negatedPred, icmp.getOperand(0),
1294 icmp.getOperand(1), icmp.getTwoState());
1297 if (op.getNumOperands() > 2) {
1298 SmallVector<Value, 4> newOperands(op.getOperands());
1299 newOperands.pop_back();
1300 newOperands.erase(newOperands.begin() + icmpOperand);
1301 newOperands.push_back(result);
1303 XorOp::create(rewriter, op.getLoc(), newOperands, op.getTwoState());
1309LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1313 auto inputs = op.getInputs();
1314 auto size = inputs.size();
1315 assert(size > 1 &&
"expected 2 or more operands");
1318 if (inputs[size - 1] == inputs[size - 2]) {
1320 "expected idempotent case for 2 elements handled already.");
1321 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1322 inputs.drop_back(2),
false);
1328 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1330 if (value.isZero()) {
1331 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1332 inputs.drop_back(),
false);
1338 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1340 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1341 newOperands.push_back(cst);
1342 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1343 newOperands,
false);
1347 bool isSingleBit = value.getBitWidth() == 1;
1350 for (
size_t i = 0; i < size - 1; ++i) {
1351 Value operand = inputs[i];
1362 if (isSingleBit && operand.hasOneUse()) {
1363 assert(value == 1 &&
"single bit constant has to be one if not zero");
1364 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1380 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1387OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1392 if (getRhs() == getLhs())
1394 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1397 if (adaptor.getRhs()) {
1399 if (adaptor.getLhs()) {
1402 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1404 auto rhsNeg = hw::ParamExprAttr::get(
1405 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1406 return hw::ParamExprAttr::get(hw::PEO::Add,
1407 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1411 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1412 if (rhsC.getValue().isZero())
1420LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1426 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1428 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1440OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1444 auto size = getInputs().size();
1448 return getInputs()[0];
1454LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1458 auto inputs = op.getInputs();
1459 auto size = inputs.size();
1460 assert(size > 1 &&
"expected 2 or more operands");
1462 APInt value, value2;
1465 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1466 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1467 inputs.drop_back(),
false);
1472 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1473 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1475 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1476 newOperands.push_back(cst);
1477 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1478 newOperands,
false);
1483 if (inputs[size - 1] == inputs[size - 2]) {
1484 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1488 comb::ShlOp::create(rewriter, op.getLoc(), inputs.back(), one,
false);
1490 newOperands.push_back(shiftLeftOp);
1491 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1492 newOperands,
false);
1496 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1498 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1499 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1501 APInt one(value.getBitWidth(), 1,
false);
1505 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1506 auto mulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1508 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1509 newOperands.push_back(mulOp);
1510 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1511 newOperands,
false);
1515 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1517 if (mulOp && mulOp.getInputs().size() == 2 &&
1518 mulOp.getInputs()[0] == inputs[size - 2] &&
1519 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1521 APInt one(value.getBitWidth(), 1,
false);
1523 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1524 auto newMulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1526 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1527 newOperands.push_back(newMulOp);
1528 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1529 newOperands,
false);
1542 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1543 if (addOp && addOp.getInputs().size() == 2 &&
1544 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1545 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1548 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1549 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1550 op.getTwoState() && addOp.getTwoState());
1557OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1561 auto size = getInputs().size();
1562 auto inputs = adaptor.getInputs();
1566 return getInputs()[0];
1568 auto width = cast<IntegerType>(getType()).getWidth();
1570 return getIntAttr(APInt::getZero(0), getContext());
1572 APInt value(width, 1,
false);
1575 for (
auto operand : inputs) {
1578 value *= cast<IntegerAttr>(operand).getValue();
1587LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1591 auto inputs = op.getInputs();
1592 auto size = inputs.size();
1593 assert(size > 1 &&
"expected 2 or more operands");
1595 APInt value, value2;
1598 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1599 value.isPowerOf2()) {
1601 value.exactLogBase2());
1603 comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift,
false);
1605 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1606 ArrayRef<Value>(shlOp),
false);
1611 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1612 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1613 inputs.drop_back());
1618 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1619 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1621 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1622 newOperands.push_back(cst);
1623 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1639template <
class Op,
bool isSigned>
1640static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1641 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1643 if (rhsValue.getValue() == 1)
1647 if (rhsValue.getValue().isZero())
1654OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1657 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1660OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1666template <
class Op,
bool isSigned>
1667static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1668 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1670 if (rhsValue.getValue() == 1)
1671 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1675 if (rhsValue.getValue().isZero())
1679 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1681 if (lhsValue.getValue().isZero())
1682 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1689OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1692 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1695OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1705OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1709 if (getNumOperands() == 1)
1710 return getOperand(0);
1713 for (
auto attr : adaptor.getInputs())
1714 if (!attr || !isa<IntegerAttr>(attr))
1718 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1719 APInt result(resultWidth, 0);
1721 unsigned nextInsertion = resultWidth;
1723 for (
auto attr : adaptor.getInputs()) {
1724 auto chunk = cast<IntegerAttr>(attr).getValue();
1725 nextInsertion -= chunk.getBitWidth();
1726 result.insertBits(chunk, nextInsertion);
1732LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1736 auto inputs = op.getInputs();
1737 auto size = inputs.size();
1738 assert(size > 1 &&
"expected 2 or more operands");
1743 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1744 ValueRange replacements) -> LogicalResult {
1745 SmallVector<Value, 4> newOperands;
1746 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1747 newOperands.append(replacements.begin(), replacements.end());
1748 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1749 if (newOperands.size() == 1)
1752 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1757 Value commonOperand = inputs[0];
1758 for (
size_t i = 0; i != size; ++i) {
1760 if (inputs[i] != commonOperand)
1761 commonOperand = Value();
1765 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1766 return flattenConcat(i, i, subConcat->getOperands());
1771 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1772 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1773 unsigned prevWidth = prevCst.getValue().getBitWidth();
1774 unsigned thisWidth = cst.getValue().getBitWidth();
1775 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1776 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1780 return flattenConcat(i - 1, i, replacement);
1785 if (inputs[i] == inputs[i - 1]) {
1787 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1788 return flattenConcat(i - 1, i, replacement);
1793 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1795 if (repl.getOperand() == inputs[i - 1]) {
1796 Value replacement = rewriter.createOrFold<ReplicateOp>(
1797 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1798 return flattenConcat(i - 1, i, replacement);
1801 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1802 if (prevRepl.getOperand() == repl.getOperand()) {
1803 Value replacement = rewriter.createOrFold<ReplicateOp>(
1804 op.getLoc(), repl.getOperand(),
1805 repl.getMultiple() + prevRepl.getMultiple());
1806 return flattenConcat(i - 1, i, replacement);
1812 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1813 if (repl.getOperand() == inputs[i]) {
1814 Value replacement = rewriter.createOrFold<ReplicateOp>(
1815 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1816 return flattenConcat(i - 1, i, replacement);
1822 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1823 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1824 if (extract.getInput() == prevExtract.getInput()) {
1825 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1826 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1827 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1828 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1831 extract.getInput(), extract.getLowBit());
1832 return flattenConcat(i - 1, i, replacement);
1845 static std::optional<ArraySlice>
get(Value value) {
1846 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1848 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1851 if (
auto arraySlice =
1854 arraySlice.getInput(), arraySlice.getLowIndex(),
1855 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1857 return std::nullopt;
1860 if (
auto extractOpt = ArraySlice::get(inputs[i])) {
1861 if (
auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1863 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1864 prevExtractOpt->input == extractOpt->input &&
1865 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1866 extractOpt->width)) {
1867 auto resType = hw::ArrayType::get(
1868 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1870 extractOpt->width + prevExtractOpt->width);
1871 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1873 rewriter, op.getLoc(), resIntType,
1875 prevExtractOpt->input,
1876 extractOpt->index));
1877 return flattenConcat(i - 1, i, replacement);
1885 if (commonOperand) {
1886 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1898OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1903 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1904 return getTrueValue();
1905 if (
auto tv = adaptor.getTrueValue())
1906 if (tv == adaptor.getFalseValue())
1911 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1912 if (pred.getValue().isZero() && getFalseValue() != getResult())
1913 return getFalseValue();
1914 if (pred.getValue().isOne() && getTrueValue() != getResult())
1915 return getTrueValue();
1919 if (getCond().getType() == getTrueValue().getType())
1920 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1921 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1922 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1923 hw::getBitWidth(getType()) == 1 && getCond() != getResult())
1939 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1941 auto requiredPredicate =
1942 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1943 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1953 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
1956 for (
auto operand : orOp.getOperands())
1963 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
1966 for (
auto operand : andOp.getOperands())
1985 PatternRewriter &rewriter,
MuxOp rootMux,
bool isFalseSide,
1991 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
1994 Value indexValue = rootCmp.getLhs();
1997 auto getCaseValue = [&](
MuxOp mux) -> Value {
1998 return mux.getOperand(1 +
unsigned(!isFalseSide));
2003 auto getTreeValue = [&](
MuxOp mux) -> Value {
2004 return mux.getOperand(1 +
unsigned(isFalseSide));
2009 SmallVector<Location> locationsFound;
2010 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2014 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2016 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2017 valuesFound.push_back({cst, getCaseValue(mux)});
2018 locationsFound.push_back(mux.getCond().getLoc());
2019 locationsFound.push_back(mux->getLoc());
2024 if (!collectConstantValues(rootMux))
2028 if (rootMux->hasOneUse()) {
2029 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2030 if (getTreeValue(userMux) == rootMux.getResult() &&
2038 auto nextTreeValue = getTreeValue(rootMux);
2040 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2041 if (!nextMux || !nextMux->hasOneUse())
2043 if (!collectConstantValues(nextMux))
2045 nextTreeValue = getTreeValue(nextMux);
2048 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2050 if (indexWidth > 20)
2053 auto foldingStyle = styleFn(indexWidth, valuesFound.size());
2057 uint64_t tableSize = 1ULL << indexWidth;
2061 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2066 for (
auto &elt :
llvm::reverse(valuesFound)) {
2067 uint64_t idx = elt.first.getValue().getZExtValue();
2068 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2069 table[idx] = elt.second;
2073 SmallVector<Value> bits;
2082 "unknown folding style");
2086 std::reverse(table.begin(), table.end());
2089 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2091 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2106 PatternRewriter &rewriter) {
2107 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2108 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2112 if (fullyAssoc->getNumOperands() == 2)
2113 return fullyAssoc->getOperand(operandNo ^ 1);
2116 if (fullyAssoc->hasOneUse()) {
2117 rewriter.modifyOpInPlace(fullyAssoc,
2118 [&]() { fullyAssoc->eraseOperand(operandNo); });
2119 return fullyAssoc->getResult(0);
2123 SmallVector<Value> operands;
2124 operands.append(fullyAssoc->getOperands().begin(),
2125 fullyAssoc->getOperands().begin() + operandNo);
2126 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2127 fullyAssoc->getOperands().end());
2129 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2130 Value excluded = fullyAssoc->getOperand(operandNo);
2134 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2136 return opWithoutExcluded;
2146 PatternRewriter &rewriter) {
2149 Operation *subExpr =
2150 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2151 if (!subExpr || subExpr->getNumOperands() < 2)
2155 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2160 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2161 size_t opNo = 0, e = subExpr->getNumOperands();
2162 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2168 Value cond = op.getCond();
2174 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2179 Value subCond = subMux.getCond();
2182 if (subMux.getTrueValue() == commonValue)
2183 otherValue = subMux.getFalseValue();
2184 else if (subMux.getFalseValue() == commonValue) {
2185 otherValue = subMux.getTrueValue();
2195 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2196 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2197 otherValue, op.getTwoState());
2203 bool isaAndOp = isa<AndOp>(subExpr);
2204 if (isTrueOperand ^ isaAndOp)
2208 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2211 bool isaXorOp = isa<XorOp>(subExpr);
2212 bool isaOrOp = isa<OrOp>(subExpr);
2221 if (isaOrOp || isaXorOp) {
2222 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2223 restOfAssoc,
false);
2225 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2226 commonValue,
false);
2228 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2234 assert(isaAndOp &&
"unexpected operation here");
2235 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2236 restOfAssoc,
false);
2237 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2248 PatternRewriter &rewriter) {
2251 if (!isa<ConcatOp>(trueOp))
2255 SmallVector<Value> trueOperands, falseOperands;
2259 size_t numTrueOperands = trueOperands.size();
2260 size_t numFalseOperands = falseOperands.size();
2262 if (!numTrueOperands || !numFalseOperands ||
2263 (trueOperands.front() != falseOperands.front() &&
2264 trueOperands.back() != falseOperands.back()))
2268 if (trueOperands.front() == falseOperands.front()) {
2269 SmallVector<Value> operands;
2271 for (i = 0; i < numTrueOperands; ++i) {
2272 Value trueOperand = trueOperands[i];
2273 if (trueOperand == falseOperands[i])
2274 operands.push_back(trueOperand);
2278 if (i == numTrueOperands) {
2285 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2286 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2287 mux->getLoc(), operands.front(), operands.size());
2289 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2293 operands.append(trueOperands.begin() + i, trueOperands.end());
2294 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2296 operands.append(falseOperands.begin() + i, falseOperands.end());
2298 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2301 Value lsb = rewriter.createOrFold<
MuxOp>(
2302 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2303 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2308 if (trueOperands.back() == falseOperands.back()) {
2309 SmallVector<Value> operands;
2312 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2313 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2314 operands.push_back(trueOperand);
2318 std::reverse(operands.begin(), operands.end());
2319 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2323 operands.append(trueOperands.begin(), trueOperands.end() - i);
2324 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2326 operands.append(falseOperands.begin(), falseOperands.end() - i);
2328 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2330 Value msb = rewriter.createOrFold<
MuxOp>(
2331 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2332 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2344 if (!trueVec || !falseVec)
2346 if (!trueVec.isUniform() || !falseVec.isUniform())
2349 auto mux = MuxOp::create(rewriter, op.getLoc(), op.getCond(),
2350 trueVec.getUniformElement(),
2351 falseVec.getUniformElement(), op.getTwoState());
2353 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2361 bool constCond, PatternRewriter &rewriter) {
2362 if (!muxValue.hasOneUse())
2364 auto *op = muxValue.getDefiningOp();
2365 if (!op || !isa_and_nonnull<CombDialect>(op->getDialect()))
2367 if (!llvm::is_contained(op->getOperands(), muxCond))
2369 OpBuilder::InsertionGuard guard(rewriter);
2370 rewriter.setInsertionPoint(op);
2373 rewriter.modifyOpInPlace(op, [&] {
2374 for (
auto &use : op->getOpOperands())
2375 if (use.get() == muxCond)
2383 using OpRewritePattern::OpRewritePattern;
2385 LogicalResult matchAndRewrite(
MuxOp op,
2386 PatternRewriter &rewriter)
const override;
2390foldToArrayCreateOnlyWhenDense(
size_t indexWidth,
size_t numEntries) {
2393 if (indexWidth >= 9 || numEntries < 3)
2399 uint64_t tableSize = 1ULL << indexWidth;
2400 if (numEntries >= tableSize * 5 / 8)
2405LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2406 PatternRewriter &rewriter)
const {
2410 bool isSignlessInt =
false;
2411 if (
auto intType = dyn_cast<IntegerType>(op.getType()))
2412 isSignlessInt = intType.isSignless();
2419 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value)) && isSignlessInt) {
2420 if (value.getBitWidth() == 1) {
2422 if (value.isZero()) {
2424 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2425 op.getFalseValue(),
false);
2430 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2431 op.getFalseValue(),
false);
2437 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2442 APInt xorValue = value ^ value2;
2443 if (xorValue.isPowerOf2()) {
2444 unsigned leadingZeros = xorValue.countLeadingZeros();
2445 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2446 SmallVector<Value, 3> operands;
2454 if (leadingZeros > 0)
2455 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2456 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2460 auto v1 = rewriter.createOrFold<
ExtractOp>(
2461 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2462 auto v2 = rewriter.createOrFold<
ExtractOp>(
2463 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2464 operands.push_back(rewriter.createOrFold<
MuxOp>(
2465 op.getLoc(), op.getCond(), v1, v2,
false));
2467 if (trailingZeros > 0)
2468 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2469 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2471 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2478 if (value.isAllOnes() && value2.isZero()) {
2479 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2480 rewriter, op, op.getType(), op.getCond());
2486 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2487 isSignlessInt && value.getBitWidth() == 1) {
2489 if (value.isZero()) {
2490 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2491 op.getTrueValue(),
false);
2498 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2499 op.getFalseValue(),
false);
2500 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2501 op.getTrueValue(),
false);
2507 Operation *condOp = op.getCond().getDefiningOp();
2508 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2510 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2511 subExpr, op.getFalseValue(),
2512 op.getTrueValue(),
true);
2519 if (condOp && condOp->hasOneUse()) {
2520 SmallVector<Value> invertedOperands;
2524 auto getInvertedOperands = [&]() ->
bool {
2525 for (Value operand : condOp->getOperands()) {
2526 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2527 invertedOperands.push_back(subExpr);
2534 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2536 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2537 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2538 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2542 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2544 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2545 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2546 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2552 if (
auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
2553 falseMux && falseMux != op) {
2555 if (op.getCond() == falseMux.getCond() &&
2556 falseMux.getFalseValue() != falseMux) {
2557 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2558 rewriter, op, op.getCond(), op.getTrueValue(),
2559 falseMux.getFalseValue(), op.getTwoStateAttr());
2565 foldToArrayCreateOnlyWhenDense))
2569 if (
auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
2570 trueMux && trueMux != op) {
2572 if (op.getCond() == trueMux.getCond()) {
2573 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2574 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2575 op.getFalseValue(), op.getTwoStateAttr());
2581 foldToArrayCreateOnlyWhenDense))
2586 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2587 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2588 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2589 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2591 auto subMux = MuxOp::create(
2592 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2593 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2594 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2595 trueMux.getTrueValue(), subMux,
2596 op.getTwoStateAttr());
2601 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2602 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2603 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2604 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2606 auto subMux = MuxOp::create(
2607 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2608 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2609 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2610 subMux, trueMux.getFalseValue(),
2611 op.getTwoStateAttr());
2616 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2617 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2618 trueMux && falseMux &&
2619 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2620 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2623 MuxOp::create(rewriter,
2624 rewriter.getFusedLoc(
2625 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2626 op.getCond(), trueMux.getCond(), falseMux.getCond());
2627 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2628 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2629 op.getTwoStateAttr());
2641 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2642 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2643 if (trueOp->getName() == falseOp->getName())
2656 if (op.getTrueValue().getDefiningOp() &&
2657 op.getTrueValue().getDefiningOp() != op)
2660 if (op.getFalseValue().getDefiningOp() &&
2661 op.getFalseValue().getDefiningOp() != op)
2672 if (op.getInputs().empty() || op.isUniform())
2674 auto inputs = op.getInputs();
2675 if (inputs.size() <= 1)
2680 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2685 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2686 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2687 if (!input || first.getCond() != input.getCond())
2692 SmallVector<Value> trues{first.getTrueValue()};
2693 SmallVector<Value> falses{first.getFalseValue()};
2694 SmallVector<Location> locs{first->getLoc()};
2695 bool isTwoState =
true;
2696 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2697 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2698 trues.push_back(input.getTrueValue());
2699 falses.push_back(input.getFalseValue());
2700 locs.push_back(input->getLoc());
2701 if (!input.getTwoState())
2706 auto loc = FusedLoc::get(op.getContext(), locs);
2710 auto arrayTy = op.getType();
2713 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2714 trueValues, falseValues, isTwoState);
2719 using OpRewritePattern::OpRewritePattern;
2722 PatternRewriter &rewriter)
const override {
2723 if (foldArrayOfMuxes(op, rewriter))
2731void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2732 MLIRContext *context) {
2733 results.insert<MuxRewriter, ArrayRewriter>(context);
2744 switch (predicate) {
2745 case ICmpPredicate::eq:
2747 case ICmpPredicate::ne:
2749 case ICmpPredicate::slt:
2750 return lhs.slt(rhs);
2751 case ICmpPredicate::sle:
2752 return lhs.sle(rhs);
2753 case ICmpPredicate::sgt:
2754 return lhs.sgt(rhs);
2755 case ICmpPredicate::sge:
2756 return lhs.sge(rhs);
2757 case ICmpPredicate::ult:
2758 return lhs.ult(rhs);
2759 case ICmpPredicate::ule:
2760 return lhs.ule(rhs);
2761 case ICmpPredicate::ugt:
2762 return lhs.ugt(rhs);
2763 case ICmpPredicate::uge:
2764 return lhs.uge(rhs);
2765 case ICmpPredicate::ceq:
2767 case ICmpPredicate::cne:
2769 case ICmpPredicate::weq:
2771 case ICmpPredicate::wne:
2774 llvm_unreachable(
"unknown comparison predicate");
2780 switch (predicate) {
2781 case ICmpPredicate::eq:
2782 case ICmpPredicate::sle:
2783 case ICmpPredicate::sge:
2784 case ICmpPredicate::ule:
2785 case ICmpPredicate::uge:
2786 case ICmpPredicate::ceq:
2787 case ICmpPredicate::weq:
2789 case ICmpPredicate::ne:
2790 case ICmpPredicate::slt:
2791 case ICmpPredicate::sgt:
2792 case ICmpPredicate::ult:
2793 case ICmpPredicate::ugt:
2794 case ICmpPredicate::cne:
2795 case ICmpPredicate::wne:
2798 llvm_unreachable(
"unknown comparison predicate");
2801OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2804 if (getLhs() == getRhs()) {
2806 return IntegerAttr::get(getType(), val);
2810 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2811 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2814 return IntegerAttr::get(getType(), val);
2822template <
typename Range>
2824 size_t commonPrefixLength = 0;
2825 auto ia = a.begin();
2826 auto ib = b.begin();
2828 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2834 return commonPrefixLength;
2838 size_t totalWidth = 0;
2839 for (
auto operand : operands) {
2842 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2844 totalWidth += width;
2854 PatternRewriter &rewriter) {
2858 SmallVector<Value> lhsOperands, rhsOperands;
2861 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2863 auto formCatOrReplicate = [&](Location loc,
2864 ArrayRef<Value> operands) -> Value {
2865 assert(!operands.empty());
2866 Value sameElement = operands[0];
2867 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2868 if (sameElement != operands[i])
2869 sameElement = Value();
2871 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2873 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2876 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2877 Value rhs) -> LogicalResult {
2878 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2883 size_t commonPrefixLength =
2885 if (commonPrefixLength == lhsOperands.size()) {
2888 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2894 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2896 size_t commonPrefixTotalWidth =
2897 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2898 size_t commonSuffixTotalWidth =
2899 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2900 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2901 .drop_back(commonSuffixLength);
2902 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2903 .drop_back(commonSuffixLength);
2905 auto replaceWithoutReplicatingSignBit = [&]() {
2906 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2907 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2908 return replaceWith(op.getPredicate(), newLhs, newRhs);
2911 auto replaceWithReplicatingSignBit = [&]() {
2912 auto firstNonEmptyValue = lhsOperands[0];
2913 auto firstNonEmptyElemWidth =
2914 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2915 Value signBit = rewriter.createOrFold<
ExtractOp>(
2916 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2918 auto newLhs = ConcatOp::create(rewriter, lhs->getLoc(), signBit, lhsOnly);
2919 auto newRhs = ConcatOp::create(rewriter, rhs->getLoc(), signBit, rhsOnly);
2920 return replaceWith(op.getPredicate(), newLhs, newRhs);
2923 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2925 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2926 return replaceWithoutReplicatingSignBit();
2932 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2933 return replaceWithReplicatingSignBit();
2935 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2937 return replaceWithoutReplicatingSignBit();
2951 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2952 PatternRewriter &rewriter) {
2956 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2957 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2960 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2961 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2969 SmallVector<Value> newConcatOperands;
2970 auto newConstant = APInt::getZeroWidth();
2975 unsigned knownMSB = bitsKnown.countLeadingOnes();
2977 Value operand = cmpOp.getLhs();
2982 while (knownMSB != bitsKnown.getBitWidth()) {
2985 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2988 unsigned unknownBits = bitsKnown.countLeadingZeros();
2989 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2990 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
2991 operand.getLoc(), operand, lowBit,
2993 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2996 newConcatOperands.push_back(spanOperand);
2999 if (newConstant.getBitWidth() != 0)
3000 newConstant = newConstant.concat(spanConstant);
3002 newConstant = spanConstant;
3005 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3006 bitsKnown = bitsKnown.trunc(newWidth);
3007 knownMSB = bitsKnown.countLeadingOnes();
3013 if (newConcatOperands.empty()) {
3014 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3015 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3021 Value concatResult =
3022 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
3026 rewriter, cmpOp.getOperand(1).getLoc(), newConstant);
3028 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
3029 cmpOp.getPredicate(), concatResult,
3030 newConstantOp, cmpOp.getTwoState());
3036 PatternRewriter &rewriter) {
3037 auto ip = rewriter.saveInsertionPoint();
3038 rewriter.setInsertionPoint(xorOp);
3040 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
3042 xorRHS.getValue() ^ rhs);
3044 switch (xorOp.getNumOperands()) {
3048 APInt::getZero(rhs.getBitWidth()));
3052 newLHS = xorOp.getOperand(0);
3056 SmallVector<Value> newOperands(xorOp.getOperands());
3057 newOperands.pop_back();
3058 newLHS = XorOp::create(rewriter, xorOp.getLoc(), newOperands,
false);
3062 bool xorMultipleUses = !xorOp->hasOneUse();
3066 if (xorMultipleUses)
3067 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3071 rewriter.restoreInsertionPoint(ip);
3072 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3073 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS,
false);
3076LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3082 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3083 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3084 "Should be folded");
3085 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3086 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3087 op.getRhs(), op.getLhs(), op.getTwoState());
3092 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3097 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3098 Value rhs) -> LogicalResult {
3099 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3100 rhs, op.getTwoState());
3104 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3105 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3106 APInt(1, constant));
3110 switch (op.getPredicate()) {
3111 case ICmpPredicate::slt:
3113 if (rhs.isMaxSignedValue())
3114 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3116 if (rhs.isMinSignedValue())
3117 return replaceWithConstantI1(0);
3119 if ((rhs - 1).isMinSignedValue())
3120 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3123 case ICmpPredicate::sgt:
3125 if (rhs.isMinSignedValue())
3126 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3128 if (rhs.isMaxSignedValue())
3129 return replaceWithConstantI1(0);
3131 if ((rhs + 1).isMaxSignedValue())
3132 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3135 case ICmpPredicate::ult:
3137 if (rhs.isAllOnes())
3138 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3141 return replaceWithConstantI1(0);
3143 if ((rhs - 1).isZero())
3144 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3148 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3149 rhs.getBitWidth()) {
3150 auto numOnes = rhs.countLeadingOnes();
3152 rhs.getBitWidth() - numOnes, numOnes);
3153 return replaceWith(ICmpPredicate::ne, smaller,
3158 case ICmpPredicate::ugt:
3161 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3163 if (rhs.isAllOnes())
3164 return replaceWithConstantI1(0);
3166 if ((rhs + 1).isAllOnes())
3167 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3171 if ((rhs + 1).isPowerOf2()) {
3172 auto numOnes = rhs.countTrailingOnes();
3173 auto newWidth = rhs.getBitWidth() - numOnes;
3176 return replaceWith(ICmpPredicate::ne, smaller,
3181 case ICmpPredicate::sle:
3183 if (rhs.isMaxSignedValue())
3184 return replaceWithConstantI1(1);
3186 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3187 case ICmpPredicate::sge:
3189 if (rhs.isMinSignedValue())
3190 return replaceWithConstantI1(1);
3192 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3193 case ICmpPredicate::ule:
3195 if (rhs.isAllOnes())
3196 return replaceWithConstantI1(1);
3198 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3199 case ICmpPredicate::uge:
3202 return replaceWithConstantI1(1);
3204 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3205 case ICmpPredicate::eq:
3206 if (rhs.getBitWidth() == 1) {
3209 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3214 if (rhs.isAllOnes()) {
3221 case ICmpPredicate::ne:
3222 if (rhs.getBitWidth() == 1) {
3228 if (rhs.isAllOnes()) {
3230 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3237 case ICmpPredicate::ceq:
3238 case ICmpPredicate::cne:
3239 case ICmpPredicate::weq:
3240 case ICmpPredicate::wne:
3246 if (op.getPredicate() == ICmpPredicate::eq ||
3247 op.getPredicate() == ICmpPredicate::ne) {
3252 if (!knownBits.isUnknown())
3259 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3266 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3267 if (rhs.isAllOnes() || rhs.isZero()) {
3268 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3271 rhs.isAllOnes() ? APInt::getAllOnes(width)
3272 : APInt::getZero(width));
3273 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3274 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3284 if (Operation *opLHS = op.getLhs().getDefiningOp())
3285 if (Operation *opRHS = op.getRhs().getDefiningOp())
3286 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3287 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool assumeMuxCondInOperand(Value muxCond, Value muxValue, bool constCond, PatternRewriter &rewriter)
If the mux condition is an operand to the op defining its true or false value, replace the condition ...
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool isOpTriviallyRecursive(Operation *op)
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(array_value, low_index, ret_type)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl< Value > &bits)
Extract bits from a value.
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
MuxChainWithComparisonFoldingStyle
Enum for mux chain folding styles.
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
Value constructMuxTree(OpBuilder &builder, Location loc, ArrayRef< Value > selectors, ArrayRef< Value > leafNodes, Value outOfBoundsValue)
Construct a mux tree for given leaf nodes.
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.