13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/Matchers.h"
15#include "mlir/IR/PatternMatch.h"
16#include "llvm/ADT/SetVector.h"
17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/KnownBits.h"
24using namespace matchers;
28 return llvm::any_of(op->getOperands(), [op](
auto operand) {
29 return operand.getDefiningOp() == op;
39 ArrayRef<Value> operands, OpBuilder &builder) {
40 OperationState state(loc, name);
41 state.addOperands(operands);
42 state.addTypes(operands[0].getType());
43 return builder.create(state)->getResult(0);
46static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
47 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
54 for (
auto op :
concat.getOperands())
56 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
57 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
68 return op->hasAttr(
"sv.attributes");
72template <
typename SubType>
73struct ComplementMatcher {
75 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
76 bool match(Operation *op) {
77 auto xorOp = dyn_cast<XorOp>(op);
78 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
83template <
typename SubType>
84static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
85 return ComplementMatcher<SubType>(subExpr);
91 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
92 "must be commutative operations"));
93 if (op->hasOneUse()) {
94 auto *user = *op->getUsers().begin();
95 return user->getName() == op->getName() &&
96 op->getAttrOfType<UnitAttr>(
"twoState") ==
97 user->getAttrOfType<UnitAttr>(
"twoState") &&
98 op->getBlock() == user->getBlock();
113 auto inputs = op->getOperands();
115 SmallVector<Value, 4> newOperands;
116 SmallVector<Location, 4> newLocations{op->getLoc()};
117 newOperands.reserve(inputs.size());
119 decltype(inputs.begin()) current, end;
122 SmallVector<Element> worklist;
123 worklist.push_back({inputs.begin(), inputs.end()});
124 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
125 bool changed =
false;
126 while (!worklist.empty()) {
127 auto &element = worklist.back();
130 if (element.current == element.end) {
135 Value value = *element.current++;
136 auto *flattenOp = value.getDefiningOp();
139 if (!flattenOp || flattenOp->getName() != op->getName() ||
140 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
141 flattenOp->getBlock() != op->getBlock()) {
142 newOperands.push_back(value);
147 if (!value.hasOneUse()) {
155 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
158 newOperands.push_back(value);
166 auto flattenOpInputs = flattenOp->getOperands();
167 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
168 newLocations.push_back(flattenOp->getLoc());
174 Value result =
createGenericOp(FusedLoc::get(op->getContext(), newLocations),
175 op->getName(), newOperands, rewriter);
177 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
185static std::pair<size_t, size_t>
187 size_t originalOpWidth) {
188 auto users = op->getUsers();
190 "getLowestBitAndHighestBitRequired cannot operate on "
191 "a empty list of uses.");
195 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
196 size_t highestBitRequired = 0;
198 for (
auto *user : users) {
199 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
200 size_t lowBit = extractOp.getLowBit();
202 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
203 highestBitRequired = std::max(highestBitRequired, highBit);
204 lowestBitRequired = std::min(lowestBitRequired, lowBit);
208 highestBitRequired = originalOpWidth - 1;
209 lowestBitRequired = 0;
213 return {lowestBitRequired, highestBitRequired};
218 PatternRewriter &rewriter) {
219 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
225 if (range.second + 1 == opType.getWidth() && range.first == 0)
228 SmallVector<Value> args;
229 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
230 for (
auto inop : op.getOperands()) {
232 if (inop.getType() != op.getType())
233 args.push_back(inop);
235 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
238 auto newop = OpTy::create(rewriter, op.getLoc(), newType, args);
239 newop->setDialectAttrs(op->getDialectAttrs());
240 if (op.getTwoState())
241 newop.setTwoState(
true);
243 Value newResult = newop.getResult();
245 newResult = rewriter.createOrFold<
ConcatOp>(
246 op.getLoc(), newResult,
248 APInt::getZero(range.first)));
249 if (range.second + 1 < opType.getWidth())
250 newResult = rewriter.createOrFold<
ConcatOp>(
253 rewriter, op.getLoc(),
254 APInt::getZero(opType.getWidth() - range.second - 1)),
256 rewriter.replaceOp(op, newResult);
264OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
269 if (cast<IntegerType>(getType()).
getWidth() ==
270 getInput().getType().getIntOrFloatBitWidth())
274 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
275 if (input.getValue().getBitWidth() == 1) {
276 if (input.getValue().isZero())
278 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
281 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
285 APInt result = APInt::getZeroWidth();
286 for (
auto i = getMultiple(); i != 0; --i)
287 result = result.concat(input.getValue());
294OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
299 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
300 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
312 hw::PEO paramOpcode) {
313 assert(operands.size() == 2 &&
"binary op takes two operands");
314 if (!operands[0] || !operands[1])
319 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
320 cast<TypedAttr>(operands[1]));
323OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
327 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
328 if (rhs.getValue().isZero())
329 return getOperand(0);
331 unsigned width = getType().getIntOrFloatBitWidth();
332 if (rhs.getValue().uge(width))
333 return getIntAttr(APInt::getZero(width), getContext());
338LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
344 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
347 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
348 if (value.ugt(width))
350 unsigned shift = value.getZExtValue();
353 if (width <= shift || shift == 0)
363 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
367OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
371 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
372 if (rhs.getValue().isZero())
373 return getOperand(0);
375 unsigned width = getType().getIntOrFloatBitWidth();
376 if (rhs.getValue().uge(width))
377 return getIntAttr(APInt::getZero(width), getContext());
382LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
388 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
391 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
392 if (value.ugt(width))
394 unsigned shift = value.getZExtValue();
397 if (width <= shift || shift == 0)
407 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
411OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
415 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
416 if (rhs.getValue().isZero())
417 return getOperand(0);
421LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
427 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
430 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
431 if (value.ugt(width))
433 unsigned shift = value.getZExtValue();
436 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
437 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
439 if (width == shift) {
447 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
455OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
460 if (getInput().getType() == getType())
464 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
465 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
466 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
477 PatternRewriter &rewriter) {
478 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
479 size_t beginOfFirstRelevantElement = 0;
480 auto it = reversedConcatArgs.begin();
481 size_t lowBit = op.getLowBit();
484 for (; it != reversedConcatArgs.end(); it++) {
485 assert(beginOfFirstRelevantElement <= lowBit &&
486 "incorrectly moved past an element that lowBit has coverage over");
489 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
490 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
514 beginOfFirstRelevantElement += operandWidth;
516 assert(it != reversedConcatArgs.end() &&
517 "incorrectly failed to find an element which contains coverage of "
520 SmallVector<Value> reverseConcatArgs;
521 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
522 size_t extractLo = lowBit - beginOfFirstRelevantElement;
527 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
528 auto concatArg = *it;
529 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
530 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
532 if (widthToConsume == operandWidth && extractLo == 0) {
533 reverseConcatArgs.push_back(concatArg);
535 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
536 reverseConcatArgs.push_back(
540 widthRemaining -= widthToConsume;
546 if (reverseConcatArgs.size() == 1) {
549 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
550 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
557 PatternRewriter &rewriter) {
558 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
559 auto replicateEltWidth =
560 replicate.getOperand().getType().getIntOrFloatBitWidth();
564 if (op.getLowBit() % replicateEltWidth == 0 &&
565 extractResultWidth % replicateEltWidth == 0) {
566 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
567 replicate.getOperand());
573 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
575 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
576 rewriter, op, op.getType(), replicate.getOperand(),
577 op.getLowBit() % replicateEltWidth);
586LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
589 auto *inputOp = op.getInput().getDefiningOp();
596 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
598 if (knownBits.isConstant()) {
599 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
600 knownBits.getConstant());
606 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
607 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
608 rewriter, op, op.getType(), innerExtract.getInput(),
609 innerExtract.getLowBit() + op.getLowBit());
614 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
618 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
624 if (inputOp && inputOp->getNumOperands() == 2 &&
625 isa<AndOp, OrOp, XorOp>(inputOp)) {
626 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
627 auto extractedCst = cstRHS.getValue().extractBits(
628 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
629 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
630 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
631 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
639 if (isa<AndOp>(inputOp)) {
642 unsigned lz = extractedCst.countLeadingZeros();
643 unsigned tz = extractedCst.countTrailingZeros();
644 unsigned pop = extractedCst.popcount();
645 if (extractedCst.getBitWidth() - lz - tz == pop) {
646 auto resultTy = rewriter.getIntegerType(pop);
647 SmallVector<Value> resultElts;
650 APInt::getZero(lz)));
651 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
652 op.getLoc(), resultTy, inputOp->getOperand(0),
653 op.getLowBit() + tz));
656 APInt::getZero(tz)));
657 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
666 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
667 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
669 if (shlOp->hasOneUse())
671 if (lhsCst.getValue().isOne()) {
673 rewriter, shlOp.getLoc(),
674 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
675 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
676 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
692 hw::PEO paramOpcode) {
693 assert(operands.size() > 1 &&
"caller should handle one-operand case");
696 if (!operands[1] || !operands[0])
700 if (llvm::all_of(operands.drop_front(2),
701 [&](Attribute in) { return !!in; })) {
702 SmallVector<mlir::TypedAttr> typedOperands;
703 typedOperands.reserve(operands.size());
704 for (
auto operand : operands) {
705 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
706 typedOperands.push_back(typedOperand);
710 if (typedOperands.size() == operands.size())
711 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
727 size_t concatIdx,
const APInt &cst,
728 PatternRewriter &rewriter) {
729 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
730 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
735 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
736 auto *operandOp = operand.getDefiningOp();
741 if (isa<hw::ConstantOp>(operandOp))
745 return operandOp->getName() == logicalOp->getName() &&
746 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
747 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
755 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
756 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
763 SmallVector<Value> newConcatOperands;
764 newConcatOperands.reserve(concatOp->getNumOperands());
767 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
768 for (Value operand : concatOp->getOperands()) {
769 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
770 nextOperandBit -= operandWidth;
774 cst.lshr(nextOperandBit).trunc(operandWidth));
776 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
781 ConcatOp::create(rewriter, concatOp.getLoc(), newConcatOperands);
785 if (logicalOp->getNumOperands() > 2) {
786 auto origOperands = logicalOp->getOperands();
787 SmallVector<Value> operands;
789 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
791 operands.append(origOperands.begin() + concatIdx + 1,
792 origOperands.begin() + (origOperands.size() - 1));
794 operands.push_back(newResult);
795 newResult = createLogicalOp(operands);
805 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
807 for (
auto op : operands) {
808 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
809 icmpOp && icmpOp.getTwoState()) {
810 auto predicate = icmpOp.getPredicate();
811 auto lhs = icmpOp.getLhs();
812 auto rhs = icmpOp.getRhs();
813 if (seenPredicates.contains(
814 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
817 seenPredicates.insert({predicate, lhs, rhs});
823OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
827 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
829 auto inputs = adaptor.getInputs();
832 for (
auto operand : inputs) {
835 value &= cast<IntegerAttr>(operand).getValue();
841 if (inputs.size() == 2 && inputs[1] &&
842 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
843 return getInputs()[0];
846 if (llvm::all_of(getInputs(),
847 [&](
auto in) {
return in == this->getInputs()[0]; }))
848 return getInputs()[0];
851 for (Value arg : getInputs()) {
854 for (Value arg2 : getInputs())
857 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
878template <
typename Op>
880 if (!op.getType().isInteger(1))
883 auto inputs = op.getInputs();
884 size_t size = inputs.size();
886 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
889 Value source = sourceOp.getOperand();
892 if (size != source.getType().getIntOrFloatBitWidth())
896 llvm::BitVector bits(size);
897 bits.set(sourceOp.getLowBit());
899 for (
size_t i = 1; i != size; ++i) {
900 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
901 if (!extractOp || extractOp.getOperand() != source)
903 bits.set(extractOp.getLowBit());
906 return bits.all() ? source : Value();
913template <
typename Op>
916 constexpr unsigned limit = 3;
917 auto inputs = op.getInputs();
919 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
920 llvm::SmallDenseSet<Op, 8> checked;
927 llvm::SmallVector<OpWithDepth, 8> worklist;
929 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
933 if (depth < limit && input.getParentBlock() == op->getBlock()) {
934 auto inputOp = input.template getDefiningOp<Op>();
935 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
936 checked.insert(inputOp).second)
937 worklist.push_back({inputOp, depth + 1});
941 for (
auto input : uniqueInputs)
944 while (!worklist.empty()) {
945 auto item = worklist.pop_back_val();
947 for (
auto input : item.op.getInputs()) {
948 uniqueInputs.remove(input);
949 enqueue(input, item.depth);
953 if (uniqueInputs.size() < inputs.size()) {
954 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
955 uniqueInputs.getArrayRef(),
963LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
967 auto inputs = op.getInputs();
968 auto size = inputs.size();
980 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
984 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
986 if (value.isAllOnes()) {
987 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
988 inputs.drop_back(),
false);
996 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
998 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
999 newOperands.push_back(cst);
1000 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1001 newOperands,
false);
1006 if (size == 2 && value.isPowerOf2()) {
1011 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1012 auto replicateOperand = replicate.getOperand();
1013 if (replicateOperand.getType().isInteger(1)) {
1014 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1015 auto trailingZeros = value.countTrailingZeros();
1018 SmallVector<Value, 3> concatOperands;
1019 if (trailingZeros != resultWidth - 1) {
1021 rewriter, op.getLoc(),
1022 APInt::getZero(resultWidth - trailingZeros - 1));
1023 concatOperands.push_back(highZeros);
1025 concatOperands.push_back(replicateOperand);
1026 if (trailingZeros != 0) {
1028 rewriter, op.getLoc(), APInt::getZero(trailingZeros));
1029 concatOperands.push_back(lowZeros);
1031 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1032 rewriter, op, op.getType(), concatOperands);
1041 unsigned leadingZeros = value.countLeadingZeros();
1042 unsigned trailingZeros = value.countTrailingZeros();
1043 if (leadingZeros > 0 || trailingZeros > 0) {
1044 unsigned maskLength = value.getBitWidth() - leadingZeros - trailingZeros;
1047 SmallVector<Value> operands;
1048 for (
auto input : inputs.drop_back()) {
1049 unsigned offset = trailingZeros;
1050 while (
auto extractOp = input.getDefiningOp<
ExtractOp>()) {
1051 input = extractOp.getInput();
1052 offset += extractOp.getLowBit();
1055 offset, maskLength));
1059 auto narrowMask = value.extractBits(maskLength, trailingZeros);
1060 if (!narrowMask.isAllOnes())
1062 rewriter, inputs.back().getLoc(), narrowMask));
1065 Value narrowValue = operands.back();
1066 if (operands.size() > 1)
1068 AndOp::create(rewriter, op.getLoc(), operands, op.getTwoState());
1072 if (leadingZeros > 0)
1074 rewriter, op.getLoc(), APInt::getZero(leadingZeros)));
1075 operands.push_back(narrowValue);
1076 if (trailingZeros > 0)
1078 rewriter, op.getLoc(), APInt::getZero(trailingZeros)));
1079 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, operands);
1086 for (
size_t i = 0; i < size - 1; ++i) {
1087 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1101 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1102 source, cmpAgainst);
1110OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1114 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1115 auto inputs = adaptor.getInputs();
1117 for (
auto operand : inputs) {
1120 value |= cast<IntegerAttr>(operand).getValue();
1121 if (value.isAllOnes())
1126 if (inputs.size() == 2 && inputs[1] &&
1127 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1128 return getInputs()[0];
1131 if (llvm::all_of(getInputs(),
1132 [&](
auto in) {
return in == this->getInputs()[0]; }))
1133 return getInputs()[0];
1136 for (Value arg : getInputs()) {
1138 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1139 for (Value arg2 : getInputs())
1140 if (arg2 == subExpr)
1142 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1152 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1159LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1163 auto inputs = op.getInputs();
1164 auto size = inputs.size();
1176 assert(size > 1 &&
"expected 2 or more operands");
1180 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1182 if (value.isZero()) {
1183 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1184 inputs.drop_back());
1190 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1192 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1193 newOperands.push_back(cst);
1194 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1202 for (
size_t i = 0; i < size - 1; ++i) {
1203 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1217 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1218 source, cmpAgainst);
1224 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1226 if (op.getTwoState() && firstMux.getTwoState() &&
1227 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1229 SmallVector<Value> conditions{firstMux.getCond()};
1230 auto check = [&](Value v) {
1234 conditions.push_back(mux.getCond());
1235 return mux.getTwoState() &&
1236 firstMux.getTrueValue() == mux.getTrueValue() &&
1237 firstMux.getFalseValue() == mux.getFalseValue();
1239 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1240 auto cond = comb::OrOp::create(rewriter, op.getLoc(), conditions,
true);
1241 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1242 rewriter, op, cond, firstMux.getTrueValue(),
1243 firstMux.getFalseValue(),
true);
1253OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1257 auto size = getInputs().size();
1258 auto inputs = adaptor.getInputs();
1262 return getInputs()[0];
1265 if (size == 2 && getInputs()[0] == getInputs()[1])
1266 return IntegerAttr::get(getType(), 0);
1269 if (inputs.size() == 2 && inputs[1] &&
1270 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1271 return getInputs()[0];
1275 if (isBinaryNot()) {
1277 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1278 subExpr != getResult())
1288 PatternRewriter &rewriter) {
1289 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1290 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1293 ICmpOp::create(rewriter, icmp.getLoc(), negatedPred, icmp.getOperand(0),
1294 icmp.getOperand(1), icmp.getTwoState());
1297 if (op.getNumOperands() > 2) {
1298 SmallVector<Value, 4> newOperands(op.getOperands());
1299 newOperands.pop_back();
1300 newOperands.erase(newOperands.begin() + icmpOperand);
1301 newOperands.push_back(result);
1303 XorOp::create(rewriter, op.getLoc(), newOperands, op.getTwoState());
1309LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1313 auto inputs = op.getInputs();
1314 auto size = inputs.size();
1315 assert(size > 1 &&
"expected 2 or more operands");
1318 if (inputs[size - 1] == inputs[size - 2]) {
1320 "expected idempotent case for 2 elements handled already.");
1321 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1322 inputs.drop_back(2),
false);
1328 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1330 if (value.isZero()) {
1331 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1332 inputs.drop_back(),
false);
1338 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1340 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1341 newOperands.push_back(cst);
1342 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1343 newOperands,
false);
1347 bool isSingleBit = value.getBitWidth() == 1;
1350 for (
size_t i = 0; i < size - 1; ++i) {
1351 Value operand = inputs[i];
1362 if (isSingleBit && operand.hasOneUse()) {
1363 assert(value == 1 &&
"single bit constant has to be one if not zero");
1364 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1380 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1387OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1392 if (getRhs() == getLhs())
1394 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1397 if (adaptor.getRhs()) {
1399 if (adaptor.getLhs()) {
1402 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1404 auto rhsNeg = hw::ParamExprAttr::get(
1405 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1406 return hw::ParamExprAttr::get(hw::PEO::Add,
1407 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1411 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1412 if (rhsC.getValue().isZero())
1420LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1426 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1428 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1440OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1444 auto size = getInputs().size();
1448 return getInputs()[0];
1454LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1458 auto inputs = op.getInputs();
1459 auto size = inputs.size();
1460 assert(size > 1 &&
"expected 2 or more operands");
1462 APInt value, value2;
1465 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1466 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1467 inputs.drop_back(),
false);
1472 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1473 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1475 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1476 newOperands.push_back(cst);
1477 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1478 newOperands,
false);
1483 if (inputs[size - 1] == inputs[size - 2]) {
1484 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1488 comb::ShlOp::create(rewriter, op.getLoc(), inputs.back(), one,
false);
1490 newOperands.push_back(shiftLeftOp);
1491 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1492 newOperands,
false);
1496 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1498 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1499 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1501 APInt one(value.getBitWidth(), 1,
false);
1505 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1506 auto mulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1508 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1509 newOperands.push_back(mulOp);
1510 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1511 newOperands,
false);
1515 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1517 if (mulOp && mulOp.getInputs().size() == 2 &&
1518 mulOp.getInputs()[0] == inputs[size - 2] &&
1519 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1521 APInt one(value.getBitWidth(), 1,
false);
1523 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1524 auto newMulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1526 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1527 newOperands.push_back(newMulOp);
1528 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1529 newOperands,
false);
1542 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1543 if (addOp && addOp.getInputs().size() == 2 &&
1544 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1545 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1548 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1549 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1550 op.getTwoState() && addOp.getTwoState());
1557OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1561 auto size = getInputs().size();
1562 auto inputs = adaptor.getInputs();
1566 return getInputs()[0];
1568 auto width = cast<IntegerType>(getType()).getWidth();
1570 return getIntAttr(APInt::getZero(0), getContext());
1572 APInt value(width, 1,
false);
1575 for (
auto operand : inputs) {
1578 value *= cast<IntegerAttr>(operand).getValue();
1587LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1591 auto inputs = op.getInputs();
1592 auto size = inputs.size();
1593 assert(size > 1 &&
"expected 2 or more operands");
1595 APInt value, value2;
1598 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1599 value.isPowerOf2()) {
1601 value.exactLogBase2());
1603 comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift,
false);
1605 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1606 ArrayRef<Value>(shlOp),
false);
1611 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1612 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1613 inputs.drop_back());
1618 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1619 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1621 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1622 newOperands.push_back(cst);
1623 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1639template <
class Op,
bool isSigned>
1640static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1641 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1643 if (rhsValue.getValue() == 1)
1647 if (rhsValue.getValue().isZero())
1654OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1657 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1660OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1666template <
class Op,
bool isSigned>
1667static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1668 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1670 if (rhsValue.getValue() == 1)
1671 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1675 if (rhsValue.getValue().isZero())
1679 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1681 if (lhsValue.getValue().isZero())
1682 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1689OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1692 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1695OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1701LogicalResult DivUOp::canonicalize(
DivUOp op, PatternRewriter &rewriter) {
1707LogicalResult ModUOp::canonicalize(
ModUOp op, PatternRewriter &rewriter) {
1719OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1723 if (getNumOperands() == 1)
1724 return getOperand(0);
1727 for (
auto attr : adaptor.getInputs())
1728 if (!attr || !isa<IntegerAttr>(attr))
1732 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1733 APInt result(resultWidth, 0);
1735 unsigned nextInsertion = resultWidth;
1737 for (
auto attr : adaptor.getInputs()) {
1738 auto chunk = cast<IntegerAttr>(attr).getValue();
1739 nextInsertion -= chunk.getBitWidth();
1740 result.insertBits(chunk, nextInsertion);
1746LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1750 auto inputs = op.getInputs();
1751 auto size = inputs.size();
1752 assert(size > 1 &&
"expected 2 or more operands");
1757 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1758 ValueRange replacements) -> LogicalResult {
1759 SmallVector<Value, 4> newOperands;
1760 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1761 newOperands.append(replacements.begin(), replacements.end());
1762 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1763 if (newOperands.size() == 1)
1766 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1771 Value commonOperand = inputs[0];
1772 for (
size_t i = 0; i != size; ++i) {
1774 if (inputs[i] != commonOperand)
1775 commonOperand = Value();
1779 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1780 return flattenConcat(i, i, subConcat->getOperands());
1785 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1786 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1787 unsigned prevWidth = prevCst.getValue().getBitWidth();
1788 unsigned thisWidth = cst.getValue().getBitWidth();
1789 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1790 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1794 return flattenConcat(i - 1, i, replacement);
1799 if (inputs[i] == inputs[i - 1]) {
1801 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1802 return flattenConcat(i - 1, i, replacement);
1807 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1809 if (repl.getOperand() == inputs[i - 1]) {
1810 Value replacement = rewriter.createOrFold<ReplicateOp>(
1811 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1812 return flattenConcat(i - 1, i, replacement);
1815 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1816 if (prevRepl.getOperand() == repl.getOperand()) {
1817 Value replacement = rewriter.createOrFold<ReplicateOp>(
1818 op.getLoc(), repl.getOperand(),
1819 repl.getMultiple() + prevRepl.getMultiple());
1820 return flattenConcat(i - 1, i, replacement);
1826 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1827 if (repl.getOperand() == inputs[i]) {
1828 Value replacement = rewriter.createOrFold<ReplicateOp>(
1829 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1830 return flattenConcat(i - 1, i, replacement);
1836 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1837 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1838 if (extract.getInput() == prevExtract.getInput()) {
1839 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1840 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1841 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1842 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1845 extract.getInput(), extract.getLowBit());
1846 return flattenConcat(i - 1, i, replacement);
1859 static std::optional<ArraySlice>
get(Value value) {
1860 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1862 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1865 if (
auto arraySlice =
1868 arraySlice.getInput(), arraySlice.getLowIndex(),
1869 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1871 return std::nullopt;
1874 if (
auto extractOpt = ArraySlice::get(inputs[i])) {
1875 if (
auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1877 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1878 prevExtractOpt->input == extractOpt->input &&
1879 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1880 extractOpt->width)) {
1881 auto resType = hw::ArrayType::get(
1882 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1884 extractOpt->width + prevExtractOpt->width);
1885 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1887 rewriter, op.getLoc(), resIntType,
1889 prevExtractOpt->input,
1890 extractOpt->index));
1891 return flattenConcat(i - 1, i, replacement);
1899 if (commonOperand) {
1900 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1912OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1917 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1918 return getTrueValue();
1919 if (
auto tv = adaptor.getTrueValue())
1920 if (tv == adaptor.getFalseValue())
1925 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1926 if (pred.getValue().isZero() && getFalseValue() != getResult())
1927 return getFalseValue();
1928 if (pred.getValue().isOne() && getTrueValue() != getResult())
1929 return getTrueValue();
1933 if (getCond().getType() == getTrueValue().getType())
1934 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1935 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1936 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1937 hw::getBitWidth(getType()) == 1 && getCond() != getResult())
1953 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1955 auto requiredPredicate =
1956 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1957 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1967 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
1970 for (
auto operand : orOp.getOperands())
1977 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
1980 for (
auto operand : andOp.getOperands())
1999 PatternRewriter &rewriter,
MuxOp rootMux,
bool isFalseSide,
2005 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2008 Value indexValue = rootCmp.getLhs();
2011 auto getCaseValue = [&](
MuxOp mux) -> Value {
2012 return mux.getOperand(1 +
unsigned(!isFalseSide));
2017 auto getTreeValue = [&](
MuxOp mux) -> Value {
2018 return mux.getOperand(1 +
unsigned(isFalseSide));
2023 SmallVector<Location> locationsFound;
2024 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2028 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2030 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2031 valuesFound.push_back({cst, getCaseValue(mux)});
2032 locationsFound.push_back(mux.getCond().getLoc());
2033 locationsFound.push_back(mux->getLoc());
2038 if (!collectConstantValues(rootMux))
2042 if (rootMux->hasOneUse()) {
2043 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2044 if (getTreeValue(userMux) == rootMux.getResult() &&
2052 auto nextTreeValue = getTreeValue(rootMux);
2054 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2055 if (!nextMux || !nextMux->hasOneUse())
2057 if (!collectConstantValues(nextMux))
2059 nextTreeValue = getTreeValue(nextMux);
2062 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2064 if (indexWidth > 20)
2067 auto foldingStyle = styleFn(indexWidth, valuesFound.size());
2071 uint64_t tableSize = 1ULL << indexWidth;
2075 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2080 for (
auto &elt :
llvm::reverse(valuesFound)) {
2081 uint64_t idx = elt.first.getValue().getZExtValue();
2082 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2083 table[idx] = elt.second;
2087 SmallVector<Value> bits;
2096 "unknown folding style");
2100 std::reverse(table.begin(), table.end());
2103 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2105 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2120 PatternRewriter &rewriter) {
2121 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2122 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2126 if (fullyAssoc->getNumOperands() == 2)
2127 return fullyAssoc->getOperand(operandNo ^ 1);
2130 if (fullyAssoc->hasOneUse()) {
2131 rewriter.modifyOpInPlace(fullyAssoc,
2132 [&]() { fullyAssoc->eraseOperand(operandNo); });
2133 return fullyAssoc->getResult(0);
2137 SmallVector<Value> operands;
2138 operands.append(fullyAssoc->getOperands().begin(),
2139 fullyAssoc->getOperands().begin() + operandNo);
2140 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2141 fullyAssoc->getOperands().end());
2143 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2144 Value excluded = fullyAssoc->getOperand(operandNo);
2148 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2150 return opWithoutExcluded;
2160 PatternRewriter &rewriter) {
2163 Operation *subExpr =
2164 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2165 if (!subExpr || subExpr->getNumOperands() < 2)
2169 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2174 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2175 size_t opNo = 0, e = subExpr->getNumOperands();
2176 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2182 Value cond = op.getCond();
2188 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2193 Value subCond = subMux.getCond();
2196 if (subMux.getTrueValue() == commonValue)
2197 otherValue = subMux.getFalseValue();
2198 else if (subMux.getFalseValue() == commonValue) {
2199 otherValue = subMux.getTrueValue();
2209 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2210 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2211 otherValue, op.getTwoState());
2217 bool isaAndOp = isa<AndOp>(subExpr);
2218 if (isTrueOperand ^ isaAndOp)
2222 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2225 bool isaXorOp = isa<XorOp>(subExpr);
2226 bool isaOrOp = isa<OrOp>(subExpr);
2235 if (isaOrOp || isaXorOp) {
2236 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2237 restOfAssoc,
false);
2239 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2240 commonValue,
false);
2242 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2248 assert(isaAndOp &&
"unexpected operation here");
2249 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2250 restOfAssoc,
false);
2251 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2262 PatternRewriter &rewriter) {
2265 if (!isa<ConcatOp>(trueOp))
2269 SmallVector<Value> trueOperands, falseOperands;
2273 size_t numTrueOperands = trueOperands.size();
2274 size_t numFalseOperands = falseOperands.size();
2276 if (!numTrueOperands || !numFalseOperands ||
2277 (trueOperands.front() != falseOperands.front() &&
2278 trueOperands.back() != falseOperands.back()))
2282 if (trueOperands.front() == falseOperands.front()) {
2283 SmallVector<Value> operands;
2285 for (i = 0; i < numTrueOperands; ++i) {
2286 Value trueOperand = trueOperands[i];
2287 if (trueOperand == falseOperands[i])
2288 operands.push_back(trueOperand);
2292 if (i == numTrueOperands) {
2299 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2300 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2301 mux->getLoc(), operands.front(), operands.size());
2303 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2307 operands.append(trueOperands.begin() + i, trueOperands.end());
2308 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2310 operands.append(falseOperands.begin() + i, falseOperands.end());
2312 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2315 Value lsb = rewriter.createOrFold<
MuxOp>(
2316 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2317 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2322 if (trueOperands.back() == falseOperands.back()) {
2323 SmallVector<Value> operands;
2326 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2327 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2328 operands.push_back(trueOperand);
2332 std::reverse(operands.begin(), operands.end());
2333 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2337 operands.append(trueOperands.begin(), trueOperands.end() - i);
2338 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2340 operands.append(falseOperands.begin(), falseOperands.end() - i);
2342 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2344 Value msb = rewriter.createOrFold<
MuxOp>(
2345 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2346 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2358 if (!trueVec || !falseVec)
2360 if (!trueVec.isUniform() || !falseVec.isUniform())
2363 auto mux = MuxOp::create(rewriter, op.getLoc(), op.getCond(),
2364 trueVec.getUniformElement(),
2365 falseVec.getUniformElement(), op.getTwoState());
2367 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2375 bool constCond, PatternRewriter &rewriter) {
2376 if (!muxValue.hasOneUse())
2378 auto *op = muxValue.getDefiningOp();
2379 if (!op || !isa_and_nonnull<CombDialect>(op->getDialect()))
2381 if (!llvm::is_contained(op->getOperands(), muxCond))
2383 OpBuilder::InsertionGuard guard(rewriter);
2384 rewriter.setInsertionPoint(op);
2387 rewriter.modifyOpInPlace(op, [&] {
2388 for (
auto &use : op->getOpOperands())
2389 if (use.get() == muxCond)
2397 using OpRewritePattern::OpRewritePattern;
2399 LogicalResult matchAndRewrite(
MuxOp op,
2400 PatternRewriter &rewriter)
const override;
2404foldToArrayCreateOnlyWhenDense(
size_t indexWidth,
size_t numEntries) {
2407 if (indexWidth >= 9 || numEntries < 3)
2413 uint64_t tableSize = 1ULL << indexWidth;
2414 if (numEntries >= tableSize * 5 / 8)
2419LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2420 PatternRewriter &rewriter)
const {
2424 bool isSignlessInt =
false;
2425 if (
auto intType = dyn_cast<IntegerType>(op.getType()))
2426 isSignlessInt = intType.isSignless();
2433 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value)) && isSignlessInt) {
2434 if (value.getBitWidth() == 1) {
2436 if (value.isZero()) {
2438 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2439 op.getFalseValue(),
false);
2444 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2445 op.getFalseValue(),
false);
2451 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2456 APInt xorValue = value ^ value2;
2457 if (xorValue.isPowerOf2()) {
2458 unsigned leadingZeros = xorValue.countLeadingZeros();
2459 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2460 SmallVector<Value, 3> operands;
2468 if (leadingZeros > 0)
2469 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2470 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2474 auto v1 = rewriter.createOrFold<
ExtractOp>(
2475 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2476 auto v2 = rewriter.createOrFold<
ExtractOp>(
2477 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2478 operands.push_back(rewriter.createOrFold<
MuxOp>(
2479 op.getLoc(), op.getCond(), v1, v2,
false));
2481 if (trailingZeros > 0)
2482 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2483 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2485 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2492 if (value.isAllOnes() && value2.isZero()) {
2493 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2494 rewriter, op, op.getType(), op.getCond());
2500 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2501 isSignlessInt && value.getBitWidth() == 1) {
2503 if (value.isZero()) {
2504 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2505 op.getTrueValue(),
false);
2512 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2513 op.getFalseValue(),
false);
2514 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2515 op.getTrueValue(),
false);
2521 Operation *condOp = op.getCond().getDefiningOp();
2522 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2524 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2525 subExpr, op.getFalseValue(),
2526 op.getTrueValue(),
true);
2533 if (condOp && condOp->hasOneUse()) {
2534 SmallVector<Value> invertedOperands;
2538 auto getInvertedOperands = [&]() ->
bool {
2539 for (Value operand : condOp->getOperands()) {
2540 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2541 invertedOperands.push_back(subExpr);
2548 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2550 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2551 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2552 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2556 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2558 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2559 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2560 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2566 if (
auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
2567 falseMux && falseMux != op) {
2569 if (op.getCond() == falseMux.getCond() &&
2570 falseMux.getFalseValue() != falseMux) {
2571 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2572 rewriter, op, op.getCond(), op.getTrueValue(),
2573 falseMux.getFalseValue(), op.getTwoStateAttr());
2579 foldToArrayCreateOnlyWhenDense))
2583 if (
auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
2584 trueMux && trueMux != op) {
2586 if (op.getCond() == trueMux.getCond()) {
2587 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2588 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2589 op.getFalseValue(), op.getTwoStateAttr());
2595 foldToArrayCreateOnlyWhenDense))
2600 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2601 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2602 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2603 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2605 auto subMux = MuxOp::create(
2606 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2607 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2608 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2609 trueMux.getTrueValue(), subMux,
2610 op.getTwoStateAttr());
2615 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2616 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2617 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2618 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2620 auto subMux = MuxOp::create(
2621 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2622 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2623 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2624 subMux, trueMux.getFalseValue(),
2625 op.getTwoStateAttr());
2630 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2631 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2632 trueMux && falseMux &&
2633 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2634 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2637 MuxOp::create(rewriter,
2638 rewriter.getFusedLoc(
2639 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2640 op.getCond(), trueMux.getCond(), falseMux.getCond());
2641 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2642 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2643 op.getTwoStateAttr());
2655 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2656 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2657 if (trueOp->getName() == falseOp->getName())
2670 if (op.getTrueValue().getDefiningOp() &&
2671 op.getTrueValue().getDefiningOp() != op)
2674 if (op.getFalseValue().getDefiningOp() &&
2675 op.getFalseValue().getDefiningOp() != op)
2686 if (op.getInputs().empty() || op.isUniform())
2688 auto inputs = op.getInputs();
2689 if (inputs.size() <= 1)
2694 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2699 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2700 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2701 if (!input || first.getCond() != input.getCond())
2706 SmallVector<Value> trues{first.getTrueValue()};
2707 SmallVector<Value> falses{first.getFalseValue()};
2708 SmallVector<Location> locs{first->getLoc()};
2709 bool isTwoState =
true;
2710 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2711 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2712 trues.push_back(input.getTrueValue());
2713 falses.push_back(input.getFalseValue());
2714 locs.push_back(input->getLoc());
2715 if (!input.getTwoState())
2720 auto loc = FusedLoc::get(op.getContext(), locs);
2724 auto arrayTy = op.getType();
2727 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2728 trueValues, falseValues, isTwoState);
2733 using OpRewritePattern::OpRewritePattern;
2736 PatternRewriter &rewriter)
const override {
2737 if (foldArrayOfMuxes(op, rewriter))
2745void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2746 MLIRContext *context) {
2747 results.insert<MuxRewriter, ArrayRewriter>(context);
2758 switch (predicate) {
2759 case ICmpPredicate::eq:
2761 case ICmpPredicate::ne:
2763 case ICmpPredicate::slt:
2764 return lhs.slt(rhs);
2765 case ICmpPredicate::sle:
2766 return lhs.sle(rhs);
2767 case ICmpPredicate::sgt:
2768 return lhs.sgt(rhs);
2769 case ICmpPredicate::sge:
2770 return lhs.sge(rhs);
2771 case ICmpPredicate::ult:
2772 return lhs.ult(rhs);
2773 case ICmpPredicate::ule:
2774 return lhs.ule(rhs);
2775 case ICmpPredicate::ugt:
2776 return lhs.ugt(rhs);
2777 case ICmpPredicate::uge:
2778 return lhs.uge(rhs);
2779 case ICmpPredicate::ceq:
2781 case ICmpPredicate::cne:
2783 case ICmpPredicate::weq:
2785 case ICmpPredicate::wne:
2788 llvm_unreachable(
"unknown comparison predicate");
2794 switch (predicate) {
2795 case ICmpPredicate::eq:
2796 case ICmpPredicate::sle:
2797 case ICmpPredicate::sge:
2798 case ICmpPredicate::ule:
2799 case ICmpPredicate::uge:
2800 case ICmpPredicate::ceq:
2801 case ICmpPredicate::weq:
2803 case ICmpPredicate::ne:
2804 case ICmpPredicate::slt:
2805 case ICmpPredicate::sgt:
2806 case ICmpPredicate::ult:
2807 case ICmpPredicate::ugt:
2808 case ICmpPredicate::cne:
2809 case ICmpPredicate::wne:
2812 llvm_unreachable(
"unknown comparison predicate");
2815OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2818 if (getLhs() == getRhs()) {
2820 return IntegerAttr::get(getType(), val);
2824 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2825 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2828 return IntegerAttr::get(getType(), val);
2836template <
typename Range>
2838 size_t commonPrefixLength = 0;
2839 auto ia = a.begin();
2840 auto ib = b.begin();
2842 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2848 return commonPrefixLength;
2852 size_t totalWidth = 0;
2853 for (
auto operand : operands) {
2856 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2858 totalWidth += width;
2868 PatternRewriter &rewriter) {
2872 SmallVector<Value> lhsOperands, rhsOperands;
2875 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2877 auto formCatOrReplicate = [&](Location loc,
2878 ArrayRef<Value> operands) -> Value {
2879 assert(!operands.empty());
2880 Value sameElement = operands[0];
2881 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2882 if (sameElement != operands[i])
2883 sameElement = Value();
2885 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2887 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2890 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2891 Value rhs) -> LogicalResult {
2892 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2897 size_t commonPrefixLength =
2899 if (commonPrefixLength == lhsOperands.size()) {
2902 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2908 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2910 size_t commonPrefixTotalWidth =
2911 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2912 size_t commonSuffixTotalWidth =
2913 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2914 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2915 .drop_back(commonSuffixLength);
2916 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2917 .drop_back(commonSuffixLength);
2919 auto replaceWithoutReplicatingSignBit = [&]() {
2920 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2921 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2922 return replaceWith(op.getPredicate(), newLhs, newRhs);
2925 auto replaceWithReplicatingSignBit = [&]() {
2926 auto firstNonEmptyValue = lhsOperands[0];
2927 auto firstNonEmptyElemWidth =
2928 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2929 Value signBit = rewriter.createOrFold<
ExtractOp>(
2930 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2932 auto newLhs = ConcatOp::create(rewriter, lhs->getLoc(), signBit, lhsOnly);
2933 auto newRhs = ConcatOp::create(rewriter, rhs->getLoc(), signBit, rhsOnly);
2934 return replaceWith(op.getPredicate(), newLhs, newRhs);
2937 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2939 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2940 return replaceWithoutReplicatingSignBit();
2946 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2947 return replaceWithReplicatingSignBit();
2949 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2951 return replaceWithoutReplicatingSignBit();
2965 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2966 PatternRewriter &rewriter) {
2970 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2971 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2974 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2975 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2983 SmallVector<Value> newConcatOperands;
2984 auto newConstant = APInt::getZeroWidth();
2989 unsigned knownMSB = bitsKnown.countLeadingOnes();
2991 Value operand = cmpOp.getLhs();
2996 while (knownMSB != bitsKnown.getBitWidth()) {
2999 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3002 unsigned unknownBits = bitsKnown.countLeadingZeros();
3003 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3004 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
3005 operand.getLoc(), operand, lowBit,
3007 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3010 newConcatOperands.push_back(spanOperand);
3013 if (newConstant.getBitWidth() != 0)
3014 newConstant = newConstant.concat(spanConstant);
3016 newConstant = spanConstant;
3019 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3020 bitsKnown = bitsKnown.trunc(newWidth);
3021 knownMSB = bitsKnown.countLeadingOnes();
3027 if (newConcatOperands.empty()) {
3028 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3029 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3035 Value concatResult =
3036 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
3040 rewriter, cmpOp.getOperand(1).getLoc(), newConstant);
3042 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
3043 cmpOp.getPredicate(), concatResult,
3044 newConstantOp, cmpOp.getTwoState());
3050 PatternRewriter &rewriter) {
3051 auto ip = rewriter.saveInsertionPoint();
3052 rewriter.setInsertionPoint(xorOp);
3054 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
3056 xorRHS.getValue() ^ rhs);
3058 switch (xorOp.getNumOperands()) {
3062 APInt::getZero(rhs.getBitWidth()));
3066 newLHS = xorOp.getOperand(0);
3070 SmallVector<Value> newOperands(xorOp.getOperands());
3071 newOperands.pop_back();
3072 newLHS = XorOp::create(rewriter, xorOp.getLoc(), newOperands,
false);
3076 bool xorMultipleUses = !xorOp->hasOneUse();
3080 if (xorMultipleUses)
3081 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3085 rewriter.restoreInsertionPoint(ip);
3086 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3087 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS,
false);
3090LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3096 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3097 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3098 "Should be folded");
3099 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3100 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3101 op.getRhs(), op.getLhs(), op.getTwoState());
3106 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3111 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3112 Value rhs) -> LogicalResult {
3113 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3114 rhs, op.getTwoState());
3118 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3119 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3120 APInt(1, constant));
3124 switch (op.getPredicate()) {
3125 case ICmpPredicate::slt:
3127 if (rhs.isMaxSignedValue())
3128 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3130 if (rhs.isMinSignedValue())
3131 return replaceWithConstantI1(0);
3133 if ((rhs - 1).isMinSignedValue())
3134 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3137 case ICmpPredicate::sgt:
3139 if (rhs.isMinSignedValue())
3140 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3142 if (rhs.isMaxSignedValue())
3143 return replaceWithConstantI1(0);
3145 if ((rhs + 1).isMaxSignedValue())
3146 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3149 case ICmpPredicate::ult:
3151 if (rhs.isAllOnes())
3152 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3155 return replaceWithConstantI1(0);
3157 if ((rhs - 1).isZero())
3158 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3162 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3163 rhs.getBitWidth()) {
3164 auto numOnes = rhs.countLeadingOnes();
3166 rhs.getBitWidth() - numOnes, numOnes);
3167 return replaceWith(ICmpPredicate::ne, smaller,
3172 case ICmpPredicate::ugt:
3175 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3177 if (rhs.isAllOnes())
3178 return replaceWithConstantI1(0);
3180 if ((rhs + 1).isAllOnes())
3181 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3185 if ((rhs + 1).isPowerOf2()) {
3186 auto numOnes = rhs.countTrailingOnes();
3187 auto newWidth = rhs.getBitWidth() - numOnes;
3190 return replaceWith(ICmpPredicate::ne, smaller,
3195 case ICmpPredicate::sle:
3197 if (rhs.isMaxSignedValue())
3198 return replaceWithConstantI1(1);
3200 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3201 case ICmpPredicate::sge:
3203 if (rhs.isMinSignedValue())
3204 return replaceWithConstantI1(1);
3206 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3207 case ICmpPredicate::ule:
3209 if (rhs.isAllOnes())
3210 return replaceWithConstantI1(1);
3212 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3213 case ICmpPredicate::uge:
3216 return replaceWithConstantI1(1);
3218 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3219 case ICmpPredicate::eq:
3220 if (rhs.getBitWidth() == 1) {
3223 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3228 if (rhs.isAllOnes()) {
3235 case ICmpPredicate::ne:
3236 if (rhs.getBitWidth() == 1) {
3242 if (rhs.isAllOnes()) {
3244 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3251 case ICmpPredicate::ceq:
3252 case ICmpPredicate::cne:
3253 case ICmpPredicate::weq:
3254 case ICmpPredicate::wne:
3260 if (op.getPredicate() == ICmpPredicate::eq ||
3261 op.getPredicate() == ICmpPredicate::ne) {
3266 if (!knownBits.isUnknown())
3273 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3280 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3281 if (rhs.isAllOnes() || rhs.isZero()) {
3282 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3285 rhs.isAllOnes() ? APInt::getAllOnes(width)
3286 : APInt::getZero(width));
3287 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3288 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3298 if (Operation *opLHS = op.getLhs().getDefiningOp())
3299 if (Operation *opRHS = op.getRhs().getDefiningOp())
3300 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3301 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool assumeMuxCondInOperand(Value muxCond, Value muxValue, bool constCond, PatternRewriter &rewriter)
If the mux condition is an operand to the op defining its true or false value, replace the condition ...
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool isOpTriviallyRecursive(Operation *op)
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(array_value, low_index, ret_type)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl< Value > &bits)
Extract bits from a value.
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
MuxChainWithComparisonFoldingStyle
Enum for mux chain folding styles.
LogicalResult convertModUByPowerOfTwo(ModUOp modOp, mlir::PatternRewriter &rewriter)
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
Value constructMuxTree(OpBuilder &builder, Location loc, ArrayRef< Value > selectors, ArrayRef< Value > leafNodes, Value outOfBoundsValue)
Construct a mux tree for given leaf nodes.
LogicalResult convertDivUByPowerOfTwo(DivUOp divOp, mlir::PatternRewriter &rewriter)
Convert unsigned division or modulo by a power of two.
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.