13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/Matchers.h"
15#include "mlir/IR/PatternMatch.h"
16#include "llvm/ADT/SetVector.h"
17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/KnownBits.h"
24using namespace matchers;
28 return llvm::any_of(op->getOperands(), [op](
auto operand) {
29 return operand.getDefiningOp() == op;
39 ArrayRef<Value> operands, OpBuilder &builder) {
40 OperationState state(loc, name);
41 state.addOperands(operands);
42 state.addTypes(operands[0].getType());
43 return builder.create(state)->getResult(0);
47 return IntegerAttr::get(IntegerType::get(
context, value.getBitWidth()),
54 for (
auto op :
concat.getOperands())
56 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
57 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
68 return op->hasAttr(
"sv.attributes");
72template <
typename SubType>
73struct ComplementMatcher {
75 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
76 bool match(Operation *op) {
77 auto xorOp = dyn_cast<XorOp>(op);
78 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
83template <
typename SubType>
84static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
85 return ComplementMatcher<SubType>(subExpr);
91 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
92 "must be commutative operations"));
93 if (op->hasOneUse()) {
94 auto *user = *op->getUsers().begin();
95 return user->getName() == op->getName() &&
96 op->getAttrOfType<UnitAttr>(
"twoState") ==
97 user->getAttrOfType<UnitAttr>(
"twoState") &&
98 op->getBlock() == user->getBlock();
113 auto inputs = op->getOperands();
115 SmallVector<Value, 4> newOperands;
116 SmallVector<Location, 4> newLocations{op->getLoc()};
117 newOperands.reserve(inputs.size());
119 decltype(inputs.begin()) current, end;
122 SmallVector<Element> worklist;
123 worklist.push_back({inputs.begin(), inputs.end()});
124 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
125 bool changed =
false;
126 while (!worklist.empty()) {
127 auto &element = worklist.back();
130 if (element.current == element.end) {
135 Value value = *element.current++;
136 auto *flattenOp = value.getDefiningOp();
139 if (!flattenOp || flattenOp->getName() != op->getName() ||
140 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
141 flattenOp->getBlock() != op->getBlock()) {
142 newOperands.push_back(value);
147 if (!value.hasOneUse()) {
155 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
158 newOperands.push_back(value);
166 auto flattenOpInputs = flattenOp->getOperands();
167 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
168 newLocations.push_back(flattenOp->getLoc());
174 Value result =
createGenericOp(FusedLoc::get(op->getContext(), newLocations),
175 op->getName(), newOperands, rewriter);
177 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
185static std::pair<size_t, size_t>
187 size_t originalOpWidth) {
188 auto users = op->getUsers();
190 "getLowestBitAndHighestBitRequired cannot operate on "
191 "a empty list of uses.");
195 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
196 size_t highestBitRequired = 0;
198 for (
auto *user : users) {
199 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
200 size_t lowBit = extractOp.getLowBit();
202 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
203 highestBitRequired = std::max(highestBitRequired, highBit);
204 lowestBitRequired = std::min(lowestBitRequired, lowBit);
208 highestBitRequired = originalOpWidth - 1;
209 lowestBitRequired = 0;
213 return {lowestBitRequired, highestBitRequired};
218 PatternRewriter &rewriter) {
219 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
225 if (range.second + 1 == opType.getWidth() && range.first == 0)
228 SmallVector<Value> args;
229 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
230 for (
auto inop : op.getOperands()) {
232 if (inop.getType() != op.getType())
233 args.push_back(inop);
235 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
238 auto newop = OpTy::create(rewriter, op.getLoc(), newType, args);
239 newop->setDialectAttrs(op->getDialectAttrs());
240 if (op.getTwoState())
241 newop.setTwoState(
true);
243 Value newResult = newop.getResult();
245 newResult = rewriter.createOrFold<
ConcatOp>(
246 op.getLoc(), newResult,
248 APInt::getZero(range.first)));
249 if (range.second + 1 < opType.getWidth())
250 newResult = rewriter.createOrFold<
ConcatOp>(
253 rewriter, op.getLoc(),
254 APInt::getZero(opType.getWidth() - range.second - 1)),
256 rewriter.replaceOp(op, newResult);
264OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
269 if (cast<IntegerType>(getType()).
getWidth() ==
270 getInput().getType().getIntOrFloatBitWidth())
274 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
275 if (input.getValue().getBitWidth() == 1) {
276 if (input.getValue().isZero())
278 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
281 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
285 APInt result = APInt::getZeroWidth();
286 for (
auto i = getMultiple(); i != 0; --i)
287 result = result.concat(input.getValue());
294OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
299 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
300 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
312 hw::PEO paramOpcode) {
313 assert(operands.size() == 2 &&
"binary op takes two operands");
314 if (!operands[0] || !operands[1])
319 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
320 cast<TypedAttr>(operands[1]));
323OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
327 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
328 if (rhs.getValue().isZero())
329 return getOperand(0);
331 unsigned width = getType().getIntOrFloatBitWidth();
332 if (rhs.getValue().uge(width))
333 return getIntAttr(APInt::getZero(width), getContext());
338LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
344 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
347 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
348 if (value.ugt(width))
350 unsigned shift = value.getZExtValue();
353 if (width <= shift || shift == 0)
363 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
367OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
371 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
372 if (rhs.getValue().isZero())
373 return getOperand(0);
375 unsigned width = getType().getIntOrFloatBitWidth();
376 if (rhs.getValue().uge(width))
377 return getIntAttr(APInt::getZero(width), getContext());
382LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
388 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
391 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
392 if (value.ugt(width))
394 unsigned shift = value.getZExtValue();
397 if (width <= shift || shift == 0)
407 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
411OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
415 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
416 if (rhs.getValue().isZero())
417 return getOperand(0);
421LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
427 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
430 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
431 if (value.ugt(width))
433 unsigned shift = value.getZExtValue();
436 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
437 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
439 if (width == shift) {
447 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
455OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
460 if (getInput().getType() == getType())
464 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
465 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
466 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
477 PatternRewriter &rewriter) {
478 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
479 size_t beginOfFirstRelevantElement = 0;
480 auto it = reversedConcatArgs.begin();
481 size_t lowBit = op.getLowBit();
484 for (; it != reversedConcatArgs.end(); it++) {
485 assert(beginOfFirstRelevantElement <= lowBit &&
486 "incorrectly moved past an element that lowBit has coverage over");
489 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
490 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
514 beginOfFirstRelevantElement += operandWidth;
516 assert(it != reversedConcatArgs.end() &&
517 "incorrectly failed to find an element which contains coverage of "
520 SmallVector<Value> reverseConcatArgs;
521 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
522 size_t extractLo = lowBit - beginOfFirstRelevantElement;
527 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
528 auto concatArg = *it;
529 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
530 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
532 if (widthToConsume == operandWidth && extractLo == 0) {
533 reverseConcatArgs.push_back(concatArg);
535 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
536 reverseConcatArgs.push_back(
540 widthRemaining -= widthToConsume;
546 if (reverseConcatArgs.size() == 1) {
549 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
550 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
557 PatternRewriter &rewriter) {
558 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
559 auto replicateEltWidth =
560 replicate.getOperand().getType().getIntOrFloatBitWidth();
564 if (op.getLowBit() % replicateEltWidth == 0 &&
565 extractResultWidth % replicateEltWidth == 0) {
566 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
567 replicate.getOperand());
573 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
575 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
576 rewriter, op, op.getType(), replicate.getOperand(),
577 op.getLowBit() % replicateEltWidth);
586LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
589 auto *inputOp = op.getInput().getDefiningOp();
596 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
598 if (knownBits.isConstant()) {
599 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
600 knownBits.getConstant());
606 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
607 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
608 rewriter, op, op.getType(), innerExtract.getInput(),
609 innerExtract.getLowBit() + op.getLowBit());
614 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
618 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
624 if (inputOp && inputOp->getNumOperands() == 2 &&
625 isa<AndOp, OrOp, XorOp>(inputOp)) {
626 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
627 auto extractedCst = cstRHS.getValue().extractBits(
628 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
629 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
630 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
631 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
639 if (isa<AndOp>(inputOp)) {
642 unsigned lz = extractedCst.countLeadingZeros();
643 unsigned tz = extractedCst.countTrailingZeros();
644 unsigned pop = extractedCst.popcount();
645 if (extractedCst.getBitWidth() - lz - tz == pop) {
646 auto resultTy = rewriter.getIntegerType(pop);
647 SmallVector<Value> resultElts;
650 APInt::getZero(lz)));
651 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
652 op.getLoc(), resultTy, inputOp->getOperand(0),
653 op.getLowBit() + tz));
656 APInt::getZero(tz)));
657 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
666 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
667 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
669 if (shlOp->hasOneUse())
671 if (lhsCst.getValue().isOne()) {
673 rewriter, shlOp.getLoc(),
674 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
675 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
676 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
692 hw::PEO paramOpcode) {
693 assert(operands.size() > 1 &&
"caller should handle one-operand case");
696 if (!operands[1] || !operands[0])
700 if (llvm::all_of(operands.drop_front(2),
701 [&](Attribute in) { return !!in; })) {
702 SmallVector<mlir::TypedAttr> typedOperands;
703 typedOperands.reserve(operands.size());
704 for (
auto operand : operands) {
705 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
706 typedOperands.push_back(typedOperand);
710 if (typedOperands.size() == operands.size())
711 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
727 size_t concatIdx,
const APInt &cst,
728 PatternRewriter &rewriter) {
729 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
730 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
735 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
736 auto *operandOp = operand.getDefiningOp();
741 if (isa<hw::ConstantOp>(operandOp))
745 return operandOp->getName() == logicalOp->getName() &&
746 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
747 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
755 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
756 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
763 SmallVector<Value> newConcatOperands;
764 newConcatOperands.reserve(concatOp->getNumOperands());
767 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
768 for (Value operand : concatOp->getOperands()) {
769 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
770 nextOperandBit -= operandWidth;
774 cst.lshr(nextOperandBit).trunc(operandWidth));
776 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
781 ConcatOp::create(rewriter, concatOp.getLoc(), newConcatOperands);
785 if (logicalOp->getNumOperands() > 2) {
786 auto origOperands = logicalOp->getOperands();
787 SmallVector<Value> operands;
789 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
791 operands.append(origOperands.begin() + concatIdx + 1,
792 origOperands.begin() + (origOperands.size() - 1));
794 operands.push_back(newResult);
795 newResult = createLogicalOp(operands);
805 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
807 for (
auto op : operands) {
808 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
809 icmpOp && icmpOp.getTwoState()) {
810 auto predicate = icmpOp.getPredicate();
811 auto lhs = icmpOp.getLhs();
812 auto rhs = icmpOp.getRhs();
813 if (seenPredicates.contains(
814 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
817 seenPredicates.insert({predicate, lhs, rhs});
823OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
827 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
829 auto inputs = adaptor.getInputs();
832 for (
auto operand : inputs) {
833 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
836 value &= attr.getValue();
842 if (inputs.size() == 2)
843 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
844 if (intAttr.getValue().isAllOnes())
845 return getInputs()[0];
848 if (llvm::all_of(getInputs(),
849 [&](
auto in) {
return in == this->getInputs()[0]; }))
850 return getInputs()[0];
853 for (Value arg : getInputs()) {
856 for (Value arg2 : getInputs())
859 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
880template <
typename Op>
882 if (!op.getType().isInteger(1))
885 auto inputs = op.getInputs();
886 size_t size = inputs.size();
888 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
891 Value source = sourceOp.getOperand();
894 if (size != source.getType().getIntOrFloatBitWidth())
898 llvm::BitVector bits(size);
899 bits.set(sourceOp.getLowBit());
901 for (
size_t i = 1; i != size; ++i) {
902 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
903 if (!extractOp || extractOp.getOperand() != source)
905 bits.set(extractOp.getLowBit());
908 return bits.all() ? source : Value();
915template <
typename Op>
918 constexpr unsigned limit = 3;
919 auto inputs = op.getInputs();
921 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
922 llvm::SmallDenseSet<Op, 8> checked;
929 llvm::SmallVector<OpWithDepth, 8> worklist;
931 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
935 if (depth < limit && input.getParentBlock() == op->getBlock()) {
936 auto inputOp = input.template getDefiningOp<Op>();
937 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
938 checked.insert(inputOp).second)
939 worklist.push_back({inputOp, depth + 1});
943 for (
auto input : uniqueInputs)
946 while (!worklist.empty()) {
947 auto item = worklist.pop_back_val();
949 for (
auto input : item.op.getInputs()) {
950 uniqueInputs.remove(input);
951 enqueue(input, item.depth);
955 if (uniqueInputs.size() < inputs.size()) {
956 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
957 uniqueInputs.getArrayRef(),
965LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
969 auto inputs = op.getInputs();
970 auto size = inputs.size();
982 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
986 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
988 if (value.isAllOnes()) {
989 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
990 inputs.drop_back(),
false);
998 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1000 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1001 newOperands.push_back(cst);
1002 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1003 newOperands,
false);
1008 if (size == 2 && value.isPowerOf2()) {
1013 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1014 auto replicateOperand = replicate.getOperand();
1015 if (replicateOperand.getType().isInteger(1)) {
1016 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1017 auto trailingZeros = value.countTrailingZeros();
1020 SmallVector<Value, 3> concatOperands;
1021 if (trailingZeros != resultWidth - 1) {
1023 rewriter, op.getLoc(),
1024 APInt::getZero(resultWidth - trailingZeros - 1));
1025 concatOperands.push_back(highZeros);
1027 concatOperands.push_back(replicateOperand);
1028 if (trailingZeros != 0) {
1030 rewriter, op.getLoc(), APInt::getZero(trailingZeros));
1031 concatOperands.push_back(lowZeros);
1033 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1034 rewriter, op, op.getType(), concatOperands);
1043 unsigned leadingZeros = value.countLeadingZeros();
1044 unsigned trailingZeros = value.countTrailingZeros();
1045 if (leadingZeros > 0 || trailingZeros > 0) {
1046 unsigned maskLength = value.getBitWidth() - leadingZeros - trailingZeros;
1049 SmallVector<Value> operands;
1050 for (
auto input : inputs.drop_back()) {
1051 unsigned offset = trailingZeros;
1052 while (
auto extractOp = input.getDefiningOp<
ExtractOp>()) {
1053 input = extractOp.getInput();
1054 offset += extractOp.getLowBit();
1057 offset, maskLength));
1061 auto narrowMask = value.extractBits(maskLength, trailingZeros);
1062 if (!narrowMask.isAllOnes())
1064 rewriter, inputs.back().getLoc(), narrowMask));
1067 Value narrowValue = operands.back();
1068 if (operands.size() > 1)
1070 AndOp::create(rewriter, op.getLoc(), operands, op.getTwoState());
1074 if (leadingZeros > 0)
1076 rewriter, op.getLoc(), APInt::getZero(leadingZeros)));
1077 operands.push_back(narrowValue);
1078 if (trailingZeros > 0)
1080 rewriter, op.getLoc(), APInt::getZero(trailingZeros)));
1081 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, operands);
1088 for (
size_t i = 0; i < size - 1; ++i) {
1089 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1103 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1104 source, cmpAgainst);
1112OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1116 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1117 auto inputs = adaptor.getInputs();
1119 for (
auto operand : inputs) {
1120 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1123 value |= attr.getValue();
1124 if (value.isAllOnes())
1129 if (inputs.size() == 2)
1130 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1131 if (intAttr.getValue().isZero())
1132 return getInputs()[0];
1135 if (llvm::all_of(getInputs(),
1136 [&](
auto in) {
return in == this->getInputs()[0]; }))
1137 return getInputs()[0];
1140 for (Value arg : getInputs()) {
1142 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1143 for (Value arg2 : getInputs())
1144 if (arg2 == subExpr)
1146 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1156 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1163LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1167 auto inputs = op.getInputs();
1168 auto size = inputs.size();
1180 assert(size > 1 &&
"expected 2 or more operands");
1184 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1186 if (value.isZero()) {
1187 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1188 inputs.drop_back());
1194 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1196 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1197 newOperands.push_back(cst);
1198 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1206 for (
size_t i = 0; i < size - 1; ++i) {
1207 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1221 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1222 source, cmpAgainst);
1228 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1230 if (op.getTwoState() && firstMux.getTwoState() &&
1231 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1233 SmallVector<Value> conditions{firstMux.getCond()};
1234 auto check = [&](Value v) {
1238 conditions.push_back(mux.getCond());
1239 return mux.getTwoState() &&
1240 firstMux.getTrueValue() == mux.getTrueValue() &&
1241 firstMux.getFalseValue() == mux.getFalseValue();
1243 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1244 auto cond = comb::OrOp::create(rewriter, op.getLoc(), conditions,
true);
1245 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1246 rewriter, op, cond, firstMux.getTrueValue(),
1247 firstMux.getFalseValue(),
true);
1257OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1261 auto size = getInputs().size();
1262 auto inputs = adaptor.getInputs();
1266 return getInputs()[0];
1269 if (size == 2 && getInputs()[0] == getInputs()[1])
1270 return IntegerAttr::get(getType(), 0);
1273 if (inputs.size() == 2)
1274 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1275 if (intAttr.getValue().isZero())
1276 return getInputs()[0];
1280 if (isBinaryNot()) {
1282 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1283 subExpr != getResult())
1293 PatternRewriter &rewriter) {
1294 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1295 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1298 ICmpOp::create(rewriter, icmp.getLoc(), negatedPred, icmp.getOperand(0),
1299 icmp.getOperand(1), icmp.getTwoState());
1302 if (op.getNumOperands() > 2) {
1303 SmallVector<Value, 4> newOperands(op.getOperands());
1304 newOperands.pop_back();
1305 newOperands.erase(newOperands.begin() + icmpOperand);
1306 newOperands.push_back(result);
1308 XorOp::create(rewriter, op.getLoc(), newOperands, op.getTwoState());
1314LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1318 auto inputs = op.getInputs();
1319 auto size = inputs.size();
1320 assert(size > 1 &&
"expected 2 or more operands");
1323 if (inputs[size - 1] == inputs[size - 2]) {
1325 "expected idempotent case for 2 elements handled already.");
1326 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1327 inputs.drop_back(2),
false);
1333 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1335 if (value.isZero()) {
1336 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1337 inputs.drop_back(),
false);
1343 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1345 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1346 newOperands.push_back(cst);
1347 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1348 newOperands,
false);
1352 bool isSingleBit = value.getBitWidth() == 1;
1355 for (
size_t i = 0; i < size - 1; ++i) {
1356 Value operand = inputs[i];
1367 if (isSingleBit && operand.hasOneUse()) {
1368 assert(value == 1 &&
"single bit constant has to be one if not zero");
1369 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1385 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1392OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1397 if (getRhs() == getLhs())
1399 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1402 if (adaptor.getRhs()) {
1404 if (adaptor.getLhs()) {
1407 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1409 auto rhsNeg = hw::ParamExprAttr::get(
1410 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1411 return hw::ParamExprAttr::get(hw::PEO::Add,
1412 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1416 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1417 if (rhsC.getValue().isZero())
1425LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1431 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1433 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1445OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1449 auto size = getInputs().size();
1453 return getInputs()[0];
1459LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1463 auto inputs = op.getInputs();
1464 auto size = inputs.size();
1465 assert(size > 1 &&
"expected 2 or more operands");
1467 APInt value, value2;
1470 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1471 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1472 inputs.drop_back(),
false);
1477 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1478 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1480 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1481 newOperands.push_back(cst);
1482 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1483 newOperands,
false);
1488 if (inputs[size - 1] == inputs[size - 2]) {
1489 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1493 comb::ShlOp::create(rewriter, op.getLoc(), inputs.back(), one,
false);
1495 newOperands.push_back(shiftLeftOp);
1496 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1497 newOperands,
false);
1501 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1503 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1504 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1506 APInt one(value.getBitWidth(), 1,
false);
1510 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1511 auto mulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1513 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1514 newOperands.push_back(mulOp);
1515 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1516 newOperands,
false);
1520 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1522 if (mulOp && mulOp.getInputs().size() == 2 &&
1523 mulOp.getInputs()[0] == inputs[size - 2] &&
1524 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1526 APInt one(value.getBitWidth(), 1,
false);
1528 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1529 auto newMulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1531 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1532 newOperands.push_back(newMulOp);
1533 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1534 newOperands,
false);
1547 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1548 if (addOp && addOp.getInputs().size() == 2 &&
1549 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1550 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1553 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1554 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1555 op.getTwoState() && addOp.getTwoState());
1562OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1566 auto size = getInputs().size();
1567 auto inputs = adaptor.getInputs();
1571 return getInputs()[0];
1573 auto width = cast<IntegerType>(getType()).getWidth();
1575 return getIntAttr(APInt::getZero(0), getContext());
1577 APInt value(width, 1,
false);
1580 for (
auto operand : inputs) {
1581 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1584 value *= attr.getValue();
1593LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1597 auto inputs = op.getInputs();
1598 auto size = inputs.size();
1599 assert(size > 1 &&
"expected 2 or more operands");
1601 APInt value, value2;
1604 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1605 value.isPowerOf2()) {
1607 value.exactLogBase2());
1609 comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift,
false);
1611 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1612 ArrayRef<Value>(shlOp),
false);
1617 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1618 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1619 inputs.drop_back());
1624 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1625 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1627 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1628 newOperands.push_back(cst);
1629 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1645template <
class Op,
bool isSigned>
1646static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1647 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1649 if (rhsValue.getValue() == 1)
1653 if (rhsValue.getValue().isZero())
1660OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1663 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1666OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1672template <
class Op,
bool isSigned>
1673static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1674 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1676 if (rhsValue.getValue() == 1)
1677 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1681 if (rhsValue.getValue().isZero())
1685 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1687 if (lhsValue.getValue().isZero())
1688 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1695OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1698 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1701OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1707LogicalResult DivUOp::canonicalize(
DivUOp op, PatternRewriter &rewriter) {
1713LogicalResult ModUOp::canonicalize(
ModUOp op, PatternRewriter &rewriter) {
1725OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1729 if (getNumOperands() == 1)
1730 return getOperand(0);
1733 for (
auto attr : adaptor.getInputs())
1734 if (!attr || !isa<IntegerAttr>(attr))
1738 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1739 APInt result(resultWidth, 0);
1741 unsigned nextInsertion = resultWidth;
1743 for (
auto attr : adaptor.getInputs()) {
1744 auto chunk = cast<IntegerAttr>(attr).getValue();
1745 nextInsertion -= chunk.getBitWidth();
1746 result.insertBits(chunk, nextInsertion);
1752LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1756 auto inputs = op.getInputs();
1757 auto size = inputs.size();
1758 assert(size > 1 &&
"expected 2 or more operands");
1763 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1764 ValueRange replacements) -> LogicalResult {
1765 SmallVector<Value, 4> newOperands;
1766 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1767 newOperands.append(replacements.begin(), replacements.end());
1768 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1769 if (newOperands.size() == 1)
1772 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1777 Value commonOperand = inputs[0];
1778 for (
size_t i = 0; i != size; ++i) {
1780 if (inputs[i] != commonOperand)
1781 commonOperand = Value();
1785 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1786 return flattenConcat(i, i, subConcat->getOperands());
1791 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1792 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1793 unsigned prevWidth = prevCst.getValue().getBitWidth();
1794 unsigned thisWidth = cst.getValue().getBitWidth();
1795 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1796 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1800 return flattenConcat(i - 1, i, replacement);
1805 if (inputs[i] == inputs[i - 1]) {
1807 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1808 return flattenConcat(i - 1, i, replacement);
1813 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1815 if (repl.getOperand() == inputs[i - 1]) {
1816 Value replacement = rewriter.createOrFold<ReplicateOp>(
1817 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1818 return flattenConcat(i - 1, i, replacement);
1821 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1822 if (prevRepl.getOperand() == repl.getOperand()) {
1823 Value replacement = rewriter.createOrFold<ReplicateOp>(
1824 op.getLoc(), repl.getOperand(),
1825 repl.getMultiple() + prevRepl.getMultiple());
1826 return flattenConcat(i - 1, i, replacement);
1832 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1833 if (repl.getOperand() == inputs[i]) {
1834 Value replacement = rewriter.createOrFold<ReplicateOp>(
1835 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1836 return flattenConcat(i - 1, i, replacement);
1842 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1843 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1844 if (extract.getInput() == prevExtract.getInput()) {
1845 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1846 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1847 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1848 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1851 extract.getInput(), extract.getLowBit());
1852 return flattenConcat(i - 1, i, replacement);
1865 static std::optional<ArraySlice>
get(Value value) {
1866 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1868 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1871 if (
auto arraySlice =
1874 arraySlice.getInput(), arraySlice.getLowIndex(),
1875 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1877 return std::nullopt;
1880 if (
auto extractOpt = ArraySlice::get(inputs[i])) {
1881 if (
auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1883 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1884 prevExtractOpt->input == extractOpt->input &&
1885 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1886 extractOpt->width)) {
1887 auto resType = hw::ArrayType::get(
1888 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1890 extractOpt->width + prevExtractOpt->width);
1891 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1893 rewriter, op.getLoc(), resIntType,
1895 prevExtractOpt->input,
1896 extractOpt->index));
1897 return flattenConcat(i - 1, i, replacement);
1905 if (commonOperand) {
1906 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1918OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1923 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1924 return getTrueValue();
1925 if (
auto tv = adaptor.getTrueValue())
1926 if (tv == adaptor.getFalseValue())
1931 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1932 if (pred.getValue().isZero() && getFalseValue() != getResult())
1933 return getFalseValue();
1934 if (pred.getValue().isOne() && getTrueValue() != getResult())
1935 return getTrueValue();
1939 if (getCond().getType() == getTrueValue().getType())
1940 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1941 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1942 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1943 hw::getBitWidth(getType()) == 1 && getCond() != getResult())
1959 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1961 auto requiredPredicate =
1962 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1963 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1973 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
1976 for (
auto operand : orOp.getOperands())
1983 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
1986 for (
auto operand : andOp.getOperands())
2005 PatternRewriter &rewriter,
MuxOp rootMux,
bool isFalseSide,
2011 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2014 Value indexValue = rootCmp.getLhs();
2017 auto getCaseValue = [&](
MuxOp mux) -> Value {
2018 return mux.getOperand(1 +
unsigned(!isFalseSide));
2023 auto getTreeValue = [&](
MuxOp mux) -> Value {
2024 return mux.getOperand(1 +
unsigned(isFalseSide));
2029 SmallVector<Location> locationsFound;
2030 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2034 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2036 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2037 valuesFound.push_back({cst, getCaseValue(mux)});
2038 locationsFound.push_back(mux.getCond().getLoc());
2039 locationsFound.push_back(mux->getLoc());
2044 if (!collectConstantValues(rootMux))
2048 if (rootMux->hasOneUse()) {
2049 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2050 if (getTreeValue(userMux) == rootMux.getResult() &&
2058 auto nextTreeValue = getTreeValue(rootMux);
2060 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2061 if (!nextMux || !nextMux->hasOneUse())
2063 if (!collectConstantValues(nextMux))
2065 nextTreeValue = getTreeValue(nextMux);
2068 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2070 if (indexWidth > 20)
2073 auto foldingStyle = styleFn(indexWidth, valuesFound.size());
2077 uint64_t tableSize = 1ULL << indexWidth;
2081 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2086 for (
auto &elt :
llvm::reverse(valuesFound)) {
2087 uint64_t idx = elt.first.getValue().getZExtValue();
2088 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2089 table[idx] = elt.second;
2093 SmallVector<Value> bits;
2102 "unknown folding style");
2106 std::reverse(table.begin(), table.end());
2109 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2111 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2126 PatternRewriter &rewriter) {
2127 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2128 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2132 if (fullyAssoc->getNumOperands() == 2)
2133 return fullyAssoc->getOperand(operandNo ^ 1);
2136 if (fullyAssoc->hasOneUse()) {
2137 rewriter.modifyOpInPlace(fullyAssoc,
2138 [&]() { fullyAssoc->eraseOperand(operandNo); });
2139 return fullyAssoc->getResult(0);
2143 SmallVector<Value> operands;
2144 operands.append(fullyAssoc->getOperands().begin(),
2145 fullyAssoc->getOperands().begin() + operandNo);
2146 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2147 fullyAssoc->getOperands().end());
2149 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2150 Value excluded = fullyAssoc->getOperand(operandNo);
2154 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2156 return opWithoutExcluded;
2166 PatternRewriter &rewriter) {
2169 Operation *subExpr =
2170 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2171 if (!subExpr || subExpr->getNumOperands() < 2)
2175 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2180 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2181 size_t opNo = 0, e = subExpr->getNumOperands();
2182 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2188 Value cond = op.getCond();
2194 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2199 Value subCond = subMux.getCond();
2202 if (subMux.getTrueValue() == commonValue)
2203 otherValue = subMux.getFalseValue();
2204 else if (subMux.getFalseValue() == commonValue) {
2205 otherValue = subMux.getTrueValue();
2215 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2216 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2217 otherValue, op.getTwoState());
2223 bool isaAndOp = isa<AndOp>(subExpr);
2224 if (isTrueOperand ^ isaAndOp)
2228 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2231 bool isaXorOp = isa<XorOp>(subExpr);
2232 bool isaOrOp = isa<OrOp>(subExpr);
2241 if (isaOrOp || isaXorOp) {
2242 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2243 restOfAssoc,
false);
2245 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2246 commonValue,
false);
2248 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2254 assert(isaAndOp &&
"unexpected operation here");
2255 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2256 restOfAssoc,
false);
2257 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2268 PatternRewriter &rewriter) {
2271 if (!isa<ConcatOp>(trueOp))
2275 SmallVector<Value> trueOperands, falseOperands;
2279 size_t numTrueOperands = trueOperands.size();
2280 size_t numFalseOperands = falseOperands.size();
2282 if (!numTrueOperands || !numFalseOperands ||
2283 (trueOperands.front() != falseOperands.front() &&
2284 trueOperands.back() != falseOperands.back()))
2288 if (trueOperands.front() == falseOperands.front()) {
2289 SmallVector<Value> operands;
2291 for (i = 0; i < numTrueOperands; ++i) {
2292 Value trueOperand = trueOperands[i];
2293 if (trueOperand == falseOperands[i])
2294 operands.push_back(trueOperand);
2298 if (i == numTrueOperands) {
2305 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2306 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2307 mux->getLoc(), operands.front(), operands.size());
2309 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2313 operands.append(trueOperands.begin() + i, trueOperands.end());
2314 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2316 operands.append(falseOperands.begin() + i, falseOperands.end());
2318 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2321 Value lsb = rewriter.createOrFold<
MuxOp>(
2322 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2323 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2328 if (trueOperands.back() == falseOperands.back()) {
2329 SmallVector<Value> operands;
2332 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2333 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2334 operands.push_back(trueOperand);
2338 std::reverse(operands.begin(), operands.end());
2339 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2343 operands.append(trueOperands.begin(), trueOperands.end() - i);
2344 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2346 operands.append(falseOperands.begin(), falseOperands.end() - i);
2348 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2350 Value msb = rewriter.createOrFold<
MuxOp>(
2351 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2352 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2364 if (!trueVec || !falseVec)
2366 if (!trueVec.isUniform() || !falseVec.isUniform())
2369 auto mux = MuxOp::create(rewriter, op.getLoc(), op.getCond(),
2370 trueVec.getUniformElement(),
2371 falseVec.getUniformElement(), op.getTwoState());
2373 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2381 bool constCond, PatternRewriter &rewriter) {
2382 if (!muxValue.hasOneUse())
2384 auto *op = muxValue.getDefiningOp();
2385 if (!op || !isa_and_nonnull<CombDialect>(op->getDialect()))
2387 if (!llvm::is_contained(op->getOperands(), muxCond))
2389 OpBuilder::InsertionGuard guard(rewriter);
2390 rewriter.setInsertionPoint(op);
2393 rewriter.modifyOpInPlace(op, [&] {
2394 for (
auto &use : op->getOpOperands())
2395 if (use.get() == muxCond)
2403 using OpRewritePattern::OpRewritePattern;
2405 LogicalResult matchAndRewrite(
MuxOp op,
2406 PatternRewriter &rewriter)
const override;
2410foldToArrayCreateOnlyWhenDense(
size_t indexWidth,
size_t numEntries) {
2413 if (indexWidth >= 9 || numEntries < 3)
2419 uint64_t tableSize = 1ULL << indexWidth;
2420 if (numEntries >= tableSize * 5 / 8)
2425LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2426 PatternRewriter &rewriter)
const {
2430 bool isSignlessInt =
false;
2431 if (
auto intType = dyn_cast<IntegerType>(op.getType()))
2432 isSignlessInt = intType.isSignless();
2439 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value)) && isSignlessInt) {
2440 if (value.getBitWidth() == 1) {
2442 if (value.isZero()) {
2444 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2445 op.getFalseValue(),
false);
2450 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2451 op.getFalseValue(),
false);
2457 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2462 APInt xorValue = value ^ value2;
2463 if (xorValue.isPowerOf2()) {
2464 unsigned leadingZeros = xorValue.countLeadingZeros();
2465 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2466 SmallVector<Value, 3> operands;
2474 if (leadingZeros > 0)
2475 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2476 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2480 auto v1 = rewriter.createOrFold<
ExtractOp>(
2481 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2482 auto v2 = rewriter.createOrFold<
ExtractOp>(
2483 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2484 operands.push_back(rewriter.createOrFold<
MuxOp>(
2485 op.getLoc(), op.getCond(), v1, v2,
false));
2487 if (trailingZeros > 0)
2488 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2489 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2491 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2498 if (value.isAllOnes() && value2.isZero()) {
2499 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2500 rewriter, op, op.getType(), op.getCond());
2506 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2507 isSignlessInt && value.getBitWidth() == 1) {
2509 if (value.isZero()) {
2510 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2511 op.getTrueValue(),
false);
2518 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2519 op.getFalseValue(),
false);
2520 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2521 op.getTrueValue(),
false);
2527 Operation *condOp = op.getCond().getDefiningOp();
2528 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2530 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2531 subExpr, op.getFalseValue(),
2532 op.getTrueValue(),
true);
2539 if (condOp && condOp->hasOneUse()) {
2540 SmallVector<Value> invertedOperands;
2544 auto getInvertedOperands = [&]() ->
bool {
2545 for (Value operand : condOp->getOperands()) {
2546 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2547 invertedOperands.push_back(subExpr);
2554 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2556 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2557 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2558 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2562 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2564 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2565 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2566 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2572 if (
auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
2573 falseMux && falseMux != op) {
2575 if (op.getCond() == falseMux.getCond() &&
2576 falseMux.getFalseValue() != falseMux) {
2577 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2578 rewriter, op, op.getCond(), op.getTrueValue(),
2579 falseMux.getFalseValue(), op.getTwoStateAttr());
2585 foldToArrayCreateOnlyWhenDense))
2589 if (
auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
2590 trueMux && trueMux != op) {
2592 if (op.getCond() == trueMux.getCond()) {
2593 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2594 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2595 op.getFalseValue(), op.getTwoStateAttr());
2601 foldToArrayCreateOnlyWhenDense))
2606 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2607 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2608 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2609 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2611 auto subMux = MuxOp::create(
2612 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2613 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2614 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2615 trueMux.getTrueValue(), subMux,
2616 op.getTwoStateAttr());
2621 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2622 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2623 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2624 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2626 auto subMux = MuxOp::create(
2627 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2628 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2629 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2630 subMux, trueMux.getFalseValue(),
2631 op.getTwoStateAttr());
2636 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2637 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2638 trueMux && falseMux &&
2639 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2640 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2643 MuxOp::create(rewriter,
2644 rewriter.getFusedLoc(
2645 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2646 op.getCond(), trueMux.getCond(), falseMux.getCond());
2647 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2648 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2649 op.getTwoStateAttr());
2661 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2662 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2663 if (trueOp->getName() == falseOp->getName())
2676 if (op.getTrueValue().getDefiningOp() &&
2677 op.getTrueValue().getDefiningOp() != op)
2680 if (op.getFalseValue().getDefiningOp() &&
2681 op.getFalseValue().getDefiningOp() != op)
2692 if (op.getInputs().empty() || op.isUniform())
2694 auto inputs = op.getInputs();
2695 if (inputs.size() <= 1)
2700 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2705 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2706 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2707 if (!input || first.getCond() != input.getCond())
2712 SmallVector<Value> trues{first.getTrueValue()};
2713 SmallVector<Value> falses{first.getFalseValue()};
2714 SmallVector<Location> locs{first->getLoc()};
2715 bool isTwoState =
true;
2716 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2717 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2718 trues.push_back(input.getTrueValue());
2719 falses.push_back(input.getFalseValue());
2720 locs.push_back(input->getLoc());
2721 if (!input.getTwoState())
2726 auto loc = FusedLoc::get(op.getContext(), locs);
2730 auto arrayTy = op.getType();
2733 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2734 trueValues, falseValues, isTwoState);
2739 using OpRewritePattern::OpRewritePattern;
2742 PatternRewriter &rewriter)
const override {
2743 if (foldArrayOfMuxes(op, rewriter))
2751void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2753 results.insert<MuxRewriter, ArrayRewriter>(
context);
2764 switch (predicate) {
2765 case ICmpPredicate::eq:
2767 case ICmpPredicate::ne:
2769 case ICmpPredicate::slt:
2770 return lhs.slt(rhs);
2771 case ICmpPredicate::sle:
2772 return lhs.sle(rhs);
2773 case ICmpPredicate::sgt:
2774 return lhs.sgt(rhs);
2775 case ICmpPredicate::sge:
2776 return lhs.sge(rhs);
2777 case ICmpPredicate::ult:
2778 return lhs.ult(rhs);
2779 case ICmpPredicate::ule:
2780 return lhs.ule(rhs);
2781 case ICmpPredicate::ugt:
2782 return lhs.ugt(rhs);
2783 case ICmpPredicate::uge:
2784 return lhs.uge(rhs);
2785 case ICmpPredicate::ceq:
2787 case ICmpPredicate::cne:
2789 case ICmpPredicate::weq:
2791 case ICmpPredicate::wne:
2794 llvm_unreachable(
"unknown comparison predicate");
2800 switch (predicate) {
2801 case ICmpPredicate::eq:
2802 case ICmpPredicate::sle:
2803 case ICmpPredicate::sge:
2804 case ICmpPredicate::ule:
2805 case ICmpPredicate::uge:
2806 case ICmpPredicate::ceq:
2807 case ICmpPredicate::weq:
2809 case ICmpPredicate::ne:
2810 case ICmpPredicate::slt:
2811 case ICmpPredicate::sgt:
2812 case ICmpPredicate::ult:
2813 case ICmpPredicate::ugt:
2814 case ICmpPredicate::cne:
2815 case ICmpPredicate::wne:
2818 llvm_unreachable(
"unknown comparison predicate");
2821OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2824 if (getLhs() == getRhs()) {
2826 return IntegerAttr::get(getType(), val);
2830 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2831 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2834 return IntegerAttr::get(getType(), val);
2842template <
typename Range>
2844 size_t commonPrefixLength = 0;
2845 auto ia = a.begin();
2846 auto ib = b.begin();
2848 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2854 return commonPrefixLength;
2858 size_t totalWidth = 0;
2859 for (
auto operand : operands) {
2862 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2864 totalWidth += width;
2874 PatternRewriter &rewriter) {
2878 SmallVector<Value> lhsOperands, rhsOperands;
2881 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2883 auto formCatOrReplicate = [&](Location loc,
2884 ArrayRef<Value> operands) -> Value {
2885 assert(!operands.empty());
2886 Value sameElement = operands[0];
2887 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2888 if (sameElement != operands[i])
2889 sameElement = Value();
2891 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2893 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2896 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2897 Value rhs) -> LogicalResult {
2898 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2903 size_t commonPrefixLength =
2905 if (commonPrefixLength == lhsOperands.size()) {
2908 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2914 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2916 size_t commonPrefixTotalWidth =
2917 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2918 size_t commonSuffixTotalWidth =
2919 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2920 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2921 .drop_back(commonSuffixLength);
2922 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2923 .drop_back(commonSuffixLength);
2925 auto replaceWithoutReplicatingSignBit = [&]() {
2926 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2927 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2928 return replaceWith(op.getPredicate(), newLhs, newRhs);
2931 auto replaceWithReplicatingSignBit = [&]() {
2932 auto firstNonEmptyValue = lhsOperands[0];
2933 auto firstNonEmptyElemWidth =
2934 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2935 Value signBit = rewriter.createOrFold<
ExtractOp>(
2936 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2938 auto newLhs = ConcatOp::create(rewriter, lhs->getLoc(), signBit, lhsOnly);
2939 auto newRhs = ConcatOp::create(rewriter, rhs->getLoc(), signBit, rhsOnly);
2940 return replaceWith(op.getPredicate(), newLhs, newRhs);
2943 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2945 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2946 return replaceWithoutReplicatingSignBit();
2952 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2953 return replaceWithReplicatingSignBit();
2955 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2957 return replaceWithoutReplicatingSignBit();
2971 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2972 PatternRewriter &rewriter) {
2976 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2977 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2980 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2981 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2989 SmallVector<Value> newConcatOperands;
2990 auto newConstant = APInt::getZeroWidth();
2995 unsigned knownMSB = bitsKnown.countLeadingOnes();
2997 Value operand = cmpOp.getLhs();
3002 while (knownMSB != bitsKnown.getBitWidth()) {
3005 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3008 unsigned unknownBits = bitsKnown.countLeadingZeros();
3009 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3010 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
3011 operand.getLoc(), operand, lowBit,
3013 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3016 newConcatOperands.push_back(spanOperand);
3019 if (newConstant.getBitWidth() != 0)
3020 newConstant = newConstant.concat(spanConstant);
3022 newConstant = spanConstant;
3025 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3026 bitsKnown = bitsKnown.trunc(newWidth);
3027 knownMSB = bitsKnown.countLeadingOnes();
3033 if (newConcatOperands.empty()) {
3034 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3035 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3041 Value concatResult =
3042 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
3046 rewriter, cmpOp.getOperand(1).getLoc(), newConstant);
3048 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
3049 cmpOp.getPredicate(), concatResult,
3050 newConstantOp, cmpOp.getTwoState());
3056 PatternRewriter &rewriter) {
3057 auto ip = rewriter.saveInsertionPoint();
3058 rewriter.setInsertionPoint(xorOp);
3060 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
3062 xorRHS.getValue() ^ rhs);
3064 switch (xorOp.getNumOperands()) {
3068 APInt::getZero(rhs.getBitWidth()));
3072 newLHS = xorOp.getOperand(0);
3076 SmallVector<Value> newOperands(xorOp.getOperands());
3077 newOperands.pop_back();
3078 newLHS = XorOp::create(rewriter, xorOp.getLoc(), newOperands,
false);
3082 bool xorMultipleUses = !xorOp->hasOneUse();
3086 if (xorMultipleUses)
3087 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3091 rewriter.restoreInsertionPoint(ip);
3092 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3093 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS,
false);
3096LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3102 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3103 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3104 "Should be folded");
3105 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3106 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3107 op.getRhs(), op.getLhs(), op.getTwoState());
3112 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3117 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3118 Value rhs) -> LogicalResult {
3119 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3120 rhs, op.getTwoState());
3124 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3125 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3126 APInt(1, constant));
3130 switch (op.getPredicate()) {
3131 case ICmpPredicate::slt:
3133 if (rhs.isMaxSignedValue())
3134 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3136 if (rhs.isMinSignedValue())
3137 return replaceWithConstantI1(0);
3139 if ((rhs - 1).isMinSignedValue())
3140 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3143 case ICmpPredicate::sgt:
3145 if (rhs.isMinSignedValue())
3146 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3148 if (rhs.isMaxSignedValue())
3149 return replaceWithConstantI1(0);
3151 if ((rhs + 1).isMaxSignedValue())
3152 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3155 case ICmpPredicate::ult:
3157 if (rhs.isAllOnes())
3158 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3161 return replaceWithConstantI1(0);
3163 if ((rhs - 1).isZero())
3164 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3168 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3169 rhs.getBitWidth()) {
3170 auto numOnes = rhs.countLeadingOnes();
3172 rhs.getBitWidth() - numOnes, numOnes);
3173 return replaceWith(ICmpPredicate::ne, smaller,
3178 case ICmpPredicate::ugt:
3181 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3183 if (rhs.isAllOnes())
3184 return replaceWithConstantI1(0);
3186 if ((rhs + 1).isAllOnes())
3187 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3191 if ((rhs + 1).isPowerOf2()) {
3192 auto numOnes = rhs.countTrailingOnes();
3193 auto newWidth = rhs.getBitWidth() - numOnes;
3196 return replaceWith(ICmpPredicate::ne, smaller,
3201 case ICmpPredicate::sle:
3203 if (rhs.isMaxSignedValue())
3204 return replaceWithConstantI1(1);
3206 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3207 case ICmpPredicate::sge:
3209 if (rhs.isMinSignedValue())
3210 return replaceWithConstantI1(1);
3212 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3213 case ICmpPredicate::ule:
3215 if (rhs.isAllOnes())
3216 return replaceWithConstantI1(1);
3218 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3219 case ICmpPredicate::uge:
3222 return replaceWithConstantI1(1);
3224 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3225 case ICmpPredicate::eq:
3226 if (rhs.getBitWidth() == 1) {
3229 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3234 if (rhs.isAllOnes()) {
3241 case ICmpPredicate::ne:
3242 if (rhs.getBitWidth() == 1) {
3248 if (rhs.isAllOnes()) {
3250 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3257 case ICmpPredicate::ceq:
3258 case ICmpPredicate::cne:
3259 case ICmpPredicate::weq:
3260 case ICmpPredicate::wne:
3266 if (op.getPredicate() == ICmpPredicate::eq ||
3267 op.getPredicate() == ICmpPredicate::ne) {
3272 if (!knownBits.isUnknown())
3279 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3286 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3287 if (rhs.isAllOnes() || rhs.isZero()) {
3288 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3291 rhs.isAllOnes() ? APInt::getAllOnes(width)
3292 : APInt::getZero(width));
3293 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3294 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3304 if (Operation *opLHS = op.getLhs().getDefiningOp())
3305 if (Operation *opRHS = op.getRhs().getDefiningOp())
3306 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3307 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool assumeMuxCondInOperand(Value muxCond, Value muxValue, bool constCond, PatternRewriter &rewriter)
If the mux condition is an operand to the op defining its true or false value, replace the condition ...
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool isOpTriviallyRecursive(Operation *op)
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(array_value, low_index, ret_type)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl< Value > &bits)
Extract bits from a value.
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
MuxChainWithComparisonFoldingStyle
Enum for mux chain folding styles.
LogicalResult convertModUByPowerOfTwo(ModUOp modOp, mlir::PatternRewriter &rewriter)
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
Value constructMuxTree(OpBuilder &builder, Location loc, ArrayRef< Value > selectors, ArrayRef< Value > leafNodes, Value outOfBoundsValue)
Construct a mux tree for given leaf nodes.
LogicalResult convertDivUByPowerOfTwo(DivUOp divOp, mlir::PatternRewriter &rewriter)
Convert unsigned division or modulo by a power of two.
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.