20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
28using namespace firrtl;
32static Value
dropWrite(PatternRewriter &rewriter, OpResult old,
34 SmallPtrSet<Operation *, 8> users;
35 for (
auto *user : old.getUsers())
37 for (Operation *user : users)
38 if (
auto connect = dyn_cast<FConnectLike>(user))
39 if (connect.getDest() == old)
40 rewriter.eraseOp(user);
32static Value
dropWrite(PatternRewriter &rewriter, OpResult old, {
…}
49 Operation *op = passthrough.getDefiningOp();
52 assert(op &&
"passthrough must be an operation");
53 Operation *oldOp = old.getOwner();
54 auto name = oldOp->getAttrOfType<StringAttr>(
"name");
55 if (name && !name.getValue().empty())
56 op->setAttr(
"name", name);
64#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
72 auto resultType = type_cast<IntType>(op->getResult(0).getType());
73 if (!resultType.hasWidth())
75 for (Value operand : op->getOperands())
76 if (!type_cast<IntType>(operand.getType()).hasWidth())
83 auto t = type_dyn_cast<UIntType>(type);
84 if (!t || !t.hasWidth() || t.getWidth() != 1)
91static void updateName(PatternRewriter &rewriter, Operation *op,
94 assert(!isa<InstanceOp>(op));
95 if (!name || name.getValue().empty())
97 auto newName = name.getValue();
98 auto newOpName = op->getAttrOfType<StringAttr>(
"name");
101 newName =
chooseName(newOpName.getValue(), name.getValue());
103 if (!newOpName || newOpName.getValue() != newName)
104 rewriter.modifyOpInPlace(
105 op, [&] { op->setAttr(
"name", rewriter.getStringAttr(newName)); });
91static void updateName(PatternRewriter &rewriter, Operation *op, {
…}
113 if (
auto *newOp = newValue.getDefiningOp()) {
114 auto name = op->getAttrOfType<StringAttr>(
"name");
117 rewriter.replaceOp(op, newValue);
123template <
typename OpTy,
typename... Args>
125 Operation *op, Args &&...args) {
126 auto name = op->getAttrOfType<StringAttr>(
"name");
128 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
136 if (
auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
137 return namableOp.hasDroppableName();
148static std::optional<APSInt>
150 assert(type_cast<IntType>(operand.getType()) &&
151 "getExtendedConstant is limited to integer types");
158 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
163 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
164 return APSInt(destWidth,
165 type_cast<IntType>(operand.getType()).isUnsigned());
173 if (
auto attr = dyn_cast<BoolAttr>(operand))
174 return APSInt(APInt(1, attr.getValue()));
175 if (
auto attr = dyn_cast<IntegerAttr>(operand))
176 return attr.getAPSInt();
184 return cst->isZero();
211 Operation *op, ArrayRef<Attribute> operands,
BinOpKind opKind,
212 const function_ref<APInt(
const APSInt &,
const APSInt &)> &calculate) {
213 assert(operands.size() == 2 &&
"binary op takes two operands");
216 auto resultType = type_cast<IntType>(op->getResult(0).getType());
217 if (resultType.getWidthOrSentinel() < 0)
221 if (resultType.getWidthOrSentinel() == 0)
222 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
228 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
230 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
231 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
232 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
233 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
234 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
238 int32_t operandWidth;
241 operandWidth = resultType.getWidthOrSentinel();
246 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
250 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
261 APInt resultValue = calculate(*lhs, *rhs);
266 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
268 assert((
unsigned)resultType.getWidthOrSentinel() ==
269 resultValue.getBitWidth());
282 Operation *op, PatternRewriter &rewriter,
283 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
285 if (op->getNumResults() != 1)
287 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
292 auto width = type.getBitWidthOrSentinel();
297 SmallVector<Attribute, 3> constOperands;
298 constOperands.reserve(op->getNumOperands());
299 for (
auto operand : op->getOperands()) {
301 if (
auto *defOp = operand.getDefiningOp())
302 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
303 [&](
auto op) { attr = op.getValueAttr(); });
304 constOperands.push_back(attr);
309 auto result = canonicalize(constOperands);
313 if (
auto cst = dyn_cast<Attribute>(result))
314 resultValue = op->getDialect()
315 ->materializeConstant(rewriter, cst, type, op->getLoc())
318 resultValue = cast<Value>(result);
322 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
323 resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue, width);
326 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
327 resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
328 else if (type_isa<UIntType>(type) &&
329 type_isa<SIntType>(resultValue.getType()))
330 resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
332 assert(type == resultValue.getType() &&
"canonicalization changed type");
340 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
346 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
352 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
359OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
360 assert(adaptor.getOperands().empty() &&
"constant has no operands");
361 return getValueAttr();
364OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
365 assert(adaptor.getOperands().empty() &&
"constant has no operands");
366 return getValueAttr();
369OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
370 assert(adaptor.getOperands().empty() &&
"constant has no operands");
371 return getFieldsAttr();
374OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
375 assert(adaptor.getOperands().empty() &&
"constant has no operands");
376 return getValueAttr();
379OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
380 assert(adaptor.getOperands().empty() &&
"constant has no operands");
381 return getValueAttr();
384OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
385 assert(adaptor.getOperands().empty() &&
"constant has no operands");
386 return getValueAttr();
389OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
390 assert(adaptor.getOperands().empty() &&
"constant has no operands");
391 return getValueAttr();
398OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
401 [=](
const APSInt &a,
const APSInt &b) { return a + b; });
404void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
405 MLIRContext *context) {
406 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
407 patterns::AddOfSelf, patterns::AddOfPad>(context);
410OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
413 [=](
const APSInt &a,
const APSInt &b) { return a - b; });
416void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 MLIRContext *context) {
418 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
419 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
420 patterns::SubOfPadL, patterns::SubOfPadR>(context);
423OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
435 [=](
const APSInt &a,
const APSInt &b) { return a * b; });
438OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
445 if (getLhs() == getRhs()) {
446 auto width = getType().base().getWidthOrSentinel();
451 return getIntAttr(getType(), APInt(width, 1));
468 if (
auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
469 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
474 [=](
const APSInt &a,
const APSInt &b) -> APInt {
477 return APInt(a.getBitWidth(), 0);
481OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
488 if (getLhs() == getRhs())
502 [=](
const APSInt &a,
const APSInt &b) -> APInt {
505 return APInt(a.getBitWidth(), 0);
509OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
512 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
515OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
518 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
521OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
524 [=](
const APSInt &a,
const APSInt &b) -> APInt {
525 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
531OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
534 if (rhsCst->isZero())
538 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
539 getRhs().getType() == getType())
545 if (lhsCst->isZero())
549 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
555 if (getLhs() == getRhs() && getRhs().getType() == getType())
560 [](
const APSInt &a,
const APSInt &b) -> APInt { return a & b; });
563void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
564 MLIRContext *context) {
566 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
567 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
568 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
571OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
574 if (rhsCst->isZero() && getLhs().getType() == getType())
578 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
579 getLhs().getType() == getType())
585 if (lhsCst->isZero() && getRhs().getType() == getType())
589 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
590 getRhs().getType() == getType())
595 if (getLhs() == getRhs() && getRhs().getType() == getType())
600 [](
const APSInt &a,
const APSInt &b) -> APInt { return a | b; });
603void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
604 MLIRContext *context) {
605 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
606 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
607 patterns::OrOrr>(context);
610OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
613 if (rhsCst->isZero() &&
619 if (lhsCst->isZero() &&
624 if (getLhs() == getRhs())
627 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
631 [](
const APSInt &a,
const APSInt &b) -> APInt { return a ^ b; });
634void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
635 MLIRContext *context) {
636 results.insert<patterns::extendXor, patterns::moveConstXor,
637 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
641void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
642 MLIRContext *context) {
643 results.insert<patterns::LEQWithConstLHS>(context);
646OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
647 bool isUnsigned = getLhs().getType().base().isUnsigned();
650 if (getLhs() == getRhs())
654 if (
auto width = getLhs().getType().base().
getWidth()) {
656 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
657 commonWidth = std::max(commonWidth, 1);
668 if (isUnsigned && rhsCst->zext(commonWidth)
681 [=](
const APSInt &a,
const APSInt &b) -> APInt {
682 return APInt(1, a <= b);
686void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
687 MLIRContext *context) {
688 results.insert<patterns::LTWithConstLHS>(context);
691OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
692 IntType lhsType = getLhs().getType();
696 if (getLhs() == getRhs())
706 if (
auto width = lhsType.
getWidth()) {
708 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
709 commonWidth = std::max(commonWidth, 1);
720 if (isUnsigned && rhsCst->zext(commonWidth)
733 [=](
const APSInt &a,
const APSInt &b) -> APInt {
734 return APInt(1, a < b);
738void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
739 MLIRContext *context) {
740 results.insert<patterns::GEQWithConstLHS>(context);
743OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
744 IntType lhsType = getLhs().getType();
748 if (getLhs() == getRhs())
753 if (rhsCst->isZero() && isUnsigned)
758 if (
auto width = lhsType.
getWidth()) {
760 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
761 commonWidth = std::max(commonWidth, 1);
764 if (isUnsigned && rhsCst->zext(commonWidth)
785 [=](
const APSInt &a,
const APSInt &b) -> APInt {
786 return APInt(1, a >= b);
790void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
791 MLIRContext *context) {
792 results.insert<patterns::GTWithConstLHS>(context);
795OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
796 IntType lhsType = getLhs().getType();
800 if (getLhs() == getRhs())
804 if (
auto width = lhsType.
getWidth()) {
806 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
807 commonWidth = std::max(commonWidth, 1);
810 if (isUnsigned && rhsCst->zext(commonWidth)
831 [=](
const APSInt &a,
const APSInt &b) -> APInt {
832 return APInt(1, a > b);
836OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
838 if (getLhs() == getRhs())
844 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
845 getRhs().getType() == getType())
851 [=](
const APSInt &a,
const APSInt &b) -> APInt {
852 return APInt(1, a == b);
856LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
858 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
860 auto width = op.getLhs().getType().getBitWidthOrSentinel();
863 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
864 op.getRhs().getType() == op.getType()) {
865 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
870 if (rhsCst->isZero() && width > 1) {
871 auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
872 return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
876 if (rhsCst->isAllOnes() && width > 1 &&
877 op.getLhs().getType() == op.getRhs().getType()) {
878 return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
886OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
888 if (getLhs() == getRhs())
894 if (rhsCst->isZero() && getLhs().getType() == getType() &&
895 getRhs().getType() == getType())
901 [=](
const APSInt &a,
const APSInt &b) -> APInt {
902 return APInt(1, a != b);
906LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
908 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
910 auto width = op.getLhs().getType().getBitWidthOrSentinel();
913 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
914 op.getRhs().getType() == op.getType()) {
915 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
920 if (rhsCst->isZero() && width > 1) {
921 return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
926 if (rhsCst->isAllOnes() && width > 1 &&
927 op.getLhs().getType() == op.getRhs().getType()) {
928 auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
929 return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
937OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
943OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
949OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
955OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
960 return IntegerAttr::get(
961 IntegerType::get(getContext(), lhsCst->getBitWidth()),
962 lhsCst->shl(*rhsCst));
965 if (rhsCst->isZero())
976OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
977 auto base = getInput().getType();
984OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
991OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
999 if (getType().base().hasWidth())
1006void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1007 MLIRContext *context) {
1008 results.insert<patterns::StoUtoS>(context);
1011OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1019 if (getType().base().hasWidth())
1026void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1027 MLIRContext *context) {
1028 results.insert<patterns::UtoStoU>(context);
1031OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1033 if (getInput().getType() == getType())
1038 return BoolAttr::get(getContext(), cst->getBoolValue());
1043OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1045 if (getInput().getType() == getType())
1050 return BoolAttr::get(getContext(), cst->getBoolValue());
1055OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1061 getType().base().getWidthOrSentinel()))
1067void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1068 MLIRContext *context) {
1069 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1072OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1079 getType().base().getWidthOrSentinel()))
1080 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1085OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1090 getType().base().getWidthOrSentinel()))
1096void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1097 MLIRContext *context) {
1098 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1099 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1100 patterns::NotGt>(context);
1103OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1107 if (getInput().getType().getBitWidthOrSentinel() == 0)
1112 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1116 if (
isUInt1(getInput().getType()))
1122void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1123 MLIRContext *context) {
1125 .insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1126 patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
1127 patterns::AndRCatZeroL, patterns::AndRCatZeroR,
1128 patterns::AndRCatAndR_left, patterns::AndRCatAndR_right>(context);
1131OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1135 if (getInput().getType().getBitWidthOrSentinel() == 0)
1140 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1144 if (
isUInt1(getInput().getType()))
1150void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1151 MLIRContext *context) {
1152 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1153 patterns::OrRCatZeroH, patterns::OrRCatZeroL,
1154 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right>(context);
1157OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1161 if (getInput().getType().getBitWidthOrSentinel() == 0)
1166 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1169 if (
isUInt1(getInput().getType()))
1175void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1176 MLIRContext *context) {
1177 results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1178 patterns::XorRCatZeroH, patterns::XorRCatZeroL,
1179 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right>(
1187OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1191 IntType lhsType = getLhs().getType();
1192 IntType rhsType = getRhs().getType();
1204 return getIntAttr(getType(), lhs->concat(*rhs));
1209void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1210 MLIRContext *context) {
1211 results.insert<patterns::DShlOfConstant>(context);
1214void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1215 MLIRContext *context) {
1216 results.insert<patterns::DShrOfConstant>(context);
1219void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1220 MLIRContext *context) {
1221 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1222 patterns::CatCast>(context);
1225OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1228 if (op.getType() == op.getInput().getType())
1229 return op.getInput();
1233 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1234 if (op.getType() == in.getInput().getType())
1235 return in.getInput();
1240OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1241 IntType inputType = getInput().getType();
1242 IntType resultType = getType();
1244 if (inputType == getType() && resultType.
hasWidth())
1251 cst->extractBits(getHi() - getLo() + 1, getLo()));
1256void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1257 MLIRContext *context) {
1259 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1260 patterns::BitsOfAnd, patterns::BitsOfPad>(context);
1267 unsigned loBit, PatternRewriter &rewriter) {
1268 auto resType = type_cast<IntType>(op->getResult(0).getType());
1269 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1270 value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
1272 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1273 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1274 }
else if (resType.isUnsigned() &&
1275 !type_cast<IntType>(value.getType()).isUnsigned()) {
1276 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1278 rewriter.replaceOp(op, value);
1281template <
typename OpTy>
1282static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
1284 if (op.getType().getBitWidthOrSentinel() == 0)
1286 APInt(0, 0, op.getType().isSignedInteger()));
1289 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1290 return op.getHigh();
1295 if (op.getType().getBitWidthOrSentinel() < 0)
1300 if (cond->isZero() && op.getLow().getType() == op.getType())
1302 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1303 return op.getHigh();
1307 if (
auto lowCst =
getConstant(adaptor.getLow())) {
1309 if (
auto highCst =
getConstant(adaptor.getHigh())) {
1311 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1312 *highCst == *lowCst)
1315 if (highCst->isOne() && lowCst->isZero() &&
1316 op.getType() == op.getSel().getType())
1282static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
…}
1329OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1330 return foldMux(*
this, adaptor);
1333OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1334 return foldMux(*
this, adaptor);
1337OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return {}; }
1344class MuxPad :
public mlir::RewritePattern {
1346 MuxPad(MLIRContext *context)
1347 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1350 matchAndRewrite(Operation *op,
1351 mlir::PatternRewriter &rewriter)
const override {
1352 auto mux = cast<MuxPrimOp>(op);
1353 auto width = mux.getType().getBitWidthOrSentinel();
1357 auto pad = [&](Value input) -> Value {
1359 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1360 if (inputWidth < 0 || width == inputWidth)
1363 .create<PadPrimOp>(mux.getLoc(), mux.getType(), input, width)
1367 auto newHigh = pad(mux.getHigh());
1368 auto newLow = pad(mux.getLow());
1369 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1372 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1373 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1381class MuxSharedCond :
public mlir::RewritePattern {
1383 MuxSharedCond(MLIRContext *context)
1384 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1386 static const int depthLimit = 5;
1388 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1389 mlir::PatternRewriter &rewriter,
1390 bool updateInPlace)
const {
1391 if (updateInPlace) {
1392 rewriter.modifyOpInPlace(mux, [&] {
1393 mux.setOperand(1, high);
1394 mux.setOperand(2, low);
1398 rewriter.setInsertionPointAfter(mux);
1400 .create<MuxPrimOp>(mux.getLoc(), mux.getType(),
1401 ValueRange{mux.getSel(), high, low})
1406 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1407 bool updateInPlace,
int limit)
const {
1408 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1411 if (mux.getSel() == cond)
1412 return mux.getHigh();
1413 if (limit > depthLimit)
1415 updateInPlace &= mux->hasOneUse();
1417 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1419 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1422 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1423 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1428 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1429 bool updateInPlace,
int limit)
const {
1430 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1433 if (mux.getSel() == cond)
1434 return mux.getLow();
1435 if (limit > depthLimit)
1437 updateInPlace &= mux->hasOneUse();
1439 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1441 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1443 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1445 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1451 matchAndRewrite(Operation *op,
1452 mlir::PatternRewriter &rewriter)
const override {
1453 auto mux = cast<MuxPrimOp>(op);
1454 auto width = mux.getType().getBitWidthOrSentinel();
1458 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter,
true, 0)) {
1459 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1463 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter,
true, 0)) {
1464 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1473void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1474 MLIRContext *context) {
1476 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1477 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1478 patterns::MuxSameTrue, patterns::MuxSameFalse,
1479 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1483void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1484 RewritePatternSet &results, MLIRContext *context) {
1485 results.add<patterns::Mux2PadSel>(context);
1488void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1489 RewritePatternSet &results, MLIRContext *context) {
1490 results.add<patterns::Mux4PadSel>(context);
1493OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1494 auto input = this->getInput();
1497 if (input.getType() == getType())
1501 auto inputType = input.getType().base();
1508 auto destWidth = getType().base().getWidthOrSentinel();
1509 if (destWidth == -1)
1512 if (inputType.
isSigned() && cst->getBitWidth())
1513 return getIntAttr(getType(), cst->sext(destWidth));
1514 return getIntAttr(getType(), cst->zext(destWidth));
1520OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1521 auto input = this->getInput();
1522 IntType inputType = input.getType();
1523 int shiftAmount = getAmount();
1526 if (shiftAmount == 0)
1532 if (inputWidth != -1) {
1533 auto resultWidth = inputWidth + shiftAmount;
1534 shiftAmount = std::min(shiftAmount, resultWidth);
1535 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1541OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1542 auto input = this->getInput();
1543 IntType inputType = input.getType();
1544 int shiftAmount = getAmount();
1550 if (shiftAmount == 0 && inputWidth > 0)
1553 if (inputWidth == -1)
1555 if (inputWidth == 0)
1560 if (shiftAmount >= inputWidth && inputType.
isUnsigned())
1561 return getIntAttr(getType(), APInt(0, 0,
false));
1567 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1569 value = cst->lshr(std::min(shiftAmount, inputWidth));
1570 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1571 return getIntAttr(getType(), value.trunc(resultWidth));
1576LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1577 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1578 if (inputWidth <= 0)
1582 unsigned shiftAmount = op.getAmount();
1583 if (
int(shiftAmount) >= inputWidth) {
1585 if (op.getType().base().isUnsigned())
1591 shiftAmount = inputWidth - 1;
1594 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1598LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1599 PatternRewriter &rewriter) {
1600 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1601 if (inputWidth <= 0)
1605 unsigned keepAmount = op.getAmount();
1607 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1612OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1616 getInput().getType().base().getWidthOrSentinel() - getAmount();
1617 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1623OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1627 cst->trunc(getType().base().getWidthOrSentinel()));
1631LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1632 PatternRewriter &rewriter) {
1633 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1634 if (inputWidth <= 0)
1638 unsigned dropAmount = op.getAmount();
1639 if (dropAmount !=
unsigned(inputWidth))
1645void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1646 MLIRContext *context) {
1647 results.add<patterns::SubaccessOfConstant>(context);
1650OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1652 if (adaptor.getInputs().size() == 1)
1653 return getOperand(1);
1655 if (
auto constIndex =
getConstant(adaptor.getIndex())) {
1656 auto index = constIndex->getZExtValue();
1657 if (index < getInputs().size())
1658 return getInputs()[getInputs().size() - 1 - index];
1664LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1665 PatternRewriter &rewriter) {
1669 if (llvm::all_of(op.getInputs().drop_front(), [&](
auto input) {
1670 return input == op.getInputs().front();
1678 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
1679 uint64_t inputSize = op.getInputs().size();
1680 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
1681 rewriter.modifyOpInPlace(op, [&]() {
1682 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
1689 if (
auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
1690 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](
auto e) {
1691 auto subindex = e.value().template getDefiningOp<SubindexOp>();
1692 return subindex && lastSubindex.getInput() == subindex.getInput() &&
1693 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
1695 replaceOpWithNewOpAndCopyName<SubaccessOp>(
1696 rewriter, op, lastSubindex.getInput(), op.getIndex());
1702 if (op.getInputs().size() != 2)
1706 auto uintType = op.getIndex().getType();
1707 if (uintType.getBitWidthOrSentinel() != 1)
1711 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1712 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
1731 MatchingConnectOp connect;
1732 for (Operation *user : value.getUsers()) {
1734 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
1737 if (
auto aConnect = dyn_cast<FConnectLike>(user))
1738 if (aConnect.getDest() == value) {
1739 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
1742 if (!matchingConnect || (connect && connect != matchingConnect) ||
1743 matchingConnect->getBlock() != value.getParentBlock())
1746 connect = matchingConnect;
1754 PatternRewriter &rewriter) {
1757 Operation *connectedDecl = op.getDest().getDefiningOp();
1762 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
1766 cast<Forceable>(connectedDecl).isForceable())
1774 if (connectedDecl->hasOneUse())
1778 auto *declBlock = connectedDecl->getBlock();
1779 auto *srcValueOp = op.getSrc().getDefiningOp();
1782 if (!isa<WireOp>(connectedDecl))
1788 if (!isa<ConstantOp>(srcValueOp))
1790 if (srcValueOp->getBlock() != declBlock)
1796 auto replacement = op.getSrc();
1799 if (srcValueOp && srcValueOp != &declBlock->front())
1800 srcValueOp->moveBefore(&declBlock->front());
1807 rewriter.eraseOp(op);
1811void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1812 MLIRContext *context) {
1813 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
1817LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
1818 PatternRewriter &rewriter) {
1835 for (
auto *user : value.getUsers()) {
1836 auto attach = dyn_cast<AttachOp>(user);
1837 if (!attach || attach == dominatedAttach)
1839 if (attach->isBeforeInBlock(dominatedAttach))
1845LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
1847 if (op.getNumOperands() <= 1) {
1848 rewriter.eraseOp(op);
1852 for (
auto operand : op.getOperands()) {
1859 SmallVector<Value> newOperands(op.getOperands());
1860 for (
auto newOperand : attach.getOperands())
1861 if (newOperand != operand)
1862 newOperands.push_back(newOperand);
1863 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1864 rewriter.eraseOp(attach);
1865 rewriter.eraseOp(op);
1873 if (
auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
1874 if (!
hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
1875 !wire.isForceable()) {
1876 SmallVector<Value> newOperands;
1877 for (
auto newOperand : op.getOperands())
1878 if (newOperand != operand)
1879 newOperands.push_back(newOperand);
1881 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1882 rewriter.eraseOp(op);
1883 rewriter.eraseOp(wire);
1894 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
1895 rewriter.inlineBlockBefore(®ion.front(), op, {});
1898LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
1899 if (
auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
1900 if (constant.getValue().isAllOnes())
1902 else if (op.hasElseRegion() && !op.getElseRegion().empty())
1905 rewriter.eraseOp(op);
1911 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
1912 op.getElseBlock().empty()) {
1913 rewriter.eraseBlock(&op.getElseBlock());
1920 if (!op.getThenBlock().empty())
1924 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
1925 rewriter.eraseOp(op);
1934struct FoldNodeName :
public mlir::RewritePattern {
1935 FoldNodeName(MLIRContext *context)
1936 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1937 LogicalResult matchAndRewrite(Operation *op,
1938 PatternRewriter &rewriter)
const override {
1939 auto node = cast<NodeOp>(op);
1940 auto name = node.getNameAttr();
1941 if (!node.hasDroppableName() || node.getInnerSym() ||
1944 auto *newOp = node.getInput().getDefiningOp();
1946 if (newOp && !isa<InstanceOp>(newOp))
1948 rewriter.replaceOp(node, node.getInput());
1954struct NodeBypass :
public mlir::RewritePattern {
1955 NodeBypass(MLIRContext *context)
1956 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1957 LogicalResult matchAndRewrite(Operation *op,
1958 PatternRewriter &rewriter)
const override {
1959 auto node = cast<NodeOp>(op);
1961 node.use_empty() || node.isForceable())
1963 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
1970template <
typename OpTy>
1972 PatternRewriter &rewriter) {
1973 if (!op.isForceable() || !op.getDataRef().use_empty())
1981LogicalResult NodeOp::fold(FoldAdaptor adaptor,
1982 SmallVectorImpl<OpFoldResult> &results) {
1991 if (!adaptor.getInput())
1994 results.push_back(adaptor.getInput());
1998void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1999 MLIRContext *context) {
2000 results.insert<FoldNodeName>(context);
2001 results.add(demoteForceableIfUnused<NodeOp>);
2007struct AggOneShot :
public mlir::RewritePattern {
2008 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2009 : RewritePattern(name, 0, context) {}
2011 SmallVector<Value> getCompleteWrite(Operation *lhs)
const {
2012 auto lhsTy = lhs->getResult(0).getType();
2013 if (!type_isa<BundleType, FVectorType>(lhsTy))
2016 DenseMap<uint32_t, Value> fields;
2017 for (Operation *user : lhs->getResult(0).getUsers()) {
2018 if (user->getParentOp() != lhs->getParentOp())
2020 if (
auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2021 if (aConnect.getDest() == lhs->getResult(0))
2023 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2024 for (Operation *subuser : subField.getResult().getUsers()) {
2025 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2026 if (aConnect.getDest() == subField) {
2027 if (subuser->getParentOp() != lhs->getParentOp())
2029 if (fields.count(subField.getFieldIndex()))
2031 fields[subField.getFieldIndex()] = aConnect.getSrc();
2037 }
else if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2038 for (Operation *subuser : subIndex.getResult().getUsers()) {
2039 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2040 if (aConnect.getDest() == subIndex) {
2041 if (subuser->getParentOp() != lhs->getParentOp())
2043 if (fields.count(subIndex.getIndex()))
2045 fields[subIndex.getIndex()] = aConnect.getSrc();
2056 SmallVector<Value> values;
2057 uint32_t total = type_isa<BundleType>(lhsTy)
2058 ? type_cast<BundleType>(lhsTy).getNumElements()
2059 : type_cast<FVectorType>(lhsTy).getNumElements();
2060 for (uint32_t i = 0; i < total; ++i) {
2061 if (!fields.count(i))
2063 values.push_back(fields[i]);
2068 LogicalResult matchAndRewrite(Operation *op,
2069 PatternRewriter &rewriter)
const override {
2070 auto values = getCompleteWrite(op);
2073 rewriter.setInsertionPointToEnd(op->getBlock());
2074 auto dest = op->getResult(0);
2075 auto destType = dest.getType();
2078 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2081 Value newVal = type_isa<BundleType>(destType)
2082 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2084 : rewriter.createOrFold<VectorCreateOp>(
2085 op->
getLoc(), destType, values);
2086 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2087 for (Operation *user : dest.getUsers()) {
2088 if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2089 for (Operation *subuser :
2090 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2091 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2092 if (aConnect.getDest() == subIndex)
2093 rewriter.eraseOp(aConnect);
2094 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2095 for (Operation *subuser :
2096 llvm::make_early_inc_range(subField.getResult().getUsers()))
2097 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2098 if (aConnect.getDest() == subField)
2099 rewriter.eraseOp(aConnect);
2106struct WireAggOneShot :
public AggOneShot {
2107 WireAggOneShot(MLIRContext *context)
2108 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2110struct SubindexAggOneShot :
public AggOneShot {
2111 SubindexAggOneShot(MLIRContext *context)
2112 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2114struct SubfieldAggOneShot :
public AggOneShot {
2115 SubfieldAggOneShot(MLIRContext *context)
2116 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2120void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2121 MLIRContext *context) {
2122 results.insert<WireAggOneShot>(context);
2123 results.add(demoteForceableIfUnused<WireOp>);
2126void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2127 MLIRContext *context) {
2128 results.insert<SubindexAggOneShot>(context);
2131OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2132 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2135 return attr[getIndex()];
2138OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2139 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2142 auto index = getFieldIndex();
2146void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2147 MLIRContext *context) {
2148 results.insert<SubfieldAggOneShot>(context);
2152 ArrayRef<Attribute> operands) {
2153 for (
auto operand : operands)
2156 return ArrayAttr::get(context, operands);
2159OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2162 if (getNumOperands() > 0)
2163 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2164 if (first.getFieldIndex() == 0 &&
2165 first.getInput().getType() == getType() &&
2167 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2169 elem.value().
template getDefiningOp<SubfieldOp>();
2170 return subindex && subindex.getInput() == first.getInput() &&
2171 subindex.getFieldIndex() == elem.index();
2173 return first.getInput();
2178OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2181 if (getNumOperands() > 0)
2182 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2183 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2185 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2187 elem.value().
template getDefiningOp<SubindexOp>();
2188 return subindex && subindex.getInput() == first.getInput() &&
2189 subindex.getIndex() == elem.index();
2191 return first.getInput();
2196OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2197 if (getOperand().getType() == getType())
2198 return getOperand();
2205struct FoldResetMux :
public mlir::RewritePattern {
2206 FoldResetMux(MLIRContext *context)
2207 : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2208 LogicalResult matchAndRewrite(Operation *op,
2209 PatternRewriter &rewriter)
const override {
2210 auto reg = cast<RegResetOp>(op);
2212 dyn_cast_or_null<ConstantOp>(
reg.getResetValue().getDefiningOp());
2221 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2224 auto *high = mux.getHigh().getDefiningOp();
2225 auto *low = mux.getLow().getDefiningOp();
2226 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2228 if (constOp && low != reg)
2230 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2231 constOp = dyn_cast<ConstantOp>(low);
2233 if (!constOp || constOp.getType() != reset.getType() ||
2234 constOp.getValue() != reset.getValue())
2238 auto regTy =
reg.getResult().getType();
2239 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2240 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2241 regTy.getBitWidthOrSentinel() < 0)
2247 if (constOp != &con->getBlock()->front())
2248 constOp->moveBefore(&con->getBlock()->front());
2253 rewriter.eraseOp(con);
2260 if (
auto c = v.getDefiningOp<ConstantOp>())
2261 return c.getValue().isOne();
2262 if (
auto sc = v.getDefiningOp<SpecialConstantOp>())
2263 return sc.getValue();
2272 auto resetValue = reg.getResetValue();
2273 if (reg.getType(0) != resetValue.getType())
2277 (void)
dropWrite(rewriter, reg->getResult(0), {});
2278 replaceOpWithNewOpAndCopyName<NodeOp>(
2279 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2280 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2284void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2285 MLIRContext *context) {
2286 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2288 results.add(demoteForceableIfUnused<RegResetOp>);
2293 auto portTy = type_cast<BundleType>(port.getType());
2294 auto fieldIndex = portTy.getElementIndex(name);
2295 assert(fieldIndex &&
"missing field on memory port");
2298 for (
auto *op : port.getUsers()) {
2299 auto portAccess = cast<SubfieldOp>(op);
2300 if (fieldIndex != portAccess.getFieldIndex())
2305 value = conn.getSrc();
2315 auto portConst = value.getDefiningOp<ConstantOp>();
2318 return portConst.getValue().isZero();
2323 auto portTy = type_cast<BundleType>(port.getType());
2324 auto fieldIndex = portTy.getElementIndex(
data);
2325 assert(fieldIndex &&
"missing enable flag on memory port");
2327 for (
auto *op : port.getUsers()) {
2328 auto portAccess = cast<SubfieldOp>(op);
2329 if (fieldIndex != portAccess.getFieldIndex())
2331 if (!portAccess.use_empty())
2340 StringRef name, Value value) {
2341 auto portTy = type_cast<BundleType>(port.getType());
2342 auto fieldIndex = portTy.getElementIndex(name);
2343 assert(fieldIndex &&
"missing field on memory port");
2345 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2346 auto portAccess = cast<SubfieldOp>(op);
2347 if (fieldIndex != portAccess.getFieldIndex())
2349 rewriter.replaceAllUsesWith(portAccess, value);
2350 rewriter.eraseOp(portAccess);
2355static void erasePort(PatternRewriter &rewriter, Value port) {
2358 auto getClock = [&] {
2360 clock = rewriter.create<SpecialConstantOp>(
2361 port.getLoc(), ClockType::get(rewriter.getContext()),
false);
2369 for (
auto *op : port.getUsers()) {
2370 auto subfield = dyn_cast<SubfieldOp>(op);
2372 auto ty = port.getType();
2373 auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
2374 rewriter.replaceAllUsesWith(port, reg.getResult());
2383 for (
auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2384 auto access = cast<SubfieldOp>(accessOp);
2385 for (
auto *user : llvm::make_early_inc_range(access->getUsers())) {
2386 auto connect = dyn_cast<FConnectLike>(user);
2387 if (connect && connect.getDest() == access) {
2388 rewriter.eraseOp(user);
2392 if (access.use_empty()) {
2393 rewriter.eraseOp(access);
2399 auto ty = access.getType();
2400 auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
2401 rewriter.replaceOp(access, reg.getResult());
2403 assert(port.use_empty() &&
"port should have no remaining uses");
2408struct FoldZeroWidthMemory :
public mlir::RewritePattern {
2409 FoldZeroWidthMemory(MLIRContext *context)
2410 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2411 LogicalResult matchAndRewrite(Operation *op,
2412 PatternRewriter &rewriter)
const override {
2413 MemOp mem = cast<MemOp>(op);
2417 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2418 mem.getDataType().getBitWidthOrSentinel() != 0)
2422 for (
auto port : mem.getResults())
2423 for (auto *user : port.getUsers())
2424 if (!isa<SubfieldOp>(user))
2429 for (
auto port : op->getResults()) {
2430 for (
auto *user :
llvm::make_early_inc_range(port.getUsers())) {
2431 SubfieldOp sfop = cast<SubfieldOp>(user);
2432 StringRef fieldName = sfop.getFieldName();
2433 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2434 rewriter, sfop, sfop.getResult().getType())
2436 if (fieldName.ends_with(
"data")) {
2438 auto zero = rewriter.create<firrtl::ConstantOp>(
2439 wire.getLoc(), firrtl::type_cast<IntType>(wire.getType()),
2441 rewriter.create<MatchingConnectOp>(wire.getLoc(), wire, zero);
2445 rewriter.eraseOp(op);
2451struct FoldReadOrWriteOnlyMemory :
public mlir::RewritePattern {
2452 FoldReadOrWriteOnlyMemory(MLIRContext *context)
2453 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2454 LogicalResult matchAndRewrite(Operation *op,
2455 PatternRewriter &rewriter)
const override {
2456 MemOp mem = cast<MemOp>(op);
2459 bool isRead =
false, isWritten =
false;
2460 for (
unsigned i = 0; i < mem.getNumResults(); ++i) {
2461 switch (mem.getPortKind(i)) {
2462 case MemOp::PortKind::Read:
2467 case MemOp::PortKind::Write:
2472 case MemOp::PortKind::Debug:
2473 case MemOp::PortKind::ReadWrite:
2476 llvm_unreachable(
"unknown port kind");
2478 assert((!isWritten || !isRead) &&
"memory is in use");
2483 if (isRead && mem.getInit())
2486 for (
auto port : mem.getResults())
2489 rewriter.eraseOp(op);
2495struct FoldUnusedPorts :
public mlir::RewritePattern {
2496 FoldUnusedPorts(MLIRContext *context)
2497 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2498 LogicalResult matchAndRewrite(Operation *op,
2499 PatternRewriter &rewriter)
const override {
2500 MemOp mem = cast<MemOp>(op);
2504 llvm::SmallBitVector deadPorts(mem.getNumResults());
2505 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2507 if (!mem.getPortAnnotation(i).empty())
2511 auto kind = mem.getPortKind(i);
2512 if (kind == MemOp::PortKind::Debug)
2521 if (kind == MemOp::PortKind::Read &&
isPortUnused(port,
"data")) {
2526 if (deadPorts.none())
2530 SmallVector<Type> resultTypes;
2531 SmallVector<StringRef> portNames;
2532 SmallVector<Attribute> portAnnotations;
2533 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2536 resultTypes.push_back(port.getType());
2537 portNames.push_back(mem.getPortName(i));
2538 portAnnotations.push_back(mem.getPortAnnotation(i));
2542 if (!resultTypes.empty())
2543 newOp = rewriter.create<MemOp>(
2544 mem.getLoc(), resultTypes, mem.getReadLatency(),
2545 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2546 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2547 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2548 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2551 unsigned nextPort = 0;
2552 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2556 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2559 rewriter.eraseOp(op);
2565struct FoldReadWritePorts :
public mlir::RewritePattern {
2566 FoldReadWritePorts(MLIRContext *context)
2567 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2568 LogicalResult matchAndRewrite(Operation *op,
2569 PatternRewriter &rewriter)
const override {
2570 MemOp mem = cast<MemOp>(op);
2575 llvm::SmallBitVector deadReads(mem.getNumResults());
2576 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2577 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2579 if (!mem.getPortAnnotation(i).empty())
2586 if (deadReads.none())
2589 SmallVector<Type> resultTypes;
2590 SmallVector<StringRef> portNames;
2591 SmallVector<Attribute> portAnnotations;
2592 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2594 resultTypes.push_back(
2595 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2596 MemOp::PortKind::Write, mem.getMaskBits()));
2598 resultTypes.push_back(port.getType());
2600 portNames.push_back(mem.getPortName(i));
2601 portAnnotations.push_back(mem.getPortAnnotation(i));
2604 auto newOp = rewriter.create<MemOp>(
2605 mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
2606 mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
2607 mem.getName(), mem.getNameKind(), mem.getAnnotations(),
2608 rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
2609 mem.getInitAttr(), mem.getPrefixAttr());
2611 for (
unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2612 auto result = mem.getResult(i);
2613 auto newResult = newOp.getResult(i);
2615 auto resultPortTy = type_cast<BundleType>(result.getType());
2619 auto replace = [&](StringRef toName, StringRef fromName) {
2620 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2621 assert(fromFieldIndex &&
"missing enable flag on memory port");
2623 auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
2625 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
2626 auto fromField = cast<SubfieldOp>(op);
2627 if (fromFieldIndex != fromField.getFieldIndex())
2629 rewriter.replaceOp(fromField, toField.getResult());
2633 replace(
"addr",
"addr");
2634 replace(
"en",
"en");
2635 replace(
"clk",
"clk");
2636 replace(
"data",
"wdata");
2637 replace(
"mask",
"wmask");
2640 auto wmodeFieldIndex = resultPortTy.getElementIndex(
"wmode");
2641 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
2642 auto wmodeField = cast<SubfieldOp>(op);
2643 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2645 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2648 rewriter.replaceAllUsesWith(result, newResult);
2651 rewriter.eraseOp(op);
2657struct FoldUnusedBits :
public mlir::RewritePattern {
2658 FoldUnusedBits(MLIRContext *context)
2659 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2661 LogicalResult matchAndRewrite(Operation *op,
2662 PatternRewriter &rewriter)
const override {
2663 MemOp mem = cast<MemOp>(op);
2668 const auto &summary = mem.getSummary();
2669 if (summary.isMasked || summary.isSeqMem())
2672 auto type = type_dyn_cast<IntType>(mem.getDataType());
2675 auto width = type.getBitWidthOrSentinel();
2679 llvm::SmallBitVector usedBits(width);
2680 DenseMap<unsigned, unsigned> mapping;
2685 SmallVector<BitsPrimOp> readOps;
2686 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
2687 auto portTy = type_cast<BundleType>(port.getType());
2688 auto fieldIndex = portTy.getElementIndex(field);
2689 assert(fieldIndex &&
"missing data port");
2691 for (
auto *op : port.getUsers()) {
2692 auto portAccess = cast<SubfieldOp>(op);
2693 if (fieldIndex != portAccess.getFieldIndex())
2696 for (
auto *user : op->getUsers()) {
2697 auto bits = dyn_cast<BitsPrimOp>(user);
2701 usedBits.set(bits.getLo(), bits.getHi() + 1);
2705 mapping[bits.getLo()] = 0;
2706 readOps.push_back(bits);
2716 SmallVector<MatchingConnectOp> writeOps;
2717 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
2718 auto portTy = type_cast<BundleType>(port.getType());
2719 auto fieldIndex = portTy.getElementIndex(field);
2720 assert(fieldIndex &&
"missing data port");
2722 for (
auto *op : port.getUsers()) {
2723 auto portAccess = cast<SubfieldOp>(op);
2724 if (fieldIndex != portAccess.getFieldIndex())
2731 writeOps.push_back(conn);
2737 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2739 if (!mem.getPortAnnotation(i).empty())
2742 switch (mem.getPortKind(i)) {
2743 case MemOp::PortKind::Debug:
2746 case MemOp::PortKind::Write:
2747 if (failed(findWriteUsers(port,
"data")))
2750 case MemOp::PortKind::Read:
2751 if (failed(findReadUsers(port,
"data")))
2754 case MemOp::PortKind::ReadWrite:
2755 if (failed(findWriteUsers(port,
"wdata")))
2757 if (failed(findReadUsers(port,
"rdata")))
2761 llvm_unreachable(
"unknown port kind");
2765 if (usedBits.none())
2769 SmallVector<std::pair<unsigned, unsigned>> ranges;
2770 unsigned newWidth = 0;
2771 for (
int i = usedBits.find_first(); 0 <= i && i < width;) {
2772 int e = usedBits.find_next_unset(i);
2775 for (
int idx = i; idx < e; ++idx, ++newWidth) {
2776 if (
auto it = mapping.find(idx); it != mapping.end()) {
2777 it->second = newWidth;
2780 ranges.emplace_back(i, e - 1);
2781 i = e != width ? usedBits.find_next(e) : e;
2785 auto newType =
IntType::get(op->getContext(), type.isSigned(), newWidth);
2786 SmallVector<Type> portTypes;
2787 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2788 portTypes.push_back(
2789 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
2791 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
2792 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
2793 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
2794 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
2795 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2798 auto rewriteSubfield = [&](Value port, StringRef field) {
2799 auto portTy = type_cast<BundleType>(port.getType());
2800 auto fieldIndex = portTy.getElementIndex(field);
2801 assert(fieldIndex &&
"missing data port");
2803 rewriter.setInsertionPointAfter(newMem);
2804 auto newPortAccess =
2805 rewriter.create<SubfieldOp>(port.getLoc(), port, field);
2807 for (
auto *op :
llvm::make_early_inc_range(port.getUsers())) {
2808 auto portAccess = cast<SubfieldOp>(op);
2809 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
2811 rewriter.replaceOp(portAccess, newPortAccess.getResult());
2816 for (
auto [i, port] :
llvm::enumerate(newMem.getResults())) {
2817 switch (newMem.getPortKind(i)) {
2818 case MemOp::PortKind::Debug:
2819 llvm_unreachable(
"cannot rewrite debug port");
2820 case MemOp::PortKind::Write:
2821 rewriteSubfield(port,
"data");
2823 case MemOp::PortKind::Read:
2824 rewriteSubfield(port,
"data");
2826 case MemOp::PortKind::ReadWrite:
2827 rewriteSubfield(port,
"rdata");
2828 rewriteSubfield(port,
"wdata");
2831 llvm_unreachable(
"unknown port kind");
2835 for (
auto readOp : readOps) {
2836 rewriter.setInsertionPointAfter(readOp);
2837 auto it = mapping.find(readOp.getLo());
2838 assert(it != mapping.end() &&
"bit op mapping not found");
2841 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
2842 readOp.getLoc(), readOp.getInput(),
2843 readOp.getHi() - readOp.getLo() + it->second, it->second);
2844 rewriter.replaceAllUsesWith(readOp, newReadValue);
2845 rewriter.eraseOp(readOp);
2849 for (
auto writeOp : writeOps) {
2850 Value source = writeOp.getSrc();
2851 rewriter.setInsertionPoint(writeOp);
2854 for (
auto &[start, end] : ranges) {
2855 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
2856 source,
end, start);
2858 catOfSlices = rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(),
2859 slice, catOfSlices);
2861 catOfSlices = slice;
2869 if (type.isSigned())
2871 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
2873 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
2882struct FoldRegMems :
public mlir::RewritePattern {
2883 FoldRegMems(MLIRContext *context)
2884 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2885 LogicalResult matchAndRewrite(Operation *op,
2886 PatternRewriter &rewriter)
const override {
2887 MemOp mem = cast<MemOp>(op);
2892 auto ty = mem.getDataType();
2893 auto loc = mem.getLoc();
2894 auto *block = mem->getBlock();
2898 SmallPtrSet<Operation *, 8> connects;
2899 SmallVector<SubfieldOp> portAccesses;
2900 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2901 if (!mem.getPortAnnotation(i).empty())
2904 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
2905 auto portTy = type_cast<BundleType>(port.getType());
2906 for (
auto field : fields) {
2907 auto fieldIndex = portTy.getElementIndex(field);
2908 assert(fieldIndex &&
"missing field on memory port");
2910 for (
auto *op : port.getUsers()) {
2911 auto portAccess = cast<SubfieldOp>(op);
2912 if (fieldIndex != portAccess.getFieldIndex())
2914 portAccesses.push_back(portAccess);
2915 for (
auto *user : portAccess->getUsers()) {
2916 auto conn = dyn_cast<FConnectLike>(user);
2919 connects.insert(conn);
2926 switch (mem.getPortKind(i)) {
2927 case MemOp::PortKind::Debug:
2929 case MemOp::PortKind::Read:
2930 if (failed(collect({
"clk",
"en",
"addr"})))
2933 case MemOp::PortKind::Write:
2934 if (failed(collect({
"clk",
"en",
"addr",
"data",
"mask"})))
2937 case MemOp::PortKind::ReadWrite:
2938 if (failed(collect({
"clk",
"en",
"addr",
"wmode",
"wdata",
"wmask"})))
2944 if (!portClock || (clock && portClock != clock))
2950 rewriter.setInsertionPointAfter(mem);
2951 auto memWire = rewriter.create<WireOp>(loc, ty).getResult();
2957 rewriter.setInsertionPointToEnd(block);
2959 rewriter.create<RegOp>(loc, ty, clock, mem.getName()).getResult();
2962 rewriter.create<MatchingConnectOp>(loc, memWire, memReg);
2966 auto pipeline = [&](Value value, Value clock,
const Twine &name,
2968 for (
unsigned i = 0; i < latency; ++i) {
2969 std::string regName;
2971 llvm::raw_string_ostream os(regName);
2972 os << mem.getName() <<
"_" << name <<
"_" << i;
2975 .create<RegOp>(mem.getLoc(), value.getType(), clock,
2976 rewriter.getStringAttr(regName))
2978 rewriter.create<MatchingConnectOp>(value.getLoc(),
reg, value);
2984 const unsigned writeStages =
info.writeLatency - 1;
2989 SmallVector<std::tuple<Value, Value, Value>> writes;
2990 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2992 StringRef name = mem.getPortName(i);
2994 auto portPipeline = [&, port = port](StringRef field,
unsigned stages) {
2997 return pipeline(value, portClock, name +
"_" + field, stages);
3000 switch (mem.getPortKind(i)) {
3001 case MemOp::PortKind::Debug:
3002 llvm_unreachable(
"unknown port kind");
3003 case MemOp::PortKind::Read: {
3011 case MemOp::PortKind::Write: {
3012 auto data = portPipeline(
"data", writeStages);
3013 auto en = portPipeline(
"en", writeStages);
3014 auto mask = portPipeline(
"mask", writeStages);
3018 case MemOp::PortKind::ReadWrite: {
3023 auto wdata = portPipeline(
"wdata", writeStages);
3024 auto wmask = portPipeline(
"wmask", writeStages);
3029 auto wen = rewriter.create<AndPrimOp>(port.getLoc(),
en,
wmode);
3031 pipeline(wen, portClock, name +
"_wen", writeStages);
3032 writes.emplace_back(
wdata, wenPipelined,
wmask);
3039 Value next = memReg;
3045 Location loc = mem.getLoc();
3046 unsigned maskGran =
info.dataWidth /
info.maskBits;
3047 for (
unsigned i = 0; i <
info.maskBits; ++i) {
3048 unsigned hi = (i + 1) * maskGran - 1;
3049 unsigned lo = i * maskGran;
3051 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc,
data, hi, lo);
3052 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3053 auto bit = rewriter.createOrFold<BitsPrimOp>(loc,
mask, i, i);
3054 auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
3057 masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
3063 next = rewriter.create<MuxPrimOp>(next.getLoc(),
en, masked, next);
3065 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3066 rewriter.create<MatchingConnectOp>(memReg.getLoc(), memReg, typedNext);
3069 for (Operation *conn : connects)
3070 rewriter.eraseOp(conn);
3071 for (
auto portAccess : portAccesses)
3072 rewriter.eraseOp(portAccess);
3073 rewriter.eraseOp(mem);
3080void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3081 MLIRContext *context) {
3083 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3084 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3104 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3107 auto *high = mux.getHigh().getDefiningOp();
3108 auto *low = mux.getLow().getDefiningOp();
3110 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3117 bool constReg =
false;
3119 if (constOp && low == reg)
3121 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3123 constOp = dyn_cast<ConstantOp>(low);
3130 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3134 auto regTy = reg.getResult().getType();
3135 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3136 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3137 regTy.getBitWidthOrSentinel() < 0)
3143 if (constOp != &con->getBlock()->front())
3144 constOp->moveBefore(&con->getBlock()->front());
3147 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3148 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3149 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3150 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3151 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3152 reg.getForceableAttr());
3153 newReg->setDialectAttrs(attrs);
3155 auto pt = rewriter.saveInsertionPoint();
3156 rewriter.setInsertionPoint(con);
3157 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3158 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3159 rewriter.restoreInsertionPoint(pt);
3163LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3164 if (!
hasDontTouch(op.getOperation()) && !op.isForceable() &&
3180 PatternRewriter &rewriter,
3183 if (
auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3184 if (constant.getValue().isZero()) {
3185 rewriter.eraseOp(op);
3191 if (
auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3192 if (constant.getValue().isZero() == eraseIfZero) {
3193 rewriter.eraseOp(op);
3201template <
class Op,
bool EraseIfZero = false>
3203 PatternRewriter &rewriter) {
3208void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3209 MLIRContext *context) {
3210 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3211 results.add<patterns::AssertXWhenX>(context);
3214void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3215 MLIRContext *context) {
3216 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3217 results.add<patterns::AssumeXWhenX>(context);
3220void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3221 RewritePatternSet &results, MLIRContext *context) {
3222 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3223 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(context);
3226void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3227 MLIRContext *context) {
3228 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3235LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3236 PatternRewriter &rewriter) {
3238 if (op.use_empty()) {
3239 rewriter.eraseOp(op);
3246 if (op->hasOneUse() &&
3247 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3248 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3249 *op->user_begin()) ||
3250 (isa<CvtPrimOp>(*op->user_begin()) &&
3251 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3252 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3253 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3254 .getBitWidthOrSentinel() > 0))) {
3255 auto *modop = *op->user_begin();
3256 auto inv = rewriter.create<InvalidValueOp>(op.getLoc(),
3257 modop->getResult(0).getType());
3258 rewriter.replaceAllOpUsesWith(modop, inv);
3259 rewriter.eraseOp(modop);
3260 rewriter.eraseOp(op);
3266OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3267 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3268 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3276OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3285 return BoolAttr::get(getContext(),
false);
3289 return BoolAttr::get(getContext(),
false);
3294LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3295 PatternRewriter &rewriter) {
3297 if (
auto testEnable = op.getTestEnable()) {
3298 if (
auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3299 if (constOp.getValue().isZero()) {
3300 rewriter.modifyOpInPlace(op,
3301 [&] { op.getTestEnableMutable().clear(); });
3317 auto forceable = op.getRef().getDefiningOp<Forceable>();
3318 if (!forceable || !forceable.isForceable() ||
3319 op.getRef() != forceable.getDataRef() ||
3320 op.getType() != forceable.getDataType())
3322 rewriter.replaceAllUsesWith(op, forceable.getData());
3326void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3327 MLIRContext *context) {
3328 results.insert<patterns::RefResolveOfRefSend>(context);
3332OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3334 if (getInput().getType() == getType())
3340 auto constOp = operand.getDefiningOp<ConstantOp>();
3341 return constOp && constOp.getValue().isZero();
3344template <
typename Op>
3347 rewriter.eraseOp(op);
3353void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3354 MLIRContext *context) {
3355 results.add(eraseIfPredFalse<RefForceOp>);
3357void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3358 MLIRContext *context) {
3359 results.add(eraseIfPredFalse<RefForceInitialOp>);
3361void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3362 MLIRContext *context) {
3363 results.add(eraseIfPredFalse<RefReleaseOp>);
3365void RefReleaseInitialOp::getCanonicalizationPatterns(
3366 RewritePatternSet &results, MLIRContext *context) {
3367 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3374OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3380 if (adaptor.getReset())
3385 if (
isUInt1(getReset().getType()) && adaptor.getClock())
3398 [&](
auto ty) ->
bool {
return isTypeEmpty(ty.getElementType()); })
3399 .Case<BundleType>([&](
auto ty) ->
bool {
3400 for (
auto elem : ty.getElements())
3405 .Case<IntType>([&](
auto ty) {
return ty.getWidth() == 0; })
3406 .Default([](
auto) ->
bool {
return false; });
3409LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3410 PatternRewriter &rewriter) {
3411 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3418 rewriter.eraseOp(op);
3426LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3427 PatternRewriter &rewriter) {
3430 if (op.getBody()->empty()) {
3431 rewriter.eraseOp(op);
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static InstancePath empty
static Location getLoc(DefSlot slot)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)