13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15#include "llvm/ADT/SetVector.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/KnownBits.h"
23using namespace matchers;
34 Block *thisBlock = op->getBlock();
35 return llvm::any_of(op->getOperands(), [&](Value operand) {
36 return operand.getParentBlock() != thisBlock;
46 ArrayRef<Value> operands, OpBuilder &builder) {
47 OperationState state(loc, name);
48 state.addOperands(operands);
49 state.addTypes(operands[0].getType());
50 return builder.create(state)->getResult(0);
53static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
54 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
53static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
…}
61 for (
auto op :
concat.getOperands())
63 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
64 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
75 return op->hasAttr(
"sv.attributes");
79template <
typename SubType>
80struct ComplementMatcher {
82 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
83 bool match(Operation *op) {
84 auto xorOp = dyn_cast<XorOp>(op);
85 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
90template <
typename SubType>
91static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
92 return ComplementMatcher<SubType>(subExpr);
91static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
…}
98 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
99 "must be commutative operations"));
100 if (op->hasOneUse()) {
101 auto *user = *op->getUsers().begin();
102 return user->getName() == op->getName() &&
103 op->getAttrOfType<UnitAttr>(
"twoState") ==
104 user->getAttrOfType<UnitAttr>(
"twoState") &&
105 op->getBlock() == user->getBlock();
120 auto inputs = op->getOperands();
122 SmallVector<Value, 4> newOperands;
123 SmallVector<Location, 4> newLocations{op->getLoc()};
124 newOperands.reserve(inputs.size());
126 decltype(inputs.begin()) current, end;
129 SmallVector<Element> worklist;
130 worklist.push_back({inputs.begin(), inputs.end()});
131 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
132 bool changed =
false;
133 while (!worklist.empty()) {
134 auto &element = worklist.back();
137 if (element.current == element.end) {
142 Value value = *element.current++;
143 auto *flattenOp = value.getDefiningOp();
146 if (!flattenOp || flattenOp->getName() != op->getName() ||
147 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState") ||
148 flattenOp->getBlock() != op->getBlock()) {
149 newOperands.push_back(value);
154 if (!value.hasOneUse()) {
162 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
165 newOperands.push_back(value);
173 auto flattenOpInputs = flattenOp->getOperands();
174 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
175 newLocations.push_back(flattenOp->getLoc());
181 Value result =
createGenericOp(FusedLoc::get(op->getContext(), newLocations),
182 op->getName(), newOperands, rewriter);
184 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
192static std::pair<size_t, size_t>
194 size_t originalOpWidth) {
195 auto users = op->getUsers();
197 "getLowestBitAndHighestBitRequired cannot operate on "
198 "a empty list of uses.");
202 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
203 size_t highestBitRequired = 0;
205 for (
auto *user : users) {
206 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
207 size_t lowBit = extractOp.getLowBit();
209 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
210 highestBitRequired = std::max(highestBitRequired, highBit);
211 lowestBitRequired = std::min(lowestBitRequired, lowBit);
215 highestBitRequired = originalOpWidth - 1;
216 lowestBitRequired = 0;
220 return {lowestBitRequired, highestBitRequired};
225 PatternRewriter &rewriter) {
226 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
232 if (range.second + 1 == opType.getWidth() && range.first == 0)
235 SmallVector<Value> args;
236 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
237 for (
auto inop : op.getOperands()) {
239 if (inop.getType() != op.getType())
240 args.push_back(inop);
242 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
245 auto newop = rewriter.create<OpTy>(op.getLoc(), newType, args);
246 newop->setDialectAttrs(op->getDialectAttrs());
247 if (op.getTwoState())
248 newop.setTwoState(
true);
250 Value newResult = newop.getResult();
252 newResult = rewriter.createOrFold<
ConcatOp>(
253 op.getLoc(), newResult,
255 APInt::getZero(range.first)));
256 if (range.second + 1 < opType.getWidth())
257 newResult = rewriter.createOrFold<
ConcatOp>(
260 op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
262 rewriter.replaceOp(op, newResult);
270OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
275 if (cast<IntegerType>(getType()).
getWidth() ==
276 getInput().getType().getIntOrFloatBitWidth())
280 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
281 if (input.getValue().getBitWidth() == 1) {
282 if (input.getValue().isZero())
284 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
287 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
291 APInt result = APInt::getZeroWidth();
292 for (
auto i = getMultiple(); i != 0; --i)
293 result = result.concat(input.getValue());
300OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
305 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
306 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
318 hw::PEO paramOpcode) {
319 assert(operands.size() == 2 &&
"binary op takes two operands");
320 if (!operands[0] || !operands[1])
325 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
326 cast<TypedAttr>(operands[1]));
329OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
333 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
334 unsigned shift = rhs.getValue().getZExtValue();
335 unsigned width = getType().getIntOrFloatBitWidth();
337 return getOperand(0);
339 return getIntAttr(APInt::getZero(width), getContext());
345LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
351 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
354 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
355 unsigned shift = value.getZExtValue();
358 if (width <= shift || shift == 0)
362 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
366 rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), 0, width - shift);
368 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
372OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
376 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
377 unsigned shift = rhs.getValue().getZExtValue();
379 return getOperand(0);
381 unsigned width = getType().getIntOrFloatBitWidth();
383 return getIntAttr(APInt::getZero(width), getContext());
388LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
394 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
397 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
398 unsigned shift = value.getZExtValue();
401 if (width <= shift || shift == 0)
405 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
408 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
411 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
415OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
419 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
420 if (rhs.getValue().getZExtValue() == 0)
421 return getOperand(0);
426LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
432 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
435 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
436 unsigned shift = value.getZExtValue();
439 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
440 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
442 if (width <= shift) {
447 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
450 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
458OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
463 if (getInput().getType() == getType())
467 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
468 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
469 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
480 PatternRewriter &rewriter) {
481 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
482 size_t beginOfFirstRelevantElement = 0;
483 auto it = reversedConcatArgs.begin();
484 size_t lowBit = op.getLowBit();
487 for (; it != reversedConcatArgs.end(); it++) {
488 assert(beginOfFirstRelevantElement <= lowBit &&
489 "incorrectly moved past an element that lowBit has coverage over");
492 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
493 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
517 beginOfFirstRelevantElement += operandWidth;
519 assert(it != reversedConcatArgs.end() &&
520 "incorrectly failed to find an element which contains coverage of "
523 SmallVector<Value> reverseConcatArgs;
524 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
525 size_t extractLo = lowBit - beginOfFirstRelevantElement;
530 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
531 auto concatArg = *it;
532 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
533 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
535 if (widthToConsume == operandWidth && extractLo == 0) {
536 reverseConcatArgs.push_back(concatArg);
538 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
539 reverseConcatArgs.push_back(
540 rewriter.create<
ExtractOp>(op.getLoc(), resultType, *it, extractLo));
543 widthRemaining -= widthToConsume;
549 if (reverseConcatArgs.size() == 1) {
552 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
553 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
560 PatternRewriter &rewriter) {
561 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
562 auto replicateEltWidth =
563 replicate.getOperand().getType().getIntOrFloatBitWidth();
567 if (op.getLowBit() % replicateEltWidth == 0 &&
568 extractResultWidth % replicateEltWidth == 0) {
569 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
570 replicate.getOperand());
576 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
578 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
579 rewriter, op, op.getType(), replicate.getOperand(),
580 op.getLowBit() % replicateEltWidth);
589LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
593 auto *inputOp = op.getInput().getDefiningOp();
600 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
602 if (knownBits.isConstant()) {
603 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
604 knownBits.getConstant());
610 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
611 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
612 rewriter, op, op.getType(), innerExtract.getInput(),
613 innerExtract.getLowBit() + op.getLowBit());
618 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
622 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
628 if (inputOp && inputOp->getNumOperands() == 2 &&
629 isa<AndOp, OrOp, XorOp>(inputOp)) {
630 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
631 auto extractedCst = cstRHS.getValue().extractBits(
632 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
633 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
634 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
635 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
643 if (isa<AndOp>(inputOp)) {
646 unsigned lz = extractedCst.countLeadingZeros();
647 unsigned tz = extractedCst.countTrailingZeros();
648 unsigned pop = extractedCst.popcount();
649 if (extractedCst.getBitWidth() - lz - tz == pop) {
650 auto resultTy = rewriter.getIntegerType(pop);
651 SmallVector<Value> resultElts;
654 op.getLoc(), APInt::getZero(lz)));
655 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
656 op.getLoc(), resultTy, inputOp->getOperand(0),
657 op.getLowBit() + tz));
660 op.getLoc(), APInt::getZero(tz)));
661 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
670 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
671 if (
auto shlOp = dyn_cast<ShlOp>(inputOp)) {
673 if (shlOp->hasOneUse())
675 if (lhsCst.getValue().isOne()) {
678 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
679 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
680 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
696 hw::PEO paramOpcode) {
697 assert(operands.size() > 1 &&
"caller should handle one-operand case");
700 if (!operands[1] || !operands[0])
704 if (llvm::all_of(operands.drop_front(2),
705 [&](Attribute in) { return !!in; })) {
706 SmallVector<mlir::TypedAttr> typedOperands;
707 typedOperands.reserve(operands.size());
708 for (
auto operand : operands) {
709 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
710 typedOperands.push_back(typedOperand);
714 if (typedOperands.size() == operands.size())
715 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
731 size_t concatIdx,
const APInt &cst,
732 PatternRewriter &rewriter) {
733 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
734 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
739 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
740 auto *operandOp = operand.getDefiningOp();
745 if (isa<hw::ConstantOp>(operandOp))
749 return operandOp->getName() == logicalOp->getName() &&
750 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
751 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
759 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
760 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
767 SmallVector<Value> newConcatOperands;
768 newConcatOperands.reserve(concatOp->getNumOperands());
771 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
772 for (Value operand : concatOp->getOperands()) {
773 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
774 nextOperandBit -= operandWidth;
777 logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
779 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
788 if (logicalOp->getNumOperands() > 2) {
789 auto origOperands = logicalOp->getOperands();
790 SmallVector<Value> operands;
792 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
794 operands.append(origOperands.begin() + concatIdx + 1,
795 origOperands.begin() + (origOperands.size() - 1));
797 operands.push_back(newResult);
798 newResult = createLogicalOp(operands);
808 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
810 for (
auto op : operands) {
811 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
812 icmpOp && icmpOp.getTwoState()) {
813 auto predicate = icmpOp.getPredicate();
814 auto lhs = icmpOp.getLhs();
815 auto rhs = icmpOp.getRhs();
816 if (seenPredicates.contains(
817 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
820 seenPredicates.insert({predicate, lhs, rhs});
826OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
830 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
832 auto inputs = adaptor.getInputs();
835 for (
auto operand : inputs) {
838 value &= cast<IntegerAttr>(operand).getValue();
844 if (inputs.size() == 2 && inputs[1] &&
845 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
846 return getInputs()[0];
849 if (llvm::all_of(getInputs(),
850 [&](
auto in) {
return in == this->getInputs()[0]; }))
851 return getInputs()[0];
854 for (Value arg : getInputs()) {
857 for (Value arg2 : getInputs())
860 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
881template <
typename Op>
883 if (!op.getType().isInteger(1))
886 auto inputs = op.getInputs();
887 size_t size = inputs.size();
889 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
892 Value source = sourceOp.getOperand();
895 if (size != source.getType().getIntOrFloatBitWidth())
899 llvm::BitVector bits(size);
900 bits.set(sourceOp.getLowBit());
902 for (
size_t i = 1; i != size; ++i) {
903 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
904 if (!extractOp || extractOp.getOperand() != source)
906 bits.set(extractOp.getLowBit());
909 return bits.all() ? source : Value();
916template <
typename Op>
919 constexpr unsigned limit = 3;
920 auto inputs = op.getInputs();
922 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
923 llvm::SmallDenseSet<Op, 8> checked;
930 llvm::SmallVector<OpWithDepth, 8> worklist;
932 auto enqueue = [&worklist, &checked, &op](Value input,
unsigned depth) {
936 if (depth < limit && input.getParentBlock() == op->getBlock()) {
937 auto inputOp = input.template getDefiningOp<Op>();
938 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
939 checked.insert(inputOp).second)
940 worklist.push_back({inputOp, depth + 1});
944 for (
auto input : uniqueInputs)
947 while (!worklist.empty()) {
948 auto item = worklist.pop_back_val();
950 for (
auto input : item.op.getInputs()) {
951 uniqueInputs.remove(input);
952 enqueue(input, item.depth);
956 if (uniqueInputs.size() < inputs.size()) {
957 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
958 uniqueInputs.getArrayRef(),
966LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
967 auto inputs = op.getInputs();
968 auto size = inputs.size();
982 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
986 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
988 if (value.isAllOnes()) {
989 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
990 inputs.drop_back(),
false);
998 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
999 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value & value2);
1000 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1001 newOperands.push_back(cst);
1002 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1003 newOperands,
false);
1008 if (size == 2 && value.isPowerOf2()) {
1013 if (
auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1014 auto replicateOperand = replicate.getOperand();
1015 if (replicateOperand.getType().isInteger(1)) {
1016 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1017 auto trailingZeros = value.countTrailingZeros();
1020 SmallVector<Value, 3> concatOperands;
1021 if (trailingZeros != resultWidth - 1) {
1023 op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
1024 concatOperands.push_back(highZeros);
1026 concatOperands.push_back(replicateOperand);
1027 if (trailingZeros != 0) {
1029 op.getLoc(), APInt::getZero(trailingZeros));
1030 concatOperands.push_back(lowZeros);
1032 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1033 rewriter, op, op.getType(), concatOperands);
1040 if (
auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
1043 (value.countLeadingZeros() || value.countTrailingZeros())) {
1044 unsigned lz = value.countLeadingZeros();
1045 unsigned tz = value.countTrailingZeros();
1048 auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
1049 Value smallElt = rewriter.createOrFold<
ExtractOp>(
1050 extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1051 extractOp.getLowBit() + tz);
1053 APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1054 if (!smallMask.isAllOnes()) {
1055 auto loc = inputs.back().getLoc();
1056 smallElt = rewriter.createOrFold<
AndOp>(
1063 SmallVector<Value> resultElts;
1065 resultElts.push_back(
1066 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1067 resultElts.push_back(smallElt);
1069 resultElts.push_back(
1070 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1071 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
1079 for (
size_t i = 0; i < size - 1; ++i) {
1080 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1093 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1094 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1095 source, cmpAgainst);
1103OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1107 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1108 auto inputs = adaptor.getInputs();
1110 for (
auto operand : inputs) {
1113 value |= cast<IntegerAttr>(operand).getValue();
1114 if (value.isAllOnes())
1119 if (inputs.size() == 2 && inputs[1] &&
1120 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1121 return getInputs()[0];
1124 if (llvm::all_of(getInputs(),
1125 [&](
auto in) {
return in == this->getInputs()[0]; }))
1126 return getInputs()[0];
1129 for (Value arg : getInputs()) {
1131 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1132 for (Value arg2 : getInputs())
1133 if (arg2 == subExpr)
1135 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1145 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1152LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1153 auto inputs = op.getInputs();
1154 auto size = inputs.size();
1168 assert(size > 1 &&
"expected 2 or more operands");
1172 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1174 if (value.isZero()) {
1175 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1176 inputs.drop_back());
1182 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1183 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value | value2);
1184 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1185 newOperands.push_back(cst);
1186 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1194 for (
size_t i = 0; i < size - 1; ++i) {
1195 if (
auto concat = inputs[i].getDefiningOp<ConcatOp>())
1208 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1209 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1210 source, cmpAgainst);
1216 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1218 if (op.getTwoState() && firstMux.getTwoState() &&
1219 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1221 SmallVector<Value> conditions{firstMux.getCond()};
1222 auto check = [&](Value v) {
1226 conditions.push_back(mux.getCond());
1227 return mux.getTwoState() &&
1228 firstMux.getTrueValue() == mux.getTrueValue() &&
1229 firstMux.getFalseValue() == mux.getFalseValue();
1231 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1232 auto cond = rewriter.create<
comb::OrOp>(op.getLoc(), conditions,
true);
1233 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1234 rewriter, op, cond, firstMux.getTrueValue(),
1235 firstMux.getFalseValue(),
true);
1245OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1249 auto size = getInputs().size();
1250 auto inputs = adaptor.getInputs();
1254 return getInputs()[0];
1257 if (size == 2 && getInputs()[0] == getInputs()[1])
1258 return IntegerAttr::get(getType(), 0);
1261 if (inputs.size() == 2 && inputs[1] &&
1262 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1263 return getInputs()[0];
1267 if (isBinaryNot()) {
1269 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1270 subExpr != getResult())
1280 PatternRewriter &rewriter) {
1281 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1282 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1285 rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1286 icmp.getOperand(1), icmp.getTwoState());
1289 if (op.getNumOperands() > 2) {
1290 SmallVector<Value, 4> newOperands(op.getOperands());
1291 newOperands.pop_back();
1292 newOperands.erase(newOperands.begin() + icmpOperand);
1293 newOperands.push_back(result);
1294 result = rewriter.create<
XorOp>(op.getLoc(), newOperands, op.getTwoState());
1300LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1304 auto inputs = op.getInputs();
1305 auto size = inputs.size();
1306 assert(size > 1 &&
"expected 2 or more operands");
1309 if (inputs[size - 1] == inputs[size - 2]) {
1311 "expected idempotent case for 2 elements handled already.");
1312 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1313 inputs.drop_back(2),
false);
1319 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1321 if (value.isZero()) {
1322 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1323 inputs.drop_back(),
false);
1329 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1330 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value ^ value2);
1331 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1332 newOperands.push_back(cst);
1333 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1334 newOperands,
false);
1338 bool isSingleBit = value.getBitWidth() == 1;
1341 for (
size_t i = 0; i < size - 1; ++i) {
1342 Value operand = inputs[i];
1353 if (isSingleBit && operand.hasOneUse()) {
1354 assert(value == 1 &&
"single bit constant has to be one if not zero");
1355 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1371 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1378OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1383 if (getRhs() == getLhs())
1385 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1388 if (adaptor.getRhs()) {
1390 if (adaptor.getLhs()) {
1393 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1395 auto rhsNeg = hw::ParamExprAttr::get(
1396 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1397 return hw::ParamExprAttr::get(hw::PEO::Add,
1398 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1402 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1403 if (rhsC.getValue().isZero())
1411LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1417 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1418 auto negCst = rewriter.create<
hw::ConstantOp>(op.getLoc(), -value);
1419 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1431OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1435 auto size = getInputs().size();
1439 return getInputs()[0];
1445LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1449 auto inputs = op.getInputs();
1450 auto size = inputs.size();
1451 assert(size > 1 &&
"expected 2 or more operands");
1453 APInt value, value2;
1456 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1457 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1458 inputs.drop_back(),
false);
1463 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1464 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1465 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1466 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1467 newOperands.push_back(cst);
1468 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1469 newOperands,
false);
1474 if (inputs[size - 1] == inputs[size - 2]) {
1475 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1477 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1481 newOperands.push_back(shiftLeftOp);
1482 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1483 newOperands,
false);
1487 auto shlOp = inputs[size - 1].getDefiningOp<
comb::ShlOp>();
1489 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1490 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1492 APInt one(value.getBitWidth(), 1,
false);
1494 rewriter.create<
hw::ConstantOp>(op.getLoc(), (one << value) + one);
1496 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1497 auto mulOp = rewriter.create<
comb::MulOp>(op.getLoc(), factors,
false);
1499 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1500 newOperands.push_back(mulOp);
1501 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1502 newOperands,
false);
1506 auto mulOp = inputs[size - 1].getDefiningOp<
comb::MulOp>();
1508 if (mulOp && mulOp.getInputs().size() == 2 &&
1509 mulOp.getInputs()[0] == inputs[size - 2] &&
1510 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1512 APInt one(value.getBitWidth(), 1,
false);
1513 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + one);
1514 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1517 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1518 newOperands.push_back(newMulOp);
1519 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1520 newOperands,
false);
1533 auto addOp = inputs[0].getDefiningOp<
comb::AddOp>();
1534 if (addOp && addOp.getInputs().size() == 2 &&
1535 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1536 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1538 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1539 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1540 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1541 op.getTwoState() && addOp.getTwoState());
1548OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1552 auto size = getInputs().size();
1553 auto inputs = adaptor.getInputs();
1557 return getInputs()[0];
1559 auto width = cast<IntegerType>(getType()).getWidth();
1561 return getIntAttr(APInt::getZero(0), getContext());
1563 APInt value(width, 1,
false);
1566 for (
auto operand : inputs) {
1569 value *= cast<IntegerAttr>(operand).getValue();
1578LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1582 auto inputs = op.getInputs();
1583 auto size = inputs.size();
1584 assert(size > 1 &&
"expected 2 or more operands");
1586 APInt value, value2;
1589 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1590 value.isPowerOf2()) {
1591 auto shift = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(),
1592 value.exactLogBase2());
1596 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1597 ArrayRef<Value>(shlOp),
false);
1602 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1603 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1604 inputs.drop_back());
1609 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1610 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1611 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value * value2);
1612 SmallVector<Value, 4> newOperands(inputs.drop_back(2));
1613 newOperands.push_back(cst);
1614 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1630template <
class Op,
bool isSigned>
1631static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1632 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1634 if (rhsValue.getValue() == 1)
1638 if (rhsValue.getValue().isZero())
1631static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
…}
1645OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1649 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1652OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1659template <
class Op,
bool isSigned>
1660static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1661 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1663 if (rhsValue.getValue() == 1)
1664 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1668 if (rhsValue.getValue().isZero())
1672 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1674 if (lhsValue.getValue().isZero())
1675 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1660static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
…}
1682OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1686 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1689OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1700OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1704 if (getNumOperands() == 1)
1705 return getOperand(0);
1708 for (
auto attr : adaptor.getInputs())
1709 if (!attr || !isa<IntegerAttr>(attr))
1713 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1714 APInt result(resultWidth, 0);
1716 unsigned nextInsertion = resultWidth;
1718 for (
auto attr : adaptor.getInputs()) {
1719 auto chunk = cast<IntegerAttr>(attr).getValue();
1720 nextInsertion -= chunk.getBitWidth();
1721 result.insertBits(chunk, nextInsertion);
1727LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1731 auto inputs = op.getInputs();
1732 auto size = inputs.size();
1733 assert(size > 1 &&
"expected 2 or more operands");
1738 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1739 ValueRange replacements) -> LogicalResult {
1740 SmallVector<Value, 4> newOperands;
1741 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1742 newOperands.append(replacements.begin(), replacements.end());
1743 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1744 if (newOperands.size() == 1)
1747 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1752 Value commonOperand = inputs[0];
1753 for (
size_t i = 0; i != size; ++i) {
1755 if (inputs[i] != commonOperand)
1756 commonOperand = Value();
1760 if (
auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1761 return flattenConcat(i, i, subConcat->getOperands());
1766 if (
auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1767 if (
auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1768 unsigned prevWidth = prevCst.getValue().getBitWidth();
1769 unsigned thisWidth = cst.getValue().getBitWidth();
1770 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1771 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1775 return flattenConcat(i - 1, i, replacement);
1780 if (inputs[i] == inputs[i - 1]) {
1782 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1783 return flattenConcat(i - 1, i, replacement);
1788 if (
auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1790 if (repl.getOperand() == inputs[i - 1]) {
1791 Value replacement = rewriter.createOrFold<ReplicateOp>(
1792 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1793 return flattenConcat(i - 1, i, replacement);
1796 if (
auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1797 if (prevRepl.getOperand() == repl.getOperand()) {
1798 Value replacement = rewriter.createOrFold<ReplicateOp>(
1799 op.getLoc(), repl.getOperand(),
1800 repl.getMultiple() + prevRepl.getMultiple());
1801 return flattenConcat(i - 1, i, replacement);
1807 if (
auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1808 if (repl.getOperand() == inputs[i]) {
1809 Value replacement = rewriter.createOrFold<ReplicateOp>(
1810 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1811 return flattenConcat(i - 1, i, replacement);
1817 if (
auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1818 if (
auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1819 if (extract.getInput() == prevExtract.getInput()) {
1820 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1821 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1822 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1823 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1824 Value replacement = rewriter.create<
ExtractOp>(
1825 op.getLoc(), resType, extract.getInput(),
1826 extract.getLowBit());
1827 return flattenConcat(i - 1, i, replacement);
1840 static std::optional<ArraySlice>
get(Value value) {
1841 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1843 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1846 if (
auto arraySlice =
1849 arraySlice.getInput(), arraySlice.getLowIndex(),
1850 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1852 return std::nullopt;
1855 if (
auto extractOpt = ArraySlice::get(inputs[i])) {
1856 if (
auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1858 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1859 prevExtractOpt->input == extractOpt->input &&
1860 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1861 extractOpt->width)) {
1862 auto resType = hw::ArrayType::get(
1863 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1865 extractOpt->width + prevExtractOpt->width);
1866 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1868 op.getLoc(), resIntType,
1870 prevExtractOpt->input,
1871 extractOpt->index));
1872 return flattenConcat(i - 1, i, replacement);
1880 if (commonOperand) {
1881 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1893OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1898 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1899 return getTrueValue();
1900 if (
auto tv = adaptor.getTrueValue())
1901 if (tv == adaptor.getFalseValue())
1906 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1907 if (pred.getValue().isZero())
1908 return getFalseValue();
1909 return getTrueValue();
1913 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1914 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1915 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1916 hw::getBitWidth(getType()) == 1)
1932 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
1934 auto requiredPredicate =
1935 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1936 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1946 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
1949 for (
auto operand : orOp.getOperands())
1956 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
1959 for (
auto operand : andOp.getOperands())
1977 PatternRewriter &rewriter) {
1980 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
1983 Value indexValue = rootCmp.getLhs();
1986 auto getCaseValue = [&](
MuxOp mux) -> Value {
1987 return mux.getOperand(1 +
unsigned(!isFalseSide));
1992 auto getTreeValue = [&](
MuxOp mux) -> Value {
1993 return mux.getOperand(1 +
unsigned(isFalseSide));
1998 SmallVector<Location> locationsFound;
1999 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2003 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2005 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2006 valuesFound.push_back({cst, getCaseValue(mux)});
2007 locationsFound.push_back(mux.getCond().getLoc());
2008 locationsFound.push_back(mux->getLoc());
2013 if (!collectConstantValues(rootMux))
2017 if (rootMux->hasOneUse()) {
2018 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2019 if (getTreeValue(userMux) == rootMux.getResult() &&
2027 auto nextTreeValue = getTreeValue(rootMux);
2029 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2030 if (!nextMux || !nextMux->hasOneUse())
2032 if (!collectConstantValues(nextMux))
2034 nextTreeValue = getTreeValue(nextMux);
2040 if (valuesFound.size() < 3)
2045 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2046 if (indexWidth >= 9)
2052 uint64_t tableSize = 1ULL << indexWidth;
2053 if (valuesFound.size() < (tableSize * 5) / 8)
2058 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2063 for (
auto &elt :
llvm::reverse(valuesFound)) {
2064 uint64_t idx = elt.first.getValue().getZExtValue();
2065 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2066 table[idx] = elt.second;
2071 std::reverse(table.begin(), table.end());
2074 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2076 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2091 PatternRewriter &rewriter) {
2092 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2093 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2097 if (fullyAssoc->getNumOperands() == 2)
2098 return fullyAssoc->getOperand(operandNo ^ 1);
2101 if (fullyAssoc->hasOneUse()) {
2102 rewriter.modifyOpInPlace(fullyAssoc,
2103 [&]() { fullyAssoc->eraseOperand(operandNo); });
2104 return fullyAssoc->getResult(0);
2108 SmallVector<Value> operands;
2109 operands.append(fullyAssoc->getOperands().begin(),
2110 fullyAssoc->getOperands().begin() + operandNo);
2111 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2112 fullyAssoc->getOperands().end());
2114 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2115 Value excluded = fullyAssoc->getOperand(operandNo);
2119 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2121 return opWithoutExcluded;
2131 PatternRewriter &rewriter) {
2134 Operation *subExpr =
2135 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2136 if (!subExpr || subExpr->getNumOperands() < 2)
2140 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2145 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2146 size_t opNo = 0, e = subExpr->getNumOperands();
2147 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2153 Value cond = op.getCond();
2159 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2164 Value subCond = subMux.getCond();
2167 if (subMux.getTrueValue() == commonValue)
2168 otherValue = subMux.getFalseValue();
2169 else if (subMux.getFalseValue() == commonValue) {
2170 otherValue = subMux.getTrueValue();
2180 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2181 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2182 otherValue, op.getTwoState());
2188 bool isaAndOp = isa<AndOp>(subExpr);
2189 if (isTrueOperand ^ isaAndOp)
2193 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2196 bool isaXorOp = isa<XorOp>(subExpr);
2197 bool isaOrOp = isa<OrOp>(subExpr);
2206 if (isaOrOp || isaXorOp) {
2207 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2208 restOfAssoc,
false);
2210 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2211 commonValue,
false);
2213 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2219 assert(isaAndOp &&
"unexpected operation here");
2220 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2221 restOfAssoc,
false);
2222 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2233 PatternRewriter &rewriter) {
2236 if (!isa<ConcatOp>(trueOp))
2240 SmallVector<Value> trueOperands, falseOperands;
2244 size_t numTrueOperands = trueOperands.size();
2245 size_t numFalseOperands = falseOperands.size();
2247 if (!numTrueOperands || !numFalseOperands ||
2248 (trueOperands.front() != falseOperands.front() &&
2249 trueOperands.back() != falseOperands.back()))
2253 if (trueOperands.front() == falseOperands.front()) {
2254 SmallVector<Value> operands;
2256 for (i = 0; i < numTrueOperands; ++i) {
2257 Value trueOperand = trueOperands[i];
2258 if (trueOperand == falseOperands[i])
2259 operands.push_back(trueOperand);
2263 if (i == numTrueOperands) {
2270 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2271 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2272 mux->getLoc(), operands.front(), operands.size());
2274 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2278 operands.append(trueOperands.begin() + i, trueOperands.end());
2279 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2281 operands.append(falseOperands.begin() + i, falseOperands.end());
2283 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2286 Value lsb = rewriter.createOrFold<
MuxOp>(
2287 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2288 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2293 if (trueOperands.back() == falseOperands.back()) {
2294 SmallVector<Value> operands;
2297 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2298 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2299 operands.push_back(trueOperand);
2303 std::reverse(operands.begin(), operands.end());
2304 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2308 operands.append(trueOperands.begin(), trueOperands.end() - i);
2309 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2311 operands.append(falseOperands.begin(), falseOperands.end() - i);
2313 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2315 Value msb = rewriter.createOrFold<
MuxOp>(
2316 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2317 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2329 if (!trueVec || !falseVec)
2331 if (!trueVec.isUniform() || !falseVec.isUniform())
2335 op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2336 falseVec.getUniformElement(), op.getTwoState());
2338 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2345 using OpRewritePattern::OpRewritePattern;
2347 LogicalResult matchAndRewrite(
MuxOp op,
2348 PatternRewriter &rewriter)
const override;
2351LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2352 PatternRewriter &rewriter)
const {
2361 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2362 if (value.getBitWidth() == 1) {
2364 if (value.isZero()) {
2366 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2367 op.getFalseValue(),
false);
2372 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2373 op.getFalseValue(),
false);
2379 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2384 APInt xorValue = value ^ value2;
2385 if (xorValue.isPowerOf2()) {
2386 unsigned leadingZeros = xorValue.countLeadingZeros();
2387 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2388 SmallVector<Value, 3> operands;
2396 if (leadingZeros > 0)
2397 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2398 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2402 auto v1 = rewriter.createOrFold<
ExtractOp>(
2403 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2404 auto v2 = rewriter.createOrFold<
ExtractOp>(
2405 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2406 operands.push_back(rewriter.createOrFold<
MuxOp>(
2407 op.getLoc(), op.getCond(), v1, v2,
false));
2409 if (trailingZeros > 0)
2410 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2411 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2413 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2420 if (value.isAllOnes() && value2.isZero()) {
2421 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2422 rewriter, op, op.getType(), op.getCond());
2428 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2429 value.getBitWidth() == 1) {
2431 if (value.isZero()) {
2432 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2433 op.getTrueValue(),
false);
2440 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2441 op.getFalseValue(),
false);
2442 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2443 op.getTrueValue(),
false);
2449 Operation *condOp = op.getCond().getDefiningOp();
2450 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2452 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2453 subExpr, op.getFalseValue(),
2454 op.getTrueValue(),
true);
2461 if (condOp && condOp->hasOneUse()) {
2462 SmallVector<Value> invertedOperands;
2466 auto getInvertedOperands = [&]() ->
bool {
2467 for (Value operand : condOp->getOperands()) {
2468 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2469 invertedOperands.push_back(subExpr);
2476 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2478 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2479 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2480 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2484 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2486 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2487 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2488 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2494 if (
auto falseMux = op.getFalseValue().getDefiningOp<
MuxOp>();
2495 falseMux && falseMux != op) {
2497 if (op.getCond() == falseMux.getCond()) {
2498 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2499 rewriter, op, op.getCond(), op.getTrueValue(),
2500 falseMux.getFalseValue(), op.getTwoStateAttr());
2509 if (
auto trueMux = op.getTrueValue().getDefiningOp<
MuxOp>();
2510 trueMux && trueMux != op) {
2512 if (op.getCond() == trueMux.getCond()) {
2513 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2514 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2515 op.getFalseValue(), op.getTwoStateAttr());
2525 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2526 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2527 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2528 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2530 auto subMux = rewriter.create<
MuxOp>(
2531 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2532 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2533 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2534 trueMux.getTrueValue(), subMux,
2535 op.getTwoStateAttr());
2540 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2541 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2542 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2543 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2545 auto subMux = rewriter.create<
MuxOp>(
2546 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2547 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2548 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2549 subMux, trueMux.getFalseValue(),
2550 op.getTwoStateAttr());
2555 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2556 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2557 trueMux && falseMux &&
2558 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2559 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2561 auto subMux = rewriter.create<
MuxOp>(
2562 rewriter.getFusedLoc(
2563 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2564 op.getCond(), trueMux.getCond(), falseMux.getCond());
2565 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2566 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2567 op.getTwoStateAttr());
2579 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2580 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2581 if (trueOp->getName() == falseOp->getName())
2598 if (op.getInputs().empty() || op.isUniform())
2600 auto inputs = op.getInputs();
2601 if (inputs.size() <= 1)
2606 auto first = inputs[0].getDefiningOp<
comb::MuxOp>();
2611 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2612 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2613 if (!input || first.getCond() != input.getCond())
2618 SmallVector<Value> trues{first.getTrueValue()};
2619 SmallVector<Value> falses{first.getFalseValue()};
2620 SmallVector<Location> locs{first->getLoc()};
2621 bool isTwoState =
true;
2622 for (
size_t i = 1, n = inputs.size(); i < n; ++i) {
2623 auto input = inputs[i].getDefiningOp<
comb::MuxOp>();
2624 trues.push_back(input.getTrueValue());
2625 falses.push_back(input.getFalseValue());
2626 locs.push_back(input->getLoc());
2627 if (!input.getTwoState())
2632 auto loc = FusedLoc::get(op.getContext(), locs);
2636 auto arrayTy = op.getType();
2639 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2640 trueValues, falseValues, isTwoState);
2645 using OpRewritePattern::OpRewritePattern;
2648 PatternRewriter &rewriter)
const override {
2652 if (foldArrayOfMuxes(op, rewriter))
2660void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2661 MLIRContext *context) {
2662 results.insert<MuxRewriter, ArrayRewriter>(context);
2673 switch (predicate) {
2674 case ICmpPredicate::eq:
2676 case ICmpPredicate::ne:
2678 case ICmpPredicate::slt:
2679 return lhs.slt(rhs);
2680 case ICmpPredicate::sle:
2681 return lhs.sle(rhs);
2682 case ICmpPredicate::sgt:
2683 return lhs.sgt(rhs);
2684 case ICmpPredicate::sge:
2685 return lhs.sge(rhs);
2686 case ICmpPredicate::ult:
2687 return lhs.ult(rhs);
2688 case ICmpPredicate::ule:
2689 return lhs.ule(rhs);
2690 case ICmpPredicate::ugt:
2691 return lhs.ugt(rhs);
2692 case ICmpPredicate::uge:
2693 return lhs.uge(rhs);
2694 case ICmpPredicate::ceq:
2696 case ICmpPredicate::cne:
2698 case ICmpPredicate::weq:
2700 case ICmpPredicate::wne:
2703 llvm_unreachable(
"unknown comparison predicate");
2709 switch (predicate) {
2710 case ICmpPredicate::eq:
2711 case ICmpPredicate::sle:
2712 case ICmpPredicate::sge:
2713 case ICmpPredicate::ule:
2714 case ICmpPredicate::uge:
2715 case ICmpPredicate::ceq:
2716 case ICmpPredicate::weq:
2718 case ICmpPredicate::ne:
2719 case ICmpPredicate::slt:
2720 case ICmpPredicate::sgt:
2721 case ICmpPredicate::ult:
2722 case ICmpPredicate::ugt:
2723 case ICmpPredicate::cne:
2724 case ICmpPredicate::wne:
2727 llvm_unreachable(
"unknown comparison predicate");
2730OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2736 if (getLhs() == getRhs()) {
2738 return IntegerAttr::get(getType(), val);
2742 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2743 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2746 return IntegerAttr::get(getType(), val);
2754template <
typename Range>
2756 size_t commonPrefixLength = 0;
2757 auto ia = a.begin();
2758 auto ib = b.begin();
2760 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2766 return commonPrefixLength;
2770 size_t totalWidth = 0;
2771 for (
auto operand : operands) {
2774 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2776 totalWidth += width;
2786 PatternRewriter &rewriter) {
2790 SmallVector<Value> lhsOperands, rhsOperands;
2793 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2795 auto formCatOrReplicate = [&](Location loc,
2796 ArrayRef<Value> operands) -> Value {
2797 assert(!operands.empty());
2798 Value sameElement = operands[0];
2799 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2800 if (sameElement != operands[i])
2801 sameElement = Value();
2803 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2805 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2808 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2809 Value rhs) -> LogicalResult {
2810 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2815 size_t commonPrefixLength =
2817 if (commonPrefixLength == lhsOperands.size()) {
2820 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2826 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2828 size_t commonPrefixTotalWidth =
2829 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2830 size_t commonSuffixTotalWidth =
2831 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2832 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2833 .drop_back(commonSuffixLength);
2834 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2835 .drop_back(commonSuffixLength);
2837 auto replaceWithoutReplicatingSignBit = [&]() {
2838 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2839 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2840 return replaceWith(op.getPredicate(), newLhs, newRhs);
2843 auto replaceWithReplicatingSignBit = [&]() {
2844 auto firstNonEmptyValue = lhsOperands[0];
2845 auto firstNonEmptyElemWidth =
2846 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2847 Value signBit = rewriter.createOrFold<
ExtractOp>(
2848 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2850 auto newLhs = rewriter.
create<
ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2851 auto newRhs = rewriter.create<
ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2852 return replaceWith(op.getPredicate(), newLhs, newRhs);
2855 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2857 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2858 return replaceWithoutReplicatingSignBit();
2864 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2865 return replaceWithReplicatingSignBit();
2867 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2869 return replaceWithoutReplicatingSignBit();
2883 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2884 PatternRewriter &rewriter) {
2888 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2889 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2892 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2893 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2901 SmallVector<Value> newConcatOperands;
2902 auto newConstant = APInt::getZeroWidth();
2907 unsigned knownMSB = bitsKnown.countLeadingOnes();
2909 Value operand = cmpOp.getLhs();
2914 while (knownMSB != bitsKnown.getBitWidth()) {
2917 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2920 unsigned unknownBits = bitsKnown.countLeadingZeros();
2921 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2922 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
2923 operand.getLoc(), operand, lowBit,
2925 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2928 newConcatOperands.push_back(spanOperand);
2931 if (newConstant.getBitWidth() != 0)
2932 newConstant = newConstant.concat(spanConstant);
2934 newConstant = spanConstant;
2937 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
2938 bitsKnown = bitsKnown.trunc(newWidth);
2939 knownMSB = bitsKnown.countLeadingOnes();
2945 if (newConcatOperands.empty()) {
2946 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
2947 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2953 Value concatResult =
2954 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
2958 cmpOp.getOperand(1).getLoc(), newConstant);
2960 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
2961 cmpOp.getPredicate(), concatResult,
2962 newConstantOp, cmpOp.getTwoState());
2968 PatternRewriter &rewriter) {
2969 auto ip = rewriter.saveInsertionPoint();
2970 rewriter.setInsertionPoint(xorOp);
2972 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
2974 xorRHS.getValue() ^ rhs);
2976 switch (xorOp.getNumOperands()) {
2980 APInt::getZero(rhs.getBitWidth()));
2984 newLHS = xorOp.getOperand(0);
2988 SmallVector<Value> newOperands(xorOp.getOperands());
2989 newOperands.pop_back();
2990 newLHS = rewriter.create<
XorOp>(xorOp.getLoc(), newOperands,
false);
2994 bool xorMultipleUses = !xorOp->hasOneUse();
2998 if (xorMultipleUses)
2999 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3003 rewriter.restoreInsertionPoint(ip);
3004 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3005 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS,
false);
3008LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3015 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3016 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3017 "Should be folded");
3018 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3019 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3020 op.getRhs(), op.getLhs(), op.getTwoState());
3025 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3027 return rewriter.create<
hw::ConstantOp>(op.getLoc(), std::move(constant));
3030 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3031 Value rhs) -> LogicalResult {
3032 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3033 rhs, op.getTwoState());
3037 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3038 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3039 APInt(1, constant));
3043 switch (op.getPredicate()) {
3044 case ICmpPredicate::slt:
3046 if (rhs.isMaxSignedValue())
3047 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3049 if (rhs.isMinSignedValue())
3050 return replaceWithConstantI1(0);
3052 if ((rhs - 1).isMinSignedValue())
3053 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3056 case ICmpPredicate::sgt:
3058 if (rhs.isMinSignedValue())
3059 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3061 if (rhs.isMaxSignedValue())
3062 return replaceWithConstantI1(0);
3064 if ((rhs + 1).isMaxSignedValue())
3065 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3068 case ICmpPredicate::ult:
3070 if (rhs.isAllOnes())
3071 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3074 return replaceWithConstantI1(0);
3076 if ((rhs - 1).isZero())
3077 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3081 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3082 rhs.getBitWidth()) {
3083 auto numOnes = rhs.countLeadingOnes();
3084 auto smaller = rewriter.create<
ExtractOp>(
3085 op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3086 return replaceWith(ICmpPredicate::ne, smaller,
3091 case ICmpPredicate::ugt:
3094 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3096 if (rhs.isAllOnes())
3097 return replaceWithConstantI1(0);
3099 if ((rhs + 1).isAllOnes())
3100 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3104 if ((rhs + 1).isPowerOf2()) {
3105 auto numOnes = rhs.countTrailingOnes();
3106 auto newWidth = rhs.getBitWidth() - numOnes;
3107 auto smaller = rewriter.create<
ExtractOp>(op.getLoc(), op.getLhs(),
3109 return replaceWith(ICmpPredicate::ne, smaller,
3114 case ICmpPredicate::sle:
3116 if (rhs.isMaxSignedValue())
3117 return replaceWithConstantI1(1);
3119 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3120 case ICmpPredicate::sge:
3122 if (rhs.isMinSignedValue())
3123 return replaceWithConstantI1(1);
3125 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3126 case ICmpPredicate::ule:
3128 if (rhs.isAllOnes())
3129 return replaceWithConstantI1(1);
3131 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3132 case ICmpPredicate::uge:
3135 return replaceWithConstantI1(1);
3137 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3138 case ICmpPredicate::eq:
3139 if (rhs.getBitWidth() == 1) {
3142 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3147 if (rhs.isAllOnes()) {
3154 case ICmpPredicate::ne:
3155 if (rhs.getBitWidth() == 1) {
3161 if (rhs.isAllOnes()) {
3163 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3170 case ICmpPredicate::ceq:
3171 case ICmpPredicate::cne:
3172 case ICmpPredicate::weq:
3173 case ICmpPredicate::wne:
3179 if (op.getPredicate() == ICmpPredicate::eq ||
3180 op.getPredicate() == ICmpPredicate::ne) {
3185 if (!knownBits.isUnknown())
3192 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3199 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3200 if (rhs.isAllOnes() || rhs.isZero()) {
3201 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3203 op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width)
3204 : APInt::getZero(width));
3205 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3206 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3216 if (Operation *opLHS = op.getLhs().getDefiningOp())
3217 if (Operation *opRHS = op.getRhs().getDefiningOp())
3218 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3219 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
uint64_t getWidth(Type t)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.