20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
28using namespace firrtl;
32static Value
dropWrite(PatternRewriter &rewriter, OpResult old,
34 SmallPtrSet<Operation *, 8> users;
35 for (
auto *user : old.getUsers())
37 for (Operation *user : users)
38 if (
auto connect = dyn_cast<FConnectLike>(user))
39 if (connect.getDest() == old)
40 rewriter.eraseOp(user);
49 Operation *op = passthrough.getDefiningOp();
52 assert(op &&
"passthrough must be an operation");
53 Operation *oldOp = old.getOwner();
54 auto name = oldOp->getAttrOfType<StringAttr>(
"name");
55 if (name && !name.getValue().empty())
56 op->setAttr(
"name", name);
64#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
72 auto resultType = type_cast<IntType>(op->getResult(0).getType());
73 if (!resultType.hasWidth())
75 for (Value operand : op->getOperands())
76 if (!type_cast<IntType>(operand.getType()).hasWidth())
83 auto t = type_dyn_cast<UIntType>(type);
84 if (!t || !t.hasWidth() || t.getWidth() != 1)
91static void updateName(PatternRewriter &rewriter, Operation *op,
94 assert(!isa<InstanceOp>(op));
95 if (!name || name.getValue().empty())
97 auto newName = name.getValue();
98 auto newOpName = op->getAttrOfType<StringAttr>(
"name");
101 newName =
chooseName(newOpName.getValue(), name.getValue());
103 if (!newOpName || newOpName.getValue() != newName)
104 rewriter.modifyOpInPlace(
105 op, [&] { op->setAttr(
"name", rewriter.getStringAttr(newName)); });
113 if (
auto *newOp = newValue.getDefiningOp()) {
114 auto name = op->getAttrOfType<StringAttr>(
"name");
117 rewriter.replaceOp(op, newValue);
123template <
typename OpTy,
typename... Args>
125 Operation *op, Args &&...args) {
126 auto name = op->getAttrOfType<StringAttr>(
"name");
128 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
136 if (
auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
137 return namableOp.hasDroppableName();
148static std::optional<APSInt>
150 assert(type_cast<IntType>(operand.getType()) &&
151 "getExtendedConstant is limited to integer types");
158 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
163 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
164 return APSInt(destWidth,
165 type_cast<IntType>(operand.getType()).isUnsigned());
173 if (
auto attr = dyn_cast<BoolAttr>(operand))
174 return APSInt(APInt(1, attr.getValue()));
175 if (
auto attr = dyn_cast<IntegerAttr>(operand))
176 return attr.getAPSInt();
184 return cst->isZero();
211 Operation *op, ArrayRef<Attribute> operands,
BinOpKind opKind,
212 const function_ref<APInt(
const APSInt &,
const APSInt &)> &calculate) {
213 assert(operands.size() == 2 &&
"binary op takes two operands");
216 auto resultType = type_cast<IntType>(op->getResult(0).getType());
217 if (resultType.getWidthOrSentinel() < 0)
221 if (resultType.getWidthOrSentinel() == 0)
222 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
228 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
230 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
231 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
232 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
233 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
234 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
238 int32_t operandWidth;
241 operandWidth = resultType.getWidthOrSentinel();
246 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
250 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
261 APInt resultValue = calculate(*lhs, *rhs);
266 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
268 assert((
unsigned)resultType.getWidthOrSentinel() ==
269 resultValue.getBitWidth());
282 Operation *op, PatternRewriter &rewriter,
283 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
285 if (op->getNumResults() != 1)
287 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
292 auto width = type.getBitWidthOrSentinel();
297 SmallVector<Attribute, 3> constOperands;
298 constOperands.reserve(op->getNumOperands());
299 for (
auto operand : op->getOperands()) {
301 if (
auto *defOp = operand.getDefiningOp())
302 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
303 [&](
auto op) { attr = op.getValueAttr(); });
304 constOperands.push_back(attr);
309 auto result = canonicalize(constOperands);
313 if (
auto cst = dyn_cast<Attribute>(result))
314 resultValue = op->getDialect()
315 ->materializeConstant(rewriter, cst, type, op->getLoc())
318 resultValue = cast<Value>(result);
322 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
323 resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue, width);
326 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
327 resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
328 else if (type_isa<UIntType>(type) &&
329 type_isa<SIntType>(resultValue.getType()))
330 resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
332 assert(type == resultValue.getType() &&
"canonicalization changed type");
340 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
346 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
352 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
359OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
360 assert(adaptor.getOperands().empty() &&
"constant has no operands");
361 return getValueAttr();
364OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
365 assert(adaptor.getOperands().empty() &&
"constant has no operands");
366 return getValueAttr();
369OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
370 assert(adaptor.getOperands().empty() &&
"constant has no operands");
371 return getFieldsAttr();
374OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
375 assert(adaptor.getOperands().empty() &&
"constant has no operands");
376 return getValueAttr();
379OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
380 assert(adaptor.getOperands().empty() &&
"constant has no operands");
381 return getValueAttr();
384OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
385 assert(adaptor.getOperands().empty() &&
"constant has no operands");
386 return getValueAttr();
389OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
390 assert(adaptor.getOperands().empty() &&
"constant has no operands");
391 return getValueAttr();
398OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
401 [=](
const APSInt &a,
const APSInt &b) { return a + b; });
404void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
405 MLIRContext *context) {
406 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
407 patterns::AddOfSelf, patterns::AddOfPad>(context);
410OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
413 [=](
const APSInt &a,
const APSInt &b) { return a - b; });
416void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 MLIRContext *context) {
418 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
419 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
420 patterns::SubOfPadL, patterns::SubOfPadR>(context);
423OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
435 [=](
const APSInt &a,
const APSInt &b) { return a * b; });
438OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
445 if (getLhs() == getRhs()) {
446 auto width = getType().base().getWidthOrSentinel();
451 return getIntAttr(getType(), APInt(width, 1));
468 if (
auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
469 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
474 [=](
const APSInt &a,
const APSInt &b) -> APInt {
477 return APInt(a.getBitWidth(), 0);
481OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
488 if (getLhs() == getRhs())
502 [=](
const APSInt &a,
const APSInt &b) -> APInt {
505 return APInt(a.getBitWidth(), 0);
509OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
512 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
515OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
518 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
521OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
524 [=](
const APSInt &a,
const APSInt &b) -> APInt {
525 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
531OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
534 if (rhsCst->isZero())
538 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
539 getRhs().getType() == getType())
545 if (lhsCst->isZero())
549 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
555 if (getLhs() == getRhs() && getRhs().getType() == getType())
560 [](
const APSInt &a,
const APSInt &b) -> APInt { return a & b; });
563void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
564 MLIRContext *context) {
566 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
567 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
568 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
571OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
574 if (rhsCst->isZero() && getLhs().getType() == getType())
578 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
579 getLhs().getType() == getType())
585 if (lhsCst->isZero() && getRhs().getType() == getType())
589 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
590 getRhs().getType() == getType())
595 if (getLhs() == getRhs() && getRhs().getType() == getType())
600 [](
const APSInt &a,
const APSInt &b) -> APInt { return a | b; });
603void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
604 MLIRContext *context) {
605 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
606 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
607 patterns::OrOrr>(context);
610OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
613 if (rhsCst->isZero() &&
619 if (lhsCst->isZero() &&
624 if (getLhs() == getRhs())
627 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
631 [](
const APSInt &a,
const APSInt &b) -> APInt { return a ^ b; });
634void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
635 MLIRContext *context) {
636 results.insert<patterns::extendXor, patterns::moveConstXor,
637 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
641void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
642 MLIRContext *context) {
643 results.insert<patterns::LEQWithConstLHS>(context);
646OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
647 bool isUnsigned = getLhs().getType().base().isUnsigned();
650 if (getLhs() == getRhs())
654 if (
auto width = getLhs().getType().base().
getWidth()) {
656 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
657 commonWidth = std::max(commonWidth, 1);
668 if (isUnsigned && rhsCst->zext(commonWidth)
681 [=](
const APSInt &a,
const APSInt &b) -> APInt {
682 return APInt(1, a <= b);
686void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
687 MLIRContext *context) {
688 results.insert<patterns::LTWithConstLHS>(context);
691OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
692 IntType lhsType = getLhs().getType();
696 if (getLhs() == getRhs())
706 if (
auto width = lhsType.
getWidth()) {
708 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
709 commonWidth = std::max(commonWidth, 1);
720 if (isUnsigned && rhsCst->zext(commonWidth)
733 [=](
const APSInt &a,
const APSInt &b) -> APInt {
734 return APInt(1, a < b);
738void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
739 MLIRContext *context) {
740 results.insert<patterns::GEQWithConstLHS>(context);
743OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
744 IntType lhsType = getLhs().getType();
748 if (getLhs() == getRhs())
753 if (rhsCst->isZero() && isUnsigned)
758 if (
auto width = lhsType.
getWidth()) {
760 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
761 commonWidth = std::max(commonWidth, 1);
764 if (isUnsigned && rhsCst->zext(commonWidth)
785 [=](
const APSInt &a,
const APSInt &b) -> APInt {
786 return APInt(1, a >= b);
790void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
791 MLIRContext *context) {
792 results.insert<patterns::GTWithConstLHS>(context);
795OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
796 IntType lhsType = getLhs().getType();
800 if (getLhs() == getRhs())
804 if (
auto width = lhsType.
getWidth()) {
806 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
807 commonWidth = std::max(commonWidth, 1);
810 if (isUnsigned && rhsCst->zext(commonWidth)
831 [=](
const APSInt &a,
const APSInt &b) -> APInt {
832 return APInt(1, a > b);
836OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
838 if (getLhs() == getRhs())
844 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
845 getRhs().getType() == getType())
851 [=](
const APSInt &a,
const APSInt &b) -> APInt {
852 return APInt(1, a == b);
856LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
858 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
860 auto width = op.getLhs().getType().getBitWidthOrSentinel();
863 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
864 op.getRhs().getType() == op.getType()) {
865 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
870 if (rhsCst->isZero() && width > 1) {
871 auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
872 return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
876 if (rhsCst->isAllOnes() && width > 1 &&
877 op.getLhs().getType() == op.getRhs().getType()) {
878 return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
886OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
888 if (getLhs() == getRhs())
894 if (rhsCst->isZero() && getLhs().getType() == getType() &&
895 getRhs().getType() == getType())
901 [=](
const APSInt &a,
const APSInt &b) -> APInt {
902 return APInt(1, a != b);
906LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
908 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
910 auto width = op.getLhs().getType().getBitWidthOrSentinel();
913 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
914 op.getRhs().getType() == op.getType()) {
915 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
920 if (rhsCst->isZero() && width > 1) {
921 return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
926 if (rhsCst->isAllOnes() && width > 1 &&
927 op.getLhs().getType() == op.getRhs().getType()) {
928 auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
929 return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
937OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
943OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
949OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
955OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
960 return IntegerAttr::get(
961 IntegerType::get(getContext(), lhsCst->getBitWidth()),
962 lhsCst->shl(*rhsCst));
965 if (rhsCst->isZero())
976OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
977 auto base = getInput().getType();
984OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
991OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
999 if (getType().base().hasWidth())
1006void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1007 MLIRContext *context) {
1008 results.insert<patterns::StoUtoS>(context);
1011OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1019 if (getType().base().hasWidth())
1026void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1027 MLIRContext *context) {
1028 results.insert<patterns::UtoStoU>(context);
1031OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1033 if (getInput().getType() == getType())
1038 return BoolAttr::get(getContext(), cst->getBoolValue());
1043OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1045 if (getInput().getType() == getType())
1050 return BoolAttr::get(getContext(), cst->getBoolValue());
1055OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1061 getType().base().getWidthOrSentinel()))
1067void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1068 MLIRContext *context) {
1069 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1072OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1079 getType().base().getWidthOrSentinel()))
1080 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1085OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1090 getType().base().getWidthOrSentinel()))
1096void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1097 MLIRContext *context) {
1098 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1099 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1100 patterns::NotGt>(context);
1103OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1107 if (getInput().getType().getBitWidthOrSentinel() == 0)
1112 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1116 if (
isUInt1(getInput().getType()))
1122void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1123 MLIRContext *context) {
1125 .insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1126 patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
1127 patterns::AndRCatZeroL, patterns::AndRCatZeroR,
1128 patterns::AndRCatAndR_left, patterns::AndRCatAndR_right>(context);
1131OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1135 if (getInput().getType().getBitWidthOrSentinel() == 0)
1140 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1144 if (
isUInt1(getInput().getType()))
1150void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1151 MLIRContext *context) {
1152 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1153 patterns::OrRCatZeroH, patterns::OrRCatZeroL,
1154 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right>(context);
1157OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1161 if (getInput().getType().getBitWidthOrSentinel() == 0)
1166 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1169 if (
isUInt1(getInput().getType()))
1175void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1176 MLIRContext *context) {
1177 results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1178 patterns::XorRCatZeroH, patterns::XorRCatZeroL,
1179 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right>(
1187OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1191 IntType lhsType = getLhs().getType();
1192 IntType rhsType = getRhs().getType();
1204 return getIntAttr(getType(), lhs->concat(*rhs));
1209void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1210 MLIRContext *context) {
1211 results.insert<patterns::DShlOfConstant>(context);
1214void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1215 MLIRContext *context) {
1216 results.insert<patterns::DShrOfConstant>(context);
1219void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1220 MLIRContext *context) {
1221 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1222 patterns::CatCast>(context);
1225OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1228 if (op.getType() == op.getInput().getType())
1229 return op.getInput();
1233 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1234 if (op.getType() == in.getInput().getType())
1235 return in.getInput();
1240OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1241 IntType inputType = getInput().getType();
1242 IntType resultType = getType();
1244 if (inputType == getType() && resultType.
hasWidth())
1251 cst->extractBits(getHi() - getLo() + 1, getLo()));
1256void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1257 MLIRContext *context) {
1259 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1260 patterns::BitsOfAnd, patterns::BitsOfPad>(context);
1267 unsigned loBit, PatternRewriter &rewriter) {
1268 auto resType = type_cast<IntType>(op->getResult(0).getType());
1269 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1270 value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
1272 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1273 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1274 }
else if (resType.isUnsigned() &&
1275 !type_cast<IntType>(value.getType()).isUnsigned()) {
1276 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1278 rewriter.replaceOp(op, value);
1281template <
typename OpTy>
1282static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
1284 if (op.getType().getBitWidthOrSentinel() == 0)
1286 APInt(0, 0, op.getType().isSignedInteger()));
1289 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1290 return op.getHigh();
1295 if (op.getType().getBitWidthOrSentinel() < 0)
1300 if (cond->isZero() && op.getLow().getType() == op.getType())
1302 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1303 return op.getHigh();
1307 if (
auto lowCst =
getConstant(adaptor.getLow())) {
1309 if (
auto highCst =
getConstant(adaptor.getHigh())) {
1311 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1312 *highCst == *lowCst)
1315 if (highCst->isOne() && lowCst->isZero() &&
1316 op.getType() == op.getSel().getType())
1329OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1330 return foldMux(*
this, adaptor);
1333OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1334 return foldMux(*
this, adaptor);
1337OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return {}; }
1344class MuxPad :
public mlir::RewritePattern {
1346 MuxPad(MLIRContext *context)
1347 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1350 matchAndRewrite(Operation *op,
1351 mlir::PatternRewriter &rewriter)
const override {
1352 auto mux = cast<MuxPrimOp>(op);
1353 auto width = mux.getType().getBitWidthOrSentinel();
1357 auto pad = [&](Value input) -> Value {
1359 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1360 if (inputWidth < 0 || width == inputWidth)
1363 .create<PadPrimOp>(mux.getLoc(), mux.getType(), input, width)
1367 auto newHigh = pad(mux.getHigh());
1368 auto newLow = pad(mux.getLow());
1369 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1372 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1373 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1381class MuxSharedCond :
public mlir::RewritePattern {
1383 MuxSharedCond(MLIRContext *context)
1384 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1386 static const int depthLimit = 5;
1388 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1389 mlir::PatternRewriter &rewriter,
1390 bool updateInPlace)
const {
1391 if (updateInPlace) {
1392 rewriter.modifyOpInPlace(mux, [&] {
1393 mux.setOperand(1, high);
1394 mux.setOperand(2, low);
1398 rewriter.setInsertionPointAfter(mux);
1400 .create<MuxPrimOp>(mux.getLoc(), mux.getType(),
1401 ValueRange{mux.getSel(), high, low})
1406 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1407 bool updateInPlace,
int limit)
const {
1408 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1411 if (mux.getSel() == cond)
1412 return mux.getHigh();
1413 if (limit > depthLimit)
1415 updateInPlace &= mux->hasOneUse();
1417 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1419 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1422 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1423 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1428 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1429 bool updateInPlace,
int limit)
const {
1430 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1433 if (mux.getSel() == cond)
1434 return mux.getLow();
1435 if (limit > depthLimit)
1437 updateInPlace &= mux->hasOneUse();
1439 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1441 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1443 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1445 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1451 matchAndRewrite(Operation *op,
1452 mlir::PatternRewriter &rewriter)
const override {
1453 auto mux = cast<MuxPrimOp>(op);
1454 auto width = mux.getType().getBitWidthOrSentinel();
1458 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter,
true, 0)) {
1459 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1463 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter,
true, 0)) {
1464 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1473void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1474 MLIRContext *context) {
1476 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1477 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1478 patterns::MuxSameTrue, patterns::MuxSameFalse,
1479 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1483void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1484 RewritePatternSet &results, MLIRContext *context) {
1485 results.add<patterns::Mux2PadSel>(context);
1488void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1489 RewritePatternSet &results, MLIRContext *context) {
1490 results.add<patterns::Mux4PadSel>(context);
1493OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1494 auto input = this->getInput();
1497 if (input.getType() == getType())
1501 auto inputType = input.getType().base();
1508 auto destWidth = getType().base().getWidthOrSentinel();
1509 if (destWidth == -1)
1512 if (inputType.
isSigned() && cst->getBitWidth())
1513 return getIntAttr(getType(), cst->sext(destWidth));
1514 return getIntAttr(getType(), cst->zext(destWidth));
1520OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1521 auto input = this->getInput();
1522 IntType inputType = input.getType();
1523 int shiftAmount = getAmount();
1526 if (shiftAmount == 0)
1532 if (inputWidth != -1) {
1533 auto resultWidth = inputWidth + shiftAmount;
1534 shiftAmount = std::min(shiftAmount, resultWidth);
1535 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1541OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1542 auto input = this->getInput();
1543 IntType inputType = input.getType();
1544 int shiftAmount = getAmount();
1550 if (shiftAmount == 0 && inputWidth > 0)
1553 if (inputWidth == -1)
1555 if (inputWidth == 0)
1560 if (shiftAmount >= inputWidth && inputType.
isUnsigned())
1561 return getIntAttr(getType(), APInt(0, 0,
false));
1567 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1569 value = cst->lshr(std::min(shiftAmount, inputWidth));
1570 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1571 return getIntAttr(getType(), value.trunc(resultWidth));
1576LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1577 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1578 if (inputWidth <= 0)
1582 unsigned shiftAmount = op.getAmount();
1583 if (
int(shiftAmount) >= inputWidth) {
1585 if (op.getType().base().isUnsigned())
1591 shiftAmount = inputWidth - 1;
1594 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1598LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1599 PatternRewriter &rewriter) {
1600 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1601 if (inputWidth <= 0)
1605 unsigned keepAmount = op.getAmount();
1607 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1612OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1616 getInput().getType().base().getWidthOrSentinel() - getAmount();
1617 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1623OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1627 cst->trunc(getType().base().getWidthOrSentinel()));
1631LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1632 PatternRewriter &rewriter) {
1633 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1634 if (inputWidth <= 0)
1638 unsigned dropAmount = op.getAmount();
1639 if (dropAmount !=
unsigned(inputWidth))
1645void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1646 MLIRContext *context) {
1647 results.add<patterns::SubaccessOfConstant>(context);
1650OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1652 if (adaptor.getInputs().size() == 1)
1653 return getOperand(1);
1655 if (
auto constIndex =
getConstant(adaptor.getIndex())) {
1656 auto index = constIndex->getZExtValue();
1657 if (index < getInputs().size())
1658 return getInputs()[getInputs().size() - 1 - index];
1664LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1665 PatternRewriter &rewriter) {
1669 if (llvm::all_of(op.getInputs().drop_front(), [&](
auto input) {
1670 return input == op.getInputs().front();
1678 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
1679 uint64_t inputSize = op.getInputs().size();
1680 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
1681 rewriter.modifyOpInPlace(op, [&]() {
1682 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
1689 if (
auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
1690 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](
auto e) {
1691 auto subindex = e.value().template getDefiningOp<SubindexOp>();
1692 return subindex && lastSubindex.getInput() == subindex.getInput() &&
1693 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
1695 replaceOpWithNewOpAndCopyName<SubaccessOp>(
1696 rewriter, op, lastSubindex.getInput(), op.getIndex());
1702 if (op.getInputs().size() != 2)
1706 auto uintType = op.getIndex().getType();
1707 if (uintType.getBitWidthOrSentinel() != 1)
1711 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1712 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
1731 MatchingConnectOp connect;
1732 for (Operation *user : value.getUsers()) {
1734 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
1737 if (
auto aConnect = dyn_cast<FConnectLike>(user))
1738 if (aConnect.getDest() == value) {
1739 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
1742 if (!matchingConnect || (connect && connect != matchingConnect) ||
1743 matchingConnect->getBlock() != value.getParentBlock())
1746 connect = matchingConnect;
1754 PatternRewriter &rewriter) {
1757 Operation *connectedDecl = op.getDest().getDefiningOp();
1762 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
1766 cast<Forceable>(connectedDecl).isForceable())
1774 if (connectedDecl->hasOneUse())
1778 auto *declBlock = connectedDecl->getBlock();
1779 auto *srcValueOp = op.getSrc().getDefiningOp();
1782 if (!isa<WireOp>(connectedDecl))
1788 if (!isa<ConstantOp>(srcValueOp))
1790 if (srcValueOp->getBlock() != declBlock)
1796 auto replacement = op.getSrc();
1799 if (srcValueOp && srcValueOp != &declBlock->front())
1800 srcValueOp->moveBefore(&declBlock->front());
1807 rewriter.eraseOp(op);
1811void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1812 MLIRContext *context) {
1813 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
1817LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
1818 PatternRewriter &rewriter) {
1835 for (
auto *user : value.getUsers()) {
1836 auto attach = dyn_cast<AttachOp>(user);
1837 if (!attach || attach == dominatedAttach)
1839 if (attach->isBeforeInBlock(dominatedAttach))
1845LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
1847 if (op.getNumOperands() <= 1) {
1848 rewriter.eraseOp(op);
1852 for (
auto operand : op.getOperands()) {
1859 SmallVector<Value> newOperands(op.getOperands());
1860 for (
auto newOperand : attach.getOperands())
1861 if (newOperand != operand)
1862 newOperands.push_back(newOperand);
1863 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1864 rewriter.eraseOp(attach);
1865 rewriter.eraseOp(op);
1873 if (
auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
1874 if (!
hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
1875 !wire.isForceable()) {
1876 SmallVector<Value> newOperands;
1877 for (
auto newOperand : op.getOperands())
1878 if (newOperand != operand)
1879 newOperands.push_back(newOperand);
1881 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1882 rewriter.eraseOp(op);
1883 rewriter.eraseOp(wire);
1894 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
1895 rewriter.inlineBlockBefore(®ion.front(), op, {});
1898LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
1899 if (
auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
1900 if (constant.getValue().isAllOnes())
1902 else if (op.hasElseRegion() && !op.getElseRegion().empty())
1905 rewriter.eraseOp(op);
1911 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
1912 op.getElseBlock().empty()) {
1913 rewriter.eraseBlock(&op.getElseBlock());
1920 if (!op.getThenBlock().empty())
1924 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
1925 rewriter.eraseOp(op);
1934struct FoldNodeName :
public mlir::RewritePattern {
1935 FoldNodeName(MLIRContext *context)
1936 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1937 LogicalResult matchAndRewrite(Operation *op,
1938 PatternRewriter &rewriter)
const override {
1939 auto node = cast<NodeOp>(op);
1940 auto name = node.getNameAttr();
1941 if (!node.hasDroppableName() || node.getInnerSym() ||
1944 auto *newOp = node.getInput().getDefiningOp();
1946 if (newOp && !isa<InstanceOp>(newOp))
1948 rewriter.replaceOp(node, node.getInput());
1954struct NodeBypass :
public mlir::RewritePattern {
1955 NodeBypass(MLIRContext *context)
1956 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1957 LogicalResult matchAndRewrite(Operation *op,
1958 PatternRewriter &rewriter)
const override {
1959 auto node = cast<NodeOp>(op);
1961 node.use_empty() || node.isForceable())
1963 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
1970template <
typename OpTy>
1972 PatternRewriter &rewriter) {
1973 if (!op.isForceable() || !op.getDataRef().use_empty())
1981LogicalResult NodeOp::fold(FoldAdaptor adaptor,
1982 SmallVectorImpl<OpFoldResult> &results) {
1991 if (!adaptor.getInput())
1994 results.push_back(adaptor.getInput());
1998void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1999 MLIRContext *context) {
2000 results.insert<FoldNodeName>(context);
2001 results.add(demoteForceableIfUnused<NodeOp>);
2007struct AggOneShot :
public mlir::RewritePattern {
2008 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2009 : RewritePattern(name, 0, context) {}
2011 SmallVector<Value> getCompleteWrite(Operation *lhs)
const {
2012 auto lhsTy = lhs->getResult(0).getType();
2013 if (!type_isa<BundleType, FVectorType>(lhsTy))
2016 DenseMap<uint32_t, Value> fields;
2017 for (Operation *user : lhs->getResult(0).getUsers()) {
2018 if (user->getParentOp() != lhs->getParentOp())
2020 if (
auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2021 if (aConnect.getDest() == lhs->getResult(0))
2023 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2024 for (Operation *subuser : subField.getResult().getUsers()) {
2025 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2026 if (aConnect.getDest() == subField) {
2027 if (subuser->getParentOp() != lhs->getParentOp())
2029 if (fields.count(subField.getFieldIndex()))
2031 fields[subField.getFieldIndex()] = aConnect.getSrc();
2037 }
else if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2038 for (Operation *subuser : subIndex.getResult().getUsers()) {
2039 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2040 if (aConnect.getDest() == subIndex) {
2041 if (subuser->getParentOp() != lhs->getParentOp())
2043 if (fields.count(subIndex.getIndex()))
2045 fields[subIndex.getIndex()] = aConnect.getSrc();
2056 SmallVector<Value> values;
2057 uint32_t total = type_isa<BundleType>(lhsTy)
2058 ? type_cast<BundleType>(lhsTy).getNumElements()
2059 : type_cast<FVectorType>(lhsTy).getNumElements();
2060 for (uint32_t i = 0; i < total; ++i) {
2061 if (!fields.count(i))
2063 values.push_back(fields[i]);
2068 LogicalResult matchAndRewrite(Operation *op,
2069 PatternRewriter &rewriter)
const override {
2070 auto values = getCompleteWrite(op);
2073 rewriter.setInsertionPointToEnd(op->getBlock());
2074 auto dest = op->getResult(0);
2075 auto destType = dest.getType();
2078 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2081 Value newVal = type_isa<BundleType>(destType)
2082 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2084 : rewriter.createOrFold<VectorCreateOp>(
2085 op->getLoc(), destType, values);
2086 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2087 for (Operation *user : dest.getUsers()) {
2088 if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2089 for (Operation *subuser :
2090 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2091 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2092 if (aConnect.getDest() == subIndex)
2093 rewriter.eraseOp(aConnect);
2094 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2095 for (Operation *subuser :
2096 llvm::make_early_inc_range(subField.getResult().getUsers()))
2097 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2098 if (aConnect.getDest() == subField)
2099 rewriter.eraseOp(aConnect);
2106struct WireAggOneShot :
public AggOneShot {
2107 WireAggOneShot(MLIRContext *context)
2108 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2110struct SubindexAggOneShot :
public AggOneShot {
2111 SubindexAggOneShot(MLIRContext *context)
2112 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2114struct SubfieldAggOneShot :
public AggOneShot {
2115 SubfieldAggOneShot(MLIRContext *context)
2116 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2120void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2121 MLIRContext *context) {
2122 results.insert<WireAggOneShot>(context);
2123 results.add(demoteForceableIfUnused<WireOp>);
2126void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2127 MLIRContext *context) {
2128 results.insert<SubindexAggOneShot>(context);
2131OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2132 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2135 return attr[getIndex()];
2138OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2139 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2142 auto index = getFieldIndex();
2146void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2147 MLIRContext *context) {
2148 results.insert<SubfieldAggOneShot>(context);
2152 ArrayRef<Attribute> operands) {
2153 for (
auto operand : operands)
2156 return ArrayAttr::get(context, operands);
2159OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2162 if (getNumOperands() > 0)
2163 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2164 if (first.getFieldIndex() == 0 &&
2165 first.getInput().getType() == getType() &&
2167 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2169 elem.value().
template getDefiningOp<SubfieldOp>();
2170 return subindex && subindex.getInput() == first.getInput() &&
2171 subindex.getFieldIndex() == elem.index();
2173 return first.getInput();
2178OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2181 if (getNumOperands() > 0)
2182 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2183 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2185 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2187 elem.value().
template getDefiningOp<SubindexOp>();
2188 return subindex && subindex.getInput() == first.getInput() &&
2189 subindex.getIndex() == elem.index();
2191 return first.getInput();
2196OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2197 if (getOperand().getType() == getType())
2198 return getOperand();
2205struct FoldResetMux :
public mlir::RewritePattern {
2206 FoldResetMux(MLIRContext *context)
2207 : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2208 LogicalResult matchAndRewrite(Operation *op,
2209 PatternRewriter &rewriter)
const override {
2210 auto reg = cast<RegResetOp>(op);
2212 dyn_cast_or_null<ConstantOp>(
reg.getResetValue().getDefiningOp());
2221 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2224 auto *high = mux.getHigh().getDefiningOp();
2225 auto *low = mux.getLow().getDefiningOp();
2226 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2228 if (constOp && low != reg)
2230 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2231 constOp = dyn_cast<ConstantOp>(low);
2233 if (!constOp || constOp.getType() != reset.getType() ||
2234 constOp.getValue() != reset.getValue())
2238 auto regTy =
reg.getResult().getType();
2239 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2240 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2241 regTy.getBitWidthOrSentinel() < 0)
2247 if (constOp != &con->getBlock()->front())
2248 constOp->moveBefore(&con->getBlock()->front());
2253 rewriter.eraseOp(con);
2260 if (
auto c = v.getDefiningOp<ConstantOp>())
2261 return c.getValue().isOne();
2262 if (
auto sc = v.getDefiningOp<SpecialConstantOp>())
2263 return sc.getValue();
2273 (void)
dropWrite(rewriter, reg->getResult(0), {});
2274 replaceOpWithNewOpAndCopyName<NodeOp>(
2275 rewriter, reg, reg.getResetValue(), reg.getNameAttr(), reg.getNameKind(),
2276 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2280void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2281 MLIRContext *context) {
2282 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2284 results.add(demoteForceableIfUnused<RegResetOp>);
2289 auto portTy = type_cast<BundleType>(port.getType());
2290 auto fieldIndex = portTy.getElementIndex(name);
2291 assert(fieldIndex &&
"missing field on memory port");
2294 for (
auto *op : port.getUsers()) {
2295 auto portAccess = cast<SubfieldOp>(op);
2296 if (fieldIndex != portAccess.getFieldIndex())
2301 value = conn.getSrc();
2311 auto portConst = value.getDefiningOp<ConstantOp>();
2314 return portConst.getValue().isZero();
2319 auto portTy = type_cast<BundleType>(port.getType());
2320 auto fieldIndex = portTy.getElementIndex(
data);
2321 assert(fieldIndex &&
"missing enable flag on memory port");
2323 for (
auto *op : port.getUsers()) {
2324 auto portAccess = cast<SubfieldOp>(op);
2325 if (fieldIndex != portAccess.getFieldIndex())
2327 if (!portAccess.use_empty())
2336 StringRef name, Value value) {
2337 auto portTy = type_cast<BundleType>(port.getType());
2338 auto fieldIndex = portTy.getElementIndex(name);
2339 assert(fieldIndex &&
"missing field on memory port");
2341 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2342 auto portAccess = cast<SubfieldOp>(op);
2343 if (fieldIndex != portAccess.getFieldIndex())
2345 rewriter.replaceAllUsesWith(portAccess, value);
2346 rewriter.eraseOp(portAccess);
2351static void erasePort(PatternRewriter &rewriter, Value port) {
2354 auto getClock = [&] {
2356 clock = rewriter.create<SpecialConstantOp>(
2357 port.getLoc(), ClockType::get(rewriter.getContext()),
false);
2365 for (
auto *op : port.getUsers()) {
2366 auto subfield = dyn_cast<SubfieldOp>(op);
2368 auto ty = port.getType();
2369 auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
2370 rewriter.replaceAllUsesWith(port, reg.getResult());
2379 for (
auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2380 auto access = cast<SubfieldOp>(accessOp);
2381 for (
auto *user : llvm::make_early_inc_range(access->getUsers())) {
2382 auto connect = dyn_cast<FConnectLike>(user);
2383 if (connect && connect.getDest() == access) {
2384 rewriter.eraseOp(user);
2388 if (access.use_empty()) {
2389 rewriter.eraseOp(access);
2395 auto ty = access.getType();
2396 auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
2397 rewriter.replaceOp(access, reg.getResult());
2399 assert(port.use_empty() &&
"port should have no remaining uses");
2404struct FoldZeroWidthMemory :
public mlir::RewritePattern {
2405 FoldZeroWidthMemory(MLIRContext *context)
2406 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2407 LogicalResult matchAndRewrite(Operation *op,
2408 PatternRewriter &rewriter)
const override {
2409 MemOp mem = cast<MemOp>(op);
2413 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2414 mem.getDataType().getBitWidthOrSentinel() != 0)
2418 for (
auto port : mem.getResults())
2419 for (auto *user : port.getUsers())
2420 if (!isa<SubfieldOp>(user))
2425 for (
auto port : op->getResults()) {
2426 for (
auto *user :
llvm::make_early_inc_range(port.getUsers())) {
2427 SubfieldOp sfop = cast<SubfieldOp>(user);
2428 StringRef fieldName = sfop.getFieldName();
2429 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2430 rewriter, sfop, sfop.getResult().getType())
2432 if (fieldName.ends_with(
"data")) {
2434 auto zero = rewriter.create<firrtl::ConstantOp>(
2435 wire.getLoc(), firrtl::type_cast<IntType>(wire.getType()),
2437 rewriter.create<MatchingConnectOp>(wire.getLoc(), wire, zero);
2441 rewriter.eraseOp(op);
2447struct FoldReadOrWriteOnlyMemory :
public mlir::RewritePattern {
2448 FoldReadOrWriteOnlyMemory(MLIRContext *context)
2449 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2450 LogicalResult matchAndRewrite(Operation *op,
2451 PatternRewriter &rewriter)
const override {
2452 MemOp mem = cast<MemOp>(op);
2455 bool isRead =
false, isWritten =
false;
2456 for (
unsigned i = 0; i < mem.getNumResults(); ++i) {
2457 switch (mem.getPortKind(i)) {
2458 case MemOp::PortKind::Read:
2463 case MemOp::PortKind::Write:
2468 case MemOp::PortKind::Debug:
2469 case MemOp::PortKind::ReadWrite:
2472 llvm_unreachable(
"unknown port kind");
2474 assert((!isWritten || !isRead) &&
"memory is in use");
2479 if (isRead && mem.getInit())
2482 for (
auto port : mem.getResults())
2485 rewriter.eraseOp(op);
2491struct FoldUnusedPorts :
public mlir::RewritePattern {
2492 FoldUnusedPorts(MLIRContext *context)
2493 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2494 LogicalResult matchAndRewrite(Operation *op,
2495 PatternRewriter &rewriter)
const override {
2496 MemOp mem = cast<MemOp>(op);
2500 llvm::SmallBitVector deadPorts(mem.getNumResults());
2501 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2503 if (!mem.getPortAnnotation(i).empty())
2507 auto kind = mem.getPortKind(i);
2508 if (kind == MemOp::PortKind::Debug)
2517 if (kind == MemOp::PortKind::Read &&
isPortUnused(port,
"data")) {
2522 if (deadPorts.none())
2526 SmallVector<Type> resultTypes;
2527 SmallVector<StringRef> portNames;
2528 SmallVector<Attribute> portAnnotations;
2529 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2532 resultTypes.push_back(port.getType());
2533 portNames.push_back(mem.getPortName(i));
2534 portAnnotations.push_back(mem.getPortAnnotation(i));
2538 if (!resultTypes.empty())
2539 newOp = rewriter.create<MemOp>(
2540 mem.getLoc(), resultTypes, mem.getReadLatency(),
2541 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2542 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2543 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2544 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2547 unsigned nextPort = 0;
2548 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2552 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2555 rewriter.eraseOp(op);
2561struct FoldReadWritePorts :
public mlir::RewritePattern {
2562 FoldReadWritePorts(MLIRContext *context)
2563 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2564 LogicalResult matchAndRewrite(Operation *op,
2565 PatternRewriter &rewriter)
const override {
2566 MemOp mem = cast<MemOp>(op);
2571 llvm::SmallBitVector deadReads(mem.getNumResults());
2572 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2573 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2575 if (!mem.getPortAnnotation(i).empty())
2582 if (deadReads.none())
2585 SmallVector<Type> resultTypes;
2586 SmallVector<StringRef> portNames;
2587 SmallVector<Attribute> portAnnotations;
2588 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2590 resultTypes.push_back(
2591 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2592 MemOp::PortKind::Write, mem.getMaskBits()));
2594 resultTypes.push_back(port.getType());
2596 portNames.push_back(mem.getPortName(i));
2597 portAnnotations.push_back(mem.getPortAnnotation(i));
2600 auto newOp = rewriter.create<MemOp>(
2601 mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
2602 mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
2603 mem.getName(), mem.getNameKind(), mem.getAnnotations(),
2604 rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
2605 mem.getInitAttr(), mem.getPrefixAttr());
2607 for (
unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2608 auto result = mem.getResult(i);
2609 auto newResult = newOp.getResult(i);
2611 auto resultPortTy = type_cast<BundleType>(result.getType());
2615 auto replace = [&](StringRef toName, StringRef fromName) {
2616 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2617 assert(fromFieldIndex &&
"missing enable flag on memory port");
2619 auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
2621 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
2622 auto fromField = cast<SubfieldOp>(op);
2623 if (fromFieldIndex != fromField.getFieldIndex())
2625 rewriter.replaceOp(fromField, toField.getResult());
2629 replace(
"addr",
"addr");
2630 replace(
"en",
"en");
2631 replace(
"clk",
"clk");
2632 replace(
"data",
"wdata");
2633 replace(
"mask",
"wmask");
2636 auto wmodeFieldIndex = resultPortTy.getElementIndex(
"wmode");
2637 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
2638 auto wmodeField = cast<SubfieldOp>(op);
2639 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2641 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2644 rewriter.replaceAllUsesWith(result, newResult);
2647 rewriter.eraseOp(op);
2653struct FoldUnusedBits :
public mlir::RewritePattern {
2654 FoldUnusedBits(MLIRContext *context)
2655 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2657 LogicalResult matchAndRewrite(Operation *op,
2658 PatternRewriter &rewriter)
const override {
2659 MemOp mem = cast<MemOp>(op);
2664 const auto &summary = mem.getSummary();
2665 if (summary.isMasked || summary.isSeqMem())
2668 auto type = type_dyn_cast<IntType>(mem.getDataType());
2671 auto width = type.getBitWidthOrSentinel();
2675 llvm::SmallBitVector usedBits(width);
2676 DenseMap<unsigned, unsigned> mapping;
2681 SmallVector<BitsPrimOp> readOps;
2682 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
2683 auto portTy = type_cast<BundleType>(port.getType());
2684 auto fieldIndex = portTy.getElementIndex(field);
2685 assert(fieldIndex &&
"missing data port");
2687 for (
auto *op : port.getUsers()) {
2688 auto portAccess = cast<SubfieldOp>(op);
2689 if (fieldIndex != portAccess.getFieldIndex())
2692 for (
auto *user : op->getUsers()) {
2693 auto bits = dyn_cast<BitsPrimOp>(user);
2697 usedBits.set(bits.getLo(), bits.getHi() + 1);
2701 mapping[bits.getLo()] = 0;
2702 readOps.push_back(bits);
2712 SmallVector<MatchingConnectOp> writeOps;
2713 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
2714 auto portTy = type_cast<BundleType>(port.getType());
2715 auto fieldIndex = portTy.getElementIndex(field);
2716 assert(fieldIndex &&
"missing data port");
2718 for (
auto *op : port.getUsers()) {
2719 auto portAccess = cast<SubfieldOp>(op);
2720 if (fieldIndex != portAccess.getFieldIndex())
2727 writeOps.push_back(conn);
2733 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2735 if (!mem.getPortAnnotation(i).empty())
2738 switch (mem.getPortKind(i)) {
2739 case MemOp::PortKind::Debug:
2742 case MemOp::PortKind::Write:
2743 if (failed(findWriteUsers(port,
"data")))
2746 case MemOp::PortKind::Read:
2747 if (failed(findReadUsers(port,
"data")))
2750 case MemOp::PortKind::ReadWrite:
2751 if (failed(findWriteUsers(port,
"wdata")))
2753 if (failed(findReadUsers(port,
"rdata")))
2757 llvm_unreachable(
"unknown port kind");
2761 if (usedBits.none())
2765 SmallVector<std::pair<unsigned, unsigned>> ranges;
2766 unsigned newWidth = 0;
2767 for (
int i = usedBits.find_first(); 0 <= i && i < width;) {
2768 int e = usedBits.find_next_unset(i);
2771 for (
int idx = i; idx < e; ++idx, ++newWidth) {
2772 if (
auto it = mapping.find(idx); it != mapping.end()) {
2773 it->second = newWidth;
2776 ranges.emplace_back(i, e - 1);
2777 i = e != width ? usedBits.find_next(e) : e;
2781 auto newType =
IntType::get(op->getContext(), type.isSigned(), newWidth);
2782 SmallVector<Type> portTypes;
2783 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2784 portTypes.push_back(
2785 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
2787 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
2788 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
2789 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
2790 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
2791 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2794 auto rewriteSubfield = [&](Value port, StringRef field) {
2795 auto portTy = type_cast<BundleType>(port.getType());
2796 auto fieldIndex = portTy.getElementIndex(field);
2797 assert(fieldIndex &&
"missing data port");
2799 rewriter.setInsertionPointAfter(newMem);
2800 auto newPortAccess =
2801 rewriter.create<SubfieldOp>(port.getLoc(), port, field);
2803 for (
auto *op :
llvm::make_early_inc_range(port.getUsers())) {
2804 auto portAccess = cast<SubfieldOp>(op);
2805 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
2807 rewriter.replaceOp(portAccess, newPortAccess.getResult());
2812 for (
auto [i, port] :
llvm::enumerate(newMem.getResults())) {
2813 switch (newMem.getPortKind(i)) {
2814 case MemOp::PortKind::Debug:
2815 llvm_unreachable(
"cannot rewrite debug port");
2816 case MemOp::PortKind::Write:
2817 rewriteSubfield(port,
"data");
2819 case MemOp::PortKind::Read:
2820 rewriteSubfield(port,
"data");
2822 case MemOp::PortKind::ReadWrite:
2823 rewriteSubfield(port,
"rdata");
2824 rewriteSubfield(port,
"wdata");
2827 llvm_unreachable(
"unknown port kind");
2831 for (
auto readOp : readOps) {
2832 rewriter.setInsertionPointAfter(readOp);
2833 auto it = mapping.find(readOp.getLo());
2834 assert(it != mapping.end() &&
"bit op mapping not found");
2837 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
2838 readOp.getLoc(), readOp.getInput(),
2839 readOp.getHi() - readOp.getLo() + it->second, it->second);
2840 rewriter.replaceAllUsesWith(readOp, newReadValue);
2841 rewriter.eraseOp(readOp);
2845 for (
auto writeOp : writeOps) {
2846 Value source = writeOp.getSrc();
2847 rewriter.setInsertionPoint(writeOp);
2850 for (
auto &[start, end] : ranges) {
2851 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
2852 source,
end, start);
2854 catOfSlices = rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(),
2855 slice, catOfSlices);
2857 catOfSlices = slice;
2865 if (type.isSigned())
2867 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
2869 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
2878struct FoldRegMems :
public mlir::RewritePattern {
2879 FoldRegMems(MLIRContext *context)
2880 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2881 LogicalResult matchAndRewrite(Operation *op,
2882 PatternRewriter &rewriter)
const override {
2883 MemOp mem = cast<MemOp>(op);
2884 const FirMemory &info = mem.getSummary();
2888 auto ty = mem.getDataType();
2889 auto loc = mem.getLoc();
2890 auto *block = mem->getBlock();
2894 SmallPtrSet<Operation *, 8> connects;
2895 SmallVector<SubfieldOp> portAccesses;
2896 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2897 if (!mem.getPortAnnotation(i).empty())
2900 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
2901 auto portTy = type_cast<BundleType>(port.getType());
2902 for (
auto field : fields) {
2903 auto fieldIndex = portTy.getElementIndex(field);
2904 assert(fieldIndex &&
"missing field on memory port");
2906 for (
auto *op : port.getUsers()) {
2907 auto portAccess = cast<SubfieldOp>(op);
2908 if (fieldIndex != portAccess.getFieldIndex())
2910 portAccesses.push_back(portAccess);
2911 for (
auto *user : portAccess->getUsers()) {
2912 auto conn = dyn_cast<FConnectLike>(user);
2915 connects.insert(conn);
2922 switch (mem.getPortKind(i)) {
2923 case MemOp::PortKind::Debug:
2925 case MemOp::PortKind::Read:
2926 if (failed(collect({
"clk",
"en",
"addr"})))
2929 case MemOp::PortKind::Write:
2930 if (failed(collect({
"clk",
"en",
"addr",
"data",
"mask"})))
2933 case MemOp::PortKind::ReadWrite:
2934 if (failed(collect({
"clk",
"en",
"addr",
"wmode",
"wdata",
"wmask"})))
2940 if (!portClock || (clock && portClock != clock))
2946 rewriter.setInsertionPointAfter(mem);
2947 auto memWire = rewriter.create<WireOp>(loc, ty).getResult();
2953 rewriter.setInsertionPointToEnd(block);
2955 rewriter.create<RegOp>(loc, ty, clock, mem.getName()).getResult();
2958 rewriter.create<MatchingConnectOp>(loc, memWire, memReg);
2962 auto pipeline = [&](Value value, Value clock,
const Twine &name,
2964 for (
unsigned i = 0; i < latency; ++i) {
2965 std::string regName;
2967 llvm::raw_string_ostream os(regName);
2968 os << mem.getName() <<
"_" << name <<
"_" << i;
2971 .create<RegOp>(mem.getLoc(), value.getType(), clock,
2972 rewriter.getStringAttr(regName))
2974 rewriter.create<MatchingConnectOp>(value.getLoc(),
reg, value);
2985 SmallVector<std::tuple<Value, Value, Value>> writes;
2986 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2988 StringRef name = mem.getPortName(i);
2990 auto portPipeline = [&, port = port](StringRef field,
unsigned stages) {
2993 return pipeline(value, portClock, name +
"_" + field, stages);
2996 switch (mem.getPortKind(i)) {
2997 case MemOp::PortKind::Debug:
2998 llvm_unreachable(
"unknown port kind");
2999 case MemOp::PortKind::Read: {
3007 case MemOp::PortKind::Write: {
3008 auto data = portPipeline(
"data", writeStages);
3009 auto en = portPipeline(
"en", writeStages);
3010 auto mask = portPipeline(
"mask", writeStages);
3014 case MemOp::PortKind::ReadWrite: {
3019 auto wdata = portPipeline(
"wdata", writeStages);
3020 auto wmask = portPipeline(
"wmask", writeStages);
3025 auto wen = rewriter.create<AndPrimOp>(port.getLoc(),
en,
wmode);
3027 pipeline(wen, portClock, name +
"_wen", writeStages);
3028 writes.emplace_back(
wdata, wenPipelined,
wmask);
3035 Value next = memReg;
3041 Location loc = mem.getLoc();
3043 for (
unsigned i = 0; i < info.
maskBits; ++i) {
3044 unsigned hi = (i + 1) * maskGran - 1;
3045 unsigned lo = i * maskGran;
3047 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc,
data, hi, lo);
3048 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3049 auto bit = rewriter.createOrFold<BitsPrimOp>(loc,
mask, i, i);
3050 auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
3053 masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
3059 next = rewriter.create<MuxPrimOp>(next.getLoc(),
en, masked, next);
3061 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3062 rewriter.create<MatchingConnectOp>(memReg.getLoc(), memReg, typedNext);
3065 for (Operation *conn : connects)
3066 rewriter.eraseOp(conn);
3067 for (
auto portAccess : portAccesses)
3068 rewriter.eraseOp(portAccess);
3069 rewriter.eraseOp(mem);
3076void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3077 MLIRContext *context) {
3079 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3080 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3100 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3103 auto *high = mux.getHigh().getDefiningOp();
3104 auto *low = mux.getLow().getDefiningOp();
3106 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3113 bool constReg =
false;
3115 if (constOp && low == reg)
3117 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3119 constOp = dyn_cast<ConstantOp>(low);
3126 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3130 auto regTy = reg.getResult().getType();
3131 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3132 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3133 regTy.getBitWidthOrSentinel() < 0)
3139 if (constOp != &con->getBlock()->front())
3140 constOp->moveBefore(&con->getBlock()->front());
3143 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3144 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3145 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3146 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3147 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3148 reg.getForceableAttr());
3149 newReg->setDialectAttrs(attrs);
3151 auto pt = rewriter.saveInsertionPoint();
3152 rewriter.setInsertionPoint(con);
3153 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3154 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3155 rewriter.restoreInsertionPoint(pt);
3159LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3160 if (!
hasDontTouch(op.getOperation()) && !op.isForceable() &&
3176 PatternRewriter &rewriter,
3179 if (
auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3180 if (constant.getValue().isZero()) {
3181 rewriter.eraseOp(op);
3187 if (
auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3188 if (constant.getValue().isZero() == eraseIfZero) {
3189 rewriter.eraseOp(op);
3197template <
class Op,
bool EraseIfZero = false>
3199 PatternRewriter &rewriter) {
3204void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3205 MLIRContext *context) {
3206 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3209void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3210 MLIRContext *context) {
3211 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3214void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3215 RewritePatternSet &results, MLIRContext *context) {
3216 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3219void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3220 MLIRContext *context) {
3221 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3228LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3229 PatternRewriter &rewriter) {
3231 if (op.use_empty()) {
3232 rewriter.eraseOp(op);
3239 if (op->hasOneUse() &&
3240 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3241 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3242 *op->user_begin()) ||
3243 (isa<CvtPrimOp>(*op->user_begin()) &&
3244 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3245 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3246 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3247 .getBitWidthOrSentinel() > 0))) {
3248 auto *modop = *op->user_begin();
3249 auto inv = rewriter.create<InvalidValueOp>(op.getLoc(),
3250 modop->getResult(0).getType());
3251 rewriter.replaceAllOpUsesWith(modop, inv);
3252 rewriter.eraseOp(modop);
3253 rewriter.eraseOp(op);
3259OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3260 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3261 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3269OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3278 return BoolAttr::get(getContext(),
false);
3282 return BoolAttr::get(getContext(),
false);
3287LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3288 PatternRewriter &rewriter) {
3290 if (
auto testEnable = op.getTestEnable()) {
3291 if (
auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3292 if (constOp.getValue().isZero()) {
3293 rewriter.modifyOpInPlace(op,
3294 [&] { op.getTestEnableMutable().clear(); });
3310 auto forceable = op.getRef().getDefiningOp<Forceable>();
3311 if (!forceable || !forceable.isForceable() ||
3312 op.getRef() != forceable.getDataRef() ||
3313 op.getType() != forceable.getDataType())
3315 rewriter.replaceAllUsesWith(op, forceable.getData());
3319void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3320 MLIRContext *context) {
3321 results.insert<patterns::RefResolveOfRefSend>(context);
3325OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3327 if (getInput().getType() == getType())
3333 auto constOp = operand.getDefiningOp<ConstantOp>();
3334 return constOp && constOp.getValue().isZero();
3337template <
typename Op>
3340 rewriter.eraseOp(op);
3346void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3347 MLIRContext *context) {
3348 results.add(eraseIfPredFalse<RefForceOp>);
3350void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3351 MLIRContext *context) {
3352 results.add(eraseIfPredFalse<RefForceInitialOp>);
3354void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3355 MLIRContext *context) {
3356 results.add(eraseIfPredFalse<RefReleaseOp>);
3358void RefReleaseInitialOp::getCanonicalizationPatterns(
3359 RewritePatternSet &results, MLIRContext *context) {
3360 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3367OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3373 if (adaptor.getReset())
3378 if (
isUInt1(getReset().getType()) && adaptor.getClock())
3391 [&](
auto ty) ->
bool {
return isTypeEmpty(ty.getElementType()); })
3392 .Case<BundleType>([&](
auto ty) ->
bool {
3393 for (
auto elem : ty.getElements())
3398 .Case<IntType>([&](
auto ty) {
return ty.getWidth() == 0; })
3399 .Default([](
auto) ->
bool {
return false; });
3402LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3403 PatternRewriter &rewriter) {
3404 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3411 rewriter.eraseOp(op);
3419LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3420 PatternRewriter &rewriter) {
3423 if (op.getBody()->empty()) {
3424 rewriter.eraseOp(op);
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static InstancePath empty
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)