12#include "mlir/IR/Matchers.h"
13#include "mlir/IR/PatternMatch.h"
14#include "llvm/ADT/SetVector.h"
15#include "llvm/ADT/SmallBitVector.h"
16#include "llvm/ADT/TypeSwitch.h"
17#include "llvm/Support/KnownBits.h"
22using namespace matchers;
33 Block *thisBlock = op->getBlock();
34 return llvm::any_of(op->getOperands(), [&](Value operand) {
35 return operand.getParentBlock() != thisBlock;
45 ArrayRef<Value> operands, OpBuilder &builder) {
46 OperationState state(loc, name);
47 state.addOperands(operands);
48 state.addTypes(operands[0].getType());
49 return builder.create(state)->getResult(0);
52static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
53 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
60 for (
auto op :
concat.getOperands())
62 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
63 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
75 if (
auto *newOp = newValue.getDefiningOp()) {
76 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
77 if (name && !newOp->hasAttr(
"sv.namehint"))
78 rewriter.modifyOpInPlace(newOp,
79 [&] { newOp->setAttr(
"sv.namehint", name); });
81 rewriter.replaceOp(op, newValue);
87template <
typename OpTy,
typename... Args>
89 Operation *op, Args &&...args) {
90 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
92 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
93 if (name && !newOp->hasAttr(
"sv.namehint"))
94 rewriter.modifyOpInPlace(newOp,
95 [&] { newOp->setAttr(
"sv.namehint", name); });
104 return op->hasAttr(
"sv.attributes");
108template <
typename SubType>
109struct ComplementMatcher {
111 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
112 bool match(Operation *op) {
113 auto xorOp = dyn_cast<XorOp>(op);
114 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
119template <
typename SubType>
120static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
121 return ComplementMatcher<SubType>(subExpr);
127 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
128 "must be commutative operations"));
129 if (op->hasOneUse()) {
130 auto *user = *op->getUsers().begin();
131 return user->getName() == op->getName() &&
132 op->getAttrOfType<UnitAttr>(
"twoState") ==
133 user->getAttrOfType<UnitAttr>(
"twoState") &&
134 op->getBlock() == user->getBlock();
149 auto inputs = op->getOperands();
151 SmallVector<Value, 4> newOperands;
152 SmallVector<Location, 4> newLocations{op->getLoc()};
153 newOperands.reserve(inputs.size());
155 decltype(inputs.begin()) current, end;
158 SmallVector<Element> worklist;
159 worklist.push_back({inputs.begin(), inputs.end()});
160 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
161 bool changed =
false;
162 while (!worklist.empty()) {
163 auto &element = worklist.back();
166 if (element.current == element.end) {
171 Value value = *element.current++;
172 auto *flattenOp = value.getDefiningOp();
175 if (!flattenOp || flattenOp->getName() != op->getName() ||
176 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
177 flattenOp->getBlock() != op->getBlock()) {
178 newOperands.push_back(value);
183 if (!value.hasOneUse()) {
191 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
194 newOperands.push_back(value);
202 auto flattenOpInputs = flattenOp->getOperands();
203 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
204 newLocations.push_back(flattenOp->getLoc());
210 Value result =
createGenericOp(FusedLoc::get(op->getContext(), newLocations),
211 op->getName(), newOperands, rewriter);
213 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
221static std::pair<size_t, size_t>
223 size_t originalOpWidth) {
224 auto users = op->getUsers();
226 "getLowestBitAndHighestBitRequired cannot operate on "
227 "a empty list of uses.");
231 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
232 size_t highestBitRequired = 0;
234 for (
auto *user : users) {
235 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
236 size_t lowBit = extractOp.getLowBit();
238 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
239 highestBitRequired = std::max(highestBitRequired, highBit);
240 lowestBitRequired = std::min(lowestBitRequired, lowBit);
244 highestBitRequired = originalOpWidth - 1;
245 lowestBitRequired = 0;
249 return {lowestBitRequired, highestBitRequired};
254 PatternRewriter &rewriter) {
255 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
261 if (range.second + 1 == opType.getWidth() && range.first == 0)
264 SmallVector<Value> args;
265 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
266 for (
auto inop : op.getOperands()) {
268 if (inop.getType() != op.getType())
269 args.push_back(inop);
271 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
274 auto newop = rewriter.create<OpTy>(op.getLoc(), newType, args);
275 newop->setDialectAttrs(op->getDialectAttrs());
276 if (op.getTwoState())
277 newop.setTwoState(
true);
279 Value newResult = newop.getResult();
281 newResult = rewriter.createOrFold<
ConcatOp>(
282 op.getLoc(), newResult,
284 APInt::getZero(range.first)));
285 if (range.second + 1 < opType.getWidth())
286 newResult = rewriter.createOrFold<
ConcatOp>(
289 op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
291 rewriter.replaceOp(op, newResult);
299OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
304 if (cast<IntegerType>(getType()).
getWidth() ==
305 getInput().getType().getIntOrFloatBitWidth())
309 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
310 if (input.getValue().getBitWidth() == 1) {
311 if (input.getValue().isZero())
313 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
316 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
320 APInt result = APInt::getZeroWidth();
321 for (
auto i = getMultiple(); i != 0; --i)
322 result = result.concat(input.getValue());
329OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
334 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
335 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
347 hw::PEO paramOpcode) {
348 assert(operands.size() == 2 &&
"binary op takes two operands");
349 if (!operands[0] || !operands[1])
354 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
355 cast<TypedAttr>(operands[1]));
358OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
362 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
363 unsigned shift = rhs.getValue().getZExtValue();
364 unsigned width = getType().getIntOrFloatBitWidth();
366 return getOperand(0);
368 return getIntAttr(APInt::getZero(width), getContext());
374LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
380 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
383 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
384 unsigned shift = value.getZExtValue();
387 if (width <= shift || shift == 0)
391 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
395 rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), 0, width - shift);
397 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, extract, zeros);
401OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
405 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
406 unsigned shift = rhs.getValue().getZExtValue();
408 return getOperand(0);
410 unsigned width = getType().getIntOrFloatBitWidth();
412 return getIntAttr(APInt::getZero(width), getContext());
417LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
423 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
426 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
427 unsigned shift = value.getZExtValue();
430 if (width <= shift || shift == 0)
434 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
437 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
440 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, zeros, extract);
444OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
448 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
449 if (rhs.getValue().getZExtValue() == 0)
450 return getOperand(0);
455LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
461 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
464 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
465 unsigned shift = value.getZExtValue();
468 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
469 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
471 if (width <= shift) {
476 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
479 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, sext, extract);
487OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
492 if (getInput().getType() == getType())
496 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
497 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
498 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
509 PatternRewriter &rewriter) {
510 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
511 size_t beginOfFirstRelevantElement = 0;
512 auto it = reversedConcatArgs.begin();
513 size_t lowBit = op.getLowBit();
516 for (; it != reversedConcatArgs.end(); it++) {
517 assert(beginOfFirstRelevantElement <= lowBit &&
518 "incorrectly moved past an element that lowBit has coverage over");
521 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
522 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
546 beginOfFirstRelevantElement += operandWidth;
548 assert(it != reversedConcatArgs.end() &&
549 "incorrectly failed to find an element which contains coverage of "
552 SmallVector<Value> reverseConcatArgs;
553 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
554 size_t extractLo = lowBit - beginOfFirstRelevantElement;
559 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
560 auto concatArg = *it;
561 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
562 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
564 if (widthToConsume == operandWidth && extractLo == 0) {
565 reverseConcatArgs.push_back(concatArg);
567 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
568 reverseConcatArgs.push_back(
569 rewriter.create<
ExtractOp>(op.getLoc(), resultType, *it, extractLo));
572 widthRemaining -= widthToConsume;
578 if (reverseConcatArgs.size() == 1) {
581 replaceOpWithNewOpAndCopyName<ConcatOp>(
582 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
589 PatternRewriter &rewriter) {
590 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
591 auto replicateEltWidth =
592 replicate.getOperand().getType().getIntOrFloatBitWidth();
596 if (op.getLowBit() % replicateEltWidth == 0 &&
597 extractResultWidth % replicateEltWidth == 0) {
598 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
599 replicate.getOperand());
605 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
607 replaceOpWithNewOpAndCopyName<ExtractOp>(
608 rewriter, op, op.getType(), replicate.getOperand(),
609 op.getLowBit() % replicateEltWidth);
618LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
622 auto *inputOp = op.getInput().getDefiningOp();
629 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
631 if (knownBits.isConstant()) {
632 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
633 knownBits.getConstant());
639 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
640 replaceOpWithNewOpAndCopyName<ExtractOp>(
641 rewriter, op, op.getType(), innerExtract.getInput(),
642 innerExtract.getLowBit() + op.getLowBit());
647 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
651 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
657 if (inputOp && inputOp->getNumOperands() == 2 &&
658 isa<AndOp, OrOp, XorOp>(inputOp)) {
659 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
660 auto extractedCst = cstRHS.getValue().extractBits(
661 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
662 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
663 replaceOpWithNewOpAndCopyName<ExtractOp>(
664 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
672 if (isa<AndOp>(inputOp)) {
675 unsigned lz = extractedCst.countLeadingZeros();
676 unsigned tz = extractedCst.countTrailingZeros();
677 unsigned pop = extractedCst.popcount();
678 if (extractedCst.getBitWidth() - lz - tz == pop) {
679 auto resultTy = rewriter.getIntegerType(pop);
680 SmallVector<Value> resultElts;
683 op.getLoc(), APInt::getZero(lz)));
684 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
685 op.getLoc(), resultTy, inputOp->getOperand(0),
686 op.getLowBit() + tz));
689 op.getLoc(), APInt::getZero(tz)));
690 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
699 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
700 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
702 if (shlOp->hasOneUse())
704 if (lhsCst.getValue().isOne()) {
707 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
708 replaceOpWithNewOpAndCopyName<ICmpOp>(
709 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
725 hw::PEO paramOpcode) {
726 assert(operands.size() > 1 &&
"caller should handle one-operand case");
729 if (!operands[1] || !operands[0])
733 if (llvm::all_of(operands.drop_front(2),
734 [&](Attribute in) { return !!in; })) {
735 SmallVector<mlir::TypedAttr> typedOperands;
736 typedOperands.reserve(operands.size());
737 for (
auto operand : operands) {
738 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
739 typedOperands.push_back(typedOperand);
743 if (typedOperands.size() == operands.size())
744 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
760 size_t concatIdx,
const APInt &cst,
761 PatternRewriter &rewriter) {
762 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
763 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
768 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
769 auto *operandOp = operand.getDefiningOp();
774 if (isa<hw::ConstantOp>(operandOp))
778 return operandOp->getName() == logicalOp->getName() &&
779 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
780 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
788 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
789 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
796 SmallVector<Value> newConcatOperands;
797 newConcatOperands.reserve(concatOp->getNumOperands());
800 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
801 for (Value operand : concatOp->getOperands()) {
802 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
803 nextOperandBit -= operandWidth;
806 logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
808 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
817 if (logicalOp->getNumOperands() > 2) {
818 auto origOperands = logicalOp->getOperands();
819 SmallVector<Value> operands;
821 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
823 operands.append(origOperands.begin() + concatIdx + 1,
824 origOperands.begin() + (origOperands.size() - 1));
826 operands.push_back(newResult);
827 newResult = createLogicalOp(operands);
837 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
839 for (
auto op : operands) {
840 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
841 icmpOp && icmpOp.getTwoState()) {
842 auto predicate = icmpOp.getPredicate();
843 auto lhs = icmpOp.getLhs();
844 auto rhs = icmpOp.getRhs();
845 if (seenPredicates.contains(
846 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
849 seenPredicates.insert({predicate, lhs, rhs});
855OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
859 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
861 auto inputs = adaptor.getInputs();
864 for (
auto operand : inputs) {
867 value &= cast<IntegerAttr>(operand).getValue();
873 if (inputs.size() == 2 && inputs[1] &&
874 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
875 return getInputs()[0];
878 if (llvm::all_of(getInputs(),
879 [&](
auto in) {
return in == this->getInputs()[0]; }))
880 return getInputs()[0];
883 for (Value arg : getInputs()) {
886 for (Value arg2 : getInputs())
889 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
910template <
typename Op>
912 if (!op.getType().isInteger(1))
915 auto inputs = op.getInputs();
916 size_t size = inputs.size();
918 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
921 Value source = sourceOp.getOperand();
924 if (size != source.getType().getIntOrFloatBitWidth())
928 llvm::BitVector bits(size);
929 bits.set(sourceOp.getLowBit());
931 for (
size_t i = 1; i != size; ++i) {
932 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
933 if (!extractOp || extractOp.getOperand() != source)
935 bits.set(extractOp.getLowBit());
938 return bits.all() ? source : Value();
945template <
typename Op>
948 constexpr unsigned limit = 3;
949 auto inputs = op.getInputs();
951 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
952 llvm::SmallDenseSet<Op, 8> checked;
959 llvm::SmallVector<OpWithDepth, 8> worklist;
961 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
965 if (depth < limit && input.getParentBlock() == op->getBlock()) {
966 auto inputOp = input.template getDefiningOp<Op>();
967 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
968 checked.insert(inputOp).second)
969 worklist.push_back({inputOp, depth + 1});
973 for (
auto input : uniqueInputs)
976 while (!worklist.empty()) {
977 auto item = worklist.pop_back_val();
979 for (
auto input : item.op.getInputs()) {
980 uniqueInputs.remove(input);
981 enqueue(input, item.depth);
985 if (uniqueInputs.size() < inputs.size()) {
986 replaceOpWithNewOpAndCopyName<Op>(rewriter, op, op.getType(),
987 uniqueInputs.getArrayRef(),
995LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
996 auto inputs = op.getInputs();
997 auto size = inputs.size();
1011 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
1015 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1017 if (value.isAllOnes()) {
1018 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1019 inputs.drop_back(),
false);
1027 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1028 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value & value2);
1029 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1030 newOperands.push_back(cst);
1031 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1032 newOperands,
false);
1037 if (size == 2 && value.isPowerOf2()) {
1042 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1043 auto replicateOperand = replicate.getOperand();
1044 if (replicateOperand.getType().isInteger(1)) {
1045 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1046 auto trailingZeros = value.countTrailingZeros();
1049 SmallVector<Value, 3> concatOperands;
1050 if (trailingZeros != resultWidth - 1) {
1052 op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
1053 concatOperands.push_back(highZeros);
1055 concatOperands.push_back(replicateOperand);
1056 if (trailingZeros != 0) {
1058 op.getLoc(), APInt::getZero(trailingZeros));
1059 concatOperands.push_back(lowZeros);
1061 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1069 if (
auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
1072 (value.countLeadingZeros() || value.countTrailingZeros())) {
1073 unsigned lz = value.countLeadingZeros();
1074 unsigned tz = value.countTrailingZeros();
1077 auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
1078 Value smallElt = rewriter.createOrFold<
ExtractOp>(
1079 extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1080 extractOp.getLowBit() + tz);
1082 APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1083 if (!smallMask.isAllOnes()) {
1084 auto loc = inputs.back().getLoc();
1085 smallElt = rewriter.createOrFold<
AndOp>(
1092 SmallVector<Value> resultElts;
1094 resultElts.push_back(
1095 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1096 resultElts.push_back(smallElt);
1098 resultElts.push_back(
1099 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1100 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
1108 for (
size_t i = 0; i < size - 1; ++i) {
1109 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1122 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1123 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1124 source, cmpAgainst);
1132OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1136 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1137 auto inputs = adaptor.getInputs();
1139 for (
auto operand : inputs) {
1142 value |= cast<IntegerAttr>(operand).getValue();
1143 if (value.isAllOnes())
1148 if (inputs.size() == 2 && inputs[1] &&
1149 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1150 return getInputs()[0];
1153 if (llvm::all_of(getInputs(),
1154 [&](
auto in) {
return in == this->getInputs()[0]; }))
1155 return getInputs()[0];
1158 for (Value arg : getInputs()) {
1160 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1161 for (Value arg2 : getInputs())
1162 if (arg2 == subExpr)
1164 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1174 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1181LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1182 auto inputs = op.getInputs();
1183 auto size = inputs.size();
1197 assert(size > 1 &&
"expected 2 or more operands");
1201 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1203 if (value.isZero()) {
1204 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1205 inputs.drop_back());
1211 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1212 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value | value2);
1213 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1214 newOperands.push_back(cst);
1215 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1223 for (
size_t i = 0; i < size - 1; ++i) {
1224 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1237 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1238 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1239 source, cmpAgainst);
1245 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1247 if (op.getTwoState() && firstMux.getTwoState() &&
1248 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1250 SmallVector<Value> conditions{firstMux.getCond()};
1251 auto check = [&](Value v) {
1255 conditions.push_back(mux.getCond());
1256 return mux.getTwoState() &&
1257 firstMux.getTrueValue() == mux.getTrueValue() &&
1258 firstMux.getFalseValue() == mux.getFalseValue();
1260 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1261 auto cond = rewriter.create<
comb::OrOp>(op.getLoc(), conditions,
true);
1262 replaceOpWithNewOpAndCopyName<comb::MuxOp>(
1263 rewriter, op, cond, firstMux.getTrueValue(),
1264 firstMux.getFalseValue(),
true);
1274OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1278 auto size = getInputs().size();
1279 auto inputs = adaptor.getInputs();
1283 return getInputs()[0];
1286 if (size == 2 && getInputs()[0] == getInputs()[1])
1287 return IntegerAttr::get(getType(), 0);
1290 if (inputs.size() == 2 && inputs[1] &&
1291 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1292 return getInputs()[0];
1296 if (isBinaryNot()) {
1298 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1299 subExpr != getResult())
1309 PatternRewriter &rewriter) {
1310 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1311 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1314 rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1315 icmp.getOperand(1), icmp.getTwoState());
1318 if (op.getNumOperands() > 2) {
1319 SmallVector<Value, 4> newOperands(op.getOperands());
1320 newOperands.pop_back();
1321 newOperands.erase(newOperands.begin() + icmpOperand);
1322 newOperands.push_back(result);
1323 result = rewriter.create<
XorOp>(op.getLoc(), newOperands, op.getTwoState());
1329LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1333 auto inputs = op.getInputs();
1334 auto size = inputs.size();
1335 assert(size > 1 &&
"expected 2 or more operands");
1338 if (inputs[size - 1] == inputs[size - 2]) {
1340 "expected idempotent case for 2 elements handled already.");
1341 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1342 inputs.drop_back(2),
false);
1348 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1350 if (value.isZero()) {
1351 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1352 inputs.drop_back(),
false);
1358 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1359 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value ^ value2);
1360 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1361 newOperands.push_back(cst);
1362 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1363 newOperands,
false);
1367 bool isSingleBit = value.getBitWidth() == 1;
1370 for (
size_t i = 0; i < size - 1; ++i) {
1371 Value operand = inputs[i];
1382 if (isSingleBit && operand.hasOneUse()) {
1383 assert(value == 1 &&
"single bit constant has to be one if not zero");
1384 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1400 replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
1407OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1412 if (getRhs() == getLhs())
1414 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1417 if (adaptor.getRhs()) {
1419 if (adaptor.getLhs()) {
1422 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1424 auto rhsNeg = hw::ParamExprAttr::get(
1425 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1426 return hw::ParamExprAttr::get(hw::PEO::Add,
1427 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1431 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1432 if (rhsC.getValue().isZero())
1440LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1446 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1447 auto negCst = rewriter.create<
hw::ConstantOp>(op.getLoc(), -value);
1448 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getLhs(), negCst,
1460OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1464 auto size = getInputs().size();
1468 return getInputs()[0];
1474LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1478 auto inputs = op.getInputs();
1479 auto size = inputs.size();
1480 assert(size > 1 &&
"expected 2 or more operands");
1482 APInt value, value2;
1485 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1486 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1487 inputs.drop_back(),
false);
1492 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1493 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1494 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1495 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1496 newOperands.push_back(cst);
1497 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1498 newOperands,
false);
1503 if (inputs[size - 1] == inputs[size - 2]) {
1504 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1506 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1510 newOperands.push_back(shiftLeftOp);
1511 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1512 newOperands,
false);
1516 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1518 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1519 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1521 APInt one(value.getBitWidth(), 1,
false);
1523 rewriter.create<
hw::ConstantOp>(op.getLoc(), (one << value) + one);
1525 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1526 auto mulOp = rewriter.create<
comb::MulOp>(op.getLoc(), factors,
false);
1528 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1529 newOperands.push_back(mulOp);
1530 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1531 newOperands,
false);
1535 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1537 if (mulOp && mulOp.getInputs().size() == 2 &&
1538 mulOp.getInputs()[0] == inputs[size - 2] &&
1539 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1541 APInt one(value.getBitWidth(), 1,
false);
1542 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + one);
1543 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1546 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1547 newOperands.push_back(newMulOp);
1548 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1549 newOperands,
false);
1562 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1563 if (addOp && addOp.getInputs().size() == 2 &&
1564 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1565 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1567 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1568 replaceOpWithNewOpAndCopyName<AddOp>(
1569 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1570 op.getTwoState() && addOp.getTwoState());
1577OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1581 auto size = getInputs().size();
1582 auto inputs = adaptor.getInputs();
1586 return getInputs()[0];
1588 auto width = cast<IntegerType>(getType()).getWidth();
1589 APInt value(width, 1,
false);
1592 for (
auto operand : inputs) {
1595 value *= cast<IntegerAttr>(operand).getValue();
1604LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1608 auto inputs = op.getInputs();
1609 auto size = inputs.size();
1610 assert(size > 1 &&
"expected 2 or more operands");
1612 APInt value, value2;
1615 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1616 value.isPowerOf2()) {
1617 auto shift = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(),
1618 value.exactLogBase2());
1622 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1623 ArrayRef<Value>(shlOp),
false);
1628 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1629 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1630 inputs.drop_back());
1635 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1636 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1637 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value * value2);
1638 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1639 newOperands.push_back(cst);
1640 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1656template <
class Op,
bool isSigned>
1657static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1658 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1660 if (rhsValue.getValue() == 1)
1664 if (rhsValue.getValue().isZero())
1671OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1675 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1678OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1685template <
class Op,
bool isSigned>
1686static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1687 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1689 if (rhsValue.getValue() == 1)
1690 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1694 if (rhsValue.getValue().isZero())
1698 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1700 if (lhsValue.getValue().isZero())
1701 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1708OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1712 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1715OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1726OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1730 if (getNumOperands() == 1)
1731 return getOperand(0);
1734 for (
auto attr : adaptor.getInputs())
1735 if (!attr || !isa<IntegerAttr>(attr))
1739 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1740 APInt result(resultWidth, 0);
1742 unsigned nextInsertion = resultWidth;
1744 for (
auto attr : adaptor.getInputs()) {
1745 auto chunk = cast<IntegerAttr>(attr).getValue();
1746 nextInsertion -= chunk.getBitWidth();
1747 result.insertBits(chunk, nextInsertion);
1753LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1757 auto inputs = op.getInputs();
1758 auto size = inputs.size();
1759 assert(size > 1 &&
"expected 2 or more operands");
1764 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1765 ValueRange replacements) -> LogicalResult {
1766 SmallVector<Value, 4> newOperands;
1767 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1768 newOperands.append(replacements.begin(), replacements.end());
1769 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1770 if (newOperands.size() == 1)
1773 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1778 Value commonOperand = inputs[0];
1779 for (
size_t i = 0; i != size; ++i) {
1781 if (inputs[i] != commonOperand)
1782 commonOperand = Value();
1786 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1787 return flattenConcat(i, i, subConcat->getOperands());
1792 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1793 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1794 unsigned prevWidth = prevCst.getValue().getBitWidth();
1795 unsigned thisWidth = cst.getValue().getBitWidth();
1796 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1797 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1801 return flattenConcat(i - 1, i, replacement);
1806 if (inputs[i] == inputs[i - 1]) {
1808 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1809 return flattenConcat(i - 1, i, replacement);
1814 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1816 if (repl.getOperand() == inputs[i - 1]) {
1817 Value replacement = rewriter.createOrFold<ReplicateOp>(
1818 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1819 return flattenConcat(i - 1, i, replacement);
1822 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1823 if (prevRepl.getOperand() == repl.getOperand()) {
1824 Value replacement = rewriter.createOrFold<ReplicateOp>(
1825 op.getLoc(), repl.getOperand(),
1826 repl.getMultiple() + prevRepl.getMultiple());
1827 return flattenConcat(i - 1, i, replacement);
1833 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1834 if (repl.getOperand() == inputs[i]) {
1835 Value replacement = rewriter.createOrFold<ReplicateOp>(
1836 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1837 return flattenConcat(i - 1, i, replacement);
1843 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1844 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1845 if (extract.getInput() == prevExtract.getInput()) {
1846 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1847 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1848 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1849 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1850 Value replacement = rewriter.create<
ExtractOp>(
1851 op.getLoc(), resType, extract.getInput(),
1852 extract.getLowBit());
1853 return flattenConcat(i - 1, i, replacement);
1866 static std::optional<ArraySlice>
get(Value value) {
1867 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1869 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1872 if (
auto arraySlice =
1875 arraySlice.getInput(), arraySlice.getLowIndex(),
1876 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1878 return std::nullopt;
1881 if (
auto extractOpt = ArraySlice::get(inputs[i])) {
1882 if (
auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1884 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1885 prevExtractOpt->input == extractOpt->input &&
1886 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1887 extractOpt->width)) {
1888 auto resType = hw::ArrayType::get(
1889 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1891 extractOpt->width + prevExtractOpt->width);
1892 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1894 op.getLoc(), resIntType,
1896 prevExtractOpt->input,
1897 extractOpt->index));
1898 return flattenConcat(i - 1, i, replacement);
1906 if (commonOperand) {
1907 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
1919OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1924 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1925 return getTrueValue();
1926 if (
auto tv = adaptor.getTrueValue())
1927 if (tv == adaptor.getFalseValue())
1932 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1933 if (pred.getValue().isZero())
1934 return getFalseValue();
1935 return getTrueValue();
1939 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1940 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1941 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1942 hw::getBitWidth(getType()) == 1)
1958 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1960 auto requiredPredicate =
1961 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1962 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1972 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
1975 for (
auto operand : orOp.getOperands())
1982 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
1985 for (
auto operand : andOp.getOperands())
2003 PatternRewriter &rewriter) {
2006 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2009 Value indexValue = rootCmp.getLhs();
2012 auto getCaseValue = [&](
MuxOp mux) -> Value {
2013 return mux.getOperand(1 +
unsigned(!isFalseSide));
2018 auto getTreeValue = [&](
MuxOp mux) -> Value {
2019 return mux.getOperand(1 +
unsigned(isFalseSide));
2024 SmallVector<Location> locationsFound;
2025 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2029 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2031 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2032 valuesFound.push_back({cst, getCaseValue(mux)});
2033 locationsFound.push_back(mux.getCond().getLoc());
2034 locationsFound.push_back(mux->getLoc());
2039 if (!collectConstantValues(rootMux))
2043 if (rootMux->hasOneUse()) {
2044 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2045 if (getTreeValue(userMux) == rootMux.getResult() &&
2053 auto nextTreeValue = getTreeValue(rootMux);
2055 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2056 if (!nextMux || !nextMux->hasOneUse())
2058 if (!collectConstantValues(nextMux))
2060 nextTreeValue = getTreeValue(nextMux);
2066 if (valuesFound.size() < 3)
2071 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2072 if (indexWidth >= 9)
2078 uint64_t tableSize = 1ULL << indexWidth;
2079 if (valuesFound.size() < (tableSize * 5) / 8)
2084 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2089 for (
auto &elt :
llvm::reverse(valuesFound)) {
2090 uint64_t idx = elt.first.getValue().getZExtValue();
2091 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2092 table[idx] = elt.second;
2097 std::reverse(table.begin(), table.end());
2100 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2102 replaceOpWithNewOpAndCopyName<hw::ArrayGetOp>(rewriter, rootMux, array,
2117 PatternRewriter &rewriter) {
2118 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2119 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2123 if (fullyAssoc->getNumOperands() == 2)
2124 return fullyAssoc->getOperand(operandNo ^ 1);
2127 if (fullyAssoc->hasOneUse()) {
2128 rewriter.modifyOpInPlace(fullyAssoc,
2129 [&]() { fullyAssoc->eraseOperand(operandNo); });
2130 return fullyAssoc->getResult(0);
2134 SmallVector<Value> operands;
2135 operands.append(fullyAssoc->getOperands().begin(),
2136 fullyAssoc->getOperands().begin() + operandNo);
2137 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2138 fullyAssoc->getOperands().end());
2140 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2141 Value excluded = fullyAssoc->getOperand(operandNo);
2145 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2147 return opWithoutExcluded;
2157 PatternRewriter &rewriter) {
2160 Operation *subExpr =
2161 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2162 if (!subExpr || subExpr->getNumOperands() < 2)
2166 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2171 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2172 size_t opNo = 0, e = subExpr->getNumOperands();
2173 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2179 Value cond = op.getCond();
2185 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2190 Value subCond = subMux.getCond();
2193 if (subMux.getTrueValue() == commonValue)
2194 otherValue = subMux.getFalseValue();
2195 else if (subMux.getFalseValue() == commonValue) {
2196 otherValue = subMux.getTrueValue();
2206 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2207 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, cond, commonValue,
2208 otherValue, op.getTwoState());
2214 bool isaAndOp = isa<AndOp>(subExpr);
2215 if (isTrueOperand ^ isaAndOp)
2219 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2222 bool isaXorOp = isa<XorOp>(subExpr);
2223 bool isaOrOp = isa<OrOp>(subExpr);
2232 if (isaOrOp || isaXorOp) {
2233 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2234 restOfAssoc,
false);
2236 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, masked, commonValue,
2239 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, masked, commonValue,
2245 assert(isaAndOp &&
"unexpected operation here");
2246 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2247 restOfAssoc,
false);
2248 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, masked, commonValue,
2259 PatternRewriter &rewriter) {
2262 if (!isa<ConcatOp>(trueOp))
2266 SmallVector<Value> trueOperands, falseOperands;
2270 size_t numTrueOperands = trueOperands.size();
2271 size_t numFalseOperands = falseOperands.size();
2273 if (!numTrueOperands || !numFalseOperands ||
2274 (trueOperands.front() != falseOperands.front() &&
2275 trueOperands.back() != falseOperands.back()))
2279 if (trueOperands.front() == falseOperands.front()) {
2280 SmallVector<Value> operands;
2282 for (i = 0; i < numTrueOperands; ++i) {
2283 Value trueOperand = trueOperands[i];
2284 if (trueOperand == falseOperands[i])
2285 operands.push_back(trueOperand);
2289 if (i == numTrueOperands) {
2296 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2297 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2298 mux->getLoc(), operands.front(), operands.size());
2300 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2304 operands.append(trueOperands.begin() + i, trueOperands.end());
2305 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2307 operands.append(falseOperands.begin() + i, falseOperands.end());
2309 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2312 Value lsb = rewriter.createOrFold<
MuxOp>(
2313 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2314 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2319 if (trueOperands.back() == falseOperands.back()) {
2320 SmallVector<Value> operands;
2323 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2324 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2325 operands.push_back(trueOperand);
2329 std::reverse(operands.begin(), operands.end());
2330 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2334 operands.append(trueOperands.begin(), trueOperands.end() - i);
2335 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2337 operands.append(falseOperands.begin(), falseOperands.end() - i);
2339 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2341 Value msb = rewriter.createOrFold<
MuxOp>(
2342 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2343 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, msb, sharedLSB);
2355 if (!trueVec || !falseVec)
2357 if (!trueVec.isUniform() || !falseVec.isUniform())
2361 op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2362 falseVec.getUniformElement(), op.getTwoState());
2364 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2371 using OpRewritePattern::OpRewritePattern;
2373 LogicalResult matchAndRewrite(
MuxOp op,
2374 PatternRewriter &rewriter)
const override;
2377LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2378 PatternRewriter &rewriter)
const {
2387 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2388 if (value.getBitWidth() == 1) {
2390 if (value.isZero()) {
2392 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, notCond,
2393 op.getFalseValue(),
false);
2398 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getCond(),
2399 op.getFalseValue(),
false);
2405 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2410 APInt xorValue = value ^ value2;
2411 if (xorValue.isPowerOf2()) {
2412 unsigned leadingZeros = xorValue.countLeadingZeros();
2413 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2414 SmallVector<Value, 3> operands;
2422 if (leadingZeros > 0)
2423 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2424 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2428 auto v1 = rewriter.createOrFold<
ExtractOp>(
2429 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2430 auto v2 = rewriter.createOrFold<
ExtractOp>(
2431 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2432 operands.push_back(rewriter.createOrFold<
MuxOp>(
2433 op.getLoc(), op.getCond(), v1, v2,
false));
2435 if (trailingZeros > 0)
2436 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2437 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2439 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
2446 if (value.isAllOnes() && value2.isZero()) {
2447 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2454 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2455 value.getBitWidth() == 1) {
2457 if (value.isZero()) {
2458 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getCond(),
2459 op.getTrueValue(),
false);
2466 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2467 op.getFalseValue(),
false);
2468 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, notCond,
2469 op.getTrueValue(),
false);
2475 Operation *condOp = op.getCond().getDefiningOp();
2476 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2478 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, op.getType(), subExpr,
2479 op.getFalseValue(), op.getTrueValue(),
2487 if (condOp && condOp->hasOneUse()) {
2488 SmallVector<Value> invertedOperands;
2492 auto getInvertedOperands = [&]() ->
bool {
2493 for (Value operand : condOp->getOperands()) {
2494 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2495 invertedOperands.push_back(subExpr);
2502 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2504 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2505 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newOr,
2507 op.getTrueValue(), op.getTwoState());
2510 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2512 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2513 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newAnd,
2515 op.getTrueValue(), op.getTwoState());
2520 if (
auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
2521 falseMux && falseMux != op) {
2523 if (op.getCond() == falseMux.getCond()) {
2524 replaceOpWithNewOpAndCopyName<MuxOp>(
2525 rewriter, op, op.getCond(), op.getTrueValue(),
2526 falseMux.getFalseValue(), op.getTwoStateAttr());
2535 if (
auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
2536 trueMux && trueMux != op) {
2538 if (op.getCond() == trueMux.getCond()) {
2539 replaceOpWithNewOpAndCopyName<MuxOp>(
2540 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2541 op.getFalseValue(), op.getTwoStateAttr());
2551 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2552 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2553 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2554 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2556 auto subMux = rewriter.create<
MuxOp>(
2557 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2558 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2559 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2560 trueMux.getTrueValue(), subMux,
2561 op.getTwoStateAttr());
2566 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2567 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2568 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2569 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2571 auto subMux = rewriter.create<
MuxOp>(
2572 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2573 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2574 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2575 subMux, trueMux.getFalseValue(),
2576 op.getTwoStateAttr());
2581 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2582 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2583 trueMux && falseMux &&
2584 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2585 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2587 auto subMux = rewriter.create<
MuxOp>(
2588 rewriter.getFusedLoc(
2589 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2590 op.getCond(), trueMux.getCond(), falseMux.getCond());
2591 replaceOpWithNewOpAndCopyName<MuxOp>(
2592 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2593 op.getTwoStateAttr());
2605 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2606 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2607 if (trueOp->getName() == falseOp->getName())
2624 if (op.getInputs().empty() || op.isUniform())
2626 auto inputs = op.getInputs();
2627 if (inputs.size() <= 1)
2632 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2637 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2638 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2639 if (!input || first.getCond() != input.getCond())
2644 SmallVector<Value> trues{first.getTrueValue()};
2645 SmallVector<Value> falses{first.getFalseValue()};
2646 SmallVector<Location> locs{first->getLoc()};
2647 bool isTwoState =
true;
2648 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2649 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2650 trues.push_back(input.getTrueValue());
2651 falses.push_back(input.getFalseValue());
2652 locs.push_back(input->getLoc());
2653 if (!input.getTwoState())
2658 auto loc = FusedLoc::get(op.getContext(), locs);
2662 auto arrayTy = op.getType();
2665 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2666 trueValues, falseValues, isTwoState);
2671 using OpRewritePattern::OpRewritePattern;
2674 PatternRewriter &rewriter)
const override {
2678 if (foldArrayOfMuxes(op, rewriter))
2686void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2687 MLIRContext *context) {
2688 results.insert<MuxRewriter, ArrayRewriter>(context);
2699 switch (predicate) {
2700 case ICmpPredicate::eq:
2702 case ICmpPredicate::ne:
2704 case ICmpPredicate::slt:
2705 return lhs.slt(rhs);
2706 case ICmpPredicate::sle:
2707 return lhs.sle(rhs);
2708 case ICmpPredicate::sgt:
2709 return lhs.sgt(rhs);
2710 case ICmpPredicate::sge:
2711 return lhs.sge(rhs);
2712 case ICmpPredicate::ult:
2713 return lhs.ult(rhs);
2714 case ICmpPredicate::ule:
2715 return lhs.ule(rhs);
2716 case ICmpPredicate::ugt:
2717 return lhs.ugt(rhs);
2718 case ICmpPredicate::uge:
2719 return lhs.uge(rhs);
2720 case ICmpPredicate::ceq:
2722 case ICmpPredicate::cne:
2724 case ICmpPredicate::weq:
2726 case ICmpPredicate::wne:
2729 llvm_unreachable(
"unknown comparison predicate");
2735 switch (predicate) {
2736 case ICmpPredicate::eq:
2737 case ICmpPredicate::sle:
2738 case ICmpPredicate::sge:
2739 case ICmpPredicate::ule:
2740 case ICmpPredicate::uge:
2741 case ICmpPredicate::ceq:
2742 case ICmpPredicate::weq:
2744 case ICmpPredicate::ne:
2745 case ICmpPredicate::slt:
2746 case ICmpPredicate::sgt:
2747 case ICmpPredicate::ult:
2748 case ICmpPredicate::ugt:
2749 case ICmpPredicate::cne:
2750 case ICmpPredicate::wne:
2753 llvm_unreachable(
"unknown comparison predicate");
2756OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2762 if (getLhs() == getRhs()) {
2764 return IntegerAttr::get(getType(), val);
2768 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2769 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2772 return IntegerAttr::get(getType(), val);
2780template <
typename Range>
2782 size_t commonPrefixLength = 0;
2783 auto ia = a.begin();
2784 auto ib = b.begin();
2786 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2792 return commonPrefixLength;
2796 size_t totalWidth = 0;
2797 for (
auto operand : operands) {
2800 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2802 totalWidth += width;
2812 PatternRewriter &rewriter) {
2816 SmallVector<Value> lhsOperands, rhsOperands;
2819 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2821 auto formCatOrReplicate = [&](Location loc,
2822 ArrayRef<Value> operands) -> Value {
2823 assert(!operands.empty());
2824 Value sameElement = operands[0];
2825 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2826 if (sameElement != operands[i])
2827 sameElement = Value();
2829 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2831 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2834 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2835 Value rhs) -> LogicalResult {
2836 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2841 size_t commonPrefixLength =
2843 if (commonPrefixLength == lhsOperands.size()) {
2846 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
2852 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2854 size_t commonPrefixTotalWidth =
2855 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2856 size_t commonSuffixTotalWidth =
2857 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2858 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2859 .drop_back(commonSuffixLength);
2860 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2861 .drop_back(commonSuffixLength);
2863 auto replaceWithoutReplicatingSignBit = [&]() {
2864 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2865 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2866 return replaceWith(op.getPredicate(), newLhs, newRhs);
2869 auto replaceWithReplicatingSignBit = [&]() {
2870 auto firstNonEmptyValue = lhsOperands[0];
2871 auto firstNonEmptyElemWidth =
2872 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2873 Value signBit = rewriter.createOrFold<
ExtractOp>(
2874 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2876 auto newLhs = rewriter.
create<
ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2877 auto newRhs = rewriter.create<
ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2878 return replaceWith(op.getPredicate(), newLhs, newRhs);
2881 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2883 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2884 return replaceWithoutReplicatingSignBit();
2890 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2891 return replaceWithReplicatingSignBit();
2893 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2895 return replaceWithoutReplicatingSignBit();
2909 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2910 PatternRewriter &rewriter) {
2914 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2915 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2918 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2919 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
2927 SmallVector<Value> newConcatOperands;
2928 auto newConstant = APInt::getZeroWidth();
2933 unsigned knownMSB = bitsKnown.countLeadingOnes();
2935 Value operand = cmpOp.getLhs();
2940 while (knownMSB != bitsKnown.getBitWidth()) {
2943 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2946 unsigned unknownBits = bitsKnown.countLeadingZeros();
2947 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2948 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
2949 operand.getLoc(), operand, lowBit,
2951 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2954 newConcatOperands.push_back(spanOperand);
2957 if (newConstant.getBitWidth() != 0)
2958 newConstant = newConstant.concat(spanConstant);
2960 newConstant = spanConstant;
2963 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
2964 bitsKnown = bitsKnown.trunc(newWidth);
2965 knownMSB = bitsKnown.countLeadingOnes();
2971 if (newConcatOperands.empty()) {
2972 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
2973 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
2979 Value concatResult =
2980 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
2984 cmpOp.getOperand(1).getLoc(), newConstant);
2986 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
2987 concatResult, newConstantOp,
2988 cmpOp.getTwoState());
2994 PatternRewriter &rewriter) {
2995 auto ip = rewriter.saveInsertionPoint();
2996 rewriter.setInsertionPoint(xorOp);
2998 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
3000 xorRHS.getValue() ^ rhs);
3002 switch (xorOp.getNumOperands()) {
3006 APInt::getZero(rhs.getBitWidth()));
3010 newLHS = xorOp.getOperand(0);
3014 SmallVector<Value> newOperands(xorOp.getOperands());
3015 newOperands.pop_back();
3016 newLHS = rewriter.create<
XorOp>(xorOp.getLoc(), newOperands,
false);
3020 bool xorMultipleUses = !xorOp->hasOneUse();
3024 if (xorMultipleUses)
3025 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3029 rewriter.restoreInsertionPoint(ip);
3030 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3031 newLHS, newRHS,
false);
3034LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3041 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3042 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3043 "Should be folded");
3044 replaceOpWithNewOpAndCopyName<ICmpOp>(
3045 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3046 op.getRhs(), op.getLhs(), op.getTwoState());
3051 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3053 return rewriter.create<
hw::ConstantOp>(op.getLoc(), std::move(constant));
3056 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3057 Value rhs) -> LogicalResult {
3058 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
3063 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3064 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
3065 APInt(1, constant));
3069 switch (op.getPredicate()) {
3070 case ICmpPredicate::slt:
3072 if (rhs.isMaxSignedValue())
3073 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3075 if (rhs.isMinSignedValue())
3076 return replaceWithConstantI1(0);
3078 if ((rhs - 1).isMinSignedValue())
3079 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3082 case ICmpPredicate::sgt:
3084 if (rhs.isMinSignedValue())
3085 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3087 if (rhs.isMaxSignedValue())
3088 return replaceWithConstantI1(0);
3090 if ((rhs + 1).isMaxSignedValue())
3091 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3094 case ICmpPredicate::ult:
3096 if (rhs.isAllOnes())
3097 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3100 return replaceWithConstantI1(0);
3102 if ((rhs - 1).isZero())
3103 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3107 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3108 rhs.getBitWidth()) {
3109 auto numOnes = rhs.countLeadingOnes();
3110 auto smaller = rewriter.create<
ExtractOp>(
3111 op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3112 return replaceWith(ICmpPredicate::ne, smaller,
3117 case ICmpPredicate::ugt:
3120 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3122 if (rhs.isAllOnes())
3123 return replaceWithConstantI1(0);
3125 if ((rhs + 1).isAllOnes())
3126 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3130 if ((rhs + 1).isPowerOf2()) {
3131 auto numOnes = rhs.countTrailingOnes();
3132 auto newWidth = rhs.getBitWidth() - numOnes;
3133 auto smaller = rewriter.create<
ExtractOp>(op.getLoc(), op.getLhs(),
3135 return replaceWith(ICmpPredicate::ne, smaller,
3140 case ICmpPredicate::sle:
3142 if (rhs.isMaxSignedValue())
3143 return replaceWithConstantI1(1);
3145 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3146 case ICmpPredicate::sge:
3148 if (rhs.isMinSignedValue())
3149 return replaceWithConstantI1(1);
3151 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3152 case ICmpPredicate::ule:
3154 if (rhs.isAllOnes())
3155 return replaceWithConstantI1(1);
3157 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3158 case ICmpPredicate::uge:
3161 return replaceWithConstantI1(1);
3163 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3164 case ICmpPredicate::eq:
3165 if (rhs.getBitWidth() == 1) {
3168 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3173 if (rhs.isAllOnes()) {
3180 case ICmpPredicate::ne:
3181 if (rhs.getBitWidth() == 1) {
3187 if (rhs.isAllOnes()) {
3189 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3196 case ICmpPredicate::ceq:
3197 case ICmpPredicate::cne:
3198 case ICmpPredicate::weq:
3199 case ICmpPredicate::wne:
3205 if (op.getPredicate() == ICmpPredicate::eq ||
3206 op.getPredicate() == ICmpPredicate::ne) {
3211 if (!knownBits.isUnknown())
3218 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3225 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3226 if (rhs.isAllOnes() || rhs.isZero()) {
3227 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3229 op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width)
3230 : APInt::getZero(width));
3231 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, op.getPredicate(),
3232 replicateOp.getInput(), cst,
3242 if (Operation *opLHS = op.getLhs().getDefiningOp())
3243 if (Operation *opRHS = op.getRhs().getDefiningOp())
3244 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3245 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "sv.namehint" attribute.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.