20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallPtrSet.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
29using namespace firrtl;
33static Value
dropWrite(PatternRewriter &rewriter, OpResult old,
35 SmallPtrSet<Operation *, 8> users;
36 for (
auto *user : old.getUsers())
38 for (Operation *user : users)
39 if (
auto connect = dyn_cast<FConnectLike>(user))
40 if (connect.getDest() == old)
41 rewriter.eraseOp(user);
51 if (op->getNumRegions() != 0)
53 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
61 Operation *op = passthrough.getDefiningOp();
64 assert(op &&
"passthrough must be an operation");
65 Operation *oldOp = old.getOwner();
66 auto name = oldOp->getAttrOfType<StringAttr>(
"name");
68 op->setAttr(
"name", name);
76#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
84 auto resultType = type_cast<IntType>(op->getResult(0).getType());
85 if (!resultType.hasWidth())
87 for (Value operand : op->getOperands())
88 if (!type_cast<IntType>(operand.getType()).hasWidth())
95 auto t = type_dyn_cast<UIntType>(type);
96 if (!t || !t.hasWidth() || t.getWidth() != 1)
103static void updateName(PatternRewriter &rewriter, Operation *op,
108 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) &&
"Should never rename");
109 auto newName = name.getValue();
110 auto newOpName = op->getAttrOfType<StringAttr>(
"name");
113 newName =
chooseName(newOpName.getValue(), name.getValue());
115 if (!newOpName || newOpName.getValue() != newName)
116 rewriter.modifyOpInPlace(
117 op, [&] { op->setAttr(
"name", rewriter.getStringAttr(newName)); });
125 if (
auto *newOp = newValue.getDefiningOp()) {
126 auto name = op->getAttrOfType<StringAttr>(
"name");
129 rewriter.replaceOp(op, newValue);
135template <
typename OpTy,
typename... Args>
137 Operation *op, Args &&...args) {
138 auto name = op->getAttrOfType<StringAttr>(
"name");
140 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
148 if (
auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
149 return namableOp.hasDroppableName();
160static std::optional<APSInt>
162 assert(type_cast<IntType>(operand.getType()) &&
163 "getExtendedConstant is limited to integer types");
170 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
175 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
176 return APSInt(destWidth,
177 type_cast<IntType>(operand.getType()).isUnsigned());
185 if (
auto attr = dyn_cast<BoolAttr>(operand))
186 return APSInt(APInt(1, attr.getValue()));
187 if (
auto attr = dyn_cast<IntegerAttr>(operand))
188 return attr.getAPSInt();
196 return cst->isZero();
223 Operation *op, ArrayRef<Attribute> operands,
BinOpKind opKind,
224 const function_ref<APInt(
const APSInt &,
const APSInt &)> &calculate) {
225 assert(operands.size() == 2 &&
"binary op takes two operands");
228 auto resultType = type_cast<IntType>(op->getResult(0).getType());
229 if (resultType.getWidthOrSentinel() < 0)
233 if (resultType.getWidthOrSentinel() == 0)
234 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
240 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
242 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
243 if (
auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
244 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
245 if (
auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
246 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
250 int32_t operandWidth;
253 operandWidth = resultType.getWidthOrSentinel();
258 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
262 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
273 APInt resultValue = calculate(*lhs, *rhs);
278 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
280 assert((
unsigned)resultType.getWidthOrSentinel() ==
281 resultValue.getBitWidth());
294 Operation *op, PatternRewriter &rewriter,
295 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
297 if (op->getNumResults() != 1)
299 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
304 auto width = type.getBitWidthOrSentinel();
309 SmallVector<Attribute, 3> constOperands;
310 constOperands.reserve(op->getNumOperands());
311 for (
auto operand : op->getOperands()) {
313 if (
auto *defOp = operand.getDefiningOp())
314 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
315 [&](
auto op) { attr = op.getValueAttr(); });
316 constOperands.push_back(attr);
321 auto result = canonicalize(constOperands);
325 if (
auto cst = dyn_cast<Attribute>(result))
326 resultValue = op->getDialect()
327 ->materializeConstant(rewriter, cst, type, op->getLoc())
330 resultValue = cast<Value>(result);
334 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
335 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
338 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
339 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
340 else if (type_isa<UIntType>(type) &&
341 type_isa<SIntType>(resultValue.getType()))
342 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
344 assert(type == resultValue.getType() &&
"canonicalization changed type");
352 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
358 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
364 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
371OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
372 assert(adaptor.getOperands().empty() &&
"constant has no operands");
373 return getValueAttr();
376OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
377 assert(adaptor.getOperands().empty() &&
"constant has no operands");
378 return getValueAttr();
381OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
382 assert(adaptor.getOperands().empty() &&
"constant has no operands");
383 return getFieldsAttr();
386OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
387 assert(adaptor.getOperands().empty() &&
"constant has no operands");
388 return getValueAttr();
391OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
392 assert(adaptor.getOperands().empty() &&
"constant has no operands");
393 return getValueAttr();
396OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
397 assert(adaptor.getOperands().empty() &&
"constant has no operands");
398 return getValueAttr();
401OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
402 assert(adaptor.getOperands().empty() &&
"constant has no operands");
403 return getValueAttr();
410OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
413 [=](
const APSInt &a,
const APSInt &b) { return a + b; });
416void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
418 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
419 patterns::AddOfSelf, patterns::AddOfPad>(
context);
422OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
425 [=](
const APSInt &a,
const APSInt &b) { return a - b; });
428void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
430 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
431 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
432 patterns::SubOfPadL, patterns::SubOfPadR>(
context);
435OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
447 [=](
const APSInt &a,
const APSInt &b) { return a * b; });
450OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
457 if (getLhs() == getRhs()) {
458 auto width = getType().base().getWidthOrSentinel();
463 return getIntAttr(getType(), APInt(width, 1));
480 if (
auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
481 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
486 [=](
const APSInt &a,
const APSInt &b) -> APInt {
489 return APInt(a.getBitWidth(), 0);
493OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
500 if (getLhs() == getRhs())
514 [=](
const APSInt &a,
const APSInt &b) -> APInt {
517 return APInt(a.getBitWidth(), 0);
521OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
524 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
527OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
530 [=](
const APSInt &a,
const APSInt &b) -> APInt { return a.shl(b); });
533OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
536 [=](
const APSInt &a,
const APSInt &b) -> APInt {
537 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
543OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
546 if (rhsCst->isZero())
550 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
551 getRhs().getType() == getType())
557 if (lhsCst->isZero())
561 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
562 getRhs().getType() == getType())
567 if (getLhs() == getRhs() && getRhs().getType() == getType())
572 [](
const APSInt &a,
const APSInt &b) -> APInt { return a & b; });
575void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
578 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
579 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
580 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(
context);
583OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
586 if (rhsCst->isZero() && getLhs().getType() == getType())
590 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
591 getLhs().getType() == getType())
597 if (lhsCst->isZero() && getRhs().getType() == getType())
601 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
602 getRhs().getType() == getType())
607 if (getLhs() == getRhs() && getRhs().getType() == getType())
612 [](
const APSInt &a,
const APSInt &b) -> APInt { return a | b; });
615void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
617 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
618 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
622OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
625 if (rhsCst->isZero() &&
631 if (lhsCst->isZero() &&
636 if (getLhs() == getRhs())
639 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
643 [](
const APSInt &a,
const APSInt &b) -> APInt { return a ^ b; });
646void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
648 results.insert<patterns::extendXor, patterns::moveConstXor,
649 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
653void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
655 results.insert<patterns::LEQWithConstLHS>(
context);
658OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
659 bool isUnsigned = getLhs().getType().base().isUnsigned();
662 if (getLhs() == getRhs())
666 if (
auto width = getLhs().getType().base().
getWidth()) {
668 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
669 commonWidth = std::max(commonWidth, 1);
680 if (isUnsigned && rhsCst->zext(commonWidth)
693 [=](
const APSInt &a,
const APSInt &b) -> APInt {
694 return APInt(1, a <= b);
698void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
700 results.insert<patterns::LTWithConstLHS>(
context);
703OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
704 IntType lhsType = getLhs().getType();
708 if (getLhs() == getRhs())
718 if (
auto width = lhsType.
getWidth()) {
720 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
721 commonWidth = std::max(commonWidth, 1);
732 if (isUnsigned && rhsCst->zext(commonWidth)
745 [=](
const APSInt &a,
const APSInt &b) -> APInt {
746 return APInt(1, a < b);
750void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
752 results.insert<patterns::GEQWithConstLHS>(
context);
755OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
756 IntType lhsType = getLhs().getType();
760 if (getLhs() == getRhs())
765 if (rhsCst->isZero() && isUnsigned)
770 if (
auto width = lhsType.
getWidth()) {
772 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
773 commonWidth = std::max(commonWidth, 1);
776 if (isUnsigned && rhsCst->zext(commonWidth)
797 [=](
const APSInt &a,
const APSInt &b) -> APInt {
798 return APInt(1, a >= b);
802void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
804 results.insert<patterns::GTWithConstLHS>(
context);
807OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
808 IntType lhsType = getLhs().getType();
812 if (getLhs() == getRhs())
816 if (
auto width = lhsType.
getWidth()) {
818 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
819 commonWidth = std::max(commonWidth, 1);
822 if (isUnsigned && rhsCst->zext(commonWidth)
843 [=](
const APSInt &a,
const APSInt &b) -> APInt {
844 return APInt(1, a > b);
848OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
850 if (getLhs() == getRhs())
856 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
857 getRhs().getType() == getType())
863 [=](
const APSInt &a,
const APSInt &b) -> APInt {
864 return APInt(1, a == b);
868LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
870 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
872 auto width = op.getLhs().getType().getBitWidthOrSentinel();
875 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
876 op.getRhs().getType() == op.getType()) {
877 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
882 if (rhsCst->isZero() && width > 1) {
883 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
884 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
888 if (rhsCst->isAllOnes() && width > 1 &&
889 op.getLhs().getType() == op.getRhs().getType()) {
890 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
898OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
900 if (getLhs() == getRhs())
906 if (rhsCst->isZero() && getLhs().getType() == getType() &&
907 getRhs().getType() == getType())
913 [=](
const APSInt &a,
const APSInt &b) -> APInt {
914 return APInt(1, a != b);
918LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
920 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
922 auto width = op.getLhs().getType().getBitWidthOrSentinel();
925 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
926 op.getRhs().getType() == op.getType()) {
927 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
932 if (rhsCst->isZero() && width > 1) {
933 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
938 if (rhsCst->isAllOnes() && width > 1 &&
939 op.getLhs().getType() == op.getRhs().getType()) {
941 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
942 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
950OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
956OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
962OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
966 return IntegerAttr::get(IntegerType::get(getContext(),
967 lhsCst->getBitWidth(),
968 IntegerType::Signed),
969 lhsCst->ashr(*rhsCst));
972 if (rhsCst->isZero())
979OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
984 return IntegerAttr::get(IntegerType::get(getContext(),
985 lhsCst->getBitWidth(),
986 IntegerType::Signed),
987 lhsCst->shl(*rhsCst));
990 if (rhsCst->isZero())
1001OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
1002 auto base = getInput().getType();
1009OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1016OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1024 if (getType().base().hasWidth())
1031void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1033 results.insert<patterns::StoUtoS>(
context);
1036OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1044 if (getType().base().hasWidth())
1051void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1053 results.insert<patterns::UtoStoU>(
context);
1056OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1058 if (getInput().getType() == getType())
1063 return BoolAttr::get(getContext(), cst->getBoolValue());
1068OpFoldResult AsResetPrimOp::fold(FoldAdaptor adaptor) {
1070 return BoolAttr::get(getContext(), cst->getBoolValue());
1074OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1076 if (getInput().getType() == getType())
1081 return BoolAttr::get(getContext(), cst->getBoolValue());
1086OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1092 getType().base().getWidthOrSentinel()))
1098void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1100 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(
context);
1103OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1110 getType().base().getWidthOrSentinel()))
1111 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1116OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1121 getType().base().getWidthOrSentinel()))
1127void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1129 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1130 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1137 : RewritePattern(opName, 0,
context) {}
1143 ConstantOp constantOp,
1144 SmallVectorImpl<Value> &remaining)
const = 0;
1151 mlir::PatternRewriter &rewriter)
const override {
1153 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1157 SmallVector<Value> nonConstantOperands;
1160 for (
auto operand : catOp.getInputs()) {
1161 if (
auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1163 if (
handleConstant(rewriter, op, constantOp, nonConstantOperands))
1167 nonConstantOperands.push_back(operand);
1172 if (nonConstantOperands.empty()) {
1173 replaceOpWithNewOpAndCopyName<ConstantOp>(
1174 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1180 if (nonConstantOperands.size() == 1) {
1181 rewriter.modifyOpInPlace(
1182 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1187 if (catOp->hasOneUse() &&
1188 nonConstantOperands.size() < catOp->getNumOperands()) {
1189 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1190 nonConstantOperands);
1203 SmallVectorImpl<Value> &remaining)
const override {
1204 if (value.getValue().isZero())
1207 replaceOpWithNewOpAndCopyName<ConstantOp>(
1208 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1221 SmallVectorImpl<Value> &remaining)
const override {
1222 if (value.getValue().isAllOnes())
1225 replaceOpWithNewOpAndCopyName<ConstantOp>(
1226 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1239 SmallVectorImpl<Value> &remaining)
const override {
1240 if (value.getValue().isZero())
1242 remaining.push_back(value);
1248OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1252 if (getInput().getType().getBitWidthOrSentinel() == 0)
1257 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1261 if (
isUInt1(getInput().getType()))
1267void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1269 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1270 patterns::AndRPadS, patterns::AndRCatAndR_left,
1274OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1278 if (getInput().getType().getBitWidthOrSentinel() == 0)
1283 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1287 if (
isUInt1(getInput().getType()))
1293void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1295 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1296 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right,
OrRCat>(
1300OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1304 if (getInput().getType().getBitWidthOrSentinel() == 0)
1309 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1312 if (
isUInt1(getInput().getType()))
1318void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1321 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1322 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right,
XorRCat>(
1330OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1331 auto inputs = getInputs();
1332 auto inputAdaptors = adaptor.getInputs();
1339 if (inputs.size() == 1 && inputs[0].getType() == getType())
1347 SmallVector<Value> nonZeroInputs;
1348 SmallVector<Attribute> nonZeroAttributes;
1349 bool allConstant =
true;
1350 for (
auto [input, attr] :
llvm::zip(inputs, inputAdaptors)) {
1351 auto inputType = type_cast<IntType>(input.getType());
1352 if (inputType.getBitWidthOrSentinel() != 0) {
1353 nonZeroInputs.push_back(input);
1355 allConstant =
false;
1356 if (nonZeroInputs.size() > 1 && !allConstant)
1362 if (nonZeroInputs.empty())
1366 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1367 return nonZeroInputs[0];
1373 SmallVector<APInt> constants;
1374 for (
auto inputAdaptor : inputAdaptors) {
1376 constants.push_back(*cst);
1381 assert(!constants.empty());
1383 APInt result = constants[0];
1384 for (
size_t i = 1; i < constants.size(); ++i)
1385 result = result.concat(constants[i]);
1390void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1392 results.insert<patterns::DShlOfConstant>(
context);
1395void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1397 results.insert<patterns::DShrOfConstant>(
context);
1405 using OpRewritePattern::OpRewritePattern;
1408 matchAndRewrite(CatPrimOp cat,
1409 mlir::PatternRewriter &rewriter)
const override {
1411 cat.getType().getBitWidthOrSentinel() == 0)
1415 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1419 SmallVector<Value> operands;
1420 SmallVector<Value> worklist;
1421 auto pushOperands = [&worklist](CatPrimOp op) {
1422 for (
auto operand :
llvm::reverse(op.getInputs()))
1423 worklist.push_back(operand);
1426 bool hasSigned =
false, hasUnsigned =
false;
1427 while (!worklist.empty()) {
1428 auto value = worklist.pop_back_val();
1429 auto catOp = value.getDefiningOp<CatPrimOp>();
1431 operands.push_back(value);
1432 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) =
true;
1436 pushOperands(catOp);
1441 auto castToUIntIfSigned = [&](Value value) -> Value {
1442 if (type_isa<UIntType>(value.getType()))
1444 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1447 assert(operands.size() >= 1 &&
"zero width cast must be rejected");
1449 if (operands.size() == 1) {
1450 rewriter.replaceOp(cat, castToUIntIfSigned(operands[0]));
1454 if (operands.size() == cat->getNumOperands())
1458 if (hasSigned && hasUnsigned)
1459 for (
auto &operand : operands)
1460 operand = castToUIntIfSigned(operand);
1462 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1471 using OpRewritePattern::OpRewritePattern;
1474 matchAndRewrite(CatPrimOp cat,
1475 mlir::PatternRewriter &rewriter)
const override {
1479 SmallVector<Value> operands;
1481 for (
size_t i = 0; i < cat->getNumOperands(); ++i) {
1482 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1484 operands.push_back(cat.getInputs()[i]);
1487 APSInt value = cst.getValue();
1489 for (; j < cat->getNumOperands(); ++j) {
1490 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1493 value = value.concat(nextCst.getValue());
1498 operands.push_back(cst);
1501 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1507 if (operands.size() == cat->getNumOperands())
1510 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1519void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1521 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1522 patterns::CatCast, FlattenCat, CatOfConstant>(
context);
1529OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
1531 if (getInputs().size() == 1)
1532 return getInputs()[0];
1535 if (!llvm::all_of(adaptor.getInputs(), [](Attribute operand) {
1536 return isa_and_nonnull<StringAttr>(operand);
1541 SmallString<64> result;
1542 for (
auto operand : adaptor.getInputs())
1543 result += cast<StringAttr>(operand).getValue();
1545 return StringAttr::get(getContext(), result);
1553 using OpRewritePattern::OpRewritePattern;
1556 matchAndRewrite(StringConcatOp concat,
1557 mlir::PatternRewriter &rewriter)
const override {
1561 bool hasNestedConcat = llvm::any_of(concat.getInputs(), [](Value operand) {
1562 auto nestedConcat = operand.getDefiningOp<StringConcatOp>();
1563 return nestedConcat && operand.hasOneUse();
1566 if (!hasNestedConcat)
1570 SmallVector<Value> flatOperands;
1571 for (
auto input : concat.getInputs()) {
1572 if (
auto nestedConcat = input.getDefiningOp<StringConcatOp>();
1573 nestedConcat && input.hasOneUse())
1574 llvm::append_range(flatOperands, nestedConcat.getInputs());
1576 flatOperands.push_back(input);
1579 rewriter.modifyOpInPlace(concat,
1580 [&]() { concat->setOperands(flatOperands); });
1587class MergeAdjacentStringConstants
1590 using OpRewritePattern::OpRewritePattern;
1593 matchAndRewrite(StringConcatOp concat,
1594 mlir::PatternRewriter &rewriter)
const override {
1596 SmallVector<Value> newOperands;
1597 SmallString<64> accumulatedLit;
1598 SmallVector<StringConstantOp> accumulatedOps;
1599 bool changed =
false;
1601 auto flushLiterals = [&]() {
1602 if (accumulatedOps.empty())
1606 if (accumulatedOps.size() == 1) {
1607 newOperands.push_back(accumulatedOps[0]);
1610 auto newLit = rewriter.createOrFold<StringConstantOp>(
1611 concat.getLoc(), StringAttr::get(getContext(), accumulatedLit));
1612 newOperands.push_back(newLit);
1615 accumulatedLit.clear();
1616 accumulatedOps.clear();
1619 for (
auto operand : concat.getInputs()) {
1620 if (
auto litOp = operand.getDefiningOp<StringConstantOp>()) {
1622 if (litOp.getValue().empty()) {
1626 accumulatedLit += litOp.getValue();
1627 accumulatedOps.push_back(litOp);
1630 newOperands.push_back(operand);
1641 if (newOperands.empty())
1642 return rewriter.replaceOpWithNewOp<StringConstantOp>(
1643 concat, StringAttr::get(getContext(),
"")),
1647 rewriter.modifyOpInPlace(concat,
1648 [&]() { concat->setOperands(newOperands); });
1655void StringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
1657 results.insert<FlattenStringConcat, MergeAdjacentStringConstants>(
context);
1664OpFoldResult PropEqOp::fold(FoldAdaptor adaptor) {
1665 auto lhsAttr = adaptor.getLhs();
1666 auto rhsAttr = adaptor.getRhs();
1667 if (!lhsAttr || !rhsAttr)
1670 return BoolAttr::get(getContext(), lhsAttr == rhsAttr);
1673OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1676 if (op.getType() == op.getInput().getType())
1677 return op.getInput();
1681 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1682 if (op.getType() == in.getInput().getType())
1683 return in.getInput();
1688OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1689 IntType inputType = getInput().getType();
1690 IntType resultType = getType();
1692 if (inputType == getType() && resultType.
hasWidth())
1699 cst->extractBits(getHi() - getLo() + 1, getLo()));
1705 using OpRewritePattern::OpRewritePattern;
1709 mlir::PatternRewriter &rewriter)
const override {
1710 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1713 int32_t bitPos = bits.getLo();
1714 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1715 if (resultWidth < 0)
1717 for (
auto operand : llvm::reverse(cat.getInputs())) {
1719 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1720 if (operandWidth < 0)
1722 if (bitPos < operandWidth) {
1723 if (bitPos + resultWidth <= operandWidth) {
1724 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1725 bits.getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1731 bitPos -= operandWidth;
1737void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1740 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1748 unsigned loBit, PatternRewriter &rewriter) {
1749 auto resType = type_cast<IntType>(op->getResult(0).getType());
1750 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1751 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1753 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1754 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1755 }
else if (resType.isUnsigned() &&
1756 !type_cast<IntType>(value.getType()).isUnsigned()) {
1757 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1759 rewriter.replaceOp(op, value);
1762template <
typename OpTy>
1763static OpFoldResult
foldMux(OpTy op,
typename OpTy::FoldAdaptor adaptor) {
1765 if (op.getType().getBitWidthOrSentinel() == 0)
1767 APInt(0, 0, op.getType().isSignedInteger()));
1770 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1771 return op.getHigh();
1776 if (op.getType().getBitWidthOrSentinel() < 0)
1781 if (cond->isZero() && op.getLow().getType() == op.getType())
1783 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1784 return op.getHigh();
1788 if (
auto lowCst =
getConstant(adaptor.getLow())) {
1790 if (
auto highCst =
getConstant(adaptor.getHigh())) {
1792 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1793 *highCst == *lowCst)
1796 if (highCst->isOne() && lowCst->isZero() &&
1797 op.getType() == op.getSel().getType())
1810OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1811 return foldMux(*
this, adaptor);
1814OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1815 return foldMux(*
this, adaptor);
1818OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) {
return {}; }
1827 using OpRewritePattern::OpRewritePattern;
1830 matchAndRewrite(MuxPrimOp mux,
1831 mlir::PatternRewriter &rewriter)
const override {
1832 auto width = mux.getType().getBitWidthOrSentinel();
1836 auto pad = [&](Value input) -> Value {
1838 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1839 if (inputWidth < 0 || width == inputWidth)
1841 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1846 auto newHigh = pad(mux.getHigh());
1847 auto newLow = pad(mux.getLow());
1848 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1851 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1852 rewriter, mux, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1862 using OpRewritePattern::OpRewritePattern;
1864 static const int depthLimit = 5;
1866 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1867 mlir::PatternRewriter &rewriter,
1868 bool updateInPlace)
const {
1869 if (updateInPlace) {
1870 rewriter.modifyOpInPlace(mux, [&] {
1871 mux.setOperand(1, high);
1872 mux.setOperand(2, low);
1876 rewriter.setInsertionPointAfter(mux);
1877 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1878 ValueRange{mux.getSel(), high, low})
1883 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1884 bool updateInPlace,
int limit)
const {
1885 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1888 if (mux.getSel() == cond)
1889 return mux.getHigh();
1890 if (limit > depthLimit)
1892 updateInPlace &= mux->hasOneUse();
1894 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1896 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1899 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1900 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1905 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1906 bool updateInPlace,
int limit)
const {
1907 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1910 if (mux.getSel() == cond)
1911 return mux.getLow();
1912 if (limit > depthLimit)
1914 updateInPlace &= mux->hasOneUse();
1916 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1918 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1920 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1922 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1928 matchAndRewrite(MuxPrimOp mux,
1929 mlir::PatternRewriter &rewriter)
const override {
1930 auto width = mux.getType().getBitWidthOrSentinel();
1934 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter,
true, 0)) {
1935 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1939 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter,
true, 0)) {
1940 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1949void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1952 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1953 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1954 patterns::MuxSameTrue, patterns::MuxSameFalse,
1955 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1959void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1960 RewritePatternSet &results, MLIRContext *
context) {
1961 results.add<patterns::Mux2PadSel>(
context);
1964void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1965 RewritePatternSet &results, MLIRContext *
context) {
1966 results.add<patterns::Mux4PadSel>(
context);
1969OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1970 auto input = this->getInput();
1973 if (input.getType() == getType())
1977 auto inputType = input.getType().base();
1984 auto destWidth = getType().base().getWidthOrSentinel();
1985 if (destWidth == -1)
1988 if (inputType.
isSigned() && cst->getBitWidth())
1989 return getIntAttr(getType(), cst->sext(destWidth));
1990 return getIntAttr(getType(), cst->zext(destWidth));
1996OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1997 auto input = this->getInput();
1998 IntType inputType = input.getType();
1999 int shiftAmount = getAmount();
2002 if (shiftAmount == 0)
2008 if (inputWidth != -1) {
2009 auto resultWidth = inputWidth + shiftAmount;
2010 shiftAmount = std::min(shiftAmount, resultWidth);
2011 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
2017OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
2018 auto input = this->getInput();
2019 IntType inputType = input.getType();
2020 int shiftAmount = getAmount();
2026 if (shiftAmount == 0 && inputWidth > 0)
2029 if (inputWidth == -1)
2031 if (inputWidth == 0)
2036 if (shiftAmount >= inputWidth && inputType.
isUnsigned())
2037 return getIntAttr(getType(), APInt(0, 0,
false));
2043 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
2045 value = cst->lshr(std::min(shiftAmount, inputWidth));
2046 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
2047 return getIntAttr(getType(), value.trunc(resultWidth));
2052LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
2053 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2054 if (inputWidth <= 0)
2058 unsigned shiftAmount = op.getAmount();
2059 if (
int(shiftAmount) >= inputWidth) {
2061 if (op.getType().base().isUnsigned())
2067 shiftAmount = inputWidth - 1;
2070 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
2074LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
2075 PatternRewriter &rewriter) {
2076 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2077 if (inputWidth <= 0)
2081 unsigned keepAmount = op.getAmount();
2083 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
2088OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
2092 getInput().getType().base().getWidthOrSentinel() - getAmount();
2093 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
2099OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
2103 cst->trunc(getType().base().getWidthOrSentinel()));
2107LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
2108 PatternRewriter &rewriter) {
2109 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2110 if (inputWidth <= 0)
2114 unsigned dropAmount = op.getAmount();
2115 if (dropAmount !=
unsigned(inputWidth))
2121void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
2123 results.add<patterns::SubaccessOfConstant>(
context);
2126OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
2128 if (adaptor.getInputs().size() == 1)
2129 return getOperand(1);
2131 if (
auto constIndex =
getConstant(adaptor.getIndex())) {
2132 auto index = constIndex->getZExtValue();
2133 if (index < getInputs().size())
2134 return getInputs()[getInputs().size() - 1 - index];
2140LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
2141 PatternRewriter &rewriter) {
2145 if (llvm::all_of(op.getInputs().drop_front(), [&](
auto input) {
2146 return input == op.getInputs().front();
2154 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2155 uint64_t inputSize = op.getInputs().size();
2156 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2157 rewriter.modifyOpInPlace(op, [&]() {
2158 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2165 if (
auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2166 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](
auto e) {
2167 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2168 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2169 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2171 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2172 rewriter, op, lastSubindex.getInput(), op.getIndex());
2178 if (op.getInputs().size() != 2)
2182 auto uintType = op.getIndex().getType();
2183 if (uintType.getBitWidthOrSentinel() != 1)
2187 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2188 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2207 MatchingConnectOp connect;
2208 for (Operation *user : value.getUsers()) {
2210 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2213 if (
auto aConnect = dyn_cast<FConnectLike>(user))
2214 if (aConnect.getDest() == value) {
2215 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2218 if (!matchingConnect || (connect && connect != matchingConnect) ||
2219 matchingConnect->getBlock() != value.getParentBlock())
2221 connect = matchingConnect;
2229 PatternRewriter &rewriter) {
2232 Operation *connectedDecl = op.getDest().getDefiningOp();
2237 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2241 cast<Forceable>(connectedDecl).isForceable())
2249 if (connectedDecl->hasOneUse())
2253 auto *declBlock = connectedDecl->getBlock();
2254 auto *srcValueOp = op.getSrc().getDefiningOp();
2257 if (!isa<WireOp>(connectedDecl))
2263 if (!isa<ConstantOp>(srcValueOp))
2265 if (srcValueOp->getBlock() != declBlock)
2271 auto replacement = op.getSrc();
2274 if (srcValueOp && srcValueOp != &declBlock->front())
2275 srcValueOp->moveBefore(&declBlock->front());
2282 rewriter.eraseOp(op);
2286void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2288 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2292LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2293 PatternRewriter &rewriter) {
2310 for (
auto *user : value.getUsers()) {
2311 auto attach = dyn_cast<AttachOp>(user);
2312 if (!attach || attach == dominatedAttach)
2314 if (attach->isBeforeInBlock(dominatedAttach))
2320LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2322 if (op.getNumOperands() <= 1) {
2323 rewriter.eraseOp(op);
2327 for (
auto operand : op.getOperands()) {
2334 SmallVector<Value> newOperands(op.getOperands());
2335 for (
auto newOperand : attach.getOperands())
2336 if (newOperand != operand)
2337 newOperands.push_back(newOperand);
2338 AttachOp::create(rewriter, op->getLoc(), newOperands);
2339 rewriter.eraseOp(attach);
2340 rewriter.eraseOp(op);
2348 if (
auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2349 if (!
hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2350 !wire.isForceable()) {
2351 SmallVector<Value> newOperands;
2352 for (
auto newOperand : op.getOperands())
2353 if (newOperand != operand)
2354 newOperands.push_back(newOperand);
2356 AttachOp::create(rewriter, op->getLoc(), newOperands);
2357 rewriter.eraseOp(op);
2358 rewriter.eraseOp(wire);
2369 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
2370 rewriter.inlineBlockBefore(®ion.front(), op, {});
2373LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2374 if (
auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2375 if (constant.getValue().isAllOnes())
2377 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2380 rewriter.eraseOp(op);
2386 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2387 op.getElseBlock().empty()) {
2388 rewriter.eraseBlock(&op.getElseBlock());
2395 if (!op.getThenBlock().empty())
2399 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2400 rewriter.eraseOp(op);
2410 using OpRewritePattern::OpRewritePattern;
2411 LogicalResult matchAndRewrite(NodeOp node,
2412 PatternRewriter &rewriter)
const override {
2413 auto name = node.getNameAttr();
2414 if (!node.hasDroppableName() || node.getInnerSym() ||
2417 auto *newOp = node.getInput().getDefiningOp();
2420 rewriter.replaceOp(node, node.getInput());
2427 using OpRewritePattern::OpRewritePattern;
2428 LogicalResult matchAndRewrite(NodeOp node,
2429 PatternRewriter &rewriter)
const override {
2431 node.use_empty() || node.isForceable())
2433 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2440template <
typename OpTy>
2442 PatternRewriter &rewriter) {
2443 if (!op.isForceable() || !op.getDataRef().use_empty())
2451LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2452 SmallVectorImpl<OpFoldResult> &results) {
2461 if (!adaptor.getInput())
2464 results.push_back(adaptor.getInput());
2468void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2470 results.insert<FoldNodeName>(
context);
2471 results.add(demoteForceableIfUnused<NodeOp>);
2477struct AggOneShot :
public mlir::RewritePattern {
2478 AggOneShot(StringRef name, uint32_t weight, MLIRContext *
context)
2479 : RewritePattern(name, 0,
context) {}
2481 SmallVector<Value> getCompleteWrite(Operation *lhs)
const {
2482 auto lhsTy = lhs->getResult(0).getType();
2483 if (!type_isa<BundleType, FVectorType>(lhsTy))
2486 DenseMap<uint32_t, Value> fields;
2487 for (Operation *user : lhs->getResult(0).getUsers()) {
2488 if (user->getParentOp() != lhs->getParentOp())
2490 if (
auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2491 if (aConnect.getDest() == lhs->getResult(0))
2493 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2494 for (Operation *subuser : subField.getResult().getUsers()) {
2495 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2496 if (aConnect.getDest() == subField) {
2497 if (subuser->getParentOp() != lhs->getParentOp())
2499 if (fields.count(subField.getFieldIndex()))
2501 fields[subField.getFieldIndex()] = aConnect.getSrc();
2507 }
else if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2508 for (Operation *subuser : subIndex.getResult().getUsers()) {
2509 if (
auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2510 if (aConnect.getDest() == subIndex) {
2511 if (subuser->getParentOp() != lhs->getParentOp())
2513 if (fields.count(subIndex.getIndex()))
2515 fields[subIndex.getIndex()] = aConnect.getSrc();
2526 SmallVector<Value> values;
2527 uint32_t total = type_isa<BundleType>(lhsTy)
2528 ? type_cast<BundleType>(lhsTy).getNumElements()
2529 : type_cast<FVectorType>(lhsTy).getNumElements();
2530 for (uint32_t i = 0; i < total; ++i) {
2531 if (!fields.count(i))
2533 values.push_back(fields[i]);
2538 LogicalResult matchAndRewrite(Operation *op,
2539 PatternRewriter &rewriter)
const override {
2540 auto values = getCompleteWrite(op);
2543 rewriter.setInsertionPointToEnd(op->getBlock());
2544 auto dest = op->getResult(0);
2545 auto destType = dest.getType();
2548 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2551 Value newVal = type_isa<BundleType>(destType)
2552 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2554 : rewriter.createOrFold<VectorCreateOp>(
2555 op->
getLoc(), destType, values);
2556 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2557 for (Operation *user : dest.getUsers()) {
2558 if (
auto subIndex = dyn_cast<SubindexOp>(user)) {
2559 for (Operation *subuser :
2560 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2561 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2562 if (aConnect.getDest() == subIndex)
2563 rewriter.eraseOp(aConnect);
2564 }
else if (
auto subField = dyn_cast<SubfieldOp>(user)) {
2565 for (Operation *subuser :
2566 llvm::make_early_inc_range(subField.getResult().getUsers()))
2567 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2568 if (aConnect.getDest() == subField)
2569 rewriter.eraseOp(aConnect);
2576struct WireAggOneShot :
public AggOneShot {
2577 WireAggOneShot(MLIRContext *
context)
2578 : AggOneShot(WireOp::getOperationName(), 0,
context) {}
2580struct SubindexAggOneShot :
public AggOneShot {
2581 SubindexAggOneShot(MLIRContext *
context)
2582 : AggOneShot(SubindexOp::getOperationName(), 0,
context) {}
2584struct SubfieldAggOneShot :
public AggOneShot {
2585 SubfieldAggOneShot(MLIRContext *
context)
2586 : AggOneShot(SubfieldOp::getOperationName(), 0,
context) {}
2590void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2592 results.insert<WireAggOneShot>(
context);
2593 results.add(demoteForceableIfUnused<WireOp>);
2596void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2598 results.insert<SubindexAggOneShot>(
context);
2601OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2602 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2605 return attr[getIndex()];
2608OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2609 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2612 auto index = getFieldIndex();
2616void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2618 results.insert<SubfieldAggOneShot>(
context);
2622 ArrayRef<Attribute> operands) {
2623 for (
auto operand : operands)
2626 return ArrayAttr::get(
context, operands);
2629OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2632 if (getNumOperands() > 0)
2633 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2634 if (first.getFieldIndex() == 0 &&
2635 first.getInput().getType() == getType() &&
2637 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2639 elem.value().
template getDefiningOp<SubfieldOp>();
2640 return subindex && subindex.getInput() == first.getInput() &&
2641 subindex.getFieldIndex() == elem.index();
2643 return first.getInput();
2648OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2651 if (getNumOperands() > 0)
2652 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2653 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2655 llvm::drop_begin(llvm::enumerate(getOperands())), [&](
auto elem) {
2657 elem.value().
template getDefiningOp<SubindexOp>();
2658 return subindex && subindex.getInput() == first.getInput() &&
2659 subindex.getIndex() == elem.index();
2661 return first.getInput();
2666OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2667 if (getOperand().getType() == getType())
2668 return getOperand();
2676 using OpRewritePattern::OpRewritePattern;
2677 LogicalResult matchAndRewrite(RegResetOp reg,
2678 PatternRewriter &rewriter)
const override {
2680 dyn_cast_or_null<ConstantOp>(
reg.getResetValue().getDefiningOp());
2689 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2692 auto *high = mux.getHigh().getDefiningOp();
2693 auto *low = mux.getLow().getDefiningOp();
2694 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2696 if (constOp && low != reg)
2698 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2699 constOp = dyn_cast<ConstantOp>(low);
2701 if (!constOp || constOp.getType() != reset.getType() ||
2702 constOp.getValue() != reset.getValue())
2706 auto regTy =
reg.getResult().getType();
2707 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2708 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2709 regTy.getBitWidthOrSentinel() < 0)
2715 if (constOp != &con->getBlock()->front())
2716 constOp->moveBefore(&con->getBlock()->front());
2721 rewriter.eraseOp(con);
2728 if (
auto c = v.getDefiningOp<ConstantOp>())
2729 return c.getValue().isOne();
2730 if (
auto sc = v.getDefiningOp<SpecialConstantOp>())
2731 return sc.getValue();
2740 auto resetValue = reg.getResetValue();
2741 if (reg.getType(0) != resetValue.getType())
2745 (void)
dropWrite(rewriter, reg->getResult(0), {});
2746 replaceOpWithNewOpAndCopyName<NodeOp>(
2747 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2748 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2752void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2754 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(
context);
2756 results.add(demoteForceableIfUnused<RegResetOp>);
2761 auto portTy = type_cast<BundleType>(port.getType());
2762 auto fieldIndex = portTy.getElementIndex(name);
2763 assert(fieldIndex &&
"missing field on memory port");
2766 for (
auto *op : port.getUsers()) {
2767 auto portAccess = cast<SubfieldOp>(op);
2768 if (fieldIndex != portAccess.getFieldIndex())
2773 value = conn.getSrc();
2783 auto portConst = value.getDefiningOp<ConstantOp>();
2786 return portConst.getValue().isZero();
2791 auto portTy = type_cast<BundleType>(port.getType());
2792 auto fieldIndex = portTy.getElementIndex(
data);
2793 assert(fieldIndex &&
"missing enable flag on memory port");
2795 for (
auto *op : port.getUsers()) {
2796 auto portAccess = cast<SubfieldOp>(op);
2797 if (fieldIndex != portAccess.getFieldIndex())
2799 if (!portAccess.use_empty())
2808 StringRef name, Value value) {
2809 auto portTy = type_cast<BundleType>(port.getType());
2810 auto fieldIndex = portTy.getElementIndex(name);
2811 assert(fieldIndex &&
"missing field on memory port");
2813 for (
auto *op : llvm::make_early_inc_range(port.getUsers())) {
2814 auto portAccess = cast<SubfieldOp>(op);
2815 if (fieldIndex != portAccess.getFieldIndex())
2817 rewriter.replaceAllUsesWith(portAccess, value);
2818 rewriter.eraseOp(portAccess);
2823static void erasePort(PatternRewriter &rewriter, Value port) {
2826 auto getClock = [&] {
2828 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2829 ClockType::get(rewriter.getContext()),
2838 for (
auto *op : port.getUsers()) {
2839 auto subfield = dyn_cast<SubfieldOp>(op);
2841 auto ty = port.getType();
2842 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2843 rewriter.replaceAllUsesWith(port, reg.getResult());
2852 for (
auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2853 auto access = cast<SubfieldOp>(accessOp);
2854 for (
auto *user : llvm::make_early_inc_range(access->getUsers())) {
2855 auto connect = dyn_cast<FConnectLike>(user);
2856 if (connect && connect.getDest() == access) {
2857 rewriter.eraseOp(user);
2861 if (access.use_empty()) {
2862 rewriter.eraseOp(access);
2868 auto ty = access.getType();
2869 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2870 rewriter.replaceOp(access, reg.getResult());
2872 assert(port.use_empty() &&
"port should have no remaining uses");
2878 using OpRewritePattern::OpRewritePattern;
2879 LogicalResult matchAndRewrite(MemOp mem,
2880 PatternRewriter &rewriter)
const override {
2884 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2885 mem.getDataType().getBitWidthOrSentinel() != 0)
2889 for (
auto port : mem.getResults())
2890 for (auto *user : port.getUsers())
2891 if (!isa<SubfieldOp>(user))
2896 for (
auto port : mem.getResults()) {
2897 for (
auto *user :
llvm::make_early_inc_range(port.getUsers())) {
2898 SubfieldOp sfop = cast<SubfieldOp>(user);
2899 StringRef fieldName = sfop.getFieldName();
2900 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2901 rewriter, sfop, sfop.getResult().getType())
2903 if (fieldName.ends_with(
"data")) {
2905 auto zero = firrtl::ConstantOp::create(
2906 rewriter, wire.getLoc(),
2907 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2908 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2912 rewriter.eraseOp(mem);
2919 using OpRewritePattern::OpRewritePattern;
2920 LogicalResult matchAndRewrite(MemOp mem,
2921 PatternRewriter &rewriter)
const override {
2924 bool isRead =
false, isWritten =
false;
2925 for (
unsigned i = 0; i < mem.getNumResults(); ++i) {
2926 switch (mem.getPortKind(i)) {
2927 case MemOp::PortKind::Read:
2932 case MemOp::PortKind::Write:
2937 case MemOp::PortKind::Debug:
2938 case MemOp::PortKind::ReadWrite:
2941 llvm_unreachable(
"unknown port kind");
2943 assert((!isWritten || !isRead) &&
"memory is in use");
2948 if (isRead && mem.getInit())
2951 for (
auto port : mem.getResults())
2954 rewriter.eraseOp(mem);
2961 using OpRewritePattern::OpRewritePattern;
2962 LogicalResult matchAndRewrite(MemOp mem,
2963 PatternRewriter &rewriter)
const override {
2967 llvm::SmallBitVector deadPorts(mem.getNumResults());
2968 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2970 if (!mem.getPortAnnotation(i).empty())
2974 auto kind = mem.getPortKind(i);
2975 if (kind == MemOp::PortKind::Debug)
2984 if (kind == MemOp::PortKind::Read &&
isPortUnused(port,
"data")) {
2989 if (deadPorts.none())
2993 SmallVector<Type> resultTypes;
2994 SmallVector<StringRef> portNames;
2995 SmallVector<Attribute> portAnnotations;
2996 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
2999 resultTypes.push_back(port.getType());
3000 portNames.push_back(mem.getPortName(i));
3001 portAnnotations.push_back(mem.getPortAnnotation(i));
3005 if (!resultTypes.empty())
3006 newOp = MemOp::create(
3007 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3008 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3009 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3010 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3011 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3014 unsigned nextPort = 0;
3015 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3019 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
3022 rewriter.eraseOp(mem);
3029 using OpRewritePattern::OpRewritePattern;
3030 LogicalResult matchAndRewrite(MemOp mem,
3031 PatternRewriter &rewriter)
const override {
3036 llvm::SmallBitVector deadReads(mem.getNumResults());
3037 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3038 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
3040 if (!mem.getPortAnnotation(i).empty())
3047 if (deadReads.none())
3050 SmallVector<Type> resultTypes;
3051 SmallVector<StringRef> portNames;
3052 SmallVector<Attribute> portAnnotations;
3053 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3055 resultTypes.push_back(
3056 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
3057 MemOp::PortKind::Write, mem.getMaskBits()));
3059 resultTypes.push_back(port.getType());
3061 portNames.push_back(mem.getPortName(i));
3062 portAnnotations.push_back(mem.getPortAnnotation(i));
3065 auto newOp = MemOp::create(
3066 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3067 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3068 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3069 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3070 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3072 for (
unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
3073 auto result = mem.getResult(i);
3074 auto newResult = newOp.getResult(i);
3076 auto resultPortTy = type_cast<BundleType>(result.getType());
3080 auto replace = [&](StringRef toName, StringRef fromName) {
3081 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
3082 assert(fromFieldIndex &&
"missing enable flag on memory port");
3084 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
3086 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
3087 auto fromField = cast<SubfieldOp>(op);
3088 if (fromFieldIndex != fromField.getFieldIndex())
3090 rewriter.replaceOp(fromField, toField.getResult());
3094 replace(
"addr",
"addr");
3095 replace(
"en",
"en");
3096 replace(
"clk",
"clk");
3097 replace(
"data",
"wdata");
3098 replace(
"mask",
"wmask");
3101 auto wmodeFieldIndex = resultPortTy.getElementIndex(
"wmode");
3102 for (
auto *op :
llvm::make_early_inc_range(result.getUsers())) {
3103 auto wmodeField = cast<SubfieldOp>(op);
3104 if (wmodeFieldIndex != wmodeField.getFieldIndex())
3106 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
3109 rewriter.replaceAllUsesWith(result, newResult);
3112 rewriter.eraseOp(mem);
3119 using OpRewritePattern::OpRewritePattern;
3121 LogicalResult matchAndRewrite(MemOp mem,
3122 PatternRewriter &rewriter)
const override {
3127 const auto &summary = mem.getSummary();
3128 if (summary.isMasked || summary.isSeqMem())
3131 auto type = type_dyn_cast<IntType>(mem.getDataType());
3134 auto width = type.getBitWidthOrSentinel();
3138 llvm::SmallBitVector usedBits(width);
3139 DenseMap<unsigned, unsigned> mapping;
3144 SmallVector<BitsPrimOp> readOps;
3145 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3146 auto portTy = type_cast<BundleType>(port.getType());
3147 auto fieldIndex = portTy.getElementIndex(field);
3148 assert(fieldIndex &&
"missing data port");
3150 for (
auto *op : port.getUsers()) {
3151 auto portAccess = cast<SubfieldOp>(op);
3152 if (fieldIndex != portAccess.getFieldIndex())
3155 for (
auto *user : op->getUsers()) {
3156 auto bits = dyn_cast<BitsPrimOp>(user);
3160 usedBits.set(bits.getLo(), bits.getHi() + 1);
3164 mapping[bits.getLo()] = 0;
3165 readOps.push_back(bits);
3175 SmallVector<MatchingConnectOp> writeOps;
3176 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3177 auto portTy = type_cast<BundleType>(port.getType());
3178 auto fieldIndex = portTy.getElementIndex(field);
3179 assert(fieldIndex &&
"missing data port");
3181 for (
auto *op : port.getUsers()) {
3182 auto portAccess = cast<SubfieldOp>(op);
3183 if (fieldIndex != portAccess.getFieldIndex())
3190 writeOps.push_back(conn);
3196 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3198 if (!mem.getPortAnnotation(i).empty())
3201 switch (mem.getPortKind(i)) {
3202 case MemOp::PortKind::Debug:
3205 case MemOp::PortKind::Write:
3206 if (failed(findWriteUsers(port,
"data")))
3209 case MemOp::PortKind::Read:
3210 if (failed(findReadUsers(port,
"data")))
3213 case MemOp::PortKind::ReadWrite:
3214 if (failed(findWriteUsers(port,
"wdata")))
3216 if (failed(findReadUsers(port,
"rdata")))
3220 llvm_unreachable(
"unknown port kind");
3224 if (usedBits.none())
3228 SmallVector<std::pair<unsigned, unsigned>> ranges;
3229 unsigned newWidth = 0;
3230 for (
int i = usedBits.find_first(); 0 <= i && i < width;) {
3231 int e = usedBits.find_next_unset(i);
3234 for (
int idx = i; idx < e; ++idx, ++newWidth) {
3235 if (
auto it = mapping.find(idx); it != mapping.end()) {
3236 it->second = newWidth;
3239 ranges.emplace_back(i, e - 1);
3240 i = e != width ? usedBits.find_next(e) : e;
3244 auto newType =
IntType::get(mem->getContext(), type.isSigned(), newWidth);
3245 SmallVector<Type> portTypes;
3246 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3247 portTypes.push_back(
3248 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3250 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3251 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3252 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3253 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3254 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3257 auto rewriteSubfield = [&](Value port, StringRef field) {
3258 auto portTy = type_cast<BundleType>(port.getType());
3259 auto fieldIndex = portTy.getElementIndex(field);
3260 assert(fieldIndex &&
"missing data port");
3262 rewriter.setInsertionPointAfter(newMem);
3263 auto newPortAccess =
3264 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3266 for (
auto *op :
llvm::make_early_inc_range(port.getUsers())) {
3267 auto portAccess = cast<SubfieldOp>(op);
3268 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3270 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3275 for (
auto [i, port] :
llvm::enumerate(newMem.getResults())) {
3276 switch (newMem.getPortKind(i)) {
3277 case MemOp::PortKind::Debug:
3278 llvm_unreachable(
"cannot rewrite debug port");
3279 case MemOp::PortKind::Write:
3280 rewriteSubfield(port,
"data");
3282 case MemOp::PortKind::Read:
3283 rewriteSubfield(port,
"data");
3285 case MemOp::PortKind::ReadWrite:
3286 rewriteSubfield(port,
"rdata");
3287 rewriteSubfield(port,
"wdata");
3290 llvm_unreachable(
"unknown port kind");
3294 for (
auto readOp : readOps) {
3295 rewriter.setInsertionPointAfter(readOp);
3296 auto it = mapping.find(readOp.getLo());
3297 assert(it != mapping.end() &&
"bit op mapping not found");
3300 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3301 readOp.getLoc(), readOp.getInput(),
3302 readOp.getHi() - readOp.getLo() + it->second, it->second);
3303 rewriter.replaceAllUsesWith(readOp, newReadValue);
3304 rewriter.eraseOp(readOp);
3308 for (
auto writeOp : writeOps) {
3309 Value source = writeOp.getSrc();
3310 rewriter.setInsertionPoint(writeOp);
3312 SmallVector<Value> slices;
3313 for (
auto &[start, end] :
llvm::reverse(ranges)) {
3314 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3315 source,
end, start);
3316 slices.push_back(slice);
3320 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3326 if (type.isSigned())
3328 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3330 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3340 using OpRewritePattern::OpRewritePattern;
3341 LogicalResult matchAndRewrite(MemOp mem,
3342 PatternRewriter &rewriter)
const override {
3347 auto ty = mem.getDataType();
3348 auto loc = mem.getLoc();
3349 auto *block = mem->getBlock();
3353 SmallPtrSet<Operation *, 8> connects;
3354 SmallVector<SubfieldOp> portAccesses;
3355 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3356 if (!mem.getPortAnnotation(i).empty())
3359 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3360 auto portTy = type_cast<BundleType>(port.getType());
3361 for (
auto field : fields) {
3362 auto fieldIndex = portTy.getElementIndex(field);
3363 assert(fieldIndex &&
"missing field on memory port");
3365 for (
auto *op : port.getUsers()) {
3366 auto portAccess = cast<SubfieldOp>(op);
3367 if (fieldIndex != portAccess.getFieldIndex())
3369 portAccesses.push_back(portAccess);
3370 for (
auto *user : portAccess->getUsers()) {
3371 auto conn = dyn_cast<FConnectLike>(user);
3374 connects.insert(conn);
3381 switch (mem.getPortKind(i)) {
3382 case MemOp::PortKind::Debug:
3384 case MemOp::PortKind::Read:
3385 if (failed(collect({
"clk",
"en",
"addr"})))
3388 case MemOp::PortKind::Write:
3389 if (failed(collect({
"clk",
"en",
"addr",
"data",
"mask"})))
3392 case MemOp::PortKind::ReadWrite:
3393 if (failed(collect({
"clk",
"en",
"addr",
"wmode",
"wdata",
"wmask"})))
3399 if (!portClock || (clock && portClock != clock))
3405 rewriter.setInsertionPointAfter(mem);
3406 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3412 rewriter.setInsertionPointToEnd(block);
3414 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3417 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3421 auto pipeline = [&](Value value, Value clock,
const Twine &name,
3423 for (
unsigned i = 0; i < latency; ++i) {
3424 std::string regName;
3426 llvm::raw_string_ostream os(regName);
3427 os << mem.getName() <<
"_" << name <<
"_" << i;
3429 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3430 rewriter.getStringAttr(regName))
3432 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3438 const unsigned writeStages =
info.writeLatency - 1;
3443 SmallVector<std::tuple<Value, Value, Value>> writes;
3444 for (
auto [i, port] :
llvm::enumerate(mem.getResults())) {
3446 StringRef name = mem.getPortName(i);
3448 auto portPipeline = [&, port = port](StringRef field,
unsigned stages) {
3451 return pipeline(value, portClock, name +
"_" + field, stages);
3454 switch (mem.getPortKind(i)) {
3455 case MemOp::PortKind::Debug:
3456 llvm_unreachable(
"unknown port kind");
3457 case MemOp::PortKind::Read: {
3465 case MemOp::PortKind::Write: {
3466 auto data = portPipeline(
"data", writeStages);
3467 auto en = portPipeline(
"en", writeStages);
3468 auto mask = portPipeline(
"mask", writeStages);
3472 case MemOp::PortKind::ReadWrite: {
3477 auto wdata = portPipeline(
"wdata", writeStages);
3478 auto wmask = portPipeline(
"wmask", writeStages);
3483 auto wen = AndPrimOp::create(rewriter, port.getLoc(),
en,
wmode);
3485 pipeline(wen, portClock, name +
"_wen", writeStages);
3486 writes.emplace_back(
wdata, wenPipelined,
wmask);
3493 Value next = memReg;
3499 Location loc = mem.getLoc();
3500 unsigned maskGran =
info.dataWidth /
info.maskBits;
3501 SmallVector<Value> chunks;
3502 for (
unsigned i = 0; i <
info.maskBits; ++i) {
3503 unsigned hi = (i + 1) * maskGran - 1;
3504 unsigned lo = i * maskGran;
3506 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc,
data, hi, lo);
3507 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3508 auto bit = rewriter.createOrFold<BitsPrimOp>(loc,
mask, i, i);
3509 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3510 chunks.push_back(chunk);
3513 std::reverse(chunks.begin(), chunks.end());
3514 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3515 next = MuxPrimOp::create(rewriter, next.getLoc(),
en, masked, next);
3517 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3518 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3521 for (Operation *conn : connects)
3522 rewriter.eraseOp(
conn);
3523 for (
auto portAccess : portAccesses)
3524 rewriter.eraseOp(portAccess);
3525 rewriter.eraseOp(mem);
3532void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3535 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3536 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3556 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3559 auto *high = mux.getHigh().getDefiningOp();
3560 auto *low = mux.getLow().getDefiningOp();
3562 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3569 bool constReg =
false;
3571 if (constOp && low == reg)
3573 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3575 constOp = dyn_cast<ConstantOp>(low);
3582 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3586 auto regTy = reg.getResult().getType();
3587 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3588 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3589 regTy.getBitWidthOrSentinel() < 0)
3595 if (constOp != &con->getBlock()->front())
3596 constOp->moveBefore(&con->getBlock()->front());
3599 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3600 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3601 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3602 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3603 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3604 reg.getForceableAttr());
3605 newReg->setDialectAttrs(attrs);
3607 auto pt = rewriter.saveInsertionPoint();
3608 rewriter.setInsertionPoint(con);
3609 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3610 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3611 rewriter.restoreInsertionPoint(pt);
3615LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3616 if (!
hasDontTouch(op.getOperation()) && !op.isForceable() &&
3632 PatternRewriter &rewriter,
3635 if (
auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3636 if (constant.getValue().isZero()) {
3637 rewriter.eraseOp(op);
3643 if (
auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3644 if (constant.getValue().isZero() == eraseIfZero) {
3645 rewriter.eraseOp(op);
3653template <
class Op,
bool EraseIfZero = false>
3655 PatternRewriter &rewriter) {
3660void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3662 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3663 results.add<patterns::AssertXWhenX>(
context);
3666void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3668 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3669 results.add<patterns::AssumeXWhenX>(
context);
3672void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3673 RewritePatternSet &results, MLIRContext *
context) {
3674 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3675 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(
context);
3678void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3680 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3687LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3688 PatternRewriter &rewriter) {
3690 if (op.use_empty()) {
3691 rewriter.eraseOp(op);
3698 if (op->hasOneUse() &&
3699 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3700 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3701 *op->user_begin()) ||
3702 (isa<CvtPrimOp>(*op->user_begin()) &&
3703 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3704 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3705 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3706 .getBitWidthOrSentinel() > 0))) {
3707 auto *modop = *op->user_begin();
3708 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3709 modop->getResult(0).getType());
3710 rewriter.replaceAllOpUsesWith(modop, inv);
3711 rewriter.eraseOp(modop);
3712 rewriter.eraseOp(op);
3718OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3719 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3720 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3728OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3737 return BoolAttr::get(getContext(),
false);
3741 return BoolAttr::get(getContext(),
false);
3746LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3747 PatternRewriter &rewriter) {
3749 if (
auto testEnable = op.getTestEnable()) {
3750 if (
auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3751 if (constOp.getValue().isZero()) {
3752 rewriter.modifyOpInPlace(op,
3753 [&] { op.getTestEnableMutable().clear(); });
3769 auto forceable = op.getRef().getDefiningOp<Forceable>();
3770 if (!forceable || !forceable.isForceable() ||
3771 op.getRef() != forceable.getDataRef() ||
3772 op.getType() != forceable.getDataType())
3774 rewriter.replaceAllUsesWith(op, forceable.getData());
3778void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3780 results.insert<patterns::RefResolveOfRefSend>(
context);
3784OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3786 if (getInput().getType() == getType())
3792 auto constOp = operand.getDefiningOp<ConstantOp>();
3793 return constOp && constOp.getValue().isZero();
3796template <
typename Op>
3799 rewriter.eraseOp(op);
3805void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3807 results.add(eraseIfPredFalse<RefForceOp>);
3809void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3811 results.add(eraseIfPredFalse<RefForceInitialOp>);
3813void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3815 results.add(eraseIfPredFalse<RefReleaseOp>);
3817void RefReleaseInitialOp::getCanonicalizationPatterns(
3818 RewritePatternSet &results, MLIRContext *
context) {
3819 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3826OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3832 if (adaptor.getReset())
3837 if (
isUInt1(getReset().getType()) && adaptor.getClock())
3850 [&](
auto ty) ->
bool {
return isTypeEmpty(ty.getElementType()); })
3851 .Case<BundleType>([&](
auto ty) ->
bool {
3852 for (
auto elem : ty.getElements())
3857 .Case<IntType>([&](
auto ty) {
return ty.getWidth() == 0; })
3858 .Default([](
auto) ->
bool {
return false; });
3861LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3862 PatternRewriter &rewriter) {
3863 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3870 rewriter.eraseOp(op);
3878LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3879 PatternRewriter &rewriter) {
3882 if (op.getBody()->empty()) {
3883 rewriter.eraseOp(op);
3894OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3896 if (getDomains().
empty())
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
LogicalResult matchAndRewrite(BitsPrimOp bits, mlir::PatternRewriter &rewriter) const override