20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/APSInt.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
27 using namespace circt;
28 using namespace firrtl;
32 static Value
dropWrite(PatternRewriter &rewriter, OpResult old,
34 SmallPtrSet<Operation *, 8> users;
35 for (
auto *user : old.getUsers())
37 for (Operation *user : users)
38 if (
auto connect = dyn_cast<FConnectLike>(user))
40 rewriter.eraseOp(user);
49 Operation *op = passthrough.getDefiningOp();
52 assert(op &&
"passthrough must be an operation");
53 Operation *oldOp = old.getOwner();
54 auto name = oldOp->getAttrOfType<StringAttr>(
"name");
55 if (name && !name.getValue().empty())
56 op->setAttr(
"name", name);
64 #include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
72 auto resultType = type_cast<IntType>(op->getResult(0).getType());
73 if (!resultType.hasWidth())
75 for (Value operand : op->getOperands())
76 if (!type_cast<IntType>(operand.getType()).hasWidth())
83 auto t = type_dyn_cast<UIntType>(type);
84 if (!t || !t.hasWidth() || t.getWidth() != 1)
91 static void updateName(PatternRewriter &rewriter, Operation *op,
94 assert(!isa<InstanceOp>(op));
95 if (!name || name.getValue().empty())
97 auto newName = name.getValue();
98 auto newOpName = op->getAttrOfType<StringAttr>(
"name");
101 newName =
chooseName(newOpName.getValue(), name.getValue());
103 if (!newOpName || newOpName.getValue() != newName)
104 rewriter.modifyOpInPlace(
105 op, [&] { op->setAttr(
"name", rewriter.getStringAttr(newName)); });
113 if (
auto *newOp = newValue.getDefiningOp()) {
114 auto name = op->getAttrOfType<StringAttr>(
"name");
117 rewriter.replaceOp(op, newValue);
123 template <
typename OpTy,
typename... Args>
125 Operation *op, Args &&...args) {
126 auto name = op->getAttrOfType<StringAttr>(
"name");
128 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
136 if (
auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
137 return namableOp.hasDroppableName();
148 static std::optional<APSInt>
150 assert(type_cast<IntType>(operand.getType()) &&
151 "getExtendedConstant is limited to integer types");
158 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
163 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
164 return APSInt(destWidth,
165 type_cast<IntType>(operand.getType()).isUnsigned());
173 if (
auto attr = dyn_cast<BoolAttr>(operand))
174 return APSInt(APInt(1, attr.getValue()));
175 if (
auto attr = dyn_cast<IntegerAttr>(operand))
176 return attr.getAPSInt();
184 return cst->isZero();
211 Operation *op, ArrayRef<Attribute> operands,
BinOpKind opKind,
212 const function_ref<APInt(
const APSInt &,
const APSInt &)> &calculate) {
213 assert(operands.size() == 2 &&
"binary op takes two operands");
216 auto resultType = type_cast<IntType>(op->getResult(0).getType());
217 if (resultType.getWidthOrSentinel() < 0)
221 if (resultType.getWidthOrSentinel() == 0)
222 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
228 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
230 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
231 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
232 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
233 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
234 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
238 int32_t operandWidth;
241 operandWidth = resultType.getWidthOrSentinel();
246 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
250 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
261 APInt resultValue = calculate(*lhs, *rhs);
266 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
268 assert((
unsigned)resultType.getWidthOrSentinel() ==
269 resultValue.getBitWidth());
282 Operation *op, PatternRewriter &rewriter,
283 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &
canonicalize) {
285 if (op->getNumResults() != 1)
287 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
292 auto width = type.getBitWidthOrSentinel();
297 SmallVector<Attribute, 3> constOperands;
298 constOperands.reserve(op->getNumOperands());
299 for (
auto operand : op->getOperands()) {
301 if (
auto *defOp = operand.getDefiningOp())
302 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
303 [&](
auto op) { attr = op.getValueAttr(); });
304 constOperands.push_back(attr);
313 if (
auto cst = dyn_cast<Attribute>(result))
314 resultValue = op->getDialect()
315 ->materializeConstant(rewriter, cst, type, op->getLoc())
318 resultValue = result.get<Value>();
322 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
323 resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue,
width);
326 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
327 resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
328 else if (type_isa<UIntType>(type) &&
329 type_isa<SIntType>(resultValue.getType()))
330 resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
332 assert(type == resultValue.getType() &&
"canonicalization changed type");
340 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
346 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
352 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
359 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
360 assert(adaptor.getOperands().empty() &&
"constant has no operands");
361 return getValueAttr();
364 OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
365 assert(adaptor.getOperands().empty() &&
"constant has no operands");
366 return getValueAttr();
369 OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
370 assert(adaptor.getOperands().empty() &&
"constant has no operands");
371 return getFieldsAttr();
374 OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
375 assert(adaptor.getOperands().empty() &&
"constant has no operands");
376 return getValueAttr();
379 OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
380 assert(adaptor.getOperands().empty() &&
"constant has no operands");
381 return getValueAttr();
384 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
385 assert(adaptor.getOperands().empty() &&
"constant has no operands");
386 return getValueAttr();
389 OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
390 assert(adaptor.getOperands().empty() &&
"constant has no operands");
391 return getValueAttr();
398 OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
401 [=](
const APSInt &a,
const APSInt &b) { return a + b; });
404 void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
405 MLIRContext *context) {
406 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
407 patterns::AddOfSelf, patterns::AddOfPad>(context);
410 OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
413 [=](
const APSInt &a,
const APSInt &b) { return a - b; });
416 void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 MLIRContext *context) {
418 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
419 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
420 patterns::SubOfPadL, patterns::SubOfPadR>(context);
423 OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
435 [=](
const APSInt &a,
const APSInt &b) { return a * b; });
438 OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
445 if (getLhs() == getRhs()) {
446 auto width = getType().base().getWidthOrSentinel();
468 if (
auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
469 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
474 [=](
const APSInt &a,
const APSInt &b) -> APInt {
477 return APInt(a.getBitWidth(), 0);
481 OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
488 if (getLhs() == getRhs())
502 [=](
const APSInt &a,
const APSInt &b) -> APInt {
505 return APInt(a.getBitWidth(), 0);
509 OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
512 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
515 OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
518 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
521 OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
524 [=](
const APSInt &a,
const APSInt &b) -> APInt {
525 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
531 OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
534 if (rhsCst->isZero())
538 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
539 getRhs().getType() == getType())
545 if (lhsCst->isZero())
549 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
555 if (getLhs() == getRhs() && getRhs().getType() == getType())
560 [](
const APSInt &a,
const APSInt &b) -> APInt { return a & b; });
563 void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
564 MLIRContext *context) {
566 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
567 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
568 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
571 OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
574 if (rhsCst->isZero() && getLhs().getType() == getType())
578 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
579 getLhs().getType() == getType())
585 if (lhsCst->isZero() && getRhs().getType() == getType())
589 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
590 getRhs().getType() == getType())
595 if (getLhs() == getRhs() && getRhs().getType() == getType())
600 [](
const APSInt &a,
const APSInt &b) -> APInt { return a | b; });
603 void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
604 MLIRContext *context) {
605 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
606 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
607 patterns::OrOrr>(context);
610 OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
613 if (rhsCst->isZero() &&
619 if (lhsCst->isZero() &&
624 if (getLhs() == getRhs())
627 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
631 [](
const APSInt &a,
const APSInt &b) -> APInt { return a ^ b; });
634 void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
635 MLIRContext *context) {
636 results.insert<patterns::extendXor, patterns::moveConstXor,
637 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
641 void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
642 MLIRContext *context) {
643 results.insert<patterns::LEQWithConstLHS>(context);
646 OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
647 bool isUnsigned = getLhs().getType().base().isUnsigned();
650 if (getLhs() == getRhs())
656 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
657 commonWidth = std::max(commonWidth, 1);
668 if (isUnsigned && rhsCst->zext(commonWidth)
681 [=](
const APSInt &a,
const APSInt &b) -> APInt {
682 return APInt(1, a <= b);
686 void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
687 MLIRContext *context) {
688 results.insert<patterns::LTWithConstLHS>(context);
691 OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
692 IntType lhsType = getLhs().getType();
696 if (getLhs() == getRhs())
708 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
709 commonWidth = std::max(commonWidth, 1);
720 if (isUnsigned && rhsCst->zext(commonWidth)
733 [=](
const APSInt &a,
const APSInt &b) -> APInt {
734 return APInt(1, a < b);
738 void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
739 MLIRContext *context) {
740 results.insert<patterns::GEQWithConstLHS>(context);
743 OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
744 IntType lhsType = getLhs().getType();
748 if (getLhs() == getRhs())
753 if (rhsCst->isZero() && isUnsigned)
760 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
761 commonWidth = std::max(commonWidth, 1);
764 if (isUnsigned && rhsCst->zext(commonWidth)
785 [=](
const APSInt &a,
const APSInt &b) -> APInt {
786 return APInt(1, a >= b);
790 void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
791 MLIRContext *context) {
792 results.insert<patterns::GTWithConstLHS>(context);
795 OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
796 IntType lhsType = getLhs().getType();
800 if (getLhs() == getRhs())
806 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
807 commonWidth = std::max(commonWidth, 1);
810 if (isUnsigned && rhsCst->zext(commonWidth)
831 [=](
const APSInt &a,
const APSInt &b) -> APInt {
832 return APInt(1, a > b);
836 OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
838 if (getLhs() == getRhs())
844 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
845 getRhs().getType() == getType())
851 [=](
const APSInt &a,
const APSInt &b) -> APInt {
852 return APInt(1, a == b);
858 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
860 auto width = op.getLhs().getType().getBitWidthOrSentinel();
863 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
864 op.getRhs().getType() == op.getType()) {
865 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
870 if (rhsCst->isZero() &&
width > 1) {
871 auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
872 return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
876 if (rhsCst->isAllOnes() &&
width > 1 &&
877 op.getLhs().getType() == op.getRhs().getType()) {
878 return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
886 OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
888 if (getLhs() == getRhs())
894 if (rhsCst->isZero() && getLhs().getType() == getType() &&
895 getRhs().getType() == getType())
901 [=](
const APSInt &a,
const APSInt &b) -> APInt {
902 return APInt(1, a != b);
908 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
910 auto width = op.getLhs().getType().getBitWidthOrSentinel();
913 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
914 op.getRhs().getType() == op.getType()) {
915 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
920 if (rhsCst->isZero() &&
width > 1) {
921 return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
926 if (rhsCst->isAllOnes() &&
width > 1 &&
927 op.getLhs().getType() == op.getRhs().getType()) {
928 auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
929 return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
937 OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
943 OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
949 OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
955 OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
962 lhsCst->shl(*rhsCst));
965 if (rhsCst->isZero())
976 OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
977 auto base = getInput().getType();
984 OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
991 OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
999 if (getType().base().hasWidth())
1006 void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1007 MLIRContext *context) {
1008 results.insert<patterns::StoUtoS>(context);
1011 OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1019 if (getType().base().hasWidth())
1026 void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1027 MLIRContext *context) {
1028 results.insert<patterns::UtoStoU>(context);
1031 OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1033 if (getInput().getType() == getType())
1043 OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1045 if (getInput().getType() == getType())
1055 OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1061 getType().base().getWidthOrSentinel()))
1067 void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1068 MLIRContext *context) {
1069 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1072 OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1079 getType().base().getWidthOrSentinel()))
1080 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1085 OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1090 getType().base().getWidthOrSentinel()))
1096 void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1097 MLIRContext *context) {
1098 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1099 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1100 patterns::NotGt>(context);
1103 OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1107 if (getInput().getType().getBitWidthOrSentinel() == 0)
1112 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1116 if (
isUInt1(getInput().getType()))
1122 void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1123 MLIRContext *context) {
1125 .insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1126 patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
1127 patterns::AndRCatZeroL, patterns::AndRCatZeroR,
1128 patterns::AndRCatAndR_left, patterns::AndRCatAndR_right>(context);
1131 OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1135 if (getInput().getType().getBitWidthOrSentinel() == 0)
1140 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1144 if (
isUInt1(getInput().getType()))
1150 void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1151 MLIRContext *context) {
1152 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1153 patterns::OrRCatZeroH, patterns::OrRCatZeroL,
1154 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right>(context);
1157 OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1161 if (getInput().getType().getBitWidthOrSentinel() == 0)
1166 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1169 if (
isUInt1(getInput().getType()))
1175 void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1176 MLIRContext *context) {
1177 results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1178 patterns::XorRCatZeroH, patterns::XorRCatZeroL,
1179 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right>(
1187 OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1191 IntType lhsType = getLhs().getType();
1192 IntType rhsType = getRhs().getType();
1204 return getIntAttr(getType(), lhs->concat(*rhs));
1209 void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1210 MLIRContext *context) {
1211 results.insert<patterns::DShlOfConstant>(context);
1214 void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1215 MLIRContext *context) {
1216 results.insert<patterns::DShrOfConstant>(context);
1219 void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1220 MLIRContext *context) {
1221 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1222 patterns::CatCast>(context);
1225 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1228 if (op.getType() == op.getInput().getType())
1229 return op.getInput();
1233 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1234 if (op.getType() == in.getInput().getType())
1235 return in.getInput();
1240 OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1241 IntType inputType = getInput().getType();
1242 IntType resultType = getType();
1244 if (inputType == getType() && resultType.
hasWidth())
1251 cst->extractBits(getHi() - getLo() + 1, getLo()));
1256 void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1257 MLIRContext *context) {
1259 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1260 patterns::BitsOfAnd, patterns::BitsOfPad>(context);
1267 unsigned loBit, PatternRewriter &rewriter) {
1268 auto resType = type_cast<IntType>(op->getResult(0).getType());
1269 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1270 value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
1272 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1273 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1274 }
else if (resType.isUnsigned() &&
1275 !type_cast<IntType>(value.getType()).isUnsigned()) {
1276 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1278 rewriter.replaceOp(op, value);
1281 template <
typename OpTy>
1282 static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
1284 if (op.getType().getBitWidthOrSentinel() == 0)
1286 APInt(0, 0, op.getType().isSignedInteger()));
1289 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1290 return op.getHigh();
1295 if (op.getType().getBitWidthOrSentinel() < 0)
1300 if (cond->isZero() && op.getLow().getType() == op.getType())
1302 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1303 return op.getHigh();
1307 if (
auto lowCst =
getConstant(adaptor.getLow())) {
1309 if (
auto highCst =
getConstant(adaptor.getHigh())) {
1311 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1312 *highCst == *lowCst)
1315 if (highCst->isOne() && lowCst->isZero() &&
1316 op.getType() == op.getSel().getType())
1329 OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1330 return foldMux(*
this, adaptor);
1333 OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1334 return foldMux(*
this, adaptor);
1337 OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return {}; }
1344 class MuxPad :
public mlir::RewritePattern {
1346 MuxPad(MLIRContext *context)
1347 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1350 matchAndRewrite(Operation *op,
1351 mlir::PatternRewriter &rewriter)
const override {
1352 auto mux = cast<MuxPrimOp>(op);
1353 auto width = mux.getType().getBitWidthOrSentinel();
1357 auto pad = [&](Value input) -> Value {
1359 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1360 if (inputWidth < 0 ||
width == inputWidth)
1363 .create<PadPrimOp>(mux.getLoc(), mux.getType(), input,
width)
1367 auto newHigh = pad(mux.getHigh());
1368 auto newLow = pad(mux.getLow());
1369 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1372 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1373 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1381 class MuxSharedCond :
public mlir::RewritePattern {
1383 MuxSharedCond(MLIRContext *context)
1384 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1386 static const int depthLimit = 5;
1388 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1389 mlir::PatternRewriter &rewriter,
1390 bool updateInPlace)
const {
1391 if (updateInPlace) {
1392 rewriter.modifyOpInPlace(mux, [&] {
1393 mux.setOperand(1, high);
1394 mux.setOperand(2, low);
1398 rewriter.setInsertionPointAfter(mux);
1400 .create<MuxPrimOp>(mux.getLoc(), mux.getType(),
1401 ValueRange{mux.getSel(), high, low})
1406 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1407 bool updateInPlace,
int limit)
const {
1408 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1411 if (mux.getSel() == cond)
1412 return mux.getHigh();
1413 if (limit > depthLimit)
1415 updateInPlace &= mux->hasOneUse();
1417 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1419 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1422 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1423 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1428 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1429 bool updateInPlace,
int limit)
const {
1430 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1433 if (mux.getSel() == cond)
1434 return mux.getLow();
1435 if (limit > depthLimit)
1437 updateInPlace &= mux->hasOneUse();
1439 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1441 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1443 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1445 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1451 matchAndRewrite(Operation *op,
1452 mlir::PatternRewriter &rewriter)
const override {
1453 auto mux = cast<MuxPrimOp>(op);
1454 auto width = mux.getType().getBitWidthOrSentinel();
1458 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter,
true, 0)) {
1459 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1463 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter,
true, 0)) {
1464 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1473 void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1474 MLIRContext *context) {
1476 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1477 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1478 patterns::MuxSameTrue, patterns::MuxSameFalse,
1479 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1483 void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1484 RewritePatternSet &results, MLIRContext *context) {
1485 results.add<patterns::Mux2PadSel>(context);
1488 void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1489 RewritePatternSet &results, MLIRContext *context) {
1490 results.add<patterns::Mux4PadSel>(context);
1493 OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1494 auto input = this->getInput();
1497 if (input.getType() == getType())
1501 auto inputType = input.getType().base();
1508 auto destWidth = getType().base().getWidthOrSentinel();
1509 if (destWidth == -1)
1512 if (inputType.
isSigned() && cst->getBitWidth())
1513 return getIntAttr(getType(), cst->sext(destWidth));
1514 return getIntAttr(getType(), cst->zext(destWidth));
1520 OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1521 auto input = this->getInput();
1522 IntType inputType = input.getType();
1523 int shiftAmount = getAmount();
1526 if (shiftAmount == 0)
1532 if (inputWidth != -1) {
1533 auto resultWidth = inputWidth + shiftAmount;
1534 shiftAmount = std::min(shiftAmount, resultWidth);
1535 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1541 OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1542 auto input = this->getInput();
1543 IntType inputType = input.getType();
1544 int shiftAmount = getAmount();
1550 if (shiftAmount == 0 && inputWidth > 0)
1553 if (inputWidth == -1)
1555 if (inputWidth == 0)
1560 if (shiftAmount >= inputWidth && inputType.
isUnsigned())
1561 return getIntAttr(getType(), APInt(0, 0,
false));
1567 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1569 value = cst->lshr(std::min(shiftAmount, inputWidth));
1570 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1571 return getIntAttr(getType(), value.trunc(resultWidth));
1577 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1578 if (inputWidth <= 0)
1582 unsigned shiftAmount = op.getAmount();
1583 if (
int(shiftAmount) >= inputWidth) {
1585 if (op.getType().base().isUnsigned())
1591 shiftAmount = inputWidth - 1;
1594 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1599 PatternRewriter &rewriter) {
1600 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1601 if (inputWidth <= 0)
1605 unsigned keepAmount = op.getAmount();
1607 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1612 OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1616 getInput().getType().base().getWidthOrSentinel() - getAmount();
1617 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1623 OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1627 cst->trunc(getType().base().getWidthOrSentinel()));
1632 PatternRewriter &rewriter) {
1633 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1634 if (inputWidth <= 0)
1638 unsigned dropAmount = op.getAmount();
1639 if (dropAmount !=
unsigned(inputWidth))
1645 void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1646 MLIRContext *context) {
1647 results.add<patterns::SubaccessOfConstant>(context);
1650 OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1652 if (adaptor.getInputs().size() == 1)
1653 return getOperand(1);
1655 if (
auto constIndex =
getConstant(adaptor.getIndex())) {
1656 auto index = constIndex->getZExtValue();
1657 if (index < getInputs().size())
1658 return getInputs()[getInputs().size() - 1 - index];
1665 PatternRewriter &rewriter) {
1669 if (llvm::all_of(op.getInputs().drop_front(), [&](
auto input) {
1670 return input == op.getInputs().front();
1678 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
1679 uint64_t inputSize = op.getInputs().size();
1680 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
1681 rewriter.modifyOpInPlace(op, [&]() {
1682 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
1689 if (
auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
1690 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](
auto e) {
1691 auto subindex = e.value().template getDefiningOp<SubindexOp>();
1692 return subindex && lastSubindex.getInput() == subindex.getInput() &&
1693 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
1695 replaceOpWithNewOpAndCopyName<SubaccessOp>(
1696 rewriter, op, lastSubindex.getInput(), op.getIndex());
1702 if (op.getInputs().size() != 2)
1706 auto uintType = op.getIndex().getType();
1707 if (uintType.getBitWidthOrSentinel() != 1)
1711 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1712 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
1732 for (Operation *user : value.getUsers()) {
1734 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
1737 if (
auto aConnect = dyn_cast<FConnectLike>(user))
1738 if (aConnect.getDest() == value) {
1739 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
1742 if (!matchingConnect || (
connect &&
connect != matchingConnect) ||
1743 matchingConnect->getBlock() != value.getParentBlock())
1754 PatternRewriter &rewriter) {
1757 Operation *connectedDecl = op.getDest().getDefiningOp();
1762 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
1767 cast<Forceable>(connectedDecl).isForceable())
1775 if (connectedDecl->hasOneUse())
1779 auto *declBlock = connectedDecl->getBlock();
1780 auto *srcValueOp = op.getSrc().getDefiningOp();
1783 if (!isa<WireOp>(connectedDecl))
1789 if (!isa<ConstantOp>(srcValueOp))
1791 if (srcValueOp->getBlock() != declBlock)
1797 auto replacement = op.getSrc();
1800 if (srcValueOp && srcValueOp != &declBlock->front())
1801 srcValueOp->moveBefore(&declBlock->front());
1808 rewriter.eraseOp(op);
1812 void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1813 MLIRContext *context) {
1814 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
1819 PatternRewriter &rewriter) {
1836 for (
auto *user : value.getUsers()) {
1837 auto attach = dyn_cast<AttachOp>(user);
1838 if (!attach || attach == dominatedAttach)
1840 if (attach->isBeforeInBlock(dominatedAttach))
1848 if (op.getNumOperands() <= 1) {
1849 rewriter.eraseOp(op);
1853 for (
auto operand : op.getOperands()) {
1860 SmallVector<Value> newOperands(op.getOperands());
1861 for (
auto newOperand : attach.getOperands())
1862 if (newOperand != operand)
1863 newOperands.push_back(newOperand);
1864 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1865 rewriter.eraseOp(attach);
1866 rewriter.eraseOp(op);
1874 if (
auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
1875 if (!
hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
1876 !wire.isForceable()) {
1877 SmallVector<Value> newOperands;
1878 for (
auto newOperand : op.getOperands())
1879 if (newOperand != operand)
1880 newOperands.push_back(newOperand);
1882 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1883 rewriter.eraseOp(op);
1884 rewriter.eraseOp(wire);
1895 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
1896 rewriter.inlineBlockBefore(®ion.front(), op, {});
1900 if (
auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
1901 if (constant.getValue().isAllOnes())
1903 else if (op.hasElseRegion() && !op.getElseRegion().empty())
1906 rewriter.eraseOp(op);
1912 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
1913 op.getElseBlock().empty()) {
1914 rewriter.eraseBlock(&op.getElseBlock());
1921 if (!op.getThenBlock().empty())
1925 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
1926 rewriter.eraseOp(op);
1935 struct FoldNodeName :
public mlir::RewritePattern {
1936 FoldNodeName(MLIRContext *context)
1937 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1938 LogicalResult matchAndRewrite(Operation *op,
1939 PatternRewriter &rewriter)
const override {
1940 auto node = cast<NodeOp>(op);
1941 auto name = node.getNameAttr();
1942 if (!node.hasDroppableName() || node.getInnerSym() ||
1945 auto *newOp = node.getInput().getDefiningOp();
1947 if (newOp && !isa<InstanceOp>(newOp))
1949 rewriter.replaceOp(node, node.getInput());
1955 struct NodeBypass :
public mlir::RewritePattern {
1956 NodeBypass(MLIRContext *context)
1957 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1958 LogicalResult matchAndRewrite(Operation *op,
1959 PatternRewriter &rewriter)
const override {
1960 auto node = cast<NodeOp>(op);
1961 if (node.getInnerSym() || !
AnnotationSet(node).canBeDeleted() ||
1962 node.use_empty() || node.isForceable())
1964 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
1971 template <
typename OpTy>
1973 PatternRewriter &rewriter) {
1974 if (!op.isForceable() || !op.getDataRef().use_empty())
1982 LogicalResult NodeOp::fold(FoldAdaptor adaptor,
1983 SmallVectorImpl<OpFoldResult> &results) {
1988 if (getAnnotationsAttr() &&
1993 if (!adaptor.getInput())
1996 results.push_back(adaptor.getInput());
2000 void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2001 MLIRContext *context) {
2002 results.insert<FoldNodeName>(context);
2003 results.add(demoteForceableIfUnused<NodeOp>);
2009 struct AggOneShot :
public mlir::RewritePattern {
2010 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2011 : RewritePattern(name, 0, context) {}
2013 SmallVector<Value> getCompleteWrite(Operation *lhs)
const {
2014 auto lhsTy = lhs->getResult(0).getType();
2015 if (!type_isa<BundleType, FVectorType>(lhsTy))
2018 DenseMap<uint32_t, Value> fields;
2019 for (Operation *user : lhs->getResult(0).getUsers()) {
2020 if (user->getParentOp() != lhs->getParentOp())
2022 if (
auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2023 if (aConnect.getDest() == lhs->getResult(0))
2025 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2026 for (Operation *subuser : subField.getResult().getUsers()) {
2027 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2028 if (aConnect.getDest() == subField) {
2029 if (subuser->getParentOp() != lhs->getParentOp())
2031 if (fields.count(subField.getFieldIndex()))
2033 fields[subField.getFieldIndex()] = aConnect.getSrc();
2039 }
else if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2040 for (Operation *subuser : subIndex.getResult().getUsers()) {
2041 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2042 if (aConnect.getDest() == subIndex) {
2043 if (subuser->getParentOp() != lhs->getParentOp())
2045 if (fields.count(subIndex.getIndex()))
2047 fields[subIndex.getIndex()] = aConnect.getSrc();
2058 SmallVector<Value> values;
2059 uint32_t total = type_isa<BundleType>(lhsTy)
2060 ? type_cast<BundleType>(lhsTy).getNumElements()
2061 : type_cast<FVectorType>(lhsTy).getNumElements();
2062 for (uint32_t i = 0; i < total; ++i) {
2063 if (!fields.count(i))
2065 values.push_back(fields[i]);
2070 LogicalResult matchAndRewrite(Operation *op,
2071 PatternRewriter &rewriter)
const override {
2072 auto values = getCompleteWrite(op);
2075 rewriter.setInsertionPointToEnd(op->getBlock());
2076 auto dest = op->getResult(0);
2077 auto destType = dest.getType();
2080 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2083 Value newVal = type_isa<BundleType>(destType)
2084 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2086 : rewriter.createOrFold<VectorCreateOp>(
2087 op->getLoc(), destType, values);
2088 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2089 for (Operation *user : dest.getUsers()) {
2090 if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2091 for (Operation *subuser :
2092 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2093 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2094 if (aConnect.getDest() == subIndex)
2095 rewriter.eraseOp(aConnect);
2096 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2097 for (Operation *subuser :
2098 llvm::make_early_inc_range(subField.getResult().getUsers()))
2099 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2100 if (aConnect.getDest() == subField)
2101 rewriter.eraseOp(aConnect);
2108 struct WireAggOneShot :
public AggOneShot {
2109 WireAggOneShot(MLIRContext *context)
2110 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2112 struct SubindexAggOneShot :
public AggOneShot {
2113 SubindexAggOneShot(MLIRContext *context)
2114 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2116 struct SubfieldAggOneShot :
public AggOneShot {
2117 SubfieldAggOneShot(MLIRContext *context)
2118 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2122 void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2123 MLIRContext *context) {
2124 results.insert<WireAggOneShot>(context);
2125 results.add(demoteForceableIfUnused<WireOp>);
2128 void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2129 MLIRContext *context) {
2130 results.insert<SubindexAggOneShot>(context);
2133 OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2134 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2137 return attr[getIndex()];
2140 OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2141 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2144 auto index = getFieldIndex();
2148 void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2149 MLIRContext *context) {
2150 results.insert<SubfieldAggOneShot>(context);
2154 ArrayRef<Attribute> operands) {
2155 for (
auto operand : operands)
2161 OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2164 if (getNumOperands() > 0)
2165 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2166 if (first.getFieldIndex() == 0 &&
2167 first.getInput().getType() == getType() &&
2169 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2171 elem.value().
template getDefiningOp<SubfieldOp>();
2172 return subindex && subindex.getInput() == first.getInput() &&
2173 subindex.getFieldIndex() == elem.index();
2175 return first.getInput();
2180 OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2183 if (getNumOperands() > 0)
2184 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2185 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2187 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2189 elem.value().
template getDefiningOp<SubindexOp>();
2190 return subindex && subindex.getInput() == first.getInput() &&
2191 subindex.getIndex() == elem.index();
2193 return first.getInput();
2198 OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2199 if (getOperand().getType() == getType())
2200 return getOperand();
2207 struct FoldResetMux :
public mlir::RewritePattern {
2208 FoldResetMux(MLIRContext *context)
2209 : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2210 LogicalResult matchAndRewrite(Operation *op,
2211 PatternRewriter &rewriter)
const override {
2212 auto reg = cast<RegResetOp>(op);
2214 dyn_cast_or_null<ConstantOp>(
reg.getResetValue().getDefiningOp());
2223 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2226 auto *high = mux.getHigh().getDefiningOp();
2227 auto *low = mux.getLow().getDefiningOp();
2228 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2230 if (constOp && low !=
reg)
2232 if (dyn_cast_or_null<ConstantOp>(low) && high ==
reg)
2233 constOp = dyn_cast<ConstantOp>(low);
2235 if (!constOp || constOp.getType() != reset.getType() ||
2236 constOp.getValue() != reset.getValue())
2240 auto regTy =
reg.getResult().getType();
2241 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2242 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2243 regTy.getBitWidthOrSentinel() < 0)
2249 if (constOp != &con->getBlock()->front())
2250 constOp->moveBefore(&con->getBlock()->front());
2255 rewriter.eraseOp(con);
2262 if (
auto c = v.getDefiningOp<ConstantOp>())
2263 return c.getValue().isOne();
2264 if (
auto sc = v.getDefiningOp<SpecialConstantOp>())
2265 return sc.getValue();
2269 static LogicalResult
2276 replaceOpWithNewOpAndCopyName<NodeOp>(
2277 rewriter,
reg,
reg.getResetValue(),
reg.getNameAttr(),
reg.getNameKind(),
2278 reg.getAnnotationsAttr(),
reg.getInnerSymAttr(),
reg.getForceable());
2282 void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2283 MLIRContext *context) {
2284 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2286 results.add(demoteForceableIfUnused<RegResetOp>);
2291 auto portTy = type_cast<BundleType>(port.getType());
2292 auto fieldIndex = portTy.getElementIndex(name);
2293 assert(fieldIndex &&
"missing field on memory port");
2296 for (
auto *op : port.getUsers()) {
2297 auto portAccess = cast<SubfieldOp>(op);
2298 if (fieldIndex != portAccess.getFieldIndex())
2303 value = conn.getSrc();
2313 auto portConst = value.getDefiningOp<ConstantOp>();
2316 return portConst.getValue().isZero();
2321 auto portTy = type_cast<BundleType>(port.getType());
2322 auto fieldIndex = portTy.getElementIndex(
data);
2323 assert(fieldIndex &&
"missing enable flag on memory port");
2325 for (
auto *op : port.getUsers()) {
2326 auto portAccess = cast<SubfieldOp>(op);
2327 if (fieldIndex != portAccess.getFieldIndex())
2329 if (!portAccess.use_empty())
2338 StringRef name, Value value) {
2339 auto portTy = type_cast<BundleType>(port.getType());
2340 auto fieldIndex = portTy.getElementIndex(name);
2341 assert(fieldIndex &&
"missing field on memory port");
2343 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2344 auto portAccess = cast<SubfieldOp>(op);
2345 if (fieldIndex != portAccess.getFieldIndex())
2347 rewriter.replaceAllUsesWith(portAccess, value);
2348 rewriter.eraseOp(portAccess);
2353 static void erasePort(PatternRewriter &rewriter, Value port) {
2356 auto getClock = [&] {
2358 clock = rewriter.create<SpecialConstantOp>(
2367 for (
auto *op : port.getUsers()) {
2368 auto subfield = dyn_cast<SubfieldOp>(op);
2370 auto ty = port.getType();
2371 auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
2372 rewriter.replaceAllUsesWith(port,
reg.getResult());
2381 for (
auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2382 auto access = cast<SubfieldOp>(accessOp);
2383 for (
auto *user : llvm::make_early_inc_range(access->getUsers())) {
2384 auto connect = dyn_cast<FConnectLike>(user);
2386 rewriter.eraseOp(user);
2390 if (access.use_empty()) {
2391 rewriter.eraseOp(access);
2397 auto ty = access.getType();
2398 auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
2399 rewriter.replaceOp(access,
reg.getResult());
2401 assert(port.use_empty() &&
"port should have no remaining uses");
2406 struct FoldZeroWidthMemory :
public mlir::RewritePattern {
2407 FoldZeroWidthMemory(MLIRContext *context)
2408 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2409 LogicalResult matchAndRewrite(Operation *op,
2410 PatternRewriter &rewriter)
const override {
2411 MemOp mem = cast<MemOp>(op);
2415 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2416 mem.getDataType().getBitWidthOrSentinel() != 0)
2420 for (
auto port : mem.getResults())
2421 for (
auto *user : port.getUsers())
2422 if (!isa<SubfieldOp>(user))
2427 for (
auto port : op->getResults()) {
2428 for (
auto *user : llvm::make_early_inc_range(port.getUsers())) {
2429 SubfieldOp sfop = cast<SubfieldOp>(user);
2430 StringRef fieldName = sfop.getFieldName();
2431 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2432 rewriter, sfop, sfop.getResult().getType())
2434 if (fieldName.ends_with(
"data")) {
2436 auto zero = rewriter.create<firrtl::ConstantOp>(
2437 wire.getLoc(), firrtl::type_cast<IntType>(wire.getType()),
2439 rewriter.create<MatchingConnectOp>(wire.getLoc(), wire, zero);
2443 rewriter.eraseOp(op);
2449 struct FoldReadOrWriteOnlyMemory :
public mlir::RewritePattern {
2450 FoldReadOrWriteOnlyMemory(MLIRContext *context)
2451 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2452 LogicalResult matchAndRewrite(Operation *op,
2453 PatternRewriter &rewriter)
const override {
2454 MemOp mem = cast<MemOp>(op);
2457 bool isRead =
false, isWritten =
false;
2458 for (
unsigned i = 0; i < mem.getNumResults(); ++i) {
2459 switch (mem.getPortKind(i)) {
2460 case MemOp::PortKind::Read:
2465 case MemOp::PortKind::Write:
2470 case MemOp::PortKind::Debug:
2471 case MemOp::PortKind::ReadWrite:
2474 llvm_unreachable(
"unknown port kind");
2476 assert((!isWritten || !isRead) &&
"memory is in use");
2481 if (isRead && mem.getInit())
2484 for (
auto port : mem.getResults())
2487 rewriter.eraseOp(op);
2493 struct FoldUnusedPorts :
public mlir::RewritePattern {
2494 FoldUnusedPorts(MLIRContext *context)
2495 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2496 LogicalResult matchAndRewrite(Operation *op,
2497 PatternRewriter &rewriter)
const override {
2498 MemOp mem = cast<MemOp>(op);
2502 llvm::SmallBitVector deadPorts(mem.getNumResults());
2503 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2505 if (!mem.getPortAnnotation(i).empty())
2509 auto kind = mem.getPortKind(i);
2510 if (kind == MemOp::PortKind::Debug)
2519 if (kind == MemOp::PortKind::Read &&
isPortUnused(port,
"data")) {
2524 if (deadPorts.none())
2528 SmallVector<Type> resultTypes;
2529 SmallVector<StringRef> portNames;
2530 SmallVector<Attribute> portAnnotations;
2531 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2534 resultTypes.push_back(port.getType());
2535 portNames.push_back(mem.getPortName(i));
2536 portAnnotations.push_back(mem.getPortAnnotation(i));
2540 if (!resultTypes.empty())
2541 newOp = rewriter.create<MemOp>(
2542 mem.getLoc(), resultTypes, mem.getReadLatency(),
2543 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2544 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2545 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2546 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2549 unsigned nextPort = 0;
2550 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2554 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2557 rewriter.eraseOp(op);
2563 struct FoldReadWritePorts :
public mlir::RewritePattern {
2564 FoldReadWritePorts(MLIRContext *context)
2565 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2566 LogicalResult matchAndRewrite(Operation *op,
2567 PatternRewriter &rewriter)
const override {
2568 MemOp mem = cast<MemOp>(op);
2573 llvm::SmallBitVector deadReads(mem.getNumResults());
2574 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2575 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2577 if (!mem.getPortAnnotation(i).empty())
2584 if (deadReads.none())
2587 SmallVector<Type> resultTypes;
2588 SmallVector<StringRef> portNames;
2589 SmallVector<Attribute> portAnnotations;
2590 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2592 resultTypes.push_back(
2593 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2594 MemOp::PortKind::Write, mem.getMaskBits()));
2596 resultTypes.push_back(port.getType());
2598 portNames.push_back(mem.getPortName(i));
2599 portAnnotations.push_back(mem.getPortAnnotation(i));
2602 auto newOp = rewriter.create<MemOp>(
2603 mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
2604 mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
2605 mem.getName(), mem.getNameKind(), mem.getAnnotations(),
2606 rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
2607 mem.getInitAttr(), mem.getPrefixAttr());
2609 for (
unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2610 auto result = mem.getResult(i);
2611 auto newResult = newOp.getResult(i);
2613 auto resultPortTy = type_cast<BundleType>(result.getType());
2617 auto replace = [&](StringRef toName, StringRef fromName) {
2618 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2619 assert(fromFieldIndex &&
"missing enable flag on memory port");
2621 auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
2623 for (
auto *op : llvm::make_early_inc_range(result.getUsers())) {
2624 auto fromField = cast<SubfieldOp>(op);
2625 if (fromFieldIndex != fromField.getFieldIndex())
2627 rewriter.replaceOp(fromField, toField.getResult());
2631 replace(
"addr",
"addr");
2632 replace(
"en",
"en");
2633 replace(
"clk",
"clk");
2634 replace(
"data",
"wdata");
2635 replace(
"mask",
"wmask");
2638 auto wmodeFieldIndex = resultPortTy.getElementIndex(
"wmode");
2639 for (
auto *op : llvm::make_early_inc_range(result.getUsers())) {
2640 auto wmodeField = cast<SubfieldOp>(op);
2641 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2643 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2646 rewriter.replaceAllUsesWith(result, newResult);
2649 rewriter.eraseOp(op);
2655 struct FoldUnusedBits :
public mlir::RewritePattern {
2656 FoldUnusedBits(MLIRContext *context)
2657 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2659 LogicalResult matchAndRewrite(Operation *op,
2660 PatternRewriter &rewriter)
const override {
2661 MemOp mem = cast<MemOp>(op);
2666 const auto &summary = mem.getSummary();
2667 if (summary.isMasked || summary.isSeqMem())
2670 auto type = type_dyn_cast<IntType>(mem.getDataType());
2673 auto width = type.getBitWidthOrSentinel();
2677 llvm::SmallBitVector usedBits(
width);
2678 DenseMap<unsigned, unsigned> mapping;
2683 SmallVector<BitsPrimOp> readOps;
2684 auto findReadUsers = [&](Value port, StringRef field) {
2685 auto portTy = type_cast<BundleType>(port.getType());
2686 auto fieldIndex = portTy.getElementIndex(field);
2687 assert(fieldIndex &&
"missing data port");
2689 for (
auto *op : port.getUsers()) {
2690 auto portAccess = cast<SubfieldOp>(op);
2691 if (fieldIndex != portAccess.getFieldIndex())
2694 for (
auto *user : op->getUsers()) {
2695 auto bits = dyn_cast<BitsPrimOp>(user);
2701 usedBits.set(bits.getLo(), bits.getHi() + 1);
2702 mapping[bits.getLo()] = 0;
2703 readOps.push_back(bits);
2711 SmallVector<MatchingConnectOp> writeOps;
2712 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
2713 auto portTy = type_cast<BundleType>(port.getType());
2714 auto fieldIndex = portTy.getElementIndex(field);
2715 assert(fieldIndex &&
"missing data port");
2717 for (
auto *op : port.getUsers()) {
2718 auto portAccess = cast<SubfieldOp>(op);
2719 if (fieldIndex != portAccess.getFieldIndex())
2726 writeOps.push_back(conn);
2732 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2734 if (!mem.getPortAnnotation(i).empty())
2737 switch (mem.getPortKind(i)) {
2738 case MemOp::PortKind::Debug:
2741 case MemOp::PortKind::Write:
2742 if (failed(findWriteUsers(port,
"data")))
2745 case MemOp::PortKind::Read:
2746 findReadUsers(port,
"data");
2748 case MemOp::PortKind::ReadWrite:
2749 if (failed(findWriteUsers(port,
"wdata")))
2751 findReadUsers(port,
"rdata");
2754 llvm_unreachable(
"unknown port kind");
2759 if (usedBits.all() || usedBits.none())
2763 SmallVector<std::pair<unsigned, unsigned>> ranges;
2764 unsigned newWidth = 0;
2765 for (
int i = usedBits.find_first(); 0 <= i && i <
width;) {
2766 int e = usedBits.find_next_unset(i);
2769 for (
int idx = i; idx < e; ++idx, ++newWidth) {
2770 if (
auto it = mapping.find(idx); it != mapping.end()) {
2771 it->second = newWidth;
2774 ranges.emplace_back(i, e - 1);
2775 i = e !=
width ? usedBits.find_next(e) : e;
2779 auto newType =
IntType::get(op->getContext(), type.isSigned(), newWidth);
2780 SmallVector<Type> portTypes;
2781 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2782 portTypes.push_back(
2783 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
2785 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
2786 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
2787 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
2788 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
2789 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2792 auto rewriteSubfield = [&](Value port, StringRef field) {
2793 auto portTy = type_cast<BundleType>(port.getType());
2794 auto fieldIndex = portTy.getElementIndex(field);
2795 assert(fieldIndex &&
"missing data port");
2797 rewriter.setInsertionPointAfter(newMem);
2798 auto newPortAccess =
2799 rewriter.create<SubfieldOp>(port.getLoc(), port, field);
2801 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2802 auto portAccess = cast<SubfieldOp>(op);
2803 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
2805 rewriter.replaceOp(portAccess, newPortAccess.getResult());
2810 for (
auto [i, port] : llvm::enumerate(newMem.getResults())) {
2811 switch (newMem.getPortKind(i)) {
2812 case MemOp::PortKind::Debug:
2813 llvm_unreachable(
"cannot rewrite debug port");
2814 case MemOp::PortKind::Write:
2815 rewriteSubfield(port,
"data");
2817 case MemOp::PortKind::Read:
2818 rewriteSubfield(port,
"data");
2820 case MemOp::PortKind::ReadWrite:
2821 rewriteSubfield(port,
"rdata");
2822 rewriteSubfield(port,
"wdata");
2825 llvm_unreachable(
"unknown port kind");
2829 for (
auto readOp : readOps) {
2830 rewriter.setInsertionPointAfter(readOp);
2831 auto it = mapping.find(readOp.getLo());
2832 assert(it != mapping.end() &&
"bit op mapping not found");
2833 rewriter.replaceOpWithNewOp<BitsPrimOp>(
2834 readOp, readOp.getInput(),
2835 readOp.getHi() - readOp.getLo() + it->second, it->second);
2839 for (
auto writeOp : writeOps) {
2840 Value source = writeOp.getSrc();
2841 rewriter.setInsertionPoint(writeOp);
2844 for (
auto &[start, end] : ranges) {
2846 rewriter.create<BitsPrimOp>(writeOp.getLoc(), source,
end, start);
2849 rewriter.create<CatPrimOp>(writeOp.getLoc(), slice, catOfSlices);
2851 catOfSlices = slice;
2854 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
2863 struct FoldRegMems :
public mlir::RewritePattern {
2864 FoldRegMems(MLIRContext *context)
2865 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2866 LogicalResult matchAndRewrite(Operation *op,
2867 PatternRewriter &rewriter)
const override {
2868 MemOp mem = cast<MemOp>(op);
2869 const FirMemory &info = mem.getSummary();
2873 auto memModule = mem->getParentOfType<FModuleOp>();
2877 SmallPtrSet<Operation *, 8> connects;
2878 SmallVector<SubfieldOp> portAccesses;
2879 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2880 if (!mem.getPortAnnotation(i).empty())
2883 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
2884 auto portTy = type_cast<BundleType>(port.getType());
2885 for (
auto field : fields) {
2886 auto fieldIndex = portTy.getElementIndex(field);
2887 assert(fieldIndex &&
"missing field on memory port");
2889 for (
auto *op : port.getUsers()) {
2890 auto portAccess = cast<SubfieldOp>(op);
2891 if (fieldIndex != portAccess.getFieldIndex())
2893 portAccesses.push_back(portAccess);
2894 for (
auto *user : portAccess->getUsers()) {
2895 auto conn = dyn_cast<FConnectLike>(user);
2898 connects.insert(conn);
2905 switch (mem.getPortKind(i)) {
2906 case MemOp::PortKind::Debug:
2908 case MemOp::PortKind::Read:
2909 if (failed(collect({
"clk",
"en",
"addr"})))
2912 case MemOp::PortKind::Write:
2913 if (failed(collect({
"clk",
"en",
"addr",
"data",
"mask"})))
2916 case MemOp::PortKind::ReadWrite:
2917 if (failed(collect({
"clk",
"en",
"addr",
"wmode",
"wdata",
"wmask"})))
2923 if (!portClock || (clock && portClock != clock))
2929 auto ty = mem.getDataType();
2930 rewriter.setInsertionPointAfterValue(clock);
2931 auto reg = rewriter.create<RegOp>(mem.getLoc(), ty, clock, mem.getName())
2935 auto pipeline = [&](Value value, Value clock,
const Twine &name,
2937 for (
unsigned i = 0; i < latency; ++i) {
2938 std::string regName;
2940 llvm::raw_string_ostream os(regName);
2941 os << mem.getName() <<
"_" << name <<
"_" << i;
2945 .create<RegOp>(mem.getLoc(), value.getType(), clock,
2946 rewriter.getStringAttr(regName))
2948 rewriter.create<MatchingConnectOp>(value.getLoc(),
reg, value);
2959 SmallVector<std::tuple<Value, Value, Value>> writes;
2960 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2962 StringRef name = mem.getPortName(i);
2964 auto portPipeline = [&, port = port](StringRef field,
unsigned stages) {
2967 rewriter.setInsertionPointAfterValue(value);
2968 return pipeline(value, portClock, name +
"_" + field, stages);
2971 switch (mem.getPortKind(i)) {
2972 case MemOp::PortKind::Debug:
2973 llvm_unreachable(
"unknown port kind");
2974 case MemOp::PortKind::Read: {
2979 rewriter.setInsertionPointAfterValue(
reg);
2983 case MemOp::PortKind::Write: {
2984 auto data = portPipeline(
"data", writeStages);
2985 auto en = portPipeline(
"en", writeStages);
2986 auto mask = portPipeline(
"mask", writeStages);
2987 writes.emplace_back(data, en, mask);
2990 case MemOp::PortKind::ReadWrite: {
2992 rewriter.setInsertionPointAfterValue(
reg);
2996 auto wdata = portPipeline(
"wdata", writeStages);
2997 auto wmask = portPipeline(
"wmask", writeStages);
3001 rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
3003 auto wen = rewriter.create<AndPrimOp>(port.getLoc(),
en,
wmode);
3005 pipeline(wen, portClock, name +
"_wen", writeStages);
3006 writes.emplace_back(wdata, wenPipelined, wmask);
3013 rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
3015 for (
auto &[data, en, mask] : writes) {
3020 Location loc = mem.getLoc();
3022 for (
unsigned i = 0; i < info.
maskBits; ++i) {
3023 unsigned hi = (i + 1) * maskGran - 1;
3024 unsigned lo = i * maskGran;
3026 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc,
data, hi, lo);
3027 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3028 auto bit = rewriter.createOrFold<BitsPrimOp>(loc,
mask, i, i);
3029 auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
3032 masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
3038 next = rewriter.create<MuxPrimOp>(next.getLoc(),
en, masked, next);
3040 rewriter.create<MatchingConnectOp>(
reg.getLoc(),
reg, next);
3043 for (Operation *conn : connects)
3044 rewriter.eraseOp(conn);
3045 for (
auto portAccess : portAccesses)
3046 rewriter.eraseOp(portAccess);
3047 rewriter.eraseOp(mem);
3054 void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3055 MLIRContext *context) {
3057 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3058 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3078 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3081 auto *high = mux.getHigh().getDefiningOp();
3082 auto *low = mux.getLow().getDefiningOp();
3084 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3091 bool constReg =
false;
3093 if (constOp && low ==
reg)
3095 else if (dyn_cast_or_null<ConstantOp>(low) && high ==
reg) {
3097 constOp = dyn_cast<ConstantOp>(low);
3104 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3108 auto regTy =
reg.getResult().getType();
3109 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3110 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3111 regTy.getBitWidthOrSentinel() < 0)
3117 if (constOp != &con->getBlock()->front())
3118 constOp->moveBefore(&con->getBlock()->front());
3121 SmallVector<NamedAttribute, 2> attrs(
reg->getDialectAttrs());
3122 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3123 rewriter,
reg,
reg.getResult().getType(),
reg.getClockVal(),
3124 mux.getSel(), mux.getHigh(),
reg.getNameAttr(),
reg.getNameKindAttr(),
3125 reg.getAnnotationsAttr(),
reg.getInnerSymAttr(),
3126 reg.getForceableAttr());
3127 newReg->setDialectAttrs(attrs);
3129 auto pt = rewriter.saveInsertionPoint();
3130 rewriter.setInsertionPoint(con);
3131 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3132 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3133 rewriter.restoreInsertionPoint(pt);
3138 if (!
hasDontTouch(op.getOperation()) && !op.isForceable() &&
3154 PatternRewriter &rewriter,
3157 if (
auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3158 if (constant.getValue().isZero()) {
3159 rewriter.eraseOp(op);
3165 if (
auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3166 if (constant.getValue().isZero() == eraseIfZero) {
3167 rewriter.eraseOp(op);
3175 template <
class Op,
bool EraseIfZero = false>
3177 PatternRewriter &rewriter) {
3182 void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3183 MLIRContext *context) {
3184 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3187 void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3188 MLIRContext *context) {
3189 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3192 void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3193 RewritePatternSet &results, MLIRContext *context) {
3194 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3197 void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3198 MLIRContext *context) {
3199 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3207 PatternRewriter &rewriter) {
3209 if (op.use_empty()) {
3210 rewriter.eraseOp(op);
3217 if (op->hasOneUse() &&
3218 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3219 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3220 *op->user_begin()) ||
3221 (isa<CvtPrimOp>(*op->user_begin()) &&
3222 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3223 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3224 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3225 .getBitWidthOrSentinel() > 0))) {
3226 auto *modop = *op->user_begin();
3227 auto inv = rewriter.create<InvalidValueOp>(op.getLoc(),
3228 modop->getResult(0).getType());
3229 rewriter.replaceAllOpUsesWith(modop, inv);
3230 rewriter.eraseOp(modop);
3231 rewriter.eraseOp(op);
3237 OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3238 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3239 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3247 OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3266 PatternRewriter &rewriter) {
3268 if (
auto testEnable = op.getTestEnable()) {
3269 if (
auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3270 if (constOp.getValue().isZero()) {
3271 rewriter.modifyOpInPlace(op,
3272 [&] { op.getTestEnableMutable().clear(); });
3286 static LogicalResult
3288 auto forceable = op.getRef().getDefiningOp<Forceable>();
3289 if (!forceable || !forceable.isForceable() ||
3290 op.getRef() != forceable.getDataRef() ||
3291 op.getType() != forceable.getDataType())
3293 rewriter.replaceAllUsesWith(op, forceable.getData());
3297 void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3298 MLIRContext *context) {
3299 results.insert<patterns::RefResolveOfRefSend>(context);
3303 OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3305 if (getInput().getType() == getType())
3311 auto constOp = operand.getDefiningOp<ConstantOp>();
3312 return constOp && constOp.getValue().isZero();
3315 template <
typename Op>
3318 rewriter.eraseOp(op);
3324 void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3325 MLIRContext *context) {
3326 results.add(eraseIfPredFalse<RefForceOp>);
3328 void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3329 MLIRContext *context) {
3330 results.add(eraseIfPredFalse<RefForceInitialOp>);
3332 void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3333 MLIRContext *context) {
3334 results.add(eraseIfPredFalse<RefReleaseOp>);
3336 void RefReleaseInitialOp::getCanonicalizationPatterns(
3337 RewritePatternSet &results, MLIRContext *context) {
3338 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3345 OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3351 if (adaptor.getReset())
3356 if (
isUInt1(getReset().getType()) && adaptor.getClock())
3369 [&](
auto ty) ->
bool {
return isTypeEmpty(ty.getElementType()); })
3370 .Case<BundleType>([&](
auto ty) ->
bool {
3371 for (
auto elem : ty.getElements())
3376 .Case<IntType>([&](
auto ty) {
return ty.getWidth() == 0; })
3377 .Default([](
auto) ->
bool {
return false; });
3381 PatternRewriter &rewriter) {
3382 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3389 rewriter.eraseOp(op);
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion)
Replaces the given op with the contents of the given single-block region.
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
bool hasWidth() const
Return true if this integer type has a known width.
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)