20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/APSInt.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
27 using namespace circt;
28 using namespace firrtl;
32 static Value
dropWrite(PatternRewriter &rewriter, OpResult old,
34 SmallPtrSet<Operation *, 8> users;
35 for (
auto *user : old.getUsers())
37 for (Operation *user : users)
38 if (
auto connect = dyn_cast<FConnectLike>(user))
40 rewriter.eraseOp(user);
49 Operation *op = passthrough.getDefiningOp();
52 assert(op &&
"passthrough must be an operation");
53 Operation *oldOp = old.getOwner();
54 auto name = oldOp->getAttrOfType<StringAttr>(
"name");
55 if (name && !name.getValue().empty())
56 op->setAttr(
"name", name);
64 #include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
72 auto resultType = type_cast<IntType>(op->getResult(0).getType());
73 if (!resultType.hasWidth())
75 for (Value operand : op->getOperands())
76 if (!type_cast<IntType>(operand.getType()).hasWidth())
83 auto t = type_dyn_cast<UIntType>(type);
84 if (!t || !t.hasWidth() || t.getWidth() != 1)
91 static void updateName(PatternRewriter &rewriter, Operation *op,
94 assert(!isa<InstanceOp>(op));
95 if (!name || name.getValue().empty())
97 auto newName = name.getValue();
98 auto newOpName = op->getAttrOfType<StringAttr>(
"name");
101 newName =
chooseName(newOpName.getValue(), name.getValue());
103 if (!newOpName || newOpName.getValue() != newName)
104 rewriter.modifyOpInPlace(
105 op, [&] { op->setAttr(
"name", rewriter.getStringAttr(newName)); });
113 if (
auto *newOp = newValue.getDefiningOp()) {
114 auto name = op->getAttrOfType<StringAttr>(
"name");
117 rewriter.replaceOp(op, newValue);
123 template <
typename OpTy,
typename... Args>
125 Operation *op, Args &&...args) {
126 auto name = op->getAttrOfType<StringAttr>(
"name");
128 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
136 if (
auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
137 return namableOp.hasDroppableName();
148 static std::optional<APSInt>
150 assert(type_cast<IntType>(operand.getType()) &&
151 "getExtendedConstant is limited to integer types");
158 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
163 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
164 return APSInt(destWidth,
165 type_cast<IntType>(operand.getType()).isUnsigned());
173 if (
auto attr = dyn_cast<BoolAttr>(operand))
174 return APSInt(APInt(1, attr.getValue()));
175 if (
auto attr = dyn_cast<IntegerAttr>(operand))
176 return attr.getAPSInt();
184 return cst->isZero();
211 Operation *op, ArrayRef<Attribute> operands,
BinOpKind opKind,
212 const function_ref<APInt(
const APSInt &,
const APSInt &)> &calculate) {
213 assert(operands.size() == 2 &&
"binary op takes two operands");
216 auto resultType = type_cast<IntType>(op->getResult(0).getType());
217 if (resultType.getWidthOrSentinel() < 0)
221 if (resultType.getWidthOrSentinel() == 0)
222 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
228 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
230 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
231 if (
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
232 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
233 if (
auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
234 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
238 int32_t operandWidth;
241 operandWidth = resultType.getWidthOrSentinel();
246 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
250 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
261 APInt resultValue = calculate(*lhs, *rhs);
266 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
268 assert((
unsigned)resultType.getWidthOrSentinel() ==
269 resultValue.getBitWidth());
282 Operation *op, PatternRewriter &rewriter,
283 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
285 if (op->getNumResults() != 1)
287 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
292 auto width = type.getBitWidthOrSentinel();
297 SmallVector<Attribute, 3> constOperands;
298 constOperands.reserve(op->getNumOperands());
299 for (
auto operand : op->getOperands()) {
301 if (
auto *defOp = operand.getDefiningOp())
302 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
303 [&](
auto op) { attr = op.getValueAttr(); });
304 constOperands.push_back(attr);
309 auto result = canonicalize(constOperands);
313 if (
auto cst = dyn_cast<Attribute>(result))
314 resultValue = op->getDialect()
315 ->materializeConstant(rewriter, cst, type, op->getLoc())
318 resultValue = result.get<Value>();
322 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
323 resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue,
width);
326 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
327 resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
328 else if (type_isa<UIntType>(type) &&
329 type_isa<SIntType>(resultValue.getType()))
330 resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
332 assert(type == resultValue.getType() &&
"canonicalization changed type");
340 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
346 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
352 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
359 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
360 assert(adaptor.getOperands().empty() &&
"constant has no operands");
361 return getValueAttr();
364 OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
365 assert(adaptor.getOperands().empty() &&
"constant has no operands");
366 return getValueAttr();
369 OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
370 assert(adaptor.getOperands().empty() &&
"constant has no operands");
371 return getFieldsAttr();
374 OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
375 assert(adaptor.getOperands().empty() &&
"constant has no operands");
376 return getValueAttr();
379 OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
380 assert(adaptor.getOperands().empty() &&
"constant has no operands");
381 return getValueAttr();
384 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
385 assert(adaptor.getOperands().empty() &&
"constant has no operands");
386 return getValueAttr();
389 OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
390 assert(adaptor.getOperands().empty() &&
"constant has no operands");
391 return getValueAttr();
398 OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
401 [=](
const APSInt &a,
const APSInt &b) { return a + b; });
404 void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
405 MLIRContext *context) {
406 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
407 patterns::AddOfSelf, patterns::AddOfPad>(context);
410 OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
413 [=](
const APSInt &a,
const APSInt &b) { return a - b; });
416 void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 MLIRContext *context) {
418 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
419 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
420 patterns::SubOfPadL, patterns::SubOfPadR>(context);
423 OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
435 [=](
const APSInt &a,
const APSInt &b) { return a * b; });
438 OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
445 if (getLhs() == getRhs()) {
446 auto width = getType().base().getWidthOrSentinel();
468 if (
auto rhsCst = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
469 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
474 [=](
const APSInt &a,
const APSInt &b) -> APInt {
477 return APInt(a.getBitWidth(), 0);
481 OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
488 if (getLhs() == getRhs())
502 [=](
const APSInt &a,
const APSInt &b) -> APInt {
505 return APInt(a.getBitWidth(), 0);
509 OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
512 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
515 OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
518 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
521 OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
524 [=](
const APSInt &a,
const APSInt &b) -> APInt {
525 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
531 OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
534 if (rhsCst->isZero())
538 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
539 getRhs().getType() == getType())
545 if (lhsCst->isZero())
549 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
555 if (getLhs() == getRhs() && getRhs().getType() == getType())
560 [](
const APSInt &a,
const APSInt &b) -> APInt { return a & b; });
563 void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
564 MLIRContext *context) {
566 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
567 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
568 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
571 OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
574 if (rhsCst->isZero() && getLhs().getType() == getType())
578 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
579 getLhs().getType() == getType())
585 if (lhsCst->isZero() && getRhs().getType() == getType())
589 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
590 getRhs().getType() == getType())
595 if (getLhs() == getRhs() && getRhs().getType() == getType())
600 [](
const APSInt &a,
const APSInt &b) -> APInt { return a | b; });
603 void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
604 MLIRContext *context) {
605 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
606 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad>(
610 OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
613 if (rhsCst->isZero() &&
619 if (lhsCst->isZero() &&
624 if (getLhs() == getRhs())
627 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
631 [](
const APSInt &a,
const APSInt &b) -> APInt { return a ^ b; });
634 void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
635 MLIRContext *context) {
636 results.insert<patterns::extendXor, patterns::moveConstXor,
637 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
641 void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
642 MLIRContext *context) {
643 results.insert<patterns::LEQWithConstLHS>(context);
646 OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
647 bool isUnsigned = getLhs().getType().base().isUnsigned();
650 if (getLhs() == getRhs())
656 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
657 commonWidth = std::max(commonWidth, 1);
668 if (isUnsigned && rhsCst->zext(commonWidth)
681 [=](
const APSInt &a,
const APSInt &b) -> APInt {
682 return APInt(1, a <= b);
686 void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
687 MLIRContext *context) {
688 results.insert<patterns::LTWithConstLHS>(context);
691 OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
692 IntType lhsType = getLhs().getType();
696 if (getLhs() == getRhs())
708 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
709 commonWidth = std::max(commonWidth, 1);
720 if (isUnsigned && rhsCst->zext(commonWidth)
733 [=](
const APSInt &a,
const APSInt &b) -> APInt {
734 return APInt(1, a < b);
738 void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
739 MLIRContext *context) {
740 results.insert<patterns::GEQWithConstLHS>(context);
743 OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
744 IntType lhsType = getLhs().getType();
748 if (getLhs() == getRhs())
753 if (rhsCst->isZero() && isUnsigned)
760 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
761 commonWidth = std::max(commonWidth, 1);
764 if (isUnsigned && rhsCst->zext(commonWidth)
785 [=](
const APSInt &a,
const APSInt &b) -> APInt {
786 return APInt(1, a >= b);
790 void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
791 MLIRContext *context) {
792 results.insert<patterns::GTWithConstLHS>(context);
795 OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
796 IntType lhsType = getLhs().getType();
800 if (getLhs() == getRhs())
806 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
807 commonWidth = std::max(commonWidth, 1);
810 if (isUnsigned && rhsCst->zext(commonWidth)
831 [=](
const APSInt &a,
const APSInt &b) -> APInt {
832 return APInt(1, a > b);
836 OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
838 if (getLhs() == getRhs())
844 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
845 getRhs().getType() == getType())
851 [=](
const APSInt &a,
const APSInt &b) -> APInt {
852 return APInt(1, a == b);
856 LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
858 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
860 auto width = op.getLhs().getType().getBitWidthOrSentinel();
863 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
864 op.getRhs().getType() == op.getType()) {
865 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
870 if (rhsCst->isZero() &&
width > 1) {
871 auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
872 return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
876 if (rhsCst->isAllOnes() &&
width > 1 &&
877 op.getLhs().getType() == op.getRhs().getType()) {
878 return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
887 OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
889 if (getLhs() == getRhs())
895 if (rhsCst->isZero() && getLhs().getType() == getType() &&
896 getRhs().getType() == getType())
902 [=](
const APSInt &a,
const APSInt &b) -> APInt {
903 return APInt(1, a != b);
907 LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
909 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
911 auto width = op.getLhs().getType().getBitWidthOrSentinel();
914 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
915 op.getRhs().getType() == op.getType()) {
916 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
921 if (rhsCst->isZero() &&
width > 1) {
922 return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
927 if (rhsCst->isAllOnes() &&
width > 1 &&
928 op.getLhs().getType() == op.getRhs().getType()) {
929 auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
930 return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
938 OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
944 OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
950 OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
960 OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
961 auto base = getInput().getType();
962 auto w = base.getBitWidthOrSentinel();
968 OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
975 OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
983 if (getType().base().hasWidth())
990 void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
991 MLIRContext *context) {
992 results.insert<patterns::StoUtoS>(context);
995 OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1003 if (getType().base().hasWidth())
1010 void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1011 MLIRContext *context) {
1012 results.insert<patterns::UtoStoU>(context);
1015 OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1017 if (getInput().getType() == getType())
1027 OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1029 if (getInput().getType() == getType())
1039 OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1045 getType().base().getWidthOrSentinel()))
1051 void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1052 MLIRContext *context) {
1053 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1056 OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1063 getType().base().getWidthOrSentinel()))
1064 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1069 OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1074 getType().base().getWidthOrSentinel()))
1080 void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1081 MLIRContext *context) {
1082 results.insert<patterns::NotNot>(context);
1085 OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1089 if (getInput().getType().getBitWidthOrSentinel() == 0)
1094 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1098 if (
isUInt1(getInput().getType()))
1104 void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1105 MLIRContext *context) {
1107 .insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1108 patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
1109 patterns::AndRCatZeroL, patterns::AndRCatZeroR>(context);
1112 OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1116 if (getInput().getType().getBitWidthOrSentinel() == 0)
1121 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1125 if (
isUInt1(getInput().getType()))
1131 void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1132 MLIRContext *context) {
1133 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1134 patterns::OrRCatZeroH, patterns::OrRCatZeroL>(context);
1137 OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1141 if (getInput().getType().getBitWidthOrSentinel() == 0)
1146 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1149 if (
isUInt1(getInput().getType()))
1155 void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1156 MLIRContext *context) {
1157 results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1158 patterns::XorRCatZeroH, patterns::XorRCatZeroL>(context);
1165 OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1169 IntType lhsType = getLhs().getType();
1170 IntType rhsType = getRhs().getType();
1182 return getIntAttr(getType(), lhs->concat(*rhs));
1187 void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1188 MLIRContext *context) {
1189 results.insert<patterns::DShlOfConstant>(context);
1192 void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1193 MLIRContext *context) {
1194 results.insert<patterns::DShrOfConstant>(context);
1200 struct CatBitsBits :
public mlir::RewritePattern {
1201 CatBitsBits(MLIRContext *context)
1202 : RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
1203 LogicalResult matchAndRewrite(Operation *op,
1204 PatternRewriter &rewriter)
const override {
1205 auto cat = cast<CatPrimOp>(op);
1207 dyn_cast_or_null<BitsPrimOp>(cat.getLhs().getDefiningOp())) {
1209 dyn_cast_or_null<BitsPrimOp>(cat.getRhs().getDefiningOp())) {
1210 if (lhsBits.getInput() == rhsBits.getInput() &&
1211 lhsBits.getLo() - 1 == rhsBits.getHi()) {
1212 replaceOpWithNewOpAndCopyName<BitsPrimOp>(
1213 rewriter, cat, cat.getType(), lhsBits.getInput(), lhsBits.getHi(),
1224 void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1225 MLIRContext *context) {
1226 results.insert<CatBitsBits, patterns::CatDoubleConst>(context);
1229 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1232 if (op.getType() == op.getInput().getType())
1233 return op.getInput();
1237 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1238 if (op.getType() == in.getInput().getType())
1239 return in.getInput();
1244 OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1245 IntType inputType = getInput().getType();
1246 IntType resultType = getType();
1248 if (inputType == getType() && resultType.
hasWidth())
1255 cst->extractBits(getHi() - getLo() + 1, getLo()));
1260 void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1261 MLIRContext *context) {
1263 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1264 patterns::BitsOfAnd, patterns::BitsOfPad>(context);
1271 unsigned loBit, PatternRewriter &rewriter) {
1272 auto resType = type_cast<IntType>(op->getResult(0).getType());
1273 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1274 value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
1276 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1277 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1278 }
else if (resType.isUnsigned() &&
1279 !type_cast<IntType>(value.getType()).isUnsigned()) {
1280 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1282 rewriter.replaceOp(op, value);
1285 template <
typename OpTy>
1286 static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
1288 if (op.getType().getBitWidthOrSentinel() == 0)
1290 APInt(0, 0, op.getType().isSignedInteger()));
1293 if (op.getHigh() == op.getLow())
1294 return op.getHigh();
1299 if (op.getType().getBitWidthOrSentinel() < 0)
1304 if (cond->isZero() && op.getLow().getType() == op.getType())
1306 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1307 return op.getHigh();
1311 if (
auto lowCst =
getConstant(adaptor.getLow())) {
1313 if (
auto highCst =
getConstant(adaptor.getHigh())) {
1315 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1316 *highCst == *lowCst)
1319 if (highCst->isOne() && lowCst->isZero() &&
1320 op.getType() == op.getSel().getType())
1333 OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1334 return foldMux(*
this, adaptor);
1337 OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1338 return foldMux(*
this, adaptor);
1341 OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return {}; }
1348 class MuxPad :
public mlir::RewritePattern {
1350 MuxPad(MLIRContext *context)
1351 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1354 matchAndRewrite(Operation *op,
1355 mlir::PatternRewriter &rewriter)
const override {
1356 auto mux = cast<MuxPrimOp>(op);
1357 auto width = mux.getType().getBitWidthOrSentinel();
1361 auto pad = [&](Value input) -> Value {
1363 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1364 if (inputWidth < 0 ||
width == inputWidth)
1367 .create<PadPrimOp>(mux.getLoc(), mux.getType(), input,
width)
1371 auto newHigh = pad(mux.getHigh());
1372 auto newLow = pad(mux.getLow());
1373 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1376 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1377 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1385 class MuxSharedCond :
public mlir::RewritePattern {
1387 MuxSharedCond(MLIRContext *context)
1388 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1390 static const int depthLimit = 5;
1392 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1393 mlir::PatternRewriter &rewriter,
1394 bool updateInPlace)
const {
1395 if (updateInPlace) {
1396 rewriter.modifyOpInPlace(mux, [&] {
1397 mux.setOperand(1, high);
1398 mux.setOperand(2, low);
1402 rewriter.setInsertionPointAfter(mux);
1404 .create<MuxPrimOp>(mux.getLoc(), mux.getType(),
1405 ValueRange{mux.getSel(), high, low})
1410 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1411 bool updateInPlace,
int limit)
const {
1412 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1415 if (mux.getSel() == cond)
1416 return mux.getHigh();
1417 if (limit > depthLimit)
1419 updateInPlace &= mux->hasOneUse();
1421 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1423 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1426 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1427 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1432 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1433 bool updateInPlace,
int limit)
const {
1434 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1437 if (mux.getSel() == cond)
1438 return mux.getLow();
1439 if (limit > depthLimit)
1441 updateInPlace &= mux->hasOneUse();
1443 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1445 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1447 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1449 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1455 matchAndRewrite(Operation *op,
1456 mlir::PatternRewriter &rewriter)
const override {
1457 auto mux = cast<MuxPrimOp>(op);
1458 auto width = mux.getType().getBitWidthOrSentinel();
1462 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter,
true, 0)) {
1463 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1467 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter,
true, 0)) {
1468 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1477 void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1478 MLIRContext *context) {
1479 results.add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1480 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ,
1481 patterns::MuxNot, patterns::MuxSameTrue, patterns::MuxSameFalse,
1482 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS>(context);
1485 OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1486 auto input = this->getInput();
1489 if (input.getType() == getType())
1493 auto inputType = input.getType().base();
1500 auto destWidth = getType().base().getWidthOrSentinel();
1501 if (destWidth == -1)
1504 if (inputType.
isSigned() && cst->getBitWidth())
1505 return getIntAttr(getType(), cst->sext(destWidth));
1506 return getIntAttr(getType(), cst->zext(destWidth));
1512 OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1513 auto input = this->getInput();
1514 IntType inputType = input.getType();
1515 int shiftAmount = getAmount();
1518 if (shiftAmount == 0)
1524 if (inputWidth != -1) {
1525 auto resultWidth = inputWidth + shiftAmount;
1526 shiftAmount = std::min(shiftAmount, resultWidth);
1527 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1533 OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1534 auto input = this->getInput();
1535 IntType inputType = input.getType();
1536 int shiftAmount = getAmount();
1542 if (shiftAmount == 0 && inputWidth > 0)
1545 if (inputWidth == -1)
1547 if (inputWidth == 0)
1552 if (shiftAmount >= inputWidth && inputType.
isUnsigned())
1553 return getIntAttr(getType(), APInt(0, 0,
false));
1559 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1561 value = cst->lshr(std::min(shiftAmount, inputWidth));
1562 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1563 return getIntAttr(getType(), value.trunc(resultWidth));
1568 LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1569 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1570 if (inputWidth <= 0)
1574 unsigned shiftAmount = op.getAmount();
1575 if (
int(shiftAmount) >= inputWidth) {
1577 if (op.getType().base().isUnsigned())
1583 shiftAmount = inputWidth - 1;
1586 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1590 LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1591 PatternRewriter &rewriter) {
1592 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1593 if (inputWidth <= 0)
1597 unsigned keepAmount = op.getAmount();
1599 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1604 OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1608 getInput().getType().base().getWidthOrSentinel() - getAmount();
1609 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1615 OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1619 cst->trunc(getType().base().getWidthOrSentinel()));
1623 LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1624 PatternRewriter &rewriter) {
1625 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1626 if (inputWidth <= 0)
1630 unsigned dropAmount = op.getAmount();
1631 if (dropAmount !=
unsigned(inputWidth))
1637 void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1638 MLIRContext *context) {
1639 results.add<patterns::SubaccessOfConstant>(context);
1642 OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1644 if (adaptor.getInputs().size() == 1)
1645 return getOperand(1);
1647 if (
auto constIndex =
getConstant(adaptor.getIndex())) {
1648 auto index = constIndex->getZExtValue();
1649 if (index < getInputs().size())
1650 return getInputs()[getInputs().size() - 1 - index];
1656 LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1657 PatternRewriter &rewriter) {
1661 if (llvm::all_of(op.getInputs().drop_front(), [&](
auto input) {
1662 return input == op.getInputs().front();
1670 if (
auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
1671 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](
auto e) {
1672 auto subindex = e.value().template getDefiningOp<SubindexOp>();
1673 return subindex && lastSubindex.getInput() == subindex.getInput() &&
1674 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
1676 replaceOpWithNewOpAndCopyName<SubaccessOp>(
1677 rewriter, op, lastSubindex.getInput(), op.getIndex());
1683 if (op.getInputs().size() != 2)
1687 auto uintType = op.getIndex().getType();
1688 if (uintType.getBitWidthOrSentinel() != 1)
1692 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1693 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
1713 for (Operation *user : value.getUsers()) {
1715 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
1718 if (
auto aConnect = dyn_cast<FConnectLike>(user))
1719 if (aConnect.getDest() == value) {
1720 auto strictConnect = dyn_cast<StrictConnectOp>(*aConnect);
1724 strictConnect->getBlock() != value.getParentBlock())
1735 PatternRewriter &rewriter) {
1738 Operation *connectedDecl = op.getDest().getDefiningOp();
1743 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
1748 cast<Forceable>(connectedDecl).isForceable())
1756 if (connectedDecl->hasOneUse())
1760 auto *declBlock = connectedDecl->getBlock();
1761 auto *srcValueOp = op.getSrc().getDefiningOp();
1764 if (!isa<WireOp>(connectedDecl))
1770 if (!isa<ConstantOp>(srcValueOp))
1772 if (srcValueOp->getBlock() != declBlock)
1778 auto replacement = op.getSrc();
1781 if (srcValueOp && srcValueOp != &declBlock->front())
1782 srcValueOp->moveBefore(&declBlock->front());
1789 rewriter.eraseOp(op);
1793 void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1794 MLIRContext *context) {
1795 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
1799 LogicalResult StrictConnectOp::canonicalize(StrictConnectOp op,
1800 PatternRewriter &rewriter) {
1817 for (
auto *user : value.getUsers()) {
1818 auto attach = dyn_cast<AttachOp>(user);
1819 if (!attach || attach == dominatedAttach)
1821 if (attach->isBeforeInBlock(dominatedAttach))
1827 LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
1829 if (op.getNumOperands() <= 1) {
1830 rewriter.eraseOp(op);
1834 for (
auto operand : op.getOperands()) {
1841 SmallVector<Value> newOperands(op.getOperands());
1842 for (
auto newOperand : attach.getOperands())
1843 if (newOperand != operand)
1844 newOperands.push_back(newOperand);
1845 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1846 rewriter.eraseOp(attach);
1847 rewriter.eraseOp(op);
1855 if (
auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
1856 if (!
hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
1857 !wire.isForceable()) {
1858 SmallVector<Value> newOperands;
1859 for (
auto newOperand : op.getOperands())
1860 if (newOperand != operand)
1861 newOperands.push_back(newOperand);
1863 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1864 rewriter.eraseOp(op);
1865 rewriter.eraseOp(wire);
1876 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
1877 rewriter.inlineBlockBefore(®ion.front(), op, {});
1880 LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
1881 if (
auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
1882 if (constant.getValue().isAllOnes())
1884 else if (op.hasElseRegion() && !op.getElseRegion().empty())
1887 rewriter.eraseOp(op);
1893 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
1894 op.getElseBlock().empty()) {
1895 rewriter.eraseBlock(&op.getElseBlock());
1902 if (!op.getThenBlock().empty())
1906 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
1907 rewriter.eraseOp(op);
1916 struct FoldNodeName :
public mlir::RewritePattern {
1917 FoldNodeName(MLIRContext *context)
1918 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1919 LogicalResult matchAndRewrite(Operation *op,
1920 PatternRewriter &rewriter)
const override {
1921 auto node = cast<NodeOp>(op);
1922 auto name = node.getNameAttr();
1923 if (!node.hasDroppableName() || node.getInnerSym() ||
1926 auto *newOp = node.getInput().getDefiningOp();
1928 if (newOp && !isa<InstanceOp>(newOp))
1930 rewriter.replaceOp(node, node.getInput());
1936 struct NodeBypass :
public mlir::RewritePattern {
1937 NodeBypass(MLIRContext *context)
1938 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1939 LogicalResult matchAndRewrite(Operation *op,
1940 PatternRewriter &rewriter)
const override {
1941 auto node = cast<NodeOp>(op);
1942 if (node.getInnerSym() || !
AnnotationSet(node).canBeDeleted() ||
1943 node.use_empty() || node.isForceable())
1945 rewriter.startOpModification(node);
1946 node.getResult().replaceAllUsesWith(node.getInput());
1947 rewriter.finalizeOpModification(node);
1954 template <
typename OpTy>
1956 PatternRewriter &rewriter) {
1957 if (!op.isForceable() || !op.getDataRef().use_empty())
1965 LogicalResult NodeOp::fold(FoldAdaptor adaptor,
1966 SmallVectorImpl<OpFoldResult> &results) {
1971 if (getAnnotationsAttr() &&
1976 if (!adaptor.getInput())
1979 results.push_back(adaptor.getInput());
1983 void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1984 MLIRContext *context) {
1985 results.insert<FoldNodeName>(context);
1986 results.add(demoteForceableIfUnused<NodeOp>);
1992 struct AggOneShot :
public mlir::RewritePattern {
1993 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
1994 : RewritePattern(name, 0, context) {}
1996 SmallVector<Value> getCompleteWrite(Operation *lhs)
const {
1997 auto lhsTy = lhs->getResult(0).getType();
1998 if (!type_isa<BundleType, FVectorType>(lhsTy))
2001 DenseMap<uint32_t, Value> fields;
2002 for (Operation *user : lhs->getResult(0).getUsers()) {
2003 if (user->getParentOp() != lhs->getParentOp())
2005 if (
auto aConnect = dyn_cast<StrictConnectOp>(user)) {
2006 if (aConnect.getDest() == lhs->getResult(0))
2008 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2009 for (Operation *subuser : subField.getResult().getUsers()) {
2010 if (
auto aConnect = dyn_cast<StrictConnectOp>(subuser)) {
2011 if (aConnect.getDest() == subField) {
2012 if (subuser->getParentOp() != lhs->getParentOp())
2014 if (fields.count(subField.getFieldIndex()))
2016 fields[subField.getFieldIndex()] = aConnect.getSrc();
2022 }
else if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2023 for (Operation *subuser : subIndex.getResult().getUsers()) {
2024 if (
auto aConnect = dyn_cast<StrictConnectOp>(subuser)) {
2025 if (aConnect.getDest() == subIndex) {
2026 if (subuser->getParentOp() != lhs->getParentOp())
2028 if (fields.count(subIndex.getIndex()))
2030 fields[subIndex.getIndex()] = aConnect.getSrc();
2041 SmallVector<Value> values;
2042 uint32_t total = type_isa<BundleType>(lhsTy)
2043 ? type_cast<BundleType>(lhsTy).getNumElements()
2044 : type_cast<FVectorType>(lhsTy).getNumElements();
2045 for (uint32_t i = 0; i < total; ++i) {
2046 if (!fields.count(i))
2048 values.push_back(fields[i]);
2053 LogicalResult matchAndRewrite(Operation *op,
2054 PatternRewriter &rewriter)
const override {
2055 auto values = getCompleteWrite(op);
2058 rewriter.setInsertionPointToEnd(op->getBlock());
2059 auto dest = op->getResult(0);
2060 auto destType = dest.getType();
2063 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2066 Value newVal = type_isa<BundleType>(destType)
2067 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2069 : rewriter.createOrFold<VectorCreateOp>(
2070 op->getLoc(), destType, values);
2071 rewriter.createOrFold<StrictConnectOp>(op->getLoc(), dest, newVal);
2072 for (Operation *user : dest.getUsers()) {
2073 if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2074 for (Operation *subuser :
2075 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2076 if (
auto aConnect = dyn_cast<StrictConnectOp>(subuser))
2077 if (aConnect.getDest() == subIndex)
2078 rewriter.eraseOp(aConnect);
2079 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2080 for (Operation *subuser :
2081 llvm::make_early_inc_range(subField.getResult().getUsers()))
2082 if (
auto aConnect = dyn_cast<StrictConnectOp>(subuser))
2083 if (aConnect.getDest() == subField)
2084 rewriter.eraseOp(aConnect);
2091 struct WireAggOneShot :
public AggOneShot {
2092 WireAggOneShot(MLIRContext *context)
2093 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2095 struct SubindexAggOneShot :
public AggOneShot {
2096 SubindexAggOneShot(MLIRContext *context)
2097 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2099 struct SubfieldAggOneShot :
public AggOneShot {
2100 SubfieldAggOneShot(MLIRContext *context)
2101 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2105 void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2106 MLIRContext *context) {
2107 results.insert<WireAggOneShot>(context);
2108 results.add(demoteForceableIfUnused<WireOp>);
2111 void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2112 MLIRContext *context) {
2113 results.insert<SubindexAggOneShot>(context);
2116 OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2117 auto attr = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
2120 return attr[getIndex()];
2123 OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2124 auto attr = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
2127 auto index = getFieldIndex();
2131 void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2132 MLIRContext *context) {
2133 results.insert<SubfieldAggOneShot>(context);
2137 ArrayRef<Attribute> operands) {
2138 for (
auto operand : operands)
2144 OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2147 if (getNumOperands() > 0)
2148 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2149 if (first.getFieldIndex() == 0 &&
2150 first.getInput().getType() == getType() &&
2152 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2154 elem.value().
template getDefiningOp<SubfieldOp>();
2155 return subindex && subindex.getInput() == first.getInput() &&
2156 subindex.getFieldIndex() == elem.index();
2158 return first.getInput();
2163 OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2166 if (getNumOperands() > 0)
2167 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2168 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2170 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2172 elem.value().
template getDefiningOp<SubindexOp>();
2173 return subindex && subindex.getInput() == first.getInput() &&
2174 subindex.getIndex() == elem.index();
2176 return first.getInput();
2181 OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2182 if (getOperand().getType() == getType())
2183 return getOperand();
2190 struct FoldResetMux :
public mlir::RewritePattern {
2191 FoldResetMux(MLIRContext *context)
2192 : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2193 LogicalResult matchAndRewrite(Operation *op,
2194 PatternRewriter &rewriter)
const override {
2195 auto reg = cast<RegResetOp>(op);
2197 dyn_cast_or_null<ConstantOp>(
reg.getResetValue().getDefiningOp());
2206 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2209 auto *high = mux.getHigh().getDefiningOp();
2210 auto *low = mux.getLow().getDefiningOp();
2211 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2213 if (constOp && low !=
reg)
2215 if (dyn_cast_or_null<ConstantOp>(low) && high ==
reg)
2216 constOp = dyn_cast<ConstantOp>(low);
2218 if (!constOp || constOp.getType() != reset.getType() ||
2219 constOp.getValue() != reset.getValue())
2223 auto regTy =
reg.getResult().getType();
2224 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2225 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2226 regTy.getBitWidthOrSentinel() < 0)
2232 if (constOp != &con->getBlock()->front())
2233 constOp->moveBefore(&con->getBlock()->front());
2238 rewriter.eraseOp(con);
2245 if (
auto c = v.getDefiningOp<ConstantOp>())
2246 return c.getValue().isOne();
2247 if (
auto sc = v.getDefiningOp<SpecialConstantOp>())
2248 return sc.getValue();
2252 static LogicalResult
2259 replaceOpWithNewOpAndCopyName<NodeOp>(
2260 rewriter,
reg,
reg.getResetValue(),
reg.getNameAttr(),
reg.getNameKind(),
2261 reg.getAnnotationsAttr(),
reg.getInnerSymAttr(),
reg.getForceable());
2265 void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2266 MLIRContext *context) {
2267 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2269 results.add(demoteForceableIfUnused<RegResetOp>);
2274 auto portTy = type_cast<BundleType>(port.getType());
2275 auto fieldIndex = portTy.getElementIndex(name);
2276 assert(fieldIndex &&
"missing field on memory port");
2279 for (
auto *op : port.getUsers()) {
2280 auto portAccess = cast<SubfieldOp>(op);
2281 if (fieldIndex != portAccess.getFieldIndex())
2286 value = conn.getSrc();
2296 auto portConst = value.getDefiningOp<ConstantOp>();
2299 return portConst.getValue().isZero();
2304 auto portTy = type_cast<BundleType>(port.getType());
2305 auto fieldIndex = portTy.getElementIndex(
data);
2306 assert(fieldIndex &&
"missing enable flag on memory port");
2308 for (
auto *op : port.getUsers()) {
2309 auto portAccess = cast<SubfieldOp>(op);
2310 if (fieldIndex != portAccess.getFieldIndex())
2312 if (!portAccess.use_empty())
2321 StringRef name, Value value) {
2322 auto portTy = type_cast<BundleType>(port.getType());
2323 auto fieldIndex = portTy.getElementIndex(name);
2324 assert(fieldIndex &&
"missing field on memory port");
2326 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2327 auto portAccess = cast<SubfieldOp>(op);
2328 if (fieldIndex != portAccess.getFieldIndex())
2330 rewriter.replaceAllUsesWith(portAccess, value);
2331 rewriter.eraseOp(portAccess);
2336 static void erasePort(PatternRewriter &rewriter, Value port) {
2339 auto getClock = [&] {
2341 clock = rewriter.create<SpecialConstantOp>(
2350 for (
auto *op : port.getUsers()) {
2351 auto subfield = dyn_cast<SubfieldOp>(op);
2353 auto ty = port.getType();
2354 auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
2355 port.replaceAllUsesWith(
reg.getResult());
2364 for (
auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2365 auto access = cast<SubfieldOp>(accessOp);
2366 for (
auto *user : llvm::make_early_inc_range(access->getUsers())) {
2367 auto connect = dyn_cast<FConnectLike>(user);
2369 rewriter.eraseOp(user);
2373 if (access.use_empty()) {
2374 rewriter.eraseOp(access);
2380 auto ty = access.getType();
2381 auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
2382 rewriter.replaceOp(access,
reg.getResult());
2384 assert(port.use_empty() &&
"port should have no remaining uses");
2389 struct FoldZeroWidthMemory :
public mlir::RewritePattern {
2390 FoldZeroWidthMemory(MLIRContext *context)
2391 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2392 LogicalResult matchAndRewrite(Operation *op,
2393 PatternRewriter &rewriter)
const override {
2394 MemOp mem = cast<MemOp>(op);
2398 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2399 mem.getDataType().getBitWidthOrSentinel() != 0)
2403 for (
auto port : mem.getResults())
2404 for (
auto *user : port.getUsers())
2405 if (!isa<SubfieldOp>(user))
2410 for (
auto port : op->getResults()) {
2411 for (
auto *user : llvm::make_early_inc_range(port.getUsers())) {
2412 SubfieldOp sfop = cast<SubfieldOp>(user);
2413 StringRef fieldName = sfop.getFieldName();
2414 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2415 rewriter, sfop, sfop.getResult().getType())
2417 if (fieldName.ends_with(
"data")) {
2419 auto zero = rewriter.create<firrtl::ConstantOp>(
2420 wire.getLoc(), firrtl::type_cast<IntType>(wire.getType()),
2422 rewriter.create<StrictConnectOp>(wire.getLoc(), wire, zero);
2426 rewriter.eraseOp(op);
2432 struct FoldReadOrWriteOnlyMemory :
public mlir::RewritePattern {
2433 FoldReadOrWriteOnlyMemory(MLIRContext *context)
2434 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2435 LogicalResult matchAndRewrite(Operation *op,
2436 PatternRewriter &rewriter)
const override {
2437 MemOp mem = cast<MemOp>(op);
2440 bool isRead =
false, isWritten =
false;
2441 for (
unsigned i = 0; i < mem.getNumResults(); ++i) {
2442 switch (mem.getPortKind(i)) {
2443 case MemOp::PortKind::Read:
2448 case MemOp::PortKind::Write:
2453 case MemOp::PortKind::Debug:
2454 case MemOp::PortKind::ReadWrite:
2457 llvm_unreachable(
"unknown port kind");
2459 assert((!isWritten || !isRead) &&
"memory is in use");
2464 if (isRead && mem.getInit())
2467 for (
auto port : mem.getResults())
2470 rewriter.eraseOp(op);
2476 struct FoldUnusedPorts :
public mlir::RewritePattern {
2477 FoldUnusedPorts(MLIRContext *context)
2478 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2479 LogicalResult matchAndRewrite(Operation *op,
2480 PatternRewriter &rewriter)
const override {
2481 MemOp mem = cast<MemOp>(op);
2485 llvm::SmallBitVector deadPorts(mem.getNumResults());
2486 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2488 if (!mem.getPortAnnotation(i).empty())
2492 auto kind = mem.getPortKind(i);
2493 if (kind == MemOp::PortKind::Debug)
2502 if (kind == MemOp::PortKind::Read &&
isPortUnused(port,
"data")) {
2507 if (deadPorts.none())
2511 SmallVector<Type> resultTypes;
2512 SmallVector<StringRef> portNames;
2513 SmallVector<Attribute> portAnnotations;
2514 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2517 resultTypes.push_back(port.getType());
2518 portNames.push_back(mem.getPortName(i));
2519 portAnnotations.push_back(mem.getPortAnnotation(i));
2523 if (!resultTypes.empty())
2524 newOp = rewriter.create<MemOp>(
2525 mem.getLoc(), resultTypes, mem.getReadLatency(),
2526 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2527 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2528 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2529 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2532 unsigned nextPort = 0;
2533 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2537 port.replaceAllUsesWith(newOp.getResult(nextPort++));
2540 rewriter.eraseOp(op);
2546 struct FoldReadWritePorts :
public mlir::RewritePattern {
2547 FoldReadWritePorts(MLIRContext *context)
2548 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2549 LogicalResult matchAndRewrite(Operation *op,
2550 PatternRewriter &rewriter)
const override {
2551 MemOp mem = cast<MemOp>(op);
2556 llvm::SmallBitVector deadReads(mem.getNumResults());
2557 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2558 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2560 if (!mem.getPortAnnotation(i).empty())
2567 if (deadReads.none())
2570 SmallVector<Type> resultTypes;
2571 SmallVector<StringRef> portNames;
2572 SmallVector<Attribute> portAnnotations;
2573 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2575 resultTypes.push_back(
2576 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2577 MemOp::PortKind::Write, mem.getMaskBits()));
2579 resultTypes.push_back(port.getType());
2581 portNames.push_back(mem.getPortName(i));
2582 portAnnotations.push_back(mem.getPortAnnotation(i));
2585 auto newOp = rewriter.create<MemOp>(
2586 mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
2587 mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
2588 mem.getName(), mem.getNameKind(), mem.getAnnotations(),
2589 rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
2590 mem.getInitAttr(), mem.getPrefixAttr());
2592 for (
unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2593 auto result = mem.getResult(i);
2594 auto newResult = newOp.getResult(i);
2596 auto resultPortTy = type_cast<BundleType>(result.getType());
2600 auto replace = [&](StringRef toName, StringRef fromName) {
2601 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2602 assert(fromFieldIndex &&
"missing enable flag on memory port");
2604 auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
2606 for (
auto *op : llvm::make_early_inc_range(result.getUsers())) {
2607 auto fromField = cast<SubfieldOp>(op);
2608 if (fromFieldIndex != fromField.getFieldIndex())
2610 rewriter.replaceOp(fromField, toField.getResult());
2614 replace(
"addr",
"addr");
2615 replace(
"en",
"en");
2616 replace(
"clk",
"clk");
2617 replace(
"data",
"wdata");
2618 replace(
"mask",
"wmask");
2621 auto wmodeFieldIndex = resultPortTy.getElementIndex(
"wmode");
2622 for (
auto *op : llvm::make_early_inc_range(result.getUsers())) {
2623 auto wmodeField = cast<SubfieldOp>(op);
2624 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2626 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2629 result.replaceAllUsesWith(newResult);
2632 rewriter.eraseOp(op);
2638 struct FoldUnusedBits :
public mlir::RewritePattern {
2639 FoldUnusedBits(MLIRContext *context)
2640 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2642 LogicalResult matchAndRewrite(Operation *op,
2643 PatternRewriter &rewriter)
const override {
2644 MemOp mem = cast<MemOp>(op);
2649 const auto &summary = mem.getSummary();
2650 if (summary.isMasked || summary.isSeqMem())
2653 auto type = type_dyn_cast<IntType>(mem.getDataType());
2656 auto width = type.getBitWidthOrSentinel();
2660 llvm::SmallBitVector usedBits(
width);
2661 DenseMap<unsigned, unsigned> mapping;
2666 SmallVector<BitsPrimOp> readOps;
2667 auto findReadUsers = [&](Value port, StringRef field) {
2668 auto portTy = type_cast<BundleType>(port.getType());
2669 auto fieldIndex = portTy.getElementIndex(field);
2670 assert(fieldIndex &&
"missing data port");
2672 for (
auto *op : port.getUsers()) {
2673 auto portAccess = cast<SubfieldOp>(op);
2674 if (fieldIndex != portAccess.getFieldIndex())
2677 for (
auto *user : op->getUsers()) {
2678 auto bits = dyn_cast<BitsPrimOp>(user);
2684 usedBits.set(bits.getLo(), bits.getHi() + 1);
2685 mapping[bits.getLo()] = 0;
2686 readOps.push_back(bits);
2694 SmallVector<StrictConnectOp> writeOps;
2695 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
2696 auto portTy = type_cast<BundleType>(port.getType());
2697 auto fieldIndex = portTy.getElementIndex(field);
2698 assert(fieldIndex &&
"missing data port");
2700 for (
auto *op : port.getUsers()) {
2701 auto portAccess = cast<SubfieldOp>(op);
2702 if (fieldIndex != portAccess.getFieldIndex())
2709 writeOps.push_back(conn);
2715 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2717 if (!mem.getPortAnnotation(i).empty())
2720 switch (mem.getPortKind(i)) {
2721 case MemOp::PortKind::Debug:
2724 case MemOp::PortKind::Write:
2725 if (failed(findWriteUsers(port,
"data")))
2728 case MemOp::PortKind::Read:
2729 findReadUsers(port,
"data");
2731 case MemOp::PortKind::ReadWrite:
2732 if (failed(findWriteUsers(port,
"wdata")))
2734 findReadUsers(port,
"rdata");
2737 llvm_unreachable(
"unknown port kind");
2742 if (usedBits.all() || usedBits.none())
2746 SmallVector<std::pair<unsigned, unsigned>> ranges;
2747 unsigned newWidth = 0;
2748 for (
int i = usedBits.find_first(); 0 <= i && i <
width;) {
2749 int e = usedBits.find_next_unset(i);
2752 for (
int idx = i; idx < e; ++idx, ++newWidth) {
2753 if (
auto it = mapping.find(idx); it != mapping.end()) {
2754 it->second = newWidth;
2757 ranges.emplace_back(i, e - 1);
2758 i = e !=
width ? usedBits.find_next(e) : e;
2762 auto newType =
IntType::get(op->getContext(), type.isSigned(), newWidth);
2763 SmallVector<Type> portTypes;
2764 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2765 portTypes.push_back(
2766 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
2768 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
2769 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
2770 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
2771 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
2772 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2775 auto rewriteSubfield = [&](Value port, StringRef field) {
2776 auto portTy = type_cast<BundleType>(port.getType());
2777 auto fieldIndex = portTy.getElementIndex(field);
2778 assert(fieldIndex &&
"missing data port");
2780 rewriter.setInsertionPointAfter(newMem);
2781 auto newPortAccess =
2782 rewriter.create<SubfieldOp>(port.getLoc(), port, field);
2784 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2785 auto portAccess = cast<SubfieldOp>(op);
2786 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
2788 rewriter.replaceOp(portAccess, newPortAccess.getResult());
2793 for (
auto [i, port] : llvm::enumerate(newMem.getResults())) {
2794 switch (newMem.getPortKind(i)) {
2795 case MemOp::PortKind::Debug:
2796 llvm_unreachable(
"cannot rewrite debug port");
2797 case MemOp::PortKind::Write:
2798 rewriteSubfield(port,
"data");
2800 case MemOp::PortKind::Read:
2801 rewriteSubfield(port,
"data");
2803 case MemOp::PortKind::ReadWrite:
2804 rewriteSubfield(port,
"rdata");
2805 rewriteSubfield(port,
"wdata");
2808 llvm_unreachable(
"unknown port kind");
2812 for (
auto readOp : readOps) {
2813 rewriter.setInsertionPointAfter(readOp);
2814 auto it = mapping.find(readOp.getLo());
2815 assert(it != mapping.end() &&
"bit op mapping not found");
2816 rewriter.replaceOpWithNewOp<BitsPrimOp>(
2817 readOp, readOp.getInput(),
2818 readOp.getHi() - readOp.getLo() + it->second, it->second);
2822 for (
auto writeOp : writeOps) {
2823 Value source = writeOp.getSrc();
2824 rewriter.setInsertionPoint(writeOp);
2827 for (
auto &[start, end] : ranges) {
2829 rewriter.create<BitsPrimOp>(writeOp.getLoc(), source,
end, start);
2832 rewriter.create<CatPrimOp>(writeOp.getLoc(), slice, catOfSlices);
2834 catOfSlices = slice;
2837 rewriter.replaceOpWithNewOp<StrictConnectOp>(writeOp, writeOp.getDest(),
2846 struct FoldRegMems :
public mlir::RewritePattern {
2847 FoldRegMems(MLIRContext *context)
2848 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2849 LogicalResult matchAndRewrite(Operation *op,
2850 PatternRewriter &rewriter)
const override {
2851 MemOp mem = cast<MemOp>(op);
2852 const FirMemory &info = mem.getSummary();
2856 auto memModule = mem->getParentOfType<FModuleOp>();
2860 SmallPtrSet<Operation *, 8> connects;
2861 SmallVector<SubfieldOp> portAccesses;
2862 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2863 if (!mem.getPortAnnotation(i).empty())
2866 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
2867 auto portTy = type_cast<BundleType>(port.getType());
2868 for (
auto field : fields) {
2869 auto fieldIndex = portTy.getElementIndex(field);
2870 assert(fieldIndex &&
"missing field on memory port");
2872 for (
auto *op : port.getUsers()) {
2873 auto portAccess = cast<SubfieldOp>(op);
2874 if (fieldIndex != portAccess.getFieldIndex())
2876 portAccesses.push_back(portAccess);
2877 for (
auto *user : portAccess->getUsers()) {
2878 auto conn = dyn_cast<FConnectLike>(user);
2881 connects.insert(conn);
2888 switch (mem.getPortKind(i)) {
2889 case MemOp::PortKind::Debug:
2891 case MemOp::PortKind::Read:
2892 if (failed(collect({
"clk",
"en",
"addr"})))
2895 case MemOp::PortKind::Write:
2896 if (failed(collect({
"clk",
"en",
"addr",
"data",
"mask"})))
2899 case MemOp::PortKind::ReadWrite:
2900 if (failed(collect({
"clk",
"en",
"addr",
"wmode",
"wdata",
"wmask"})))
2906 if (!portClock || (clock && portClock != clock))
2912 auto ty = mem.getDataType();
2913 rewriter.setInsertionPointAfterValue(clock);
2914 auto reg = rewriter.create<RegOp>(mem.getLoc(), ty, clock, mem.getName())
2918 auto pipeline = [&](Value value, Value clock,
const Twine &name,
2920 for (
unsigned i = 0; i < latency; ++i) {
2921 std::string regName;
2923 llvm::raw_string_ostream os(regName);
2924 os << mem.getName() <<
"_" << name <<
"_" << i;
2928 .create<RegOp>(mem.getLoc(), value.getType(), clock,
2929 rewriter.getStringAttr(regName))
2931 rewriter.create<StrictConnectOp>(value.getLoc(),
reg, value);
2942 SmallVector<std::tuple<Value, Value, Value>> writes;
2943 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2945 StringRef name = mem.getPortName(i);
2947 auto portPipeline = [&, port = port](StringRef field,
unsigned stages) {
2950 rewriter.setInsertionPointAfterValue(value);
2951 return pipeline(value, portClock, name +
"_" + field, stages);
2954 switch (mem.getPortKind(i)) {
2955 case MemOp::PortKind::Debug:
2956 llvm_unreachable(
"unknown port kind");
2957 case MemOp::PortKind::Read: {
2962 rewriter.setInsertionPointAfterValue(
reg);
2966 case MemOp::PortKind::Write: {
2967 auto data = portPipeline(
"data", writeStages);
2968 auto en = portPipeline(
"en", writeStages);
2969 auto mask = portPipeline(
"mask", writeStages);
2970 writes.emplace_back(data, en, mask);
2973 case MemOp::PortKind::ReadWrite: {
2975 rewriter.setInsertionPointAfterValue(
reg);
2979 auto wdata = portPipeline(
"wdata", writeStages);
2980 auto wmask = portPipeline(
"wmask", writeStages);
2984 rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
2986 auto wen = rewriter.create<AndPrimOp>(port.getLoc(),
en,
wmode);
2988 pipeline(wen, portClock, name +
"_wen", writeStages);
2989 writes.emplace_back(wdata, wenPipelined, wmask);
2996 rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
2998 for (
auto &[data, en, mask] : writes) {
3003 Location loc = mem.getLoc();
3005 for (
unsigned i = 0; i < info.
maskBits; ++i) {
3006 unsigned hi = (i + 1) * maskGran - 1;
3007 unsigned lo = i * maskGran;
3009 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc,
data, hi, lo);
3010 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3011 auto bit = rewriter.createOrFold<BitsPrimOp>(loc,
mask, i, i);
3012 auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
3015 masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
3021 next = rewriter.create<MuxPrimOp>(next.getLoc(),
en, masked, next);
3023 rewriter.create<StrictConnectOp>(
reg.getLoc(),
reg, next);
3026 for (Operation *conn : connects)
3027 rewriter.eraseOp(conn);
3028 for (
auto portAccess : portAccesses)
3029 rewriter.eraseOp(portAccess);
3030 rewriter.eraseOp(mem);
3037 void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3038 MLIRContext *context) {
3040 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3041 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3061 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3064 auto *high = mux.getHigh().getDefiningOp();
3065 auto *low = mux.getLow().getDefiningOp();
3067 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3074 bool constReg =
false;
3076 if (constOp && low ==
reg)
3078 else if (dyn_cast_or_null<ConstantOp>(low) && high ==
reg) {
3080 constOp = dyn_cast<ConstantOp>(low);
3087 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3091 auto regTy =
reg.getResult().getType();
3092 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3093 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3094 regTy.getBitWidthOrSentinel() < 0)
3100 if (constOp != &con->getBlock()->front())
3101 constOp->moveBefore(&con->getBlock()->front());
3104 SmallVector<NamedAttribute, 2> attrs(
reg->getDialectAttrs());
3105 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3106 rewriter,
reg,
reg.getResult().getType(),
reg.getClockVal(),
3107 mux.getSel(), mux.getHigh(),
reg.getNameAttr(),
reg.getNameKindAttr(),
3108 reg.getAnnotationsAttr(),
reg.getInnerSymAttr(),
3109 reg.getForceableAttr());
3110 newReg->setDialectAttrs(attrs);
3112 auto pt = rewriter.saveInsertionPoint();
3113 rewriter.setInsertionPoint(con);
3114 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3115 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3116 rewriter.restoreInsertionPoint(pt);
3120 LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3121 if (!
hasDontTouch(op.getOperation()) && !op.isForceable() &&
3137 PatternRewriter &rewriter,
3140 if (
auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3141 if (constant.getValue().isZero()) {
3142 rewriter.eraseOp(op);
3148 if (
auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3149 if (constant.getValue().isZero() == eraseIfZero) {
3150 rewriter.eraseOp(op);
3158 template <
class Op,
bool EraseIfZero = false>
3160 PatternRewriter &rewriter) {
3165 void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3166 MLIRContext *context) {
3167 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3170 void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3171 MLIRContext *context) {
3172 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3175 void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3176 RewritePatternSet &results, MLIRContext *context) {
3177 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3180 void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3181 MLIRContext *context) {
3182 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3189 LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3190 PatternRewriter &rewriter) {
3192 if (op.use_empty()) {
3193 rewriter.eraseOp(op);
3203 OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3221 LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3222 PatternRewriter &rewriter) {
3224 if (
auto testEnable = op.getTestEnable()) {
3225 if (
auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3226 if (constOp.getValue().isZero()) {
3227 rewriter.modifyOpInPlace(op,
3228 [&] { op.getTestEnableMutable().clear(); });
3242 static LogicalResult
3244 auto forceable = op.getRef().getDefiningOp<Forceable>();
3245 if (!forceable || !forceable.isForceable() ||
3246 op.getRef() != forceable.getDataRef() ||
3247 op.getType() != forceable.getDataType())
3249 rewriter.replaceAllUsesWith(op, forceable.getData());
3253 void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3254 MLIRContext *context) {
3255 results.insert<patterns::RefResolveOfRefSend>(context);
3259 OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3261 if (getInput().getType() == getType())
3267 auto constOp = operand.getDefiningOp<ConstantOp>();
3268 return constOp && constOp.getValue().isZero();
3271 template <
typename Op>
3274 rewriter.eraseOp(op);
3280 void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3281 MLIRContext *context) {
3282 results.add(eraseIfPredFalse<RefForceOp>);
3284 void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3285 MLIRContext *context) {
3286 results.add(eraseIfPredFalse<RefForceInitialOp>);
3288 void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3289 MLIRContext *context) {
3290 results.add(eraseIfPredFalse<RefReleaseOp>);
3292 void RefReleaseInitialOp::getCanonicalizationPatterns(
3293 RewritePatternSet &results, MLIRContext *context) {
3294 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3301 OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3307 if (adaptor.getReset())
3312 if (
isUInt1(getReset().getType()) && adaptor.getClock())
3325 [&](
auto ty) ->
bool {
return isTypeEmpty(ty.getElementType()); })
3326 .Case<BundleType>([&](
auto ty) ->
bool {
3327 for (
auto elem : ty.getElements())
3332 .Case<IntType>([&](
auto ty) {
return ty.getWidth() == 0; })
3333 .Default([](
auto) ->
bool {
return false; });
3336 LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3337 PatternRewriter &rewriter) {
3338 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3345 rewriter.eraseOp(op);
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeSingleSetConnect(StrictConnectOp op, PatternRewriter &rewriter)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion)
Replaces the given op with the contents of the given single-block region.
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
bool hasWidth() const
Return true if this integer type has a known width.
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
StrictConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)