20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallPtrSet.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
29using namespace firrtl;
33static Value
dropWrite(PatternRewriter &rewriter, OpResult old,
35 SmallPtrSet<Operation *, 8> users;
36 for (
auto *user : old.getUsers())
38 for (Operation *user : users)
39 if (
auto connect = dyn_cast<FConnectLike>(user))
40 if (connect.getDest() == old)
41 rewriter.eraseOp(user);
51 if (op->getNumRegions() != 0)
53 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
61 Operation *op = passthrough.getDefiningOp();
64 assert(op &&
"passthrough must be an operation");
65 Operation *oldOp = old.getOwner();
66 auto name = oldOp->getAttrOfType<StringAttr>(
"name");
68 op->setAttr(
"name", name);
76#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
84 auto resultType = type_cast<IntType>(op->getResult(0).getType());
85 if (!resultType.hasWidth())
87 for (Value operand : op->getOperands())
88 if (!type_cast<IntType>(operand.getType()).hasWidth())
95 auto t = type_dyn_cast<UIntType>(type);
96 if (!t || !t.hasWidth() || t.getWidth() != 1)
103static void updateName(PatternRewriter &rewriter, Operation *op,
108 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) &&
"Should never rename");
109 auto newName = name.getValue();
110 auto newOpName = op->getAttrOfType<StringAttr>(
"name");
113 newName =
chooseName(newOpName.getValue(), name.getValue());
115 if (!newOpName || newOpName.getValue() != newName)
116 rewriter.modifyOpInPlace(
117 op, [&] { op->setAttr(
"name", rewriter.getStringAttr(newName)); });
125 if (
auto *newOp = newValue.getDefiningOp()) {
126 auto name = op->getAttrOfType<StringAttr>(
"name");
129 rewriter.replaceOp(op, newValue);
135template <
typename OpTy,
typename... Args>
137 Operation *op, Args &&...args) {
138 auto name = op->getAttrOfType<StringAttr>(
"name");
140 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
148 if (
auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
149 return namableOp.hasDroppableName();
160static std::optional<APSInt>
162 assert(type_cast<IntType>(operand.getType()) &&
163 "getExtendedConstant is limited to integer types");
170 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
175 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
176 return APSInt(destWidth,
177 type_cast<IntType>(operand.getType()).isUnsigned());
185 if (
auto attr = dyn_cast<BoolAttr>(operand))
186 return APSInt(APInt(1, attr.getValue()));
187 if (
auto attr = dyn_cast<IntegerAttr>(operand))
188 return attr.getAPSInt();
196 return cst->isZero();
223 Operation *op, ArrayRef<Attribute> operands,
BinOpKind opKind,
224 const function_ref<APInt(
const APSInt &,
const APSInt &)> &calculate) {
225 assert(operands.size() == 2 &&
"binary op takes two operands");
228 auto resultType = type_cast<IntType>(op->getResult(0).getType());
229 if (resultType.getWidthOrSentinel() < 0)
233 if (resultType.getWidthOrSentinel() == 0)
234 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
240 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
242 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
243 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
244 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
245 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
246 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
250 int32_t operandWidth;
253 operandWidth = resultType.getWidthOrSentinel();
258 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
262 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
273 APInt resultValue = calculate(*lhs, *rhs);
278 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
280 assert((
unsigned)resultType.getWidthOrSentinel() ==
281 resultValue.getBitWidth());
294 Operation *op, PatternRewriter &rewriter,
295 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
297 if (op->getNumResults() != 1)
299 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
304 auto width = type.getBitWidthOrSentinel();
309 SmallVector<Attribute, 3> constOperands;
310 constOperands.reserve(op->getNumOperands());
311 for (
auto operand : op->getOperands()) {
313 if (
auto *defOp = operand.getDefiningOp())
314 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
315 [&](
auto op) { attr = op.getValueAttr(); });
316 constOperands.push_back(attr);
321 auto result = canonicalize(constOperands);
325 if (
auto cst = dyn_cast<Attribute>(result))
326 resultValue = op->getDialect()
327 ->materializeConstant(rewriter, cst, type, op->getLoc())
330 resultValue = cast<Value>(result);
334 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
335 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
338 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
339 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
340 else if (type_isa<UIntType>(type) &&
341 type_isa<SIntType>(resultValue.getType()))
342 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
344 assert(type == resultValue.getType() &&
"canonicalization changed type");
352 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
358 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
364 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
371OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
372 assert(adaptor.getOperands().empty() &&
"constant has no operands");
373 return getValueAttr();
376OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
377 assert(adaptor.getOperands().empty() &&
"constant has no operands");
378 return getValueAttr();
381OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
382 assert(adaptor.getOperands().empty() &&
"constant has no operands");
383 return getFieldsAttr();
386OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
387 assert(adaptor.getOperands().empty() &&
"constant has no operands");
388 return getValueAttr();
391OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
392 assert(adaptor.getOperands().empty() &&
"constant has no operands");
393 return getValueAttr();
396OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
397 assert(adaptor.getOperands().empty() &&
"constant has no operands");
398 return getValueAttr();
401OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
402 assert(adaptor.getOperands().empty() &&
"constant has no operands");
403 return getValueAttr();
410OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
413 [=](
const APSInt &a,
const APSInt &b) { return a + b; });
416void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
418 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
419 patterns::AddOfSelf, patterns::AddOfPad>(
context);
422OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
425 [=](
const APSInt &a,
const APSInt &b) { return a - b; });
428void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
430 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
431 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
432 patterns::SubOfPadL, patterns::SubOfPadR>(
context);
435OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
447 [=](
const APSInt &a,
const APSInt &b) { return a * b; });
450OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
457 if (getLhs() == getRhs()) {
458 auto width = getType().base().getWidthOrSentinel();
463 return getIntAttr(getType(), APInt(width, 1));
480 if (
auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
481 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
486 [=](
const APSInt &a,
const APSInt &b) -> APInt {
489 return APInt(a.getBitWidth(), 0);
493OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
500 if (getLhs() == getRhs())
514 [=](
const APSInt &a,
const APSInt &b) -> APInt {
517 return APInt(a.getBitWidth(), 0);
521OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
524 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
527OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
530 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
533OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
536 [=](
const APSInt &a,
const APSInt &b) -> APInt {
537 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
543OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
546 if (rhsCst->isZero())
550 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
551 getRhs().getType() == getType())
557 if (lhsCst->isZero())
561 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
562 getRhs().getType() == getType())
567 if (getLhs() == getRhs() && getRhs().getType() == getType())
572 [](
const APSInt &a,
const APSInt &b) -> APInt { return a & b; });
575void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
578 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
579 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
580 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(
context);
583OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
586 if (rhsCst->isZero() && getLhs().getType() == getType())
590 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
591 getLhs().getType() == getType())
597 if (lhsCst->isZero() && getRhs().getType() == getType())
601 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
602 getRhs().getType() == getType())
607 if (getLhs() == getRhs() && getRhs().getType() == getType())
612 [](
const APSInt &a,
const APSInt &b) -> APInt { return a | b; });
615void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
617 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
618 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
622OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
625 if (rhsCst->isZero() &&
631 if (lhsCst->isZero() &&
636 if (getLhs() == getRhs())
639 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
643 [](
const APSInt &a,
const APSInt &b) -> APInt { return a ^ b; });
646void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
648 results.insert<patterns::extendXor, patterns::moveConstXor,
649 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
653void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
655 results.insert<patterns::LEQWithConstLHS>(
context);
658OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
659 bool isUnsigned = getLhs().getType().base().isUnsigned();
662 if (getLhs() == getRhs())
666 if (
auto width = getLhs().getType().base().
getWidth()) {
668 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
669 commonWidth = std::max(commonWidth, 1);
680 if (isUnsigned && rhsCst->zext(commonWidth)
693 [=](
const APSInt &a,
const APSInt &b) -> APInt {
694 return APInt(1, a <= b);
698void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
700 results.insert<patterns::LTWithConstLHS>(
context);
703OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
704 IntType lhsType = getLhs().getType();
708 if (getLhs() == getRhs())
718 if (
auto width = lhsType.
getWidth()) {
720 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
721 commonWidth = std::max(commonWidth, 1);
732 if (isUnsigned && rhsCst->zext(commonWidth)
745 [=](
const APSInt &a,
const APSInt &b) -> APInt {
746 return APInt(1, a < b);
750void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
752 results.insert<patterns::GEQWithConstLHS>(
context);
755OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
756 IntType lhsType = getLhs().getType();
760 if (getLhs() == getRhs())
765 if (rhsCst->isZero() && isUnsigned)
770 if (
auto width = lhsType.
getWidth()) {
772 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
773 commonWidth = std::max(commonWidth, 1);
776 if (isUnsigned && rhsCst->zext(commonWidth)
797 [=](
const APSInt &a,
const APSInt &b) -> APInt {
798 return APInt(1, a >= b);
802void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
804 results.insert<patterns::GTWithConstLHS>(
context);
807OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
808 IntType lhsType = getLhs().getType();
812 if (getLhs() == getRhs())
816 if (
auto width = lhsType.
getWidth()) {
818 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
819 commonWidth = std::max(commonWidth, 1);
822 if (isUnsigned && rhsCst->zext(commonWidth)
843 [=](
const APSInt &a,
const APSInt &b) -> APInt {
844 return APInt(1, a > b);
848OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
850 if (getLhs() == getRhs())
856 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
857 getRhs().getType() == getType())
863 [=](
const APSInt &a,
const APSInt &b) -> APInt {
864 return APInt(1, a == b);
868LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
870 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
872 auto width = op.getLhs().getType().getBitWidthOrSentinel();
875 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
876 op.getRhs().getType() == op.getType()) {
877 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
882 if (rhsCst->isZero() && width > 1) {
883 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
884 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
888 if (rhsCst->isAllOnes() && width > 1 &&
889 op.getLhs().getType() == op.getRhs().getType()) {
890 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
898OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
900 if (getLhs() == getRhs())
906 if (rhsCst->isZero() && getLhs().getType() == getType() &&
907 getRhs().getType() == getType())
913 [=](
const APSInt &a,
const APSInt &b) -> APInt {
914 return APInt(1, a != b);
918LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
920 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
922 auto width = op.getLhs().getType().getBitWidthOrSentinel();
925 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
926 op.getRhs().getType() == op.getType()) {
927 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
932 if (rhsCst->isZero() && width > 1) {
933 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
938 if (rhsCst->isAllOnes() && width > 1 &&
939 op.getLhs().getType() == op.getRhs().getType()) {
941 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
942 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
950OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
956OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
962OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
966 return IntegerAttr::get(IntegerType::get(getContext(),
967 lhsCst->getBitWidth(),
968 IntegerType::Signed),
969 lhsCst->ashr(*rhsCst));
972 if (rhsCst->isZero())
979OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
984 return IntegerAttr::get(IntegerType::get(getContext(),
985 lhsCst->getBitWidth(),
986 IntegerType::Signed),
987 lhsCst->shl(*rhsCst));
990 if (rhsCst->isZero())
1001OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
1002 auto base = getInput().getType();
1009OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1016OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1024 if (getType().base().hasWidth())
1031void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1033 results.insert<patterns::StoUtoS>(
context);
1036OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1044 if (getType().base().hasWidth())
1051void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1053 results.insert<patterns::UtoStoU>(
context);
1056OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1058 if (getInput().getType() == getType())
1063 return BoolAttr::get(getContext(), cst->getBoolValue());
1068OpFoldResult AsResetPrimOp::fold(FoldAdaptor adaptor) {
1070 return BoolAttr::get(getContext(), cst->getBoolValue());
1074OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1076 if (getInput().getType() == getType())
1081 return BoolAttr::get(getContext(), cst->getBoolValue());
1086OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1092 getType().base().getWidthOrSentinel()))
1098void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1100 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(
context);
1103OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1110 getType().base().getWidthOrSentinel()))
1111 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1116OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1121 getType().base().getWidthOrSentinel()))
1127void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1129 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1130 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1137 : RewritePattern(opName, 0,
context) {}
1143 ConstantOp constantOp,
1144 SmallVectorImpl<Value> &remaining)
const = 0;
1151 mlir::PatternRewriter &rewriter)
const override {
1153 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1157 SmallVector<Value> nonConstantOperands;
1160 for (
auto operand : catOp.getInputs()) {
1161 if (
auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1163 if (
handleConstant(rewriter, op, constantOp, nonConstantOperands))
1167 nonConstantOperands.push_back(operand);
1172 if (nonConstantOperands.empty()) {
1173 replaceOpWithNewOpAndCopyName<ConstantOp>(
1174 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1180 if (nonConstantOperands.size() == 1) {
1181 rewriter.modifyOpInPlace(
1182 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1187 if (catOp->hasOneUse() &&
1188 nonConstantOperands.size() < catOp->getNumOperands()) {
1189 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1190 nonConstantOperands);
1203 SmallVectorImpl<Value> &remaining)
const override {
1204 if (value.getValue().isZero())
1207 replaceOpWithNewOpAndCopyName<ConstantOp>(
1208 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1221 SmallVectorImpl<Value> &remaining)
const override {
1222 if (value.getValue().isAllOnes())
1225 replaceOpWithNewOpAndCopyName<ConstantOp>(
1226 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1239 SmallVectorImpl<Value> &remaining)
const override {
1240 if (value.getValue().isZero())
1242 remaining.push_back(value);
1248OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1252 if (getInput().getType().getBitWidthOrSentinel() == 0)
1257 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1261 if (
isUInt1(getInput().getType()))
1267void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1269 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1270 patterns::AndRPadS, patterns::AndRCatAndR_left,
1274OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1278 if (getInput().getType().getBitWidthOrSentinel() == 0)
1283 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1287 if (
isUInt1(getInput().getType()))
1293void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1295 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1296 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right,
OrRCat>(
1300OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1304 if (getInput().getType().getBitWidthOrSentinel() == 0)
1309 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1312 if (
isUInt1(getInput().getType()))
1318void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1321 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1322 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right,
XorRCat>(
1330OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1331 auto inputs = getInputs();
1332 auto inputAdaptors = adaptor.getInputs();
1339 if (inputs.size() == 1 && inputs[0].getType() == getType())
1347 SmallVector<Value> nonZeroInputs;
1348 SmallVector<Attribute> nonZeroAttributes;
1349 bool allConstant =
true;
1350 for (
auto [input, attr] :
llvm::zip(inputs, inputAdaptors)) {
1351 auto inputType = type_cast<IntType>(input.getType());
1352 if (inputType.getBitWidthOrSentinel() != 0) {
1353 nonZeroInputs.push_back(input);
1355 allConstant =
false;
1356 if (nonZeroInputs.size() > 1 && !allConstant)
1362 if (nonZeroInputs.empty())
1366 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1367 return nonZeroInputs[0];
1373 SmallVector<APInt> constants;
1374 for (
auto inputAdaptor : inputAdaptors) {
1376 constants.push_back(*cst);
1381 assert(!constants.empty());
1383 APInt result = constants[0];
1384 for (
size_t i = 1; i < constants.size(); ++i)
1385 result = result.concat(constants[i]);
1390void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1392 results.insert<patterns::DShlOfConstant>(
context);
1395void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1397 results.insert<patterns::DShrOfConstant>(
context);
1405 using OpRewritePattern::OpRewritePattern;
1408 matchAndRewrite(CatPrimOp cat,
1409 mlir::PatternRewriter &rewriter)
const override {
1411 cat.getType().getBitWidthOrSentinel() == 0)
1415 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1419 SmallVector<Value> operands;
1420 SmallVector<Value> worklist;
1421 auto pushOperands = [&worklist](CatPrimOp op) {
1422 for (
auto operand :
llvm::reverse(op.getInputs()))
1423 worklist.push_back(operand);
1426 bool hasSigned =
false, hasUnsigned =
false;
1427 while (!worklist.empty()) {
1428 auto value = worklist.pop_back_val();
1429 auto catOp = value.getDefiningOp<CatPrimOp>();
1431 operands.push_back(value);
1432 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) =
true;
1436 pushOperands(catOp);
1441 auto castToUIntIfSigned = [&](Value value) -> Value {
1442 if (type_isa<UIntType>(value.getType()))
1444 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1447 assert(operands.size() >= 1 &&
"zero width cast must be rejected");
1449 if (operands.size() == 1) {
1450 rewriter.replaceOp(cat, castToUIntIfSigned(operands[0]));
1454 if (operands.size() == cat->getNumOperands())
1458 if (hasSigned && hasUnsigned)
1459 for (
auto &operand : operands)
1460 operand = castToUIntIfSigned(operand);
1462 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1471 using OpRewritePattern::OpRewritePattern;
1474 matchAndRewrite(CatPrimOp cat,
1475 mlir::PatternRewriter &rewriter)
const override {
1479 SmallVector<Value> operands;
1481 for (
size_t i = 0; i < cat->getNumOperands(); ++i) {
1482 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1484 operands.push_back(cat.getInputs()[i]);
1487 APSInt value = cst.getValue();
1489 for (; j < cat->getNumOperands(); ++j) {
1490 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1493 value = value.concat(nextCst.getValue());
1498 operands.push_back(cst);
1501 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1507 if (operands.size() == cat->getNumOperands())
1510 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1519void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1521 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1522 patterns::CatCast, FlattenCat, CatOfConstant>(
context);
1529OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
1531 if (getInputs().size() == 1)
1532 return getInputs()[0];
1535 if (!llvm::all_of(adaptor.getInputs(), [](Attribute operand) {
1536 return isa_and_nonnull<StringAttr>(operand);
1541 SmallString<64> result;
1542 for (
auto operand : adaptor.getInputs())
1543 result += cast<StringAttr>(operand).getValue();
1545 return StringAttr::get(getContext(), result);
1553 using OpRewritePattern::OpRewritePattern;
1556 matchAndRewrite(StringConcatOp concat,
1557 mlir::PatternRewriter &rewriter)
const override {
1561 bool hasNestedConcat = llvm::any_of(concat.getInputs(), [](Value operand) {
1562 auto nestedConcat = operand.getDefiningOp<StringConcatOp>();
1563 return nestedConcat && operand.hasOneUse();
1566 if (!hasNestedConcat)
1570 SmallVector<Value> flatOperands;
1571 for (
auto input : concat.getInputs()) {
1572 if (
auto nestedConcat = input.getDefiningOp<StringConcatOp>();
1573 nestedConcat && input.hasOneUse())
1574 llvm::append_range(flatOperands, nestedConcat.getInputs());
1576 flatOperands.push_back(input);
1579 rewriter.modifyOpInPlace(concat,
1580 [&]() { concat->setOperands(flatOperands); });
1587class MergeAdjacentStringConstants
1590 using OpRewritePattern::OpRewritePattern;
1593 matchAndRewrite(StringConcatOp concat,
1594 mlir::PatternRewriter &rewriter)
const override {
1596 SmallVector<Value> newOperands;
1597 SmallString<64> accumulatedLit;
1598 SmallVector<StringConstantOp> accumulatedOps;
1599 bool changed =
false;
1601 auto flushLiterals = [&]() {
1602 if (accumulatedOps.empty())
1606 if (accumulatedOps.size() == 1) {
1607 newOperands.push_back(accumulatedOps[0]);
1610 auto newLit = rewriter.createOrFold<StringConstantOp>(
1611 concat.getLoc(), StringAttr::get(getContext(), accumulatedLit));
1612 newOperands.push_back(newLit);
1615 accumulatedLit.clear();
1616 accumulatedOps.clear();
1619 for (
auto operand : concat.getInputs()) {
1620 if (
auto litOp = operand.getDefiningOp<StringConstantOp>()) {
1622 if (litOp.getValue().empty()) {
1626 accumulatedLit += litOp.getValue();
1627 accumulatedOps.push_back(litOp);
1630 newOperands.push_back(operand);
1641 if (newOperands.empty())
1642 return rewriter.replaceOpWithNewOp<StringConstantOp>(
1643 concat, StringAttr::get(getContext(),
"")),
1647 rewriter.modifyOpInPlace(concat,
1648 [&]() { concat->setOperands(newOperands); });
1655void StringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
1657 results.insert<FlattenStringConcat, MergeAdjacentStringConstants>(
context);
1660OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1663 if (op.getType() == op.getInput().getType())
1664 return op.getInput();
1668 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1669 if (op.getType() == in.getInput().getType())
1670 return in.getInput();
1675OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1676 IntType inputType = getInput().getType();
1677 IntType resultType = getType();
1679 if (inputType == getType() && resultType.
hasWidth())
1686 cst->extractBits(getHi() - getLo() + 1, getLo()));
1692 using OpRewritePattern::OpRewritePattern;
1696 mlir::PatternRewriter &rewriter)
const override {
1697 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1700 int32_t bitPos = bits.getLo();
1701 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1702 if (resultWidth < 0)
1704 for (
auto operand : llvm::reverse(cat.getInputs())) {
1706 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1707 if (operandWidth < 0)
1709 if (bitPos < operandWidth) {
1710 if (bitPos + resultWidth <= operandWidth) {
1711 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1712 bits.getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1718 bitPos -= operandWidth;
1724void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1727 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1735 unsigned loBit, PatternRewriter &rewriter) {
1736 auto resType = type_cast<IntType>(op->getResult(0).getType());
1737 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1738 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1740 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1741 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1742 }
else if (resType.isUnsigned() &&
1743 !type_cast<IntType>(value.getType()).isUnsigned()) {
1744 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1746 rewriter.replaceOp(op, value);
1749template <
typename OpTy>
1750static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
1752 if (op.getType().getBitWidthOrSentinel() == 0)
1754 APInt(0, 0, op.getType().isSignedInteger()));
1757 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1758 return op.getHigh();
1763 if (op.getType().getBitWidthOrSentinel() < 0)
1768 if (cond->isZero() && op.getLow().getType() == op.getType())
1770 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1771 return op.getHigh();
1775 if (
auto lowCst =
getConstant(adaptor.getLow())) {
1777 if (
auto highCst =
getConstant(adaptor.getHigh())) {
1779 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1780 *highCst == *lowCst)
1783 if (highCst->isOne() && lowCst->isZero() &&
1784 op.getType() == op.getSel().getType())
1797OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1798 return foldMux(*
this, adaptor);
1801OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1802 return foldMux(*
this, adaptor);
1805OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return {}; }
1814 using OpRewritePattern::OpRewritePattern;
1817 matchAndRewrite(MuxPrimOp mux,
1818 mlir::PatternRewriter &rewriter)
const override {
1819 auto width = mux.getType().getBitWidthOrSentinel();
1823 auto pad = [&](Value input) -> Value {
1825 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1826 if (inputWidth < 0 || width == inputWidth)
1828 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1833 auto newHigh = pad(mux.getHigh());
1834 auto newLow = pad(mux.getLow());
1835 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1838 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1839 rewriter, mux, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1849 using OpRewritePattern::OpRewritePattern;
1851 static const int depthLimit = 5;
1853 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1854 mlir::PatternRewriter &rewriter,
1855 bool updateInPlace)
const {
1856 if (updateInPlace) {
1857 rewriter.modifyOpInPlace(mux, [&] {
1858 mux.setOperand(1, high);
1859 mux.setOperand(2, low);
1863 rewriter.setInsertionPointAfter(mux);
1864 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1865 ValueRange{mux.getSel(), high, low})
1870 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1871 bool updateInPlace,
int limit)
const {
1872 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1875 if (mux.getSel() == cond)
1876 return mux.getHigh();
1877 if (limit > depthLimit)
1879 updateInPlace &= mux->hasOneUse();
1881 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1883 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1886 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1887 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1892 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1893 bool updateInPlace,
int limit)
const {
1894 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1897 if (mux.getSel() == cond)
1898 return mux.getLow();
1899 if (limit > depthLimit)
1901 updateInPlace &= mux->hasOneUse();
1903 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1905 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1907 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1909 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1915 matchAndRewrite(MuxPrimOp mux,
1916 mlir::PatternRewriter &rewriter)
const override {
1917 auto width = mux.getType().getBitWidthOrSentinel();
1921 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter,
true, 0)) {
1922 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1926 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter,
true, 0)) {
1927 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1936void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1939 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1940 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1941 patterns::MuxSameTrue, patterns::MuxSameFalse,
1942 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1946void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1947 RewritePatternSet &results, MLIRContext *
context) {
1948 results.add<patterns::Mux2PadSel>(
context);
1951void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1952 RewritePatternSet &results, MLIRContext *
context) {
1953 results.add<patterns::Mux4PadSel>(
context);
1956OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1957 auto input = this->getInput();
1960 if (input.getType() == getType())
1964 auto inputType = input.getType().base();
1971 auto destWidth = getType().base().getWidthOrSentinel();
1972 if (destWidth == -1)
1975 if (inputType.
isSigned() && cst->getBitWidth())
1976 return getIntAttr(getType(), cst->sext(destWidth));
1977 return getIntAttr(getType(), cst->zext(destWidth));
1983OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1984 auto input = this->getInput();
1985 IntType inputType = input.getType();
1986 int shiftAmount = getAmount();
1989 if (shiftAmount == 0)
1995 if (inputWidth != -1) {
1996 auto resultWidth = inputWidth + shiftAmount;
1997 shiftAmount = std::min(shiftAmount, resultWidth);
1998 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
2004OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
2005 auto input = this->getInput();
2006 IntType inputType = input.getType();
2007 int shiftAmount = getAmount();
2013 if (shiftAmount == 0 && inputWidth > 0)
2016 if (inputWidth == -1)
2018 if (inputWidth == 0)
2023 if (shiftAmount >= inputWidth && inputType.
isUnsigned())
2024 return getIntAttr(getType(), APInt(0, 0,
false));
2030 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
2032 value = cst->lshr(std::min(shiftAmount, inputWidth));
2033 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
2034 return getIntAttr(getType(), value.trunc(resultWidth));
2039LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
2040 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2041 if (inputWidth <= 0)
2045 unsigned shiftAmount = op.getAmount();
2046 if (
int(shiftAmount) >= inputWidth) {
2048 if (op.getType().base().isUnsigned())
2054 shiftAmount = inputWidth - 1;
2057 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
2061LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
2062 PatternRewriter &rewriter) {
2063 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2064 if (inputWidth <= 0)
2068 unsigned keepAmount = op.getAmount();
2070 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
2075OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
2079 getInput().getType().base().getWidthOrSentinel() - getAmount();
2080 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
2086OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
2090 cst->trunc(getType().base().getWidthOrSentinel()));
2094LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
2095 PatternRewriter &rewriter) {
2096 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2097 if (inputWidth <= 0)
2101 unsigned dropAmount = op.getAmount();
2102 if (dropAmount !=
unsigned(inputWidth))
2108void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
2110 results.add<patterns::SubaccessOfConstant>(
context);
2113OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
2115 if (adaptor.getInputs().size() == 1)
2116 return getOperand(1);
2118 if (
auto constIndex =
getConstant(adaptor.getIndex())) {
2119 auto index = constIndex->getZExtValue();
2120 if (index < getInputs().size())
2121 return getInputs()[getInputs().size() - 1 - index];
2127LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
2128 PatternRewriter &rewriter) {
2132 if (llvm::all_of(op.getInputs().drop_front(), [&](
auto input) {
2133 return input == op.getInputs().front();
2141 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2142 uint64_t inputSize = op.getInputs().size();
2143 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2144 rewriter.modifyOpInPlace(op, [&]() {
2145 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2152 if (
auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2153 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](
auto e) {
2154 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2155 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2156 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2158 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2159 rewriter, op, lastSubindex.getInput(), op.getIndex());
2165 if (op.getInputs().size() != 2)
2169 auto uintType = op.getIndex().getType();
2170 if (uintType.getBitWidthOrSentinel() != 1)
2174 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2175 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2194 MatchingConnectOp connect;
2195 for (Operation *user : value.getUsers()) {
2197 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2200 if (
auto aConnect = dyn_cast<FConnectLike>(user))
2201 if (aConnect.getDest() == value) {
2202 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2205 if (!matchingConnect || (connect && connect != matchingConnect) ||
2206 matchingConnect->getBlock() != value.getParentBlock())
2208 connect = matchingConnect;
2216 PatternRewriter &rewriter) {
2219 Operation *connectedDecl = op.getDest().getDefiningOp();
2224 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2228 cast<Forceable>(connectedDecl).isForceable())
2236 if (connectedDecl->hasOneUse())
2240 auto *declBlock = connectedDecl->getBlock();
2241 auto *srcValueOp = op.getSrc().getDefiningOp();
2244 if (!isa<WireOp>(connectedDecl))
2250 if (!isa<ConstantOp>(srcValueOp))
2252 if (srcValueOp->getBlock() != declBlock)
2258 auto replacement = op.getSrc();
2261 if (srcValueOp && srcValueOp != &declBlock->front())
2262 srcValueOp->moveBefore(&declBlock->front());
2269 rewriter.eraseOp(op);
2273void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2275 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2279LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2280 PatternRewriter &rewriter) {
2297 for (
auto *user : value.getUsers()) {
2298 auto attach = dyn_cast<AttachOp>(user);
2299 if (!attach || attach == dominatedAttach)
2301 if (attach->isBeforeInBlock(dominatedAttach))
2307LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2309 if (op.getNumOperands() <= 1) {
2310 rewriter.eraseOp(op);
2314 for (
auto operand : op.getOperands()) {
2321 SmallVector<Value> newOperands(op.getOperands());
2322 for (
auto newOperand : attach.getOperands())
2323 if (newOperand != operand)
2324 newOperands.push_back(newOperand);
2325 AttachOp::create(rewriter, op->getLoc(), newOperands);
2326 rewriter.eraseOp(attach);
2327 rewriter.eraseOp(op);
2335 if (
auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2336 if (!
hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2337 !wire.isForceable()) {
2338 SmallVector<Value> newOperands;
2339 for (
auto newOperand : op.getOperands())
2340 if (newOperand != operand)
2341 newOperands.push_back(newOperand);
2343 AttachOp::create(rewriter, op->getLoc(), newOperands);
2344 rewriter.eraseOp(op);
2345 rewriter.eraseOp(wire);
2356 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
2357 rewriter.inlineBlockBefore(®ion.front(), op, {});
2360LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2361 if (
auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2362 if (constant.getValue().isAllOnes())
2364 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2367 rewriter.eraseOp(op);
2373 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2374 op.getElseBlock().empty()) {
2375 rewriter.eraseBlock(&op.getElseBlock());
2382 if (!op.getThenBlock().empty())
2386 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2387 rewriter.eraseOp(op);
2397 using OpRewritePattern::OpRewritePattern;
2398 LogicalResult matchAndRewrite(NodeOp node,
2399 PatternRewriter &rewriter)
const override {
2400 auto name = node.getNameAttr();
2401 if (!node.hasDroppableName() || node.getInnerSym() ||
2404 auto *newOp = node.getInput().getDefiningOp();
2407 rewriter.replaceOp(node, node.getInput());
2414 using OpRewritePattern::OpRewritePattern;
2415 LogicalResult matchAndRewrite(NodeOp node,
2416 PatternRewriter &rewriter)
const override {
2418 node.use_empty() || node.isForceable())
2420 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2427template <
typename OpTy>
2429 PatternRewriter &rewriter) {
2430 if (!op.isForceable() || !op.getDataRef().use_empty())
2438LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2439 SmallVectorImpl<OpFoldResult> &results) {
2448 if (!adaptor.getInput())
2451 results.push_back(adaptor.getInput());
2455void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2457 results.insert<FoldNodeName>(
context);
2458 results.add(demoteForceableIfUnused<NodeOp>);
2464struct AggOneShot :
public mlir::RewritePattern {
2465 AggOneShot(StringRef name, uint32_t weight, MLIRContext *
context)
2466 : RewritePattern(name, 0,
context) {}
2468 SmallVector<Value> getCompleteWrite(Operation *lhs)
const {
2469 auto lhsTy = lhs->getResult(0).getType();
2470 if (!type_isa<BundleType, FVectorType>(lhsTy))
2473 DenseMap<uint32_t, Value> fields;
2474 for (Operation *user : lhs->getResult(0).getUsers()) {
2475 if (user->getParentOp() != lhs->getParentOp())
2477 if (
auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2478 if (aConnect.getDest() == lhs->getResult(0))
2480 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2481 for (Operation *subuser : subField.getResult().getUsers()) {
2482 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2483 if (aConnect.getDest() == subField) {
2484 if (subuser->getParentOp() != lhs->getParentOp())
2486 if (fields.count(subField.getFieldIndex()))
2488 fields[subField.getFieldIndex()] = aConnect.getSrc();
2494 }
else if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2495 for (Operation *subuser : subIndex.getResult().getUsers()) {
2496 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2497 if (aConnect.getDest() == subIndex) {
2498 if (subuser->getParentOp() != lhs->getParentOp())
2500 if (fields.count(subIndex.getIndex()))
2502 fields[subIndex.getIndex()] = aConnect.getSrc();
2513 SmallVector<Value> values;
2514 uint32_t total = type_isa<BundleType>(lhsTy)
2515 ? type_cast<BundleType>(lhsTy).getNumElements()
2516 : type_cast<FVectorType>(lhsTy).getNumElements();
2517 for (uint32_t i = 0; i < total; ++i) {
2518 if (!fields.count(i))
2520 values.push_back(fields[i]);
2525 LogicalResult matchAndRewrite(Operation *op,
2526 PatternRewriter &rewriter)
const override {
2527 auto values = getCompleteWrite(op);
2530 rewriter.setInsertionPointToEnd(op->getBlock());
2531 auto dest = op->getResult(0);
2532 auto destType = dest.getType();
2535 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2538 Value newVal = type_isa<BundleType>(destType)
2539 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2541 : rewriter.createOrFold<VectorCreateOp>(
2542 op->
getLoc(), destType, values);
2543 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2544 for (Operation *user : dest.getUsers()) {
2545 if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2546 for (Operation *subuser :
2547 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2548 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2549 if (aConnect.getDest() == subIndex)
2550 rewriter.eraseOp(aConnect);
2551 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2552 for (Operation *subuser :
2553 llvm::make_early_inc_range(subField.getResult().getUsers()))
2554 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2555 if (aConnect.getDest() == subField)
2556 rewriter.eraseOp(aConnect);
2563struct WireAggOneShot :
public AggOneShot {
2564 WireAggOneShot(MLIRContext *
context)
2565 : AggOneShot(WireOp::getOperationName(), 0,
context) {}
2567struct SubindexAggOneShot :
public AggOneShot {
2568 SubindexAggOneShot(MLIRContext *
context)
2569 : AggOneShot(SubindexOp::getOperationName(), 0,
context) {}
2571struct SubfieldAggOneShot :
public AggOneShot {
2572 SubfieldAggOneShot(MLIRContext *
context)
2573 : AggOneShot(SubfieldOp::getOperationName(), 0,
context) {}
2577void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2579 results.insert<WireAggOneShot>(
context);
2580 results.add(demoteForceableIfUnused<WireOp>);
2583void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2585 results.insert<SubindexAggOneShot>(
context);
2588OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2589 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2592 return attr[getIndex()];
2595OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2596 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2599 auto index = getFieldIndex();
2603void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2605 results.insert<SubfieldAggOneShot>(
context);
2609 ArrayRef<Attribute> operands) {
2610 for (
auto operand : operands)
2613 return ArrayAttr::get(
context, operands);
2616OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2619 if (getNumOperands() > 0)
2620 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2621 if (first.getFieldIndex() == 0 &&
2622 first.getInput().getType() == getType() &&
2624 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2626 elem.value().
template getDefiningOp<SubfieldOp>();
2627 return subindex && subindex.getInput() == first.getInput() &&
2628 subindex.getFieldIndex() == elem.index();
2630 return first.getInput();
2635OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2638 if (getNumOperands() > 0)
2639 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2640 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2642 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2644 elem.value().
template getDefiningOp<SubindexOp>();
2645 return subindex && subindex.getInput() == first.getInput() &&
2646 subindex.getIndex() == elem.index();
2648 return first.getInput();
2653OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2654 if (getOperand().getType() == getType())
2655 return getOperand();
2663 using OpRewritePattern::OpRewritePattern;
2664 LogicalResult matchAndRewrite(RegResetOp reg,
2665 PatternRewriter &rewriter)
const override {
2667 dyn_cast_or_null<ConstantOp>(
reg.getResetValue().getDefiningOp());
2676 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2679 auto *high = mux.getHigh().getDefiningOp();
2680 auto *low = mux.getLow().getDefiningOp();
2681 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2683 if (constOp && low != reg)
2685 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2686 constOp = dyn_cast<ConstantOp>(low);
2688 if (!constOp || constOp.getType() != reset.getType() ||
2689 constOp.getValue() != reset.getValue())
2693 auto regTy =
reg.getResult().getType();
2694 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2695 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2696 regTy.getBitWidthOrSentinel() < 0)
2702 if (constOp != &con->getBlock()->front())
2703 constOp->moveBefore(&con->getBlock()->front());
2708 rewriter.eraseOp(con);
2715 if (
auto c = v.getDefiningOp<ConstantOp>())
2716 return c.getValue().isOne();
2717 if (
auto sc = v.getDefiningOp<SpecialConstantOp>())
2718 return sc.getValue();
2727 auto resetValue = reg.getResetValue();
2728 if (reg.getType(0) != resetValue.getType())
2732 (void)
dropWrite(rewriter, reg->getResult(0), {});
2733 replaceOpWithNewOpAndCopyName<NodeOp>(
2734 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2735 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2739void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2741 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(
context);
2743 results.add(demoteForceableIfUnused<RegResetOp>);
2748 auto portTy = type_cast<BundleType>(port.getType());
2749 auto fieldIndex = portTy.getElementIndex(name);
2750 assert(fieldIndex &&
"missing field on memory port");
2753 for (
auto *op : port.getUsers()) {
2754 auto portAccess = cast<SubfieldOp>(op);
2755 if (fieldIndex != portAccess.getFieldIndex())
2760 value = conn.getSrc();
2770 auto portConst = value.getDefiningOp<ConstantOp>();
2773 return portConst.getValue().isZero();
2778 auto portTy = type_cast<BundleType>(port.getType());
2779 auto fieldIndex = portTy.getElementIndex(
data);
2780 assert(fieldIndex &&
"missing enable flag on memory port");
2782 for (
auto *op : port.getUsers()) {
2783 auto portAccess = cast<SubfieldOp>(op);
2784 if (fieldIndex != portAccess.getFieldIndex())
2786 if (!portAccess.use_empty())
2795 StringRef name, Value value) {
2796 auto portTy = type_cast<BundleType>(port.getType());
2797 auto fieldIndex = portTy.getElementIndex(name);
2798 assert(fieldIndex &&
"missing field on memory port");
2800 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2801 auto portAccess = cast<SubfieldOp>(op);
2802 if (fieldIndex != portAccess.getFieldIndex())
2804 rewriter.replaceAllUsesWith(portAccess, value);
2805 rewriter.eraseOp(portAccess);
2810static void erasePort(PatternRewriter &rewriter, Value port) {
2813 auto getClock = [&] {
2815 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2816 ClockType::get(rewriter.getContext()),
2825 for (
auto *op : port.getUsers()) {
2826 auto subfield = dyn_cast<SubfieldOp>(op);
2828 auto ty = port.getType();
2829 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2830 rewriter.replaceAllUsesWith(port, reg.getResult());
2839 for (
auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2840 auto access = cast<SubfieldOp>(accessOp);
2841 for (
auto *user : llvm::make_early_inc_range(access->getUsers())) {
2842 auto connect = dyn_cast<FConnectLike>(user);
2843 if (connect && connect.getDest() == access) {
2844 rewriter.eraseOp(user);
2848 if (access.use_empty()) {
2849 rewriter.eraseOp(access);
2855 auto ty = access.getType();
2856 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2857 rewriter.replaceOp(access, reg.getResult());
2859 assert(port.use_empty() &&
"port should have no remaining uses");
2865 using OpRewritePattern::OpRewritePattern;
2866 LogicalResult matchAndRewrite(MemOp mem,
2867 PatternRewriter &rewriter)
const override {
2871 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2872 mem.getDataType().getBitWidthOrSentinel() != 0)
2876 for (
auto port : mem.getResults())
2877 for (auto *user : port.getUsers())
2878 if (!isa<SubfieldOp>(user))
2883 for (
auto port : mem.getResults()) {
2884 for (
auto *user :
llvm::make_early_inc_range(port.getUsers())) {
2885 SubfieldOp sfop = cast<SubfieldOp>(user);
2886 StringRef fieldName = sfop.getFieldName();
2887 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2888 rewriter, sfop, sfop.getResult().getType())
2890 if (fieldName.ends_with(
"data")) {
2892 auto zero = firrtl::ConstantOp::create(
2893 rewriter, wire.getLoc(),
2894 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2895 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2899 rewriter.eraseOp(mem);
2906 using OpRewritePattern::OpRewritePattern;
2907 LogicalResult matchAndRewrite(MemOp mem,
2908 PatternRewriter &rewriter)
const override {
2911 bool isRead =
false, isWritten =
false;
2912 for (
unsigned i = 0; i < mem.getNumResults(); ++i) {
2913 switch (mem.getPortKind(i)) {
2914 case MemOp::PortKind::Read:
2919 case MemOp::PortKind::Write:
2924 case MemOp::PortKind::Debug:
2925 case MemOp::PortKind::ReadWrite:
2928 llvm_unreachable(
"unknown port kind");
2930 assert((!isWritten || !isRead) &&
"memory is in use");
2935 if (isRead && mem.getInit())
2938 for (
auto port : mem.getResults())
2941 rewriter.eraseOp(mem);
2948 using OpRewritePattern::OpRewritePattern;
2949 LogicalResult matchAndRewrite(MemOp mem,
2950 PatternRewriter &rewriter)
const override {
2954 llvm::SmallBitVector deadPorts(mem.getNumResults());
2955 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2957 if (!mem.getPortAnnotation(i).empty())
2961 auto kind = mem.getPortKind(i);
2962 if (kind == MemOp::PortKind::Debug)
2971 if (kind == MemOp::PortKind::Read &&
isPortUnused(port,
"data")) {
2976 if (deadPorts.none())
2980 SmallVector<Type> resultTypes;
2981 SmallVector<StringRef> portNames;
2982 SmallVector<Attribute> portAnnotations;
2983 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2986 resultTypes.push_back(port.getType());
2987 portNames.push_back(mem.getPortName(i));
2988 portAnnotations.push_back(mem.getPortAnnotation(i));
2992 if (!resultTypes.empty())
2993 newOp = MemOp::create(
2994 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2995 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2996 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2997 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2998 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3001 unsigned nextPort = 0;
3002 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3006 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
3009 rewriter.eraseOp(mem);
3016 using OpRewritePattern::OpRewritePattern;
3017 LogicalResult matchAndRewrite(MemOp mem,
3018 PatternRewriter &rewriter)
const override {
3023 llvm::SmallBitVector deadReads(mem.getNumResults());
3024 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3025 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
3027 if (!mem.getPortAnnotation(i).empty())
3034 if (deadReads.none())
3037 SmallVector<Type> resultTypes;
3038 SmallVector<StringRef> portNames;
3039 SmallVector<Attribute> portAnnotations;
3040 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3042 resultTypes.push_back(
3043 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
3044 MemOp::PortKind::Write, mem.getMaskBits()));
3046 resultTypes.push_back(port.getType());
3048 portNames.push_back(mem.getPortName(i));
3049 portAnnotations.push_back(mem.getPortAnnotation(i));
3052 auto newOp = MemOp::create(
3053 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3054 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3055 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3056 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3057 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3059 for (
unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
3060 auto result = mem.getResult(i);
3061 auto newResult = newOp.getResult(i);
3063 auto resultPortTy = type_cast<BundleType>(result.getType());
3067 auto replace = [&](StringRef toName, StringRef fromName) {
3068 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
3069 assert(fromFieldIndex &&
"missing enable flag on memory port");
3071 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
3073 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
3074 auto fromField = cast<SubfieldOp>(op);
3075 if (fromFieldIndex != fromField.getFieldIndex())
3077 rewriter.replaceOp(fromField, toField.getResult());
3081 replace(
"addr",
"addr");
3082 replace(
"en",
"en");
3083 replace(
"clk",
"clk");
3084 replace(
"data",
"wdata");
3085 replace(
"mask",
"wmask");
3088 auto wmodeFieldIndex = resultPortTy.getElementIndex(
"wmode");
3089 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
3090 auto wmodeField = cast<SubfieldOp>(op);
3091 if (wmodeFieldIndex != wmodeField.getFieldIndex())
3093 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
3096 rewriter.replaceAllUsesWith(result, newResult);
3099 rewriter.eraseOp(mem);
3106 using OpRewritePattern::OpRewritePattern;
3108 LogicalResult matchAndRewrite(MemOp mem,
3109 PatternRewriter &rewriter)
const override {
3114 const auto &summary = mem.getSummary();
3115 if (summary.isMasked || summary.isSeqMem())
3118 auto type = type_dyn_cast<IntType>(mem.getDataType());
3121 auto width = type.getBitWidthOrSentinel();
3125 llvm::SmallBitVector usedBits(width);
3126 DenseMap<unsigned, unsigned> mapping;
3131 SmallVector<BitsPrimOp> readOps;
3132 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3133 auto portTy = type_cast<BundleType>(port.getType());
3134 auto fieldIndex = portTy.getElementIndex(field);
3135 assert(fieldIndex &&
"missing data port");
3137 for (
auto *op : port.getUsers()) {
3138 auto portAccess = cast<SubfieldOp>(op);
3139 if (fieldIndex != portAccess.getFieldIndex())
3142 for (
auto *user : op->getUsers()) {
3143 auto bits = dyn_cast<BitsPrimOp>(user);
3147 usedBits.set(bits.getLo(), bits.getHi() + 1);
3151 mapping[bits.getLo()] = 0;
3152 readOps.push_back(bits);
3162 SmallVector<MatchingConnectOp> writeOps;
3163 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3164 auto portTy = type_cast<BundleType>(port.getType());
3165 auto fieldIndex = portTy.getElementIndex(field);
3166 assert(fieldIndex &&
"missing data port");
3168 for (
auto *op : port.getUsers()) {
3169 auto portAccess = cast<SubfieldOp>(op);
3170 if (fieldIndex != portAccess.getFieldIndex())
3177 writeOps.push_back(conn);
3183 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3185 if (!mem.getPortAnnotation(i).empty())
3188 switch (mem.getPortKind(i)) {
3189 case MemOp::PortKind::Debug:
3192 case MemOp::PortKind::Write:
3193 if (failed(findWriteUsers(port,
"data")))
3196 case MemOp::PortKind::Read:
3197 if (failed(findReadUsers(port,
"data")))
3200 case MemOp::PortKind::ReadWrite:
3201 if (failed(findWriteUsers(port,
"wdata")))
3203 if (failed(findReadUsers(port,
"rdata")))
3207 llvm_unreachable(
"unknown port kind");
3211 if (usedBits.none())
3215 SmallVector<std::pair<unsigned, unsigned>> ranges;
3216 unsigned newWidth = 0;
3217 for (
int i = usedBits.find_first(); 0 <= i && i < width;) {
3218 int e = usedBits.find_next_unset(i);
3221 for (
int idx = i; idx < e; ++idx, ++newWidth) {
3222 if (
auto it = mapping.find(idx); it != mapping.end()) {
3223 it->second = newWidth;
3226 ranges.emplace_back(i, e - 1);
3227 i = e != width ? usedBits.find_next(e) : e;
3231 auto newType =
IntType::get(mem->getContext(), type.isSigned(), newWidth);
3232 SmallVector<Type> portTypes;
3233 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3234 portTypes.push_back(
3235 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3237 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3238 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3239 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3240 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3241 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3244 auto rewriteSubfield = [&](Value port, StringRef field) {
3245 auto portTy = type_cast<BundleType>(port.getType());
3246 auto fieldIndex = portTy.getElementIndex(field);
3247 assert(fieldIndex &&
"missing data port");
3249 rewriter.setInsertionPointAfter(newMem);
3250 auto newPortAccess =
3251 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3253 for (
auto *op :
llvm::make_early_inc_range(port.getUsers())) {
3254 auto portAccess = cast<SubfieldOp>(op);
3255 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3257 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3262 for (
auto [i, port] :
llvm::enumerate(newMem.getResults())) {
3263 switch (newMem.getPortKind(i)) {
3264 case MemOp::PortKind::Debug:
3265 llvm_unreachable(
"cannot rewrite debug port");
3266 case MemOp::PortKind::Write:
3267 rewriteSubfield(port,
"data");
3269 case MemOp::PortKind::Read:
3270 rewriteSubfield(port,
"data");
3272 case MemOp::PortKind::ReadWrite:
3273 rewriteSubfield(port,
"rdata");
3274 rewriteSubfield(port,
"wdata");
3277 llvm_unreachable(
"unknown port kind");
3281 for (
auto readOp : readOps) {
3282 rewriter.setInsertionPointAfter(readOp);
3283 auto it = mapping.find(readOp.getLo());
3284 assert(it != mapping.end() &&
"bit op mapping not found");
3287 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3288 readOp.getLoc(), readOp.getInput(),
3289 readOp.getHi() - readOp.getLo() + it->second, it->second);
3290 rewriter.replaceAllUsesWith(readOp, newReadValue);
3291 rewriter.eraseOp(readOp);
3295 for (
auto writeOp : writeOps) {
3296 Value source = writeOp.getSrc();
3297 rewriter.setInsertionPoint(writeOp);
3299 SmallVector<Value> slices;
3300 for (
auto &[start, end] :
llvm::reverse(ranges)) {
3301 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3302 source,
end, start);
3303 slices.push_back(slice);
3307 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3313 if (type.isSigned())
3315 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3317 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3327 using OpRewritePattern::OpRewritePattern;
3328 LogicalResult matchAndRewrite(MemOp mem,
3329 PatternRewriter &rewriter)
const override {
3334 auto ty = mem.getDataType();
3335 auto loc = mem.getLoc();
3336 auto *block = mem->getBlock();
3340 SmallPtrSet<Operation *, 8> connects;
3341 SmallVector<SubfieldOp> portAccesses;
3342 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3343 if (!mem.getPortAnnotation(i).empty())
3346 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3347 auto portTy = type_cast<BundleType>(port.getType());
3348 for (
auto field : fields) {
3349 auto fieldIndex = portTy.getElementIndex(field);
3350 assert(fieldIndex &&
"missing field on memory port");
3352 for (
auto *op : port.getUsers()) {
3353 auto portAccess = cast<SubfieldOp>(op);
3354 if (fieldIndex != portAccess.getFieldIndex())
3356 portAccesses.push_back(portAccess);
3357 for (
auto *user : portAccess->getUsers()) {
3358 auto conn = dyn_cast<FConnectLike>(user);
3361 connects.insert(conn);
3368 switch (mem.getPortKind(i)) {
3369 case MemOp::PortKind::Debug:
3371 case MemOp::PortKind::Read:
3372 if (failed(collect({
"clk",
"en",
"addr"})))
3375 case MemOp::PortKind::Write:
3376 if (failed(collect({
"clk",
"en",
"addr",
"data",
"mask"})))
3379 case MemOp::PortKind::ReadWrite:
3380 if (failed(collect({
"clk",
"en",
"addr",
"wmode",
"wdata",
"wmask"})))
3386 if (!portClock || (clock && portClock != clock))
3392 rewriter.setInsertionPointAfter(mem);
3393 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3399 rewriter.setInsertionPointToEnd(block);
3401 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3404 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3408 auto pipeline = [&](Value value, Value clock,
const Twine &name,
3410 for (
unsigned i = 0; i < latency; ++i) {
3411 std::string regName;
3413 llvm::raw_string_ostream os(regName);
3414 os << mem.getName() <<
"_" << name <<
"_" << i;
3416 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3417 rewriter.getStringAttr(regName))
3419 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3425 const unsigned writeStages =
info.writeLatency - 1;
3430 SmallVector<std::tuple<Value, Value, Value>> writes;
3431 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3433 StringRef name = mem.getPortName(i);
3435 auto portPipeline = [&, port = port](StringRef field,
unsigned stages) {
3438 return pipeline(value, portClock, name +
"_" + field, stages);
3441 switch (mem.getPortKind(i)) {
3442 case MemOp::PortKind::Debug:
3443 llvm_unreachable(
"unknown port kind");
3444 case MemOp::PortKind::Read: {
3452 case MemOp::PortKind::Write: {
3453 auto data = portPipeline(
"data", writeStages);
3454 auto en = portPipeline(
"en", writeStages);
3455 auto mask = portPipeline(
"mask", writeStages);
3459 case MemOp::PortKind::ReadWrite: {
3464 auto wdata = portPipeline(
"wdata", writeStages);
3465 auto wmask = portPipeline(
"wmask", writeStages);
3470 auto wen = AndPrimOp::create(rewriter, port.getLoc(),
en,
wmode);
3472 pipeline(wen, portClock, name +
"_wen", writeStages);
3473 writes.emplace_back(
wdata, wenPipelined,
wmask);
3480 Value next = memReg;
3486 Location loc = mem.getLoc();
3487 unsigned maskGran =
info.dataWidth /
info.maskBits;
3488 SmallVector<Value> chunks;
3489 for (
unsigned i = 0; i <
info.maskBits; ++i) {
3490 unsigned hi = (i + 1) * maskGran - 1;
3491 unsigned lo = i * maskGran;
3493 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc,
data, hi, lo);
3494 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3495 auto bit = rewriter.createOrFold<BitsPrimOp>(loc,
mask, i, i);
3496 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3497 chunks.push_back(chunk);
3500 std::reverse(chunks.begin(), chunks.end());
3501 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3502 next = MuxPrimOp::create(rewriter, next.getLoc(),
en, masked, next);
3504 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3505 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3508 for (Operation *conn : connects)
3509 rewriter.eraseOp(
conn);
3510 for (
auto portAccess : portAccesses)
3511 rewriter.eraseOp(portAccess);
3512 rewriter.eraseOp(mem);
3519void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3522 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3523 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3543 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3546 auto *high = mux.getHigh().getDefiningOp();
3547 auto *low = mux.getLow().getDefiningOp();
3549 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3556 bool constReg =
false;
3558 if (constOp && low == reg)
3560 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3562 constOp = dyn_cast<ConstantOp>(low);
3569 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3573 auto regTy = reg.getResult().getType();
3574 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3575 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3576 regTy.getBitWidthOrSentinel() < 0)
3582 if (constOp != &con->getBlock()->front())
3583 constOp->moveBefore(&con->getBlock()->front());
3586 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3587 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3588 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3589 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3590 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3591 reg.getForceableAttr());
3592 newReg->setDialectAttrs(attrs);
3594 auto pt = rewriter.saveInsertionPoint();
3595 rewriter.setInsertionPoint(con);
3596 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3597 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3598 rewriter.restoreInsertionPoint(pt);
3602LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3603 if (!
hasDontTouch(op.getOperation()) && !op.isForceable() &&
3619 PatternRewriter &rewriter,
3622 if (
auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3623 if (constant.getValue().isZero()) {
3624 rewriter.eraseOp(op);
3630 if (
auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3631 if (constant.getValue().isZero() == eraseIfZero) {
3632 rewriter.eraseOp(op);
3640template <
class Op,
bool EraseIfZero = false>
3642 PatternRewriter &rewriter) {
3647void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3649 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3650 results.add<patterns::AssertXWhenX>(
context);
3653void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3655 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3656 results.add<patterns::AssumeXWhenX>(
context);
3659void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3660 RewritePatternSet &results, MLIRContext *
context) {
3661 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3662 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(
context);
3665void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3667 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3674LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3675 PatternRewriter &rewriter) {
3677 if (op.use_empty()) {
3678 rewriter.eraseOp(op);
3685 if (op->hasOneUse() &&
3686 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3687 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3688 *op->user_begin()) ||
3689 (isa<CvtPrimOp>(*op->user_begin()) &&
3690 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3691 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3692 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3693 .getBitWidthOrSentinel() > 0))) {
3694 auto *modop = *op->user_begin();
3695 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3696 modop->getResult(0).getType());
3697 rewriter.replaceAllOpUsesWith(modop, inv);
3698 rewriter.eraseOp(modop);
3699 rewriter.eraseOp(op);
3705OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3706 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3707 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3715OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3724 return BoolAttr::get(getContext(),
false);
3728 return BoolAttr::get(getContext(),
false);
3733LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3734 PatternRewriter &rewriter) {
3736 if (
auto testEnable = op.getTestEnable()) {
3737 if (
auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3738 if (constOp.getValue().isZero()) {
3739 rewriter.modifyOpInPlace(op,
3740 [&] { op.getTestEnableMutable().clear(); });
3756 auto forceable = op.getRef().getDefiningOp<Forceable>();
3757 if (!forceable || !forceable.isForceable() ||
3758 op.getRef() != forceable.getDataRef() ||
3759 op.getType() != forceable.getDataType())
3761 rewriter.replaceAllUsesWith(op, forceable.getData());
3765void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3767 results.insert<patterns::RefResolveOfRefSend>(
context);
3771OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3773 if (getInput().getType() == getType())
3779 auto constOp = operand.getDefiningOp<ConstantOp>();
3780 return constOp && constOp.getValue().isZero();
3783template <
typename Op>
3786 rewriter.eraseOp(op);
3792void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3794 results.add(eraseIfPredFalse<RefForceOp>);
3796void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3798 results.add(eraseIfPredFalse<RefForceInitialOp>);
3800void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3802 results.add(eraseIfPredFalse<RefReleaseOp>);
3804void RefReleaseInitialOp::getCanonicalizationPatterns(
3805 RewritePatternSet &results, MLIRContext *
context) {
3806 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3813OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3819 if (adaptor.getReset())
3824 if (
isUInt1(getReset().getType()) && adaptor.getClock())
3837 [&](
auto ty) ->
bool {
return isTypeEmpty(ty.getElementType()); })
3838 .Case<BundleType>([&](
auto ty) ->
bool {
3839 for (
auto elem : ty.getElements())
3844 .Case<IntType>([&](
auto ty) {
return ty.getWidth() == 0; })
3845 .Default([](
auto) ->
bool {
return false; });
3848LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3849 PatternRewriter &rewriter) {
3850 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3857 rewriter.eraseOp(op);
3865LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3866 PatternRewriter &rewriter) {
3869 if (op.getBody()->empty()) {
3870 rewriter.eraseOp(op);
3881OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3883 if (getDomains().
empty())
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
LogicalResult matchAndRewrite(BitsPrimOp bits, mlir::PatternRewriter &rewriter) const override