12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallBitVector.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/KnownBits.h"
20 using namespace circt;
22 using namespace matchers;
33 Block *thisBlock = op->getBlock();
34 return llvm::any_of(op->getOperands(), [&](Value operand) {
35 return operand.getParentBlock() != thisBlock;
45 ArrayRef<Value> operands, OpBuilder &
builder) {
46 OperationState state(loc, name);
47 state.addOperands(operands);
48 state.addTypes(operands[0].getType());
49 return builder.create(state)->getResult(0);
52 static TypedAttr
getIntAttr(
const APInt &value, MLIRContext *context) {
60 for (
auto op :
concat.getOperands())
62 }
else if (
auto repl = v.getDefiningOp<ReplicateOp>()) {
63 for (
size_t i = 0, e = repl.getMultiple(); i != e; ++i)
75 if (
auto *newOp = newValue.getDefiningOp()) {
76 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
77 if (name && !newOp->hasAttr(
"sv.namehint"))
78 rewriter.modifyOpInPlace(newOp,
79 [&] { newOp->setAttr(
"sv.namehint", name); });
81 rewriter.replaceOp(op, newValue);
87 template <
typename OpTy,
typename... Args>
89 Operation *op, Args &&...args) {
90 auto name = op->getAttrOfType<StringAttr>(
"sv.namehint");
92 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
93 if (name && !newOp->hasAttr(
"sv.namehint"))
94 rewriter.modifyOpInPlace(newOp,
95 [&] { newOp->setAttr(
"sv.namehint", name); });
104 return op->hasAttr(
"sv.attributes");
108 template <
typename SubType>
109 struct ComplementMatcher {
111 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
112 bool match(Operation *op) {
113 auto xorOp = dyn_cast<XorOp>(op);
114 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
119 template <
typename SubType>
120 static inline ComplementMatcher<SubType>
m_Complement(
const SubType &subExpr) {
121 return ComplementMatcher<SubType>(subExpr);
127 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
128 "must be commutative operations"));
129 if (op->hasOneUse()) {
130 auto *user = *op->getUsers().begin();
131 return user->getName() == op->getName() &&
132 op->getAttrOfType<UnitAttr>(
"twoState") ==
133 user->getAttrOfType<UnitAttr>(
"twoState");
148 auto inputs = op->getOperands();
150 SmallVector<Value, 4> newOperands;
151 SmallVector<Location, 4> newLocations{op->getLoc()};
152 newOperands.reserve(
inputs.size());
154 decltype(
inputs.begin()) current, end;
157 SmallVector<Element> worklist;
159 bool binFlag = op->hasAttrOfType<UnitAttr>(
"twoState");
160 bool changed =
false;
161 while (!worklist.empty()) {
162 auto &element = worklist.back();
165 if (element.current == element.end) {
170 Value value = *element.current++;
171 auto *flattenOp = value.getDefiningOp();
172 if (!flattenOp || flattenOp->getName() != op->getName() ||
173 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>(
"twoState")) {
174 newOperands.push_back(value);
179 if (!value.hasOneUse()) {
187 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
190 newOperands.push_back(value);
198 auto flattenOpInputs = flattenOp->getOperands();
199 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
200 newLocations.push_back(flattenOp->getLoc());
207 op->getName(), newOperands, rewriter);
209 result.getDefiningOp()->setAttr(
"twoState", rewriter.getUnitAttr());
217 static std::pair<size_t, size_t>
219 size_t originalOpWidth) {
220 auto users = op->getUsers();
222 "getLowestBitAndHighestBitRequired cannot operate on "
223 "a empty list of uses.");
227 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
228 size_t highestBitRequired = 0;
230 for (
auto *user : users) {
231 if (
auto extractOp = dyn_cast<ExtractOp>(user)) {
232 size_t lowBit = extractOp.getLowBit();
234 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
235 highestBitRequired = std::max(highestBitRequired, highBit);
236 lowestBitRequired = std::min(lowestBitRequired, lowBit);
240 highestBitRequired = originalOpWidth - 1;
241 lowestBitRequired = 0;
245 return {lowestBitRequired, highestBitRequired};
248 template <
class OpTy>
250 PatternRewriter &rewriter) {
251 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
257 if (range.second + 1 == opType.getWidth() && range.first == 0)
260 SmallVector<Value> args;
261 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
262 for (
auto inop : op.getOperands()) {
264 if (inop.getType() != op.getType())
265 args.push_back(inop);
267 args.push_back(rewriter.createOrFold<
ExtractOp>(inop.getLoc(), newType,
270 Value newop = rewriter.createOrFold<OpTy>(op.getLoc(), newType, args);
271 newop.getDefiningOp()->setDialectAttrs(op->getDialectAttrs());
273 newop = rewriter.createOrFold<
ConcatOp>(
276 APInt::getZero(range.first)));
277 if (range.second + 1 < opType.getWidth())
278 newop = rewriter.createOrFold<
ConcatOp>(
281 op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
283 rewriter.replaceOp(op, newop);
291 OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
296 if (cast<IntegerType>(getType()).
getWidth() ==
297 getInput().getType().getIntOrFloatBitWidth())
301 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
302 if (input.getValue().getBitWidth() == 1) {
303 if (input.getValue().isZero())
305 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
308 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
312 APInt result = APInt::getZeroWidth();
313 for (
auto i = getMultiple(); i != 0; --i)
314 result = result.concat(input.getValue());
321 OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
326 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
327 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
339 hw::PEO paramOpcode) {
340 assert(operands.size() == 2 &&
"binary op takes two operands");
341 if (!operands[0] || !operands[1])
347 cast<TypedAttr>(operands[1]));
350 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
354 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
355 unsigned shift = rhs.getValue().getZExtValue();
356 unsigned width = getType().getIntOrFloatBitWidth();
358 return getOperand(0);
366 LogicalResult ShlOp::canonicalize(
ShlOp op, PatternRewriter &rewriter) {
372 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
375 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
376 unsigned shift = value.getZExtValue();
379 if (
width <= shift || shift == 0)
383 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
389 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, extract, zeros);
393 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
397 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
398 unsigned shift = rhs.getValue().getZExtValue();
400 return getOperand(0);
402 unsigned width = getType().getIntOrFloatBitWidth();
409 LogicalResult ShrUOp::canonicalize(
ShrUOp op, PatternRewriter &rewriter) {
415 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
418 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
419 unsigned shift = value.getZExtValue();
422 if (
width <= shift || shift == 0)
426 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
429 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
432 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, zeros, extract);
436 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
440 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
441 if (rhs.getValue().getZExtValue() == 0)
442 return getOperand(0);
447 LogicalResult ShrSOp::canonicalize(
ShrSOp op, PatternRewriter &rewriter) {
453 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
456 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
457 unsigned shift = value.getZExtValue();
460 rewriter.createOrFold<
ExtractOp>(op.getLoc(), op.getLhs(),
width - 1, 1);
461 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
463 if (
width <= shift) {
468 auto extract = rewriter.
create<
ExtractOp>(op.getLoc(), op.getLhs(), shift,
471 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, sext, extract);
479 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
484 if (getInput().getType() == getType())
488 if (
auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
489 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
490 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
501 PatternRewriter &rewriter) {
502 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
503 size_t beginOfFirstRelevantElement = 0;
504 auto it = reversedConcatArgs.begin();
505 size_t lowBit = op.getLowBit();
508 for (; it != reversedConcatArgs.end(); it++) {
509 assert(beginOfFirstRelevantElement <= lowBit &&
510 "incorrectly moved past an element that lowBit has coverage over");
513 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
514 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
538 beginOfFirstRelevantElement += operandWidth;
540 assert(it != reversedConcatArgs.end() &&
541 "incorrectly failed to find an element which contains coverage of "
544 SmallVector<Value> reverseConcatArgs;
545 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
546 size_t extractLo = lowBit - beginOfFirstRelevantElement;
551 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
552 auto concatArg = *it;
553 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
554 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
556 if (widthToConsume == operandWidth && extractLo == 0) {
557 reverseConcatArgs.push_back(concatArg);
560 reverseConcatArgs.push_back(
561 rewriter.create<
ExtractOp>(op.getLoc(), resultType, *it, extractLo));
564 widthRemaining -= widthToConsume;
570 if (reverseConcatArgs.size() == 1) {
573 replaceOpWithNewOpAndCopyName<ConcatOp>(
574 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
581 PatternRewriter &rewriter) {
582 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
583 auto replicateEltWidth =
584 replicate.getOperand().getType().getIntOrFloatBitWidth();
588 if (op.getLowBit() % replicateEltWidth == 0 &&
589 extractResultWidth % replicateEltWidth == 0) {
590 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
591 replicate.getOperand());
597 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
599 replaceOpWithNewOpAndCopyName<ExtractOp>(
600 rewriter, op, op.getType(), replicate.getOperand(),
601 op.getLowBit() % replicateEltWidth);
610 LogicalResult ExtractOp::canonicalize(
ExtractOp op, PatternRewriter &rewriter) {
614 auto *inputOp = op.getInput().getDefiningOp();
621 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
623 if (knownBits.isConstant()) {
624 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
625 knownBits.getConstant());
631 if (
auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
632 replaceOpWithNewOpAndCopyName<ExtractOp>(
633 rewriter, op, op.getType(), innerExtract.getInput(),
634 innerExtract.getLowBit() + op.getLowBit());
639 if (
auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
643 if (
auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
649 if (inputOp && inputOp->getNumOperands() == 2 &&
650 isa<AndOp, OrOp, XorOp>(inputOp)) {
651 if (
auto cstRHS = inputOp->getOperand(1).getDefiningOp<
hw::ConstantOp>()) {
652 auto extractedCst = cstRHS.getValue().extractBits(
653 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
654 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
655 replaceOpWithNewOpAndCopyName<ExtractOp>(
656 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
664 if (isa<AndOp>(inputOp)) {
667 unsigned lz = extractedCst.countLeadingZeros();
668 unsigned tz = extractedCst.countTrailingZeros();
669 unsigned pop = extractedCst.popcount();
670 if (extractedCst.getBitWidth() - lz - tz == pop) {
671 auto resultTy = rewriter.getIntegerType(pop);
672 SmallVector<Value> resultElts;
675 op.getLoc(), APInt::getZero(lz)));
676 resultElts.push_back(rewriter.createOrFold<
ExtractOp>(
677 op.getLoc(), resultTy, inputOp->getOperand(0),
678 op.getLowBit() + tz));
681 op.getLoc(), APInt::getZero(tz)));
682 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
691 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
692 if (
auto shlOp = dyn_cast<ShlOp>(inputOp))
693 if (
auto lhsCst = shlOp.getOperand(0).getDefiningOp<
hw::ConstantOp>())
694 if (lhsCst.getValue().isOne()) {
697 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
698 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
699 shlOp->getOperand(1), newCst,
714 hw::PEO paramOpcode) {
715 assert(operands.size() > 1 &&
"caller should handle one-operand case");
718 if (!operands[1] || !operands[0])
722 if (llvm::all_of(operands.drop_front(2),
723 [&](Attribute in) { return !!in; })) {
724 SmallVector<mlir::TypedAttr> typedOperands;
725 typedOperands.reserve(operands.size());
726 for (
auto operand : operands) {
727 if (
auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
728 typedOperands.push_back(typedOperand);
732 if (typedOperands.size() == operands.size())
749 size_t concatIdx,
const APInt &cst,
750 PatternRewriter &rewriter) {
751 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<
ConcatOp>();
752 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
757 llvm::any_of(concatOp->getOperands(), [&](Value operand) ->
bool {
758 auto *operandOp = operand.getDefiningOp();
763 if (isa<hw::ConstantOp>(operandOp))
767 return operandOp->getName() == logicalOp->getName() &&
768 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
769 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
777 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
778 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
785 SmallVector<Value> newConcatOperands;
786 newConcatOperands.reserve(concatOp->getNumOperands());
789 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
790 for (Value operand : concatOp->getOperands()) {
791 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
792 nextOperandBit -= operandWidth;
795 logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
797 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
806 if (logicalOp->getNumOperands() > 2) {
807 auto origOperands = logicalOp->getOperands();
808 SmallVector<Value> operands;
810 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
812 operands.append(origOperands.begin() + concatIdx + 1,
813 origOperands.begin() + (origOperands.size() - 1));
815 operands.push_back(newResult);
816 newResult = createLogicalOp(operands);
826 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
828 for (
auto op : operands) {
829 if (
auto icmpOp = op.getDefiningOp<ICmpOp>();
830 icmpOp && icmpOp.getTwoState()) {
831 auto predicate = icmpOp.getPredicate();
832 auto lhs = icmpOp.getLhs();
833 auto rhs = icmpOp.getRhs();
834 if (seenPredicates.contains(
835 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
838 seenPredicates.insert({predicate, lhs, rhs});
844 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
848 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth());
850 auto inputs = adaptor.getInputs();
853 for (
auto operand :
inputs) {
856 value &= cast<IntegerAttr>(operand).getValue();
863 cast<IntegerAttr>(
inputs[1]).getValue().isAllOnes())
864 return getInputs()[0];
867 if (llvm::all_of(getInputs(),
868 [&](
auto in) {
return in == this->getInputs()[0]; }))
869 return getInputs()[0];
872 for (Value arg : getInputs()) {
875 for (Value arg2 : getInputs())
878 APInt::getZero(cast<IntegerType>(getType()).
getWidth()),
899 template <
typename Op>
901 if (!op.getType().isInteger(1))
904 auto inputs = op.getInputs();
905 size_t size =
inputs.size();
907 auto sourceOp =
inputs[0].template getDefiningOp<ExtractOp>();
910 Value source = sourceOp.getOperand();
913 if (size != source.getType().getIntOrFloatBitWidth())
917 llvm::BitVector bits(size);
918 bits.set(sourceOp.getLowBit());
920 for (
size_t i = 1; i != size; ++i) {
921 auto extractOp =
inputs[i].template getDefiningOp<ExtractOp>();
922 if (!extractOp || extractOp.getOperand() != source)
924 bits.set(extractOp.getLowBit());
927 return bits.all() ? source : Value();
934 template <
typename Op>
936 auto inputs = op.getInputs();
938 llvm::SmallSetVector<Value, 8> uniqueInputs(
inputs.begin(),
inputs.end());
939 llvm::SmallDenseSet<Value, 8> checked;
942 llvm::SmallVector<Value, 8> worklist;
943 for (
auto input :
inputs) {
945 worklist.push_back(input);
948 while (!worklist.empty()) {
949 auto element = worklist.pop_back_val();
951 if (
auto idempotentOp = element.getDefiningOp<Op>()) {
952 for (
auto input : idempotentOp.getInputs()) {
953 uniqueInputs.remove(input);
955 if (checked.insert(input).second)
956 worklist.push_back(input);
961 if (uniqueInputs.size() <
inputs.size()) {
962 replaceOpWithNewOpAndCopyName<Op>(rewriter, op, op.getType(),
963 uniqueInputs.getArrayRef());
970 LogicalResult AndOp::canonicalize(
AndOp op, PatternRewriter &rewriter) {
974 auto inputs = op.getInputs();
975 auto size =
inputs.size();
976 assert(size > 1 &&
"expected 2 or more operands, `fold` should handle this");
990 if (matchPattern(
inputs.back(), m_ConstantInt(&value))) {
992 if (value.isAllOnes()) {
993 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
994 inputs.drop_back(),
false);
1002 if (matchPattern(
inputs[size - 2], m_ConstantInt(&value2))) {
1003 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value & value2);
1004 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1005 newOperands.push_back(cst);
1006 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1007 newOperands,
false);
1012 if (size == 2 && value.isPowerOf2()) {
1017 if (
auto replicate =
inputs[0].getDefiningOp<ReplicateOp>()) {
1018 auto replicateOperand = replicate.getOperand();
1019 if (replicateOperand.getType().isInteger(1)) {
1020 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1021 auto trailingZeros = value.countTrailingZeros();
1024 SmallVector<Value, 3> concatOperands;
1025 if (trailingZeros != resultWidth - 1) {
1027 op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
1028 concatOperands.push_back(highZeros);
1030 concatOperands.push_back(replicateOperand);
1031 if (trailingZeros != 0) {
1033 op.getLoc(), APInt::getZero(trailingZeros));
1034 concatOperands.push_back(lowZeros);
1036 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1044 if (
auto extractOp =
inputs[0].getDefiningOp<ExtractOp>()) {
1047 (value.countLeadingZeros() || value.countTrailingZeros())) {
1048 unsigned lz = value.countLeadingZeros();
1049 unsigned tz = value.countTrailingZeros();
1052 auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
1053 Value smallElt = rewriter.createOrFold<
ExtractOp>(
1054 extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1055 extractOp.getLowBit() + tz);
1057 APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1058 if (!smallMask.isAllOnes()) {
1059 auto loc =
inputs.back().getLoc();
1060 smallElt = rewriter.createOrFold<
AndOp>(
1067 SmallVector<Value> resultElts;
1069 resultElts.push_back(
1070 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1071 resultElts.push_back(smallElt);
1073 resultElts.push_back(
1074 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1075 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
1083 for (
size_t i = 0; i < size - 1; ++i) {
1097 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1098 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1099 source, cmpAgainst);
1107 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1111 auto value = APInt::getZero(cast<IntegerType>(getType()).
getWidth());
1112 auto inputs = adaptor.getInputs();
1114 for (
auto operand :
inputs) {
1117 value |= cast<IntegerAttr>(operand).getValue();
1118 if (value.isAllOnes())
1124 cast<IntegerAttr>(
inputs[1]).getValue().isZero())
1125 return getInputs()[0];
1128 if (llvm::all_of(getInputs(),
1129 [&](
auto in) {
return in == this->getInputs()[0]; }))
1130 return getInputs()[0];
1133 for (Value arg : getInputs()) {
1135 if (matchPattern(arg,
m_Complement(m_Any(&subExpr)))) {
1136 for (Value arg2 : getInputs())
1137 if (arg2 == subExpr)
1139 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1149 APInt::getAllOnes(cast<IntegerType>(getType()).
getWidth()),
1173 PatternRewriter &rewriter) {
1174 assert(concatIdx1 < concatIdx2 &&
"concatIdx1 must be < concatIdx2");
1176 auto inputs = op.getInputs();
1180 assert(concat1 && concat2 &&
"expected indexes to point to ConcatOps");
1183 bool hasConstantOp1 =
1184 llvm::any_of(concat1->getOperands(), [&](Value operand) ->
bool {
1185 return operand.getDefiningOp<hw::ConstantOp>();
1187 if (!hasConstantOp1) {
1188 bool hasConstantOp2 =
1189 llvm::any_of(concat2->getOperands(), [&](Value operand) ->
bool {
1190 return operand.getDefiningOp<hw::ConstantOp>();
1192 if (!hasConstantOp2)
1196 SmallVector<Value> newConcatOperands;
1201 auto operands1 = concat1->getOperands();
1202 auto operands2 = concat2->getOperands();
1204 unsigned consumedWidth1 = 0;
1205 unsigned consumedWidth2 = 0;
1206 for (
auto it1 = operands1.begin(), end1 = operands1.end(),
1207 it2 = operands2.begin(), end2 = operands2.end();
1208 it1 != end1 && it2 != end2;) {
1209 auto operand1 = *it1;
1210 auto operand2 = *it2;
1212 unsigned remainingWidth1 =
1214 unsigned remainingWidth2 =
1216 unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2);
1217 auto narrowedType = rewriter.getIntegerType(widthToConsume);
1219 auto extract1 = rewriter.createOrFold<
ExtractOp>(
1220 op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume);
1221 auto extract2 = rewriter.createOrFold<
ExtractOp>(
1222 op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume);
1224 newConcatOperands.push_back(
1225 rewriter.createOrFold<
OrOp>(op.getLoc(), extract1, extract2,
false));
1227 consumedWidth1 += widthToConsume;
1228 consumedWidth2 += widthToConsume;
1230 if (widthToConsume == remainingWidth1) {
1234 if (widthToConsume == remainingWidth2) {
1240 ConcatOp newOp = rewriter.create<
ConcatOp>(op.getLoc(), newConcatOperands);
1244 SmallVector<Value> newOrOperands;
1245 newOrOperands.append(
inputs.begin(),
inputs.begin() + concatIdx1);
1246 newOrOperands.append(
inputs.begin() + concatIdx1 + 1,
1247 inputs.begin() + concatIdx2);
1248 newOrOperands.append(
inputs.begin() + concatIdx2 + 1,
1250 newOrOperands.push_back(newOp);
1252 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1257 LogicalResult OrOp::canonicalize(
OrOp op, PatternRewriter &rewriter) {
1261 auto inputs = op.getInputs();
1262 auto size =
inputs.size();
1263 assert(size > 1 &&
"expected 2 or more operands");
1277 if (matchPattern(
inputs.back(), m_ConstantInt(&value))) {
1279 if (value.isZero()) {
1280 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1287 if (matchPattern(
inputs[size - 2], m_ConstantInt(&value2))) {
1288 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value | value2);
1289 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1290 newOperands.push_back(cst);
1291 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1299 for (
size_t i = 0; i < size - 1; ++i) {
1308 for (
size_t i = 0; i < size - 1; ++i) {
1310 for (
size_t j = i + 1; j < size; ++j)
1323 rewriter.create<
hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1324 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1325 source, cmpAgainst);
1331 if (
auto firstMux = op.getOperand(0).getDefiningOp<
comb::MuxOp>()) {
1333 if (op.getTwoState() && firstMux.getTwoState() &&
1334 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1336 SmallVector<Value> conditions{firstMux.getCond()};
1337 auto check = [&](Value v) {
1341 conditions.push_back(mux.getCond());
1342 return mux.getTwoState() &&
1343 firstMux.getTrueValue() == mux.getTrueValue() &&
1344 firstMux.getFalseValue() == mux.getFalseValue();
1346 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1347 auto cond = rewriter.create<
comb::OrOp>(op.getLoc(), conditions,
true);
1348 replaceOpWithNewOpAndCopyName<comb::MuxOp>(
1349 rewriter, op, cond, firstMux.getTrueValue(),
1350 firstMux.getFalseValue(),
true);
1360 OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1364 auto size = getInputs().size();
1365 auto inputs = adaptor.getInputs();
1369 return getInputs()[0];
1372 if (size == 2 && getInputs()[0] == getInputs()[1])
1377 cast<IntegerAttr>(
inputs[1]).getValue().isZero())
1378 return getInputs()[0];
1382 if (isBinaryNot()) {
1384 if (matchPattern(getOperand(0),
m_Complement(m_Any(&subExpr))) &&
1385 subExpr != getResult())
1395 PatternRewriter &rewriter) {
1396 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1397 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1400 rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1401 icmp.getOperand(1), icmp.getTwoState());
1404 if (op.getNumOperands() > 2) {
1405 SmallVector<Value, 4> newOperands(op.getOperands());
1406 newOperands.pop_back();
1407 newOperands.erase(newOperands.begin() + icmpOperand);
1408 newOperands.push_back(result);
1409 result = rewriter.create<
XorOp>(op.getLoc(), newOperands, op.getTwoState());
1415 LogicalResult XorOp::canonicalize(
XorOp op, PatternRewriter &rewriter) {
1419 auto inputs = op.getInputs();
1420 auto size =
inputs.size();
1421 assert(size > 1 &&
"expected 2 or more operands");
1426 "expected idempotent case for 2 elements handled already.");
1427 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1428 inputs.drop_back(2),
false);
1434 if (matchPattern(
inputs.back(), m_ConstantInt(&value))) {
1436 if (value.isZero()) {
1437 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1438 inputs.drop_back(),
false);
1444 if (matchPattern(
inputs[size - 2], m_ConstantInt(&value2))) {
1445 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value ^ value2);
1446 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1447 newOperands.push_back(cst);
1448 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1449 newOperands,
false);
1453 bool isSingleBit = value.getBitWidth() == 1;
1456 for (
size_t i = 0; i < size - 1; ++i) {
1457 Value operand =
inputs[i];
1468 if (isSingleBit && operand.hasOneUse()) {
1469 assert(value == 1 &&
"single bit constant has to be one if not zero");
1470 if (
auto icmp = operand.getDefiningOp<ICmpOp>())
1486 replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
1493 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1498 if (getRhs() == getLhs())
1500 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1503 if (adaptor.getRhs()) {
1505 if (adaptor.getLhs()) {
1508 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1511 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1513 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1517 if (
auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1518 if (rhsC.getValue().isZero())
1526 LogicalResult SubOp::canonicalize(
SubOp op, PatternRewriter &rewriter) {
1532 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1533 auto negCst = rewriter.create<
hw::ConstantOp>(op.getLoc(), -value);
1534 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getLhs(), negCst,
1546 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1550 auto size = getInputs().size();
1554 return getInputs()[0];
1560 LogicalResult AddOp::canonicalize(
AddOp op, PatternRewriter &rewriter) {
1564 auto inputs = op.getInputs();
1565 auto size =
inputs.size();
1566 assert(size > 1 &&
"expected 2 or more operands");
1568 APInt value, value2;
1571 if (matchPattern(
inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1572 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1573 inputs.drop_back(),
false);
1578 if (matchPattern(
inputs[size - 1], m_ConstantInt(&value)) &&
1579 matchPattern(
inputs[size - 2], m_ConstantInt(&value2))) {
1580 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1581 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1582 newOperands.push_back(cst);
1583 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1584 newOperands,
false);
1590 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1592 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1596 newOperands.push_back(shiftLeftOp);
1597 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1598 newOperands,
false);
1604 if (shlOp && shlOp.getLhs() ==
inputs[size - 2] &&
1605 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1607 APInt one(value.getBitWidth(), 1,
false);
1609 rewriter.create<
hw::ConstantOp>(op.getLoc(), (one << value) + one);
1611 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1612 auto mulOp = rewriter.create<
comb::MulOp>(op.getLoc(), factors,
false);
1614 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1615 newOperands.push_back(mulOp);
1616 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1617 newOperands,
false);
1623 if (mulOp && mulOp.getInputs().size() == 2 &&
1624 mulOp.getInputs()[0] ==
inputs[size - 2] &&
1625 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1627 APInt one(value.getBitWidth(), 1,
false);
1628 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + one);
1629 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1632 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1633 newOperands.push_back(newMulOp);
1634 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1635 newOperands,
false);
1649 if (addOp && addOp.getInputs().size() == 2 &&
1650 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1651 inputs.size() == 2 && matchPattern(
inputs[1], m_ConstantInt(&value))) {
1653 auto rhs = rewriter.create<
hw::ConstantOp>(op.getLoc(), value + value2);
1654 replaceOpWithNewOpAndCopyName<AddOp>(
1655 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1656 op.getTwoState() && addOp.getTwoState());
1663 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1667 auto size = getInputs().size();
1668 auto inputs = adaptor.getInputs();
1672 return getInputs()[0];
1674 auto width = cast<IntegerType>(getType()).getWidth();
1675 APInt value(
width, 1,
false);
1678 for (
auto operand :
inputs) {
1681 value *= cast<IntegerAttr>(operand).getValue();
1690 LogicalResult MulOp::canonicalize(
MulOp op, PatternRewriter &rewriter) {
1694 auto inputs = op.getInputs();
1695 auto size =
inputs.size();
1696 assert(size > 1 &&
"expected 2 or more operands");
1698 APInt value, value2;
1701 if (size == 2 && matchPattern(
inputs.back(), m_ConstantInt(&value)) &&
1702 value.isPowerOf2()) {
1703 auto shift = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(),
1704 value.exactLogBase2());
1708 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1709 ArrayRef<Value>(shlOp),
false);
1714 if (matchPattern(
inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1715 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1721 if (matchPattern(
inputs[size - 1], m_ConstantInt(&value)) &&
1722 matchPattern(
inputs[size - 2], m_ConstantInt(&value2))) {
1723 auto cst = rewriter.create<
hw::ConstantOp>(op.getLoc(), value * value2);
1724 SmallVector<Value, 4> newOperands(
inputs.drop_back(2));
1725 newOperands.push_back(cst);
1726 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1742 template <
class Op,
bool isSigned>
1743 static OpFoldResult
foldDiv(Op op, ArrayRef<Attribute> constants) {
1744 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1746 if (rhsValue.getValue() == 1)
1750 if (rhsValue.getValue().isZero())
1757 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1761 return foldDiv<
DivUOp,
false>(*
this, adaptor.getOperands());
1764 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1771 template <
class Op,
bool isSigned>
1772 static OpFoldResult
foldMod(Op op, ArrayRef<Attribute> constants) {
1773 if (
auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1775 if (rhsValue.getValue() == 1)
1776 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1780 if (rhsValue.getValue().isZero())
1784 if (
auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1786 if (lhsValue.getValue().isZero())
1787 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1794 OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1798 return foldMod<
ModUOp,
false>(*
this, adaptor.getOperands());
1801 OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1812 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1816 if (getNumOperands() == 1)
1817 return getOperand(0);
1820 for (
auto attr : adaptor.getInputs())
1821 if (!attr || !isa<IntegerAttr>(attr))
1825 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1826 APInt result(resultWidth, 0);
1828 unsigned nextInsertion = resultWidth;
1830 for (
auto attr : adaptor.getInputs()) {
1831 auto chunk = cast<IntegerAttr>(attr).getValue();
1832 nextInsertion -= chunk.getBitWidth();
1833 result.insertBits(chunk, nextInsertion);
1839 LogicalResult ConcatOp::canonicalize(
ConcatOp op, PatternRewriter &rewriter) {
1843 auto inputs = op.getInputs();
1844 auto size =
inputs.size();
1845 assert(size > 1 &&
"expected 2 or more operands");
1850 auto flattenConcat = [&](
size_t firstOpIndex,
size_t lastOpIndex,
1851 ValueRange replacements) -> LogicalResult {
1852 SmallVector<Value, 4> newOperands;
1853 newOperands.append(
inputs.begin(),
inputs.begin() + firstOpIndex);
1854 newOperands.append(replacements.begin(), replacements.end());
1855 newOperands.append(
inputs.begin() + lastOpIndex + 1,
inputs.end());
1856 if (newOperands.size() == 1)
1859 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1864 Value commonOperand =
inputs[0];
1865 for (
size_t i = 0; i != size; ++i) {
1867 if (
inputs[i] != commonOperand)
1868 commonOperand = Value();
1872 if (
auto subConcat =
inputs[i].getDefiningOp<ConcatOp>())
1873 return flattenConcat(i, i, subConcat->getOperands());
1878 if (
auto cst =
inputs[i].getDefiningOp<hw::ConstantOp>()) {
1879 if (
auto prevCst =
inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1880 unsigned prevWidth = prevCst.getValue().getBitWidth();
1881 unsigned thisWidth = cst.getValue().getBitWidth();
1882 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1883 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1887 return flattenConcat(i - 1, i, replacement);
1894 rewriter.createOrFold<ReplicateOp>(op.getLoc(),
inputs[i], 2);
1895 return flattenConcat(i - 1, i, replacement);
1900 if (
auto repl =
inputs[i].getDefiningOp<ReplicateOp>()) {
1902 if (repl.getOperand() ==
inputs[i - 1]) {
1903 Value replacement = rewriter.createOrFold<ReplicateOp>(
1904 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1905 return flattenConcat(i - 1, i, replacement);
1908 if (
auto prevRepl =
inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1909 if (prevRepl.getOperand() == repl.getOperand()) {
1910 Value replacement = rewriter.createOrFold<ReplicateOp>(
1911 op.getLoc(), repl.getOperand(),
1912 repl.getMultiple() + prevRepl.getMultiple());
1913 return flattenConcat(i - 1, i, replacement);
1919 if (
auto repl =
inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1920 if (repl.getOperand() ==
inputs[i]) {
1921 Value replacement = rewriter.createOrFold<ReplicateOp>(
1922 op.getLoc(),
inputs[i], repl.getMultiple() + 1);
1923 return flattenConcat(i - 1, i, replacement);
1929 if (
auto extract =
inputs[i].getDefiningOp<ExtractOp>()) {
1930 if (
auto prevExtract =
inputs[i - 1].getDefiningOp<ExtractOp>()) {
1931 if (extract.getInput() == prevExtract.getInput()) {
1932 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1933 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1934 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1935 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1936 Value replacement = rewriter.create<
ExtractOp>(
1937 op.getLoc(), resType, extract.getInput(),
1938 extract.getLowBit());
1939 return flattenConcat(i - 1, i, replacement);
1952 static std::optional<ArraySlice>
get(Value value) {
1953 assert(isa<IntegerType>(value.getType()) &&
"expected integer type");
1955 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1958 if (
auto arraySlice =
1961 arraySlice.getInput(), arraySlice.getLowIndex(),
1962 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1964 return std::nullopt;
1970 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1971 prevExtractOpt->input == extractOpt->input &&
1973 extractOpt->width)) {
1975 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1977 extractOpt->width + prevExtractOpt->width);
1980 op.getLoc(), resIntType,
1982 prevExtractOpt->input,
1983 extractOpt->index));
1984 return flattenConcat(i - 1, i, replacement);
1992 if (commonOperand) {
1993 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2005 OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
2010 if (getTrueValue() == getFalseValue())
2011 return getTrueValue();
2012 if (
auto tv = adaptor.getTrueValue())
2013 if (tv == adaptor.getFalseValue())
2018 if (
auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
2019 if (pred.getValue().isZero())
2020 return getFalseValue();
2021 return getTrueValue();
2025 if (
auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
2026 if (
auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
2027 if (tv.getValue().isOne() && fv.getValue().isZero() &&
2044 if (
auto cmp = cond.getDefiningOp<ICmpOp>()) {
2046 auto requiredPredicate =
2047 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
2048 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
2058 if (
auto orOp = cond.getDefiningOp<
OrOp>()) {
2061 for (
auto operand : orOp.getOperands())
2068 if (
auto andOp = cond.getDefiningOp<
AndOp>()) {
2071 for (
auto operand : andOp.getOperands())
2089 PatternRewriter &rewriter) {
2092 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2095 Value indexValue = rootCmp.getLhs();
2098 auto getCaseValue = [&](
MuxOp mux) -> Value {
2099 return mux.getOperand(1 +
unsigned(!isFalseSide));
2104 auto getTreeValue = [&](
MuxOp mux) -> Value {
2105 return mux.getOperand(1 +
unsigned(isFalseSide));
2110 SmallVector<Location> locationsFound;
2111 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2115 auto collectConstantValues = [&](
MuxOp mux) ->
bool {
2117 mux.getCond(), indexValue, isFalseSide, [&](
hw::ConstantOp cst) {
2118 valuesFound.push_back({cst, getCaseValue(mux)});
2119 locationsFound.push_back(mux.getCond().getLoc());
2120 locationsFound.push_back(mux->getLoc());
2125 if (!collectConstantValues(rootMux))
2129 if (rootMux->hasOneUse()) {
2130 if (
auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2131 if (getTreeValue(userMux) == rootMux.getResult() &&
2139 auto nextTreeValue = getTreeValue(rootMux);
2141 auto nextMux = nextTreeValue.getDefiningOp<
MuxOp>();
2142 if (!nextMux || !nextMux->hasOneUse())
2144 if (!collectConstantValues(nextMux))
2146 nextTreeValue = getTreeValue(nextMux);
2152 if (valuesFound.size() < 3)
2157 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2158 if (indexWidth >= 9)
2164 uint64_t tableSize = 1ULL << indexWidth;
2165 if (valuesFound.size() < (tableSize * 5) / 8)
2170 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2175 for (
auto &elt : llvm::reverse(valuesFound)) {
2176 uint64_t idx = elt.first.getValue().getZExtValue();
2177 assert(idx < table.size() &&
"constant should be same bitwidth as index");
2178 table[idx] = elt.second;
2183 std::reverse(table.begin(), table.end());
2186 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2188 replaceOpWithNewOpAndCopyName<hw::ArrayGetOp>(rewriter, rootMux, array,
2203 PatternRewriter &rewriter) {
2204 assert(fullyAssoc->getNumOperands() >= 2 &&
"cannot split up unary ops");
2205 assert(operandNo < fullyAssoc->getNumOperands() &&
"Invalid operand #");
2209 if (fullyAssoc->getNumOperands() == 2)
2210 return fullyAssoc->getOperand(operandNo ^ 1);
2213 if (fullyAssoc->hasOneUse()) {
2214 fullyAssoc->eraseOperand(operandNo);
2215 return fullyAssoc->getResult(0);
2219 SmallVector<Value> operands;
2220 operands.append(fullyAssoc->getOperands().begin(),
2221 fullyAssoc->getOperands().begin() + operandNo);
2222 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2223 fullyAssoc->getOperands().end());
2225 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2226 Value excluded = fullyAssoc->getOperand(operandNo);
2230 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2232 return opWithoutExcluded;
2242 PatternRewriter &rewriter) {
2245 Operation *subExpr =
2246 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2247 if (!subExpr || subExpr->getNumOperands() < 2)
2251 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2256 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2257 size_t opNo = 0, e = subExpr->getNumOperands();
2258 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2264 Value cond = op.getCond();
2270 if (
auto subMux = dyn_cast<MuxOp>(subExpr)) {
2272 Value subCond = subMux.getCond();
2275 if (subMux.getTrueValue() == commonValue)
2276 otherValue = subMux.getFalseValue();
2277 else if (subMux.getFalseValue() == commonValue) {
2278 otherValue = subMux.getTrueValue();
2288 cond = rewriter.createOrFold<
OrOp>(op.getLoc(), cond, subCond,
false);
2289 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, cond, commonValue,
2290 otherValue, op.getTwoState());
2296 bool isaAndOp = isa<AndOp>(subExpr);
2297 if (isTrueOperand ^ isaAndOp)
2301 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2304 bool isaXorOp = isa<XorOp>(subExpr);
2305 bool isaOrOp = isa<OrOp>(subExpr);
2314 if (isaOrOp || isaXorOp) {
2315 auto masked = rewriter.createOrFold<
AndOp>(op.getLoc(), extendedCond,
2316 restOfAssoc,
false);
2318 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, masked, commonValue,
2321 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, masked, commonValue,
2327 assert(isaAndOp &&
"unexpected operation here");
2328 auto masked = rewriter.createOrFold<
OrOp>(op.getLoc(), extendedCond,
2329 restOfAssoc,
false);
2330 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, masked, commonValue,
2341 PatternRewriter &rewriter) {
2344 if (!isa<ConcatOp>(trueOp))
2348 SmallVector<Value> trueOperands, falseOperands;
2352 size_t numTrueOperands = trueOperands.size();
2353 size_t numFalseOperands = falseOperands.size();
2355 if (!numTrueOperands || !numFalseOperands ||
2356 (trueOperands.front() != falseOperands.front() &&
2357 trueOperands.back() != falseOperands.back()))
2361 if (trueOperands.front() == falseOperands.front()) {
2362 SmallVector<Value> operands;
2364 for (i = 0; i < numTrueOperands; ++i) {
2365 Value trueOperand = trueOperands[i];
2366 if (trueOperand == falseOperands[i])
2367 operands.push_back(trueOperand);
2371 if (i == numTrueOperands) {
2378 if (llvm::all_of(operands, [&](Value v) {
return v == operands.front(); }))
2379 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2380 mux->getLoc(), operands.front(), operands.size());
2382 sharedMSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2386 operands.append(trueOperands.begin() + i, trueOperands.end());
2387 Value trueLSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2389 operands.append(falseOperands.begin() + i, falseOperands.end());
2391 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2394 Value lsb = rewriter.createOrFold<
MuxOp>(
2395 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2396 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2401 if (trueOperands.back() == falseOperands.back()) {
2402 SmallVector<Value> operands;
2405 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2406 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2407 operands.push_back(trueOperand);
2411 std::reverse(operands.begin(), operands.end());
2412 Value sharedLSB = rewriter.createOrFold<
ConcatOp>(mux->getLoc(), operands);
2416 operands.append(trueOperands.begin(), trueOperands.end() - i);
2417 Value trueMSB = rewriter.createOrFold<
ConcatOp>(trueOp->getLoc(), operands);
2419 operands.append(falseOperands.begin(), falseOperands.end() - i);
2421 rewriter.createOrFold<
ConcatOp>(falseOp->getLoc(), operands);
2423 Value msb = rewriter.createOrFold<
MuxOp>(
2424 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2425 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, msb, sharedLSB);
2437 if (!trueVec || !falseVec)
2439 if (!trueVec.isUniform() || !falseVec.isUniform())
2443 op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2444 falseVec.getUniformElement(), op.getTwoState());
2446 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2453 using OpRewritePattern::OpRewritePattern;
2455 LogicalResult matchAndRewrite(
MuxOp op,
2456 PatternRewriter &rewriter)
const override;
2459 LogicalResult MuxRewriter::matchAndRewrite(
MuxOp op,
2460 PatternRewriter &rewriter)
const {
2469 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2470 if (value.getBitWidth() == 1) {
2472 if (value.isZero()) {
2474 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, notCond,
2475 op.getFalseValue(),
false);
2480 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getCond(),
2481 op.getFalseValue(),
false);
2487 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2492 APInt xorValue = value ^ value2;
2493 if (xorValue.isPowerOf2()) {
2494 unsigned leadingZeros = xorValue.countLeadingZeros();
2495 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2496 SmallVector<Value, 3> operands;
2504 if (leadingZeros > 0)
2505 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2506 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2510 auto v1 = rewriter.createOrFold<
ExtractOp>(
2511 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2512 auto v2 = rewriter.createOrFold<
ExtractOp>(
2513 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2514 operands.push_back(rewriter.createOrFold<
MuxOp>(
2515 op.getLoc(), op.getCond(), v1, v2,
false));
2517 if (trailingZeros > 0)
2518 operands.push_back(rewriter.createOrFold<
ExtractOp>(
2519 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2521 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
2528 if (value.isAllOnes() && value2.isZero()) {
2529 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2536 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2537 value.getBitWidth() == 1) {
2539 if (value.isZero()) {
2540 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getCond(),
2541 op.getTrueValue(),
false);
2548 auto notCond = rewriter.createOrFold<
XorOp>(op.getLoc(), op.getCond(),
2549 op.getFalseValue(),
false);
2550 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, notCond,
2551 op.getTrueValue(),
false);
2557 Operation *condOp = op.getCond().getDefiningOp();
2558 if (condOp && matchPattern(condOp,
m_Complement(m_Any(&subExpr))) &&
2560 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, op.getType(), subExpr,
2561 op.getFalseValue(), op.getTrueValue(),
2569 if (condOp && condOp->hasOneUse()) {
2570 SmallVector<Value> invertedOperands;
2574 auto getInvertedOperands = [&]() ->
bool {
2575 for (Value operand : condOp->getOperands()) {
2576 if (matchPattern(operand,
m_Complement(m_Any(&subExpr))))
2577 invertedOperands.push_back(subExpr);
2584 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2586 rewriter.createOrFold<
OrOp>(op.getLoc(), invertedOperands,
false);
2587 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newOr,
2589 op.getTrueValue(), op.getTwoState());
2592 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2594 rewriter.createOrFold<
AndOp>(op.getLoc(), invertedOperands,
false);
2595 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newAnd,
2597 op.getTrueValue(), op.getTwoState());
2603 dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp())) {
2605 if (op.getCond() == falseMux.getCond()) {
2606 replaceOpWithNewOpAndCopyName<MuxOp>(
2607 rewriter, op, op.getCond(), op.getTrueValue(),
2608 falseMux.getFalseValue(), op.getTwoStateAttr());
2618 dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp())) {
2620 if (op.getCond() == trueMux.getCond()) {
2621 replaceOpWithNewOpAndCopyName<MuxOp>(
2622 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2623 op.getFalseValue(), op.getTwoStateAttr());
2633 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2634 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2635 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2636 trueMux.getTrueValue() == falseMux.getTrueValue()) {
2637 auto subMux = rewriter.create<
MuxOp>(
2638 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2639 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2640 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2641 trueMux.getTrueValue(), subMux,
2642 op.getTwoStateAttr());
2647 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2648 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2649 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2650 trueMux.getFalseValue() == falseMux.getFalseValue()) {
2651 auto subMux = rewriter.create<
MuxOp>(
2652 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2653 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2654 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2655 subMux, trueMux.getFalseValue(),
2656 op.getTwoStateAttr());
2661 if (
auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2662 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2663 trueMux && falseMux &&
2664 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2665 trueMux.getFalseValue() == falseMux.getFalseValue()) {
2666 auto subMux = rewriter.create<
MuxOp>(
2667 rewriter.getFusedLoc(
2668 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2669 op.getCond(), trueMux.getCond(), falseMux.getCond());
2670 replaceOpWithNewOpAndCopyName<MuxOp>(
2671 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2672 op.getTwoStateAttr());
2684 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2685 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2686 if (trueOp->getName() == falseOp->getName())
2703 if (op.getInputs().empty() || op.isUniform())
2705 auto inputs = op.getInputs();
2716 for (
size_t i = 1, n =
inputs.size(); i < n; ++i) {
2718 if (!input || first.getCond() != input.getCond())
2723 SmallVector<Value> trues{first.getTrueValue()};
2724 SmallVector<Value> falses{first.getFalseValue()};
2725 SmallVector<Location> locs{first->getLoc()};
2726 bool isTwoState =
true;
2727 for (
size_t i = 1, n =
inputs.size(); i < n; ++i) {
2729 trues.push_back(input.getTrueValue());
2730 falses.push_back(input.getFalseValue());
2731 locs.push_back(input->getLoc());
2732 if (!input.getTwoState())
2741 auto arrayTy = op.getType();
2744 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, arrayTy, first.getCond(),
2745 trueValues, falseValues, isTwoState);
2750 using OpRewritePattern::OpRewritePattern;
2753 PatternRewriter &rewriter)
const override {
2757 if (foldArrayOfMuxes(op, rewriter))
2765 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2766 MLIRContext *context) {
2767 results.insert<MuxRewriter, ArrayRewriter>(context);
2778 switch (predicate) {
2779 case ICmpPredicate::eq:
2781 case ICmpPredicate::ne:
2783 case ICmpPredicate::slt:
2784 return lhs.slt(rhs);
2785 case ICmpPredicate::sle:
2786 return lhs.sle(rhs);
2787 case ICmpPredicate::sgt:
2788 return lhs.sgt(rhs);
2789 case ICmpPredicate::sge:
2790 return lhs.sge(rhs);
2791 case ICmpPredicate::ult:
2792 return lhs.ult(rhs);
2793 case ICmpPredicate::ule:
2794 return lhs.ule(rhs);
2795 case ICmpPredicate::ugt:
2796 return lhs.ugt(rhs);
2797 case ICmpPredicate::uge:
2798 return lhs.uge(rhs);
2799 case ICmpPredicate::ceq:
2801 case ICmpPredicate::cne:
2803 case ICmpPredicate::weq:
2805 case ICmpPredicate::wne:
2808 llvm_unreachable(
"unknown comparison predicate");
2814 switch (predicate) {
2815 case ICmpPredicate::eq:
2816 case ICmpPredicate::sle:
2817 case ICmpPredicate::sge:
2818 case ICmpPredicate::ule:
2819 case ICmpPredicate::uge:
2820 case ICmpPredicate::ceq:
2821 case ICmpPredicate::weq:
2823 case ICmpPredicate::ne:
2824 case ICmpPredicate::slt:
2825 case ICmpPredicate::sgt:
2826 case ICmpPredicate::ult:
2827 case ICmpPredicate::ugt:
2828 case ICmpPredicate::cne:
2829 case ICmpPredicate::wne:
2832 llvm_unreachable(
"unknown comparison predicate");
2835 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2841 if (getLhs() == getRhs()) {
2847 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2848 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2859 template <
typename Range>
2861 size_t commonPrefixLength = 0;
2862 auto ia = a.begin();
2863 auto ib = b.begin();
2865 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2871 return commonPrefixLength;
2875 size_t totalWidth = 0;
2876 for (
auto operand : operands) {
2879 ssize_t
width = operand.getType().getIntOrFloatBitWidth();
2881 totalWidth +=
width;
2891 PatternRewriter &rewriter) {
2895 SmallVector<Value> lhsOperands, rhsOperands;
2898 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2900 auto formCatOrReplicate = [&](Location loc,
2901 ArrayRef<Value> operands) -> Value {
2902 assert(!operands.empty());
2903 Value sameElement = operands[0];
2904 for (
size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2905 if (sameElement != operands[i])
2906 sameElement = Value();
2908 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2910 return rewriter.createOrFold<
ConcatOp>(loc, operands);
2913 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2914 Value rhs) -> LogicalResult {
2915 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2920 size_t commonPrefixLength =
2922 if (commonPrefixLength == lhsOperands.size()) {
2925 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
2931 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2933 size_t commonPrefixTotalWidth =
2934 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2935 size_t commonSuffixTotalWidth =
2936 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2937 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2938 .drop_back(commonSuffixLength);
2939 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2940 .drop_back(commonSuffixLength);
2942 auto replaceWithoutReplicatingSignBit = [&]() {
2943 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2944 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2945 return replaceWith(op.getPredicate(), newLhs, newRhs);
2948 auto replaceWithReplicatingSignBit = [&]() {
2949 auto firstNonEmptyValue = lhsOperands[0];
2950 auto firstNonEmptyElemWidth =
2951 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2952 Value signBit = rewriter.createOrFold<
ExtractOp>(
2953 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2955 auto newLhs = rewriter.
create<
ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2956 auto newRhs = rewriter.create<
ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2957 return replaceWith(op.getPredicate(), newLhs, newRhs);
2960 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2962 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2963 return replaceWithoutReplicatingSignBit();
2969 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2970 return replaceWithReplicatingSignBit();
2972 }
else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2974 return replaceWithoutReplicatingSignBit();
2988 ICmpOp cmpOp,
const KnownBits &bitAnalysis,
const APInt &rhsCst,
2989 PatternRewriter &rewriter) {
2993 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2994 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2997 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2998 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
3006 SmallVector<Value> newConcatOperands;
3007 auto newConstant = APInt::getZeroWidth();
3012 unsigned knownMSB = bitsKnown.countLeadingOnes();
3014 Value operand = cmpOp.getLhs();
3019 while (knownMSB != bitsKnown.getBitWidth()) {
3022 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3025 unsigned unknownBits = bitsKnown.countLeadingZeros();
3026 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3027 auto spanOperand = rewriter.createOrFold<
ExtractOp>(
3028 operand.getLoc(), operand, lowBit,
3030 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3033 newConcatOperands.push_back(spanOperand);
3036 if (newConstant.getBitWidth() != 0)
3037 newConstant = newConstant.concat(spanConstant);
3039 newConstant = spanConstant;
3042 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3043 bitsKnown = bitsKnown.trunc(newWidth);
3044 knownMSB = bitsKnown.countLeadingOnes();
3050 if (newConcatOperands.empty()) {
3051 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3052 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
3058 Value concatResult =
3059 rewriter.createOrFold<
ConcatOp>(operand.getLoc(), newConcatOperands);
3063 cmpOp.getOperand(1).getLoc(), newConstant);
3065 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3066 concatResult, newConstantOp,
3067 cmpOp.getTwoState());
3073 PatternRewriter &rewriter) {
3074 auto ip = rewriter.saveInsertionPoint();
3075 rewriter.setInsertionPoint(xorOp);
3077 auto xorRHS = xorOp.getOperands().back().getDefiningOp<
hw::ConstantOp>();
3079 xorRHS.getValue() ^ rhs);
3081 switch (xorOp.getNumOperands()) {
3085 APInt::getZero(rhs.getBitWidth()));
3089 newLHS = xorOp.getOperand(0);
3093 SmallVector<Value> newOperands(xorOp.getOperands());
3094 newOperands.pop_back();
3095 newLHS = rewriter.create<
XorOp>(xorOp.getLoc(), newOperands,
false);
3099 bool xorMultipleUses = !xorOp->hasOneUse();
3103 if (xorMultipleUses)
3104 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3108 rewriter.restoreInsertionPoint(ip);
3109 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3110 newLHS, newRHS,
false);
3113 LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3120 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3121 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3122 "Should be folded");
3123 replaceOpWithNewOpAndCopyName<ICmpOp>(
3124 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3125 op.getRhs(), op.getLhs(), op.getTwoState());
3130 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3132 return rewriter.create<
hw::ConstantOp>(op.getLoc(), std::move(constant));
3135 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3136 Value rhs) -> LogicalResult {
3137 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
3142 auto replaceWithConstantI1 = [&](
bool constant) -> LogicalResult {
3143 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
3144 APInt(1, constant));
3148 switch (op.getPredicate()) {
3149 case ICmpPredicate::slt:
3151 if (rhs.isMaxSignedValue())
3152 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3154 if (rhs.isMinSignedValue())
3155 return replaceWithConstantI1(0);
3157 if ((rhs - 1).isMinSignedValue())
3158 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3161 case ICmpPredicate::sgt:
3163 if (rhs.isMinSignedValue())
3164 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3166 if (rhs.isMaxSignedValue())
3167 return replaceWithConstantI1(0);
3169 if ((rhs + 1).isMaxSignedValue())
3170 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3173 case ICmpPredicate::ult:
3175 if (rhs.isAllOnes())
3176 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3179 return replaceWithConstantI1(0);
3181 if ((rhs - 1).isZero())
3182 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3186 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3187 rhs.getBitWidth()) {
3188 auto numOnes = rhs.countLeadingOnes();
3189 auto smaller = rewriter.create<
ExtractOp>(
3190 op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3191 return replaceWith(ICmpPredicate::ne, smaller,
3196 case ICmpPredicate::ugt:
3199 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3201 if (rhs.isAllOnes())
3202 return replaceWithConstantI1(0);
3204 if ((rhs + 1).isAllOnes())
3205 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3209 if ((rhs + 1).isPowerOf2()) {
3210 auto numOnes = rhs.countTrailingOnes();
3211 auto newWidth = rhs.getBitWidth() - numOnes;
3212 auto smaller = rewriter.create<
ExtractOp>(op.getLoc(), op.getLhs(),
3214 return replaceWith(ICmpPredicate::ne, smaller,
3219 case ICmpPredicate::sle:
3221 if (rhs.isMaxSignedValue())
3222 return replaceWithConstantI1(1);
3224 return replaceWith(ICmpPredicate::slt, op.getLhs(),
getConstant(rhs + 1));
3225 case ICmpPredicate::sge:
3227 if (rhs.isMinSignedValue())
3228 return replaceWithConstantI1(1);
3230 return replaceWith(ICmpPredicate::sgt, op.getLhs(),
getConstant(rhs - 1));
3231 case ICmpPredicate::ule:
3233 if (rhs.isAllOnes())
3234 return replaceWithConstantI1(1);
3236 return replaceWith(ICmpPredicate::ult, op.getLhs(),
getConstant(rhs + 1));
3237 case ICmpPredicate::uge:
3240 return replaceWithConstantI1(1);
3242 return replaceWith(ICmpPredicate::ugt, op.getLhs(),
getConstant(rhs - 1));
3243 case ICmpPredicate::eq:
3244 if (rhs.getBitWidth() == 1) {
3247 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3252 if (rhs.isAllOnes()) {
3259 case ICmpPredicate::ne:
3260 if (rhs.getBitWidth() == 1) {
3266 if (rhs.isAllOnes()) {
3268 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3275 case ICmpPredicate::ceq:
3276 case ICmpPredicate::cne:
3277 case ICmpPredicate::weq:
3278 case ICmpPredicate::wne:
3284 if (op.getPredicate() == ICmpPredicate::eq ||
3285 op.getPredicate() == ICmpPredicate::ne) {
3290 if (!knownBits.isUnknown())
3297 if (
auto xorOp = op.getLhs().getDefiningOp<
XorOp>())
3304 if (
auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3305 if (rhs.isAllOnes() || rhs.isZero()) {
3306 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3308 op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(
width)
3309 : APInt::getZero(
width));
3310 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, op.getPredicate(),
3311 replicateOp.getInput(), cst,
3321 if (Operation *opLHS = op.getLhs().getDefiningOp())
3322 if (Operation *opRHS = op.getRhs().getDefiningOp())
3323 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3324 isa<ConcatOp, ReplicateOp>(opRHS)) {
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "sv.namehint" attribute.
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1, size_t concatIdx2, PatternRewriter &rewriter)
Simplify concat ops in an or op when a constant operand is present in either concat.
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
llvm::SmallVector< StringAttr > inputs
def create(data_type, value)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
uint64_t getWidth(Type t)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
bool isOffset(Value base, Value index, uint64_t offset)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.