13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/Matchers.h"
15#include "mlir/IR/PatternMatch.h"
16#include "llvm/ADT/SetVector.h"
17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/KnownBits.h"
24using namespace matchers;
28 return llvm::any_of(op->getOperands(), [op](
auto operand) {
29 return operand.getDefiningOp() == op;
39 ArrayRef<Value> operands, OpBuilder &builder) {
40 OperationState state(loc, name);
41 state.addOperands(operands);
42 state.addTypes(operands[0].getType());
43 return builder.create(state)->getResult(0);
46static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
47 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
54 for (
auto op :
concat.getOperands())
56 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
57 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
68 return op->hasAttr(
"sv.attributes");
72template <
typename SubType>
73struct ComplementMatcher {
75 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
76 bool match(Operation *op) {
77 auto xorOp = dyn_cast<XorOp>(op);
78 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
83template <
typename SubType>
84static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
85 return ComplementMatcher<SubType>(subExpr);
91 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
92 "must be commutative operations"));
93 if (op->hasOneUse()) {
94 auto *user = *op->getUsers().begin();
95 return user->getName() == op->getName() &&
96 op->getAttrOfType<UnitAttr>(
"twoState") ==
97 user->getAttrOfType<UnitAttr>(
"twoState") &&
98 op->getBlock() == user->getBlock();
113 auto inputs = op->getOperands();
115 SmallVector<Value, 4> newOperands;
116 SmallVector<Location, 4> newLocations{op->getLoc()};
117 newOperands.reserve(inputs.size());
119 decltype(inputs.begin()) current, end;
122 SmallVector<Element> worklist;
123 worklist.push_back({inputs.begin(), inputs.end()});
124 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
125 bool changed =
false;
126 while (!worklist.empty()) {
127 auto &element = worklist.back();
130 if (element.current == element.end) {
135 Value value = *element.current++;
136 auto *flattenOp = value.getDefiningOp();
139 if (!flattenOp || flattenOp->getName() != op->getName() ||
140 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
141 flattenOp->getBlock() != op->getBlock()) {
142 newOperands.push_back(value);
147 if (!value.hasOneUse()) {
155 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
158 newOperands.push_back(value);
166 auto flattenOpInputs = flattenOp->getOperands();
167 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
168 newLocations.push_back(flattenOp->getLoc());
174 Value result =
createGenericOp(FusedLoc::get(op->getContext(), newLocations),
175 op->getName(), newOperands, rewriter);
177 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
185static std::pair<size_t, size_t>
187 size_t originalOpWidth) {
188 auto users = op->getUsers();
190 "getLowestBitAndHighestBitRequired cannot operate on "
191 "a empty list of uses.");
195 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
196 size_t highestBitRequired = 0;
198 for (
auto *user : users) {
199 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
200 size_t lowBit = extractOp.getLowBit();
202 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
203 highestBitRequired = std::max(highestBitRequired, highBit);
204 lowestBitRequired = std::min(lowestBitRequired, lowBit);
208 highestBitRequired = originalOpWidth - 1;
209 lowestBitRequired = 0;
213 return {lowestBitRequired, highestBitRequired};
218 PatternRewriter &rewriter) {
219 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
225 if (range.second + 1 == opType.getWidth() && range.first == 0)
228 SmallVector<Value> args;
229 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
230 for (
auto inop : op.getOperands()) {
232 if (inop.getType() != op.getType())
233 args.push_back(inop);
235 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
238 auto newop = OpTy::create(rewriter, op.getLoc(), newType, args);
239 newop->setDialectAttrs(op->getDialectAttrs());
240 if (op.getTwoState())
241 newop.setTwoState(
true);
243 Value newResult = newop.getResult();
245 newResult = rewriter.createOrFold<
ConcatOp>(
246 op.getLoc(), newResult,
248 APInt::getZero(range.first)));
249 if (range.second + 1 < opType.getWidth())
250 newResult = rewriter.createOrFold<
ConcatOp>(
253 rewriter, op.getLoc(),
254 APInt::getZero(opType.getWidth() - range.second - 1)),
256 rewriter.replaceOp(op, newResult);
264OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
269 if (cast<IntegerType>(getType()).
getWidth() ==
270 getInput().getType().getIntOrFloatBitWidth())
274 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
275 if (input.getValue().getBitWidth() == 1) {
276 if (input.getValue().isZero())
278 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
281 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
285 APInt result = APInt::getZeroWidth();
286 for (
auto i = getMultiple(); i != 0; --i)
287 result = result.concat(input.getValue());
294OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
299 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
300 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
312 hw::PEO paramOpcode) {
313 assert(operands.size() == 2 &&
"binary op takes two operands");
314 if (!operands[0] || !operands[1])
319 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
320 cast<TypedAttr>(operands[1]));
323OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
327 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
328 if (rhs.getValue().isZero())
329 return getOperand(0);
331 unsigned width = getType().getIntOrFloatBitWidth();
332 if (rhs.getValue().uge(width))
333 return getIntAttr(APInt::getZero(width), getContext());
338LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
344 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
347 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
348 if (value.ugt(width))
350 unsigned shift = value.getZExtValue();
353 if (width <= shift || shift == 0)
363 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
367OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
371 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
372 if (rhs.getValue().isZero())
373 return getOperand(0);
375 unsigned width = getType().getIntOrFloatBitWidth();
376 if (rhs.getValue().uge(width))
377 return getIntAttr(APInt::getZero(width), getContext());
382LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
388 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
391 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
392 if (value.ugt(width))
394 unsigned shift = value.getZExtValue();
397 if (width <= shift || shift == 0)
407 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
411OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
415 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
416 if (rhs.getValue().isZero())
417 return getOperand(0);
421LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
427 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
430 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
431 if (value.ugt(width))
433 unsigned shift = value.getZExtValue();
436 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
437 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
439 if (width == shift) {
447 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
455OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
460 if (getInput().getType() == getType())
464 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
465 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
466 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
477 PatternRewriter &rewriter) {
478 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
479 size_t beginOfFirstRelevantElement = 0;
480 auto it = reversedConcatArgs.begin();
481 size_t lowBit = op.getLowBit();
484 for (; it != reversedConcatArgs.end(); it++) {
485 assert(beginOfFirstRelevantElement <= lowBit &&
486 "incorrectly moved past an element that lowBit has coverage over");
489 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
490 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
514 beginOfFirstRelevantElement += operandWidth;
516 assert(it != reversedConcatArgs.end() &&
517 "incorrectly failed to find an element which contains coverage of "
520 SmallVector<Value> reverseConcatArgs;
521 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
522 size_t extractLo = lowBit - beginOfFirstRelevantElement;
527 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
528 auto concatArg = *it;
529 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
530 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
532 if (widthToConsume == operandWidth && extractLo == 0) {
533 reverseConcatArgs.push_back(concatArg);
535 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
536 reverseConcatArgs.push_back(
540 widthRemaining -= widthToConsume;
546 if (reverseConcatArgs.size() == 1) {
549 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
550 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
557 PatternRewriter &rewriter) {
558 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
559 auto replicateEltWidth =
560 replicate.getOperand().getType().getIntOrFloatBitWidth();
564 if (op.getLowBit() % replicateEltWidth == 0 &&
565 extractResultWidth % replicateEltWidth == 0) {
566 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
567 replicate.getOperand());
573 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
575 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
576 rewriter, op, op.getType(), replicate.getOperand(),
577 op.getLowBit() % replicateEltWidth);
586LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
589 auto *inputOp = op.getInput().getDefiningOp();
596 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
598 if (knownBits.isConstant()) {
599 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
600 knownBits.getConstant());
606 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
607 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
608 rewriter, op, op.getType(), innerExtract.getInput(),
609 innerExtract.getLowBit() + op.getLowBit());
614 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
618 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
624 if (inputOp && inputOp->getNumOperands() == 2 &&
625 isa<AndOp, OrOp, XorOp>(inputOp)) {
626 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
627 auto extractedCst = cstRHS.getValue().extractBits(
628 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
629 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
630 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
631 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
639 if (isa<AndOp>(inputOp)) {
642 unsigned lz = extractedCst.countLeadingZeros();
643 unsigned tz = extractedCst.countTrailingZeros();
644 unsigned pop = extractedCst.popcount();
645 if (extractedCst.getBitWidth() - lz - tz == pop) {
646 auto resultTy = rewriter.getIntegerType(pop);
647 SmallVector<Value> resultElts;
650 APInt::getZero(lz)));
651 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
652 op.getLoc(), resultTy, inputOp->getOperand(0),
653 op.getLowBit() + tz));
656 APInt::getZero(tz)));
657 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
666 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
667 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
669 if (shlOp->hasOneUse())
671 if (lhsCst.getValue().isOne()) {
673 rewriter, shlOp.getLoc(),
674 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
675 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
676 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
692 hw::PEO paramOpcode) {
693 assert(operands.size() > 1 &&
"caller should handle one-operand case");
696 if (!operands[1] || !operands[0])
700 if (llvm::all_of(operands.drop_front(2),
701 [&](Attribute in) { return !!in; })) {
702 SmallVector<mlir::TypedAttr> typedOperands;
703 typedOperands.reserve(operands.size());
704 for (
auto operand : operands) {
705 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
706 typedOperands.push_back(typedOperand);
710 if (typedOperands.size() == operands.size())
711 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
727 size_t concatIdx,
const APInt &cst,
728 PatternRewriter &rewriter) {
729 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
730 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
735 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
736 auto *operandOp = operand.getDefiningOp();
741 if (isa<hw::ConstantOp>(operandOp))
745 return operandOp->getName() == logicalOp->getName() &&
746 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
747 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
755 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
756 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
763 SmallVector<Value> newConcatOperands;
764 newConcatOperands.reserve(concatOp->getNumOperands());
767 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
768 for (Value operand : concatOp->getOperands()) {
769 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
770 nextOperandBit -= operandWidth;
774 cst.lshr(nextOperandBit).trunc(operandWidth));
776 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
781 ConcatOp::create(rewriter, concatOp.getLoc(), newConcatOperands);
785 if (logicalOp->getNumOperands() > 2) {
786 auto origOperands = logicalOp->getOperands();
787 SmallVector<Value> operands;
789 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
791 operands.append(origOperands.begin() + concatIdx + 1,
792 origOperands.begin() + (origOperands.size() - 1));
794 operands.push_back(newResult);
795 newResult = createLogicalOp(operands);
805 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
807 for (
auto op : operands) {
808 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
809 icmpOp && icmpOp.getTwoState()) {
810 auto predicate = icmpOp.getPredicate();
811 auto lhs = icmpOp.getLhs();
812 auto rhs = icmpOp.getRhs();
813 if (seenPredicates.contains(
814 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
817 seenPredicates.insert({predicate, lhs, rhs});
823OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
827 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
829 auto inputs = adaptor.getInputs();
832 for (
auto operand : inputs) {
835 value &= cast<IntegerAttr>(operand).getValue();
841 if (inputs.size() == 2 && inputs[1] &&
842 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
843 return getInputs()[0];
846 if (llvm::all_of(getInputs(),
847 [&](
auto in) {
return in == this->getInputs()[0]; }))
848 return getInputs()[0];
851 for (Value arg : getInputs()) {
854 for (Value arg2 : getInputs())
857 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
878template <
typename Op>
880 if (!op.getType().isInteger(1))
883 auto inputs = op.getInputs();
884 size_t size = inputs.size();
886 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
889 Value source = sourceOp.getOperand();
892 if (size != source.getType().getIntOrFloatBitWidth())
896 llvm::BitVector bits(size);
897 bits.set(sourceOp.getLowBit());
899 for (
size_t i = 1; i != size; ++i) {
900 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
901 if (!extractOp || extractOp.getOperand() != source)
903 bits.set(extractOp.getLowBit());
906 return bits.all() ? source : Value();
913template <
typename Op>
916 constexpr unsigned limit = 3;
917 auto inputs = op.getInputs();
919 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
920 llvm::SmallDenseSet<Op, 8> checked;
927 llvm::SmallVector<OpWithDepth, 8> worklist;
929 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
933 if (depth < limit && input.getParentBlock() == op->getBlock()) {
934 auto inputOp = input.template getDefiningOp<Op>();
935 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
936 checked.insert(inputOp).second)
937 worklist.push_back({inputOp, depth + 1});
941 for (
auto input : uniqueInputs)
944 while (!worklist.empty()) {
945 auto item = worklist.pop_back_val();
947 for (
auto input : item.op.getInputs()) {
948 uniqueInputs.remove(input);
949 enqueue(input, item.depth);
953 if (uniqueInputs.size() < inputs.size()) {
954 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
955 uniqueInputs.getArrayRef(),
963LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
967 auto inputs = op.getInputs();
968 auto size = inputs.size();
980 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
984 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
986 if (value.isAllOnes()) {
987 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
988 inputs.drop_back(),
false);
996 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
998 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
999 newOperands.push_back(cst);
1000 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1001 newOperands,
false);
1006 if (size == 2 && value.isPowerOf2()) {
1011 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1012 auto replicateOperand = replicate.getOperand();
1013 if (replicateOperand.getType().isInteger(1)) {
1014 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1015 auto trailingZeros = value.countTrailingZeros();
1018 SmallVector<Value, 3> concatOperands;
1019 if (trailingZeros != resultWidth - 1) {
1021 rewriter, op.getLoc(),
1022 APInt::getZero(resultWidth - trailingZeros - 1));
1023 concatOperands.push_back(highZeros);
1025 concatOperands.push_back(replicateOperand);
1026 if (trailingZeros != 0) {
1028 rewriter, op.getLoc(), APInt::getZero(trailingZeros));
1029 concatOperands.push_back(lowZeros);
1031 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1032 rewriter, op, op.getType(), concatOperands);
1041 unsigned leadingZeros = value.countLeadingZeros();
1042 unsigned trailingZeros = value.countTrailingZeros();
1043 if (leadingZeros > 0 || trailingZeros > 0) {
1044 unsigned maskLength = value.getBitWidth() - leadingZeros - trailingZeros;
1047 SmallVector<Value> operands;
1048 for (
auto input : inputs.drop_back()) {
1049 unsigned offset = trailingZeros;
1050 while (
auto extractOp = input.getDefiningOp<
ExtractOp>()) {
1051 input = extractOp.getInput();
1052 offset += extractOp.getLowBit();
1055 offset, maskLength));
1059 auto narrowMask = value.extractBits(maskLength, trailingZeros);
1060 if (!narrowMask.isAllOnes())
1062 rewriter, inputs.back().getLoc(), narrowMask));
1065 Value narrowValue = operands.back();
1066 if (operands.size() > 1)
1068 AndOp::create(rewriter, op.getLoc(), operands, op.getTwoState());
1072 if (leadingZeros > 0)
1074 rewriter, op.getLoc(), APInt::getZero(leadingZeros)));
1075 operands.push_back(narrowValue);
1076 if (trailingZeros > 0)
1078 rewriter, op.getLoc(), APInt::getZero(trailingZeros)));
1079 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, operands);
1086 for (
size_t i = 0; i < size - 1; ++i) {
1087 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1101 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1102 source, cmpAgainst);
1110OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1114 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1115 auto inputs = adaptor.getInputs();
1117 for (
auto operand : inputs) {
1120 value |= cast<IntegerAttr>(operand).getValue();
1121 if (value.isAllOnes())
1126 if (inputs.size() == 2 && inputs[1] &&
1127 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1128 return getInputs()[0];
1131 if (llvm::all_of(getInputs(),
1132 [&](
auto in) {
return in == this->getInputs()[0]; }))
1133 return getInputs()[0];
1136 for (Value arg : getInputs()) {
1138 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1139 for (Value arg2 : getInputs())
1140 if (arg2 == subExpr)
1142 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1152 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1159LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1163 auto inputs = op.getInputs();
1164 auto size = inputs.size();
1176 assert(size > 1 &&
"expected 2 or more operands");
1180 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1182 if (value.isZero()) {
1183 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1184 inputs.drop_back());
1190 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1192 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1193 newOperands.push_back(cst);
1194 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1202 for (
size_t i = 0; i < size - 1; ++i) {
1203 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1217 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1218 source, cmpAgainst);
1224 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1226 if (op.getTwoState() && firstMux.getTwoState() &&
1227 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1229 SmallVector<Value> conditions{firstMux.getCond()};
1230 auto check = [&](Value v) {
1234 conditions.push_back(mux.getCond());
1235 return mux.getTwoState() &&
1236 firstMux.getTrueValue() == mux.getTrueValue() &&
1237 firstMux.getFalseValue() == mux.getFalseValue();
1239 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1240 auto cond = comb::OrOp::create(rewriter, op.getLoc(), conditions,
true);
1241 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1242 rewriter, op, cond, firstMux.getTrueValue(),
1243 firstMux.getFalseValue(),
true);
1253OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1257 auto size = getInputs().size();
1258 auto inputs = adaptor.getInputs();
1262 return getInputs()[0];
1265 if (size == 2 && getInputs()[0] == getInputs()[1])
1266 return IntegerAttr::get(getType(), 0);
1269 if (inputs.size() == 2 && inputs[1] &&
1270 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1271 return getInputs()[0];
1275 if (isBinaryNot()) {
1277 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1278 subExpr != getResult())
1288 PatternRewriter &rewriter) {
1289 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1290 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1293 ICmpOp::create(rewriter, icmp.getLoc(), negatedPred, icmp.getOperand(0),
1294 icmp.getOperand(1), icmp.getTwoState());
1297 if (op.getNumOperands() > 2) {
1298 SmallVector<Value, 4> newOperands(op.getOperands());
1299 newOperands.pop_back();
1300 newOperands.erase(newOperands.begin() + icmpOperand);
1301 newOperands.push_back(result);
1303 XorOp::create(rewriter, op.getLoc(), newOperands, op.getTwoState());
1309LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1313 auto inputs = op.getInputs();
1314 auto size = inputs.size();
1315 assert(size > 1 &&
"expected 2 or more operands");
1318 if (inputs[size - 1] == inputs[size - 2]) {
1320 "expected idempotent case for 2 elements handled already.");
1321 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1322 inputs.drop_back(2),
false);
1328 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1330 if (value.isZero()) {
1331 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1332 inputs.drop_back(),
false);
1338 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1340 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1341 newOperands.push_back(cst);
1342 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1343 newOperands,
false);
1347 bool isSingleBit = value.getBitWidth() == 1;
1350 for (
size_t i = 0; i < size - 1; ++i) {
1351 Value operand = inputs[i];
1362 if (isSingleBit && operand.hasOneUse()) {
1363 assert(value == 1 &&
"single bit constant has to be one if not zero");
1364 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1380 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1387OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1392 if (getRhs() == getLhs())
1394 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1397 if (adaptor.getRhs()) {
1399 if (adaptor.getLhs()) {
1402 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1404 auto rhsNeg = hw::ParamExprAttr::get(
1405 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1406 return hw::ParamExprAttr::get(hw::PEO::Add,
1407 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1411 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1412 if (rhsC.getValue().isZero())
1420LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1426 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1428 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1440OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1444 auto size = getInputs().size();
1448 return getInputs()[0];
1454LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1458 auto inputs = op.getInputs();
1459 auto size = inputs.size();
1460 assert(size > 1 &&
"expected 2 or more operands");
1462 APInt value, value2;
1465 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1466 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1467 inputs.drop_back(),
false);
1472 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1473 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1475 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1476 newOperands.push_back(cst);
1477 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1478 newOperands,
false);
1483 if (inputs[size - 1] == inputs[size - 2]) {
1484 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1488 comb::ShlOp::create(rewriter, op.getLoc(), inputs.back(), one,
false);
1490 newOperands.push_back(shiftLeftOp);
1491 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1492 newOperands,
false);
1496 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1498 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1499 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1501 APInt one(value.getBitWidth(), 1,
false);
1505 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1506 auto mulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1508 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1509 newOperands.push_back(mulOp);
1510 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1511 newOperands,
false);
1515 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1517 if (mulOp && mulOp.getInputs().size() == 2 &&
1518 mulOp.getInputs()[0] == inputs[size - 2] &&
1519 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1521 APInt one(value.getBitWidth(), 1,
false);
1523 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1524 auto newMulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1526 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1527 newOperands.push_back(newMulOp);
1528 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1529 newOperands,
false);
1542 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1543 if (addOp && addOp.getInputs().size() == 2 &&
1544 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1545 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1548 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1549 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1550 op.getTwoState() && addOp.getTwoState());
1557OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1561 auto size = getInputs().size();
1562 auto inputs = adaptor.getInputs();
1566 return getInputs()[0];
1568 auto width = cast<IntegerType>(getType()).getWidth();
1570 return getIntAttr(APInt::getZero(0), getContext());
1572 APInt value(width, 1,
false);
1575 for (
auto operand : inputs) {
1578 value *= cast<IntegerAttr>(operand).getValue();
1587LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1591 auto inputs = op.getInputs();
1592 auto size = inputs.size();
1593 assert(size > 1 &&
"expected 2 or more operands");
1595 APInt value, value2;
1598 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1599 value.isPowerOf2()) {
1601 value.exactLogBase2());
1603 comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift,
false);
1605 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1606 ArrayRef<Value>(shlOp),
false);
1611 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1612 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1613 inputs.drop_back());
1618 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1619 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1621 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1622 newOperands.push_back(cst);
1623 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1639template <
class Op,
bool isSigned>
1640static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1641 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1643 if (rhsValue.getValue() == 1)
1647 if (rhsValue.getValue().isZero())
1654OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1657 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1660OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1666template <
class Op,
bool isSigned>
1667static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1668 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1670 if (rhsValue.getValue() == 1)
1671 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1675 if (rhsValue.getValue().isZero())
1679 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1681 if (lhsValue.getValue().isZero())
1682 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1689OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1692 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1695OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1705OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1709 if (getNumOperands() == 1)
1710 return getOperand(0);
1713 for (
auto attr : adaptor.getInputs())
1714 if (!attr || !isa<IntegerAttr>(attr))
1718 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1719 APInt result(resultWidth, 0);
1721 unsigned nextInsertion = resultWidth;
1723 for (
auto attr : adaptor.getInputs()) {
1724 auto chunk = cast<IntegerAttr>(attr).getValue();
1725 nextInsertion -= chunk.getBitWidth();
1726 result.insertBits(chunk, nextInsertion);
1732LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1736 auto inputs = op.getInputs();
1737 auto size = inputs.size();
1738 assert(size > 1 &&
"expected 2 or more operands");
1743 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1744 ValueRange replacements) -> LogicalResult {
1745 SmallVector<Value, 4> newOperands;
1746 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1747 newOperands.append(replacements.begin(), replacements.end());
1748 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1749 if (newOperands.size() == 1)
1752 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1757 Value commonOperand = inputs[0];
1758 for (
size_t i = 0; i != size; ++i) {
1760 if (inputs[i] != commonOperand)
1761 commonOperand = Value();
1765 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1766 return flattenConcat(i, i, subConcat->getOperands());
1771 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1772 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1773 unsigned prevWidth = prevCst.getValue().getBitWidth();
1774 unsigned thisWidth = cst.getValue().getBitWidth();
1775 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1776 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1780 return flattenConcat(i - 1, i, replacement);
1785 if (inputs[i] == inputs[i - 1]) {
1787 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1788 return flattenConcat(i - 1, i, replacement);
1793 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1795 if (repl.getOperand() == inputs[i - 1]) {
1796 Value replacement = rewriter.createOrFold<ReplicateOp>(
1797 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1798 return flattenConcat(i - 1, i, replacement);
1801 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1802 if (prevRepl.getOperand() == repl.getOperand()) {
1803 Value replacement = rewriter.createOrFold<ReplicateOp>(
1804 op.getLoc(), repl.getOperand(),
1805 repl.getMultiple() + prevRepl.getMultiple());
1806 return flattenConcat(i - 1, i, replacement);
1812 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1813 if (repl.getOperand() == inputs[i]) {
1814 Value replacement = rewriter.createOrFold<ReplicateOp>(
1815 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1816 return flattenConcat(i - 1, i, replacement);
1822 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1823 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1824 if (extract.getInput() == prevExtract.getInput()) {
1825 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1826 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1827 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1828 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1831 extract.getInput(), extract.getLowBit());
1832 return flattenConcat(i - 1, i, replacement);
1845 static std::optional<ArraySlice>
get(Value value) {
1846 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1848 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1851 if (
auto arraySlice =
1854 arraySlice.getInput(), arraySlice.getLowIndex(),
1855 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1857 return std::nullopt;
1860 if (
auto extractOpt = ArraySlice::get(inputs[i])) {
1861 if (
auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1863 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1864 prevExtractOpt->input == extractOpt->input &&
1865 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1866 extractOpt->width)) {
1867 auto resType = hw::ArrayType::get(
1868 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1870 extractOpt->width + prevExtractOpt->width);
1871 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1873 rewriter, op.getLoc(), resIntType,
1875 prevExtractOpt->input,
1876 extractOpt->index));
1877 return flattenConcat(i - 1, i, replacement);
1885 if (commonOperand) {
1886 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1898OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1903 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1904 return getTrueValue();
1905 if (
auto tv = adaptor.getTrueValue())
1906 if (tv == adaptor.getFalseValue())
1911 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1912 if (pred.getValue().isZero() && getFalseValue() != getResult())
1913 return getFalseValue();
1914 if (pred.getValue().isOne() && getTrueValue() != getResult())
1915 return getTrueValue();
1919 if (getCond().getType() == getTrueValue().getType())
1920 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1921 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1922 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1923 hw::getBitWidth(getType()) == 1 && getCond() != getResult())
1939 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1941 auto requiredPredicate =
1942 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1943 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1953 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
1956 for (
auto operand : orOp.getOperands())
1963 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
1966 for (
auto operand : andOp.getOperands())
1985 PatternRewriter &rewriter,
MuxOp rootMux,
bool isFalseSide,
1991 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
1994 Value indexValue = rootCmp.getLhs();
1997 auto getCaseValue = [&](
MuxOp mux) -> Value {
1998 return mux.getOperand(1 +
unsigned(!isFalseSide));
2003 auto getTreeValue = [&](
MuxOp mux) -> Value {
2004 return mux.getOperand(1 +
unsigned(isFalseSide));
2009 SmallVector<Location> locationsFound;
2010 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2014 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2016 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2017 valuesFound.push_back({cst, getCaseValue(mux)});
2018 locationsFound.push_back(mux.getCond().getLoc());
2019 locationsFound.push_back(mux->getLoc());
2024 if (!collectConstantValues(rootMux))
2028 if (rootMux->hasOneUse()) {
2029 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2030 if (getTreeValue(userMux) == rootMux.getResult() &&
2038 auto nextTreeValue = getTreeValue(rootMux);
2040 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2041 if (!nextMux || !nextMux->hasOneUse())
2043 if (!collectConstantValues(nextMux))
2045 nextTreeValue = getTreeValue(nextMux);
2048 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2050 if (indexWidth > 20)
2053 auto foldingStyle = styleFn(indexWidth, valuesFound.size());
2057 uint64_t tableSize = 1ULL << indexWidth;
2061 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2066 for (
auto &elt :
llvm::reverse(valuesFound)) {
2067 uint64_t idx = elt.first.getValue().getZExtValue();
2068 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2069 table[idx] = elt.second;
2073 SmallVector<Value> bits;
2082 "unknown folding style");
2086 std::reverse(table.begin(), table.end());
2089 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2091 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2106 PatternRewriter &rewriter) {
2107 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2108 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2112 if (fullyAssoc->getNumOperands() == 2)
2113 return fullyAssoc->getOperand(operandNo ^ 1);
2116 if (fullyAssoc->hasOneUse()) {
2117 rewriter.modifyOpInPlace(fullyAssoc,
2118 [&]() { fullyAssoc->eraseOperand(operandNo); });
2119 return fullyAssoc->getResult(0);
2123 SmallVector<Value> operands;
2124 operands.append(fullyAssoc->getOperands().begin(),
2125 fullyAssoc->getOperands().begin() + operandNo);
2126 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2127 fullyAssoc->getOperands().end());
2129 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2130 Value excluded = fullyAssoc->getOperand(operandNo);
2134 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2136 return opWithoutExcluded;
2146 PatternRewriter &rewriter) {
2149 Operation *subExpr =
2150 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2151 if (!subExpr || subExpr->getNumOperands() < 2)
2155 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2160 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2161 size_t opNo = 0, e = subExpr->getNumOperands();
2162 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2168 Value cond = op.getCond();
2174 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2179 Value subCond = subMux.getCond();
2182 if (subMux.getTrueValue() == commonValue)
2183 otherValue = subMux.getFalseValue();
2184 else if (subMux.getFalseValue() == commonValue) {
2185 otherValue = subMux.getTrueValue();
2195 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2196 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2197 otherValue, op.getTwoState());
2203 bool isaAndOp = isa<AndOp>(subExpr);
2204 if (isTrueOperand ^ isaAndOp)
2208 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2211 bool isaXorOp = isa<XorOp>(subExpr);
2212 bool isaOrOp = isa<OrOp>(subExpr);
2221 if (isaOrOp || isaXorOp) {
2222 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2223 restOfAssoc,
false);
2225 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2226 commonValue,
false);
2228 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2234 assert(isaAndOp &&
"unexpected operation here");
2235 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2236 restOfAssoc,
false);
2237 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2248 PatternRewriter &rewriter) {
2251 if (!isa<ConcatOp>(trueOp))
2255 SmallVector<Value> trueOperands, falseOperands;
2259 size_t numTrueOperands = trueOperands.size();
2260 size_t numFalseOperands = falseOperands.size();
2262 if (!numTrueOperands || !numFalseOperands ||
2263 (trueOperands.front() != falseOperands.front() &&
2264 trueOperands.back() != falseOperands.back()))
2268 if (trueOperands.front() == falseOperands.front()) {
2269 SmallVector<Value> operands;
2271 for (i = 0; i < numTrueOperands; ++i) {
2272 Value trueOperand = trueOperands[i];
2273 if (trueOperand == falseOperands[i])
2274 operands.push_back(trueOperand);
2278 if (i == numTrueOperands) {
2285 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2286 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2287 mux->getLoc(), operands.front(), operands.size());
2289 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2293 operands.append(trueOperands.begin() + i, trueOperands.end());
2294 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2296 operands.append(falseOperands.begin() + i, falseOperands.end());
2298 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2301 Value lsb = rewriter.createOrFold<
MuxOp>(
2302 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2303 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2308 if (trueOperands.back() == falseOperands.back()) {
2309 SmallVector<Value> operands;
2312 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2313 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2314 operands.push_back(trueOperand);
2318 std::reverse(operands.begin(), operands.end());
2319 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2323 operands.append(trueOperands.begin(), trueOperands.end() - i);
2324 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2326 operands.append(falseOperands.begin(), falseOperands.end() - i);
2328 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2330 Value msb = rewriter.createOrFold<
MuxOp>(
2331 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2332 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2344 if (!trueVec || !falseVec)
2346 if (!trueVec.isUniform() || !falseVec.isUniform())
2349 auto mux = MuxOp::create(rewriter, op.getLoc(), op.getCond(),
2350 trueVec.getUniformElement(),
2351 falseVec.getUniformElement(), op.getTwoState());
2353 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2361 bool constCond, PatternRewriter &rewriter) {
2362 if (!muxValue.hasOneUse())
2364 auto *op = muxValue.getDefiningOp();
2365 if (!op || !isa_and_nonnull<CombDialect>(op->getDialect()))
2367 if (!llvm::is_contained(op->getOperands(), muxCond))
2369 OpBuilder::InsertionGuard guard(rewriter);
2370 rewriter.setInsertionPoint(op);
2373 rewriter.modifyOpInPlace(op, [&] {
2374 for (
auto &use : op->getOpOperands())
2375 if (use.get() == muxCond)
2383 using OpRewritePattern::OpRewritePattern;
2385 LogicalResult matchAndRewrite(
MuxOp op,
2386 PatternRewriter &rewriter)
const override;
2390foldToArrayCreateOnlyWhenDense(
size_t indexWidth,
size_t numEntries) {
2393 if (indexWidth >= 9 || numEntries < 3)
2399 uint64_t tableSize = 1ULL << indexWidth;
2400 if (numEntries >= tableSize * 5 / 8)
2405LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2406 PatternRewriter &rewriter)
const {
2410 bool isSignlessInt =
false;
2411 if (
auto intType = dyn_cast<IntegerType>(op.getType()))
2412 isSignlessInt = intType.isSignless();
2419 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value)) && isSignlessInt) {
2420 if (value.getBitWidth() == 1) {
2422 if (value.isZero()) {
2424 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2425 op.getFalseValue(),
false);
2430 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2431 op.getFalseValue(),
false);
2437 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2442 APInt xorValue = value ^ value2;
2443 if (xorValue.isPowerOf2()) {
2444 unsigned leadingZeros = xorValue.countLeadingZeros();
2445 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2446 SmallVector<Value, 3> operands;
2454 if (leadingZeros > 0)
2455 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2456 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2460 auto v1 = rewriter.createOrFold<
ExtractOp>(
2461 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2462 auto v2 = rewriter.createOrFold<
ExtractOp>(
2463 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2464 operands.push_back(rewriter.createOrFold<
MuxOp>(
2465 op.getLoc(), op.getCond(), v1, v2,
false));
2467 if (trailingZeros > 0)
2468 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2469 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2471 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2478 if (value.isAllOnes() && value2.isZero()) {
2479 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2480 rewriter, op, op.getType(), op.getCond());
2486 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2487 isSignlessInt && value.getBitWidth() == 1) {
2489 if (value.isZero()) {
2490 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2491 op.getTrueValue(),
false);
2498 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2499 op.getFalseValue(),
false);
2500 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2501 op.getTrueValue(),
false);
2507 Operation *condOp = op.getCond().getDefiningOp();
2508 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2510 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2511 subExpr, op.getFalseValue(),
2512 op.getTrueValue(),
true);
2519 if (condOp && condOp->hasOneUse()) {
2520 SmallVector<Value> invertedOperands;
2524 auto getInvertedOperands = [&]() ->
bool {
2525 for (Value operand : condOp->getOperands()) {
2526 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2527 invertedOperands.push_back(subExpr);
2534 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2536 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2537 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2538 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2542 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2544 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2545 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2546 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2552 if (
auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
2553 falseMux && falseMux != op) {
2555 if (op.getCond() == falseMux.getCond()) {
2556 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2557 rewriter, op, op.getCond(), op.getTrueValue(),
2558 falseMux.getFalseValue(), op.getTwoStateAttr());
2564 foldToArrayCreateOnlyWhenDense))
2568 if (
auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
2569 trueMux && trueMux != op) {
2571 if (op.getCond() == trueMux.getCond()) {
2572 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2573 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2574 op.getFalseValue(), op.getTwoStateAttr());
2580 foldToArrayCreateOnlyWhenDense))
2585 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2586 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2587 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2588 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2590 auto subMux = MuxOp::create(
2591 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2592 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2593 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2594 trueMux.getTrueValue(), subMux,
2595 op.getTwoStateAttr());
2600 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2601 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2602 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2603 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2605 auto subMux = MuxOp::create(
2606 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2607 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2608 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2609 subMux, trueMux.getFalseValue(),
2610 op.getTwoStateAttr());
2615 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2616 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2617 trueMux && falseMux &&
2618 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2619 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2622 MuxOp::create(rewriter,
2623 rewriter.getFusedLoc(
2624 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2625 op.getCond(), trueMux.getCond(), falseMux.getCond());
2626 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2627 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2628 op.getTwoStateAttr());
2640 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2641 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2642 if (trueOp->getName() == falseOp->getName())
2655 if (op.getTrueValue().getDefiningOp() &&
2656 op.getTrueValue().getDefiningOp() != op)
2659 if (op.getFalseValue().getDefiningOp() &&
2660 op.getFalseValue().getDefiningOp() != op)
2671 if (op.getInputs().empty() || op.isUniform())
2673 auto inputs = op.getInputs();
2674 if (inputs.size() <= 1)
2679 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2684 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2685 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2686 if (!input || first.getCond() != input.getCond())
2691 SmallVector<Value> trues{first.getTrueValue()};
2692 SmallVector<Value> falses{first.getFalseValue()};
2693 SmallVector<Location> locs{first->getLoc()};
2694 bool isTwoState =
true;
2695 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2696 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2697 trues.push_back(input.getTrueValue());
2698 falses.push_back(input.getFalseValue());
2699 locs.push_back(input->getLoc());
2700 if (!input.getTwoState())
2705 auto loc = FusedLoc::get(op.getContext(), locs);
2709 auto arrayTy = op.getType();
2712 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2713 trueValues, falseValues, isTwoState);
2718 using OpRewritePattern::OpRewritePattern;
2721 PatternRewriter &rewriter)
const override {
2722 if (foldArrayOfMuxes(op, rewriter))
2730void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2731 MLIRContext *context) {
2732 results.insert<MuxRewriter, ArrayRewriter>(context);
2743 switch (predicate) {
2744 case ICmpPredicate::eq:
2746 case ICmpPredicate::ne:
2748 case ICmpPredicate::slt:
2749 return lhs.slt(rhs);
2750 case ICmpPredicate::sle:
2751 return lhs.sle(rhs);
2752 case ICmpPredicate::sgt:
2753 return lhs.sgt(rhs);
2754 case ICmpPredicate::sge:
2755 return lhs.sge(rhs);
2756 case ICmpPredicate::ult:
2757 return lhs.ult(rhs);
2758 case ICmpPredicate::ule:
2759 return lhs.ule(rhs);
2760 case ICmpPredicate::ugt:
2761 return lhs.ugt(rhs);
2762 case ICmpPredicate::uge:
2763 return lhs.uge(rhs);
2764 case ICmpPredicate::ceq:
2766 case ICmpPredicate::cne:
2768 case ICmpPredicate::weq:
2770 case ICmpPredicate::wne:
2773 llvm_unreachable(
"unknown comparison predicate");
2779 switch (predicate) {
2780 case ICmpPredicate::eq:
2781 case ICmpPredicate::sle:
2782 case ICmpPredicate::sge:
2783 case ICmpPredicate::ule:
2784 case ICmpPredicate::uge:
2785 case ICmpPredicate::ceq:
2786 case ICmpPredicate::weq:
2788 case ICmpPredicate::ne:
2789 case ICmpPredicate::slt:
2790 case ICmpPredicate::sgt:
2791 case ICmpPredicate::ult:
2792 case ICmpPredicate::ugt:
2793 case ICmpPredicate::cne:
2794 case ICmpPredicate::wne:
2797 llvm_unreachable(
"unknown comparison predicate");
2800OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2803 if (getLhs() == getRhs()) {
2805 return IntegerAttr::get(getType(), val);
2809 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2810 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2813 return IntegerAttr::get(getType(), val);
2821template <
typename Range>
2823 size_t commonPrefixLength = 0;
2824 auto ia = a.begin();
2825 auto ib = b.begin();
2827 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2833 return commonPrefixLength;
2837 size_t totalWidth = 0;
2838 for (
auto operand : operands) {
2841 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2843 totalWidth += width;
2853 PatternRewriter &rewriter) {
2857 SmallVector<Value> lhsOperands, rhsOperands;
2860 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2862 auto formCatOrReplicate = [&](Location loc,
2863 ArrayRef<Value> operands) -> Value {
2864 assert(!operands.empty());
2865 Value sameElement = operands[0];
2866 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2867 if (sameElement != operands[i])
2868 sameElement = Value();
2870 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2872 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2875 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2876 Value rhs) -> LogicalResult {
2877 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2882 size_t commonPrefixLength =
2884 if (commonPrefixLength == lhsOperands.size()) {
2887 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2893 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2895 size_t commonPrefixTotalWidth =
2896 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2897 size_t commonSuffixTotalWidth =
2898 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2899 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2900 .drop_back(commonSuffixLength);
2901 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2902 .drop_back(commonSuffixLength);
2904 auto replaceWithoutReplicatingSignBit = [&]() {
2905 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2906 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2907 return replaceWith(op.getPredicate(), newLhs, newRhs);
2910 auto replaceWithReplicatingSignBit = [&]() {
2911 auto firstNonEmptyValue = lhsOperands[0];
2912 auto firstNonEmptyElemWidth =
2913 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2914 Value signBit = rewriter.createOrFold<
ExtractOp>(
2915 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2917 auto newLhs = ConcatOp::create(rewriter, lhs->getLoc(), signBit, lhsOnly);
2918 auto newRhs = ConcatOp::create(rewriter, rhs->getLoc(), signBit, rhsOnly);
2919 return replaceWith(op.getPredicate(), newLhs, newRhs);
2922 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2924 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2925 return replaceWithoutReplicatingSignBit();
2931 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2932 return replaceWithReplicatingSignBit();
2934 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2936 return replaceWithoutReplicatingSignBit();
2950 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2951 PatternRewriter &rewriter) {
2955 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2956 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2959 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2960 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2968 SmallVector<Value> newConcatOperands;
2969 auto newConstant = APInt::getZeroWidth();
2974 unsigned knownMSB = bitsKnown.countLeadingOnes();
2976 Value operand = cmpOp.getLhs();
2981 while (knownMSB != bitsKnown.getBitWidth()) {
2984 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2987 unsigned unknownBits = bitsKnown.countLeadingZeros();
2988 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2989 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
2990 operand.getLoc(), operand, lowBit,
2992 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2995 newConcatOperands.push_back(spanOperand);
2998 if (newConstant.getBitWidth() != 0)
2999 newConstant = newConstant.concat(spanConstant);
3001 newConstant = spanConstant;
3004 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3005 bitsKnown = bitsKnown.trunc(newWidth);
3006 knownMSB = bitsKnown.countLeadingOnes();
3012 if (newConcatOperands.empty()) {
3013 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3014 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3020 Value concatResult =
3021 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
3025 rewriter, cmpOp.getOperand(1).getLoc(), newConstant);
3027 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
3028 cmpOp.getPredicate(), concatResult,
3029 newConstantOp, cmpOp.getTwoState());
3035 PatternRewriter &rewriter) {
3036 auto ip = rewriter.saveInsertionPoint();
3037 rewriter.setInsertionPoint(xorOp);
3039 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
3041 xorRHS.getValue() ^ rhs);
3043 switch (xorOp.getNumOperands()) {
3047 APInt::getZero(rhs.getBitWidth()));
3051 newLHS = xorOp.getOperand(0);
3055 SmallVector<Value> newOperands(xorOp.getOperands());
3056 newOperands.pop_back();
3057 newLHS = XorOp::create(rewriter, xorOp.getLoc(), newOperands,
false);
3061 bool xorMultipleUses = !xorOp->hasOneUse();
3065 if (xorMultipleUses)
3066 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3070 rewriter.restoreInsertionPoint(ip);
3071 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3072 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS,
false);
3075LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3081 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3082 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3083 "Should be folded");
3084 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3085 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3086 op.getRhs(), op.getLhs(), op.getTwoState());
3091 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3096 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3097 Value rhs) -> LogicalResult {
3098 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3099 rhs, op.getTwoState());
3103 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3104 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3105 APInt(1, constant));
3109 switch (op.getPredicate()) {
3110 case ICmpPredicate::slt:
3112 if (rhs.isMaxSignedValue())
3113 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3115 if (rhs.isMinSignedValue())
3116 return replaceWithConstantI1(0);
3118 if ((rhs - 1).isMinSignedValue())
3119 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3122 case ICmpPredicate::sgt:
3124 if (rhs.isMinSignedValue())
3125 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3127 if (rhs.isMaxSignedValue())
3128 return replaceWithConstantI1(0);
3130 if ((rhs + 1).isMaxSignedValue())
3131 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3134 case ICmpPredicate::ult:
3136 if (rhs.isAllOnes())
3137 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3140 return replaceWithConstantI1(0);
3142 if ((rhs - 1).isZero())
3143 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3147 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3148 rhs.getBitWidth()) {
3149 auto numOnes = rhs.countLeadingOnes();
3151 rhs.getBitWidth() - numOnes, numOnes);
3152 return replaceWith(ICmpPredicate::ne, smaller,
3157 case ICmpPredicate::ugt:
3160 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3162 if (rhs.isAllOnes())
3163 return replaceWithConstantI1(0);
3165 if ((rhs + 1).isAllOnes())
3166 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3170 if ((rhs + 1).isPowerOf2()) {
3171 auto numOnes = rhs.countTrailingOnes();
3172 auto newWidth = rhs.getBitWidth() - numOnes;
3175 return replaceWith(ICmpPredicate::ne, smaller,
3180 case ICmpPredicate::sle:
3182 if (rhs.isMaxSignedValue())
3183 return replaceWithConstantI1(1);
3185 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3186 case ICmpPredicate::sge:
3188 if (rhs.isMinSignedValue())
3189 return replaceWithConstantI1(1);
3191 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3192 case ICmpPredicate::ule:
3194 if (rhs.isAllOnes())
3195 return replaceWithConstantI1(1);
3197 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3198 case ICmpPredicate::uge:
3201 return replaceWithConstantI1(1);
3203 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3204 case ICmpPredicate::eq:
3205 if (rhs.getBitWidth() == 1) {
3208 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3213 if (rhs.isAllOnes()) {
3220 case ICmpPredicate::ne:
3221 if (rhs.getBitWidth() == 1) {
3227 if (rhs.isAllOnes()) {
3229 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3236 case ICmpPredicate::ceq:
3237 case ICmpPredicate::cne:
3238 case ICmpPredicate::weq:
3239 case ICmpPredicate::wne:
3245 if (op.getPredicate() == ICmpPredicate::eq ||
3246 op.getPredicate() == ICmpPredicate::ne) {
3251 if (!knownBits.isUnknown())
3258 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3265 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3266 if (rhs.isAllOnes() || rhs.isZero()) {
3267 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3270 rhs.isAllOnes() ? APInt::getAllOnes(width)
3271 : APInt::getZero(width));
3272 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3273 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3283 if (Operation *opLHS = op.getLhs().getDefiningOp())
3284 if (Operation *opRHS = op.getRhs().getDefiningOp())
3285 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3286 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool assumeMuxCondInOperand(Value muxCond, Value muxValue, bool constCond, PatternRewriter &rewriter)
If the mux condition is an operand to the op defining its true or false value, replace the condition ...
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool isOpTriviallyRecursive(Operation *op)
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(array_value, low_index, ret_type)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl< Value > &bits)
Extract bits from a value.
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
MuxChainWithComparisonFoldingStyle
Enum for mux chain folding styles.
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
Value constructMuxTree(OpBuilder &builder, Location loc, ArrayRef< Value > selectors, ArrayRef< Value > leafNodes, Value outOfBoundsValue)
Construct a mux tree for given leaf nodes.
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.