13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/Matchers.h"
15#include "mlir/IR/PatternMatch.h"
16#include "llvm/ADT/SetVector.h"
17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/KnownBits.h"
24using namespace matchers;
28 return llvm::any_of(op->getOperands(), [op](
auto operand) {
29 return operand.getDefiningOp() == op;
39 ArrayRef<Value> operands, OpBuilder &builder) {
40 OperationState state(loc, name);
41 state.addOperands(operands);
42 state.addTypes(operands[0].getType());
43 return builder.create(state)->getResult(0);
47 return IntegerAttr::get(IntegerType::get(
context, value.getBitWidth()),
53 if (
auto concat = v.getDefiningOp<
ConcatOp>()) {
54 for (
auto op : concat.getOperands())
56 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
57 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
68 return op->hasAttr(
"sv.attributes");
72template <
typename SubType>
73struct ComplementMatcher {
75 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
76 bool match(Operation *op) {
77 auto xorOp = dyn_cast<XorOp>(op);
78 return xorOp && xorOp.isBinaryNot() &&
79 mlir::detail::matchOperandOrValueAtIndex(op, 0, lhs);
84template <
typename SubType>
85static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
86 return ComplementMatcher<SubType>(subExpr);
92 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
93 "must be commutative operations"));
94 if (op->hasOneUse()) {
95 auto *user = *op->getUsers().begin();
96 return user->getName() == op->getName() &&
97 op->getAttrOfType<UnitAttr>(
"twoState") ==
98 user->getAttrOfType<UnitAttr>(
"twoState") &&
99 op->getBlock() == user->getBlock();
114 auto inputs = op->getOperands();
116 SmallVector<Value, 4> newOperands;
117 SmallVector<Location, 4> newLocations{op->getLoc()};
118 newOperands.reserve(inputs.size());
120 decltype(inputs.begin()) current, end;
123 SmallVector<Element> worklist;
124 worklist.push_back({inputs.begin(), inputs.end()});
125 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
126 bool changed =
false;
127 while (!worklist.empty()) {
128 auto &element = worklist.back();
131 if (element.current == element.end) {
136 Value value = *element.current++;
137 auto *flattenOp = value.getDefiningOp();
140 if (!flattenOp || flattenOp->getName() != op->getName() ||
141 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
142 flattenOp->getBlock() != op->getBlock()) {
143 newOperands.push_back(value);
148 if (!value.hasOneUse()) {
156 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
159 newOperands.push_back(value);
167 auto flattenOpInputs = flattenOp->getOperands();
168 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
169 newLocations.push_back(flattenOp->getLoc());
175 Value result =
createGenericOp(FusedLoc::get(op->getContext(), newLocations),
176 op->getName(), newOperands, rewriter);
178 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
186static std::pair<size_t, size_t>
188 size_t originalOpWidth) {
189 auto users = op->getUsers();
191 "getLowestBitAndHighestBitRequired cannot operate on "
192 "a empty list of uses.");
196 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
197 size_t highestBitRequired = 0;
199 for (
auto *user : users) {
200 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
201 size_t lowBit = extractOp.getLowBit();
203 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
204 highestBitRequired = std::max(highestBitRequired, highBit);
205 lowestBitRequired = std::min(lowestBitRequired, lowBit);
209 highestBitRequired = originalOpWidth - 1;
210 lowestBitRequired = 0;
214 return {lowestBitRequired, highestBitRequired};
219 PatternRewriter &rewriter) {
220 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
226 if (range.second + 1 == opType.getWidth() && range.first == 0)
229 SmallVector<Value> args;
230 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
231 for (
auto inop : op.getOperands()) {
233 if (inop.getType() != op.getType())
234 args.push_back(inop);
236 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
239 auto newop = OpTy::create(rewriter, op.getLoc(), newType, args);
240 newop->setDialectAttrs(op->getDialectAttrs());
241 if (op.getTwoState())
242 newop.setTwoState(
true);
244 Value newResult = newop.getResult();
246 newResult = rewriter.createOrFold<
ConcatOp>(
247 op.getLoc(), newResult,
249 APInt::getZero(range.first)));
250 if (range.second + 1 < opType.getWidth())
251 newResult = rewriter.createOrFold<
ConcatOp>(
254 rewriter, op.getLoc(),
255 APInt::getZero(opType.getWidth() - range.second - 1)),
257 rewriter.replaceOp(op, newResult);
265OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
270 if (cast<IntegerType>(getType()).
getWidth() ==
271 getInput().getType().getIntOrFloatBitWidth())
275 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
276 if (input.getValue().getBitWidth() == 1) {
277 if (input.getValue().isZero())
279 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
282 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
286 APInt result = APInt::getZeroWidth();
287 for (
auto i = getMultiple(); i != 0; --i)
288 result = result.concat(input.getValue());
295OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
300 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
301 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
313 hw::PEO paramOpcode) {
314 assert(operands.size() == 2 &&
"binary op takes two operands");
315 if (!operands[0] || !operands[1])
320 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
321 cast<TypedAttr>(operands[1]));
324OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
328 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
329 if (rhs.getValue().isZero())
330 return getOperand(0);
332 unsigned width = getType().getIntOrFloatBitWidth();
333 if (rhs.getValue().uge(width))
334 return getIntAttr(APInt::getZero(width), getContext());
339LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
345 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
348 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
349 if (value.ugt(width))
351 unsigned shift = value.getZExtValue();
354 if (width <= shift || shift == 0)
364 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
368OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
372 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
373 if (rhs.getValue().isZero())
374 return getOperand(0);
376 unsigned width = getType().getIntOrFloatBitWidth();
377 if (rhs.getValue().uge(width))
378 return getIntAttr(APInt::getZero(width), getContext());
383LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
389 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
392 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
393 if (value.ugt(width))
395 unsigned shift = value.getZExtValue();
398 if (width <= shift || shift == 0)
408 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
412OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
416 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
417 if (rhs.getValue().isZero())
418 return getOperand(0);
422LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
428 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
431 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
432 if (value.ugt(width))
434 unsigned shift = value.getZExtValue();
437 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
438 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
440 if (width == shift) {
448 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
456OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
461 if (getInput().getType() == getType())
465 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
466 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
467 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
478 PatternRewriter &rewriter) {
479 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
480 size_t beginOfFirstRelevantElement = 0;
481 auto it = reversedConcatArgs.begin();
482 size_t lowBit = op.getLowBit();
485 for (; it != reversedConcatArgs.end(); it++) {
486 assert(beginOfFirstRelevantElement <= lowBit &&
487 "incorrectly moved past an element that lowBit has coverage over");
490 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
491 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
515 beginOfFirstRelevantElement += operandWidth;
517 assert(it != reversedConcatArgs.end() &&
518 "incorrectly failed to find an element which contains coverage of "
521 SmallVector<Value> reverseConcatArgs;
522 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
523 size_t extractLo = lowBit - beginOfFirstRelevantElement;
528 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
529 auto concatArg = *it;
530 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
531 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
533 if (widthToConsume == operandWidth && extractLo == 0) {
534 reverseConcatArgs.push_back(concatArg);
536 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
537 reverseConcatArgs.push_back(
541 widthRemaining -= widthToConsume;
547 if (reverseConcatArgs.size() == 1) {
550 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
551 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
558 PatternRewriter &rewriter) {
559 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
560 auto replicateEltWidth =
561 replicate.getOperand().getType().getIntOrFloatBitWidth();
565 if (op.getLowBit() % replicateEltWidth == 0 &&
566 extractResultWidth % replicateEltWidth == 0) {
567 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
568 replicate.getOperand());
574 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
576 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
577 rewriter, op, op.getType(), replicate.getOperand(),
578 op.getLowBit() % replicateEltWidth);
587LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
590 auto *inputOp = op.getInput().getDefiningOp();
597 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
599 if (knownBits.isConstant()) {
600 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
601 knownBits.getConstant());
607 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
608 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
609 rewriter, op, op.getType(), innerExtract.getInput(),
610 innerExtract.getLowBit() + op.getLowBit());
615 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
619 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
625 if (inputOp && inputOp->getNumOperands() == 2 &&
626 isa<AndOp, OrOp, XorOp>(inputOp)) {
627 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
628 auto extractedCst = cstRHS.getValue().extractBits(
629 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
630 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
631 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
632 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
640 if (isa<AndOp>(inputOp)) {
643 unsigned lz = extractedCst.countLeadingZeros();
644 unsigned tz = extractedCst.countTrailingZeros();
645 unsigned pop = extractedCst.popcount();
646 if (extractedCst.getBitWidth() - lz - tz == pop) {
647 auto resultTy = rewriter.getIntegerType(pop);
648 SmallVector<Value> resultElts;
651 APInt::getZero(lz)));
652 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
653 op.getLoc(), resultTy, inputOp->getOperand(0),
654 op.getLowBit() + tz));
657 APInt::getZero(tz)));
658 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
667 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
668 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
670 if (shlOp->hasOneUse())
672 if (lhsCst.getValue().isOne()) {
674 rewriter, shlOp.getLoc(),
675 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
676 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
677 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
693 hw::PEO paramOpcode) {
694 assert(operands.size() > 1 &&
"caller should handle one-operand case");
697 if (!operands[1] || !operands[0])
701 if (llvm::all_of(operands.drop_front(2),
702 [&](Attribute in) { return !!in; })) {
703 SmallVector<mlir::TypedAttr> typedOperands;
704 typedOperands.reserve(operands.size());
705 for (
auto operand : operands) {
706 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
707 typedOperands.push_back(typedOperand);
711 if (typedOperands.size() == operands.size())
712 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
728 size_t concatIdx,
const APInt &cst,
729 PatternRewriter &rewriter) {
730 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
731 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
736 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
737 auto *operandOp = operand.getDefiningOp();
742 if (isa<hw::ConstantOp>(operandOp))
746 return operandOp->getName() == logicalOp->getName() &&
747 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
748 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
756 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
757 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
764 SmallVector<Value> newConcatOperands;
765 newConcatOperands.reserve(concatOp->getNumOperands());
768 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
769 for (Value operand : concatOp->getOperands()) {
770 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
771 nextOperandBit -= operandWidth;
775 cst.lshr(nextOperandBit).trunc(operandWidth));
777 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
782 ConcatOp::create(rewriter, concatOp.getLoc(), newConcatOperands);
786 if (logicalOp->getNumOperands() > 2) {
787 auto origOperands = logicalOp->getOperands();
788 SmallVector<Value> operands;
790 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
792 operands.append(origOperands.begin() + concatIdx + 1,
793 origOperands.begin() + (origOperands.size() - 1));
795 operands.push_back(newResult);
796 newResult = createLogicalOp(operands);
806 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
808 for (
auto op : operands) {
809 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
810 icmpOp && icmpOp.getTwoState()) {
811 auto predicate = icmpOp.getPredicate();
812 auto lhs = icmpOp.getLhs();
813 auto rhs = icmpOp.getRhs();
814 if (seenPredicates.contains(
815 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
818 seenPredicates.insert({predicate, lhs, rhs});
824OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
828 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
830 auto inputs = adaptor.getInputs();
833 for (
auto operand : inputs) {
834 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
837 value &= attr.getValue();
843 if (inputs.size() == 2)
844 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
845 if (intAttr.getValue().isAllOnes())
846 return getInputs()[0];
849 if (llvm::all_of(getInputs(),
850 [&](
auto in) {
return in == this->getInputs()[0]; }))
851 return getInputs()[0];
854 for (Value arg : getInputs()) {
857 for (Value arg2 : getInputs())
860 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
881template <
typename Op>
883 if (!op.getType().isInteger(1))
886 auto inputs = op.getInputs();
887 size_t size = inputs.size();
889 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
892 Value source = sourceOp.getOperand();
895 if (size != source.getType().getIntOrFloatBitWidth())
899 llvm::BitVector bits(size);
900 bits.set(sourceOp.getLowBit());
902 for (
size_t i = 1; i != size; ++i) {
903 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
904 if (!extractOp || extractOp.getOperand() != source)
906 bits.set(extractOp.getLowBit());
909 return bits.all() ? source : Value();
916template <
typename Op>
919 constexpr unsigned limit = 3;
920 auto inputs = op.getInputs();
923 llvm::SmallDenseSet<Op, 8> checked;
930 llvm::SmallVector<OpWithDepth, 8> worklist;
932 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
936 if (depth < limit && input.getParentBlock() == op->getBlock()) {
937 auto inputOp = input.template getDefiningOp<Op>();
938 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
939 checked.insert(inputOp).second)
940 worklist.push_back({inputOp, depth + 1});
944 for (
auto input : uniqueInputs)
947 while (!worklist.empty()) {
948 auto item = worklist.pop_back_val();
950 for (
auto input : item.op.getInputs()) {
951 uniqueInputs.remove(input);
952 enqueue(input, item.depth);
956 if (uniqueInputs.size() < inputs.size()) {
957 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
958 uniqueInputs.getArrayRef(),
966LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
970 auto inputs = op.getInputs();
971 auto size = inputs.size();
983 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
987 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
989 if (value.isAllOnes()) {
990 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
991 inputs.drop_back(),
false);
999 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1001 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1002 newOperands.push_back(cst);
1003 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1004 newOperands,
false);
1009 if (size == 2 && value.isPowerOf2()) {
1014 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1015 auto replicateOperand = replicate.getOperand();
1016 if (replicateOperand.getType().isInteger(1)) {
1017 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1018 auto trailingZeros = value.countTrailingZeros();
1021 SmallVector<Value, 3> concatOperands;
1022 if (trailingZeros != resultWidth - 1) {
1024 rewriter, op.getLoc(),
1025 APInt::getZero(resultWidth - trailingZeros - 1));
1026 concatOperands.push_back(highZeros);
1028 concatOperands.push_back(replicateOperand);
1029 if (trailingZeros != 0) {
1031 rewriter, op.getLoc(), APInt::getZero(trailingZeros));
1032 concatOperands.push_back(lowZeros);
1034 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1035 rewriter, op, op.getType(), concatOperands);
1044 unsigned leadingZeros = value.countLeadingZeros();
1045 unsigned trailingZeros = value.countTrailingZeros();
1046 if (leadingZeros > 0 || trailingZeros > 0) {
1047 unsigned maskLength = value.getBitWidth() - leadingZeros - trailingZeros;
1050 SmallVector<Value> operands;
1051 for (
auto input : inputs.drop_back()) {
1052 unsigned offset = trailingZeros;
1053 while (
auto extractOp = input.getDefiningOp<
ExtractOp>()) {
1054 input = extractOp.getInput();
1055 offset += extractOp.getLowBit();
1058 offset, maskLength));
1062 auto narrowMask = value.extractBits(maskLength, trailingZeros);
1063 if (!narrowMask.isAllOnes())
1065 rewriter, inputs.back().getLoc(), narrowMask));
1068 Value narrowValue = operands.back();
1069 if (operands.size() > 1)
1071 AndOp::create(rewriter, op.getLoc(), operands, op.getTwoState());
1075 if (leadingZeros > 0)
1077 rewriter, op.getLoc(), APInt::getZero(leadingZeros)));
1078 operands.push_back(narrowValue);
1079 if (trailingZeros > 0)
1081 rewriter, op.getLoc(), APInt::getZero(trailingZeros)));
1082 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, operands);
1089 for (
size_t i = 0; i < size - 1; ++i) {
1090 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1104 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1105 source, cmpAgainst);
1110 if (op.getTwoState() && op.getNumOperands() == 2) {
1111 auto isReplicateOfI1 = [](Value v) {
1112 auto rep = v.getDefiningOp<ReplicateOp>();
1115 return rep.getOperand().getType().isInteger(1);
1117 Value x = op.getOperand(0);
1118 Value y = op.getOperand(1);
1119 if (isReplicateOfI1(x))
1121 if (isReplicateOfI1(y)) {
1122 Value p = y.getDefiningOp<ReplicateOp>().getInput();
1124 rewriter, op.getLoc(), rewriter.getIntegerAttr(op.getType(), 0));
1125 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, p, x, zero,
1135OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1139 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1140 auto inputs = adaptor.getInputs();
1142 for (
auto operand : inputs) {
1143 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1146 value |= attr.getValue();
1147 if (value.isAllOnes())
1152 if (inputs.size() == 2)
1153 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1154 if (intAttr.getValue().isZero())
1155 return getInputs()[0];
1158 if (llvm::all_of(getInputs(),
1159 [&](
auto in) {
return in == this->getInputs()[0]; }))
1160 return getInputs()[0];
1163 for (Value arg : getInputs()) {
1165 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1166 for (Value arg2 : getInputs())
1167 if (arg2 == subExpr)
1169 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1179 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1186LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1190 auto inputs = op.getInputs();
1191 auto size = inputs.size();
1203 assert(size > 1 &&
"expected 2 or more operands");
1207 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1209 if (value.isZero()) {
1210 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1211 inputs.drop_back());
1217 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1219 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1220 newOperands.push_back(cst);
1221 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1229 for (
size_t i = 0; i < size - 1; ++i) {
1230 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1244 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1245 source, cmpAgainst);
1251 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1253 if (op.getTwoState() && firstMux.getTwoState() &&
1254 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1256 SmallVector<Value> conditions{firstMux.getCond()};
1257 auto check = [&](Value v) {
1261 conditions.push_back(mux.getCond());
1262 return mux.getTwoState() &&
1263 firstMux.getTrueValue() == mux.getTrueValue() &&
1264 firstMux.getFalseValue() == mux.getFalseValue();
1266 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1267 auto cond = comb::OrOp::create(rewriter, op.getLoc(), conditions,
true);
1268 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1269 rewriter, op, cond, firstMux.getTrueValue(),
1270 firstMux.getFalseValue(),
true);
1280OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1284 auto size = getInputs().size();
1285 auto inputs = adaptor.getInputs();
1289 return getInputs()[0];
1292 if (size == 2 && getInputs()[0] == getInputs()[1])
1293 return IntegerAttr::get(getType(), 0);
1296 if (inputs.size() == 2)
1297 if (
auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1298 if (intAttr.getValue().isZero())
1299 return getInputs()[0];
1305 subExpr != getResult())
1314 PatternRewriter &rewriter) {
1315 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1316 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1319 ICmpOp::create(rewriter, icmp.getLoc(), negatedPred, icmp.getOperand(0),
1320 icmp.getOperand(1), icmp.getTwoState());
1323 if (op.getNumOperands() > 2) {
1324 SmallVector<Value, 4> newOperands(op.getOperands());
1325 newOperands.pop_back();
1326 newOperands.erase(newOperands.begin() + icmpOperand);
1327 newOperands.push_back(result);
1329 XorOp::create(rewriter, op.getLoc(), newOperands, op.getTwoState());
1335LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1339 auto inputs = op.getInputs();
1340 auto size = inputs.size();
1341 assert(size > 1 &&
"expected 2 or more operands");
1344 if (inputs[size - 1] == inputs[size - 2]) {
1346 "expected idempotent case for 2 elements handled already.");
1347 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1348 inputs.drop_back(2),
false);
1354 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1356 if (value.isZero()) {
1357 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1358 inputs.drop_back(),
false);
1364 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1366 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1367 newOperands.push_back(cst);
1368 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1369 newOperands,
false);
1373 bool isSingleBit = value.getBitWidth() == 1;
1376 for (
size_t i = 0; i < size - 1; ++i) {
1377 Value operand = inputs[i];
1383 if (
auto concat = operand.getDefiningOp<
ConcatOp>())
1388 if (isSingleBit && operand.hasOneUse()) {
1389 assert(value == 1 &&
"single bit constant has to be one if not zero");
1390 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1400 if (matchPattern(op.getResult(),
m_Complement(m_Sext(m_Any(&base))))) {
1419 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1426OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1431 if (getRhs() == getLhs())
1433 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1436 if (adaptor.getRhs()) {
1438 if (adaptor.getLhs()) {
1441 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1443 auto rhsNeg = hw::ParamExprAttr::get(
1444 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1445 return hw::ParamExprAttr::get(hw::PEO::Add,
1446 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1450 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1451 if (rhsC.getValue().isZero())
1459LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1465 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1467 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1479OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1483 auto size = getInputs().size();
1487 return getInputs()[0];
1493LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1497 auto inputs = op.getInputs();
1498 auto size = inputs.size();
1499 assert(size > 1 &&
"expected 2 or more operands");
1501 APInt value, value2;
1504 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1505 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1506 inputs.drop_back(),
false);
1511 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1512 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1514 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1515 newOperands.push_back(cst);
1516 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1517 newOperands,
false);
1522 if (inputs[size - 1] == inputs[size - 2]) {
1523 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1527 comb::ShlOp::create(rewriter, op.getLoc(), inputs.back(), one,
false);
1529 newOperands.push_back(shiftLeftOp);
1530 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1531 newOperands,
false);
1535 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1537 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1538 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1540 APInt one(value.getBitWidth(), 1,
false);
1544 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1545 auto mulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1547 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1548 newOperands.push_back(mulOp);
1549 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1550 newOperands,
false);
1554 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1556 if (mulOp && mulOp.getInputs().size() == 2 &&
1557 mulOp.getInputs()[0] == inputs[size - 2] &&
1558 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1560 APInt one(value.getBitWidth(), 1,
false);
1562 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1563 auto newMulOp = comb::MulOp::create(rewriter, op.getLoc(), factors,
false);
1565 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1566 newOperands.push_back(newMulOp);
1567 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1568 newOperands,
false);
1581 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1582 if (addOp && addOp.getInputs().size() == 2 &&
1583 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1584 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1587 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1588 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1589 op.getTwoState() && addOp.getTwoState());
1596OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1600 auto size = getInputs().size();
1601 auto inputs = adaptor.getInputs();
1605 return getInputs()[0];
1607 auto width = cast<IntegerType>(getType()).getWidth();
1609 return getIntAttr(APInt::getZero(0), getContext());
1611 APInt value(width, 1,
false);
1614 for (
auto operand : inputs) {
1615 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1618 value *= attr.getValue();
1627LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1631 auto inputs = op.getInputs();
1632 auto size = inputs.size();
1633 assert(size > 1 &&
"expected 2 or more operands");
1635 APInt value, value2;
1638 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1639 value.isPowerOf2()) {
1641 value.exactLogBase2());
1643 comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift,
false);
1645 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1646 ArrayRef<Value>(shlOp),
false);
1651 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1652 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1653 inputs.drop_back());
1658 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1659 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1661 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1662 newOperands.push_back(cst);
1663 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1679template <
class Op,
bool isSigned>
1680static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1681 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1683 if (rhsValue.getValue() == 1)
1687 if (rhsValue.getValue().isZero())
1694OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1697 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1700OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1706template <
class Op,
bool isSigned>
1707static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1708 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1710 if (rhsValue.getValue() == 1)
1711 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1715 if (rhsValue.getValue().isZero())
1719 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1721 if (lhsValue.getValue().isZero())
1722 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1729OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1732 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1735OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1741LogicalResult DivUOp::canonicalize(
DivUOp op, PatternRewriter &rewriter) {
1747LogicalResult ModUOp::canonicalize(
ModUOp op, PatternRewriter &rewriter) {
1759OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1763 if (getNumOperands() == 1)
1764 return getOperand(0);
1767 for (
auto attr : adaptor.getInputs())
1768 if (!attr || !isa<IntegerAttr>(attr))
1772 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1773 APInt result(resultWidth, 0);
1775 unsigned nextInsertion = resultWidth;
1777 for (
auto attr : adaptor.getInputs()) {
1778 auto chunk = cast<IntegerAttr>(attr).getValue();
1779 nextInsertion -= chunk.getBitWidth();
1780 result.insertBits(chunk, nextInsertion);
1786LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1790 auto inputs = op.getInputs();
1791 auto size = inputs.size();
1792 assert(size > 1 &&
"expected 2 or more operands");
1797 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1798 ValueRange replacements) -> LogicalResult {
1799 SmallVector<Value, 4> newOperands;
1800 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1801 newOperands.append(replacements.begin(), replacements.end());
1802 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1803 if (newOperands.size() == 1)
1806 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1811 Value commonOperand = inputs[0];
1812 for (
size_t i = 0; i != size; ++i) {
1814 if (inputs[i] != commonOperand)
1815 commonOperand = Value();
1819 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1820 return flattenConcat(i, i, subConcat->getOperands());
1825 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1826 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1827 unsigned prevWidth = prevCst.getValue().getBitWidth();
1828 unsigned thisWidth = cst.getValue().getBitWidth();
1829 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1830 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1834 return flattenConcat(i - 1, i, replacement);
1839 if (inputs[i] == inputs[i - 1]) {
1841 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1842 return flattenConcat(i - 1, i, replacement);
1847 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1849 if (repl.getOperand() == inputs[i - 1]) {
1850 Value replacement = rewriter.createOrFold<ReplicateOp>(
1851 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1852 return flattenConcat(i - 1, i, replacement);
1855 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1856 if (prevRepl.getOperand() == repl.getOperand()) {
1857 Value replacement = rewriter.createOrFold<ReplicateOp>(
1858 op.getLoc(), repl.getOperand(),
1859 repl.getMultiple() + prevRepl.getMultiple());
1860 return flattenConcat(i - 1, i, replacement);
1866 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1867 if (repl.getOperand() == inputs[i]) {
1868 Value replacement = rewriter.createOrFold<ReplicateOp>(
1869 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1870 return flattenConcat(i - 1, i, replacement);
1876 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1877 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1878 if (extract.getInput() == prevExtract.getInput()) {
1879 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1880 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1881 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1882 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1885 extract.getInput(), extract.getLowBit());
1886 return flattenConcat(i - 1, i, replacement);
1899 static std::optional<ArraySlice>
get(Value value) {
1900 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1902 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1905 if (
auto arraySlice =
1908 arraySlice.getInput(), arraySlice.getLowIndex(),
1909 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1911 return std::nullopt;
1914 if (
auto extractOpt = ArraySlice::get(inputs[i])) {
1915 if (
auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1917 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1918 prevExtractOpt->input == extractOpt->input &&
1919 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1920 extractOpt->width)) {
1921 auto resType = hw::ArrayType::get(
1922 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1924 extractOpt->width + prevExtractOpt->width);
1925 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1927 rewriter, op.getLoc(), resIntType,
1929 prevExtractOpt->input,
1930 extractOpt->index));
1931 return flattenConcat(i - 1, i, replacement);
1939 if (commonOperand) {
1940 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1952OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1957 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1958 return getTrueValue();
1959 if (
auto tv = adaptor.getTrueValue())
1960 if (tv == adaptor.getFalseValue())
1965 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1966 if (pred.getValue().isZero() && getFalseValue() != getResult())
1967 return getFalseValue();
1968 if (pred.getValue().isOne() && getTrueValue() != getResult())
1969 return getTrueValue();
1973 if (getCond().getType() == getTrueValue().getType())
1974 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1975 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1976 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1977 hw::getBitWidth(getType()) == 1 && getCond() != getResult())
1993 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1995 auto requiredPredicate =
1996 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1997 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
2007 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
2010 for (
auto operand : orOp.getOperands())
2017 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
2020 for (
auto operand : andOp.getOperands())
2039 PatternRewriter &rewriter,
MuxOp rootMux,
bool isFalseSide,
2045 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2048 Value indexValue = rootCmp.getLhs();
2051 auto getCaseValue = [&](
MuxOp mux) -> Value {
2052 return mux.getOperand(1 +
unsigned(!isFalseSide));
2057 auto getTreeValue = [&](
MuxOp mux) -> Value {
2058 return mux.getOperand(1 +
unsigned(isFalseSide));
2063 SmallVector<Location> locationsFound;
2064 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2068 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2070 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2071 valuesFound.push_back({cst, getCaseValue(mux)});
2072 locationsFound.push_back(mux.getCond().getLoc());
2073 locationsFound.push_back(mux->getLoc());
2078 if (!collectConstantValues(rootMux))
2082 if (rootMux->hasOneUse()) {
2083 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2084 if (getTreeValue(userMux) == rootMux.getResult() &&
2092 auto nextTreeValue = getTreeValue(rootMux);
2094 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2095 if (!nextMux || !nextMux->hasOneUse())
2097 if (!collectConstantValues(nextMux))
2099 nextTreeValue = getTreeValue(nextMux);
2102 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2104 if (indexWidth > 20)
2107 auto foldingStyle = styleFn(indexWidth, valuesFound.size());
2111 uint64_t tableSize = 1ULL << indexWidth;
2115 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2120 for (
auto &elt :
llvm::reverse(valuesFound)) {
2121 uint64_t idx = elt.first.getValue().getZExtValue();
2122 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2123 table[idx] = elt.second;
2127 SmallVector<Value> bits;
2136 "unknown folding style");
2140 std::reverse(table.begin(), table.end());
2143 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2145 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2160 PatternRewriter &rewriter) {
2161 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2162 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2166 if (fullyAssoc->getNumOperands() == 2)
2167 return fullyAssoc->getOperand(operandNo ^ 1);
2170 if (fullyAssoc->hasOneUse()) {
2171 rewriter.modifyOpInPlace(fullyAssoc,
2172 [&]() { fullyAssoc->eraseOperand(operandNo); });
2173 return fullyAssoc->getResult(0);
2177 SmallVector<Value> operands;
2178 operands.append(fullyAssoc->getOperands().begin(),
2179 fullyAssoc->getOperands().begin() + operandNo);
2180 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2181 fullyAssoc->getOperands().end());
2183 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2184 Value excluded = fullyAssoc->getOperand(operandNo);
2188 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2190 return opWithoutExcluded;
2200 PatternRewriter &rewriter) {
2203 Operation *subExpr =
2204 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2205 if (!subExpr || subExpr->getNumOperands() < 2)
2209 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2214 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2215 size_t opNo = 0, e = subExpr->getNumOperands();
2216 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2222 Value cond = op.getCond();
2228 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2233 Value subCond = subMux.getCond();
2236 if (subMux.getTrueValue() == commonValue)
2237 otherValue = subMux.getFalseValue();
2238 else if (subMux.getFalseValue() == commonValue) {
2239 otherValue = subMux.getTrueValue();
2249 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2250 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2251 otherValue, op.getTwoState());
2257 bool isaAndOp = isa<AndOp>(subExpr);
2258 if (isTrueOperand ^ isaAndOp)
2262 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2265 bool isaXorOp = isa<XorOp>(subExpr);
2266 bool isaOrOp = isa<OrOp>(subExpr);
2275 if (isaOrOp || isaXorOp) {
2276 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2277 restOfAssoc,
false);
2279 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2280 commonValue,
false);
2282 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2288 assert(isaAndOp &&
"unexpected operation here");
2289 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2290 restOfAssoc,
false);
2291 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2302 PatternRewriter &rewriter) {
2305 if (!isa<ConcatOp>(trueOp))
2309 SmallVector<Value> trueOperands, falseOperands;
2313 size_t numTrueOperands = trueOperands.size();
2314 size_t numFalseOperands = falseOperands.size();
2316 if (!numTrueOperands || !numFalseOperands ||
2317 (trueOperands.front() != falseOperands.front() &&
2318 trueOperands.back() != falseOperands.back()))
2322 if (trueOperands.front() == falseOperands.front()) {
2323 SmallVector<Value> operands;
2325 for (i = 0; i < numTrueOperands; ++i) {
2326 Value trueOperand = trueOperands[i];
2327 if (trueOperand == falseOperands[i])
2328 operands.push_back(trueOperand);
2332 if (i == numTrueOperands) {
2339 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2340 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2341 mux->getLoc(), operands.front(), operands.size());
2343 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2347 operands.append(trueOperands.begin() + i, trueOperands.end());
2348 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2350 operands.append(falseOperands.begin() + i, falseOperands.end());
2352 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2355 Value lsb = rewriter.createOrFold<
MuxOp>(
2356 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2357 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2362 if (trueOperands.back() == falseOperands.back()) {
2363 SmallVector<Value> operands;
2366 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2367 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2368 operands.push_back(trueOperand);
2372 std::reverse(operands.begin(), operands.end());
2373 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2377 operands.append(trueOperands.begin(), trueOperands.end() - i);
2378 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2380 operands.append(falseOperands.begin(), falseOperands.end() - i);
2382 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2384 Value msb = rewriter.createOrFold<
MuxOp>(
2385 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2386 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2398 if (!trueVec || !falseVec)
2400 if (!trueVec.isUniform() || !falseVec.isUniform())
2403 auto mux = MuxOp::create(rewriter, op.getLoc(), op.getCond(),
2404 trueVec.getUniformElement(),
2405 falseVec.getUniformElement(), op.getTwoState());
2407 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2415 bool constCond, PatternRewriter &rewriter) {
2416 if (!muxValue.hasOneUse())
2418 auto *op = muxValue.getDefiningOp();
2419 if (!op || !isa_and_nonnull<CombDialect>(op->getDialect()))
2421 if (!llvm::is_contained(op->getOperands(), muxCond))
2423 OpBuilder::InsertionGuard guard(rewriter);
2424 rewriter.setInsertionPoint(op);
2427 rewriter.modifyOpInPlace(op, [&] {
2428 for (
auto &use : op->getOpOperands())
2429 if (use.get() == muxCond)
2437 using OpRewritePattern::OpRewritePattern;
2439 LogicalResult matchAndRewrite(
MuxOp op,
2440 PatternRewriter &rewriter)
const override;
2444foldToArrayCreateOnlyWhenDense(
size_t indexWidth,
size_t numEntries) {
2447 if (indexWidth >= 9 || numEntries < 3)
2453 uint64_t tableSize = 1ULL << indexWidth;
2454 if (numEntries >= tableSize * 5 / 8)
2459LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2460 PatternRewriter &rewriter)
const {
2464 bool isSignlessInt =
false;
2465 if (
auto intType = dyn_cast<IntegerType>(op.getType()))
2466 isSignlessInt = intType.isSignless();
2473 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value)) && isSignlessInt) {
2474 if (value.getBitWidth() == 1) {
2476 if (value.isZero()) {
2478 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2479 op.getFalseValue(),
false);
2484 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2485 op.getFalseValue(),
false);
2491 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2496 APInt xorValue = value ^ value2;
2497 if (xorValue.isPowerOf2()) {
2498 unsigned leadingZeros = xorValue.countLeadingZeros();
2499 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2500 SmallVector<Value, 3> operands;
2508 if (leadingZeros > 0)
2509 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2510 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2514 auto v1 = rewriter.createOrFold<
ExtractOp>(
2515 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2516 auto v2 = rewriter.createOrFold<
ExtractOp>(
2517 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2518 operands.push_back(rewriter.createOrFold<
MuxOp>(
2519 op.getLoc(), op.getCond(), v1, v2,
false));
2521 if (trailingZeros > 0)
2522 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2523 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2525 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2532 if (value.isAllOnes() && value2.isZero()) {
2533 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2534 rewriter, op, op.getType(), op.getCond());
2540 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2541 isSignlessInt && value.getBitWidth() == 1) {
2543 if (value.isZero()) {
2544 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2545 op.getTrueValue(),
false);
2552 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2553 op.getFalseValue(),
false);
2554 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2555 op.getTrueValue(),
false);
2561 Operation *condOp = op.getCond().getDefiningOp();
2562 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2564 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2565 subExpr, op.getFalseValue(),
2566 op.getTrueValue(),
true);
2573 if (condOp && condOp->hasOneUse()) {
2574 SmallVector<Value> invertedOperands;
2578 auto getInvertedOperands = [&]() ->
bool {
2579 for (Value operand : condOp->getOperands()) {
2580 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2581 invertedOperands.push_back(subExpr);
2588 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2590 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2591 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2592 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2596 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2598 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2599 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2600 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2606 if (
auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
2607 falseMux && falseMux != op) {
2609 if (op.getCond() == falseMux.getCond() &&
2610 falseMux.getFalseValue() != falseMux) {
2611 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2612 rewriter, op, op.getCond(), op.getTrueValue(),
2613 falseMux.getFalseValue(), op.getTwoStateAttr());
2619 foldToArrayCreateOnlyWhenDense))
2623 if (
auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
2624 trueMux && trueMux != op) {
2626 if (op.getCond() == trueMux.getCond()) {
2627 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2628 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2629 op.getFalseValue(), op.getTwoStateAttr());
2635 foldToArrayCreateOnlyWhenDense))
2640 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2641 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2642 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2643 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2645 auto subMux = MuxOp::create(
2646 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2647 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2648 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2649 trueMux.getTrueValue(), subMux,
2650 op.getTwoStateAttr());
2655 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2656 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2657 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2658 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2660 auto subMux = MuxOp::create(
2661 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2662 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2663 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2664 subMux, trueMux.getFalseValue(),
2665 op.getTwoStateAttr());
2670 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2671 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2672 trueMux && falseMux &&
2673 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2674 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2677 MuxOp::create(rewriter,
2678 rewriter.getFusedLoc(
2679 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2680 op.getCond(), trueMux.getCond(), falseMux.getCond());
2681 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2682 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2683 op.getTwoStateAttr());
2695 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2696 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2697 if (trueOp->getName() == falseOp->getName())
2710 if (op.getTrueValue().getDefiningOp() &&
2711 op.getTrueValue().getDefiningOp() != op)
2714 if (op.getFalseValue().getDefiningOp() &&
2715 op.getFalseValue().getDefiningOp() != op)
2726 if (op.getInputs().empty() || op.isUniform())
2728 auto inputs = op.getInputs();
2729 if (inputs.size() <= 1)
2734 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2739 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2740 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2741 if (!input || first.getCond() != input.getCond())
2746 SmallVector<Value> trues{first.getTrueValue()};
2747 SmallVector<Value> falses{first.getFalseValue()};
2748 SmallVector<Location> locs{first->getLoc()};
2749 bool isTwoState =
true;
2750 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2751 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2752 trues.push_back(input.getTrueValue());
2753 falses.push_back(input.getFalseValue());
2754 locs.push_back(input->getLoc());
2755 if (!input.getTwoState())
2760 auto loc = FusedLoc::get(op.getContext(), locs);
2764 auto arrayTy = op.getType();
2767 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2768 trueValues, falseValues, isTwoState);
2773 using OpRewritePattern::OpRewritePattern;
2776 PatternRewriter &rewriter)
const override {
2777 if (foldArrayOfMuxes(op, rewriter))
2785void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2787 results.insert<MuxRewriter, ArrayRewriter>(
context);
2798 switch (predicate) {
2799 case ICmpPredicate::eq:
2801 case ICmpPredicate::ne:
2803 case ICmpPredicate::slt:
2804 return lhs.slt(rhs);
2805 case ICmpPredicate::sle:
2806 return lhs.sle(rhs);
2807 case ICmpPredicate::sgt:
2808 return lhs.sgt(rhs);
2809 case ICmpPredicate::sge:
2810 return lhs.sge(rhs);
2811 case ICmpPredicate::ult:
2812 return lhs.ult(rhs);
2813 case ICmpPredicate::ule:
2814 return lhs.ule(rhs);
2815 case ICmpPredicate::ugt:
2816 return lhs.ugt(rhs);
2817 case ICmpPredicate::uge:
2818 return lhs.uge(rhs);
2819 case ICmpPredicate::ceq:
2821 case ICmpPredicate::cne:
2823 case ICmpPredicate::weq:
2825 case ICmpPredicate::wne:
2828 llvm_unreachable(
"unknown comparison predicate");
2834 switch (predicate) {
2835 case ICmpPredicate::eq:
2836 case ICmpPredicate::sle:
2837 case ICmpPredicate::sge:
2838 case ICmpPredicate::ule:
2839 case ICmpPredicate::uge:
2840 case ICmpPredicate::ceq:
2841 case ICmpPredicate::weq:
2843 case ICmpPredicate::ne:
2844 case ICmpPredicate::slt:
2845 case ICmpPredicate::sgt:
2846 case ICmpPredicate::ult:
2847 case ICmpPredicate::ugt:
2848 case ICmpPredicate::cne:
2849 case ICmpPredicate::wne:
2852 llvm_unreachable(
"unknown comparison predicate");
2855OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2858 if (getLhs() == getRhs()) {
2860 return IntegerAttr::get(getType(), val);
2864 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2865 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2868 return IntegerAttr::get(getType(), val);
2876template <
typename Range>
2878 size_t commonPrefixLength = 0;
2879 auto ia = a.begin();
2880 auto ib = b.begin();
2882 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2888 return commonPrefixLength;
2892 size_t totalWidth = 0;
2893 for (
auto operand : operands) {
2896 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2898 totalWidth += width;
2908 PatternRewriter &rewriter) {
2912 SmallVector<Value> lhsOperands, rhsOperands;
2915 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2917 auto formCatOrReplicate = [&](Location loc,
2918 ArrayRef<Value> operands) -> Value {
2919 assert(!operands.empty());
2920 Value sameElement = operands[0];
2921 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2922 if (sameElement != operands[i])
2923 sameElement = Value();
2925 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2927 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2930 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2931 Value rhs) -> LogicalResult {
2932 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2937 size_t commonPrefixLength =
2939 if (commonPrefixLength == lhsOperands.size()) {
2942 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2948 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2950 size_t commonPrefixTotalWidth =
2951 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2952 size_t commonSuffixTotalWidth =
2953 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2954 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2955 .drop_back(commonSuffixLength);
2956 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2957 .drop_back(commonSuffixLength);
2959 auto replaceWithoutReplicatingSignBit = [&]() {
2960 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2961 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2962 return replaceWith(op.getPredicate(), newLhs, newRhs);
2965 auto replaceWithReplicatingSignBit = [&]() {
2966 auto firstNonEmptyValue = lhsOperands[0];
2967 auto firstNonEmptyElemWidth =
2968 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2969 Value signBit = rewriter.createOrFold<
ExtractOp>(
2970 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2972 auto newLhs = ConcatOp::create(rewriter, lhs->getLoc(), signBit, lhsOnly);
2973 auto newRhs = ConcatOp::create(rewriter, rhs->getLoc(), signBit, rhsOnly);
2974 return replaceWith(op.getPredicate(), newLhs, newRhs);
2977 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2979 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2980 return replaceWithoutReplicatingSignBit();
2986 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2987 return replaceWithReplicatingSignBit();
2989 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2991 return replaceWithoutReplicatingSignBit();
3005 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
3006 PatternRewriter &rewriter) {
3010 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
3011 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
3014 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
3015 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3023 SmallVector<Value> newConcatOperands;
3024 auto newConstant = APInt::getZeroWidth();
3029 unsigned knownMSB = bitsKnown.countLeadingOnes();
3031 Value operand = cmpOp.getLhs();
3036 while (knownMSB != bitsKnown.getBitWidth()) {
3039 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3042 unsigned unknownBits = bitsKnown.countLeadingZeros();
3043 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3044 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
3045 operand.getLoc(), operand, lowBit,
3047 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3050 newConcatOperands.push_back(spanOperand);
3053 if (newConstant.getBitWidth() != 0)
3054 newConstant = newConstant.concat(spanConstant);
3056 newConstant = spanConstant;
3059 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3060 bitsKnown = bitsKnown.trunc(newWidth);
3061 knownMSB = bitsKnown.countLeadingOnes();
3067 if (newConcatOperands.empty()) {
3068 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3069 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3075 Value concatResult =
3076 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
3080 rewriter, cmpOp.getOperand(1).getLoc(), newConstant);
3082 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
3083 cmpOp.getPredicate(), concatResult,
3084 newConstantOp, cmpOp.getTwoState());
3090 PatternRewriter &rewriter) {
3091 auto ip = rewriter.saveInsertionPoint();
3092 rewriter.setInsertionPoint(xorOp);
3094 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
3096 xorRHS.getValue() ^ rhs);
3098 switch (xorOp.getNumOperands()) {
3102 APInt::getZero(rhs.getBitWidth()));
3106 newLHS = xorOp.getOperand(0);
3110 SmallVector<Value> newOperands(xorOp.getOperands());
3111 newOperands.pop_back();
3112 newLHS = XorOp::create(rewriter, xorOp.getLoc(), newOperands,
false);
3116 bool xorMultipleUses = !xorOp->hasOneUse();
3120 if (xorMultipleUses)
3121 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3125 rewriter.restoreInsertionPoint(ip);
3126 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3127 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS,
false);
3130LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3136 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3137 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3138 "Should be folded");
3139 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3140 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3141 op.getRhs(), op.getLhs(), op.getTwoState());
3146 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3151 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3152 Value rhs) -> LogicalResult {
3153 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3154 rhs, op.getTwoState());
3158 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3159 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3160 APInt(1, constant));
3164 switch (op.getPredicate()) {
3165 case ICmpPredicate::slt:
3167 if (rhs.isMaxSignedValue())
3168 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3170 if (rhs.isMinSignedValue())
3171 return replaceWithConstantI1(0);
3173 if ((rhs - 1).isMinSignedValue())
3174 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3177 case ICmpPredicate::sgt:
3179 if (rhs.isMinSignedValue())
3180 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3182 if (rhs.isMaxSignedValue())
3183 return replaceWithConstantI1(0);
3185 if ((rhs + 1).isMaxSignedValue())
3186 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3189 case ICmpPredicate::ult:
3191 if (rhs.isAllOnes())
3192 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3195 return replaceWithConstantI1(0);
3197 if ((rhs - 1).isZero())
3198 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3202 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3203 rhs.getBitWidth()) {
3204 auto numOnes = rhs.countLeadingOnes();
3206 rhs.getBitWidth() - numOnes, numOnes);
3207 return replaceWith(ICmpPredicate::ne, smaller,
3212 case ICmpPredicate::ugt:
3215 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3217 if (rhs.isAllOnes())
3218 return replaceWithConstantI1(0);
3220 if ((rhs + 1).isAllOnes())
3221 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3225 if ((rhs + 1).isPowerOf2()) {
3226 auto numOnes = rhs.countTrailingOnes();
3227 auto newWidth = rhs.getBitWidth() - numOnes;
3230 return replaceWith(ICmpPredicate::ne, smaller,
3235 case ICmpPredicate::sle:
3237 if (rhs.isMaxSignedValue())
3238 return replaceWithConstantI1(1);
3240 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3241 case ICmpPredicate::sge:
3243 if (rhs.isMinSignedValue())
3244 return replaceWithConstantI1(1);
3246 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3247 case ICmpPredicate::ule:
3249 if (rhs.isAllOnes())
3250 return replaceWithConstantI1(1);
3252 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3253 case ICmpPredicate::uge:
3256 return replaceWithConstantI1(1);
3258 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3259 case ICmpPredicate::eq:
3260 if (rhs.getBitWidth() == 1) {
3263 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3268 if (rhs.isAllOnes()) {
3275 case ICmpPredicate::ne:
3276 if (rhs.getBitWidth() == 1) {
3282 if (rhs.isAllOnes()) {
3284 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3291 case ICmpPredicate::ceq:
3292 case ICmpPredicate::cne:
3293 case ICmpPredicate::weq:
3294 case ICmpPredicate::wne:
3300 if (op.getPredicate() == ICmpPredicate::eq ||
3301 op.getPredicate() == ICmpPredicate::ne) {
3306 if (!knownBits.isUnknown())
3313 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3320 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3321 if (rhs.isAllOnes() || rhs.isZero()) {
3322 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3325 rhs.isAllOnes() ? APInt::getAllOnes(width)
3326 : APInt::getZero(width));
3327 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3328 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3338 if (Operation *opLHS = op.getLhs().getDefiningOp())
3339 if (Operation *opRHS = op.getRhs().getDefiningOp())
3340 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3341 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool assumeMuxCondInOperand(Value muxCond, Value muxValue, bool constCond, PatternRewriter &rewriter)
If the mux condition is an operand to the op defining its true or false value, replace the condition ...
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool isOpTriviallyRecursive(Operation *op)
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(elements, Type result_type=None)
create(array_value, low_index, ret_type)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl< Value > &bits)
Extract bits from a value.
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
Value createOrFoldNot(OpBuilder &builder, Location loc, Value value, bool twoState=false)
Create a `‘Not’' gate on a value.
MuxChainWithComparisonFoldingStyle
Enum for mux chain folding styles.
LogicalResult convertModUByPowerOfTwo(ModUOp modOp, mlir::PatternRewriter &rewriter)
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
Value constructMuxTree(OpBuilder &builder, Location loc, ArrayRef< Value > selectors, ArrayRef< Value > leafNodes, Value outOfBoundsValue)
Construct a mux tree for given leaf nodes.
Value createOrFoldSExt(OpBuilder &builder, Location loc, Value value, Type destTy)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
LogicalResult convertDivUByPowerOfTwo(DivUOp divOp, mlir::PatternRewriter &rewriter)
Convert unsigned division or modulo by a power of two.
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.