20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/APSInt.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
27 using namespace circt;
28 using namespace firrtl;
32 static Value
dropWrite(PatternRewriter &rewriter, OpResult old,
34 SmallPtrSet<Operation *, 8> users;
35 for (
auto *user : old.getUsers())
37 for (Operation *user : users)
38 if (
auto connect = dyn_cast<FConnectLike>(user))
40 rewriter.eraseOp(user);
49 Operation *op = passthrough.getDefiningOp();
52 assert(op &&
"passthrough must be an operation");
53 Operation *oldOp = old.getOwner();
54 auto name = oldOp->getAttrOfType<StringAttr>(
"name");
55 if (name && !name.getValue().empty())
56 op->setAttr(
"name", name);
64 #include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
72 auto resultType = type_cast<IntType>(op->getResult(0).getType());
73 if (!resultType.hasWidth())
75 for (Value operand : op->getOperands())
76 if (!type_cast<IntType>(operand.getType()).hasWidth())
83 auto t = type_dyn_cast<UIntType>(type);
84 if (!t || !t.hasWidth() || t.getWidth() != 1)
91 static void updateName(PatternRewriter &rewriter, Operation *op,
94 assert(!isa<InstanceOp>(op));
95 if (!name || name.getValue().empty())
97 auto newName = name.getValue();
98 auto newOpName = op->getAttrOfType<StringAttr>(
"name");
101 newName =
chooseName(newOpName.getValue(), name.getValue());
103 if (!newOpName || newOpName.getValue() != newName)
104 rewriter.modifyOpInPlace(
105 op, [&] { op->setAttr(
"name", rewriter.getStringAttr(newName)); });
113 if (
auto *newOp = newValue.getDefiningOp()) {
114 auto name = op->getAttrOfType<StringAttr>(
"name");
117 rewriter.replaceOp(op, newValue);
123 template <
typename OpTy,
typename... Args>
125 Operation *op, Args &&...args) {
126 auto name = op->getAttrOfType<StringAttr>(
"name");
128 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
136 if (
auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
137 return namableOp.hasDroppableName();
148 static std::optional<APSInt>
150 assert(type_cast<IntType>(operand.getType()) &&
151 "getExtendedConstant is limited to integer types");
158 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
163 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
164 return APSInt(destWidth,
165 type_cast<IntType>(operand.getType()).isUnsigned());
173 if (
auto attr = dyn_cast<BoolAttr>(operand))
174 return APSInt(APInt(1, attr.getValue()));
175 if (
auto attr = dyn_cast<IntegerAttr>(operand))
176 return attr.getAPSInt();
184 return cst->isZero();
211 Operation *op, ArrayRef<Attribute> operands,
BinOpKind opKind,
212 const function_ref<APInt(
const APSInt &,
const APSInt &)> &calculate) {
213 assert(operands.size() == 2 &&
"binary op takes two operands");
216 auto resultType = type_cast<IntType>(op->getResult(0).getType());
217 if (resultType.getWidthOrSentinel() < 0)
221 if (resultType.getWidthOrSentinel() == 0)
222 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
228 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
230 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
231 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
232 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
233 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
234 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
238 int32_t operandWidth;
241 operandWidth = resultType.getWidthOrSentinel();
246 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
250 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
261 APInt resultValue = calculate(*lhs, *rhs);
266 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
268 assert((
unsigned)resultType.getWidthOrSentinel() ==
269 resultValue.getBitWidth());
282 Operation *op, PatternRewriter &rewriter,
283 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &
canonicalize) {
285 if (op->getNumResults() != 1)
287 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
292 auto width = type.getBitWidthOrSentinel();
297 SmallVector<Attribute, 3> constOperands;
298 constOperands.reserve(op->getNumOperands());
299 for (
auto operand : op->getOperands()) {
301 if (
auto *defOp = operand.getDefiningOp())
302 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
303 [&](
auto op) { attr = op.getValueAttr(); });
304 constOperands.push_back(attr);
313 if (
auto cst = dyn_cast<Attribute>(result))
314 resultValue = op->getDialect()
315 ->materializeConstant(rewriter, cst, type, op->getLoc())
318 resultValue = result.get<Value>();
322 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
323 resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue,
width);
326 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
327 resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
328 else if (type_isa<UIntType>(type) &&
329 type_isa<SIntType>(resultValue.getType()))
330 resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
332 assert(type == resultValue.getType() &&
"canonicalization changed type");
340 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
346 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
352 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
359 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
360 assert(adaptor.getOperands().empty() &&
"constant has no operands");
361 return getValueAttr();
364 OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
365 assert(adaptor.getOperands().empty() &&
"constant has no operands");
366 return getValueAttr();
369 OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
370 assert(adaptor.getOperands().empty() &&
"constant has no operands");
371 return getFieldsAttr();
374 OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
375 assert(adaptor.getOperands().empty() &&
"constant has no operands");
376 return getValueAttr();
379 OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
380 assert(adaptor.getOperands().empty() &&
"constant has no operands");
381 return getValueAttr();
384 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
385 assert(adaptor.getOperands().empty() &&
"constant has no operands");
386 return getValueAttr();
389 OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
390 assert(adaptor.getOperands().empty() &&
"constant has no operands");
391 return getValueAttr();
398 OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
401 [=](
const APSInt &a,
const APSInt &b) { return a + b; });
404 void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
405 MLIRContext *context) {
406 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
407 patterns::AddOfSelf, patterns::AddOfPad>(context);
410 OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
413 [=](
const APSInt &a,
const APSInt &b) { return a - b; });
416 void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 MLIRContext *context) {
418 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
419 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
420 patterns::SubOfPadL, patterns::SubOfPadR>(context);
423 OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
435 [=](
const APSInt &a,
const APSInt &b) { return a * b; });
438 OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
445 if (getLhs() == getRhs()) {
446 auto width = getType().base().getWidthOrSentinel();
468 if (
auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
469 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
474 [=](
const APSInt &a,
const APSInt &b) -> APInt {
477 return APInt(a.getBitWidth(), 0);
481 OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
488 if (getLhs() == getRhs())
502 [=](
const APSInt &a,
const APSInt &b) -> APInt {
505 return APInt(a.getBitWidth(), 0);
509 OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
512 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
515 OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
518 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
521 OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
524 [=](
const APSInt &a,
const APSInt &b) -> APInt {
525 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
531 OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
534 if (rhsCst->isZero())
538 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
539 getRhs().getType() == getType())
545 if (lhsCst->isZero())
549 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
555 if (getLhs() == getRhs() && getRhs().getType() == getType())
560 [](
const APSInt &a,
const APSInt &b) -> APInt { return a & b; });
563 void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
564 MLIRContext *context) {
566 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
567 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
568 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
571 OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
574 if (rhsCst->isZero() && getLhs().getType() == getType())
578 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
579 getLhs().getType() == getType())
585 if (lhsCst->isZero() && getRhs().getType() == getType())
589 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
590 getRhs().getType() == getType())
595 if (getLhs() == getRhs() && getRhs().getType() == getType())
600 [](
const APSInt &a,
const APSInt &b) -> APInt { return a | b; });
603 void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
604 MLIRContext *context) {
605 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
606 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
607 patterns::OrOrr>(context);
610 OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
613 if (rhsCst->isZero() &&
619 if (lhsCst->isZero() &&
624 if (getLhs() == getRhs())
627 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
631 [](
const APSInt &a,
const APSInt &b) -> APInt { return a ^ b; });
634 void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
635 MLIRContext *context) {
636 results.insert<patterns::extendXor, patterns::moveConstXor,
637 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
641 void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
642 MLIRContext *context) {
643 results.insert<patterns::LEQWithConstLHS>(context);
646 OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
647 bool isUnsigned = getLhs().getType().base().isUnsigned();
650 if (getLhs() == getRhs())
656 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
657 commonWidth = std::max(commonWidth, 1);
668 if (isUnsigned && rhsCst->zext(commonWidth)
681 [=](
const APSInt &a,
const APSInt &b) -> APInt {
682 return APInt(1, a <= b);
686 void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
687 MLIRContext *context) {
688 results.insert<patterns::LTWithConstLHS>(context);
691 OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
692 IntType lhsType = getLhs().getType();
696 if (getLhs() == getRhs())
708 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
709 commonWidth = std::max(commonWidth, 1);
720 if (isUnsigned && rhsCst->zext(commonWidth)
733 [=](
const APSInt &a,
const APSInt &b) -> APInt {
734 return APInt(1, a < b);
738 void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
739 MLIRContext *context) {
740 results.insert<patterns::GEQWithConstLHS>(context);
743 OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
744 IntType lhsType = getLhs().getType();
748 if (getLhs() == getRhs())
753 if (rhsCst->isZero() && isUnsigned)
760 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
761 commonWidth = std::max(commonWidth, 1);
764 if (isUnsigned && rhsCst->zext(commonWidth)
785 [=](
const APSInt &a,
const APSInt &b) -> APInt {
786 return APInt(1, a >= b);
790 void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
791 MLIRContext *context) {
792 results.insert<patterns::GTWithConstLHS>(context);
795 OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
796 IntType lhsType = getLhs().getType();
800 if (getLhs() == getRhs())
806 auto commonWidth = std::max<int32_t>(*
width, rhsCst->getBitWidth());
807 commonWidth = std::max(commonWidth, 1);
810 if (isUnsigned && rhsCst->zext(commonWidth)
831 [=](
const APSInt &a,
const APSInt &b) -> APInt {
832 return APInt(1, a > b);
836 OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
838 if (getLhs() == getRhs())
844 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
845 getRhs().getType() == getType())
851 [=](
const APSInt &a,
const APSInt &b) -> APInt {
852 return APInt(1, a == b);
858 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
860 auto width = op.getLhs().getType().getBitWidthOrSentinel();
863 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
864 op.getRhs().getType() == op.getType()) {
865 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
870 if (rhsCst->isZero() &&
width > 1) {
871 auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
872 return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
876 if (rhsCst->isAllOnes() &&
width > 1 &&
877 op.getLhs().getType() == op.getRhs().getType()) {
878 return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
886 OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
888 if (getLhs() == getRhs())
894 if (rhsCst->isZero() && getLhs().getType() == getType() &&
895 getRhs().getType() == getType())
901 [=](
const APSInt &a,
const APSInt &b) -> APInt {
902 return APInt(1, a != b);
908 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
910 auto width = op.getLhs().getType().getBitWidthOrSentinel();
913 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
914 op.getRhs().getType() == op.getType()) {
915 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
920 if (rhsCst->isZero() &&
width > 1) {
921 return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
926 if (rhsCst->isAllOnes() &&
width > 1 &&
927 op.getLhs().getType() == op.getRhs().getType()) {
928 auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
929 return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
937 OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
943 OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
949 OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
959 OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
960 auto base = getInput().getType();
967 OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
974 OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
982 if (getType().base().hasWidth())
989 void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
990 MLIRContext *context) {
991 results.insert<patterns::StoUtoS>(context);
994 OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1002 if (getType().base().hasWidth())
1009 void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1010 MLIRContext *context) {
1011 results.insert<patterns::UtoStoU>(context);
1014 OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1016 if (getInput().getType() == getType())
1026 OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1028 if (getInput().getType() == getType())
1038 OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1044 getType().base().getWidthOrSentinel()))
1050 void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1051 MLIRContext *context) {
1052 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1055 OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1062 getType().base().getWidthOrSentinel()))
1063 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1068 OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1073 getType().base().getWidthOrSentinel()))
1079 void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1080 MLIRContext *context) {
1081 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1082 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1083 patterns::NotGt>(context);
1086 OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1090 if (getInput().getType().getBitWidthOrSentinel() == 0)
1095 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1099 if (
isUInt1(getInput().getType()))
1105 void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1106 MLIRContext *context) {
1108 .insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1109 patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
1110 patterns::AndRCatZeroL, patterns::AndRCatZeroR,
1111 patterns::AndRCatAndR_left, patterns::AndRCatAndR_right>(context);
1114 OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1118 if (getInput().getType().getBitWidthOrSentinel() == 0)
1123 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1127 if (
isUInt1(getInput().getType()))
1133 void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1134 MLIRContext *context) {
1135 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1136 patterns::OrRCatZeroH, patterns::OrRCatZeroL,
1137 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right>(context);
1140 OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1144 if (getInput().getType().getBitWidthOrSentinel() == 0)
1149 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1152 if (
isUInt1(getInput().getType()))
1158 void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1159 MLIRContext *context) {
1160 results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1161 patterns::XorRCatZeroH, patterns::XorRCatZeroL,
1162 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right>(
1170 OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1174 IntType lhsType = getLhs().getType();
1175 IntType rhsType = getRhs().getType();
1187 return getIntAttr(getType(), lhs->concat(*rhs));
1192 void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1193 MLIRContext *context) {
1194 results.insert<patterns::DShlOfConstant>(context);
1197 void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1198 MLIRContext *context) {
1199 results.insert<patterns::DShrOfConstant>(context);
1202 void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1203 MLIRContext *context) {
1204 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1205 patterns::CatCast>(context);
1208 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1211 if (op.getType() == op.getInput().getType())
1212 return op.getInput();
1216 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1217 if (op.getType() == in.getInput().getType())
1218 return in.getInput();
1223 OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1224 IntType inputType = getInput().getType();
1225 IntType resultType = getType();
1227 if (inputType == getType() && resultType.
hasWidth())
1234 cst->extractBits(getHi() - getLo() + 1, getLo()));
1239 void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1240 MLIRContext *context) {
1242 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1243 patterns::BitsOfAnd, patterns::BitsOfPad>(context);
1250 unsigned loBit, PatternRewriter &rewriter) {
1251 auto resType = type_cast<IntType>(op->getResult(0).getType());
1252 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1253 value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
1255 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1256 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1257 }
else if (resType.isUnsigned() &&
1258 !type_cast<IntType>(value.getType()).isUnsigned()) {
1259 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1261 rewriter.replaceOp(op, value);
1264 template <
typename OpTy>
1265 static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
1267 if (op.getType().getBitWidthOrSentinel() == 0)
1269 APInt(0, 0, op.getType().isSignedInteger()));
1272 if (op.getHigh() == op.getLow())
1273 return op.getHigh();
1278 if (op.getType().getBitWidthOrSentinel() < 0)
1283 if (cond->isZero() && op.getLow().getType() == op.getType())
1285 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1286 return op.getHigh();
1290 if (
auto lowCst =
getConstant(adaptor.getLow())) {
1292 if (
auto highCst =
getConstant(adaptor.getHigh())) {
1294 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1295 *highCst == *lowCst)
1298 if (highCst->isOne() && lowCst->isZero() &&
1299 op.getType() == op.getSel().getType())
1312 OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1313 return foldMux(*
this, adaptor);
1316 OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1317 return foldMux(*
this, adaptor);
1320 OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return {}; }
1327 class MuxPad :
public mlir::RewritePattern {
1329 MuxPad(MLIRContext *context)
1330 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1333 matchAndRewrite(Operation *op,
1334 mlir::PatternRewriter &rewriter)
const override {
1335 auto mux = cast<MuxPrimOp>(op);
1336 auto width = mux.getType().getBitWidthOrSentinel();
1340 auto pad = [&](Value input) -> Value {
1342 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1343 if (inputWidth < 0 ||
width == inputWidth)
1346 .create<PadPrimOp>(mux.getLoc(), mux.getType(), input,
width)
1350 auto newHigh = pad(mux.getHigh());
1351 auto newLow = pad(mux.getLow());
1352 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1355 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1356 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1364 class MuxSharedCond :
public mlir::RewritePattern {
1366 MuxSharedCond(MLIRContext *context)
1367 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1369 static const int depthLimit = 5;
1371 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1372 mlir::PatternRewriter &rewriter,
1373 bool updateInPlace)
const {
1374 if (updateInPlace) {
1375 rewriter.modifyOpInPlace(mux, [&] {
1376 mux.setOperand(1, high);
1377 mux.setOperand(2, low);
1381 rewriter.setInsertionPointAfter(mux);
1383 .create<MuxPrimOp>(mux.getLoc(), mux.getType(),
1384 ValueRange{mux.getSel(), high, low})
1389 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1390 bool updateInPlace,
int limit)
const {
1391 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1394 if (mux.getSel() == cond)
1395 return mux.getHigh();
1396 if (limit > depthLimit)
1398 updateInPlace &= mux->hasOneUse();
1400 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1402 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1405 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1406 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1411 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1412 bool updateInPlace,
int limit)
const {
1413 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1416 if (mux.getSel() == cond)
1417 return mux.getLow();
1418 if (limit > depthLimit)
1420 updateInPlace &= mux->hasOneUse();
1422 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1424 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1426 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1428 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1434 matchAndRewrite(Operation *op,
1435 mlir::PatternRewriter &rewriter)
const override {
1436 auto mux = cast<MuxPrimOp>(op);
1437 auto width = mux.getType().getBitWidthOrSentinel();
1441 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter,
true, 0)) {
1442 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1446 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter,
true, 0)) {
1447 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1456 void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1457 MLIRContext *context) {
1459 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1460 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1461 patterns::MuxSameTrue, patterns::MuxSameFalse,
1462 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1466 void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1467 RewritePatternSet &results, MLIRContext *context) {
1468 results.add<patterns::Mux2PadSel>(context);
1471 void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1472 RewritePatternSet &results, MLIRContext *context) {
1473 results.add<patterns::Mux4PadSel>(context);
1476 OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1477 auto input = this->getInput();
1480 if (input.getType() == getType())
1484 auto inputType = input.getType().base();
1491 auto destWidth = getType().base().getWidthOrSentinel();
1492 if (destWidth == -1)
1495 if (inputType.
isSigned() && cst->getBitWidth())
1496 return getIntAttr(getType(), cst->sext(destWidth));
1497 return getIntAttr(getType(), cst->zext(destWidth));
1503 OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1504 auto input = this->getInput();
1505 IntType inputType = input.getType();
1506 int shiftAmount = getAmount();
1509 if (shiftAmount == 0)
1515 if (inputWidth != -1) {
1516 auto resultWidth = inputWidth + shiftAmount;
1517 shiftAmount = std::min(shiftAmount, resultWidth);
1518 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1524 OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1525 auto input = this->getInput();
1526 IntType inputType = input.getType();
1527 int shiftAmount = getAmount();
1533 if (shiftAmount == 0 && inputWidth > 0)
1536 if (inputWidth == -1)
1538 if (inputWidth == 0)
1543 if (shiftAmount >= inputWidth && inputType.
isUnsigned())
1544 return getIntAttr(getType(), APInt(0, 0,
false));
1550 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1552 value = cst->lshr(std::min(shiftAmount, inputWidth));
1553 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1554 return getIntAttr(getType(), value.trunc(resultWidth));
1560 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1561 if (inputWidth <= 0)
1565 unsigned shiftAmount = op.getAmount();
1566 if (
int(shiftAmount) >= inputWidth) {
1568 if (op.getType().base().isUnsigned())
1574 shiftAmount = inputWidth - 1;
1577 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1582 PatternRewriter &rewriter) {
1583 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1584 if (inputWidth <= 0)
1588 unsigned keepAmount = op.getAmount();
1590 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1595 OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1599 getInput().getType().base().getWidthOrSentinel() - getAmount();
1600 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1606 OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1610 cst->trunc(getType().base().getWidthOrSentinel()));
1615 PatternRewriter &rewriter) {
1616 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1617 if (inputWidth <= 0)
1621 unsigned dropAmount = op.getAmount();
1622 if (dropAmount !=
unsigned(inputWidth))
1628 void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1629 MLIRContext *context) {
1630 results.add<patterns::SubaccessOfConstant>(context);
1633 OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1635 if (adaptor.getInputs().size() == 1)
1636 return getOperand(1);
1638 if (
auto constIndex =
getConstant(adaptor.getIndex())) {
1639 auto index = constIndex->getZExtValue();
1640 if (index < getInputs().size())
1641 return getInputs()[getInputs().size() - 1 - index];
1648 PatternRewriter &rewriter) {
1652 if (llvm::all_of(op.getInputs().drop_front(), [&](
auto input) {
1653 return input == op.getInputs().front();
1661 if (
auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
1662 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](
auto e) {
1663 auto subindex = e.value().template getDefiningOp<SubindexOp>();
1664 return subindex && lastSubindex.getInput() == subindex.getInput() &&
1665 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
1667 replaceOpWithNewOpAndCopyName<SubaccessOp>(
1668 rewriter, op, lastSubindex.getInput(), op.getIndex());
1674 if (op.getInputs().size() != 2)
1678 auto uintType = op.getIndex().getType();
1679 if (uintType.getBitWidthOrSentinel() != 1)
1683 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1684 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
1704 for (Operation *user : value.getUsers()) {
1706 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
1709 if (
auto aConnect = dyn_cast<FConnectLike>(user))
1710 if (aConnect.getDest() == value) {
1711 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
1714 if (!matchingConnect || (
connect &&
connect != matchingConnect) ||
1715 matchingConnect->getBlock() != value.getParentBlock())
1726 PatternRewriter &rewriter) {
1729 Operation *connectedDecl = op.getDest().getDefiningOp();
1734 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
1739 cast<Forceable>(connectedDecl).isForceable())
1747 if (connectedDecl->hasOneUse())
1751 auto *declBlock = connectedDecl->getBlock();
1752 auto *srcValueOp = op.getSrc().getDefiningOp();
1755 if (!isa<WireOp>(connectedDecl))
1761 if (!isa<ConstantOp>(srcValueOp))
1763 if (srcValueOp->getBlock() != declBlock)
1769 auto replacement = op.getSrc();
1772 if (srcValueOp && srcValueOp != &declBlock->front())
1773 srcValueOp->moveBefore(&declBlock->front());
1780 rewriter.eraseOp(op);
1784 void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1785 MLIRContext *context) {
1786 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
1791 PatternRewriter &rewriter) {
1808 for (
auto *user : value.getUsers()) {
1809 auto attach = dyn_cast<AttachOp>(user);
1810 if (!attach || attach == dominatedAttach)
1812 if (attach->isBeforeInBlock(dominatedAttach))
1820 if (op.getNumOperands() <= 1) {
1821 rewriter.eraseOp(op);
1825 for (
auto operand : op.getOperands()) {
1832 SmallVector<Value> newOperands(op.getOperands());
1833 for (
auto newOperand : attach.getOperands())
1834 if (newOperand != operand)
1835 newOperands.push_back(newOperand);
1836 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1837 rewriter.eraseOp(attach);
1838 rewriter.eraseOp(op);
1846 if (
auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
1847 if (!
hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
1848 !wire.isForceable()) {
1849 SmallVector<Value> newOperands;
1850 for (
auto newOperand : op.getOperands())
1851 if (newOperand != operand)
1852 newOperands.push_back(newOperand);
1854 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1855 rewriter.eraseOp(op);
1856 rewriter.eraseOp(wire);
1867 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
1868 rewriter.inlineBlockBefore(®ion.front(), op, {});
1872 if (
auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
1873 if (constant.getValue().isAllOnes())
1875 else if (op.hasElseRegion() && !op.getElseRegion().empty())
1878 rewriter.eraseOp(op);
1884 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
1885 op.getElseBlock().empty()) {
1886 rewriter.eraseBlock(&op.getElseBlock());
1893 if (!op.getThenBlock().empty())
1897 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
1898 rewriter.eraseOp(op);
1907 struct FoldNodeName :
public mlir::RewritePattern {
1908 FoldNodeName(MLIRContext *context)
1909 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1910 LogicalResult matchAndRewrite(Operation *op,
1911 PatternRewriter &rewriter)
const override {
1912 auto node = cast<NodeOp>(op);
1913 auto name = node.getNameAttr();
1914 if (!node.hasDroppableName() || node.getInnerSym() ||
1917 auto *newOp = node.getInput().getDefiningOp();
1919 if (newOp && !isa<InstanceOp>(newOp))
1921 rewriter.replaceOp(node, node.getInput());
1927 struct NodeBypass :
public mlir::RewritePattern {
1928 NodeBypass(MLIRContext *context)
1929 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1930 LogicalResult matchAndRewrite(Operation *op,
1931 PatternRewriter &rewriter)
const override {
1932 auto node = cast<NodeOp>(op);
1933 if (node.getInnerSym() || !
AnnotationSet(node).canBeDeleted() ||
1934 node.use_empty() || node.isForceable())
1936 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
1943 template <
typename OpTy>
1945 PatternRewriter &rewriter) {
1946 if (!op.isForceable() || !op.getDataRef().use_empty())
1954 LogicalResult NodeOp::fold(FoldAdaptor adaptor,
1955 SmallVectorImpl<OpFoldResult> &results) {
1960 if (getAnnotationsAttr() &&
1965 if (!adaptor.getInput())
1968 results.push_back(adaptor.getInput());
1972 void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1973 MLIRContext *context) {
1974 results.insert<FoldNodeName>(context);
1975 results.add(demoteForceableIfUnused<NodeOp>);
1981 struct AggOneShot :
public mlir::RewritePattern {
1982 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
1983 : RewritePattern(name, 0, context) {}
1985 SmallVector<Value> getCompleteWrite(Operation *lhs)
const {
1986 auto lhsTy = lhs->getResult(0).getType();
1987 if (!type_isa<BundleType, FVectorType>(lhsTy))
1990 DenseMap<uint32_t, Value> fields;
1991 for (Operation *user : lhs->getResult(0).getUsers()) {
1992 if (user->getParentOp() != lhs->getParentOp())
1994 if (
auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
1995 if (aConnect.getDest() == lhs->getResult(0))
1997 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
1998 for (Operation *subuser : subField.getResult().getUsers()) {
1999 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2000 if (aConnect.getDest() == subField) {
2001 if (subuser->getParentOp() != lhs->getParentOp())
2003 if (fields.count(subField.getFieldIndex()))
2005 fields[subField.getFieldIndex()] = aConnect.getSrc();
2011 }
else if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2012 for (Operation *subuser : subIndex.getResult().getUsers()) {
2013 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2014 if (aConnect.getDest() == subIndex) {
2015 if (subuser->getParentOp() != lhs->getParentOp())
2017 if (fields.count(subIndex.getIndex()))
2019 fields[subIndex.getIndex()] = aConnect.getSrc();
2030 SmallVector<Value> values;
2031 uint32_t total = type_isa<BundleType>(lhsTy)
2032 ? type_cast<BundleType>(lhsTy).getNumElements()
2033 : type_cast<FVectorType>(lhsTy).getNumElements();
2034 for (uint32_t i = 0; i < total; ++i) {
2035 if (!fields.count(i))
2037 values.push_back(fields[i]);
2042 LogicalResult matchAndRewrite(Operation *op,
2043 PatternRewriter &rewriter)
const override {
2044 auto values = getCompleteWrite(op);
2047 rewriter.setInsertionPointToEnd(op->getBlock());
2048 auto dest = op->getResult(0);
2049 auto destType = dest.getType();
2052 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2055 Value newVal = type_isa<BundleType>(destType)
2056 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2058 : rewriter.createOrFold<VectorCreateOp>(
2059 op->getLoc(), destType, values);
2060 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2061 for (Operation *user : dest.getUsers()) {
2062 if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2063 for (Operation *subuser :
2064 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2065 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2066 if (aConnect.getDest() == subIndex)
2067 rewriter.eraseOp(aConnect);
2068 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2069 for (Operation *subuser :
2070 llvm::make_early_inc_range(subField.getResult().getUsers()))
2071 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2072 if (aConnect.getDest() == subField)
2073 rewriter.eraseOp(aConnect);
2080 struct WireAggOneShot :
public AggOneShot {
2081 WireAggOneShot(MLIRContext *context)
2082 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2084 struct SubindexAggOneShot :
public AggOneShot {
2085 SubindexAggOneShot(MLIRContext *context)
2086 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2088 struct SubfieldAggOneShot :
public AggOneShot {
2089 SubfieldAggOneShot(MLIRContext *context)
2090 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2094 void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2095 MLIRContext *context) {
2096 results.insert<WireAggOneShot>(context);
2097 results.add(demoteForceableIfUnused<WireOp>);
2100 void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2101 MLIRContext *context) {
2102 results.insert<SubindexAggOneShot>(context);
2105 OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2106 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2109 return attr[getIndex()];
2112 OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2113 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2116 auto index = getFieldIndex();
2120 void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2121 MLIRContext *context) {
2122 results.insert<SubfieldAggOneShot>(context);
2126 ArrayRef<Attribute> operands) {
2127 for (
auto operand : operands)
2133 OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2136 if (getNumOperands() > 0)
2137 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2138 if (first.getFieldIndex() == 0 &&
2139 first.getInput().getType() == getType() &&
2141 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2143 elem.value().
template getDefiningOp<SubfieldOp>();
2144 return subindex && subindex.getInput() == first.getInput() &&
2145 subindex.getFieldIndex() == elem.index();
2147 return first.getInput();
2152 OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2155 if (getNumOperands() > 0)
2156 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2157 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2159 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2161 elem.value().
template getDefiningOp<SubindexOp>();
2162 return subindex && subindex.getInput() == first.getInput() &&
2163 subindex.getIndex() == elem.index();
2165 return first.getInput();
2170 OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2171 if (getOperand().getType() == getType())
2172 return getOperand();
2179 struct FoldResetMux :
public mlir::RewritePattern {
2180 FoldResetMux(MLIRContext *context)
2181 : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2182 LogicalResult matchAndRewrite(Operation *op,
2183 PatternRewriter &rewriter)
const override {
2184 auto reg = cast<RegResetOp>(op);
2186 dyn_cast_or_null<ConstantOp>(
reg.getResetValue().getDefiningOp());
2195 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2198 auto *high = mux.getHigh().getDefiningOp();
2199 auto *low = mux.getLow().getDefiningOp();
2200 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2202 if (constOp && low !=
reg)
2204 if (dyn_cast_or_null<ConstantOp>(low) && high ==
reg)
2205 constOp = dyn_cast<ConstantOp>(low);
2207 if (!constOp || constOp.getType() != reset.getType() ||
2208 constOp.getValue() != reset.getValue())
2212 auto regTy =
reg.getResult().getType();
2213 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2214 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2215 regTy.getBitWidthOrSentinel() < 0)
2221 if (constOp != &con->getBlock()->front())
2222 constOp->moveBefore(&con->getBlock()->front());
2227 rewriter.eraseOp(con);
2234 if (
auto c = v.getDefiningOp<ConstantOp>())
2235 return c.getValue().isOne();
2236 if (
auto sc = v.getDefiningOp<SpecialConstantOp>())
2237 return sc.getValue();
2241 static LogicalResult
2248 replaceOpWithNewOpAndCopyName<NodeOp>(
2249 rewriter,
reg,
reg.getResetValue(),
reg.getNameAttr(),
reg.getNameKind(),
2250 reg.getAnnotationsAttr(),
reg.getInnerSymAttr(),
reg.getForceable());
2254 void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2255 MLIRContext *context) {
2256 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2258 results.add(demoteForceableIfUnused<RegResetOp>);
2263 auto portTy = type_cast<BundleType>(port.getType());
2264 auto fieldIndex = portTy.getElementIndex(name);
2265 assert(fieldIndex &&
"missing field on memory port");
2268 for (
auto *op : port.getUsers()) {
2269 auto portAccess = cast<SubfieldOp>(op);
2270 if (fieldIndex != portAccess.getFieldIndex())
2275 value = conn.getSrc();
2285 auto portConst = value.getDefiningOp<ConstantOp>();
2288 return portConst.getValue().isZero();
2293 auto portTy = type_cast<BundleType>(port.getType());
2294 auto fieldIndex = portTy.getElementIndex(
data);
2295 assert(fieldIndex &&
"missing enable flag on memory port");
2297 for (
auto *op : port.getUsers()) {
2298 auto portAccess = cast<SubfieldOp>(op);
2299 if (fieldIndex != portAccess.getFieldIndex())
2301 if (!portAccess.use_empty())
2310 StringRef name, Value value) {
2311 auto portTy = type_cast<BundleType>(port.getType());
2312 auto fieldIndex = portTy.getElementIndex(name);
2313 assert(fieldIndex &&
"missing field on memory port");
2315 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2316 auto portAccess = cast<SubfieldOp>(op);
2317 if (fieldIndex != portAccess.getFieldIndex())
2319 rewriter.replaceAllUsesWith(portAccess, value);
2320 rewriter.eraseOp(portAccess);
2325 static void erasePort(PatternRewriter &rewriter, Value port) {
2328 auto getClock = [&] {
2330 clock = rewriter.create<SpecialConstantOp>(
2339 for (
auto *op : port.getUsers()) {
2340 auto subfield = dyn_cast<SubfieldOp>(op);
2342 auto ty = port.getType();
2343 auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
2344 rewriter.replaceAllUsesWith(port,
reg.getResult());
2353 for (
auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2354 auto access = cast<SubfieldOp>(accessOp);
2355 for (
auto *user : llvm::make_early_inc_range(access->getUsers())) {
2356 auto connect = dyn_cast<FConnectLike>(user);
2358 rewriter.eraseOp(user);
2362 if (access.use_empty()) {
2363 rewriter.eraseOp(access);
2369 auto ty = access.getType();
2370 auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
2371 rewriter.replaceOp(access,
reg.getResult());
2373 assert(port.use_empty() &&
"port should have no remaining uses");
2378 struct FoldZeroWidthMemory :
public mlir::RewritePattern {
2379 FoldZeroWidthMemory(MLIRContext *context)
2380 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2381 LogicalResult matchAndRewrite(Operation *op,
2382 PatternRewriter &rewriter)
const override {
2383 MemOp mem = cast<MemOp>(op);
2387 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2388 mem.getDataType().getBitWidthOrSentinel() != 0)
2392 for (
auto port : mem.getResults())
2393 for (
auto *user : port.getUsers())
2394 if (!isa<SubfieldOp>(user))
2399 for (
auto port : op->getResults()) {
2400 for (
auto *user : llvm::make_early_inc_range(port.getUsers())) {
2401 SubfieldOp sfop = cast<SubfieldOp>(user);
2402 StringRef fieldName = sfop.getFieldName();
2403 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2404 rewriter, sfop, sfop.getResult().getType())
2406 if (fieldName.ends_with(
"data")) {
2408 auto zero = rewriter.create<firrtl::ConstantOp>(
2409 wire.getLoc(), firrtl::type_cast<IntType>(wire.getType()),
2411 rewriter.create<MatchingConnectOp>(wire.getLoc(), wire, zero);
2415 rewriter.eraseOp(op);
2421 struct FoldReadOrWriteOnlyMemory :
public mlir::RewritePattern {
2422 FoldReadOrWriteOnlyMemory(MLIRContext *context)
2423 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2424 LogicalResult matchAndRewrite(Operation *op,
2425 PatternRewriter &rewriter)
const override {
2426 MemOp mem = cast<MemOp>(op);
2429 bool isRead =
false, isWritten =
false;
2430 for (
unsigned i = 0; i < mem.getNumResults(); ++i) {
2431 switch (mem.getPortKind(i)) {
2432 case MemOp::PortKind::Read:
2437 case MemOp::PortKind::Write:
2442 case MemOp::PortKind::Debug:
2443 case MemOp::PortKind::ReadWrite:
2446 llvm_unreachable(
"unknown port kind");
2448 assert((!isWritten || !isRead) &&
"memory is in use");
2453 if (isRead && mem.getInit())
2456 for (
auto port : mem.getResults())
2459 rewriter.eraseOp(op);
2465 struct FoldUnusedPorts :
public mlir::RewritePattern {
2466 FoldUnusedPorts(MLIRContext *context)
2467 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2468 LogicalResult matchAndRewrite(Operation *op,
2469 PatternRewriter &rewriter)
const override {
2470 MemOp mem = cast<MemOp>(op);
2474 llvm::SmallBitVector deadPorts(mem.getNumResults());
2475 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2477 if (!mem.getPortAnnotation(i).empty())
2481 auto kind = mem.getPortKind(i);
2482 if (kind == MemOp::PortKind::Debug)
2491 if (kind == MemOp::PortKind::Read &&
isPortUnused(port,
"data")) {
2496 if (deadPorts.none())
2500 SmallVector<Type> resultTypes;
2501 SmallVector<StringRef> portNames;
2502 SmallVector<Attribute> portAnnotations;
2503 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2506 resultTypes.push_back(port.getType());
2507 portNames.push_back(mem.getPortName(i));
2508 portAnnotations.push_back(mem.getPortAnnotation(i));
2512 if (!resultTypes.empty())
2513 newOp = rewriter.create<MemOp>(
2514 mem.getLoc(), resultTypes, mem.getReadLatency(),
2515 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2516 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2517 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2518 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2521 unsigned nextPort = 0;
2522 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2526 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2529 rewriter.eraseOp(op);
2535 struct FoldReadWritePorts :
public mlir::RewritePattern {
2536 FoldReadWritePorts(MLIRContext *context)
2537 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2538 LogicalResult matchAndRewrite(Operation *op,
2539 PatternRewriter &rewriter)
const override {
2540 MemOp mem = cast<MemOp>(op);
2545 llvm::SmallBitVector deadReads(mem.getNumResults());
2546 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2547 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2549 if (!mem.getPortAnnotation(i).empty())
2556 if (deadReads.none())
2559 SmallVector<Type> resultTypes;
2560 SmallVector<StringRef> portNames;
2561 SmallVector<Attribute> portAnnotations;
2562 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2564 resultTypes.push_back(
2565 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2566 MemOp::PortKind::Write, mem.getMaskBits()));
2568 resultTypes.push_back(port.getType());
2570 portNames.push_back(mem.getPortName(i));
2571 portAnnotations.push_back(mem.getPortAnnotation(i));
2574 auto newOp = rewriter.create<MemOp>(
2575 mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
2576 mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
2577 mem.getName(), mem.getNameKind(), mem.getAnnotations(),
2578 rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
2579 mem.getInitAttr(), mem.getPrefixAttr());
2581 for (
unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2582 auto result = mem.getResult(i);
2583 auto newResult = newOp.getResult(i);
2585 auto resultPortTy = type_cast<BundleType>(result.getType());
2589 auto replace = [&](StringRef toName, StringRef fromName) {
2590 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2591 assert(fromFieldIndex &&
"missing enable flag on memory port");
2593 auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
2595 for (
auto *op : llvm::make_early_inc_range(result.getUsers())) {
2596 auto fromField = cast<SubfieldOp>(op);
2597 if (fromFieldIndex != fromField.getFieldIndex())
2599 rewriter.replaceOp(fromField, toField.getResult());
2603 replace(
"addr",
"addr");
2604 replace(
"en",
"en");
2605 replace(
"clk",
"clk");
2606 replace(
"data",
"wdata");
2607 replace(
"mask",
"wmask");
2610 auto wmodeFieldIndex = resultPortTy.getElementIndex(
"wmode");
2611 for (
auto *op : llvm::make_early_inc_range(result.getUsers())) {
2612 auto wmodeField = cast<SubfieldOp>(op);
2613 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2615 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2618 rewriter.replaceAllUsesWith(result, newResult);
2621 rewriter.eraseOp(op);
2627 struct FoldUnusedBits :
public mlir::RewritePattern {
2628 FoldUnusedBits(MLIRContext *context)
2629 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2631 LogicalResult matchAndRewrite(Operation *op,
2632 PatternRewriter &rewriter)
const override {
2633 MemOp mem = cast<MemOp>(op);
2638 const auto &summary = mem.getSummary();
2639 if (summary.isMasked || summary.isSeqMem())
2642 auto type = type_dyn_cast<IntType>(mem.getDataType());
2645 auto width = type.getBitWidthOrSentinel();
2649 llvm::SmallBitVector usedBits(
width);
2650 DenseMap<unsigned, unsigned> mapping;
2655 SmallVector<BitsPrimOp> readOps;
2656 auto findReadUsers = [&](Value port, StringRef field) {
2657 auto portTy = type_cast<BundleType>(port.getType());
2658 auto fieldIndex = portTy.getElementIndex(field);
2659 assert(fieldIndex &&
"missing data port");
2661 for (
auto *op : port.getUsers()) {
2662 auto portAccess = cast<SubfieldOp>(op);
2663 if (fieldIndex != portAccess.getFieldIndex())
2666 for (
auto *user : op->getUsers()) {
2667 auto bits = dyn_cast<BitsPrimOp>(user);
2673 usedBits.set(bits.getLo(), bits.getHi() + 1);
2674 mapping[bits.getLo()] = 0;
2675 readOps.push_back(bits);
2683 SmallVector<MatchingConnectOp> writeOps;
2684 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
2685 auto portTy = type_cast<BundleType>(port.getType());
2686 auto fieldIndex = portTy.getElementIndex(field);
2687 assert(fieldIndex &&
"missing data port");
2689 for (
auto *op : port.getUsers()) {
2690 auto portAccess = cast<SubfieldOp>(op);
2691 if (fieldIndex != portAccess.getFieldIndex())
2698 writeOps.push_back(conn);
2704 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2706 if (!mem.getPortAnnotation(i).empty())
2709 switch (mem.getPortKind(i)) {
2710 case MemOp::PortKind::Debug:
2713 case MemOp::PortKind::Write:
2714 if (failed(findWriteUsers(port,
"data")))
2717 case MemOp::PortKind::Read:
2718 findReadUsers(port,
"data");
2720 case MemOp::PortKind::ReadWrite:
2721 if (failed(findWriteUsers(port,
"wdata")))
2723 findReadUsers(port,
"rdata");
2726 llvm_unreachable(
"unknown port kind");
2731 if (usedBits.all() || usedBits.none())
2735 SmallVector<std::pair<unsigned, unsigned>> ranges;
2736 unsigned newWidth = 0;
2737 for (
int i = usedBits.find_first(); 0 <= i && i <
width;) {
2738 int e = usedBits.find_next_unset(i);
2741 for (
int idx = i; idx < e; ++idx, ++newWidth) {
2742 if (
auto it = mapping.find(idx); it != mapping.end()) {
2743 it->second = newWidth;
2746 ranges.emplace_back(i, e - 1);
2747 i = e !=
width ? usedBits.find_next(e) : e;
2751 auto newType =
IntType::get(op->getContext(), type.isSigned(), newWidth);
2752 SmallVector<Type> portTypes;
2753 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2754 portTypes.push_back(
2755 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
2757 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
2758 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
2759 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
2760 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
2761 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2764 auto rewriteSubfield = [&](Value port, StringRef field) {
2765 auto portTy = type_cast<BundleType>(port.getType());
2766 auto fieldIndex = portTy.getElementIndex(field);
2767 assert(fieldIndex &&
"missing data port");
2769 rewriter.setInsertionPointAfter(newMem);
2770 auto newPortAccess =
2771 rewriter.create<SubfieldOp>(port.getLoc(), port, field);
2773 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2774 auto portAccess = cast<SubfieldOp>(op);
2775 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
2777 rewriter.replaceOp(portAccess, newPortAccess.getResult());
2782 for (
auto [i, port] : llvm::enumerate(newMem.getResults())) {
2783 switch (newMem.getPortKind(i)) {
2784 case MemOp::PortKind::Debug:
2785 llvm_unreachable(
"cannot rewrite debug port");
2786 case MemOp::PortKind::Write:
2787 rewriteSubfield(port,
"data");
2789 case MemOp::PortKind::Read:
2790 rewriteSubfield(port,
"data");
2792 case MemOp::PortKind::ReadWrite:
2793 rewriteSubfield(port,
"rdata");
2794 rewriteSubfield(port,
"wdata");
2797 llvm_unreachable(
"unknown port kind");
2801 for (
auto readOp : readOps) {
2802 rewriter.setInsertionPointAfter(readOp);
2803 auto it = mapping.find(readOp.getLo());
2804 assert(it != mapping.end() &&
"bit op mapping not found");
2805 rewriter.replaceOpWithNewOp<BitsPrimOp>(
2806 readOp, readOp.getInput(),
2807 readOp.getHi() - readOp.getLo() + it->second, it->second);
2811 for (
auto writeOp : writeOps) {
2812 Value source = writeOp.getSrc();
2813 rewriter.setInsertionPoint(writeOp);
2816 for (
auto &[start, end] : ranges) {
2818 rewriter.create<BitsPrimOp>(writeOp.getLoc(), source,
end, start);
2821 rewriter.create<CatPrimOp>(writeOp.getLoc(), slice, catOfSlices);
2823 catOfSlices = slice;
2826 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
2835 struct FoldRegMems :
public mlir::RewritePattern {
2836 FoldRegMems(MLIRContext *context)
2837 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2838 LogicalResult matchAndRewrite(Operation *op,
2839 PatternRewriter &rewriter)
const override {
2840 MemOp mem = cast<MemOp>(op);
2841 const FirMemory &info = mem.getSummary();
2845 auto memModule = mem->getParentOfType<FModuleOp>();
2849 SmallPtrSet<Operation *, 8> connects;
2850 SmallVector<SubfieldOp> portAccesses;
2851 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2852 if (!mem.getPortAnnotation(i).empty())
2855 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
2856 auto portTy = type_cast<BundleType>(port.getType());
2857 for (
auto field : fields) {
2858 auto fieldIndex = portTy.getElementIndex(field);
2859 assert(fieldIndex &&
"missing field on memory port");
2861 for (
auto *op : port.getUsers()) {
2862 auto portAccess = cast<SubfieldOp>(op);
2863 if (fieldIndex != portAccess.getFieldIndex())
2865 portAccesses.push_back(portAccess);
2866 for (
auto *user : portAccess->getUsers()) {
2867 auto conn = dyn_cast<FConnectLike>(user);
2870 connects.insert(conn);
2877 switch (mem.getPortKind(i)) {
2878 case MemOp::PortKind::Debug:
2880 case MemOp::PortKind::Read:
2881 if (failed(collect({
"clk",
"en",
"addr"})))
2884 case MemOp::PortKind::Write:
2885 if (failed(collect({
"clk",
"en",
"addr",
"data",
"mask"})))
2888 case MemOp::PortKind::ReadWrite:
2889 if (failed(collect({
"clk",
"en",
"addr",
"wmode",
"wdata",
"wmask"})))
2895 if (!portClock || (clock && portClock != clock))
2901 auto ty = mem.getDataType();
2902 rewriter.setInsertionPointAfterValue(clock);
2903 auto reg = rewriter.create<RegOp>(mem.getLoc(), ty, clock, mem.getName())
2907 auto pipeline = [&](Value value, Value clock,
const Twine &name,
2909 for (
unsigned i = 0; i < latency; ++i) {
2910 std::string regName;
2912 llvm::raw_string_ostream os(regName);
2913 os << mem.getName() <<
"_" << name <<
"_" << i;
2917 .create<RegOp>(mem.getLoc(), value.getType(), clock,
2918 rewriter.getStringAttr(regName))
2920 rewriter.create<MatchingConnectOp>(value.getLoc(),
reg, value);
2931 SmallVector<std::tuple<Value, Value, Value>> writes;
2932 for (
auto [i, port] : llvm::enumerate(mem.getResults())) {
2934 StringRef name = mem.getPortName(i);
2936 auto portPipeline = [&, port = port](StringRef field,
unsigned stages) {
2939 rewriter.setInsertionPointAfterValue(value);
2940 return pipeline(value, portClock, name +
"_" + field, stages);
2943 switch (mem.getPortKind(i)) {
2944 case MemOp::PortKind::Debug:
2945 llvm_unreachable(
"unknown port kind");
2946 case MemOp::PortKind::Read: {
2951 rewriter.setInsertionPointAfterValue(
reg);
2955 case MemOp::PortKind::Write: {
2956 auto data = portPipeline(
"data", writeStages);
2957 auto en = portPipeline(
"en", writeStages);
2958 auto mask = portPipeline(
"mask", writeStages);
2959 writes.emplace_back(data, en, mask);
2962 case MemOp::PortKind::ReadWrite: {
2964 rewriter.setInsertionPointAfterValue(
reg);
2968 auto wdata = portPipeline(
"wdata", writeStages);
2969 auto wmask = portPipeline(
"wmask", writeStages);
2973 rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
2975 auto wen = rewriter.create<AndPrimOp>(port.getLoc(),
en,
wmode);
2977 pipeline(wen, portClock, name +
"_wen", writeStages);
2978 writes.emplace_back(wdata, wenPipelined, wmask);
2985 rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
2987 for (
auto &[data, en, mask] : writes) {
2992 Location loc = mem.getLoc();
2994 for (
unsigned i = 0; i < info.
maskBits; ++i) {
2995 unsigned hi = (i + 1) * maskGran - 1;
2996 unsigned lo = i * maskGran;
2998 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc,
data, hi, lo);
2999 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3000 auto bit = rewriter.createOrFold<BitsPrimOp>(loc,
mask, i, i);
3001 auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
3004 masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
3010 next = rewriter.create<MuxPrimOp>(next.getLoc(),
en, masked, next);
3012 rewriter.create<MatchingConnectOp>(
reg.getLoc(),
reg, next);
3015 for (Operation *conn : connects)
3016 rewriter.eraseOp(conn);
3017 for (
auto portAccess : portAccesses)
3018 rewriter.eraseOp(portAccess);
3019 rewriter.eraseOp(mem);
3026 void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3027 MLIRContext *context) {
3029 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3030 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3050 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3053 auto *high = mux.getHigh().getDefiningOp();
3054 auto *low = mux.getLow().getDefiningOp();
3056 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3063 bool constReg =
false;
3065 if (constOp && low ==
reg)
3067 else if (dyn_cast_or_null<ConstantOp>(low) && high ==
reg) {
3069 constOp = dyn_cast<ConstantOp>(low);
3076 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3080 auto regTy =
reg.getResult().getType();
3081 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3082 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3083 regTy.getBitWidthOrSentinel() < 0)
3089 if (constOp != &con->getBlock()->front())
3090 constOp->moveBefore(&con->getBlock()->front());
3093 SmallVector<NamedAttribute, 2> attrs(
reg->getDialectAttrs());
3094 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3095 rewriter,
reg,
reg.getResult().getType(),
reg.getClockVal(),
3096 mux.getSel(), mux.getHigh(),
reg.getNameAttr(),
reg.getNameKindAttr(),
3097 reg.getAnnotationsAttr(),
reg.getInnerSymAttr(),
3098 reg.getForceableAttr());
3099 newReg->setDialectAttrs(attrs);
3101 auto pt = rewriter.saveInsertionPoint();
3102 rewriter.setInsertionPoint(con);
3103 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3104 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3105 rewriter.restoreInsertionPoint(pt);
3110 if (!
hasDontTouch(op.getOperation()) && !op.isForceable() &&
3126 PatternRewriter &rewriter,
3129 if (
auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3130 if (constant.getValue().isZero()) {
3131 rewriter.eraseOp(op);
3137 if (
auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3138 if (constant.getValue().isZero() == eraseIfZero) {
3139 rewriter.eraseOp(op);
3147 template <
class Op,
bool EraseIfZero = false>
3149 PatternRewriter &rewriter) {
3154 void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3155 MLIRContext *context) {
3156 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3159 void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3160 MLIRContext *context) {
3161 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3164 void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3165 RewritePatternSet &results, MLIRContext *context) {
3166 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3169 void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3170 MLIRContext *context) {
3171 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3179 PatternRewriter &rewriter) {
3181 if (op.use_empty()) {
3182 rewriter.eraseOp(op);
3189 if (op->hasOneUse() &&
3190 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3191 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3192 *op->user_begin()) ||
3193 (isa<CvtPrimOp>(*op->user_begin()) &&
3194 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3195 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3196 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3197 .getBitWidthOrSentinel() > 0))) {
3198 auto *modop = *op->user_begin();
3199 auto inv = rewriter.create<InvalidValueOp>(op.getLoc(),
3200 modop->getResult(0).getType());
3201 rewriter.replaceAllOpUsesWith(modop, inv);
3202 rewriter.eraseOp(modop);
3203 rewriter.eraseOp(op);
3209 OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3210 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3211 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3219 OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3238 PatternRewriter &rewriter) {
3240 if (
auto testEnable = op.getTestEnable()) {
3241 if (
auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3242 if (constOp.getValue().isZero()) {
3243 rewriter.modifyOpInPlace(op,
3244 [&] { op.getTestEnableMutable().clear(); });
3258 static LogicalResult
3260 auto forceable = op.getRef().getDefiningOp<Forceable>();
3261 if (!forceable || !forceable.isForceable() ||
3262 op.getRef() != forceable.getDataRef() ||
3263 op.getType() != forceable.getDataType())
3265 rewriter.replaceAllUsesWith(op, forceable.getData());
3269 void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3270 MLIRContext *context) {
3271 results.insert<patterns::RefResolveOfRefSend>(context);
3275 OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3277 if (getInput().getType() == getType())
3283 auto constOp = operand.getDefiningOp<ConstantOp>();
3284 return constOp && constOp.getValue().isZero();
3287 template <
typename Op>
3290 rewriter.eraseOp(op);
3296 void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3297 MLIRContext *context) {
3298 results.add(eraseIfPredFalse<RefForceOp>);
3300 void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3301 MLIRContext *context) {
3302 results.add(eraseIfPredFalse<RefForceInitialOp>);
3304 void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3305 MLIRContext *context) {
3306 results.add(eraseIfPredFalse<RefReleaseOp>);
3308 void RefReleaseInitialOp::getCanonicalizationPatterns(
3309 RewritePatternSet &results, MLIRContext *context) {
3310 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3317 OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3323 if (adaptor.getReset())
3328 if (
isUInt1(getReset().getType()) && adaptor.getClock())
3341 [&](
auto ty) ->
bool {
return isTypeEmpty(ty.getElementType()); })
3342 .Case<BundleType>([&](
auto ty) ->
bool {
3343 for (
auto elem : ty.getElements())
3348 .Case<IntType>([&](
auto ty) {
return ty.getWidth() == 0; })
3349 .Default([](
auto) ->
bool {
return false; });
3353 PatternRewriter &rewriter) {
3354 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3361 rewriter.eraseOp(op);
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion)
Replaces the given op with the contents of the given single-block region.
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
bool hasWidth() const
Return true if this integer type has a known width.
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
def connect(destination, source)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)