13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15#include "llvm/ADT/SetVector.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/KnownBits.h"
23using namespace matchers;
27 return llvm::any_of(op->getOperands(), [op](
auto operand) {
28 return operand.getDefiningOp() == op;
38 ArrayRef<Value> operands, OpBuilder &builder) {
39 OperationState state(loc, name);
40 state.addOperands(operands);
41 state.addTypes(operands[0].getType());
42 return builder.create(state)->getResult(0);
45static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
46 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
45static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
…}
53 for (
auto op :
concat.getOperands())
55 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
56 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
67 return op->hasAttr(
"sv.attributes");
71template <
typename SubType>
72struct ComplementMatcher {
74 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
75 bool match(Operation *op) {
76 auto xorOp = dyn_cast<XorOp>(op);
77 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
82template <
typename SubType>
83static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
84 return ComplementMatcher<SubType>(subExpr);
83static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
…}
90 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
91 "must be commutative operations"));
92 if (op->hasOneUse()) {
93 auto *user = *op->getUsers().begin();
94 return user->getName() == op->getName() &&
95 op->getAttrOfType<UnitAttr>(
"twoState") ==
96 user->getAttrOfType<UnitAttr>(
"twoState") &&
97 op->getBlock() == user->getBlock();
112 auto inputs = op->getOperands();
114 SmallVector<Value, 4> newOperands;
115 SmallVector<Location, 4> newLocations{op->getLoc()};
116 newOperands.reserve(inputs.size());
118 decltype(inputs.begin()) current, end;
121 SmallVector<Element> worklist;
122 worklist.push_back({inputs.begin(), inputs.end()});
123 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
124 bool changed =
false;
125 while (!worklist.empty()) {
126 auto &element = worklist.back();
129 if (element.current == element.end) {
134 Value value = *element.current++;
135 auto *flattenOp = value.getDefiningOp();
138 if (!flattenOp || flattenOp->getName() != op->getName() ||
139 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
140 flattenOp->getBlock() != op->getBlock()) {
141 newOperands.push_back(value);
146 if (!value.hasOneUse()) {
154 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
157 newOperands.push_back(value);
165 auto flattenOpInputs = flattenOp->getOperands();
166 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
167 newLocations.push_back(flattenOp->getLoc());
173 Value result =
createGenericOp(FusedLoc::get(op->getContext(), newLocations),
174 op->getName(), newOperands, rewriter);
176 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
184static std::pair<size_t, size_t>
186 size_t originalOpWidth) {
187 auto users = op->getUsers();
189 "getLowestBitAndHighestBitRequired cannot operate on "
190 "a empty list of uses.");
194 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
195 size_t highestBitRequired = 0;
197 for (
auto *user : users) {
198 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
199 size_t lowBit = extractOp.getLowBit();
201 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
202 highestBitRequired = std::max(highestBitRequired, highBit);
203 lowestBitRequired = std::min(lowestBitRequired, lowBit);
207 highestBitRequired = originalOpWidth - 1;
208 lowestBitRequired = 0;
212 return {lowestBitRequired, highestBitRequired};
217 PatternRewriter &rewriter) {
218 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
224 if (range.second + 1 == opType.getWidth() && range.first == 0)
227 SmallVector<Value> args;
228 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
229 for (
auto inop : op.getOperands()) {
231 if (inop.getType() != op.getType())
232 args.push_back(inop);
234 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
237 auto newop = OpTy::create(rewriter, op.getLoc(), newType, args);
238 newop->setDialectAttrs(op->getDialectAttrs());
239 if (op.getTwoState())
240 newop.setTwoState(
true);
242 Value newResult = newop.getResult();
244 newResult = rewriter.createOrFold<
ConcatOp>(
245 op.getLoc(), newResult,
247 APInt::getZero(range.first)));
248 if (range.second + 1 < opType.getWidth())
249 newResult = rewriter.createOrFold<
ConcatOp>(
252 rewriter, op.getLoc(),
253 APInt::getZero(opType.getWidth() - range.second - 1)),
255 rewriter.replaceOp(op, newResult);
263OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
268 if (cast<IntegerType>(getType()).
getWidth() ==
269 getInput().getType().getIntOrFloatBitWidth())
273 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
274 if (input.getValue().getBitWidth() == 1) {
275 if (input.getValue().isZero())
277 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
280 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
284 APInt result = APInt::getZeroWidth();
285 for (
auto i = getMultiple(); i != 0; --i)
286 result = result.concat(input.getValue());
293OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
298 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
299 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
311 hw::PEO paramOpcode) {
312 assert(operands.size() == 2 &&
"binary op takes two operands");
313 if (!operands[0] || !operands[1])
318 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
319 cast<TypedAttr>(operands[1]));
322OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
326 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
327 if (rhs.getValue().isZero())
328 return getOperand(0);
330 unsigned width = getType().getIntOrFloatBitWidth();
331 if (rhs.getValue().uge(width))
332 return getIntAttr(APInt::getZero(width), getContext());
337LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
343 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
346 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
347 if (value.ugt(width))
349 unsigned shift = value.getZExtValue();
352 if (width <= shift || shift == 0)
362 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
366OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
370 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
371 if (rhs.getValue().isZero())
372 return getOperand(0);
374 unsigned width = getType().getIntOrFloatBitWidth();
375 if (rhs.getValue().uge(width))
376 return getIntAttr(APInt::getZero(width), getContext());
381LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
387 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
390 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
391 if (value.ugt(width))
393 unsigned shift = value.getZExtValue();
396 if (width <= shift || shift == 0)
406 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
410OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
414 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
415 if (rhs.getValue().isZero())
416 return getOperand(0);
420LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
426 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
429 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
430 if (value.ugt(width))
432 unsigned shift = value.getZExtValue();
435 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
436 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
438 if (width == shift) {
446 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
454OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
459 if (getInput().getType() == getType())
463 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
464 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
465 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
476 PatternRewriter &rewriter) {
477 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
478 size_t beginOfFirstRelevantElement = 0;
479 auto it = reversedConcatArgs.begin();
480 size_t lowBit = op.getLowBit();
483 for (; it != reversedConcatArgs.end(); it++) {
484 assert(beginOfFirstRelevantElement <= lowBit &&
485 "incorrectly moved past an element that lowBit has coverage over");
488 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
489 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
513 beginOfFirstRelevantElement += operandWidth;
515 assert(it != reversedConcatArgs.end() &&
516 "incorrectly failed to find an element which contains coverage of "
519 SmallVector<Value> reverseConcatArgs;
520 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
521 size_t extractLo = lowBit - beginOfFirstRelevantElement;
526 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
527 auto concatArg = *it;
528 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
529 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
531 if (widthToConsume == operandWidth && extractLo == 0) {
532 reverseConcatArgs.push_back(concatArg);
534 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
535 reverseConcatArgs.push_back(
539 widthRemaining -= widthToConsume;
545 if (reverseConcatArgs.size() == 1) {
548 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
549 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
556 PatternRewriter &rewriter) {
557 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
558 auto replicateEltWidth =
559 replicate.getOperand().getType().getIntOrFloatBitWidth();
563 if (op.getLowBit() % replicateEltWidth == 0 &&
564 extractResultWidth % replicateEltWidth == 0) {
565 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
566 replicate.getOperand());
572 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
574 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
575 rewriter, op, op.getType(), replicate.getOperand(),
576 op.getLowBit() % replicateEltWidth);
585LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
588 auto *inputOp = op.getInput().getDefiningOp();
595 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
597 if (knownBits.isConstant()) {
598 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
599 knownBits.getConstant());
605 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
606 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
607 rewriter, op, op.getType(), innerExtract.getInput(),
608 innerExtract.getLowBit() + op.getLowBit());
613 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
617 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
623 if (inputOp && inputOp->getNumOperands() == 2 &&
624 isa<AndOp, OrOp, XorOp>(inputOp)) {
625 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
626 auto extractedCst = cstRHS.getValue().extractBits(
627 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
628 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
629 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
630 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
638 if (isa<AndOp>(inputOp)) {
641 unsigned lz = extractedCst.countLeadingZeros();
642 unsigned tz = extractedCst.countTrailingZeros();
643 unsigned pop = extractedCst.popcount();
644 if (extractedCst.getBitWidth() - lz - tz == pop) {
645 auto resultTy = rewriter.getIntegerType(pop);
646 SmallVector<Value> resultElts;
649 APInt::getZero(lz)));
650 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
651 op.getLoc(), resultTy, inputOp->getOperand(0),
652 op.getLowBit() + tz));
655 APInt::getZero(tz)));
656 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
665 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
666 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
668 if (shlOp->hasOneUse())
670 if (lhsCst.getValue().isOne()) {
672 rewriter, shlOp.getLoc(),
673 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
674 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
675 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
691 hw::PEO paramOpcode) {
692 assert(operands.size() > 1 &&
"caller should handle one-operand case");
695 if (!operands[1] || !operands[0])
699 if (llvm::all_of(operands.drop_front(2),
700 [&](Attribute in) { return !!in; })) {
701 SmallVector<mlir::TypedAttr> typedOperands;
702 typedOperands.reserve(operands.size());
703 for (
auto operand : operands) {
704 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
705 typedOperands.push_back(typedOperand);
709 if (typedOperands.size() == operands.size())
710 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
726 size_t concatIdx,
const APInt &cst,
727 PatternRewriter &rewriter) {
728 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
729 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
734 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
735 auto *operandOp = operand.getDefiningOp();
740 if (isa<hw::ConstantOp>(operandOp))
744 return operandOp->getName() == logicalOp->getName() &&
745 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
746 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
754 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
755 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
762 SmallVector<Value> newConcatOperands;
763 newConcatOperands.reserve(concatOp->getNumOperands());
766 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
767 for (Value operand : concatOp->getOperands()) {
768 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
769 nextOperandBit -= operandWidth;
773 cst.lshr(nextOperandBit).trunc(operandWidth));
775 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
780 ConcatOp::create(rewriter, concatOp.getLoc(), newConcatOperands);
784 if (logicalOp->getNumOperands() > 2) {
785 auto origOperands = logicalOp->getOperands();
786 SmallVector<Value> operands;
788 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
790 operands.append(origOperands.begin() + concatIdx + 1,
791 origOperands.begin() + (origOperands.size() - 1));
793 operands.push_back(newResult);
794 newResult = createLogicalOp(operands);
804 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
806 for (
auto op : operands) {
807 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
808 icmpOp && icmpOp.getTwoState()) {
809 auto predicate = icmpOp.getPredicate();
810 auto lhs = icmpOp.getLhs();
811 auto rhs = icmpOp.getRhs();
812 if (seenPredicates.contains(
813 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
816 seenPredicates.insert({predicate, lhs, rhs});
822OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
826 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
828 auto inputs = adaptor.getInputs();
831 for (
auto operand : inputs) {
834 value &= cast<IntegerAttr>(operand).getValue();
840 if (inputs.size() == 2 && inputs[1] &&
841 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
842 return getInputs()[0];
845 if (llvm::all_of(getInputs(),
846 [&](
auto in) {
return in == this->getInputs()[0]; }))
847 return getInputs()[0];
850 for (Value arg : getInputs()) {
853 for (Value arg2 : getInputs())
856 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
877template <
typename Op>
879 if (!op.getType().isInteger(1))
882 auto inputs = op.getInputs();
883 size_t size = inputs.size();
885 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
888 Value source = sourceOp.getOperand();
891 if (size != source.getType().getIntOrFloatBitWidth())
895 llvm::BitVector bits(size);
896 bits.set(sourceOp.getLowBit());
898 for (
size_t i = 1; i != size; ++i) {
899 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
900 if (!extractOp || extractOp.getOperand() != source)
902 bits.set(extractOp.getLowBit());
905 return bits.all() ? source : Value();
912template <
typename Op>
915 constexpr unsigned limit = 3;
916 auto inputs = op.getInputs();
918 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
919 llvm::SmallDenseSet<Op, 8> checked;
926 llvm::SmallVector<OpWithDepth, 8> worklist;
928 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
932 if (depth < limit && input.getParentBlock() == op->getBlock()) {
933 auto inputOp = input.template getDefiningOp<Op>();
934 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
935 checked.insert(inputOp).second)
936 worklist.push_back({inputOp, depth + 1});
940 for (
auto input : uniqueInputs)
943 while (!worklist.empty()) {
944 auto item = worklist.pop_back_val();
946 for (
auto input : item.op.getInputs()) {
947 uniqueInputs.remove(input);
948 enqueue(input, item.depth);
952 if (uniqueInputs.size() < inputs.size()) {
953 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
954 uniqueInputs.getArrayRef(),
962LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
966 auto inputs = op.getInputs();
967 auto size = inputs.size();
979 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
983 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
985 if (value.isAllOnes()) {
986 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
987 inputs.drop_back(),
false);
995 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
997 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
998 newOperands.push_back(cst);
999 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1000 newOperands,
false);
1005 if (size == 2 && value.isPowerOf2()) {
1010 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1011 auto replicateOperand = replicate.getOperand();
1012 if (replicateOperand.getType().isInteger(1)) {
1013 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1014 auto trailingZeros = value.countTrailingZeros();
1017 SmallVector<Value, 3> concatOperands;
1018 if (trailingZeros != resultWidth - 1) {
1020 rewriter, op.getLoc(),
1021 APInt::getZero(resultWidth - trailingZeros - 1));
1022 concatOperands.push_back(highZeros);
1024 concatOperands.push_back(replicateOperand);
1025 if (trailingZeros != 0) {
1027 rewriter, op.getLoc(), APInt::getZero(trailingZeros));
1028 concatOperands.push_back(lowZeros);
1030 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1031 rewriter, op, op.getType(), concatOperands);
1040 unsigned leadingZeros = value.countLeadingZeros();
1041 unsigned trailingZeros = value.countTrailingZeros();
1042 if (leadingZeros > 0 || trailingZeros > 0) {
1043 unsigned maskLength = value.getBitWidth() - leadingZeros - trailingZeros;
1046 SmallVector<Value> operands;
1047 for (
auto input : inputs.drop_back()) {
1048 unsigned offset = trailingZeros;
1049 while (
auto extractOp = input.getDefiningOp<
ExtractOp>()) {
1050 input = extractOp.getInput();
1051 offset += extractOp.getLowBit();
1054 offset, maskLength));
1058 auto narrowMask = value.extractBits(maskLength, trailingZeros);
1059 if (!narrowMask.isAllOnes())
1061 rewriter, inputs.back().getLoc(), narrowMask));
1064 Value narrowValue = operands.back();
1065 if (operands.size() > 1)
1067 AndOp::create(rewriter, op.getLoc(), operands, op.getTwoState());
1071 if (leadingZeros > 0)
1073 rewriter, op.getLoc(), APInt::getZero(leadingZeros)));
1074 operands.push_back(narrowValue);
1075 if (trailingZeros > 0)
1077 rewriter, op.getLoc(), APInt::getZero(trailingZeros)));
1078 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, operands);
1085 for (
size_t i = 0; i < size - 1; ++i) {
1086 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1100 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1101 source, cmpAgainst);
1109OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1113 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1114 auto inputs = adaptor.getInputs();
1116 for (
auto operand : inputs) {
1119 value |= cast<IntegerAttr>(operand).getValue();
1120 if (value.isAllOnes())
1125 if (inputs.size() == 2 && inputs[1] &&
1126 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1127 return getInputs()[0];
1130 if (llvm::all_of(getInputs(),
1131 [&](
auto in) {
return in == this->getInputs()[0]; }))
1132 return getInputs()[0];
1135 for (Value arg : getInputs()) {
1137 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1138 for (Value arg2 : getInputs())
1139 if (arg2 == subExpr)
1141 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1151 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1158LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1162 auto inputs = op.getInputs();
1163 auto size = inputs.size();
1175 assert(size > 1 &&
"expected 2 or more operands");
1179 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1181 if (value.isZero()) {
1182 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1183 inputs.drop_back());
1189 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1191 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1192 newOperands.push_back(cst);
1193 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1201 for (
size_t i = 0; i < size - 1; ++i) {
1202 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1216 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1217 source, cmpAgainst);
1223 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1225 if (op.getTwoState() && firstMux.getTwoState() &&
1226 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1228 SmallVector<Value> conditions{firstMux.getCond()};
1229 auto check = [&](Value v) {
1233 conditions.push_back(mux.getCond());
1234 return mux.getTwoState() &&
1235 firstMux.getTrueValue() == mux.getTrueValue() &&
1236 firstMux.getFalseValue() == mux.getFalseValue();
1238 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1239 auto cond = comb::OrOp::create(rewriter, op.getLoc(), conditions,
true);
1240 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1241 rewriter, op, cond, firstMux.getTrueValue(),
1242 firstMux.getFalseValue(),
true);
1252OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1256 auto size = getInputs().size();
1257 auto inputs = adaptor.getInputs();
1261 return getInputs()[0];
1264 if (size == 2 && getInputs()[0] == getInputs()[1])
1265 return IntegerAttr::get(getType(), 0);
1268 if (inputs.size() == 2 && inputs[1] &&
1269 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1270 return getInputs()[0];
1274 if (isBinaryNot()) {
1276 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1277 subExpr != getResult())
1287 PatternRewriter &rewriter) {
1288 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1289 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1292 ICmpOp::create(rewriter, icmp.getLoc(), negatedPred, icmp.getOperand(0),
1293 icmp.getOperand(1), icmp.getTwoState());
1296 if (op.getNumOperands() > 2) {
1297 SmallVector<Value, 4> newOperands(op.getOperands());
1298 newOperands.pop_back();
1299 newOperands.erase(newOperands.begin() + icmpOperand);
1300 newOperands.push_back(result);
1302 XorOp::create(rewriter, op.getLoc(), newOperands, op.getTwoState());
1308LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1312 auto inputs = op.getInputs();
1313 auto size = inputs.size();
1314 assert(size > 1 &&
"expected 2 or more operands");
1317 if (inputs[size - 1] == inputs[size - 2]) {
1319 "expected idempotent case for 2 elements handled already.");
1320 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1321 inputs.drop_back(2),
false);
1327 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1329 if (value.isZero()) {
1330 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1331 inputs.drop_back(),
false);
1337 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1339 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1340 newOperands.push_back(cst);
1341 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1342 newOperands,
false);
1346 bool isSingleBit = value.getBitWidth() == 1;
1349 for (
size_t i = 0; i < size - 1; ++i) {
1350 Value operand = inputs[i];
1361 if (isSingleBit && operand.hasOneUse()) {
1362 assert(value == 1 &&
"single bit constant has to be one if not zero");
1363 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1379 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1386OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1391 if (getRhs() == getLhs())
1393 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1396 if (adaptor.getRhs()) {
1398 if (adaptor.getLhs()) {
1401 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1403 auto rhsNeg = hw::ParamExprAttr::get(
1404 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1405 return hw::ParamExprAttr::get(hw::PEO::Add,
1406 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1410 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1411 if (rhsC.getValue().isZero())
1419LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1425 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1427 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1439OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1443 auto size = getInputs().size();
1447 return getInputs()[0];
1453LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1457 auto inputs = op.getInputs();
1458 auto size = inputs.size();
1459 assert(size > 1 &&
"expected 2 or more operands");
1461 APInt value, value2;
1464 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1465 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1466 inputs.drop_back(),
false);
1471 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1472 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1474 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1475 newOperands.push_back(cst);
1476 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1477 newOperands,
false);
1482 if (inputs[size - 1] == inputs[size - 2]) {
1483 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1487 comb::ShlOp::create(rewriter, op.getLoc(), inputs.back(), one,
false);
1489 newOperands.push_back(shiftLeftOp);
1490 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1491 newOperands,
false);
1495 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1497 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1498 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1500 APInt one(value.getBitWidth(), 1,
false);
1504 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1505 auto mulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1507 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1508 newOperands.push_back(mulOp);
1509 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1510 newOperands,
false);
1514 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1516 if (mulOp && mulOp.getInputs().size() == 2 &&
1517 mulOp.getInputs()[0] == inputs[size - 2] &&
1518 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1520 APInt one(value.getBitWidth(), 1,
false);
1522 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1523 auto newMulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1525 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1526 newOperands.push_back(newMulOp);
1527 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1528 newOperands,
false);
1541 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1542 if (addOp && addOp.getInputs().size() == 2 &&
1543 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1544 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1547 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1548 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1549 op.getTwoState() && addOp.getTwoState());
1556OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1560 auto size = getInputs().size();
1561 auto inputs = adaptor.getInputs();
1565 return getInputs()[0];
1567 auto width = cast<IntegerType>(getType()).getWidth();
1569 return getIntAttr(APInt::getZero(0), getContext());
1571 APInt value(width, 1,
false);
1574 for (
auto operand : inputs) {
1577 value *= cast<IntegerAttr>(operand).getValue();
1586LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1590 auto inputs = op.getInputs();
1591 auto size = inputs.size();
1592 assert(size > 1 &&
"expected 2 or more operands");
1594 APInt value, value2;
1597 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1598 value.isPowerOf2()) {
1600 value.exactLogBase2());
1602 comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift,
false);
1604 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1605 ArrayRef<Value>(shlOp),
false);
1610 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1611 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1612 inputs.drop_back());
1617 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1618 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1620 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1621 newOperands.push_back(cst);
1622 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1638template <
class Op,
bool isSigned>
1639static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1640 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1642 if (rhsValue.getValue() == 1)
1646 if (rhsValue.getValue().isZero())
1639static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
…}
1653OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1656 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1659OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1665template <
class Op,
bool isSigned>
1666static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1667 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1669 if (rhsValue.getValue() == 1)
1670 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1674 if (rhsValue.getValue().isZero())
1678 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1680 if (lhsValue.getValue().isZero())
1681 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1666static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
…}
1688OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1691 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1694OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1704OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1708 if (getNumOperands() == 1)
1709 return getOperand(0);
1712 for (
auto attr : adaptor.getInputs())
1713 if (!attr || !isa<IntegerAttr>(attr))
1717 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1718 APInt result(resultWidth, 0);
1720 unsigned nextInsertion = resultWidth;
1722 for (
auto attr : adaptor.getInputs()) {
1723 auto chunk = cast<IntegerAttr>(attr).getValue();
1724 nextInsertion -= chunk.getBitWidth();
1725 result.insertBits(chunk, nextInsertion);
1731LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1735 auto inputs = op.getInputs();
1736 auto size = inputs.size();
1737 assert(size > 1 &&
"expected 2 or more operands");
1742 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1743 ValueRange replacements) -> LogicalResult {
1744 SmallVector<Value, 4> newOperands;
1745 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1746 newOperands.append(replacements.begin(), replacements.end());
1747 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1748 if (newOperands.size() == 1)
1751 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1756 Value commonOperand = inputs[0];
1757 for (
size_t i = 0; i != size; ++i) {
1759 if (inputs[i] != commonOperand)
1760 commonOperand = Value();
1764 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1765 return flattenConcat(i, i, subConcat->getOperands());
1770 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1771 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1772 unsigned prevWidth = prevCst.getValue().getBitWidth();
1773 unsigned thisWidth = cst.getValue().getBitWidth();
1774 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1775 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1779 return flattenConcat(i - 1, i, replacement);
1784 if (inputs[i] == inputs[i - 1]) {
1786 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1787 return flattenConcat(i - 1, i, replacement);
1792 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1794 if (repl.getOperand() == inputs[i - 1]) {
1795 Value replacement = rewriter.createOrFold<ReplicateOp>(
1796 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1797 return flattenConcat(i - 1, i, replacement);
1800 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1801 if (prevRepl.getOperand() == repl.getOperand()) {
1802 Value replacement = rewriter.createOrFold<ReplicateOp>(
1803 op.getLoc(), repl.getOperand(),
1804 repl.getMultiple() + prevRepl.getMultiple());
1805 return flattenConcat(i - 1, i, replacement);
1811 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1812 if (repl.getOperand() == inputs[i]) {
1813 Value replacement = rewriter.createOrFold<ReplicateOp>(
1814 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1815 return flattenConcat(i - 1, i, replacement);
1821 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1822 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1823 if (extract.getInput() == prevExtract.getInput()) {
1824 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1825 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1826 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1827 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1830 extract.getInput(), extract.getLowBit());
1831 return flattenConcat(i - 1, i, replacement);
1844 static std::optional<ArraySlice>
get(Value value) {
1845 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1847 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1850 if (
auto arraySlice =
1853 arraySlice.getInput(), arraySlice.getLowIndex(),
1854 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1856 return std::nullopt;
1859 if (
auto extractOpt = ArraySlice::get(inputs[i])) {
1860 if (
auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1862 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1863 prevExtractOpt->input == extractOpt->input &&
1864 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1865 extractOpt->width)) {
1866 auto resType = hw::ArrayType::get(
1867 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1869 extractOpt->width + prevExtractOpt->width);
1870 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1872 rewriter, op.getLoc(), resIntType,
1874 prevExtractOpt->input,
1875 extractOpt->index));
1876 return flattenConcat(i - 1, i, replacement);
1884 if (commonOperand) {
1885 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1897OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1902 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1903 return getTrueValue();
1904 if (
auto tv = adaptor.getTrueValue())
1905 if (tv == adaptor.getFalseValue())
1910 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1911 if (pred.getValue().isZero() && getFalseValue() != getResult())
1912 return getFalseValue();
1913 if (pred.getValue().isOne() && getTrueValue() != getResult())
1914 return getTrueValue();
1918 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1919 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1920 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1921 hw::getBitWidth(getType()) == 1 && getCond() != getResult())
1937 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1939 auto requiredPredicate =
1940 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1941 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1951 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
1954 for (
auto operand : orOp.getOperands())
1961 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
1964 for (
auto operand : andOp.getOperands())
1982 PatternRewriter &rewriter) {
1985 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
1988 Value indexValue = rootCmp.getLhs();
1991 auto getCaseValue = [&](
MuxOp mux) -> Value {
1992 return mux.getOperand(1 +
unsigned(!isFalseSide));
1997 auto getTreeValue = [&](
MuxOp mux) -> Value {
1998 return mux.getOperand(1 +
unsigned(isFalseSide));
2003 SmallVector<Location> locationsFound;
2004 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2008 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2010 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2011 valuesFound.push_back({cst, getCaseValue(mux)});
2012 locationsFound.push_back(mux.getCond().getLoc());
2013 locationsFound.push_back(mux->getLoc());
2018 if (!collectConstantValues(rootMux))
2022 if (rootMux->hasOneUse()) {
2023 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2024 if (getTreeValue(userMux) == rootMux.getResult() &&
2032 auto nextTreeValue = getTreeValue(rootMux);
2034 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2035 if (!nextMux || !nextMux->hasOneUse())
2037 if (!collectConstantValues(nextMux))
2039 nextTreeValue = getTreeValue(nextMux);
2045 if (valuesFound.size() < 3)
2050 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2051 if (indexWidth >= 9)
2057 uint64_t tableSize = 1ULL << indexWidth;
2058 if (valuesFound.size() < (tableSize * 5) / 8)
2063 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2068 for (
auto &elt :
llvm::reverse(valuesFound)) {
2069 uint64_t idx = elt.first.getValue().getZExtValue();
2070 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2071 table[idx] = elt.second;
2076 std::reverse(table.begin(), table.end());
2079 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2081 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2096 PatternRewriter &rewriter) {
2097 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2098 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2102 if (fullyAssoc->getNumOperands() == 2)
2103 return fullyAssoc->getOperand(operandNo ^ 1);
2106 if (fullyAssoc->hasOneUse()) {
2107 rewriter.modifyOpInPlace(fullyAssoc,
2108 [&]() { fullyAssoc->eraseOperand(operandNo); });
2109 return fullyAssoc->getResult(0);
2113 SmallVector<Value> operands;
2114 operands.append(fullyAssoc->getOperands().begin(),
2115 fullyAssoc->getOperands().begin() + operandNo);
2116 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2117 fullyAssoc->getOperands().end());
2119 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2120 Value excluded = fullyAssoc->getOperand(operandNo);
2124 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2126 return opWithoutExcluded;
2136 PatternRewriter &rewriter) {
2139 Operation *subExpr =
2140 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2141 if (!subExpr || subExpr->getNumOperands() < 2)
2145 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2150 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2151 size_t opNo = 0, e = subExpr->getNumOperands();
2152 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2158 Value cond = op.getCond();
2164 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2169 Value subCond = subMux.getCond();
2172 if (subMux.getTrueValue() == commonValue)
2173 otherValue = subMux.getFalseValue();
2174 else if (subMux.getFalseValue() == commonValue) {
2175 otherValue = subMux.getTrueValue();
2185 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2186 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2187 otherValue, op.getTwoState());
2193 bool isaAndOp = isa<AndOp>(subExpr);
2194 if (isTrueOperand ^ isaAndOp)
2198 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2201 bool isaXorOp = isa<XorOp>(subExpr);
2202 bool isaOrOp = isa<OrOp>(subExpr);
2211 if (isaOrOp || isaXorOp) {
2212 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2213 restOfAssoc,
false);
2215 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2216 commonValue,
false);
2218 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2224 assert(isaAndOp &&
"unexpected operation here");
2225 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2226 restOfAssoc,
false);
2227 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2238 PatternRewriter &rewriter) {
2241 if (!isa<ConcatOp>(trueOp))
2245 SmallVector<Value> trueOperands, falseOperands;
2249 size_t numTrueOperands = trueOperands.size();
2250 size_t numFalseOperands = falseOperands.size();
2252 if (!numTrueOperands || !numFalseOperands ||
2253 (trueOperands.front() != falseOperands.front() &&
2254 trueOperands.back() != falseOperands.back()))
2258 if (trueOperands.front() == falseOperands.front()) {
2259 SmallVector<Value> operands;
2261 for (i = 0; i < numTrueOperands; ++i) {
2262 Value trueOperand = trueOperands[i];
2263 if (trueOperand == falseOperands[i])
2264 operands.push_back(trueOperand);
2268 if (i == numTrueOperands) {
2275 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2276 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2277 mux->getLoc(), operands.front(), operands.size());
2279 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2283 operands.append(trueOperands.begin() + i, trueOperands.end());
2284 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2286 operands.append(falseOperands.begin() + i, falseOperands.end());
2288 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2291 Value lsb = rewriter.createOrFold<
MuxOp>(
2292 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2293 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2298 if (trueOperands.back() == falseOperands.back()) {
2299 SmallVector<Value> operands;
2302 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2303 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2304 operands.push_back(trueOperand);
2308 std::reverse(operands.begin(), operands.end());
2309 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2313 operands.append(trueOperands.begin(), trueOperands.end() - i);
2314 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2316 operands.append(falseOperands.begin(), falseOperands.end() - i);
2318 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2320 Value msb = rewriter.createOrFold<
MuxOp>(
2321 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2322 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2334 if (!trueVec || !falseVec)
2336 if (!trueVec.isUniform() || !falseVec.isUniform())
2339 auto mux = MuxOp::create(rewriter, op.getLoc(), op.getCond(),
2340 trueVec.getUniformElement(),
2341 falseVec.getUniformElement(), op.getTwoState());
2343 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2351 bool constCond, PatternRewriter &rewriter) {
2352 if (!muxValue.hasOneUse())
2354 auto *op = muxValue.getDefiningOp();
2355 if (!op || !isa_and_nonnull<CombDialect>(op->getDialect()))
2357 if (!llvm::is_contained(op->getOperands(), muxCond))
2359 OpBuilder::InsertionGuard guard(rewriter);
2360 rewriter.setInsertionPoint(op);
2363 rewriter.modifyOpInPlace(op, [&] {
2364 for (
auto &use : op->getOpOperands())
2365 if (use.get() == muxCond)
2373 using OpRewritePattern::OpRewritePattern;
2375 LogicalResult matchAndRewrite(
MuxOp op,
2376 PatternRewriter &rewriter)
const override;
2379LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2380 PatternRewriter &rewriter)
const {
2389 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2390 if (value.getBitWidth() == 1) {
2392 if (value.isZero()) {
2394 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2395 op.getFalseValue(),
false);
2400 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2401 op.getFalseValue(),
false);
2407 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2412 APInt xorValue = value ^ value2;
2413 if (xorValue.isPowerOf2()) {
2414 unsigned leadingZeros = xorValue.countLeadingZeros();
2415 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2416 SmallVector<Value, 3> operands;
2424 if (leadingZeros > 0)
2425 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2426 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2430 auto v1 = rewriter.createOrFold<
ExtractOp>(
2431 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2432 auto v2 = rewriter.createOrFold<
ExtractOp>(
2433 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2434 operands.push_back(rewriter.createOrFold<
MuxOp>(
2435 op.getLoc(), op.getCond(), v1, v2,
false));
2437 if (trailingZeros > 0)
2438 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2439 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2441 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2448 if (value.isAllOnes() && value2.isZero()) {
2449 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2450 rewriter, op, op.getType(), op.getCond());
2456 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2457 value.getBitWidth() == 1) {
2459 if (value.isZero()) {
2460 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2461 op.getTrueValue(),
false);
2468 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2469 op.getFalseValue(),
false);
2470 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2471 op.getTrueValue(),
false);
2477 Operation *condOp = op.getCond().getDefiningOp();
2478 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2480 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2481 subExpr, op.getFalseValue(),
2482 op.getTrueValue(),
true);
2489 if (condOp && condOp->hasOneUse()) {
2490 SmallVector<Value> invertedOperands;
2494 auto getInvertedOperands = [&]() ->
bool {
2495 for (Value operand : condOp->getOperands()) {
2496 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2497 invertedOperands.push_back(subExpr);
2504 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2506 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2507 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2508 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2512 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2514 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2515 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2516 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2522 if (
auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
2523 falseMux && falseMux != op) {
2525 if (op.getCond() == falseMux.getCond()) {
2526 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2527 rewriter, op, op.getCond(), op.getTrueValue(),
2528 falseMux.getFalseValue(), op.getTwoStateAttr());
2537 if (
auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
2538 trueMux && trueMux != op) {
2540 if (op.getCond() == trueMux.getCond()) {
2541 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2542 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2543 op.getFalseValue(), op.getTwoStateAttr());
2553 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2554 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2555 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2556 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2558 auto subMux = MuxOp::create(
2559 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2560 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2561 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2562 trueMux.getTrueValue(), subMux,
2563 op.getTwoStateAttr());
2568 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2569 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2570 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2571 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2573 auto subMux = MuxOp::create(
2574 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2575 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2576 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2577 subMux, trueMux.getFalseValue(),
2578 op.getTwoStateAttr());
2583 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2584 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2585 trueMux && falseMux &&
2586 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2587 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2590 MuxOp::create(rewriter,
2591 rewriter.getFusedLoc(
2592 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2593 op.getCond(), trueMux.getCond(), falseMux.getCond());
2594 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2595 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2596 op.getTwoStateAttr());
2608 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2609 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2610 if (trueOp->getName() == falseOp->getName())
2623 if (op.getTrueValue().getDefiningOp() &&
2624 op.getTrueValue().getDefiningOp() != op)
2627 if (op.getFalseValue().getDefiningOp() &&
2628 op.getFalseValue().getDefiningOp() != op)
2639 if (op.getInputs().empty() || op.isUniform())
2641 auto inputs = op.getInputs();
2642 if (inputs.size() <= 1)
2647 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2652 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2653 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2654 if (!input || first.getCond() != input.getCond())
2659 SmallVector<Value> trues{first.getTrueValue()};
2660 SmallVector<Value> falses{first.getFalseValue()};
2661 SmallVector<Location> locs{first->getLoc()};
2662 bool isTwoState =
true;
2663 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2664 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2665 trues.push_back(input.getTrueValue());
2666 falses.push_back(input.getFalseValue());
2667 locs.push_back(input->getLoc());
2668 if (!input.getTwoState())
2673 auto loc = FusedLoc::get(op.getContext(), locs);
2677 auto arrayTy = op.getType();
2680 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2681 trueValues, falseValues, isTwoState);
2686 using OpRewritePattern::OpRewritePattern;
2689 PatternRewriter &rewriter)
const override {
2690 if (foldArrayOfMuxes(op, rewriter))
2698void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2699 MLIRContext *context) {
2700 results.insert<MuxRewriter, ArrayRewriter>(context);
2711 switch (predicate) {
2712 case ICmpPredicate::eq:
2714 case ICmpPredicate::ne:
2716 case ICmpPredicate::slt:
2717 return lhs.slt(rhs);
2718 case ICmpPredicate::sle:
2719 return lhs.sle(rhs);
2720 case ICmpPredicate::sgt:
2721 return lhs.sgt(rhs);
2722 case ICmpPredicate::sge:
2723 return lhs.sge(rhs);
2724 case ICmpPredicate::ult:
2725 return lhs.ult(rhs);
2726 case ICmpPredicate::ule:
2727 return lhs.ule(rhs);
2728 case ICmpPredicate::ugt:
2729 return lhs.ugt(rhs);
2730 case ICmpPredicate::uge:
2731 return lhs.uge(rhs);
2732 case ICmpPredicate::ceq:
2734 case ICmpPredicate::cne:
2736 case ICmpPredicate::weq:
2738 case ICmpPredicate::wne:
2741 llvm_unreachable(
"unknown comparison predicate");
2747 switch (predicate) {
2748 case ICmpPredicate::eq:
2749 case ICmpPredicate::sle:
2750 case ICmpPredicate::sge:
2751 case ICmpPredicate::ule:
2752 case ICmpPredicate::uge:
2753 case ICmpPredicate::ceq:
2754 case ICmpPredicate::weq:
2756 case ICmpPredicate::ne:
2757 case ICmpPredicate::slt:
2758 case ICmpPredicate::sgt:
2759 case ICmpPredicate::ult:
2760 case ICmpPredicate::ugt:
2761 case ICmpPredicate::cne:
2762 case ICmpPredicate::wne:
2765 llvm_unreachable(
"unknown comparison predicate");
2768OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2771 if (getLhs() == getRhs()) {
2773 return IntegerAttr::get(getType(), val);
2777 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2778 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2781 return IntegerAttr::get(getType(), val);
2789template <
typename Range>
2791 size_t commonPrefixLength = 0;
2792 auto ia = a.begin();
2793 auto ib = b.begin();
2795 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2801 return commonPrefixLength;
2805 size_t totalWidth = 0;
2806 for (
auto operand : operands) {
2809 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2811 totalWidth += width;
2821 PatternRewriter &rewriter) {
2825 SmallVector<Value> lhsOperands, rhsOperands;
2828 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2830 auto formCatOrReplicate = [&](Location loc,
2831 ArrayRef<Value> operands) -> Value {
2832 assert(!operands.empty());
2833 Value sameElement = operands[0];
2834 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2835 if (sameElement != operands[i])
2836 sameElement = Value();
2838 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2840 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2843 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2844 Value rhs) -> LogicalResult {
2845 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2850 size_t commonPrefixLength =
2852 if (commonPrefixLength == lhsOperands.size()) {
2855 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2861 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2863 size_t commonPrefixTotalWidth =
2864 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2865 size_t commonSuffixTotalWidth =
2866 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2867 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2868 .drop_back(commonSuffixLength);
2869 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2870 .drop_back(commonSuffixLength);
2872 auto replaceWithoutReplicatingSignBit = [&]() {
2873 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2874 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2875 return replaceWith(op.getPredicate(), newLhs, newRhs);
2878 auto replaceWithReplicatingSignBit = [&]() {
2879 auto firstNonEmptyValue = lhsOperands[0];
2880 auto firstNonEmptyElemWidth =
2881 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2882 Value signBit = rewriter.createOrFold<
ExtractOp>(
2883 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2885 auto newLhs = ConcatOp::create(rewriter, lhs->getLoc(), signBit, lhsOnly);
2886 auto newRhs = ConcatOp::create(rewriter, rhs->getLoc(), signBit, rhsOnly);
2887 return replaceWith(op.getPredicate(), newLhs, newRhs);
2890 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2892 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2893 return replaceWithoutReplicatingSignBit();
2899 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2900 return replaceWithReplicatingSignBit();
2902 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2904 return replaceWithoutReplicatingSignBit();
2918 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2919 PatternRewriter &rewriter) {
2923 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2924 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2927 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2928 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2936 SmallVector<Value> newConcatOperands;
2937 auto newConstant = APInt::getZeroWidth();
2942 unsigned knownMSB = bitsKnown.countLeadingOnes();
2944 Value operand = cmpOp.getLhs();
2949 while (knownMSB != bitsKnown.getBitWidth()) {
2952 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2955 unsigned unknownBits = bitsKnown.countLeadingZeros();
2956 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2957 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
2958 operand.getLoc(), operand, lowBit,
2960 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2963 newConcatOperands.push_back(spanOperand);
2966 if (newConstant.getBitWidth() != 0)
2967 newConstant = newConstant.concat(spanConstant);
2969 newConstant = spanConstant;
2972 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
2973 bitsKnown = bitsKnown.trunc(newWidth);
2974 knownMSB = bitsKnown.countLeadingOnes();
2980 if (newConcatOperands.empty()) {
2981 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
2982 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2988 Value concatResult =
2989 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
2993 rewriter, cmpOp.getOperand(1).getLoc(), newConstant);
2995 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
2996 cmpOp.getPredicate(), concatResult,
2997 newConstantOp, cmpOp.getTwoState());
3003 PatternRewriter &rewriter) {
3004 auto ip = rewriter.saveInsertionPoint();
3005 rewriter.setInsertionPoint(xorOp);
3007 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
3009 xorRHS.getValue() ^ rhs);
3011 switch (xorOp.getNumOperands()) {
3015 APInt::getZero(rhs.getBitWidth()));
3019 newLHS = xorOp.getOperand(0);
3023 SmallVector<Value> newOperands(xorOp.getOperands());
3024 newOperands.pop_back();
3025 newLHS = XorOp::create(rewriter, xorOp.getLoc(), newOperands,
false);
3029 bool xorMultipleUses = !xorOp->hasOneUse();
3033 if (xorMultipleUses)
3034 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3038 rewriter.restoreInsertionPoint(ip);
3039 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3040 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS,
false);
3043LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3049 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3050 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3051 "Should be folded");
3052 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3053 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3054 op.getRhs(), op.getLhs(), op.getTwoState());
3059 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3064 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3065 Value rhs) -> LogicalResult {
3066 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3067 rhs, op.getTwoState());
3071 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3072 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3073 APInt(1, constant));
3077 switch (op.getPredicate()) {
3078 case ICmpPredicate::slt:
3080 if (rhs.isMaxSignedValue())
3081 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3083 if (rhs.isMinSignedValue())
3084 return replaceWithConstantI1(0);
3086 if ((rhs - 1).isMinSignedValue())
3087 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3090 case ICmpPredicate::sgt:
3092 if (rhs.isMinSignedValue())
3093 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3095 if (rhs.isMaxSignedValue())
3096 return replaceWithConstantI1(0);
3098 if ((rhs + 1).isMaxSignedValue())
3099 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3102 case ICmpPredicate::ult:
3104 if (rhs.isAllOnes())
3105 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3108 return replaceWithConstantI1(0);
3110 if ((rhs - 1).isZero())
3111 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3115 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3116 rhs.getBitWidth()) {
3117 auto numOnes = rhs.countLeadingOnes();
3119 rhs.getBitWidth() - numOnes, numOnes);
3120 return replaceWith(ICmpPredicate::ne, smaller,
3125 case ICmpPredicate::ugt:
3128 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3130 if (rhs.isAllOnes())
3131 return replaceWithConstantI1(0);
3133 if ((rhs + 1).isAllOnes())
3134 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3138 if ((rhs + 1).isPowerOf2()) {
3139 auto numOnes = rhs.countTrailingOnes();
3140 auto newWidth = rhs.getBitWidth() - numOnes;
3143 return replaceWith(ICmpPredicate::ne, smaller,
3148 case ICmpPredicate::sle:
3150 if (rhs.isMaxSignedValue())
3151 return replaceWithConstantI1(1);
3153 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3154 case ICmpPredicate::sge:
3156 if (rhs.isMinSignedValue())
3157 return replaceWithConstantI1(1);
3159 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3160 case ICmpPredicate::ule:
3162 if (rhs.isAllOnes())
3163 return replaceWithConstantI1(1);
3165 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3166 case ICmpPredicate::uge:
3169 return replaceWithConstantI1(1);
3171 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3172 case ICmpPredicate::eq:
3173 if (rhs.getBitWidth() == 1) {
3176 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3181 if (rhs.isAllOnes()) {
3188 case ICmpPredicate::ne:
3189 if (rhs.getBitWidth() == 1) {
3195 if (rhs.isAllOnes()) {
3197 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3204 case ICmpPredicate::ceq:
3205 case ICmpPredicate::cne:
3206 case ICmpPredicate::weq:
3207 case ICmpPredicate::wne:
3213 if (op.getPredicate() == ICmpPredicate::eq ||
3214 op.getPredicate() == ICmpPredicate::ne) {
3219 if (!knownBits.isUnknown())
3226 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3233 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3234 if (rhs.isAllOnes() || rhs.isZero()) {
3235 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3238 rhs.isAllOnes() ? APInt::getAllOnes(width)
3239 : APInt::getZero(width));
3240 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3241 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3251 if (Operation *opLHS = op.getLhs().getDefiningOp())
3252 if (Operation *opRHS = op.getRhs().getDefiningOp())
3253 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3254 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool assumeMuxCondInOperand(Value muxCond, Value muxValue, bool constCond, PatternRewriter &rewriter)
If the mux condition is an operand to the op defining its true or false value, replace the condition ...
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool isOpTriviallyRecursive(Operation *op)
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(array_value, low_index, ret_type)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.