13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15#include "llvm/ADT/SetVector.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/KnownBits.h"
23using namespace matchers;
31 ArrayRef<Value> operands, OpBuilder &builder) {
32 OperationState state(loc, name);
33 state.addOperands(operands);
34 state.addTypes(operands[0].getType());
35 return builder.create(state)->getResult(0);
38static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
39 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
46 for (
auto op :
concat.getOperands())
48 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
49 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
60 return op->hasAttr(
"sv.attributes");
64template <
typename SubType>
65struct ComplementMatcher {
67 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
68 bool match(Operation *op) {
69 auto xorOp = dyn_cast<XorOp>(op);
70 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
75template <
typename SubType>
76static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
77 return ComplementMatcher<SubType>(subExpr);
83 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
84 "must be commutative operations"));
85 if (op->hasOneUse()) {
86 auto *user = *op->getUsers().begin();
87 return user->getName() == op->getName() &&
88 op->getAttrOfType<UnitAttr>(
"twoState") ==
89 user->getAttrOfType<UnitAttr>(
"twoState") &&
90 op->getBlock() == user->getBlock();
105 auto inputs = op->getOperands();
107 SmallVector<Value, 4> newOperands;
108 SmallVector<Location, 4> newLocations{op->getLoc()};
109 newOperands.reserve(inputs.size());
111 decltype(inputs.begin()) current, end;
114 SmallVector<Element> worklist;
115 worklist.push_back({inputs.begin(), inputs.end()});
116 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
117 bool changed =
false;
118 while (!worklist.empty()) {
119 auto &element = worklist.back();
122 if (element.current == element.end) {
127 Value value = *element.current++;
128 auto *flattenOp = value.getDefiningOp();
131 if (!flattenOp || flattenOp->getName() != op->getName() ||
132 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
133 flattenOp->getBlock() != op->getBlock()) {
134 newOperands.push_back(value);
139 if (!value.hasOneUse()) {
147 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
150 newOperands.push_back(value);
158 auto flattenOpInputs = flattenOp->getOperands();
159 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
160 newLocations.push_back(flattenOp->getLoc());
166 Value result =
createGenericOp(FusedLoc::get(op->getContext(), newLocations),
167 op->getName(), newOperands, rewriter);
169 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
177static std::pair<size_t, size_t>
179 size_t originalOpWidth) {
180 auto users = op->getUsers();
182 "getLowestBitAndHighestBitRequired cannot operate on "
183 "a empty list of uses.");
187 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
188 size_t highestBitRequired = 0;
190 for (
auto *user : users) {
191 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
192 size_t lowBit = extractOp.getLowBit();
194 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
195 highestBitRequired = std::max(highestBitRequired, highBit);
196 lowestBitRequired = std::min(lowestBitRequired, lowBit);
200 highestBitRequired = originalOpWidth - 1;
201 lowestBitRequired = 0;
205 return {lowestBitRequired, highestBitRequired};
210 PatternRewriter &rewriter) {
211 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
217 if (range.second + 1 == opType.getWidth() && range.first == 0)
220 SmallVector<Value> args;
221 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
222 for (
auto inop : op.getOperands()) {
224 if (inop.getType() != op.getType())
225 args.push_back(inop);
227 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
230 auto newop = rewriter.create<OpTy>(op.getLoc(), newType, args);
231 newop->setDialectAttrs(op->getDialectAttrs());
232 if (op.getTwoState())
233 newop.setTwoState(
true);
235 Value newResult = newop.getResult();
237 newResult = rewriter.createOrFold<
ConcatOp>(
238 op.getLoc(), newResult,
240 APInt::getZero(range.first)));
241 if (range.second + 1 < opType.getWidth())
242 newResult = rewriter.createOrFold<
ConcatOp>(
245 op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
247 rewriter.replaceOp(op, newResult);
255OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
257 if (cast<IntegerType>(getType()).
getWidth() ==
258 getInput().getType().getIntOrFloatBitWidth())
262 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
263 if (input.getValue().getBitWidth() == 1) {
264 if (input.getValue().isZero())
266 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
269 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
273 APInt result = APInt::getZeroWidth();
274 for (
auto i = getMultiple(); i != 0; --i)
275 result = result.concat(input.getValue());
282OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
284 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
285 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
297 hw::PEO paramOpcode) {
298 assert(operands.size() == 2 &&
"binary op takes two operands");
299 if (!operands[0] || !operands[1])
304 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
305 cast<TypedAttr>(operands[1]));
308OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
309 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
310 unsigned shift = rhs.getValue().getZExtValue();
311 unsigned width = getType().getIntOrFloatBitWidth();
313 return getOperand(0);
315 return getIntAttr(APInt::getZero(width), getContext());
321LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
324 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
327 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
328 unsigned shift = value.getZExtValue();
331 if (width <= shift || shift == 0)
335 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
339 rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), 0, width - shift);
341 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
345OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
346 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
347 unsigned shift = rhs.getValue().getZExtValue();
349 return getOperand(0);
351 unsigned width = getType().getIntOrFloatBitWidth();
353 return getIntAttr(APInt::getZero(width), getContext());
358LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
361 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
364 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
365 unsigned shift = value.getZExtValue();
368 if (width <= shift || shift == 0)
372 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
375 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
378 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
382OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
383 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
384 if (rhs.getValue().getZExtValue() == 0)
385 return getOperand(0);
390LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
393 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
396 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
397 unsigned shift = value.getZExtValue();
400 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
401 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
403 if (width <= shift) {
408 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
411 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
419OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
421 if (getInput().getType() == getType())
425 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
426 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
427 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
438 PatternRewriter &rewriter) {
439 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
440 size_t beginOfFirstRelevantElement = 0;
441 auto it = reversedConcatArgs.begin();
442 size_t lowBit = op.getLowBit();
445 for (; it != reversedConcatArgs.end(); it++) {
446 assert(beginOfFirstRelevantElement <= lowBit &&
447 "incorrectly moved past an element that lowBit has coverage over");
450 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
451 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
475 beginOfFirstRelevantElement += operandWidth;
477 assert(it != reversedConcatArgs.end() &&
478 "incorrectly failed to find an element which contains coverage of "
481 SmallVector<Value> reverseConcatArgs;
482 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
483 size_t extractLo = lowBit - beginOfFirstRelevantElement;
488 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
489 auto concatArg = *it;
490 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
491 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
493 if (widthToConsume == operandWidth && extractLo == 0) {
494 reverseConcatArgs.push_back(concatArg);
496 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
497 reverseConcatArgs.push_back(
498 rewriter.create<
ExtractOp>(op.getLoc(), resultType, *it, extractLo));
501 widthRemaining -= widthToConsume;
507 if (reverseConcatArgs.size() == 1) {
510 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
511 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
518 PatternRewriter &rewriter) {
519 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
520 auto replicateEltWidth =
521 replicate.getOperand().getType().getIntOrFloatBitWidth();
525 if (op.getLowBit() % replicateEltWidth == 0 &&
526 extractResultWidth % replicateEltWidth == 0) {
527 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
528 replicate.getOperand());
534 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
536 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
537 rewriter, op, op.getType(), replicate.getOperand(),
538 op.getLowBit() % replicateEltWidth);
547LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
548 auto *inputOp = op.getInput().getDefiningOp();
555 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
557 if (knownBits.isConstant()) {
558 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
559 knownBits.getConstant());
565 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
566 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
567 rewriter, op, op.getType(), innerExtract.getInput(),
568 innerExtract.getLowBit() + op.getLowBit());
573 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
577 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
583 if (inputOp && inputOp->getNumOperands() == 2 &&
584 isa<AndOp, OrOp, XorOp>(inputOp)) {
585 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
586 auto extractedCst = cstRHS.getValue().extractBits(
587 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
588 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
589 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
590 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
598 if (isa<AndOp>(inputOp)) {
601 unsigned lz = extractedCst.countLeadingZeros();
602 unsigned tz = extractedCst.countTrailingZeros();
603 unsigned pop = extractedCst.popcount();
604 if (extractedCst.getBitWidth() - lz - tz == pop) {
605 auto resultTy = rewriter.getIntegerType(pop);
606 SmallVector<Value> resultElts;
609 op.getLoc(), APInt::getZero(lz)));
610 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
611 op.getLoc(), resultTy, inputOp->getOperand(0),
612 op.getLowBit() + tz));
615 op.getLoc(), APInt::getZero(tz)));
616 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
625 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
626 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
628 if (shlOp->hasOneUse())
630 if (lhsCst.getValue().isOne()) {
633 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
634 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
635 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
651 hw::PEO paramOpcode) {
652 assert(operands.size() > 1 &&
"caller should handle one-operand case");
655 if (!operands[1] || !operands[0])
659 if (llvm::all_of(operands.drop_front(2),
660 [&](Attribute in) { return !!in; })) {
661 SmallVector<mlir::TypedAttr> typedOperands;
662 typedOperands.reserve(operands.size());
663 for (
auto operand : operands) {
664 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
665 typedOperands.push_back(typedOperand);
669 if (typedOperands.size() == operands.size())
670 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
686 size_t concatIdx,
const APInt &cst,
687 PatternRewriter &rewriter) {
688 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
689 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
694 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
695 auto *operandOp = operand.getDefiningOp();
700 if (isa<hw::ConstantOp>(operandOp))
704 return operandOp->getName() == logicalOp->getName() &&
705 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
706 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
714 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
715 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
722 SmallVector<Value> newConcatOperands;
723 newConcatOperands.reserve(concatOp->getNumOperands());
726 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
727 for (Value operand : concatOp->getOperands()) {
728 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
729 nextOperandBit -= operandWidth;
732 logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
734 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
743 if (logicalOp->getNumOperands() > 2) {
744 auto origOperands = logicalOp->getOperands();
745 SmallVector<Value> operands;
747 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
749 operands.append(origOperands.begin() + concatIdx + 1,
750 origOperands.begin() + (origOperands.size() - 1));
752 operands.push_back(newResult);
753 newResult = createLogicalOp(operands);
763 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
765 for (
auto op : operands) {
766 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
767 icmpOp && icmpOp.getTwoState()) {
768 auto predicate = icmpOp.getPredicate();
769 auto lhs = icmpOp.getLhs();
770 auto rhs = icmpOp.getRhs();
771 if (seenPredicates.contains(
772 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
775 seenPredicates.insert({predicate, lhs, rhs});
781OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
782 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
784 auto inputs = adaptor.getInputs();
787 for (
auto operand : inputs) {
790 value &= cast<IntegerAttr>(operand).getValue();
796 if (inputs.size() == 2 && inputs[1] &&
797 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
798 return getInputs()[0];
801 if (llvm::all_of(getInputs(),
802 [&](
auto in) {
return in == this->getInputs()[0]; }))
803 return getInputs()[0];
806 for (Value arg : getInputs()) {
809 for (Value arg2 : getInputs())
812 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
833template <
typename Op>
835 if (!op.getType().isInteger(1))
838 auto inputs = op.getInputs();
839 size_t size = inputs.size();
841 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
844 Value source = sourceOp.getOperand();
847 if (size != source.getType().getIntOrFloatBitWidth())
851 llvm::BitVector bits(size);
852 bits.set(sourceOp.getLowBit());
854 for (
size_t i = 1; i != size; ++i) {
855 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
856 if (!extractOp || extractOp.getOperand() != source)
858 bits.set(extractOp.getLowBit());
861 return bits.all() ? source : Value();
868template <
typename Op>
871 constexpr unsigned limit = 3;
872 auto inputs = op.getInputs();
874 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
875 llvm::SmallDenseSet<Op, 8> checked;
882 llvm::SmallVector<OpWithDepth, 8> worklist;
884 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
888 if (depth < limit && input.getParentBlock() == op->getBlock()) {
889 auto inputOp = input.template getDefiningOp<Op>();
890 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
891 checked.insert(inputOp).second)
892 worklist.push_back({inputOp, depth + 1});
896 for (
auto input : uniqueInputs)
899 while (!worklist.empty()) {
900 auto item = worklist.pop_back_val();
902 for (
auto input : item.op.getInputs()) {
903 uniqueInputs.remove(input);
904 enqueue(input, item.depth);
908 if (uniqueInputs.size() < inputs.size()) {
909 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
910 uniqueInputs.getArrayRef(),
918LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
919 auto inputs = op.getInputs();
920 auto size = inputs.size();
932 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
936 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
938 if (value.isAllOnes()) {
939 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
940 inputs.drop_back(),
false);
948 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
949 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value & value2);
950 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
951 newOperands.push_back(cst);
952 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
958 if (size == 2 && value.isPowerOf2()) {
963 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
964 auto replicateOperand = replicate.getOperand();
965 if (replicateOperand.getType().isInteger(1)) {
966 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
967 auto trailingZeros = value.countTrailingZeros();
970 SmallVector<Value, 3> concatOperands;
971 if (trailingZeros != resultWidth - 1) {
973 op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
974 concatOperands.push_back(highZeros);
976 concatOperands.push_back(replicateOperand);
977 if (trailingZeros != 0) {
979 op.getLoc(), APInt::getZero(trailingZeros));
980 concatOperands.push_back(lowZeros);
982 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
983 rewriter, op, op.getType(), concatOperands);
990 if (
auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
993 (value.countLeadingZeros() || value.countTrailingZeros())) {
994 unsigned lz = value.countLeadingZeros();
995 unsigned tz = value.countTrailingZeros();
998 auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
999 Value smallElt = rewriter.createOrFold<
ExtractOp>(
1000 extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1001 extractOp.getLowBit() + tz);
1003 APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1004 if (!smallMask.isAllOnes()) {
1005 auto loc = inputs.back().getLoc();
1006 smallElt = rewriter.createOrFold<
AndOp>(
1013 SmallVector<Value> resultElts;
1015 resultElts.push_back(
1016 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1017 resultElts.push_back(smallElt);
1019 resultElts.push_back(
1020 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1021 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
1029 for (
size_t i = 0; i < size - 1; ++i) {
1030 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1043 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1044 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1045 source, cmpAgainst);
1053OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1054 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1055 auto inputs = adaptor.getInputs();
1057 for (
auto operand : inputs) {
1060 value |= cast<IntegerAttr>(operand).getValue();
1061 if (value.isAllOnes())
1066 if (inputs.size() == 2 && inputs[1] &&
1067 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1068 return getInputs()[0];
1071 if (llvm::all_of(getInputs(),
1072 [&](
auto in) {
return in == this->getInputs()[0]; }))
1073 return getInputs()[0];
1076 for (Value arg : getInputs()) {
1078 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1079 for (Value arg2 : getInputs())
1080 if (arg2 == subExpr)
1082 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1092 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1099LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1100 auto inputs = op.getInputs();
1101 auto size = inputs.size();
1113 assert(size > 1 &&
"expected 2 or more operands");
1117 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1119 if (value.isZero()) {
1120 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1121 inputs.drop_back());
1127 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1128 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value | value2);
1129 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1130 newOperands.push_back(cst);
1131 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1139 for (
size_t i = 0; i < size - 1; ++i) {
1140 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1153 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1154 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1155 source, cmpAgainst);
1161 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1163 if (op.getTwoState() && firstMux.getTwoState() &&
1164 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1166 SmallVector<Value> conditions{firstMux.getCond()};
1167 auto check = [&](Value v) {
1171 conditions.push_back(mux.getCond());
1172 return mux.getTwoState() &&
1173 firstMux.getTrueValue() == mux.getTrueValue() &&
1174 firstMux.getFalseValue() == mux.getFalseValue();
1176 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1177 auto cond = rewriter.create<
comb::OrOp>(op.getLoc(), conditions,
true);
1178 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1179 rewriter, op, cond, firstMux.getTrueValue(),
1180 firstMux.getFalseValue(),
true);
1190OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1191 auto size = getInputs().size();
1192 auto inputs = adaptor.getInputs();
1196 return getInputs()[0];
1199 if (size == 2 && getInputs()[0] == getInputs()[1])
1200 return IntegerAttr::get(getType(), 0);
1203 if (inputs.size() == 2 && inputs[1] &&
1204 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1205 return getInputs()[0];
1209 if (isBinaryNot()) {
1211 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1212 subExpr != getResult())
1222 PatternRewriter &rewriter) {
1223 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1224 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1227 rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1228 icmp.getOperand(1), icmp.getTwoState());
1231 if (op.getNumOperands() > 2) {
1232 SmallVector<Value, 4> newOperands(op.getOperands());
1233 newOperands.pop_back();
1234 newOperands.erase(newOperands.begin() + icmpOperand);
1235 newOperands.push_back(result);
1236 result = rewriter.create<
XorOp>(op.getLoc(), newOperands, op.getTwoState());
1242LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1243 auto inputs = op.getInputs();
1244 auto size = inputs.size();
1245 assert(size > 1 &&
"expected 2 or more operands");
1248 if (inputs[size - 1] == inputs[size - 2]) {
1250 "expected idempotent case for 2 elements handled already.");
1251 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1252 inputs.drop_back(2),
false);
1258 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1260 if (value.isZero()) {
1261 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1262 inputs.drop_back(),
false);
1268 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1269 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value ^ value2);
1270 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1271 newOperands.push_back(cst);
1272 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1273 newOperands,
false);
1277 bool isSingleBit = value.getBitWidth() == 1;
1280 for (
size_t i = 0; i < size - 1; ++i) {
1281 Value operand = inputs[i];
1292 if (isSingleBit && operand.hasOneUse()) {
1293 assert(value == 1 &&
"single bit constant has to be one if not zero");
1294 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1310 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1317OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1319 if (getRhs() == getLhs())
1321 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1324 if (adaptor.getRhs()) {
1326 if (adaptor.getLhs()) {
1329 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1331 auto rhsNeg = hw::ParamExprAttr::get(
1332 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1333 return hw::ParamExprAttr::get(hw::PEO::Add,
1334 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1338 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1339 if (rhsC.getValue().isZero())
1347LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1350 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1351 auto negCst = rewriter.create<
hw::ConstantOp>(op.getLoc(), -value);
1352 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1364OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1365 auto size = getInputs().size();
1369 return getInputs()[0];
1375LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1376 auto inputs = op.getInputs();
1377 auto size = inputs.size();
1378 assert(size > 1 &&
"expected 2 or more operands");
1380 APInt value, value2;
1383 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1384 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1385 inputs.drop_back(),
false);
1390 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1391 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1392 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1393 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1394 newOperands.push_back(cst);
1395 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1396 newOperands,
false);
1401 if (inputs[size - 1] == inputs[size - 2]) {
1402 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1404 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1408 newOperands.push_back(shiftLeftOp);
1409 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1410 newOperands,
false);
1414 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1416 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1417 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1419 APInt one(value.getBitWidth(), 1,
false);
1421 rewriter.create<
hw::ConstantOp>(op.getLoc(), (one << value) + one);
1423 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1424 auto mulOp = rewriter.create<
comb::MulOp>(op.getLoc(), factors,
false);
1426 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1427 newOperands.push_back(mulOp);
1428 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1429 newOperands,
false);
1433 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1435 if (mulOp && mulOp.getInputs().size() == 2 &&
1436 mulOp.getInputs()[0] == inputs[size - 2] &&
1437 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1439 APInt one(value.getBitWidth(), 1,
false);
1440 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + one);
1441 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1444 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1445 newOperands.push_back(newMulOp);
1446 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1447 newOperands,
false);
1460 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1461 if (addOp && addOp.getInputs().size() == 2 &&
1462 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1463 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1465 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1466 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1467 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1468 op.getTwoState() && addOp.getTwoState());
1475OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1476 auto size = getInputs().size();
1477 auto inputs = adaptor.getInputs();
1481 return getInputs()[0];
1483 auto width = cast<IntegerType>(getType()).getWidth();
1485 return getIntAttr(APInt::getZero(0), getContext());
1487 APInt value(width, 1,
false);
1490 for (
auto operand : inputs) {
1493 value *= cast<IntegerAttr>(operand).getValue();
1502LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1503 auto inputs = op.getInputs();
1504 auto size = inputs.size();
1505 assert(size > 1 &&
"expected 2 or more operands");
1507 APInt value, value2;
1510 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1511 value.isPowerOf2()) {
1512 auto shift = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(),
1513 value.exactLogBase2());
1517 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1518 ArrayRef<Value>(shlOp),
false);
1523 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1524 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1525 inputs.drop_back());
1530 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1531 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1532 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value * value2);
1533 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1534 newOperands.push_back(cst);
1535 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1551template <
class Op,
bool isSigned>
1552static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1553 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1555 if (rhsValue.getValue() == 1)
1559 if (rhsValue.getValue().isZero())
1566OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1567 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1570OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1574template <
class Op,
bool isSigned>
1575static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1576 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1578 if (rhsValue.getValue() == 1)
1579 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1583 if (rhsValue.getValue().isZero())
1587 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1589 if (lhsValue.getValue().isZero())
1590 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1597OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1598 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1601OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1609OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1610 if (getNumOperands() == 1)
1611 return getOperand(0);
1614 for (
auto attr : adaptor.getInputs())
1615 if (!attr || !isa<IntegerAttr>(attr))
1619 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1620 APInt result(resultWidth, 0);
1622 unsigned nextInsertion = resultWidth;
1624 for (
auto attr : adaptor.getInputs()) {
1625 auto chunk = cast<IntegerAttr>(attr).getValue();
1626 nextInsertion -= chunk.getBitWidth();
1627 result.insertBits(chunk, nextInsertion);
1633LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1634 auto inputs = op.getInputs();
1635 auto size = inputs.size();
1636 assert(size > 1 &&
"expected 2 or more operands");
1641 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1642 ValueRange replacements) -> LogicalResult {
1643 SmallVector<Value, 4> newOperands;
1644 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1645 newOperands.append(replacements.begin(), replacements.end());
1646 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1647 if (newOperands.size() == 1)
1650 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1655 Value commonOperand = inputs[0];
1656 for (
size_t i = 0; i != size; ++i) {
1658 if (inputs[i] != commonOperand)
1659 commonOperand = Value();
1663 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1664 return flattenConcat(i, i, subConcat->getOperands());
1669 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1670 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1671 unsigned prevWidth = prevCst.getValue().getBitWidth();
1672 unsigned thisWidth = cst.getValue().getBitWidth();
1673 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1674 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1678 return flattenConcat(i - 1, i, replacement);
1683 if (inputs[i] == inputs[i - 1]) {
1685 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1686 return flattenConcat(i - 1, i, replacement);
1691 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1693 if (repl.getOperand() == inputs[i - 1]) {
1694 Value replacement = rewriter.createOrFold<ReplicateOp>(
1695 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1696 return flattenConcat(i - 1, i, replacement);
1699 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1700 if (prevRepl.getOperand() == repl.getOperand()) {
1701 Value replacement = rewriter.createOrFold<ReplicateOp>(
1702 op.getLoc(), repl.getOperand(),
1703 repl.getMultiple() + prevRepl.getMultiple());
1704 return flattenConcat(i - 1, i, replacement);
1710 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1711 if (repl.getOperand() == inputs[i]) {
1712 Value replacement = rewriter.createOrFold<ReplicateOp>(
1713 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1714 return flattenConcat(i - 1, i, replacement);
1720 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1721 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1722 if (extract.getInput() == prevExtract.getInput()) {
1723 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1724 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1725 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1726 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1727 Value replacement = rewriter.create<
ExtractOp>(
1728 op.getLoc(), resType, extract.getInput(),
1729 extract.getLowBit());
1730 return flattenConcat(i - 1, i, replacement);
1743 static std::optional<ArraySlice>
get(Value value) {
1744 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1746 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1749 if (
auto arraySlice =
1752 arraySlice.getInput(), arraySlice.getLowIndex(),
1753 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1755 return std::nullopt;
1758 if (
auto extractOpt = ArraySlice::get(inputs[i])) {
1759 if (
auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1761 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1762 prevExtractOpt->input == extractOpt->input &&
1763 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1764 extractOpt->width)) {
1765 auto resType = hw::ArrayType::get(
1766 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1768 extractOpt->width + prevExtractOpt->width);
1769 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1771 op.getLoc(), resIntType,
1773 prevExtractOpt->input,
1774 extractOpt->index));
1775 return flattenConcat(i - 1, i, replacement);
1783 if (commonOperand) {
1784 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1796OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1798 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1799 return getTrueValue();
1800 if (
auto tv = adaptor.getTrueValue())
1801 if (tv == adaptor.getFalseValue())
1806 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1807 if (pred.getValue().isZero())
1808 return getFalseValue();
1809 return getTrueValue();
1813 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1814 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1815 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1816 hw::getBitWidth(getType()) == 1)
1832 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1834 auto requiredPredicate =
1835 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1836 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1846 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
1849 for (
auto operand : orOp.getOperands())
1856 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
1859 for (
auto operand : andOp.getOperands())
1877 PatternRewriter &rewriter) {
1880 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
1883 Value indexValue = rootCmp.getLhs();
1886 auto getCaseValue = [&](
MuxOp mux) -> Value {
1887 return mux.getOperand(1 +
unsigned(!isFalseSide));
1892 auto getTreeValue = [&](
MuxOp mux) -> Value {
1893 return mux.getOperand(1 +
unsigned(isFalseSide));
1898 SmallVector<Location> locationsFound;
1899 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
1903 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
1905 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
1906 valuesFound.push_back({cst, getCaseValue(mux)});
1907 locationsFound.push_back(mux.getCond().getLoc());
1908 locationsFound.push_back(mux->getLoc());
1913 if (!collectConstantValues(rootMux))
1917 if (rootMux->hasOneUse()) {
1918 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
1919 if (getTreeValue(userMux) == rootMux.getResult() &&
1927 auto nextTreeValue = getTreeValue(rootMux);
1929 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
1930 if (!nextMux || !nextMux->hasOneUse())
1932 if (!collectConstantValues(nextMux))
1934 nextTreeValue = getTreeValue(nextMux);
1940 if (valuesFound.size() < 3)
1945 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
1946 if (indexWidth >= 9)
1952 uint64_t tableSize = 1ULL << indexWidth;
1953 if (valuesFound.size() < (tableSize * 5) / 8)
1958 SmallVector<Value, 8> table(tableSize, nextTreeValue);
1963 for (
auto &elt :
llvm::reverse(valuesFound)) {
1964 uint64_t idx = elt.first.getValue().getZExtValue();
1965 assert(idx < table.size() &&
"constant should be same bitwidth as index");
1966 table[idx] = elt.second;
1971 std::reverse(table.begin(), table.end());
1974 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
1976 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
1991 PatternRewriter &rewriter) {
1992 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
1993 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
1997 if (fullyAssoc->getNumOperands() == 2)
1998 return fullyAssoc->getOperand(operandNo ^ 1);
2001 if (fullyAssoc->hasOneUse()) {
2002 rewriter.modifyOpInPlace(fullyAssoc,
2003 [&]() { fullyAssoc->eraseOperand(operandNo); });
2004 return fullyAssoc->getResult(0);
2008 SmallVector<Value> operands;
2009 operands.append(fullyAssoc->getOperands().begin(),
2010 fullyAssoc->getOperands().begin() + operandNo);
2011 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2012 fullyAssoc->getOperands().end());
2014 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2015 Value excluded = fullyAssoc->getOperand(operandNo);
2019 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2021 return opWithoutExcluded;
2031 PatternRewriter &rewriter) {
2034 Operation *subExpr =
2035 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2036 if (!subExpr || subExpr->getNumOperands() < 2)
2040 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2045 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2046 size_t opNo = 0, e = subExpr->getNumOperands();
2047 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2053 Value cond = op.getCond();
2059 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2064 Value subCond = subMux.getCond();
2067 if (subMux.getTrueValue() == commonValue)
2068 otherValue = subMux.getFalseValue();
2069 else if (subMux.getFalseValue() == commonValue) {
2070 otherValue = subMux.getTrueValue();
2080 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2081 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2082 otherValue, op.getTwoState());
2088 bool isaAndOp = isa<AndOp>(subExpr);
2089 if (isTrueOperand ^ isaAndOp)
2093 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2096 bool isaXorOp = isa<XorOp>(subExpr);
2097 bool isaOrOp = isa<OrOp>(subExpr);
2106 if (isaOrOp || isaXorOp) {
2107 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2108 restOfAssoc,
false);
2110 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2111 commonValue,
false);
2113 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2119 assert(isaAndOp &&
"unexpected operation here");
2120 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2121 restOfAssoc,
false);
2122 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2133 PatternRewriter &rewriter) {
2136 if (!isa<ConcatOp>(trueOp))
2140 SmallVector<Value> trueOperands, falseOperands;
2144 size_t numTrueOperands = trueOperands.size();
2145 size_t numFalseOperands = falseOperands.size();
2147 if (!numTrueOperands || !numFalseOperands ||
2148 (trueOperands.front() != falseOperands.front() &&
2149 trueOperands.back() != falseOperands.back()))
2153 if (trueOperands.front() == falseOperands.front()) {
2154 SmallVector<Value> operands;
2156 for (i = 0; i < numTrueOperands; ++i) {
2157 Value trueOperand = trueOperands[i];
2158 if (trueOperand == falseOperands[i])
2159 operands.push_back(trueOperand);
2163 if (i == numTrueOperands) {
2170 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2171 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2172 mux->getLoc(), operands.front(), operands.size());
2174 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2178 operands.append(trueOperands.begin() + i, trueOperands.end());
2179 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2181 operands.append(falseOperands.begin() + i, falseOperands.end());
2183 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2186 Value lsb = rewriter.createOrFold<
MuxOp>(
2187 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2188 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2193 if (trueOperands.back() == falseOperands.back()) {
2194 SmallVector<Value> operands;
2197 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2198 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2199 operands.push_back(trueOperand);
2203 std::reverse(operands.begin(), operands.end());
2204 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2208 operands.append(trueOperands.begin(), trueOperands.end() - i);
2209 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2211 operands.append(falseOperands.begin(), falseOperands.end() - i);
2213 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2215 Value msb = rewriter.createOrFold<
MuxOp>(
2216 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2217 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2229 if (!trueVec || !falseVec)
2231 if (!trueVec.isUniform() || !falseVec.isUniform())
2235 op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2236 falseVec.getUniformElement(), op.getTwoState());
2238 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2245 using OpRewritePattern::OpRewritePattern;
2247 LogicalResult matchAndRewrite(
MuxOp op,
2248 PatternRewriter &rewriter)
const override;
2251LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2252 PatternRewriter &rewriter)
const {
2258 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2259 if (value.getBitWidth() == 1) {
2261 if (value.isZero()) {
2263 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2264 op.getFalseValue(),
false);
2269 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2270 op.getFalseValue(),
false);
2276 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2281 APInt xorValue = value ^ value2;
2282 if (xorValue.isPowerOf2()) {
2283 unsigned leadingZeros = xorValue.countLeadingZeros();
2284 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2285 SmallVector<Value, 3> operands;
2293 if (leadingZeros > 0)
2294 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2295 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2299 auto v1 = rewriter.createOrFold<
ExtractOp>(
2300 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2301 auto v2 = rewriter.createOrFold<
ExtractOp>(
2302 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2303 operands.push_back(rewriter.createOrFold<
MuxOp>(
2304 op.getLoc(), op.getCond(), v1, v2,
false));
2306 if (trailingZeros > 0)
2307 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2308 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2310 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2317 if (value.isAllOnes() && value2.isZero()) {
2318 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2319 rewriter, op, op.getType(), op.getCond());
2325 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2326 value.getBitWidth() == 1) {
2328 if (value.isZero()) {
2329 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2330 op.getTrueValue(),
false);
2337 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2338 op.getFalseValue(),
false);
2339 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2340 op.getTrueValue(),
false);
2346 Operation *condOp = op.getCond().getDefiningOp();
2347 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2349 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2350 subExpr, op.getFalseValue(),
2351 op.getTrueValue(),
true);
2358 if (condOp && condOp->hasOneUse()) {
2359 SmallVector<Value> invertedOperands;
2363 auto getInvertedOperands = [&]() ->
bool {
2364 for (Value operand : condOp->getOperands()) {
2365 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2366 invertedOperands.push_back(subExpr);
2373 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2375 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2376 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2377 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2381 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2383 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2384 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2385 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2391 if (
auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
2392 falseMux && falseMux != op) {
2394 if (op.getCond() == falseMux.getCond()) {
2395 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2396 rewriter, op, op.getCond(), op.getTrueValue(),
2397 falseMux.getFalseValue(), op.getTwoStateAttr());
2406 if (
auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
2407 trueMux && trueMux != op) {
2409 if (op.getCond() == trueMux.getCond()) {
2410 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2411 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2412 op.getFalseValue(), op.getTwoStateAttr());
2422 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2423 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2424 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2425 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2427 auto subMux = rewriter.create<
MuxOp>(
2428 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2429 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2430 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2431 trueMux.getTrueValue(), subMux,
2432 op.getTwoStateAttr());
2437 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2438 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2439 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2440 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2442 auto subMux = rewriter.create<
MuxOp>(
2443 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2444 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2445 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2446 subMux, trueMux.getFalseValue(),
2447 op.getTwoStateAttr());
2452 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2453 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2454 trueMux && falseMux &&
2455 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2456 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2458 auto subMux = rewriter.create<
MuxOp>(
2459 rewriter.getFusedLoc(
2460 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2461 op.getCond(), trueMux.getCond(), falseMux.getCond());
2462 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2463 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2464 op.getTwoStateAttr());
2476 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2477 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2478 if (trueOp->getName() == falseOp->getName())
2495 if (op.getInputs().empty() || op.isUniform())
2497 auto inputs = op.getInputs();
2498 if (inputs.size() <= 1)
2503 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2508 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2509 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2510 if (!input || first.getCond() != input.getCond())
2515 SmallVector<Value> trues{first.getTrueValue()};
2516 SmallVector<Value> falses{first.getFalseValue()};
2517 SmallVector<Location> locs{first->getLoc()};
2518 bool isTwoState =
true;
2519 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2520 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2521 trues.push_back(input.getTrueValue());
2522 falses.push_back(input.getFalseValue());
2523 locs.push_back(input->getLoc());
2524 if (!input.getTwoState())
2529 auto loc = FusedLoc::get(op.getContext(), locs);
2533 auto arrayTy = op.getType();
2536 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2537 trueValues, falseValues, isTwoState);
2542 using OpRewritePattern::OpRewritePattern;
2545 PatternRewriter &rewriter)
const override {
2546 if (foldArrayOfMuxes(op, rewriter))
2554void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2555 MLIRContext *context) {
2556 results.insert<MuxRewriter, ArrayRewriter>(context);
2567 switch (predicate) {
2568 case ICmpPredicate::eq:
2570 case ICmpPredicate::ne:
2572 case ICmpPredicate::slt:
2573 return lhs.slt(rhs);
2574 case ICmpPredicate::sle:
2575 return lhs.sle(rhs);
2576 case ICmpPredicate::sgt:
2577 return lhs.sgt(rhs);
2578 case ICmpPredicate::sge:
2579 return lhs.sge(rhs);
2580 case ICmpPredicate::ult:
2581 return lhs.ult(rhs);
2582 case ICmpPredicate::ule:
2583 return lhs.ule(rhs);
2584 case ICmpPredicate::ugt:
2585 return lhs.ugt(rhs);
2586 case ICmpPredicate::uge:
2587 return lhs.uge(rhs);
2588 case ICmpPredicate::ceq:
2590 case ICmpPredicate::cne:
2592 case ICmpPredicate::weq:
2594 case ICmpPredicate::wne:
2597 llvm_unreachable(
"unknown comparison predicate");
2603 switch (predicate) {
2604 case ICmpPredicate::eq:
2605 case ICmpPredicate::sle:
2606 case ICmpPredicate::sge:
2607 case ICmpPredicate::ule:
2608 case ICmpPredicate::uge:
2609 case ICmpPredicate::ceq:
2610 case ICmpPredicate::weq:
2612 case ICmpPredicate::ne:
2613 case ICmpPredicate::slt:
2614 case ICmpPredicate::sgt:
2615 case ICmpPredicate::ult:
2616 case ICmpPredicate::ugt:
2617 case ICmpPredicate::cne:
2618 case ICmpPredicate::wne:
2621 llvm_unreachable(
"unknown comparison predicate");
2624OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2627 if (getLhs() == getRhs()) {
2629 return IntegerAttr::get(getType(), val);
2633 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2634 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2637 return IntegerAttr::get(getType(), val);
2645template <
typename Range>
2647 size_t commonPrefixLength = 0;
2648 auto ia = a.begin();
2649 auto ib = b.begin();
2651 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2657 return commonPrefixLength;
2661 size_t totalWidth = 0;
2662 for (
auto operand : operands) {
2665 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2667 totalWidth += width;
2677 PatternRewriter &rewriter) {
2681 SmallVector<Value> lhsOperands, rhsOperands;
2684 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2686 auto formCatOrReplicate = [&](Location loc,
2687 ArrayRef<Value> operands) -> Value {
2688 assert(!operands.empty());
2689 Value sameElement = operands[0];
2690 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2691 if (sameElement != operands[i])
2692 sameElement = Value();
2694 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2696 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2699 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2700 Value rhs) -> LogicalResult {
2701 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2706 size_t commonPrefixLength =
2708 if (commonPrefixLength == lhsOperands.size()) {
2711 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2717 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2719 size_t commonPrefixTotalWidth =
2720 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2721 size_t commonSuffixTotalWidth =
2722 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2723 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2724 .drop_back(commonSuffixLength);
2725 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2726 .drop_back(commonSuffixLength);
2728 auto replaceWithoutReplicatingSignBit = [&]() {
2729 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2730 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2731 return replaceWith(op.getPredicate(), newLhs, newRhs);
2734 auto replaceWithReplicatingSignBit = [&]() {
2735 auto firstNonEmptyValue = lhsOperands[0];
2736 auto firstNonEmptyElemWidth =
2737 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2738 Value signBit = rewriter.createOrFold<
ExtractOp>(
2739 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2741 auto newLhs = rewriter.
create<
ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2742 auto newRhs = rewriter.create<
ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2743 return replaceWith(op.getPredicate(), newLhs, newRhs);
2746 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2748 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2749 return replaceWithoutReplicatingSignBit();
2755 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2756 return replaceWithReplicatingSignBit();
2758 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2760 return replaceWithoutReplicatingSignBit();
2774 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2775 PatternRewriter &rewriter) {
2779 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2780 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2783 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2784 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2792 SmallVector<Value> newConcatOperands;
2793 auto newConstant = APInt::getZeroWidth();
2798 unsigned knownMSB = bitsKnown.countLeadingOnes();
2800 Value operand = cmpOp.getLhs();
2805 while (knownMSB != bitsKnown.getBitWidth()) {
2808 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2811 unsigned unknownBits = bitsKnown.countLeadingZeros();
2812 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2813 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
2814 operand.getLoc(), operand, lowBit,
2816 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2819 newConcatOperands.push_back(spanOperand);
2822 if (newConstant.getBitWidth() != 0)
2823 newConstant = newConstant.concat(spanConstant);
2825 newConstant = spanConstant;
2828 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
2829 bitsKnown = bitsKnown.trunc(newWidth);
2830 knownMSB = bitsKnown.countLeadingOnes();
2836 if (newConcatOperands.empty()) {
2837 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
2838 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2844 Value concatResult =
2845 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
2849 cmpOp.getOperand(1).getLoc(), newConstant);
2851 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
2852 cmpOp.getPredicate(), concatResult,
2853 newConstantOp, cmpOp.getTwoState());
2859 PatternRewriter &rewriter) {
2860 auto ip = rewriter.saveInsertionPoint();
2861 rewriter.setInsertionPoint(xorOp);
2863 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
2865 xorRHS.getValue() ^ rhs);
2867 switch (xorOp.getNumOperands()) {
2871 APInt::getZero(rhs.getBitWidth()));
2875 newLHS = xorOp.getOperand(0);
2879 SmallVector<Value> newOperands(xorOp.getOperands());
2880 newOperands.pop_back();
2881 newLHS = rewriter.create<
XorOp>(xorOp.getLoc(), newOperands,
false);
2885 bool xorMultipleUses = !xorOp->hasOneUse();
2889 if (xorMultipleUses)
2890 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
2894 rewriter.restoreInsertionPoint(ip);
2895 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
2896 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS,
false);
2899LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
2903 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
2904 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
2905 "Should be folded");
2906 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
2907 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
2908 op.getRhs(), op.getLhs(), op.getTwoState());
2913 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
2915 return rewriter.create<
hw::ConstantOp>(op.getLoc(), std::move(constant));
2918 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2919 Value rhs) -> LogicalResult {
2920 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
2921 rhs, op.getTwoState());
2925 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
2926 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2927 APInt(1, constant));
2931 switch (op.getPredicate()) {
2932 case ICmpPredicate::slt:
2934 if (rhs.isMaxSignedValue())
2935 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
2937 if (rhs.isMinSignedValue())
2938 return replaceWithConstantI1(0);
2940 if ((rhs - 1).isMinSignedValue())
2941 return replaceWith(ICmpPredicate::eq, op.getLhs(),
2944 case ICmpPredicate::sgt:
2946 if (rhs.isMinSignedValue())
2947 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
2949 if (rhs.isMaxSignedValue())
2950 return replaceWithConstantI1(0);
2952 if ((rhs + 1).isMaxSignedValue())
2953 return replaceWith(ICmpPredicate::eq, op.getLhs(),
2956 case ICmpPredicate::ult:
2958 if (rhs.isAllOnes())
2959 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
2962 return replaceWithConstantI1(0);
2964 if ((rhs - 1).isZero())
2965 return replaceWith(ICmpPredicate::eq, op.getLhs(),
2969 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
2970 rhs.getBitWidth()) {
2971 auto numOnes = rhs.countLeadingOnes();
2972 auto smaller = rewriter.create<
ExtractOp>(
2973 op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
2974 return replaceWith(ICmpPredicate::ne, smaller,
2979 case ICmpPredicate::ugt:
2982 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
2984 if (rhs.isAllOnes())
2985 return replaceWithConstantI1(0);
2987 if ((rhs + 1).isAllOnes())
2988 return replaceWith(ICmpPredicate::eq, op.getLhs(),
2992 if ((rhs + 1).isPowerOf2()) {
2993 auto numOnes = rhs.countTrailingOnes();
2994 auto newWidth = rhs.getBitWidth() - numOnes;
2995 auto smaller = rewriter.create<
ExtractOp>(op.getLoc(), op.getLhs(),
2997 return replaceWith(ICmpPredicate::ne, smaller,
3002 case ICmpPredicate::sle:
3004 if (rhs.isMaxSignedValue())
3005 return replaceWithConstantI1(1);
3007 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3008 case ICmpPredicate::sge:
3010 if (rhs.isMinSignedValue())
3011 return replaceWithConstantI1(1);
3013 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3014 case ICmpPredicate::ule:
3016 if (rhs.isAllOnes())
3017 return replaceWithConstantI1(1);
3019 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3020 case ICmpPredicate::uge:
3023 return replaceWithConstantI1(1);
3025 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3026 case ICmpPredicate::eq:
3027 if (rhs.getBitWidth() == 1) {
3030 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3035 if (rhs.isAllOnes()) {
3042 case ICmpPredicate::ne:
3043 if (rhs.getBitWidth() == 1) {
3049 if (rhs.isAllOnes()) {
3051 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3058 case ICmpPredicate::ceq:
3059 case ICmpPredicate::cne:
3060 case ICmpPredicate::weq:
3061 case ICmpPredicate::wne:
3067 if (op.getPredicate() == ICmpPredicate::eq ||
3068 op.getPredicate() == ICmpPredicate::ne) {
3073 if (!knownBits.isUnknown())
3080 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3087 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3088 if (rhs.isAllOnes() || rhs.isZero()) {
3089 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3091 op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width)
3092 : APInt::getZero(width));
3093 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3094 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3104 if (Operation *opLHS = op.getLhs().getDefiningOp())
3105 if (Operation *opRHS = op.getRhs().getDefiningOp())
3106 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3107 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.